Self-attention中为什么softmax要除d_k
我觉得这是一个很有意思的问题,简单但是很细节。先说结论,是为了保证梯度的平稳。那怎么个意思?
首先说向量(行向量和列向量都一样),他们的点乘和叉乘。
向量的内积:也叫点乘,结果是一个数。两个向量对应位相乘再求和。要求向量a和b的维度要一样。
a ⃗ ∗ b ⃗ = ( a 1 ∗ b 1 + a 2 ∗ b 2 + ⋯ + a n ∗ b n ) \vec{a}*\vec{b}=(a_1*b_1+a_2*b_2+\cdots+a_n*b_n) a ∗b =(a1∗b1+a2∗b2+⋯+an∗bn)
内积的几何意义:计算两个向量之间的夹角或者向量b在向量a上的投影。
a ⃗ ∗ b ⃗ = ∣ a ∣ ∗ ∣ b ∣ ∗ c o s ( θ ) ( 余 弦 定 理 可 以 证 ) \vec{a}*\vec{b}=|a|*|b|*cos(\theta)(余弦定理可以证) a ∗b =∣a∣∗∣b∣∗cos(θ)(余弦定理可以证)
θ = a r c c o s ( a ⃗ ∗ b ⃗ ∣ a ∣ ∣ b ∣ ) \theta=arccos(\frac{\vec{a}*\vec{b}}{|a||b|}) θ=arccos(∣a∣∣b∣a ∗b )
向量的外积:也叫叉乘,结果是一个新的向量。具体的来说它是a向量和b向量组成平面的法向量。
a ⃗ x b ⃗ = ∣ i j k x a y a z a x b y b z b ∣ \vec{a}x\vec{b}=\begin{vmatrix} i&j&k\\ x_a&y_a&z_a\\ x_b&y_b&z_b \end{vmatrix} a xb =∣∣∣∣∣∣ixaxbjyaybkzazb∣∣∣∣∣∣
下边进入正文Self-attention的细节。
S e l f a t t e n t i o n = s o f t m a x ( Q K T d k ) V Self attention=softmax(\frac{QK^T}{\sqrt{d_k}})V Selfattention=softmax(dk QKT)V
读过《Attention is all you need》我们就知道QKV三个矩阵都是X的线性变换。这里为了简单,我们认为QKV都是同样的一个矩阵,也就是Q=K=V。假设Q是一个行向量,维度为 d k d_k dk。那么我们可以知道 Q Q T QQ^T QQT其实在计算每个元素之间的相似度。而且不存在上下文关系,也就是全局的相似性关系。那么继续我们如果假设Q是从一个标准正太分布(0均值,1方差的高斯分布)中产生的,那么 Q Q T QQ^T QQT就也是0均值, 2 d k 2d_k 2dk为方差的。为什么?
因 Q Q T QQ^T QQT相乘之后我理解现在是一个卡方分布,不知道这里理解的对不对,希望和大家一起探讨。那么卡方分布的方差 E ( x 2 ) = 2 d k E(x^2)= 2d_k E(x2)=2dk。
所以为了让 Q Q T QQ^T QQT的方差回到1,我们需要除 d k \sqrt{d_k} dk 。又来了,为什么想要让方差回到1呢?
因为方差大, Q Q T QQ^T QQT中出现大值的可能性就大,下边引用我看到的文章的一段话。
当 d k d_k dk很大时,意味着 Q Q T QQ^T QQT的方差就很大,分布会趋于陡峭(分布的方差大,分布就会集中在绝对值大的区域),就会使得softmax()之后使得值出现两极分化的状态。(https://blog.csdn.net/qq_44846512/article/details/114364559)
也就是说方差大,那么经过softmax后输出的矩阵softmax( Q Q T QQ^T QQT),会很陡峭。这句话乍一看可能比较含糊,我后来自己特意看了下结果,就明白了。
import torch
import matplotlib.pyplot as plt
import mathdef main():matsize = 10q = torch.randn(matsize)k = torch.randn(matsize*matsize)v = torch.randn(matsize*matsize*matsize)c1 = q*qc2 = k*kc3 = v*vfor var in q, k, v, c1, c2, c3:print("mean is %f, div is %f." % (var.mean(), var.var()))ax1 = plt.subplot(331) # c1 originplt.plot(torch.arange(c1.shape[0]), c1)ax2 = plt.subplot(334) # c1 softmaxplt.plot(torch.arange(c1.shape[0]), torch.nn.functional.softmax(c1))ax3 = plt.subplot(332) # c2 originplt.plot(torch.arange(c2.shape[0]), c2)ax4 = plt.subplot(335) # c2 softmaxplt.plot(torch.arange(c2.shape[0]), torch.nn.functional.softmax(c2))ax6 = plt.subplot(338) # c2 softmax with sqrt dkplt.plot(torch.arange(c2.shape[0]),torch.nn.functional.softmax(c2/math.sqrt(c2.shape[0])))ax7 = plt.subplot(333) # c3 originplt.plot(torch.arange(c3.shape[0]), c3)ax8 = plt.subplot(336) # c3 softmaxplt.plot(torch.arange(c3.shape[0]), torch.nn.functional.softmax(c3))ax9 = plt.subplot(339) # c3 softmax with sqrt dkplt.plot(torch.arange(c3.shape[0]),torch.nn.functional.softmax(c3/math.sqrt(c3.shape[0])))plt.savefig("cov.jpg")plt.show()if __name__ == "__main__":main()
通过下边的这个图
图中最上边一行是softmax之前的结果,中间一行是没有除 d k d_k dk的softmax结果,最后一行是除了 d k d_k dk的softmax结果。可以看出在不除 d k d_k dk的时候softmax的结果只会在输入的最大值或者几个大值附近出现,看起来非常陡峭。当输入除了 d k d_k dk以后我们发现输入数据的分布大部分都保留了下来,这样的好处就是可以在梯度回传的时候让梯度比较平稳。而且当 d k d_k dk越大,影响越明显(从左向右 d k d_k dk越来越大)。
这就是为什么Self-attention中要除 d k d_k dk。这也是我看了一些网上的资料后自己的理解,只不过我觉得其中有不清楚的地方自己又想了下写下来而已。
参考文献
- https://www.zhihu.com/question/293696778 卡方分布方差计算
- https://blog.csdn.net/qq_44846512/article/details/114364559 关于这个问题写的也不错的博客
Self-attention中为什么softmax要除d_k相关推荐
- 如何理解self attention中的QKV矩阵
如何理解self attention中的QKV矩阵 疑问:三个矩阵的形状是一样的(embd_dim*embd_dim),作用也都是对输入句子的embedding做线性变换(tf.matmul(Q,in ...
- 通俗易懂:Attention中的Q、K、V是什么?怎么得到Q、K、V?
说一下Attention中的QKV是什么,再举点例子说明QKV怎么得到.还是结合例子明白的快. Attention中Q.K.V是什么? 首先Attention的任务是获取局部关注的信息.Attenti ...
- 神经网络学习中的SoftMax与交叉熵
简 介: 对于在深度学习中的两个常见的函数SoftMax,交叉熵进行的探讨.在利用paddle平台中的反向求微分进行验证的过程中,发现结果 与数学定义有差别.具体原因还需要之后进行查找. 关键词: 交 ...
- Fashion Mnist中的softmax应用
Fashion Mnist中的softmax应用 1.Fashion Mnist数据集 链接:https://pan.baidu.com/s/1jPYlBMg-MTQ7nxnLMwjsLQ 提取码:3 ...
- 极智AI | Attention 中 torch.chunk 的 TensorRT 实现
欢迎关注我的公众号 [极智视界],获取我的更多笔记分享 大家好,我是极智视界,本文介绍一下 Attention 中 torch.chunk 的 TensorRT 实现. Attention ...
- What is the Softmax Function?详解机器学习中的Softmax函数【小白菜可懂】
目录 定义 公式 计算 Softmax vs Sigmoid Softmax vs Sigmoid 计算 Softmax vs Argmax Softmax vs Argmax 计算 应用 神经网络中 ...
- caffe中的softmax layer
在caffe中的lenet实现最后一层是softmax layer,输出分类的结果,下面就简单介绍一下softmax回归. 1,首先,在caffe中,softmax layer输出的是原始的输入在每一 ...
- Attention中softmax的梯度消失及scaled原因
在bert模型中的attention构造中 Q:[batch, 12, seq, dk] K:[batch, 12, seq, dk] softmax中的梯度消失 x=(a,a,2a,4a)x=(a, ...
- 神经网络中的Softmax激活函数
Softmax回归模型是logistic回归模型在多分类问题上的推广,适用于多分类问题中,且类别之间互斥的场合. Softmax将多个神经元的输出,映射到(0,1)区间内,可以看成是当前输出是属于各个 ...
最新文章
- Python培训教程:pycharm常用的快捷键合集
- 自己动手在 Linux 系统实现一个 everything 程序
- java 鸡尾酒排序_冒泡排序及优化(Java实现)
- python输出数据到excel-如何使用python将传感器数据输出保存到excel中
- react draft api 简介
- 修改ALSM_EXCEL_TO_INTERNAL_TABLE的限制
- 最长公共子字符串(动态规划)
- 【editor】Source Insight定制之代码风格自动校准功能(AStyle的使用)
- 机器学习笔记(六) ---- 支持向量机(SVM)
- 净利下降7成、新业务“扛大旗” 阿里转型更需耐心
- wpf项目无法使用针式打印机_针式打印机的常见故障和解决方法2
- udp程序启动后检测都是未启动_【例子教程】联想Leez P710 物联网AI物体检测
- word公式大括号内容对齐
- html视频倍速播放,如何让网页视频倍速播放
- 糅合不好变搀合,搀合不好变搅合
- connected papers 白嫖攻略
- mysql主从配置duxi_配置MySQL主从复制(一主一从)
- 虚拟机Linux上网ping百度跳过的坑,亲测有效
- TabLayout自定义指示器及样式
- c语言打印五角星图案解锁教程,手机解锁图案五角星怎么?