前言

本文代码基于Pytorch实现。

一、softmax的定义及代码实现

1.1 定义

softmax(xi)=exp(xi)∑jnexp(xj)softmax(x_i) = \frac{exp(x_i)}{\sum_j^nexp(x_j)} softmax(xi​)=∑jn​exp(xj​)exp(xi​)​

1.2 代码实现

def softmax(X):'''实现softmax输入X的形状为[样本个数,输出向量维度]'''return torch.exp(X) / torch.sum(torch.exp(X), dim=1).reshape(-1, 1)>>> X = torch.randn(5, 5)
>>> y = softmax(X)
>>> torch.sum(y, dim=1)
tensor([1.0000, 1.0000, 1.0000, 1.0000])

二、softmax的作用


softmax可以对线性层的输出做规范化校准:保证输出为非负且总和为1。
因为如果直接将未规范化的输出看作概率,会存在2点问题:

  1. 线性层的输出并没有限制每个神经元输出数字的总和为1;
  2. 根据输入的不同,线性层的输出可能为负值。
    这2点违反了概率的基本公理。

三、softmax的上溢出(overflow)与下溢出(underflow)

3.1 上溢出

当xix_ixi​的取值过大时,指数运算的取值过大,若超出精度表示范围,则上溢出。

>>> torch.exp(torch.tensor([1000]))
tensor([inf])

3.2 下溢出

当向量x\boldsymbol xx的每个元素xix_ixi​的取值均为绝对值很大的负数时,则exp(xi)exp(x_i)exp(xi​)的数值很小超出了精度范围向下取0,分母∑jexp(j)\sum_jexp(j)∑j​exp(j)的取值为0。

>>> X = torch.ones(1, 3) * (-1000)
>>> softmax(X)
tensor([[nan, nan, nan]])

3.3 避免溢出

参考1中的技巧:

  1. 找到向量x\boldsymbol xx中的最大值:
    c=max(x)c=max(\boldsymbol x) c=max(x)
  2. softmaxsoftmaxsoftmax的分子、分母同时除以ccc
    softmax(xi−c)=exp(xi−c)∑jnexp(xj−c)=exp(xi)exp(−c)∑jnexp(xi)exp(−c)=softmax(xi)softmax(x_i - c) = \frac{exp(x_i-c)}{\sum_j^nexp(x_j-c)}=\frac{exp(x_i)exp(-c)}{\sum_j^nexp(x_i)exp(-c)}=softmax(x_i) softmax(xi​−c)=∑jn​exp(xj​−c)exp(xi​−c)​=∑jn​exp(xi​)exp(−c)exp(xi​)exp(−c)​=softmax(xi​)
    经过上述变换,分子的最大取值变为了exp(0)=1exp(0)=1exp(0)=1,避免了上溢出;
    分母中至少会+1+1+1,避免了分母为0造成下溢出。
    ∑jnexp(xj−c)=exp(xi−c)+exp(x2−c)+...+exp(xmax−c)=exp(x1−c)+exp(x2−c)+...+1\sum_j^nexp(x_j-c) =exp(x_i-c)+exp(x_2-c)+...+exp(x_{max}-c)\\ =exp(x_1-c) + exp(x_2-c)+...+1 j∑n​exp(xj​−c)=exp(xi​−c)+exp(x2​−c)+...+exp(xmax​−c)=exp(x1​−c)+exp(x2​−c)+...+1
def softmax_trick(X):c, _ = torch.max(X, dim=1, keepdim=True)return torch.exp(X - c) / torch.sum(torch.exp(X - c), dim=1).reshape(-1, 1)
>>> X = torch.tensor([[-1000, 1000, -1000]])
>>> softmax_trick(X)
tensor([0., 1., 0.])
>>> softmax(X)
tensor([[0., nan, 0.]])

pytorch的实现中已经做过了防止溢出的处理,所以,其运行结果与softmax_trick一致。

import pytorch.nn.functional as F
>>> X = torch.tensor([[-1000., 1000., -1000.]])
>>> F.softmax(X, dim=1)
tensor([[0., 1., 0.]])

3.4 Log-Sum_Exp Trick2(取log操作)

1. 避免下溢出
对数运算可以将相乘变为相加,即:log(x1x2)=log(x1)+log(x2)log(x_1x_2) = log(x_1) + log(x_2)log(x1​x2​)=log(x1​)+log(x2​)。 当两个很小的数x1、x2x_1、x_2x1​、x2​相乘时,其乘积会变得更小,超出精度则下溢出;而对数操作将乘积变为相加,降低了下溢出的风险。
2. 避免上溢出
log−softmaxlog-softmaxlog−softmax的定义:
log−softmax=log[softmax(xi)]=log(exp(xi)∑jnexp(xj))=xi−log[∑jnexp(xj)]\begin{aligned} log-softmax &=log[softmax(x_i)] \\ &= log(\frac{exp(x_i)}{\sum_j^nexp(x_j)}) \\ &=x_i - log[\sum_j^nexp(x_j)] \end{aligned}log−softmax​=log[softmax(xi​)]=log(∑jn​exp(xj​)exp(xi​)​)=xi​−log[j∑n​exp(xj​)]​
令y=log∑jnexp(xj)y=log\sum_j^nexp(x_j)y=log∑jn​exp(xj​),当xjx_jxj​的取值过大时,yyy存在上溢出的风险,因此,采用与3.3中同样的Trick:
y=log∑jnexp(xj)=log∑jnexp(xj−c)exp(c)=c+log∑jnexp(xj−c)\begin{aligned} y &= log\sum_j^nexp(x_j) \\ & = log\sum_j^nexp(x_j-c)exp(c) \\ & = c +log\sum_j^nexp(x_j-c) \end{aligned}y​=logj∑n​exp(xj​)=logj∑n​exp(xj​−c)exp(c)=c+logj∑n​exp(xj​−c)​
当c=max(x)c=max(\boldsymbol x)c=max(x)时,可避免上溢出。
此时,log−softmaxlog-softmaxlog−softmax的计算公式变为:(其实等价于直接对3.3节的Trick取对数
log−softmax=(xi−c)−log∑jnexp(xj−c)log-softmax = (x_i-c)-log\sum_j^nexp(x_j-c) log−softmax=(xi​−c)−logj∑n​exp(xj​−c)
代码实现:

def log_softmax(X):c, _ = torch.max(X, dim=1, keepdim=True)return X - c - torch.log(torch.sum(torch.exp(X-c), dim=1, keepdim=True))
>>> X = torch.tensor([[-1000., 1000., -1000.]])
>>> torch.exp(log_softmax(X))
tensor([[0., 1., 0.]])
# pytorch API实现
>>> torch.exp(F.log_softmax(X, dim=1))
tensor([[0., 1., 0.]])

3.5 log-softmax与softmax的区别3

结合3.3节的Trick及我自己的理解:

  1. 在pytorch的实现中,softmax的运算结果等价于对log_softmax的结果作指数运算
>>> X = torch.tensor([[-1000., 1000., -1000.]])
>>> torch.exp(F.log_softmax(X, dim=1)) == F.softmax(X)
tensor([[True, True, True]])
  1. 使用logloglog运算之后求导更方便,可以加快反向传播的速度4
    ∂∂xilogsoftmax=∂∂xi[xi−log∑jnexp(xj)]=1−softmax(xi)\begin{aligned} \frac{\partial}{\partial x_i}logsoftmax&=\frac{\partial}{\partial x_i} [{x_i - log\sum_j^nexp(x_j)]} \\ &= 1 - softmax(x_i) \end{aligned}∂xi​∂​logsoftmax​=∂xi​∂​[xi​−logj∑n​exp(xj​)]=1−softmax(xi​)​


  1. BTTB你不知道的softmax ↩︎

  2. The Log-Sum_Exp Trick ↩︎

  3. log-softmax与softmax的区别 ↩︎

  4. 动手学深度学习:softmax回归 ↩︎

深入理解softmax相关推荐

  1. 一分钟理解softmax函数(超简单)

    做过多分类任务的同学一定都知道softmax函数.softmax函数,又称归一化指数函数.它是二分类函数sigmoid在多分类上的推广,目的是将多分类的结果以概率的形式展现出来.下图展示了softma ...

  2. 深入理解softmax函数

    Softmax回归模型,该模型是logistic回归模型在多分类问题上的推广,在多分类问题中,类标签  可以取两个以上的值.Softmax模型可以用来给不同的对象分配概率.即使在之后,我们训练更加精细 ...

  3. 理解 softmax 和 NLL 损失函数 (the negative log-likelihood) 以及求导过程

    本文转载自 https://ljvmiranda921.github.io/notebook/2017/08/13/softmax-and-the-negative-log-likelihood/ 有 ...

  4. 神经网络中的激活函数与损失函数深入理解推导softmax交叉熵

    神经网络中的激活函数与损失函数&深入理解softmax交叉熵 前面在深度学习入门笔记1和深度学习入门笔记2中已经介绍了激活函数和损失函数,这里做一些补充,主要是介绍softmax交叉熵损失函数 ...

  5. 【机器学习基础】(三):理解逻辑回归及二分类、多分类代码实践

    本文是机器学习系列的第三篇,算上前置机器学习系列是第八篇.本文的概念相对简单,主要侧重于代码实践. 上一篇文章说到,我们可以用线性回归做预测,但显然现实生活中不止有预测的问题还有分类的问题.我们可以从 ...

  6. Softmax(假神经网络)与词向量的训练

    今天终于要完成好久之前的一个约定了~在很久很久以前的<如果风停了,你会怎样>中,小夕提到了"深刻理解了sigmoid的同学一定可以轻松的理解用(假)深度学习训练词向量的原理&qu ...

  7. 机器学习:理解逻辑回归及二分类、多分类代码实践

    作者 | caiyongji   责编 | 张红月 来源 | 转载自 caiyongji(ID:cai-yong-ji) 本文的概念相对简单,主要侧重于代码实践.现实生活中不止有预测的问题还有分类的问 ...

  8. softmax函数名字的由来(代数几何原理)——softmax前世今生系列(2)

    导读: softmax的前世今生系列是作者在学习NLP神经网络时,以softmax层为何能对文本进行分类.预测等问题为入手点,顺藤摸瓜进行的一系列研究学习.其中包含: 1.softmax函数的正推原理 ...

  9. softmax 和 log-likelihood(对数似然) 损失函数

    文章目录 `softmax`神经元 `log-likelihood` 损失函数 softmax神经元 softmax神经元的想法其实就是位神经网络定义一种新式的输出层,开始时和S型神经元一样,首先计算 ...

最新文章

  1. spark streaming 入门例子
  2. uniapp商城_【程序源代码】商城小程序
  3. 键盘工具栏的快速集成--IQKeyboardManager
  4. HDU 5936 Difference
  5. osx doc to html,macos – 在OSX上安装Git HTML帮助
  6. pb 动态改变DW的WHERE子句
  7. java获取inputstream_Java:我怎样才能从inputStream获取编码?
  8. matlab图上面加箭头,如何在matlab中显示箭头
  9. Xcode7.0 更新完后,网络请求报错
  10. 万物并作,吾以观复|OceanBase 政企行业实践
  11. 宝塔面板本地调试网站提示域名解析错误的问题
  12. android root权限获取失败,安卓手机为什么获取Root权限失败?Root失败是什么原因...
  13. 2021届Java开发求职-------面试实战之Vivo提前批
  14. 防劫持工具,介绍几款浏览器劫持修复工具
  15. 【分享视频资源】React JS教程
  16. Gym实践(一)——环境安装
  17. 给媳妇做一个记录心情的小程序
  18. 脑波设备mindwave二次开发框架
  19. 监视资本主义:智能陷阱
  20. kkFileView代码分析(四)——office文件的转换(1)office插件管理

热门文章

  1. 拉普拉斯分布,高斯分布,L1 L2
  2. 螺旋无限延伸_如果有一个无限向上螺旋延伸的楼梯,人能到达多高的位置?
  3. Thymeleaf语法详解
  4. 【重磅】雷军直播小米无人机,万万没想到会炸!机!
  5. 各种品牌液晶显示器的面板类型
  6. Binder service入门–创建native binder service
  7. go语言程序逆向整理
  8. [LGOJ3950]部落冲突——[LCT]
  9. 如何制作网页2:如何学习制作网页
  10. 做了一个 仿qq的APP