点击上方“视学算法”,选择加"星标"或“置顶

重磅干货,第一时间送达

作者 | 小新

来源 | https://lhyxx.top

编辑 | 极市平台

导读

本文从理论和实践两方面来全面梳理一下常用的损失函数。(避免自己总是一瓶子不满半瓶子晃荡……)。要么理论满分,编码时不会用;要么编码是会调包,但是不明白其中的计算原理。本文来科普一下。

本文从理论和实践两方面来全面梳理一下常用的损失函数。(避免自己总是一瓶子不满半瓶子晃荡……)。要么理论满分,编码时不会用;要么编码是会调包,但是不明白其中的计算原理。本文来科普一下。

我们将每个损失函数分别从理论和pytorch中的实现两个方面来拆解一下。

另外,解释一下torch.nn.Module 和 torch.nn.functional(俗称F)中损失函数的区别。

Module的损失函数例如CrossEntropyLoss、NLLLoss等是封装之后的损失函数类,是一个类,因此其中的变量可以自动维护。经常是对F中的函数的封装。而F中的损失函数只是单纯的函数。

当然我们也可以自己构造自己的损失函数对象。有时候损失函数并不需要太复杂,没有必要特意封装一个类,直接调用F中的函数也是可以的。使用哪种看具体实现需求而定。

CrossEntropyLoss

交叉熵损失,是分类任务中最常用的一个损失函数。

理论

直接上理论公式:

其中 是真实标签, 是预测的类分布(通常是使用softmax将模 型输出转换为概率分布), 也就是 与 中的元素分别表示对应类 别的概率。

举个例子,清晰明了:

# 假设该样本属于第二类 # 因为是分布, 所以属于各个类的和为 1

pytorch-实现

from torch.nn import CrossEntropyLoss

举例:

实际使用中需要注意几点:

  • torch.nn.CrossEntropyLoss(input, target)中的标签target使用的不是one-hot形式,而是类别的序号。形如 target = [1, 3, 2] 表示3个样本分别属于第1类、第3类、第2类。

  • torch.nn.CrossEntropyLoss(input, target)input没有归一化的每个类的得分,而不是softmax之后的分布。

举例,输入的形式大概就像相面这种格式:

然后就将他们扔到CrossEntropyLoss函数中,就可以得到损失。

loss = CrossEntropyLoss(input, target)

我们看CrossEntropyLoss函数里面的实现,是下面这样子的:

def forward(self, input, target):return F.cross_entropy(input, target, weight=self.weight,ignore_index=self.ignore_index, reduction=self.reduction)

是调用的torch.nn.functional(俗称F)中的cross_entropy()函数。

参数

  • input:预测值,(batch,dim),这里dim就是要分类的总类别数

  • target:真实值,(batch),这里为啥是1维的?因为真实值并不是用one-hot形式表示,而是直接传类别id。

  • weight:指定权重,(dim),可选参数,可以给每个类指定一个权重。通常在训练数据中不同类别的样本数量差别较大时,可以使用权重来平衡。

  • ignore_index:指定忽略一个真实值,(int),也就是手动忽略一个真实值。

  • reduction:在[none, mean, sum]中选,string型。none表示不降维,返回和target相同形状;mean表示对一个batch的损失求均值;sum表示对一个batch的损失求和。

其中参数weight、ignore_index、reduction要在实例化CrossEntropyLoss对象时指定,例如:

loss = torch.nn.CrossEntropyLoss(reduction='none')

我们再看一下F中的cross_entropy的实现

return nll_loss(log_softmax(input, dim=1), target, weight, None, ignore_index, None, reduction)

可以看到就是先调用log_softmax,再调用nll_loss

log_softmax就是先softmax再取log

nll_loss 是negative log likelihood loss:

详细介绍见下面torch.nn.NLLLoss,计算公式如下:

例如假设 , class ,则,class class

源码中给了个用法例子:

# input is of size N x C = 3 x 5
input = torch.randn(3, 5, requires_grad=True)
# each element in target has to have 0 <= value < C
target = torch.tensor([1, 0, 4])
output = F.nll_loss(F.log_softmax(input), target)
output.backward()

因此,其实CrossEntropyLoss损失,就是softmax + log + nll_loss的集成。

CrossEntropyLoss(input, target) = nll_loss(log_softmax(input, dim=1), target)

CrossEntropyLoss中的target必须是LongTensor类型。

实验如下:

pred = torch.FloatTensor([[2, 1], [1, 2]])
target = torch.LongTensor([1, 0])loss_fun = nn.CrossEntropyLoss()loss = loss_fun(pred, target)
print(loss)  # 输出为tensor(1.3133)
loss2 = F.nll_loss(F.log_softmax(pred, dim=1), target)
print(loss2)  # 输出为tensor(1.3133)

数学形式就是:

torch-nn-BCELoss

理论

CrossEntropy损失函数适用于总共有N个类别的分类。当N=2时,即二分类任务,只需要判断是还是否的情况,就可以使用二分类交叉熵损失:BCELoss 二分类交叉熵损失。上公式 (y是真实标签,x是预测值)

其实这个函数就是CrossEntropyLoss的当类别数N=2时候的特例。因为类别数为2,属于第一类的概率为y,那么属于第二类的概率自然就是(1-y)。因此套用与CrossEntropy损失的计算方法,用对应的标签乘以对应的预测值再求和,就得到了最终的损失。

实践

torch.nn.BCELoss(x, y)

x形状(batch,*),y形状与x相同。

x与y中每个元素,表示的是该维度上属于(或不属于)这个类的概率。

另外,pytorch中的BCELoss可以为每个类指定权重。通常,当训练数据中正例和反例的比例差别较大时,可以为其赋予不同的权重,weight的形状应该是一个一维的,元素的个数等于类别数。

实际使用如下例,计算BCELoss(pred, target):

pred = torch.FloatTensor([0.4, 0.1])  # 可以理解为第一个元素分类为是的概率为0.4,第二个元素分类为是的概率为0.1。
target = torch.FloatTensor([0.2, 0.8])  # 实际上第一个元素分类为是的概率为0.2,第二个元素分类为是的概率为0.8。
loss_fun = nn.BCELoss(reduction='mean')  # reduction可选 none, sum, mean, batchmean
loss = loss_fun(pred, target)
print(loss)  # tensor(1.2275)a = -(0.2 * np.log(0.4) + 0.8 * np.log(0.6) + 0.8 * np.log(0.1) + 0.2 * np.log(0.9))/2
print(a)  # 1.2275294114572126

可以看到,计算BCELoss(pred,target)与上面理论中的公式一样。

内部实现

pytorch 中的torch.nn.BCELoss类,实际上就是调用了F.binary_cross_entropy(input, target, weight=self.weight, reduction=self.reduction)

torch.nn.BCEWithLogitsLoss

理论

该函数实际上与BCELoss相同,只是BCELoss的输入x,在输入之前需要先手动经过sigmoid激活函数映射到(0, 1)区间,而该函数将sigmoid与BCELoss整合到一起了

也就是先将输入经过sigmoid函数,然后计算BCE损失。

实践

torch.nn.BCEWithLogitsLoss(x, y)

x与y的形状要求与BCELoss相同。

pred = torch.FloatTensor([0.4, 0.1])
target = torch.FloatTensor([0.2, 0.8])
loss_fun = nn.BCEWithLogitsLoss(reduction='mean')  # reduction可选 none, sum, mean, batchmean
loss = loss_fun(pred, target)
print(loss)  # tensor(0.7487)# 上面的过程与下面的过程结果相同
loss_fun = nn.BCELoss(reduction='mean')  # reduction可选 none, sum, mean, batchmean
loss = loss_fun(torch.sigmoid(pred), target)  # 先经过sigmoid,然后与target计算BCELoss
print(loss)  # tensor(0.7487)

可以看出,先对输入pred调用sigmoid,在调用BCELoss,结果就等于直接调用BCEWithLogitsLoss。

torch.nn.L1Loss

理论

L1损失很简单,公式如下:

x是预测值,y是真实值。

实践

torch.nn.L1Loss(x, y)

x形状:任意形状

y形状:与输入形状相同

pred = torch.FloatTensor([[3, 1], [1, 0]])
target = torch.FloatTensor([[1, 0], [1, 0]])
loss_fun = nn.L1Loss()
loss = loss_fun(pred, target)
print(loss)  # tensor(0.7500)

其中L1Loss的内部实现为:

def forward(self, input, target):return F.l1_loss(input, target, reduction=self.reduction)

我们可以看到,其实还是对F.l1_loss的封装。

torch.nn.MSELoss

理论

L1Loss可以理解为向量的1-范数,MSE均方误差就可以理解为向量的2-范数,或矩阵的F-范数。

x是预测值,y是真实值。

实践

torch.nn.MSELoss(x, y)

x任意形状,y与x形状相同。

pred = torch.FloatTensor([[3, 1], [1, 0]])
target = torch.FloatTensor([[1, 0], [1, 0]])
loss_fun = nn.MSELoss()
loss = loss_fun(pred, target)
print(loss)  # tensor(1.2500)

其中MSELoss内部实现为:

def forward(self, input, target):return F.mse_loss(input, target, reduction=self.reduction)

本质上是对F中mse_loss函数的封装。

torch.nn.NLLLoss

理论

NLLLoss(Negative Log Likelihood Loss),其数学表达形式为:

前面讲到CrossEntropyLoss中用的nll_loss,实际上,该损失函数就是对F.nll_loss的封装,功能也和nll_loss相同。

正如前面所说,先把输入x进行softmax,在进行log,再输入该函数中就是CrossEntropyLoss

实践

torch.nn.NLLLoss(x, y)

x是预测值,形状为(batch,dim)

y是真实值,形状为(batch)

形状要求与CrossEntropyLoss相同。

pred = torch.FloatTensor([[3, 1], [2, 4]])
target = torch.LongTensor([0, 1])  #target必须是Long型
loss_fun = nn.NLLLoss()
loss = loss_fun(pred, target)
print(loss)  # tensor(-3.5000)

其内部实现实际上就是调用了F.nll_loss():

def forward(self, input, target):return F.nll_loss(input, target, weight=self.weight, ignore_index=self.ignore_index, reduction=self.reduction)

torch.nn.KLDivLoss

理论

KL散度通常用来衡量两个连续分布之间的距离。两个分布越相似,KL散度越接近0。

KL散度又叫相对熵,具体理论可以参考:https://lhyxx.top/2019/09/15/%E4%BF%A1%E6%81%AF%E8%AE%BA%E5%9F%BA%E7%A1%80-%E7%86%B5/

注意,这里 x 与 y 都是分布,分布就意味着其中所有元素求和概率为1。

例如

则:

本例中计算的 都是以e为底的。

实践

torch.nn.KLDivLoss(input, target)

试验测试torch.nn.KLDivLoss,计算KL(pred|target)

pred = torch.FloatTensor([0.1, 0.2, 0.7])
target = torch.FloatTensor([0.5, 0.2, 0.3])
loss_fun = nn.KLDivLoss(reduction='sum')  # reduction可选 none, sum, mean, batchmean
loss = loss_fun(target.log(), pred)
print(loss)  # tensor(0.4322)#上面的计算过程等价于下面
a = (0.1 * np.log(1/5) + 0.2 * np.log(1) + 0.7 * np.log(7/3))
print(a)  # 0.43216

input应该是log-probabilities,target是probabilities。inputtarget形状相同。

该函数是对F.kl_div(input, target, reduction=self.reduction)的封装。其原型为:torch.nn.functional.kl_div(input, target, size_average=None, reduce=None, reduction='mean')

注意,使用nn.KLDivLoss计算KL(pred|target)时,需要将predtarget调换位置,而且target需要先取对数:

loss_fun(target.log(), pred)

如果觉得有用,就请分享到朋友圈吧!

点个在看 paper不断!

实操教程|Pytorch常用损失函数拆解相关推荐

  1. Pytorch 常用损失函数拆解

    作者 | 小新 ,编辑 | 极市平台 来源 | https://lhyxx.top 本文从理论和实践两方面来全面梳理一下常用的损失函数.(避免自己总是一瓶子不满半瓶子晃荡--).要么理论满分,编码时不 ...

  2. 实操教程|PyTorch AutoGrad C++层实现

    点击上方"视学算法",选择加"星标"或"置顶" 重磅干货,第一时间送达 作者丨xxy-zhihu@知乎 来源丨https://zhuanla ...

  3. [转载]你们要的GIF动图制作全攻略!看完就会做!(实操教程)

    非常实用呀 原文地址:你们要的GIF动图制作全攻略!看完就会做!(实操教程)作者:木木老贼 来源:文案匠(ID:sun-work) 作者:一木(授权转载,如需转载请联系文案匠) 文章配图的GIF动图怎 ...

  4. 通过大白菜u盘启动工具备份/还原/重装/激活系统/修复引导 实操教程(上)

    通过大白菜u盘启动工具备份/还原/重装/激活系统/修复引导 实操教程(上) 前言 进入大白菜u盘的pe系统 用GHOST进行系统盘备份/还原 在D盘上安装新系统(以win10-2004为例) 镜像下载 ...

  5. 寻找亚马逊测评师邮箱_美国及欧盟亚马逊产品外观专利查询步骤实操教程(已验证)...

    亚马逊产品外观专利防不胜防:美国及欧盟外观专利查询步骤实操教程(已验证) 欧洲 https://www.tmdn.org/tmdsview-web/dsview-logo-white.15c95da2 ...

  6. 实操教程|火遍全网的剪纸风格究竟是怎么做出来的?

    原文来自公众号:希音的设计笔记 > 添加微信:xiyin0820 获取高质量样机 | C4D教程 | OC渲染教程 | Sketch教程 Adobe2021 | Adobe2020 | LED字 ...

  7. mysql教程乛it教程网_MySQL数据库实操教程(35)——完结篇

    版权声明 专栏概况 从2019年7月21日至今,约莫一个月的时间终于写完了MySQL教程,我已将其集结在专栏<MySQL数据库实操教程>,概述如下: 共计35篇文章 每篇文章均附源码和运行 ...

  8. MySQL数据库实操教程(35)——完结篇

    版权声明 本文原创作者:谷哥的小弟 作者博客地址:http://blog.csdn.net/lfdfhl 专栏概况 从2019年7月21日至今,约莫一个月的时间终于写完了MySQL教程,我已将其集结在 ...

  9. MetagenoNets:在线宏基因组网络分析实操教程

    宏基因组研究中网络分析已经十分普及,但却缺少整合的分析方法,限制了广大同行的使用. 关于网络分析的基本步骤,和现在工具的比较,详见原文解读 - NAR:宏基因组网络分析工具MetagenoNets 本 ...

最新文章

  1. AS更改初始布局遇到的问题
  2. 架构设计的目标与衡量
  3. mac 系统使用macaca inspector 获取iphone真机应用元素
  4. WAP2.0(XHTML MP)基础介绍
  5. php redis 签到,基于Redis位图实现用户签到功能
  6. IC设计常用文件及格式介绍
  7. 8月8日白暨豚宣告灭绝
  8. 管理感悟:学会推论及验证
  9. SSM常用面试题整理一
  10. vbs脚本学习整人Demo
  11. maya中英文对比_[转载]maya中英文对照
  12. 18获得触发事件元素节点的方法
  13. UE4超过20万个动画角色的优化实战
  14. 云集微店亿级交易额下的Order子系统架构演变
  15. MKR:协同过滤算法效果不佳,知识图谱来帮忙
  16. 2017福建省计算机一级应用技术,2017年一级计算机信息技术及应用考试试题级答案...
  17. c语言中十六进制乘以16啥意思,C语言16进制中16怎么表示?
  18. 用html制作百度地图,canvas实现百度地图个性化底图绘制
  19. 高阶篇:4.3)FTA故障树分析法-DFMEA的另外一张脸
  20. 基于SNMP/MIB的网络数据获取系统设计与实现(三)

热门文章

  1. js实现页面跳转的几种方式
  2. Hibernate 获取某个表全部记录时 奇怪现象 (重复出现某个记录)
  3. stella forum v 2.0 的两款主题样式
  4. Visual Studio UML Use Case Diagram(1)
  5. 【CTF】实验吧 Fair-Play
  6. python 第六章 函数 pta(1)
  7. GitLab 在中国成立公司极狐,GitHub 还会远吗?
  8. 网红“AI大佬”被爆论文剽窃,Jeff Dean都看不下去了
  9. 清华大学提出APDrawingGAN,人脸照片秒变艺术肖像画
  10. 公开课报名 | 基于自定义模板的OCR结果的结构化处理技术