我觉得这是一个很有意思的问题,简单但是很细节。先说结论,是为了保证梯度的平稳。那怎么个意思?

首先说向量(行向量和列向量都一样),他们的点乘和叉乘。

向量的内积:也叫点乘,结果是一个数。两个向量对应位相乘再求和。要求向量a和b的维度要一样。
a ⃗ ∗ b ⃗ = ( a 1 ∗ b 1 + a 2 ∗ b 2 + ⋯ + a n ∗ b n ) \vec{a}*\vec{b}=(a_1*b_1+a_2*b_2+\cdots+a_n*b_n) a ∗b =(a1​∗b1​+a2​∗b2​+⋯+an​∗bn​)
内积的几何意义:计算两个向量之间的夹角或者向量b在向量a上的投影。
a ⃗ ∗ b ⃗ = ∣ a ∣ ∗ ∣ b ∣ ∗ c o s ( θ ) ( 余 弦 定 理 可 以 证 ) \vec{a}*\vec{b}=|a|*|b|*cos(\theta)(余弦定理可以证) a ∗b =∣a∣∗∣b∣∗cos(θ)(余弦定理可以证)

θ = a r c c o s ( a ⃗ ∗ b ⃗ ∣ a ∣ ∣ b ∣ ) \theta=arccos(\frac{\vec{a}*\vec{b}}{|a||b|}) θ=arccos(∣a∣∣b∣a ∗b ​)
向量的外积:也叫叉乘,结果是一个新的向量。具体的来说它是a向量和b向量组成平面的法向量。
a ⃗ x b ⃗ = ∣ i j k x a y a z a x b y b z b ∣ \vec{a}x\vec{b}=\begin{vmatrix} i&j&k\\ x_a&y_a&z_a\\ x_b&y_b&z_b \end{vmatrix} a xb =∣∣∣∣∣∣​ixa​xb​​jya​yb​​kza​zb​​∣∣∣∣∣∣​
下边进入正文Self-attention的细节。
S e l f a t t e n t i o n = s o f t m a x ( Q K T d k ) V Self attention=softmax(\frac{QK^T}{\sqrt{d_k}})V Selfattention=softmax(dk​ ​QKT​)V
读过《Attention is all you need》我们就知道QKV三个矩阵都是X的线性变换。这里为了简单,我们认为QKV都是同样的一个矩阵,也就是Q=K=V。假设Q是一个行向量,维度为 d k d_k dk​。那么我们可以知道 Q Q T QQ^T QQT其实在计算每个元素之间的相似度。而且不存在上下文关系,也就是全局的相似性关系。那么继续我们如果假设Q是从一个标准正太分布(0均值,1方差的高斯分布)中产生的,那么 Q Q T QQ^T QQT就也是0均值, 2 d k 2d_k 2dk​为方差的。为什么?

因 Q Q T QQ^T QQT相乘之后我理解现在是一个卡方分布,不知道这里理解的对不对,希望和大家一起探讨。那么卡方分布的方差 E ( x 2 ) = 2 d k E(x^2)= 2d_k E(x2)=2dk​。

所以为了让 Q Q T QQ^T QQT的方差回到1,我们需要除 d k \sqrt{d_k} dk​ ​。又来了,为什么想要让方差回到1呢?

因为方差大, Q Q T QQ^T QQT中出现大值的可能性就大,下边引用我看到的文章的一段话。

当 d k d_k dk​很大时,意味着 Q Q T QQ^T QQT的方差就很大,分布会趋于陡峭(分布的方差大,分布就会集中在绝对值大的区域),就会使得softmax()之后使得值出现两极分化的状态。(https://blog.csdn.net/qq_44846512/article/details/114364559)

也就是说方差大,那么经过softmax后输出的矩阵softmax( Q Q T QQ^T QQT),会很陡峭。这句话乍一看可能比较含糊,我后来自己特意看了下结果,就明白了。

import torch
import matplotlib.pyplot as plt
import mathdef main():matsize = 10q = torch.randn(matsize)k = torch.randn(matsize*matsize)v = torch.randn(matsize*matsize*matsize)c1 = q*qc2 = k*kc3 = v*vfor var in q, k, v, c1, c2, c3:print("mean is %f, div is %f." % (var.mean(), var.var()))ax1 = plt.subplot(331)  # c1 originplt.plot(torch.arange(c1.shape[0]), c1)ax2 = plt.subplot(334)  # c1 softmaxplt.plot(torch.arange(c1.shape[0]), torch.nn.functional.softmax(c1))ax3 = plt.subplot(332)  # c2 originplt.plot(torch.arange(c2.shape[0]), c2)ax4 = plt.subplot(335)  # c2 softmaxplt.plot(torch.arange(c2.shape[0]), torch.nn.functional.softmax(c2))ax6 = plt.subplot(338)  # c2 softmax with sqrt dkplt.plot(torch.arange(c2.shape[0]),torch.nn.functional.softmax(c2/math.sqrt(c2.shape[0])))ax7 = plt.subplot(333)  # c3 originplt.plot(torch.arange(c3.shape[0]), c3)ax8 = plt.subplot(336)  # c3 softmaxplt.plot(torch.arange(c3.shape[0]), torch.nn.functional.softmax(c3))ax9 = plt.subplot(339)  # c3 softmax with sqrt dkplt.plot(torch.arange(c3.shape[0]),torch.nn.functional.softmax(c3/math.sqrt(c3.shape[0])))plt.savefig("cov.jpg")plt.show()if __name__ == "__main__":main()

通过下边的这个图

图中最上边一行是softmax之前的结果,中间一行是没有除 d k d_k dk​的softmax结果,最后一行是除了 d k d_k dk​的softmax结果。可以看出在不除 d k d_k dk​的时候softmax的结果只会在输入的最大值或者几个大值附近出现,看起来非常陡峭。当输入除了 d k d_k dk​以后我们发现输入数据的分布大部分都保留了下来,这样的好处就是可以在梯度回传的时候让梯度比较平稳。而且当 d k d_k dk​越大,影响越明显(从左向右 d k d_k dk​越来越大)。

这就是为什么Self-attention中要除 d k d_k dk​。这也是我看了一些网上的资料后自己的理解,只不过我觉得其中有不清楚的地方自己又想了下写下来而已。

参考文献

  1. https://www.zhihu.com/question/293696778 卡方分布方差计算
  2. https://blog.csdn.net/qq_44846512/article/details/114364559 关于这个问题写的也不错的博客

Self-attention中为什么softmax要除d_k相关推荐

  1. 如何理解self attention中的QKV矩阵

    如何理解self attention中的QKV矩阵 疑问:三个矩阵的形状是一样的(embd_dim*embd_dim),作用也都是对输入句子的embedding做线性变换(tf.matmul(Q,in ...

  2. 通俗易懂:Attention中的Q、K、V是什么?怎么得到Q、K、V?

    说一下Attention中的QKV是什么,再举点例子说明QKV怎么得到.还是结合例子明白的快. Attention中Q.K.V是什么? 首先Attention的任务是获取局部关注的信息.Attenti ...

  3. 神经网络学习中的SoftMax与交叉熵

    简 介: 对于在深度学习中的两个常见的函数SoftMax,交叉熵进行的探讨.在利用paddle平台中的反向求微分进行验证的过程中,发现结果 与数学定义有差别.具体原因还需要之后进行查找. 关键词: 交 ...

  4. Fashion Mnist中的softmax应用

    Fashion Mnist中的softmax应用 1.Fashion Mnist数据集 链接:https://pan.baidu.com/s/1jPYlBMg-MTQ7nxnLMwjsLQ 提取码:3 ...

  5. 极智AI | Attention 中 torch.chunk 的 TensorRT 实现

      欢迎关注我的公众号 [极智视界],获取我的更多笔记分享   大家好,我是极智视界,本文介绍一下 Attention 中 torch.chunk 的 TensorRT 实现.   Attention ...

  6. What is the Softmax Function?详解机器学习中的Softmax函数【小白菜可懂】

    目录 定义 公式 计算 Softmax vs Sigmoid Softmax vs Sigmoid 计算 Softmax vs Argmax Softmax vs Argmax 计算 应用 神经网络中 ...

  7. caffe中的softmax layer

    在caffe中的lenet实现最后一层是softmax layer,输出分类的结果,下面就简单介绍一下softmax回归. 1,首先,在caffe中,softmax layer输出的是原始的输入在每一 ...

  8. Attention中softmax的梯度消失及scaled原因

    在bert模型中的attention构造中 Q:[batch, 12, seq, dk] K:[batch, 12, seq, dk] softmax中的梯度消失 x=(a,a,2a,4a)x=(a, ...

  9. 神经网络中的Softmax激活函数

    Softmax回归模型是logistic回归模型在多分类问题上的推广,适用于多分类问题中,且类别之间互斥的场合. Softmax将多个神经元的输出,映射到(0,1)区间内,可以看成是当前输出是属于各个 ...

最新文章

  1. Python培训教程:pycharm常用的快捷键合集
  2. 自己动手在 Linux 系统实现一个 everything 程序
  3. java 鸡尾酒排序_冒泡排序及优化(Java实现)
  4. python输出数据到excel-如何使用python将传感器数据输出保存到excel中
  5. react draft api 简介
  6. 修改ALSM_EXCEL_TO_INTERNAL_TABLE的限制
  7. 最长公共子字符串(动态规划)
  8. 【editor】Source Insight定制之代码风格自动校准功能(AStyle的使用)
  9. 机器学习笔记(六) ---- 支持向量机(SVM)
  10. 净利下降7成、新业务“扛大旗” 阿里转型更需耐心
  11. wpf项目无法使用针式打印机_针式打印机的常见故障和解决方法2
  12. udp程序启动后检测都是未启动_【例子教程】联想Leez P710 物联网AI物体检测
  13. word公式大括号内容对齐
  14. html视频倍速播放,如何让网页视频倍速播放
  15. 糅合不好变搀合,搀合不好变搅合
  16. connected papers 白嫖攻略
  17. mysql主从配置duxi_配置MySQL主从复制(一主一从)
  18. 虚拟机Linux上网ping百度跳过的坑,亲测有效
  19. TabLayout自定义指示器及样式
  20. c语言打印五角星图案解锁教程,手机解锁图案五角星怎么?

热门文章

  1. 南威尔士警方称,2017年欧洲冠军联赛决赛使用的人脸识别技术错误率超过90%
  2. Python数据分析实用程序
  3. L1-054 福到了 (15分)题解
  4. 状态模式、有限状态机 Unity版本实现
  5. 什么东西改善睡眠质量,辅助睡眠好物推荐
  6. TX1刷机教程(安装caffe、cuda/cudnn)
  7. 制作Linux的优盘(usb)启动盘
  8. Php绘制棋盘,第二次DIY棋盘,纯手工绘制完美棋盘
  9. 网易视频云 php接口
  10. 席绢言情系列书评总序