众所周知,softmax+cross entropy是在线性模型、神经网络等模型中解决分类问题的通用方案,但是为什么选择这种方案呢?它相对于其他方案有什么优势?笔者一直也困惑不解,最近浏览了一些资料,有一些小小心得,希望大家指正~

损失函数:交叉熵Cross Entropy

我们可以从三个角度来理解cross entropy的物理意义

从实例上直观理解

我们首先来看Cross Entropy 的公式:

假设存在两个分布

为样本的真实分布,
为模型预测出的样本分布,则在给定的样本集
上,交叉熵的计算方式为

通常情况下在线性模型、神经网络等模型中,关于样本的真实分布可以用one-hot的编码来表示,比如男、女分别可以用[0,1]和[1,0]来表示,同样的,C种类别的样本可以用长度为C的向量来表示,且一个样本的表示向量中有且仅有一个维度为1,其余为0。那会造成什么后果呢?我们来看一个例子,假设一个样本的真实label为

,预测的分布为
,则交叉熵为:

如果预测分布为

,则交叉熵为:

可以看出其实

只与label中1所对应下标的预测值有关,且该预测值越大,

越小

只要label中1所对应下标的预测值越接近1,则损失函数越小,这在直观上就是符合我们对于损失函数的预期

交叉熵为什么比均方误差好

作为回归问题的常见损失函数,均方误差公式为

,好像也可以用来计算分类问题的损失函数,那它为什么不适合分类问题呢?我们再来看一个例子假设一个样本的真实label为[0,0,0,1,0],预测的分布为
,预测分布
,此时
,也就是说对于
而言,

即使与label中1所对应下标的预测值是正确的,其他项预测值的分布也会影响损失的大小,这不符合我们对于分类问题损失函数的预期

似然估计的视角

我们知道,对于一个多分类问题,给定样本

,它的似然函数可以表示为

其中

是模型预测的概率,
是对应类的label,
为类别的个数,那么其

负对数似然估计则为:

,
对应于
对应于
,其实

交叉熵就是对应于该样本的负对数似然估计

KL散度视角

KL散度又被称为相对熵,可以用来衡量两个分布之间的距离,想了解KL散度可以参考如何理解K-L散度。需要了解的是:KL散度越小,两个分布越相近。这么看KL散度是不是很符合我们对于两个分布损失函数的定义呢?

KL散度的公式为:

其中

的熵,注意这里的
是样本的真实分布,所以
为常数,因此,KL散度与交叉熵事实上是等价的,所以

交叉熵也可以用来衡量两个分布之间的距离,符合我们对于损失函数的期待

softmax+cross entropy到底学到了什么?

我们知道在回归问题中的最常用的损失函数是均方误差

,那么在反向传播时,
,即

均方误差在反向传播时传递的是预测值与label值的偏差,这显然是一个符合我们预期的、非常直觉的结果。

假定分类问题的最后一个隐藏层和输出层如下图所示

为最后一个隐藏层的C个类别,
为输出层,则有

因此softmax+cross entropy在反向传播时传递的同样是预测值与label值的偏差,即

,如果对于证明不感兴趣的,那么这篇文章就可以到此结束了~以下均为证明过程。

图中

,我们用
表示分母
,则

注意这里的

与所有的
都相关,因此需要用链式法则求导

下面求

的求导分为两种情况

时,

时,

代入上式得

注意这里
为所有label的和,应该等于1。

为什么要返回softmax_为什么softmax搭配cross entropy是解决分类问题的通用方案?相关推荐

  1. softmax ce loss_手写softmax和cross entropy

    import 解释下给定的数据,x假设是fc layer的输出,可以看到这里x是(3,3)的,也就是batch_size=3,n_classes=3.但是label给出了三个数,取值是0,1,因此这里 ...

  2. 卷积神经网络系列之softmax,softmax loss和cross entropy

    全连接层到损失层间的计算 先理清下从全连接层到损失层之间的计算. 这张图的等号左边部分就是全连接层做的事,W是全连接层的参数,我们也称为权值,X是全连接层的输入,也就是特征. 从图上可以看出特征X是N ...

  3. 卷积神经网络系列之softmax,softmax loss和cross entropy的讲解

    我们知道卷积神经网络(CNN)在图像领域的应用已经非常广泛了,一般一个CNN网络主要包含卷积层,池化层(pooling),全连接层,损失层等.虽然现在已经开源了很多深度学习框架(比如MxNet,Caf ...

  4. softmax,softmax loss和cross entropy

    我们知道卷积神经网络(CNN)在图像领域的应用已经非常广泛了,一般一个CNN网络主要包含卷积层,池化层(pooling),全连接层,损失层等.虽然现在已经开源了很多深度学习框架(比如MxNet,Caf ...

  5. cross entropy loss = log softmax + nll loss

    代码如下: import torchlogits = torch.randn(3,4,requires_grad=True) labels = torch.LongTensor([1,0,2]) pr ...

  6. 机器学习中交叉熵cross entropy是什么,怎么计算?

    项目背景:人体动作识别(分类),CNN或者RNN网络,softmax分类输出,输出为one-hot型标签. loss可以理解为预测输出pred与实际输出Y之间的差距,其中pred和Y均为one-hot ...

  7. python损失函数实现_pytorch 实现cross entropy损失函数计算方式

    均方损失函数: 这里 loss, x, y 的维度是一样的,可以是向量或者矩阵,i 是下标. 很多的 loss 函数都有 size_average 和 reduce 两个布尔类型的参数.因为一般损失函 ...

  8. 平均符号熵的计算公式_交叉熵(Cross Entropy)从原理到代码解读

    交叉熵(Cross Entropy)是Shannon(香浓)信息论中的一个概念,在深度学习领域中解决分类问题时常用它作为损失函数. 原理部分:要想搞懂交叉熵需要先清楚一些概念,顺序如下:==1.自信息 ...

  9. TensorFlow学习笔记(二十三)四种Cross Entropy交叉熵算法实现和应用

    交叉熵(Cross-Entropy) 交叉熵是一个在ML领域经常会被提到的名词.在这篇文章里将对这个概念进行详细的分析. 1.什么是信息量? 假设是一个离散型随机变量,其取值集合为,概率分布函数为 p ...

最新文章

  1. 交换机的VACL测试
  2. Python 0/1背包、动态规划
  3. java 读取网络图片_每日一学:如何读取网络图片
  4. 廖雪峰Java1-3流程控制-9break、continue
  5. 如何自己动手写一个搜索引擎?我是一份害羞的教程
  6. u-boot_NAND_Flash操作命令及烧录Linux内核和文件系统
  7. gulp.js 自动化构建工具学习入门
  8. 输入输出流——字符流部分
  9. 泛型(java菜鸟的课堂笔记)
  10. 英伟达CUDA 10终于开放下载了
  11. c语言调用python变量_在c中读取python的全局变量
  12. qcom charger
  13. 自己做量化交易软件(18)小白量化平台
  14. Kafka拉取某一个时间段內的消息
  15. 计算机专业背景的大学,不要求专业背景的计算机专业!
  16. java EE单例Singleton自启动
  17. java io流分为,Java中的IO流按照传输数据不同,可分为和
  18. 使用Python turtle快速实现七夕情人节礼物
  19. 漫画 | 外行对程序员误会有多深!
  20. COOX培训材料 — MTG

热门文章

  1. Dreamweaver Flash Photoshop网页设计综合应用 (智云科技) [iso] 1.86G​
  2. springcloud 入门 10 (eureka高可用)
  3. Shell 简单的java微服务jar包 -- 部署脚本
  4. Python中生成器generator和迭代器Iterator的使用方法
  5. excel中如何取消自动超链接?
  6. iPhone UITableViewCell如何滚动到视图顶端。
  7. zepto学习之路--源代码提取
  8. hdu 1564 Play a game
  9. commons-lang的FastDateFormat性能测试
  10. css中em与px的介绍及换算方法