原标题:Softmax和交叉熵的深度解析和Python实现

作者 | PARAS DAHAL

译者 | zzq

【导读】如果你稍微了解一点深度学习的知识或者看过深度学习的在线课程,你就一定知道最基础的多分类问题。当中,老师一定会告诉你在全连接层后面应该加上 Softmax 函数,如果正常情况下(不正常情况指的是类别超级多的时候)用交叉熵函数作为损失函数,你就一定可以得到一个让你基本满意的结果。而且,现在很多开源的深度学习框架,直接就把各种损失函数写好了(甚至在 Pytorch中 CrossEntropyLoss 已经把 Softmax函数集合进去了),你根本不用操心怎么去实现他们,但是你真的理解为什么要这么做吗?这篇小文就将告诉你:Softmax 是如何把 CNN 的输出转变成概率,以及交叉熵是如何为优化过程提供度量。为了让读者能够深入理解,我们将会用 Python 一一实现他们。

▌Softmax函数

Softmax 函数接收一个 这N维向量作为输入,然后把每一维的值转换成(0,1)之间的一个实数,它的公式如下面所示:

正如它的名字一样,Softmax 函数是一个“软”的最大值函数,它不是直接取输出的最大值那一类作为分类结果,同时也会考虑到其它相对来说较小的一类的输出。

说白了,Softmax 可以将全连接层的输出映射成一个概率的分布,我们训练的目标就是让属于第k类的样本经过 Softmax 以后,第 k 类的概率越大越好。这就使得分类问题能更好的用统计学方法去解释了。

使用 Python,我们可以这么去实现 Softmax 函数:

我们需要注意的是,在 numpy 中浮点类型是有数值上的限制的,对于float64,它的上限是。对于指数函数来说,这个限制很容易就会被打破,如果这种情况发生了 python 便会返回nan。

为了让 Softmax 函数在数值计算层面更加稳定,避免它的输出出现nan这种情况,一个很简单的方法就是对输入向量做一步归一化操作,仅仅需要在分子和分母上同乘一个常数C,如下面的式子所示

理论上来说,我们可以选择任意一个值作为,但是一般我们会选择

,通过这种方法就使得原本非常大的指数结果变成0,避免出现nan的情况。

同样使用 Python,改进以后的 Softmax 函数可以这样写:

▌Softmax 函数的导数推倒过程

通过上文我们了解到,Softmax 函数可以将样本的输出转变成概率密度函数,由于这一很好的特性,我们就可以把它加装在神经网络的最后一层,随着迭代过程的不断深入,它最理想的输出就是样本类别的 One-hot 表示形式。进一步我们来了解一下如何去计算 Softmax 函数的梯度(虽然有了深度学习框架这些都不需要你去一步步推导,但为了将来能设计出新的层,理解反向传播的原理还是很重要的),对 Softmax 的参数求导:

根据商的求导法则,对于

其导数为

。对于我们来说

。在中,

一直都是,但是在中,当且仅当的时候,

才为。具体的过程,我们看一下下面的步骤:

如果 ,

如果

所以 Softmax 函数的导数如下面所示:

▌交叉熵损失函数

下面我们来看一下对模型优化真正起到作用的损失函数——交叉熵损失函数。交叉熵函数体现了模型输出的概率分布和真实样本的概率分布的相似程度。它的定义式就是这样:

在分类问题中,交叉熵函数已经大范围的代替了均方误差函数。也就是说,在输出为概率分布的情况下,就可以使用交叉熵函数作为理想与现实的度量。这也就是为什么它可以作为有 Softmax 函数激活的神经网络的损失函数。

我们来看一下,在 Python 中是如何实现交叉熵函数的:

▌交叉熵损失函数的求导过程

就像我们之前所说的,Softmax 函数和交叉熵损失函数是一对好兄弟,我们用上之前推导 Softmax 函数导数的结论,配合求导交叉熵函数的导数:

加上 Softmax 函数的导数:

y 代表标签的 One-hot 编码,因此

,并且 。因此我们就可以得到:

可以看到,这个结果真的太简单了,不得不佩服发明它的大神们!最后,我们把它转换成代码:

▌小结

需要注意的是,正如我之前提到过的,在许多开源的深度学习框架中,Softmax 函数被集成到所谓的 CrossEntropyLoss 函数中。比如 Pytorch 的说明文档,就明确地告诉读者 CrossEntropyLoss 这个损失函数是 Log-Softmax 函数和负对数似然函数(NLLoss)的组合,也就是说当你使用它的时候,没有必要再在全连接层后面加入 Softmax 函数。还有许多文章中会提到 SoftmaxLoss,其实它就是 Softmax 函数和交叉熵函数的组合,跟我们说的 CrossEntropyLoss 函数是一个意思,这点需要读者自行分辨即可。

https://deepnotes.io/softmax-crossentropy

GitHub 地址:

https://github.com/parasdahal/deepnet

参考链接:

The Softmax function and its derivative

Bendersky, E., 2016.

https://eli.thegreenplace.net/2016/the-softmax-function-and-its-derivative/

CS231n Convolutional Neural Networks for Visual Recognition

Andrej Karpathy, A.K., 2016.

http://cs231n.github.io/convolutional-networks/返回搜狐,查看更多

责任编辑:

python交叉熵损失函数实现_Softmax和交叉熵的深度解析和Python实现相关推荐

  1. python交叉熵损失函数实现_大话交叉熵损失函数

    使用keras进行二分类时,常使用binary_crossentropy作为损失函数.那么它的原理是什么,跟categorical_crossentropy.sparse_categorical_cr ...

  2. 交叉熵损失函数和focal loss_理解熵、交叉熵和交叉熵损失

    交叉熵损失是深度学习中应用最广泛的损失函数之一,这个强大的损失函数是建立在交叉熵概念上的.当我开始使用这个损失函数时,我很难理解它背后的直觉.在google了不同材料后,我能够得到一个令人满意的理解, ...

  3. 交叉熵损失函数和似然估计_熵、交叉熵及似然函数的关系

    熵.交叉熵及似然函数的关系 1. 熵 1.1 信息量 信息量:最初的定义是信号取值数量m的对数为信息量\(I\),即 \(I=log_2m\).这是与比特数相关的,比如一个信号只有两个取值,那么用1个 ...

  4. 深度学习入门之Python小白逆袭大神系列(三)—深度学习常用Python库

    深度学习常用Python库介绍 目录 深度学习常用Python库介绍 简介 Numpy库 padas库 PIL库 Matplotlib库 简介 Python被大量应用在数据挖掘和深度学习领域,其中使用 ...

  5. python 静态方法_Python编程思想(25):方法深度解析

    -----------支持作者请转发本文-----------李宁老师已经在「极客起源」 微信公众号推出<Python编程思想>电子书,囊括了Python的核心技术,以及Python的主要 ...

  6. 【深度学习原理】交叉熵损失函数的实现

    交叉熵损失函数 一般我们学习交叉熵损失函数是在二元分类情况下: L=−[ylogy^+(1−y)log(1−y^)]L=−[ylog ŷ +(1−y)log (1−ŷ )]L=−[ylogy^​+ ...

  7. 深入理解Pytorch负对数似然函数(torch.nn.NLLLoss)和交叉熵损失函数(torch.nn.CrossEntropyLoss)

    在看Pytorch的交叉熵损失函数torch.nn.CrossEntropyLoss官方文档介绍中,给出的表达式如下.不免有点疑惑为何交叉熵损失的表达式是这个样子的 loss ⁡ ( y , clas ...

  8. 【NLP】再看交叉熵损失函数

    交叉熵 在深度学习领域出现交叉熵(cross entropy)的地方就是交叉熵损失函数了.通过交叉熵来衡量目标与预测值之间的差距.了解交叉熵还需要从信息论中的几个概念说起. 信息量 如何衡量一条信息包 ...

  9. 深度学习基础入门篇[五]:交叉熵损失函数、MSE、CTC损失适用于字识别语音等序列问题、Balanced L1 Loss适用于目标检测

    [深度学习入门到进阶]必看系列,含激活函数.优化策略.损失函数.模型调优.归一化算法.卷积模型.序列模型.预训练模型.对抗神经网络等 专栏详细介绍:[深度学习入门到进阶]必看系列,含激活函数.优化策略 ...

最新文章

  1. ARP欺骗原理与模拟
  2. 【新星计划】MATLAB-字符串处理
  3. 手把手教你逼走员工的23种套路,大写的服!
  4. centos7.4 mysql启动,centos7下mysql服务启动失败_网站服务器运行维护,centos7,mysql
  5. 代码 or 指令,浅析ARM架构下的函数的调用过程
  6. 分享一套基于SpringBoot和Vue的企业级中后台开源项目,代码很规范!
  7. thinkphp 表单令牌
  8. C# XML文件读取
  9. Flutter之路由系列之Navigator简析
  10. Laravel 数据库 - 数据填充
  11. Springboot 静态资源路径配置 实例介绍
  12. 一周一英文测试文稿翻译 质量保障测试人员的一天
  13. python图形包是什么_介绍Python 图形计算工具包
  14. (C/C++/Java)判断中文、字符串、数字是否为“回文”
  15. 云中马在A股上市:总市值约为40亿元,叶福忠为实际控制人
  16. 【转载】贵妃醉酒百态(原创)
  17. 苹果手机html转pdf文件怎么打开吗,今天才知道,苹果手机打开这个功能,可以将纸质文档转为Word...
  18. 如何更方便的探讨技术
  19. UEFI开发探索32 – 有趣的图像特效
  20. 【Stephen Boyd】【2009】凸优化

热门文章

  1. 日语语法归纳--「に」的用法
  2. OL记载Arcgis Server切片
  3. Brief Bioinform | 农科院深圳基因组所王怡雯组提出一种去除微生物组数据中批次效应的多元算法框架...
  4. 谁动了我的奶酪[By tina]
  5. 微信小程序购物车组件
  6. 二叉树遍历方法——前、中、后序遍历(图解)
  7. week8 B - 猫猫向前冲(拓扑排序)
  8. Ubuntu下的docker搭建及基础使用
  9. GDC - 《幽灵行动:荒野》地形技术和工具(三)
  10. 折腾个自己的nps服务器