nn.CrossEntropyLoss总结

  • 目录
    • nn.CrossEntropyLoss
    • nn.LogSoftmax
    • nn.NLLLoss
    • Cross entropy

目录

  • 版本

    • pytorch 1.0

nn.CrossEntropyLoss

  • 依次执行nn.LogSoftmax(), nn.NLLLoss()两个操作(不了解可见下面的section)

  • 调用时输入参数:

    • Input:(N,C)(N,C)(N,C),N是样本数,C是类别数
    • Target:(N)(N)(N),每个值在[0,C−1][0,C-1][0,C−1]之间
  • 用于多分类任务

  • 数学公式
    loss(x,class)=−log⁡(exp⁡(x[class])∑jexp⁡(x[j]))=−x[class]+log⁡(∑jexp⁡(x[j]))loss(x,class)=-\log(\frac{\exp(x[class])}{\sum_j{\exp(x[j])}}) = -x[class] + \log(\sum_j{\exp(x[j])}) loss(x,class)=−log(∑j​exp(x[j])exp(x[class])​)=−x[class]+log(j∑​exp(x[j]))

    该公式计算的是样本xxx的预测向量和其真实类别classclassclass的误差,这里解释一下:CrossEntropy应该是CrossEntropy=−∑iNyi⋅log⁡(xi)\mathtt{CrossEntropy}=-\sum_i^{N}{y_i \cdot \log(x_i)}CrossEntropy=−∑iN​yi​⋅log(xi​),(这里公式不了解可见下面的section),这里讨论一个样本,一共有NNN个类别,对每个类别计算并求和,但N−1N-1N−1个yiy_iyi​都为0,就可以化简为CrossEntropy=−log⁡(xn)\mathtt{CrossEntropy}=- \log(x_n)CrossEntropy=−log(xn​)。

nn.LogSoftmax

  • Softmax() + Log
  • 数学公式(直接,清晰)
    LogSoftmax(xi)=log⁡(exp⁡(xi)∑jexp⁡(xj))\mathtt{LogSoftmax}(x_i) = \log(\frac{\exp(x_i)}{\sum_j{\exp(x_j)}}) LogSoftmax(xi​)=log(∑j​exp(xj​)exp(xi​)​)

nn.NLLLoss

  • 负对数似然损失(Negative Log Likelihood Loss),其实就是加个负号再求均值。

  • 用于:多分类任务

  • 调用时输入参数:

    • Input:(N,C)(N,C)(N,C),C=C=C=类别数
    • Target:(N)(N)(N),就是label,它的每个值在[0,C−1][0,C-1][0,C−1]范围内。
  • 数学公式
    l(x,y)=L={l1,…,lN}T,ln=−wynxn,yn,wc=weight[c]⋅1{c≠ignore_index}l(x,y)=L=\{l_1,\dots,l_N\}^T,l_n=-w_{y_n}x_{n,y_n},w_c=weight[c] \cdot 1\{c\neq {ignore\_index}\} l(x,y)=L={l1​,…,lN​}T,ln​=−wyn​​xn,yn​​,wc​=weight[c]⋅1{c​=ignore_index}

    其中,NNN表示batch大小,xn,ynx_{n,y_n}xn,yn​​表示第nnn个样本对其真实标签yny_nyn​的对数似然。而wcw_cwc​中,默认情况下,weightweightweight的值都为1,而c≠ignore_indexc\neq {ignore\_index}c​=ignore_index也是为真的。所以,可以将l_n化简为ln=−xn,yn.l_n=-x_{n,y_n}.ln​=−xn,yn​​.

    现在的LLL实际上是一个向量,默认的话,最后还会对其求均值。

  • 实例:

a = torch.rand(3,5)
# tensor([[0.4417, 0.2536, 0.6055, 0.3409, 0.3773],
#        [0.1164, 0.4653, 0.9451, 0.9057, 0.9112],
#        [0.1945, 0.3237, 0.9122, 0.6768, 0.4759]])
b = torch.rand(3,5)
label = torch.argmax(b,dim=-1)
# tensor([3, 0, 4])
nn.NLLLoss(reduce=False)(a,label)
# tensor([-0.3409, -0.1164, -0.4759])

Cross entropy

在信息论中,在相同潜在事件集合下,概率分布ppp和qqq间的交叉熵(Cross entropy)是指,当基于一个“非自然”(相对于“真实”分布ppp而言)的概率分布qqq进行编码时,在事件集合中唯一标识一个事件所需要的平均比特数(bit)。

给定两个概率分布ppp和qqq,ppp相对于qqq的交叉熵定义为:
H(p,q)=Ep[−log⁡q]=H(p)+DKL(p∣∣q)H(p,q) = E_p{[-\log{q}]}=H(p)+D_{KL}(p||q) H(p,q)=Ep​[−logq]=H(p)+DKL​(p∣∣q)
其中,H(p)H(p)H(p)是ppp的熵,DKL(p∣∣q)D_{KL}(p||q)DKL​(p∣∣q)是ppp与qqq的KL散度(相对熵)。


H(p)=∑x∈Xp(x)⋅log⁡(1p(x)),DKL(p∣∣q)=∑x∈Xp(x)⋅log⁡p(x)q(x)H(p)=\sum_{x \in X}p(x)\cdot\log(\frac{1}{p(x)}), D_{KL}(p||q)=\sum_{x \in X}p(x)\cdot\log\frac{p(x)}{q(x)} H(p)=x∈X∑​p(x)⋅log(p(x)1​),DKL​(p∣∣q)=x∈X∑​p(x)⋅logq(x)p(x)​
对于离散分布ppp和qqq,交叉熵定义为:
H(p,q)=−∑xp(x)log⁡q(x)H(p,q)=-\sum_x{p(x)\log{q(x)}} H(p,q)=−x∑​p(x)logq(x)

参考:

  1. torch.nn.CrossEntropyLoss
  2. (wiki)Cross_entropy
  3. (维基)交叉熵

nn.CrossEntropyLoss总结相关推荐

  1. nn.CrossEntropyLoss()

    用于多分类,直接写标签序号就可以:0,1,2. 预测需要维度与标签长度一致. import torch import torch.nn as nn import math criterion = nn ...

  2. nn.BCELoss与nn.CrossEntropyLoss的区别

    以前我浏览博客的时候记得别人说过,BCELoss与CrossEntropyLoss都是用于分类问题.可以知道,BCELoss是Binary CrossEntropyLoss的缩写,BCELoss Cr ...

  3. pytorch的nn.CrossEntropyLoss()函数使用方法

    nn.CrossEntropyLoss()函数计算交叉熵损失 用法: # output是网络的输出,size=[batch_size, class] #如网络的batch size为128,数据分为1 ...

  4. PyTorch之torch.nn.CrossEntropyLoss()

    简介 信息熵: 按照真实分布p来衡量识别一个样本所需的编码长度的期望,即平均编码长度 交叉熵: 使用拟合分布q来表示来自真实分布p的编码长度的期望,即平均编码长度 多分类任务中的交叉熵损失函数 代码 ...

  5. pytorch nn.CrossEntropyLoss

    应用 概念讲解 1)假设有m张图片,经过神经网络后输出为m*n的矩阵(m是图片个数,n是图片类别),下例中: m=2,n=2既有两张图片,供区分两种类别比如猫狗.假设第0维为猫,第1维为狗 impor ...

  6. PyTorch nn.CrossEntropyLoss() dimension out of range (expected to be in range of [-1, 0], but got 1)

    import torch import torch.nn as nn loss_fn = nn.CrossEntropyLoss() # 方便理解,此处假设batch_size = 1 x_input ...

  7. nn.BCELoss和nn.CrossEntropyloss

    nn.BCELoss和nn.CrossEntropyloss总结 nn.BCEloss 公式如下: 1.输入的X 代表模型的最后输出 y 代表你的label 我们的目的就是为了让模型去更好的学习lab ...

  8. 深入理解Pytorch负对数似然函数(torch.nn.NLLLoss)和交叉熵损失函数(torch.nn.CrossEntropyLoss)

    在看Pytorch的交叉熵损失函数torch.nn.CrossEntropyLoss官方文档介绍中,给出的表达式如下.不免有点疑惑为何交叉熵损失的表达式是这个样子的 loss ⁡ ( y , clas ...

  9. 对于pytorch中nn.CrossEntropyLoss()与nn.BCELoss()的理解和使用

    在pytorch中nn.CrossEntropyLoss()为交叉熵损失函数,用于解决多分类问题,也可用于解决二分类问题. BCELoss是Binary CrossEntropyLoss的缩写,nn. ...

  10. nn.CrossEntropyLoss的ignore_index标签(CE loss)

    例如我的pred是(b,2,w,h),而label索引是(b,1,w,h)的矩阵,其中只有0,1值,0值代表从pred的第0个通道选择像素值,1值代表从pred的第1个通道选择像素值. 而此时我发现因 ...

最新文章

  1. poj 1740 A New Stone Game 博弈
  2. java并发环境安全初始化
  3. Day 12: OpenCV —— Java开发者的人脸检测
  4. 小程序如何封装自定义组件(Toast)
  5. python数据分析实验报告_Python 数据分析入门实战
  6. 51单片机 | 模拟PWM调制控制实验
  7. mysql innobackupex 备份及恢复
  8. 易语言解析html实例,易语言总使用正则表达式实例解析
  9. 单目标跟踪算法:SiamRPN++
  10. 了解CompletableFuture
  11. 【日常实用篇】解决2345压缩软件自带的流氓广告
  12. 深入浅出计算机原理组成--->指令与运算——指令跳转(2)
  13. CentOS7安装kangle和easypanel
  14. [VN2020 公开赛]内存取证
  15. 1.1微信支付之现金红包 - Java 开发
  16. 显示测试漏光软件,屏幕漏光测试怎么做(液晶显示器屏幕漏光的检测方法)
  17. 浅谈网页设计中的构图
  18. 二维码怎么制作?手把手教你制作生成
  19. 【解决方案】t2gp.exe - 损坏的映像 | libcef.dll没有被指定在 Windows 上运行
  20. 信奥中的数学基础:分解质因数

热门文章

  1. 作为程序员,赚取额外收入的 4个简单副业!
  2. Qt5笔记之Qt5插件的生成与加载及json文件的读取
  3. 最好的vsftpd配置教程
  4. 谷胱甘肽(GSH)修饰的CdTe/CdS量子点(GSH-CdTe/CdSQDs)|PEG修饰水溶性量子点ZnS:Mn
  5. RabbitMQ之mandatory和immediate介绍
  6. 基于微信小程序的快递取件及上门服务
  7. Microsoft Visual Studio + Qt插件编程出现错误error MSB4184问题
  8. 共享单车、公交车辆位置、地铁等50+个交通数据集
  9. C语言程序设计 现代设计方法_第8章代码、练习题及编程题答案
  10. 2022-2028年全球与中国单过硫酸氢钾行业市场需求预测分析