交叉熵损失函数

  • 前言
  • 交叉熵损失函数
    • 信息量
    • 信息熵
    • 交叉熵
    • 求导过程
  • 应用
  • 扩展
    • Binary_Crossentropy
    • 均方差损失函数(MSE)

前言

深度学习中的损失函数的选择,需要注意一点是直接衡量问题成功的指标不一定总可行。损失函数需要在只有小批量数据时即可计算,而且还必须可微。下表列出常见问题类型的最后一层以及损失函数的选择,仅供参考。

问题类型 最后一层激活 损失函数
二分类问题 sigmoid binary_crossentropy
多分类、单标签 softmax categorical_crossentropy
多分类、多标签问题 sigmoid binary_crossentropy
回归到任意值 mse
回归到0~1范围内的值 sigmoid mse或binary_crossentropy
  • 二分类问题
    二分类:表示分类任务中有两个类别,比如我们想识别一幅图片是不是猫。也就是说,训练一个分类器,输入一幅图片,用特征向量x表示,输出是不是猫,用y=0或1表示。二类分类是假设每个样本都被设置了一个且仅有一个标签 0 或者 1。
  • 多分类问题
    多类分类(Multiclass classification): 表示分类任务中有多个类别, 比如对一堆水果图片分类, 它们可能是橘子、苹果、梨等. 多类分类是假设每个样本都被设置了一个且仅有一个标签: 一个水果可以是苹果或者梨, 但是同时不可能是两者。
  • 多标签分类
    多标签分类(Multilabel classification): 给每个样本一系列的目标标签. 可以想象成一个数据点的各属性不是相互排斥的(一个水果既是苹果又是梨就是相互排斥的), 比如一个文档相关的话题. 一个文本可能被同时认为是宗教、政治、金融或者教育相关话题。

接下来我们从信息熵出发介绍交叉熵损失函数。

交叉熵损失函数

信息量

香农提出“信息是用来消除随机不确定的东西”。设一个事件发生的概率为P(x)P(x)P(x),其信息量表示为:I(x)=−log⁡(P(x))I(x)=-\log(P(x))I(x)=−log(P(x))

信息熵

信息熵描述所有信息量的期望,公式是H(X)=−∑i=1nP(xi)log⁡(P(xi)),X=x1,x2,x3,...xnH(X)=-\sum^n_{i=1}P(x_i)\log(P(x_i)), X=x_1, x_2, x_3,... x_nH(X)=−i=1∑n​P(xi​)log(P(xi​)),X=x1​,x2​,x3​,...xn​。以下用一个例子说明,假设明天周末,有三个活动:打球,逛街以及爬山,他们的概率以及对应的信息量如下表所示:

event P signal
打球 0.2 −log⁡(0.2)-\log(0.2)−log(0.2)
逛街 0.3 −log⁡(0.3)-\log(0.3)−log(0.3)
爬山 0.5 −log⁡(0.5)-\log(0.5)−log(0.5)

那么信息熵H(X)H(X)H(X)则为:H(X)=−(0.2∗log⁡(0.2)+0.3∗log⁡(0.3)+0.5∗log⁡(0.5))H(X)=-(0.2*\log(0.2)+0.3*\log(0.3)+0.5*\log(0.5))H(X)=−(0.2∗log(0.2)+0.3∗log(0.3)+0.5∗log(0.5))

交叉熵

KL散度描述两个分布之间的差异:DKL(p∣∣q)=∑i=1np(xi)log⁡(p(xi)q(xi))=∑i=1np(xi)log⁡(p(xi)))−∑i=1np(xi)log⁡(q(xi)))=−H(p(x))+[−∑i=1np(xi)log⁡(q(xi)))]D_{KL}(p||q)=\sum^n_{i=1}p(x_i)\log(\frac{p(x_i)}{q(x_i)}) \\=\sum^n_{i=1}p(x_i)\log(p(x_i)))-\sum^n_{i=1}p(x_i)\log(q(x_i)))\\=-H(p(x))+[-\sum^n_{i=1}p(x_i)\log(q(x_i)))]DKL​(p∣∣q)=i=1∑n​p(xi​)log(q(xi​)p(xi​)​)=i=1∑n​p(xi​)log(p(xi​)))−i=1∑n​p(xi​)log(q(xi​)))=−H(p(x))+[−i=1∑n​p(xi​)log(q(xi​)))],其中H(p(x))H(p(x))H(p(x))表示信息熵,后者则为交叉熵:H(p,q)=−∑i=1np(xi)log⁡(q(xi)))H(p, q)=-\sum^n_{i=1}p(x_i)\log(q(x_i)))H(p,q)=−i=1∑n​p(xi​)log(q(xi​)))
交叉熵损失函数经常用于分类问题中,特别是在神经网络做分类问题时,也经常使用交叉熵作为损失函数,此外,由于交叉熵涉及到计算每个类别的概率,所以交叉熵几乎每次都和sigmoid(或softmax)函数一起出现。

求导过程

交叉熵的公式是H(p,q)=−∑i=1np(xi)log⁡(q(xi)))H(p, q)=-\sum^n_{i=1}p(x_i)\log(q(x_i)))H(p,q)=−i=1∑n​p(xi​)log(q(xi​))),为了好看,我们令p(xi)=yip(x_i)=y_ip(xi​)=yi​,其中yiy_iyi​表示真实label标签;q(xi)=aiq(x_i)=a_iq(xi​)=ai​,aia_iai​表示神经元的输出。则交叉熵损失函数可以表示成:Cost=−∑i=1nyilog⁡(ai)Cost=-\sum^n_{i=1}y_i\log(a_i)Cost=−i=1∑n​yi​log(ai​)。Softmax公式为:Si=ezi∑kezkS_i=\frac{e^{z_i}}{\sum_ke^{z_k}}Si​=∑k​ezk​ezi​​。一个神经元的输出如下图所示:

假设神经元的输出为:zi=∑jwxjxxj+bz_i=\sum_jw_{xj}x_{xj}+bzi​=j∑​wxj​xxj​+b,其中wijw_{ij}wij​表示第iii个神经元第jjj个权重。ziz_izi​通过softmax输出则是:ai=ezi∑kezka_i=\frac{e^{z_i}}{\sum_ke^{z_k}}ai​=∑k​ezk​ezi​​一切准备就绪,开始我们的反向推导。

  1. 求导公式∂C∂zi\frac{\partial C}{\partial z_i}∂zi​∂C​根据复合函数求导法则:∂C∂zi=∂C∂aj∂aj∂zi\frac{\partial C}{\partial z_i}=\frac{\partial C}{\partial a_j}\frac{\partial a_j}{\partial z_i}∂zi​∂C​=∂aj​∂C​∂zi​∂aj​​这里可能有个疑问:对ziz_izi​求导怎么会出现aja_jaj​?这是因为softmax公式里面分母包含了所有神经元参数,这里需要分情况进行讨论。

  2. 对∂C∂aj\frac{\partial C}{\partial a_j}∂aj​∂C​求导:∂C∂aj=∂(−∑jyjln⁡aj)∂aj=−∑jyj1aj\frac{\partial C}{\partial a_j}=\frac{\partial(-\sum_jy_j\ln a_j)}{\partial a_j}=-\sum_jy_j\frac{1}{a_j}∂aj​∂C​=∂aj​∂(−∑j​yj​lnaj​)​=−j∑​yj​aj​1​

  3. 对∂aj∂zi\frac{\partial a_j}{\partial z_i}∂zi​∂aj​​求导,需要分情况。

  4. 最后合在一起:
    ∑jyj=1\sum_j y_j=1∑j​yj​=1,于是有:∂C∂zi=ai−yi\frac{\partial C}{\partial z_i}=a_i-y_i∂zi​∂C​=ai​−yi​

应用

  1. 为什么用交叉熵,而不用均方差作为损失函数?

扩展

Binary_Crossentropy

二值交叉熵损失函数很好理解,就是取值只有0或者1的分类,他的公式是:Hp(q)=−1N∑i=1Nyilog⁡(p(yi))+(1−yi)log⁡(1−p(yi))H_p(q)=-\frac{1}{N}\sum_{i=1}^Ny_i\log(p(y_i))+(1-y_i)\log(1-p(y_i))Hp​(q)=−N1​i=1∑N​yi​log(p(yi​))+(1−yi​)log(1−p(yi​))

均方差损失函数(MSE)

该损失函数通过计算真实值与预测值的欧式距离直观反馈了预测值与真实值得误差。预测值与真实值越接近,则两者的均方差越小。公式如下:loss=12(z−y)2(单样本)loss = \frac{1}{2}(z-y)^2 (单样本)loss=21​(z−y)2(单样本),J=12m∑i=1m(zi−yi)2(多样本)J=\frac{1}{2m}\sum^m_{i=1}(z_i-y_i)^2 (多样本)J=2m1​i=1∑m​(zi​−yi​)2(多样本)具体案例可以参考:均方差损失函数

19个损失函数汇总:https://zhuanlan.zhihu.com/p/258395701

深度学习 交叉熵损失函数相关推荐

  1. 【深度学习原理】交叉熵损失函数的实现

    交叉熵损失函数 一般我们学习交叉熵损失函数是在二元分类情况下: L=−[ylogy^+(1−y)log(1−y^)]L=−[ylog ŷ +(1−y)log (1−ŷ )]L=−[ylogy^​+ ...

  2. 交叉熵损失函数优缺点_如何简单通俗的理解交叉熵损失函数?

    前面小编给大家简单介绍过损失函数,今天给大家继续分享交叉熵损失函数,直接来看干货吧. 一.交叉熵损失函数概念 交叉熵损失函数CrossEntropy Loss,是分类问题中经常使用的一种损失函数.公式 ...

  3. 交叉熵损失函数的通用性(为什么深度学习DL普遍用它):预测输出与 y 差得越多,L 的值越大,也就是说对当前模型的 “ 惩罚 ” 越大,而且是非线性增大是一种类似指数增长的级别,结论:它对结果有引导性

    交叉熵损失函数的通用性(为什么深度学习DL普遍用它):预测输出与 y 差得越多,L 的值越大,也就是说对当前模型的 " 惩罚 " 越大,而且是非线性增大是一种类似指数增长的级别,结 ...

  4. 二分类交叉熵损失函数python_【深度学习基础】第二课:softmax分类器和交叉熵损失函数...

    [深度学习基础]系列博客为学习Coursera上吴恩达深度学习课程所做的课程笔记. 本文为原创文章,未经本人允许,禁止转载.转载请注明出处. 1.线性分类 如果我们使用一个线性分类器去进行图像分类该怎 ...

  5. 深度学习基础入门篇[五]:交叉熵损失函数、MSE、CTC损失适用于字识别语音等序列问题、Balanced L1 Loss适用于目标检测

    [深度学习入门到进阶]必看系列,含激活函数.优化策略.损失函数.模型调优.归一化算法.卷积模型.序列模型.预训练模型.对抗神经网络等 专栏详细介绍:[深度学习入门到进阶]必看系列,含激活函数.优化策略 ...

  6. 【深度学习】——分类损失函数、回归损失函数、交叉熵损失函数、均方差损失函数、损失函数曲线、

    目录 代码 回归问题的损失函数 分类问题的损失函数 1. 0-1损失 (zero-one loss) 2.Logistic loss 3.Hinge loss 4.指数损失(Exponential l ...

  7. 深度学习中softmax交叉熵损失函数的理解

    1. softmax层的作用 通过神经网络解决多分类问题时,最常用的一种方式就是在最后一层设置n个输出节点,无论在浅层神经网络还是在CNN中都是如此,比如,在AlexNet中最后的输出层有1000个节 ...

  8. 深度学习-tensorflow1.x之交叉熵损失函数(softmax_cross_entropy_with_logits)代码实现 Tensorflow1.x 和 Numpy

    交叉熵损失函数 神经网络(机器学习)中作为损失函数 具体的理解可以看 https://blog.csdn.net/SIGAI_CSDN/article/details/86554230 交叉熵损失函数 ...

  9. 【深度学习】损失函数系列 (一) 平方损失函数、交叉熵损失函数 (含label_smoothing、ignore_index等内容)

    一.平方损失函数(Quadratic Loss / MSELoss): Pytorch实现: from torch.nn import MSELoss loss = nn.MSELoss(reduct ...

最新文章

  1. 数据结构(严蔚敏)之三——顺序栈之c语言实现
  2. power design教程
  3. XML反序列化出错,XML 文档(2, 2)中有错误
  4. 全网最细之static关键字讲解
  5. catia怎么将特征参数化_浅谈Catia VBA与参数化建模的结合
  6. 西门子v90伺服说明书_西门子V90伺服驱动器的的EPOS控制模式
  7. spring学习--bean--普通bean与工厂bean(FactoryBean)区别
  8. MongoDB 数据库管理(不定时更新)
  9. java.util.concurrent.ExecutionException: java.lang.OutOfMemoryError: PermGen space
  10. 一个DirectInput演示程序
  11. HAUT OJ 1504: CXK的篮球数(加强版)--差分法
  12. 大数据的数据库设计原则有哪些
  13. hge引擎配置登录器教程_Hge引擎程序+登录器配置器+配套工具+全套入门教程
  14. xampp mysql密码忘记_XAMPP重置MySQL密码
  15. 世界著名的品牌啤酒——网络整理X
  16. Web 实现前后端分离,前后端解耦
  17. 还有未完待续的瓜哦!
  18. ubuntu 下安装labelImg报错
  19. mysql根据表的一个字段决定去关联(join)那张表格
  20. 3ds max 2015 安装方法

热门文章

  1. Unity - 优化 Vector3.ProjectOnPlane
  2. 大数据统计:发布近期网络辟谣TOP10,看看你信过几条?
  3. 这本对我影响最大的书,想与你分享!
  4. 解决:启动Mybatis自动生成代码插件出现低级异常
  5. 人生哲理---人的一生
  6. 对记忆化搜索(ms)和动态规划(dp)的深入理解
  7. 微信域名检测采用官方接口
  8. Qt 集成miniblink浏览器库之5 支持独立窗口和子窗口
  9. IDEAL葵花宝典:java代码开发规范插件 FindBugs-IDEA
  10. DVWA上XSS(DOM)(基于 DOM 的跨站脚本)全难度