公众号关注 “视学算法”

设为 “星标”,DLCV消息即可送达!

来自 | 知乎   作者 | 董鑫

https://www.zhihu.com/question/294679135/answer/885285177

本文仅作学术分享,著作权归作者所有,如有侵权,请联系删除

softmax 虽然简单,但是其实这里面有非常的多细节值得一说。

我们挨个捋一捋。

   1. 什么是 Softmax?

首先,softmax 的作用是把 一个序列,变成概率。

他能够保证:

  1. 所有的值都是 [0, 1] 之间的(因为概率必须是 [0, 1])

  2. 所有的值加起来等于 1

从概率的角度解释 softmax 的话,就是

   2. 文档里面跟 Softmax 有关的坑

这里穿插一个“小坑”,很多deep learning frameworks的 文档 里面 (PyTorch,TensorFlow)是这样描述 softmax 的,

take logits and produce probabilities

很明显,这里面的 logits 就是 全连接层(经过或者不经过 activation都可以)的输出, probability 就是 softmax 的输出结果。这里  logits 有些地方还称之为 unscaled log probabilities。这个就很意思了,unscaled probability可以理解,那又为什么 全连接层直接出来结果会和 log 有关系呢?

原因有两个:

  1. 因为 全连接层 出来的结果,其实是无界的(有正有负),这个跟概率的定义不一致,但是你如果他看成 概率的 log,就可以理解了。

  2. softmax 的作用,我们都知道是 normalize probability。在 softmax 里面,输入  都是在指数上的 ,所有把 想成 log of probability 也就顺理成章了。

   3. Softmax 就是 Soft 版本的 ArgMax

好的,我们把话题拉回到 softmax。

softmax,顾名思义就是 soft 版本的 argmax。我们来看一下为什么?

举个栗子,假如 softmax 的输入是:

softmax 的结果是:

我们稍微改变一下输入,把 3 改大一点,变成 5,输入是

softmax 的结果是:

可见 softmax 是一种非常明显的 “马太效应”:强(大)的更强(大),弱(小)的更弱(小)。假如你要选一个最大的数出来,这个其实就是叫 hardmax。那么 softmax 呢,其实真的就是 soft 版本的 max,以一定的概率选一个最大值出来。在hardmax中,真正最大的那个数,一定是以1(100%) 的概率被选出来,其他的值根本一点机会没有。但是在 softmax 中,所有的值都有机会被作为最大值选出来。只不过,由于 softmax 的 “马太效应”,次大的数,即使跟真正最大的那个数差别非常少,在概率上跟真正最大的数相比也小了很多。

所以,前面说,“softmax 的作用是把 一个序列,变成概率。” 这个概率不是别的,而是被选为 max 的概率。

这种 soft 版本的 max 在很多地方有用的上。因为 hard 版本的 max 好是好,但是有很严重的梯度问题,求最大值这个函数本身的梯度是非常非常稀疏的(比如神经网络中的 max pooling),经过hardmax之后,只有被选中的那个变量上面才有梯度,其他都是没有梯度。这对于一些任务(比如文本生成等)来说几乎是不可接受的。所以要么用 hard max 的变种,比如Gumbel

Categorical Reparameterization with Gumbel-Softmax

链接:https://arxiv.org/abs/1611.01144

亦或是 ARSM

ARSM: Augment-REINFORCE-Swap-Merge Estimator for Gradient Backpropagation Through Categorical Variable

链接:http://proceedings.mlr.press/v97/yin19c.html

,要么就直接 softmax。

   4. Softmax 的实现以及数值稳定性

softmax 的代码实现看似是比较简单的,直接套上面的公式就好

def softmax(x):"""Compute the softmax of vector x."""exps = np.exp(x)return exps / np.sum(exps)

但是这种方法非常的不稳定。因为这种方法要算指数,只要你的输入稍微大一点,比如:

分母上就是

很明显,在计算上一定会溢出。解决方法也比较简单,就是我们在分子分母上都乘上一个系数,减小数值大小,同时保证整体还是对的

把常数 C 吸收进指数里面

这里的D是可以随便选的,一般可以选成

具体实现可以写成这样

def stablesoftmax(x):"""Compute the softmax of vector x in a numerically stable way."""shiftx = x - np.max(x)exps = np.exp(shiftx)return exps / np.sum(exps)

这样一种实现数值稳定性已经好了很多,但是仍然会有数值稳定性的问题。比如输入的值差别过大的时候,比如

这种情况即使用了上面的方法,可能还是报 NaN 的错误。但是这个就是数学本身的问题了,大家使用的时候稍微注意下。

一种可能的替代的方案是使用 LogSoftmax (然后再求 exp),数值稳定性比 softmax 好一些。

可以看到LogSoftmax省了一个指数计算,省了一个除法,数值上相对稳定一些。另外,其实 Softmax_Cross_Entropy 里面也是这么实现的

   5. Softmax 的梯度

下面我们来看一下 softmax 的梯度问题。整个 softmax 里面的操作都是可微的,所以求梯度就非常简单了,就是基础的求导公式,这里就直接放结果了。

所以说,如果某个变量做完 softmax 之后很小,比如 ,那么他的梯度也是非常小的,几乎得不到任何梯度。有些时候,这会造成梯度非常的稀疏,优化不动。

   6. Softmax 和 Cross-Entropy 的关系

先说结论,

softmax 和 cross-entropy 本来太大的关系,只是把两个放在一起实现的话,算起来更快,也更数值稳定。

cross-entropy 不是机器学习独有的概念,本质上是用来衡量两个概率分布的相似性的。简单理解(只是简单理解!)就是这样,

如果有两组变量:

如果你直接求 L2 距离,两个距离就很大了,但是你对这俩做 cross entropy,那么距离就是0。所以 cross-entropy 其实是更“灵活”一些。

那么我们知道了,cross entropy 是用来衡量两个概率分布之间的距离的,softmax能把一切转换成概率分布,那么自然二者经常在一起使用。但是你只需要简单推导一下,就会发现,softmax + cross entropy 就好像

“往东走五米,再往西走十米”,

我们为什么不直接

“往西走五米”呢?

cross entropy 的公式是

这里的 就是我们前面说的 LogSoftmax。这玩意算起来比 softmax 好算,数值稳定还好一点,为啥不直接算他呢?

所以说,这有了 PyTorch 里面的 torch.nn.CrossEntropyLoss (输入是我们前面讲的 logits,也就是 全连接直接出来的东西)。这个 CrossEntropyLoss 其实就是等于 torch.nn.LogSoftmax + torch.nn.NLLLoss。

Softmax和Cross-entropy是什么关系?相关推荐

  1. 为什么要返回softmax_为什么softmax搭配cross entropy是解决分类问题的通用方案?

    众所周知,softmax+cross entropy是在线性模型.神经网络等模型中解决分类问题的通用方案,但是为什么选择这种方案呢?它相对于其他方案有什么优势?笔者一直也困惑不解,最近浏览了一些资料, ...

  2. softmax ce loss_手写softmax和cross entropy

    import 解释下给定的数据,x假设是fc layer的输出,可以看到这里x是(3,3)的,也就是batch_size=3,n_classes=3.但是label给出了三个数,取值是0,1,因此这里 ...

  3. 卷积神经网络系列之softmax,softmax loss和cross entropy

    全连接层到损失层间的计算 先理清下从全连接层到损失层之间的计算. 这张图的等号左边部分就是全连接层做的事,W是全连接层的参数,我们也称为权值,X是全连接层的输入,也就是特征. 从图上可以看出特征X是N ...

  4. 卷积神经网络系列之softmax,softmax loss和cross entropy的讲解

    我们知道卷积神经网络(CNN)在图像领域的应用已经非常广泛了,一般一个CNN网络主要包含卷积层,池化层(pooling),全连接层,损失层等.虽然现在已经开源了很多深度学习框架(比如MxNet,Caf ...

  5. softmax,softmax loss和cross entropy

    我们知道卷积神经网络(CNN)在图像领域的应用已经非常广泛了,一般一个CNN网络主要包含卷积层,池化层(pooling),全连接层,损失层等.虽然现在已经开源了很多深度学习框架(比如MxNet,Caf ...

  6. cross entropy loss = log softmax + nll loss

    代码如下: import torchlogits = torch.randn(3,4,requires_grad=True) labels = torch.LongTensor([1,0,2]) pr ...

  7. TensorFlow学习笔记(二十三)四种Cross Entropy交叉熵算法实现和应用

    交叉熵(Cross-Entropy) 交叉熵是一个在ML领域经常会被提到的名词.在这篇文章里将对这个概念进行详细的分析. 1.什么是信息量? 假设是一个离散型随机变量,其取值集合为,概率分布函数为 p ...

  8. TensorFlow 实战(一)—— 交叉熵(cross entropy)的定义

    对多分类问题(multi-class),通常使用 cross-entropy 作为 loss function.cross entropy 最早是信息论(information theory)中的概念 ...

  9. pytorch:交叉熵(cross entropy)

    1.entropy entropy中文叫做熵,也叫不确定性,某种程度上也叫惊喜度(measure of surprise) = 如果p(x)采用0-1分部,那么entropy=1log1=0 而对于e ...

  10. 平均符号熵的计算公式_交叉熵(Cross Entropy)从原理到代码解读

    交叉熵(Cross Entropy)是Shannon(香浓)信息论中的一个概念,在深度学习领域中解决分类问题时常用它作为损失函数. 原理部分:要想搞懂交叉熵需要先清楚一些概念,顺序如下:==1.自信息 ...

最新文章

  1. 建高性能ASP.NET站点 第五章—性能调优综述(中篇)
  2. Spring in Action 4 读书笔记之使用标签创建 AOP
  3. Web Service入门简介(一个简单的WebService示例)
  4. 用Python学分析 - 单因素方差分析
  5. step4 . day4 库函数和库函数的制作
  6. <scope>test</scope>的作用
  7. node python复用代码_python-代码复用(函数、lambda、递归、PyInstaller库)
  8. JavaScript中通过点击单选框动态显示和隐藏组件
  9. redis win连接以及配置连接密码
  10. 拓端tecdat|Python对商店数据进行lstm和xgboost销售量时间序列建模预测分析
  11. 软件测试的测试方法及测试流程
  12. STM32平台RT-Thread最小系统移植搭建 - STM32F107VCT6
  13. 修真院教学模式三大阶段之真实项目
  14. System.Globalization 命名空间
  15. 【luogu P5055】【模板】可持久化文艺平衡树
  16. register hotkey
  17. 网络游戏服务器之 日志系统
  18. 文件服务器怎么限制速度,文件服务器的速度
  19. 如何发掘各种暴利的赚钱项目,如何知道别人在干什么赚钱
  20. STM32CubeIDE HAL库IIC实现气压计MS5637的数据读取

热门文章

  1. 刻意练习:机器学习实战 -- Task01. K邻近算法
  2. 利用BP神经网络教计算机进行非线函数拟合(代码部分多层)
  3. Python 写了一个网页版的「P图软件」,惊呆了!
  4. 观点:AI 与自动化是矛盾的
  5. TIOBE 6 月榜单: Python 有望超越 C 语言成为第一名
  6. 深度学习中的注意力机制(三)
  7. 两大AI技术集于一身,有道词典笔3从0到1的飞跃
  8. 豪赌 ARM 梦碎:63 岁孙正义的「花甲历险记」
  9. 技术直播:1小时突击Java工程师面试核心(限免报名)
  10. 干货!3 个重要因素,带你看透 AI 技术架构方案的可行性!