交叉熵损失与均方误差损失

常规分类网络最后的softmax层如下图所示,传统机器学习方法以此类比,

一共有\(K\)类,令网络的输出为\([\hat{y}_1,\dots, \hat{y}_K]\),对应每个类别的概率,令label为 \([y_1, \dots, y_K]\)。对某个属于\(p\)类的样本,其label中\(y_p=1\),\(y_1, \dots, y_{p-1}, y_{p+1}, \dots, y_K\)均为0。

对这个样本,交叉熵(cross entropy)损失为

\[\begin{aligned}L &= - (y_1 \log \hat{y}_1 + \dots + y_K \log \hat{y}_K) \\&= -y_p \log \hat{y}_p \\ &= - \log \hat{y}_p\end{aligned}

\]

均方误差损失(mean squared error,MSE)为

\[\begin{aligned}L &= (y_1 - \hat{y}_1)^2 + \dots + (y_K - \hat{y}_K)^2 \\&= (1 - \hat{y}_p)^2 + (\hat{y}_1^2 + \dots + \hat{y}_{p-1}^2 + \hat{y}_{p+1}^2 + \dots + \hat{y}_K^2)\end{aligned}

\]

则\(m\)个样本的损失为

\[\ell = \frac{1}{m} \sum_{i=1}^m L_i

\]

对比交叉熵损失与均方误差损失,只看单个样本的损失即可,下面从两个角度进行分析。

损失函数角度

损失函数是网络学习的指挥棒,它引导着网络学习的方向——能让损失函数变小的参数就是好参数。

所以,损失函数的选择和设计要能表达你希望模型具有的性质与倾向。

对比交叉熵和均方误差损失,可以发现,两者均在\(\hat{y} = y = 1\)时取得最小值0,但在实践中\(\hat{y}_p\)只会趋近于1而不是恰好等于1,在\(\hat{y}_p < 1\)的情况下,

交叉熵只与label类别有关,\(\hat{y}_p\)越趋近于1越好

均方误差不仅与\(\hat{y}_p\)有关,还与其他项有关,它希望\(\hat{y}_1, \dots, \hat{y}_{p-1}, \hat{y}_{p+1}, \dots, \hat{y}_K\)越平均越好,即在\(\frac{1-\hat{y}_p}{K-1}\)时取得最小值

分类问题中,对于类别之间的相关性,我们缺乏先验。

虽然我们知道,与“狗”相比,“猫”和“老虎”之间的相似度更高,但是这种关系在样本标记之初是难以量化的,所以label都是one hot。

在这个前提下,均方误差损失可能会给出错误的指示,比如猫、老虎、狗的3分类问题,label为\([1, 0, 0]\),在均方误差看来,预测为\([0.8, 0.1, 0.1]\)要比\([0.8, 0.15, 0.05]\)要好,即认为平均总比有倾向性要好,但这有悖我们的常识。

而对交叉熵损失,既然类别间复杂的相似度矩阵是难以量化的,索性只能关注样本所属的类别,只要\(\hat{y}_p\)越接近于1就好,这显示是更合理的。

softmax反向传播角度

softmax的作用是将\((-\infty, +\infty)\)的几个实数映射到\((0,1)\)之间且之和为1,以获得某种概率解释。

令softmax函数的输入为\(z\),输出为\(\hat{y}\),对结点\(p\)有,

\[\hat{y}_p = \frac{e^{z_p}}{\sum_{k=1}^K e^{z_k}}

\]

\(\hat{y}_p\)不仅与\(z_p\)有关,还与\(\{z_k | k\neq p\}\)有关,这里仅看$z_p $,则有

\[\frac{\partial \hat{y}_p}{\partial z_p} = \hat{y}_p(1-\hat{y}_p)

\]

\(\hat{y}_p\)为正确分类的概率,为0时表示分类完全错误,越接近于1表示越正确。根据链式法则,按理来讲,对与\(z_p\)相连的权重,损失函数的偏导会含有\(\hat{y}_p(1-\hat{y}_p)\)这一因子项,\(\hat{y}_p = 0\)时分类错误,但偏导为0,权重不会更新,这显然不对——分类越错误越需要对权重进行更新。

对交叉熵损失,

\[\frac{\partial L}{\partial \hat{y}_p} = -\frac{1}{\hat{y}_p}

\]

则有

\[\frac{\partial L}{\partial \hat{z}_p} = \frac{\partial L}{\partial \hat{y}_p} \cdot \frac{\partial \hat{y}_p}{\partial z_p} = \hat{y}_p - 1

\]

恰好将\(\hat{y}_p(1-\hat{y}_p)\)中的\(\hat{y}_p\)消掉,避免了上述情形的发生,且\(\hat{y}_p\)越接近于1,偏导越接近于0,即分类越正确越不需要更新权重,这与我们的期望相符。

而对均方误差损失,

\[\frac{\partial L}{\partial \hat{y}_p} = -2(1-\hat{y}_p)=2(\hat{y}_p - 1)

\]

则有,

\[\frac{\partial L}{\partial \hat{z}_p} = \frac{\partial L}{\partial \hat{y}_p} \cdot \frac{\partial \hat{y}_p}{\partial z_p} = -2 \hat{y}_p (1 - \hat{y}_p)^2

\]

显然,仍会发生上面所说的情况——\(\hat{y}_p = 0\),分类错误,但不更新权重。

综上,对分类问题而言,无论从损失函数角度还是softmax反向传播角度,交叉熵都比均方误差要好。

参考

均方误差越大越好_直观理解为什么分类问题用交叉熵损失而不用均方误差损失?...相关推荐

  1. 交叉熵损失函数分类_逻辑回归(Logistic Regression)二分类原理,交叉熵损失函数及python numpy实现...

    本文目录: 1. sigmoid function (logistic function) 2. 逻辑回归二分类模型 3. 神经网络做二分类问题 4. python实现神经网络做二分类问题 ----- ...

  2. 二分类交叉熵损失函数python_二分类问题的交叉熵损失函数多分类的问题的函数交叉熵损失函数求解...

    二分类问题的交叉熵损失函数; 在二分类问题中,损失函数为交叉熵损失函数.对于样本(x,y)来讲,x为样本 y为对应的标签.在二分类问题中,其取值的集合可能为{0,1},我们假设某个样本的真实标签为yt ...

  3. 多分类问题的交叉熵计算

    多分类问题的交叉熵   在多分类问题中,损失函数(loss function)为交叉熵(cross entropy)损失函数.对于样本点(x,y)来说,y是真实的标签,在多分类问题中,其取值只可能为标 ...

  4. seq2seq模型_直观理解并使用Tensorflow实现Seq2Seq模型的注意机制

    采用带注意机制的序列序列结构进行英印地语神经机器翻译 Seq2seq模型构成了机器翻译.图像和视频字幕.文本摘要.聊天机器人以及任何你可能想到的包括从一个数据序列到另一个数据序列转换的任务的基础.如果 ...

  5. kkt条件的理解_直观理解KKT条件

    KKT最优化条件是Karush[1939],以及Kuhn和Tucker[1951]先后独立发表出來的.这组最优化条件在Kuhn和Tucker发表之后才逐渐受到重视,因此许多情况下只记载成库恩塔克条件( ...

  6. 快速理解binary cross entropy 二元交叉熵

    Binary cross entropy 二元交叉熵是二分类问题中常用的一个Loss损失函数,在常见的机器学习模块中都有实现.本文就二元交叉熵这个损失函数的原理,简单地进行解释. 首先是二元交叉熵的公 ...

  7. 均方误差越大越好_什么是峰值信噪比(PSNR)及均方误差(MSE)

    展开全部 峰值信噪比(英语:32313133353236313431303231363533e58685e5aeb931333431356632Peak signal-to-noise ratio,常 ...

  8. 均方误差越大越好_超详细 | 如何写好计量经济学实证分析论文?

    经济学研究的主要目的是用经济理论解释所预测到的经济现象,预测经济走势,并提出政策建议.计量经济学是检验经济理论,解释.预测经济现象的最主要数量化方法.其重要性是因为绝大多数经济现象不能像自然科学那样通 ...

  9. ker矩阵是什么意思_直观理解!你一定要读一下的“矩阵和线性代数入门”

    首发于 | 知乎 作者 | 家里有只肉丸子 链接 | https://zhuanlan.zhihu.com/p/137112358 许多同学一听到高等代数(线性代数)的名字就瑟瑟发抖,觉得似乎是极困难 ...

最新文章

  1. 《LeetCode力扣练习》剑指 Offer 24. 反转链表 Java
  2. java占位符填充_实现java中的占位符
  3. 【STM32 .Net MF开发板学习-21】蓝牙遥控智能小车(PC模式)
  4. 打车APP大数据宰客套路多:苹果比安卓贵、熟客比新客贵
  5. Loadrunner基础:Loadrunner Vuser基本概念和应用
  6. mocha 测试 mysql_GitHub - zouzhenxing/lei: 整合Express mysql ioredis ejs 的一开发框架,使用mocha对api进行测试...
  7. ShadeGraph教程之节点详解2:Channel Nodes
  8. c++的类中,声明一个对象好还是用指针申请一块空间好?
  9. 26岁辞职、365天创业,就让程序员任性一回
  10. Python : Arrow、Pyarrow库、以及与Julia互读
  11. EasyRecovery2022强力数据恢复软件
  12. 了解传销系列之三 : 开心门
  13. 更新Win10版本后,wifi图标不见了,并且连接不到wifi和宽带,以及点击网络和Internet闪退的情况
  14. SQL高级——PLSQL数据库编程
  15. stl格式文件导入Unity
  16. java水果超市课程设计_(学习java)水果超市管理系统
  17. 七印部落送给大家的《启示录》
  18. Ansys workbench分析应用基础(2)
  19. c语言is_int(),C程序设计英文试题
  20. STM32—规则通道和注入通道的知识总结

热门文章

  1. 谢天谢地,AI开发者的“吐槽大会”终于结束了
  2. 传z播客 刘意_2015年Java基础视频笔记(day18~day20(2016年3月20日14:36:05)
  3. python电影评价分析.dat_python读DAT - IT屋-程序员软件开发技术分享社区
  4. importerror: cannot import name ‘HTTPClientFactory‘ from ‘twisted.web.client‘ (unknown location)
  5. matlab学习四,一元函数绘图方法
  6. 零基础学PS平面设计基础有哪些?
  7. 2014年大学生创业项目大全
  8. 简短的超市管理c语言程序设计,C语言程序设计超市管理系统1.doc
  9. 创新电影院与市场呈鲜明对比产业现状
  10. 七、webpack 介绍