点击上方,选择星标置顶,每天给你送干货

阅读大概需要23分钟

跟随小博主,每天进步一丢丢

福利:AI书籍免费领取

来自:AI蜗牛车

前言

对于一个深度学习的训练过程,可以将它描述为让网络输出值和实际值越来越接近的过程。我们通过训练优化器来完成这个过程,还需要一个评估函数来为我们的优化器指明方向。这个评估函数用来估量模型的预测值和真实值的不一致程度,也就是所谓的损失函数。Loss函数有很多,并且在很多的深度学习任务中,有时候是需要我们自行去根据任务相关来设计Loss函数的。

1. 回归任务中的损失函数

1.1 MAE loss(L1)

L1 Loss 是一个衡量输入x(模型预测输出)和目标y之间差的绝对值的平均值,也叫MAE Loss。由于L1 Loss 具有稀疏性,为了惩罚较大的值,因此常常将其作为正则项添加到其他Loss中作为约束。L1 Loss的最大问题是梯度在0点不平滑,导致会跳过极小值。在Pytorch中,L1 Loss的实例化类为:

class torch.nn.L1Loss(size_average=None, reduce=None, reduction='mean')

其中N为一个batch的样本数,参数reduction控制batch的loss取每个样本L1 loss的均值还是总和,缺省为mean。

1.2 MSE loss(L2)

L2 Loss是输入x(模型预测输出)和目标y之间均方误差,所以也叫做MSE Loss:同样,L2 Loss也常常作为正则项。当y和f(x)也就是真实值和预测值的差值大于1时,会放大误差;而当差值小于1时,则会缩小误差,这是平方运算决定的。MSE对于较大的误差(>1)给予较大的惩罚,较小的误差(<1)给予较小的惩罚。也就是说,对离群点比较敏感,受其影响较大。如果样本中存在离群点,MSE会给离群点更高的权重,这就会牺牲其他正常点数据的预测效果,最终降低整体的模型性能。在Pytorch中,L2 Loss的实例化类为:

class torch.nn.MSELoss(size_average=None, reduce=None, reduction='mean')

同样,N为一个batch的样本数,参数reduction控制batch的loss取每个样本L2 loss的均值还是总和,缺省为mean。

1.3 选MSE还是MAE?

L1 Loss作为损失函数更稳定,并且对离群值不敏感,而且 L1 Loss 在0处不可导,大部分情况下梯度都是相等的,这意味着即使对于小的损失值,其梯度也是大的。这不利于函数的收敛和模型的学习。另外,在深度学习中,收敛较慢。

L2 Loss导数求解速度高,但是其对离群值敏感,不过可以将离群值的导数设为0(导数值大于某个阈值)来避免这种情况。

在实际的应用中,这两种损失函数的选择要视情况而定:从计算机求解梯度的复杂度来说,MSE 要优于 MAE,而且梯度也是动态变化的,能较快准确达到收敛。但是从离群点角度来看,如果离群点是实际数据或重要数据,而且是应该被检测到的异常值,那么我们应该使用MSE。另一方面,离群点仅仅代表数据损坏或者错误采样,无须给予过多关注,那么我们应该选择MAE作为损失。

1.4  Huber loss 和 Smooth L1  loss

Huber loss结合了MSE和MAE,定义如下:Huber Loss 包含了一个超参数 δ。δ 值的大小决定了 Huber Loss 对 MSE 和 MAE 的侧重性,当 |y−f(x)| ≤ δ 时,变为 MSE;当 |y−f(x)| > δ 时,则变成类似于 MAE,因此 Huber Loss 同时具备了 MSE 和 MAE 的优点,减小了对离群点的敏感度问题,实现了处处可导的功能。Smooth L1  loss就是Huber loss的参数δ取值为1时的形式。在Faster R-CNN以及SSD中对边框的回归使用的损失函数都是Smooth L1  loss。Smooth L1 Loss 能从两个方面限制梯度:

  • 1.当预测框与 ground truth 差别过大时,梯度值不至于过大

  • 2.当预测框与 ground truth 差别很小时,梯度值足够小

从上面可以看出,Smooth L1 loss函数实际上就是一个分段函数,在[-1,1]之间实际上就是L2损失,这样解决了L1的不光滑问题,在[-1,1]区间外,实际上就是L1损失,这样就解决了离群点梯度爆炸的问题。

在Pytorch中,Smooth L1 Loss的实例化类为:

class torch.nn.SmoothL1Loss(size_average=None, reduce=None, reduction='mean')

2. 分类任务中的损失函数

2.1 交叉熵损失

2.1.1  什么是交叉熵损失?(举例)

在一个多分类任务中,交叉熵损失函数是非常常见的,其定义如下:其中:

  • [




    ] ——类别的数量;

  • [







    ] ——指示变量(0或1),如果该类别和样本的类别相同就是1,否则是0;

  • [







    ] ——对于观测样本属于类别 [c] 的预测概率。

交叉熵,实际上就是真实标签和预测标签两个分布的交叉熵。举个例子: 假设一个5分类问题,然后一个样本I的标签







=[0,0,0,1,0],也就是说样本I的真实标签是4:

  • 假设模型预测的结果概率







    =[0.1,0.15,0.05,0.6,0.1],可以看出这个预测是对的,也就是类别4,那么对应的损失值为












  • 假设







    =[0.15,0.2,0.4,0.1,0.15],这个预测结果就很离谱了,因为真实标签是4,而你觉得这个样本是4的概率只有0.1(远不如其他概率高,如果是在测试阶段,那么模型就会预测该样本属于类别3),对应损失值L=-log(0.1)。

  • 假设







    =[0.05,0.15,0.4,0.3,0.1],这个预测结果虽然也错了,但是没有前面那个那么离谱,对应的损失L=-log(0.3)。根据log函数的性质,有-log(0.6) < -log(0.3) < -log(0.1)。可以看出预测错比预测对的损失要大,预测错得离谱比预测错得轻微的损失要大。

2.1.2 softmax loss

对于网络层中常用的softmax loss,其实,在交叉熵损失的公式里面,如果预测概率







是由softmax函数(softmax函数输出向量为样本在N个类别中,属于每个类别的概率)得到的。那么此时的softmax loss就是交叉熵loss。

2.1.3  Pytorch中的二分类交叉熵损失

在Pytorch中,交叉熵 Loss有几个函数,其中,二分类的交叉熵为:

1. class torch.nn.BCELoss(weight=None, size_average=None, reduce=None, reduction='mean')
2. class torch.nn.BCEWithLogitsLoss(weight=None, size_average=None, reduce=None, reduction='mean', pos_weight=None)

对于BCELoss,由于二分类样本的输出只有两维,所以有:其中参数reduction表示一个batch样本loss的统计方式,默认为均值统计。API提供权重参数weight来调整loss值,weight是和分类维度一样的tensor,一般weight默认即可。

BCEWithLogitsLoss相当于在BCELoss的基础上加了sigmoid层:这样做的好处是可以使用一个tricks:log_sum_exp ,使得数值结果更加稳定,实际任务时,二分类交叉熵损失建议使用BCEWithLogitsLoss。

2.1.4  Pytorch中的多分类交叉熵损失

多分类任务的交叉熵loss为:

class torch.nn.CrossEntropyLoss(weight=None, size_average=None, ignore_index=-100, reduce=None, reduction='mean')

Pytorch的CrossEntropyLoss实际上做了这么几件事情:

1.计算了一层softmax:

softmax函数会返回样本分类成每一个类别的概率分数,值在0~1之间。

2.将Softmax之后的结果取log,将乘法改成加法减少计算量,同时保障函数的单调性 .

3.上面的输出与Label对应的那个值拿出来,乘以权重weight(用于数据样本分布不均衡的调整),去掉负号,再求均值(reduction缺省为mean)。

Pytorch中也提供了两个函数:

1.  class torch.nn.LogSoftmax(dim=None)
2. class torch.nn.NLLLoss(weight=None, size_average=None, ignore_index=-100, reduce=None, reduction='mean')

而nn.CrossEntropyLoss的作用就相当于nn.LogSoftmax + nn.NLLLoss。

  • nn.LogSoftmax完成上面的步骤1-2:

  • nn.NLLLoss完成上面的步骤3(取出label对应的值):

这里一个需要注意的点是nn.CrossEntropyLoss已经做了一次softmax,所以它的input在之前不需要再在网络中添加一个softmax层了。

2.2 铰链损失(Hinge loss)

铰链损失的出名应用是作为SVM的损失函数,其名字来自于Hinge loss的图像:其中,









是预测值,为一概率分数,




是标签值。与0-1损失相比,Hinge loss的图像如下:同样对于多分类问题,Pytorch提供如下函数表示多分类hinge loss:

class torch.nn.MultiMarginLoss(p=1, margin=1.0, weight=None, size_average=None, reduce=None, reduction='mean')

其中次数p一般缺省为1。

weight为根据样本类别分布而设置的权重,可选择性设置。margin为hinge的阈值,就像图像表示的函数,1也是margin值。







为该样本错误预测的得分,







为正确预测的得分。两者的差值可用来表示两种预测结果的相似关系,margin是一个由自己指定的安全系数。

我们希望正确预测的得分高于错误预测的得分,且高出一个边界值 margin,换句话说,







越高越好,







越低越好,(















)越大越好,(















)越小越好,但二者得分之差最多为margin就足够了,差距更大并不会有任何奖励。这样设计的目的在于,对单个样本正确分类只要有margin的把握就足够了,更大的把握则不必要,过分注重单个样本的分类效果反而有可能使整体的分类效果变坏。分类器应该更加专注于整体的分类误差。

2.3 KL散度

KL散度也被称为相对熵,常被用于生成模型,比如GAN。在信息论中,关于熵有如下表述:

  • 熵:可以表示一个事件P包含多少信息。

  • KL散度:可以表述事件P和事件P的拟合事件Q有多大不同

  • 交叉熵:可以表述从事件P的角度如何去描述P的拟合事件Q。

前面说到的交叉熵,便是表达了预测事件和真实事件的相关程度,同样,KL散度也同样能描述两个时间分布的关系,并作为损失函数使用。上面公式是描述连续型事件分布的KL散度公式,不难发现,第一项便是之前说到的交叉熵的连续型,而后一项则是熵本身的定义,反映了事件P的信息量大小,所以,对于真实事件P和预测事件Q,熵,相对熵(KL散度),交叉熵有如下关系:*** P与Q的交叉熵 = P与Q的KL散度 - P的熵*** Pytorch中提供KLDivLoss函数来表述离散型KL散度损失:

class torch.nn.KLDivLoss(size_average=None, reduce=None, reduction='mean')

对于一个N个样本的batch,KL散度损失做如下计算:参数reduction控制batch的loss取每个样本loss的均值还是总和,缺省为mean。

2.4 Triplet loss

Triplet loss用于训练差异性较小的样本,最初出现在FaceNet的论文中:FaceNet: A Unified Embedding for Face Recognition and Clustering (https://arxiv.org/pdf/1503.03832.pdf),可以学到较好的人脸的embedding。Triplet loss的输入是一个三元组:(anchor,positive, negative),其中,从训练数据集中随机选一个样本,该样本称为anchor,然后随机选取和anchor同类的样本positive和不同类的样本negative。下图是人脸embedding产生的Triplet loss:训练模型使得Triplet loss最小就是拉近同类(anchor,positive)距离,拉远异类(anchor,negative)距离,如下图:Triplet loss的公式如下:在训练的时候会得到很多的三元组(a,p,n),他们可以分为以下几类:

  • easy triplets:loss = 0,d(a, p) + margin < d(a, n),ap对的距离远远小于an对的距离。即,类内距离很小,类间很大距离,这种情况不需要优化。

  • hard triplets:d(a, n)   <  d(a, p) ,ap对的距离大于于an对的距离,即类内距离大于类间距离。这种情况比较难优化。

  • semi-hard triplets:d(a, p) < d(a, n) < d(a, p) + margin。ap对的距离和an对的距离比较高近。即,<a, p>和<a, n>很近,但都在一个margin内。比较容易优化。

一般在训练的时候是随机选取semi-hard triplets 进行训练的,但早期为了网络loss平稳,一般选择easy triplets进行优化,后期为了优化训练关键是要选择hard triplets,他们是活跃的,因此可以帮助改进模型。Triplet loss有两种训练方法,

  1. OFFLINE将训练集所有数据经过计算得到对应的 embeddings, 然后再计算 triplet loss,这种方式的效率不高,因为要遍历所有的数据得到三元组。

  2. ONLINE在ReID的论文:In Defense of the Triplet Loss for Person Re-Identification中使用了这样的方式。在训练时,分为Batch All和Batch Hard。Batch All计算了一个batch中所有val的的hard triplet 和 semi-hard triplet, 然后取平均得到Triplet loss。而Batch Hard则是对于每一个anchor,都选择距离最大的d(a, p) 和距离最大的d(a, n)。论文中选择Batchhard,随机抽取P个人,每个人K张图片形成一个batch,每个人的K张图片之间形成K*(K-1)个ap对,再在剩下其他人里取一个与该ap距离最近的negative,组成apn组并将apn组按照下面式子中的公式取模型里进行训练,使得下面的式子最小化。

在这里插入图片描述

Pytorch中提供TripletMarginLoss函数来实现Triplet loss,其中p为距离范数,默认为2,即2-范数:

class torch.nn.TripletMarginLoss(margin=1.0, p=2.0,size_average=None, reduce=None, reduction='mean')

3. PyTorch 如何自定义损失函数

关于PyTorch 如何自定义损失函数?总的来说,大体有三种方法:

3.1 调用torch.Tensor的原生接口

和一般的自定义函数一样只需要在__init__()里面定义好超参数,再在forward里写好计算过程就可以了。因为继承了nn.Module,所以这个loss类在实例化之后可以直接运行__call__()方法。这里以center loss为例(center loss来自于ECCV 2016 的一篇论文,被使用在ReID任务中,论文地址)

import torch as t
import torch.nn as nn
import torch.nn.functional as Fclass CenterLoss(nn.Module):def __init__(self,cls_num,featur_num):super().__init__()self.cls_num = cls_numself.featur_num=featur_numself.center = nn.Parameter(t.rand(cls_num,featur_num))def forward(self, xs,ys):   self.center_exp = self.center.index_select(dim=0,index=ys.long())count = t.histc(ys,bins=self.cls_num,min=0,max=self.cls_num-1)self.count_dis = count.index_select(dim=0,index=ys.long())+1loss = t.sum(t.sum((xs-self.center_exp)**2,dim=1)/2.0/self.count_dis.float())return loss

3.2 Pytorch使用numpy/scipy扩展

原生接口提供了torch.nn.functional模块来代替一些函数操作,当该模块功能不能满足自定义函数的功能实现要求时,我们可以先将tensor转换为numpy,再使用numpy/scipy来实现函数功能,最后再返回tensor。下面是Pytorch官网给出的使用numpy/scipy扩展自定义快速傅里叶变换的案例:

import torch
from torch.autograd import Function
from numpy.fft import rfft2, irfft2class BadFFTFunction(Function):@staticmethoddef forward(ctx, input):numpy_input = input.detach().numpy()  #先转换成Numpyresult = abs(rfft2(numpy_input))return input.new(result)@staticmethoddef backward(ctx, grad_output):numpy_go = grad_output.numpy()result = irfft2(numpy_go)return grad_output.new(result)def incorrect_fft(input):return BadFFTFunction.apply(input)input = torch.randn(8, 8, requires_grad=True)
result = incorrect_fft(input)
print(result)
result.backward(torch.randn(result.size()))   #返回tensor
print(input)'''
运算结果:
tensor([[ 5.9226,  6.3469,  4.8813,  8.1814,  3.2829],[ 7.1644,  6.0789,  3.7858,  6.2823,  2.3036],[ 2.1042,  5.3961,  1.2794, 13.0508,  7.8831],[11.0937,  6.8053,  4.2092,  2.3636,  4.6894],[11.4806,  9.2691,  2.2958,  4.5882, 15.1742],[11.0937, 12.0654,  7.2496,  4.1519,  4.6894],[ 2.1042,  5.0917,  4.5153, 12.4071,  7.8831],[ 7.1644,  4.8029,  7.4481,  1.6412,  2.3036]],grad_fn=<BadFFTFunctionBackward>)
tensor([[ 0.6201,  1.2492,  0.1847,  0.1239,  0.0867,  0.1096,  1.1102,  0.8024],[-1.7598,  0.3906,  1.2448, -0.1645, -0.7275, -1.7156, -0.7443, -1.5144],[-0.3217,  0.1188,  0.1551, -0.9676, -0.8834,  0.8660,  0.2944,  2.7816],[-0.0698,  0.3642, -1.0339, -0.1114,  0.0208, -1.3441,  0.0184,  0.1927],[ 0.1153,  1.5583, -0.9675,  0.3124, -0.2498, -0.5960,  1.4346, -0.5523],[ 0.6704, -0.1076,  0.6561,  0.4233,  1.0294, -0.4443,  0.2737,  0.7467],[-0.1177,  0.6641,  0.8596,  0.5245, -0.4537,  0.8934, -2.1302,  1.0770],[-0.5317,  0.0276,  0.5124, -0.3272,  0.8176, -0.0871,  1.2068, -0.6912]],requires_grad=True)
'''

参考

[1].https://www.cnblogs.com/wangguchangqing/p/12021638.html

[2].https://blog.csdn.net/wonengguwozai/article/details/74066157

[3].https://msd.misuland.com/pd/2884250171976192486

[4].https://mp.weixin.qq.com/s/Xbi5iOh3xoBIK5kVmqbKYA

[5].https://blog.csdn.net/weixin_40671425/article/details/98068190

[6].https://blog.csdn.net/weixin_45191152/article/details/97762005

投稿或交流学习,备注:昵称-学校(公司)-方向,进入DL&NLP交流群。
方向有很多:机器学习、深度学习,python,情感分析、意见挖掘、句法分析、机器翻译、人机对话、知识图谱、语音识别等。记得备注呦让更多的人知道你“在看”

【Pytorch版本】损失函数大汇总相关推荐

  1. 不同PyTorch版本训练同一个代码结果差异巨大

    问题描述 笔者在训练一个深度学习网络时,发现使用不同的PyTorch版本运行同一个训练代码,训练出来的网络结果差异巨大.具体来说,笔者训练得到的结果如下所示: PyTorch版本 Torchvisio ...

  2. 【原创】强化学习笔记|从零开始学习PPO算法编程(pytorch版本)

    从零开始学习PPO算法编程(pytorch版本)_melody_cjw的博客-CSDN博客_ppo算法 pytorch 从零开始学习PPO算法编程(pytorch版本)(二)_melody_cjw的博 ...

  3. 超轻量目标检测模型NanoDet(速度很快)PyTorch版本实践

    文章目录 前言 NanoDet 模型介绍 1)NanoDet 模型性能 2)NanoDet 模型架构 3)NanoDet损失函数 4)NanoDet 优势 基于PyTorch 实现NanoDet 1) ...

  4. 47种常见的浏览器兼容性问题大汇总

    浏览器兼容性问题大汇总 Ø JavaScript 3 1. HTML对象获取问题 3 2. const问题 3 3. event.x与event.y问题 3 4. window.location.hr ...

  5. ANDROID内存优化(大汇总——中)

    转载请注明本文出自大苞米的博客(http://blog.csdn.net/a396901990),谢谢支持! 写在最前: 本文的思路主要借鉴了2014年AnDevCon开发者大会的一个演讲PPT,加上 ...

  6. (转载)浏览器兼容性问题大汇总

    浏览器兼容性问题大汇总 Ø JavaScript 3 1. HTML对象获取问题 3 2. const问题 3 3. event.x与event.y问题 3 4. window.location.hr ...

  7. c++ 字符串合并_C语言输入字符和字符串(所有函数大汇总)

    C语言输入字符和字符串(所有函数大汇总) C语言有多个函数可以从键盘获得用户输入,它们分别是: scanf():和 printf() 类似,scanf() 可以输入多种类型的数据. getchar() ...

  8. 比特飞解决方案大汇总,你的贴身家教

    该文章的最新版本已迁移至个人博客[比特飞],单击链接:.Net中文网解决方案大汇总,你的贴身家教 | .Net中文网. 概述 本系列文章将会向大家介绍本人实际开发过程中所遇的解决方案大汇总,旨在抛砖引 ...

  9. 利用公式画图_【高中数学】重要公式大汇总!

    公式口诀 一.集合与函数 内容子交并补集,还有幂指对函数. 性质奇偶与增减,观察图象最明显. 复合函数式出现,性质乘法法则辨, 若要详细证明它,还须将那定义抓. 指数与对数函数,两者互为反函数. 底数 ...

  10. python3廖雪峰云-python3基础教程廖雪峰云_Python GUI库大汇总

    Python GUI库大汇总 所有程序都是基于命令行的,这序可能只有一些"专的计算机人士才会使用.例如前面编写的五等程序,恐怕只有程序员自己才愿意玩这么"糟糕"的游戏,很 ...

最新文章

  1. 双十一,程序员前女友发来消息。。。
  2. BZOJ4298 : [ONTAK2015]Bajtocja
  3. How Kafka’s Storage Internals Work
  4. mysql sqlserver 性能优化_SQLSERVER SQL性能优化技巧
  5. 【增强】批次特性增强案例
  6. mac搭建本地svn
  7. 阿里云推出“磐久”云原生服务器系列 能效和交付效率大幅提升
  8. 解决手机端上的iframe无法触摸滚动
  9. 【学习笔记】深入理解Linux内核第三版 ——第二章 内存寻址
  10. 分享个最终幻想勇气启示录脚本,手游上能一键推图自动升级
  11. maven+Tomcat热部署
  12. 神经网络学习笔记4:CPN网络的实现
  13. outlook图片显示红叉
  14. 玩转华为ENSP模拟器系列 | 配置L3VdPdNd迭代SR-BE隧道示例
  15. 爬虫python创意_爬虫案例:利用python爬虫关键词批量下载高清大图
  16. Linux 自签名ssl证书生成
  17. Web安全技术—常见的攻击和防御
  18. 常见的几种ADSL 路由器的端口映射方法
  19. Google Pay 谷歌支付(gateway = stripe)
  20. 一个大数据架构师应该掌握的技能

热门文章

  1. JavaScript_HTML DEMO_2_事件
  2. Trie树 01Trie
  3. 移植u-boot.2012.04.01
  4. Nodejs express、html5实现拖拽上传(转载)
  5. UVa10047 BFS
  6. java day56【 Mybatis 延迟加载策略 、 Mybatis 缓存、Mybatis 注解开发 】
  7. 常用正则表达式(regular expression)
  8. java day09【继承、super、this、抽象类】
  9. Git 与 GitHub
  10. 【CSS3】 线性渐变