How to custom losses by subclass tf.keras.losses.Loss class in Tensorflow2.x(如何通过Tensorflow 2.x中的子类tf.keras.losse.Loss类自定义损耗)
本文介绍了如何通过Tensorflow 2.x中的子类tf.keras.losse.Loss类自定义损耗的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!
问题描述
当我在TensorFlow的网站上阅读guides时,我发现了两种自定义损失的方法。第一种是定义损失函数,就像:
def basic_loss_function(y_true, y_pred):
return tf.math.reduce_mean(tf.abs(y_true - y_pred))
为了简单起见,我们假设批大小也是1,因此y_true
和y_pred
的形状都是(1,c),其中c是类的数量。因此,在此方法中,我们给出两个向量y_true
和y_pred
,并返回一个值(Scala)。
然后,第二种方法是子类化tf.keras.losses.Loss
类,指南中的代码是:
class WeightedBinaryCrossEntropy(keras.losses.Loss):
"""
Args:
pos_weight: Scalar to affect the positive labels of the loss function.
weight: Scalar to affect the entirety of the loss function.
from_logits: Whether to compute loss from logits or the probability.
reduction: Type of tf.keras.losses.Reduction to apply to loss.
name: Name of the loss function.
"""
def __init__(self, pos_weight, weight, from_logits=False,
reduction=keras.losses.Reduction.AUTO,
name='weighted_binary_crossentropy'):
super().__init__(reduction=reduction, name=name)
self.pos_weight = pos_weight
self.weight = weight
self.from_logits = from_logits
def call(self, y_true, y_pred):
ce = tf.losses.binary_crossentropy(
y_true, y_pred, from_logits=self.from_logits)[:,None]
ce = self.weight * (ce*(1-y_true) + self.pos_weight*ce*(y_true))
return ce
在调用方法中,像往常一样,我们给出了两个向量y_true
和y_pred
,但我注意到它返回的是ce
,这是一个形状为(1,c)!
那么在上面的玩具示例中有什么问题吗?或者Tensorflow 2.x背后有什么神奇的东西?
推荐答案
除了实现之外,两者的主要区别在于损耗函数的类型。第一个是L1损失(根据定义,绝对差异的平均值,主要用于类似回归的问题),第二个是二元交叉(用于分类)。它们不是意味着相同损失的不同实现,这在您链接的指南中有说明。
多标签、多类别分类设置中的二进制交叉标记会为每个类别输出一个值,就好像它们彼此独立一样。
编辑:
在第二个损失函数中,reduction
参数控制聚合输出的方式,例如。默认情况下,您的代码使用keras.losses.Reduction.AUTO
,如果您选中the source code,这将转换为对批处理求和。这意味着,最终的损失将是一个矢量,但还有其他可用的减少,您可以在docs中检查它们。我相信,即使您没有定义折减来获取损失向量中损失元素的总和,TF优化器也会这样做,以避免反向传播向量所产生的错误。向量上的反向传播会在权重上造成问题,这些问题会对每个损失元素产生影响。但是,我还没有在源代码中检查这一点。:)
这篇关于如何通过Tensorflow 2.x中的子类tf.keras.losse.Loss类自定义损耗的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持编程学习网!
沃梦达教程
本文标题为:如何通过Tensorflow 2.x中的子类tf.keras.losse.Loss类自定义损耗
基础教程推荐
猜你喜欢
- 使用 Google App Engine (Python) 将文件上传到 Google Cloud Storage 2022-01-01
- 使用Python匹配Stata加权xtil命令的确定方法? 2022-01-01
- 使 Python 脚本在 Windows 上运行而不指定“.py";延期 2022-01-01
- 哪些 Python 包提供独立的事件系统? 2022-01-01
- Python 的 List 是如何实现的? 2022-01-01
- 如何在 Python 中检测文件是否为二进制(非文本)文 2022-01-01
- 合并具有多索引的两个数据帧 2022-01-01
- 将 YAML 文件转换为 python dict 2022-01-01
- 症状类型错误:无法确定关系的真值 2022-01-01
- 如何在Python中绘制多元函数? 2022-01-01