简介
损失函数在机器学习模型训练中至关重要,在大多数机器学习项目中,如果没有损失函数,就无法驱动模型做出正确的预测。用通俗的话来说,损失函数是一种数学函数或表达式,用于衡量模型在某些数据集上的表现。了解模型在特定数据集上的表现,可以让开发者洞察到在训练过程中做出许多决策,比如使用一个新的、更强大的模型,或者甚至将损失函数本身更改为另一种类型。说到损失函数的类型,多年来已经开发出几种这样的损失函数,每种都适用于特定的训练任务。
先决条件
阅读这篇文章需要理解神经网络。在较高层次上,神经网络由互连的节点(“神经元”)组成,这些节点组织成层。他们通过一个称为“训练”的过程进行学习和做出预测,该过程调整神经元之间的连接权重和偏置。理解神经网络包括了解它们的不同的层(输入层、隐藏层、输出层)、激活函数、优化算法(梯度下降的变体)、损失函数等。
此外,熟悉 Python 语法和 PyTorch 库对于理解本文中提供的代码片段是必不可少的。
在本文中,我们将探讨 PyTorch nn 模块中的不同损失函数。我们还将通过构建一个自定义的损失函数,深入了解 PyTorch 如何通过其 nn 模块 API 向用户暴露这些损失函数。
现在我们对损失函数有了一个高层次的理解,让我们进一步探讨损失函数的技术细节以及它们是如何工作的。
什么是损失函数?
我们之前提到,损失函数告诉我们模型在特定数据集上的表现如何。从技术上讲,它通过测量预测值与实际值的接近程度来实现这一点。当我们的模型在训练和测试数据集上做出非常接近实际值的预测时,意味着我们的模型相当健壮。
尽管损失函数为我们提供了有关模型性能的重要信息,但这并不是损失函数的主要功能,因为还有更健壮的技术来评估我们的模型,如准确率和 F 分数。损失函数的重要性主要体现在训练过程中,通过调整模型权重以最小化损失,我们增加了模型做出正确预测的概率,而这在没有损失函数的情况下可能无法实现。
不同的损失函数适用于不同的问题,每个损失函数都是研究人员精心设计的,以确保训练过程中梯度的稳定流动。
有时,损失函数的数学表达式可能会有些令人生畏,这导致一些开发者将它们视为黑箱。稍后我们将揭示 PyTorch 中一些最常用的损失函数,但在此之前,让我们先来看看如何在 PyTorch 中使用损失函数。
PyTorch 中的损失函数
PyTorch 自带了许多标准的损失函数,其简化的设计模式使开发者能够在训练过程中快速迭代这些不同的损失函数。所有 PyTorch 的损失函数都打包在 nn 模块中,这是 PyTorch 所有神经网络的基类。这使得将损失函数添加到项目中变得非常简单,只需要添加一行代码即可。让我们看看如何在 PyTorch 中添加一个均方误差损失函数。
上面代码返回的函数可以用于计算预测值与实际值之间的差距,使用以下格式。
现在我们已经了解了如何在 PyTorch 中使用损失函数,让我们深入了解 PyTorch 提供的几种损失函数的幕后原理。
PyTorch 中有哪些损失函数可用?
PyTorch 提供的许多损失函数大致分为三类——回归损失、分类损失和排序损失。
回归损失主要涉及可以在两个极限之间取任何值的连续值。其中一个例子是对一个社区房价的预测。
分类损失函数处理离散值,比如将物体分类为盒子、笔或瓶子的任务。
排序损失预测值之间的相对距离。这方面的一个例子是人脸验证,我们希望知道哪些人脸图像属于特定的人脸,并可以通过排名哪些人脸属于和不属于原始人脸持有者来实现,这通过它们相对于目标人脸扫描的相对近似程度来实现。
L1 损失函数/平均绝对误差
L1损失函数计算预测张量和目标之间的平均绝对误差。它首先计算预测张量中每个值与目标之间的绝对差,然后计算所有绝对差计算返回值的和。最后,它计算这个和值的平均值,得到均方误差(MAE)。L1损失函数在处理噪声方面非常稳健。
返回的单个值是两个维度为3×5的张量之间的计算损失。
均方误差
均方误差(MSE)与平均绝对误差(MAE)有显著的相似之处。与平均绝对误差不同,均方误差计算预测张量和目标张量中值之间的平方差。这样做的话,相对较大的差异会受到更大的惩罚,而相对较小的差异则受到较小的惩罚。与MAE相比,MSE在处理异常值和噪声方面被认为不太鲁棒。
交叉熵损失
交叉熵损失用于涉及多个离散类别的分类问题。它衡量给定随机变量的两个概率分布之间的差异。通常,在使用交叉熵损失时,我们网络的输出是一个softmax层,这确保了神经网络的输出是一个概率值(0-1之间的值)。
softmax层由两部分组成——特定类的预测的指数。
yi 是神经网络对特定类的输出。如果yi较大且为负,该函数的输出接近于零,但永远不会为零;如果yi为正且非常大,则输出更接近于1。
第二部分是归一化值,用于确保softmax层的输出始终是概率值。
这是通过将每个类别值的指数相加得到的。softmax的最终方程看起来像这样:
]
在PyTorch的nn模块中,交叉熵损失将逻辑softmax和负对数似然(NLL)损失组合成一个损失函数。
注意打印输出中的梯度函数是一个NLL损失。这实际上揭示了交叉熵损失在幕后将NLL损失与逻辑softmax层结合在一起。
负对数似然(NLL)损失
NLL损失函数的工作方式与交叉熵损失函数非常相似。交叉熵损失通过结合逻辑softmax层和NLL损失来得到交叉熵损失的值。这意味着可以通过将神经网络的最后一层替换为逻辑softmax层而不是普通softmax层来获得交叉熵损失值。
二元交叉熵损失
二元交叉熵损失是用于将数据点分类为仅两个类别的特殊交叉熵损失类。这种问题类型的标签通常是二进制的,因此我们的目标是让模型预测值为零的标签接近零,预测值为一的标签接近一。在用于二分类的 BCE 损失时,通常会使用 Sigmoid 层确保神经网络的输出值为零或一附近。
二进制交叉熵损失
我们在上一节提到,二进制交叉熵损失通常输出为一个sigmoid层,以确保输出在0到1之间。带有logits的二进制交叉熵损失将这两个层合并为一个层。根据PyTorch文档,这版更加数值稳定,因为它利用了log-sum exp技巧。
平滑L1损失
光滑的L1损失函数通过启发式值beta结合了MSE损失和MAE损失的优点。这个标准是在Fast R-CNN论文中提出的。当真实值与预测值之间的绝对差异小于beta时,标准使用平方差异,类似于MSE损失。MSE损失的图形是一条连续曲线,这意味着每个损失值的梯度是可变的,并且可以在任何地方求导。此外,随着损失值的减少,梯度减小,这在梯度下降期间是方便的。然而,对于非常大的损失值,梯度会爆炸,因此在绝对差异大于beta时,标准会切换到MAE损失,此时每个损失值的梯度几乎是恒定的,从而消除了潜在的梯度爆炸。
铰链嵌入损失
铰链嵌入损失主要用于半监督学习任务中,用于衡量两个输入之间的相似性。它用于有一个输入张量和一个包含值为1或-1的标签张量的情况。它主要用于涉及非线性嵌入和半监督学习的问题。
边缘排名损失
边缘排名损失是一种排名损失,其主要目标与其他损失函数不同,它是用来衡量数据集中一组输入之间的相对距离。边缘排名损失函数接受两个输入和一个仅包含1或-1的标签。如果标签为1,则假设第一个输入应该排在第二个输入之前,如果标签为-1,则假设第二个输入应该排在第一个输入之前。这种关系由下面的方程式和代码表示。
三元组损失
该准则通过使用训练数据样本的三元组来度量数据点之间的相似性。涉及的三元组包括锚样本、正样本和负样本。目标有两个) 将正样本与锚样本之间的距离尽可能最小化,锚样本与负样本之间的距离大于边距值加上正样本与锚样本之间的距离。通常,正样本属于锚样本的同一类,但负样本不属于。因此,通过使用此损失函数,我们旨在使用三元组边际损失来预测锚样本与正样本之间的高相似度值和锚样本与负样本之间的低相似度值。
余弦嵌入损失
余弦嵌入损失测量给定输入x1、x2和包含值1或-1的标签张量y的损失。它用于测量两个输入的相似或不相似程度。
该准则通过计算空间中两个数据点之间的余弦距离来测量相似性。余弦距离与两个点之间的角度相关,这意味着角度越小,输入越接近,因此它们越相似。
库尔兰-列贝尔损失函数
给定两个分布P和Q,库尔兰-列贝尔(KL)散度损失衡量的是当P(假设为真实分布)被Q替换时损失了多少信息。通过测量使用Q逼近P时损失的信息量,我们能够得到P和Q之间的相似度,从而驱动我们的算法产生非常接近真实分布P的分布。当用Q逼近P时损失的信息量与用P逼近Q时不相同,因此KL散度不是对称的。
构建自定义损失函数
PyTorch为我们提供了两种流行的方法来构建适合我们问题的自定义损失函数;这两种方法分别是用类实现和使用函数实现。让我们看看如何从函数实现开始实现这两种方法。
这无疑是编写自定义损失函数的最简单方法。这和创建一个函数一样简单,只需要传入所需的输入和其他参数,使用PyTorch的核心API或功能API执行某些操作,然后返回一个值。让我们通过一个自定义均方误差示例来看看。
在上述代码中,我们定义了一个自定义损失函数,以计算给定预测张量和目标张量时的均方误差
我们可以使用我们的自定义损失函数和PyTorch的MSE损失函数来观察我们得到了相同的结果。
使用Python类的自定义损失
这种方法可能是PyTorch中定义自定义损失的标准和推荐方法。通过从nn模块派生来创建损失函数,作为神经网络图中的一个节点。这意味着我们的自定义损失函数和卷积层一样,是一个PyTorch层。让我们看看这个示例是如何工作的。
最终思考
我们已经讨论了PyTorch中可用的许多损失函数,并深入了解了这些损失函数的内部工作原理。为特定问题选择合适的损失函数可能是一项令人望而却步的任务。希望,本教程以及官方PyTorch文档能为您提供指导,帮助您理解哪种损失函数更适合您的问题。
Source:
https://www.digitalocean.com/community/tutorials/pytorch-loss-functions