原文名称:Attention Is All You Need
原文链接:https://arxiv.org/abs/1706.03762

如果不想看文章的可以看下我在b站上录的视频:https://b23.tv/gucpvt

最近Transformer在CV领域很火,Transformer是2017年Google在Computation and Language上发表的,当时主要是针对自然语言处理领域提出的(之前的RNN模型记忆长度有限且无法并行化,只有计算完tit_iti时刻后的数据才能计算ti+1t_{i+1}ti+1时刻的数据,但Transformer都可以做到)。在这篇文章中作者提出了Self-Attention的概念,然后在此基础上提出Multi-Head Attention,所以本文对Self-Attention以及Multi-Head Attention的理论进行详细的讲解。在阅读本文之前,建议大家先去看下李弘毅老师讲的Transformer的内容。本文的内容是基于李宏毅老师讲的内容加上自己阅读一些源码进行的总结。


文章目录

  • 前言
  • Self-Attention
  • Multi-Head Attention
  • Self-Attention与Multi-Head Attention计算量对比
  • Positional Encoding
  • 超参对比

前言

如果之前你有在网上找过self-attention或者transformer的相关资料,基本上都是贴的原论文中的几张图以及公式,如下图,讲的都挺抽象的,反正就是看不懂(可能我太菜的原因)。就像李弘毅老师课程里讲到的"不懂的人再怎么看也不会懂的"。那接下来本文就结合李弘毅老师课上的内容加上原论文的公式来一个个进行详解。


Self-Attention

下面这个图是我自己画的,为了方便大家理解,假设输入的序列长度为2,输入就两个节点x1,x2x_1, x_2x1,x2,然后通过Input Embedding也就是图中的f(x)f(x)f(x)将输入映射到a1,a2a_1, a_2a1,a2。紧接着分别将a1,a2a_1, a_2a1,a2分别通过三个变换矩阵Wq,Wk,WvW_q, W_k, W_vWq,Wk,Wv(这三个参数是可训练的,是共享的)得到对应的qi,ki,viq^i, k^i, v^iqi,ki,vi(这里在源码中是直接使用全连接层实现的,这里为了方便理解,忽略偏执)。

其中

  • qqq代表query,后续会去和每一个kkk进行匹配
  • kkk代表key,后续会被每个qqq匹配
  • vvv代表从aaa中提取得到的信息
  • 后续qqqkkk匹配的过程可以理解成计算两者的相关性,相关性越大对应vvv的权重也就越大

假设a1=(1,1),a2=(1,0),Wq=(1,10,1)a_1=(1, 1), a_2=(1,0), W^q= \binom{1, 1}{0, 1}a1=(1,1),a2=(1,0),Wq=(0,11,1)那么:
q1=(1,1)(1,10,1)=(1,2),q2=(1,0)(1,10,1)=(1,1)q^1 = (1, 1) \binom{1, 1}{0, 1} =(1, 2) , \ \ \ q^2 = (1, 0) \binom{1, 1}{0, 1} =(1, 1) q1=(1,1)(0,11,1)=(1,2),q2=(1,0)(0,11,1)=(1,1)
前面有说Transformer是可以并行化的,所以可以直接写成:
(q1q2)=(1,11,0)(1,10,1)=(1,21,1)\binom{q^1}{q^2} = \binom{1, 1}{1, 0} \binom{1, 1}{0, 1} = \binom{1, 2}{1, 1} (q2q1)=(1,01,1)(0,11,1)=(1,11,2)
同理我们可以得到(k1k2)\binom{k^1}{k^2}(k2k1)(v1v2)\binom{v^1}{v^2}(v2v1),那么求得的(q1q2)\binom{q^1}{q^2}(q2q1)就是原论文中的QQQ(k1k2)\binom{k^1}{k^2}(k2k1)就是KKK(v1v2)\binom{v^1}{v^2}(v2v1)就是VVV。接着先拿q1q^1q1和每个kkk进行match,点乘操作,接着除以d\sqrt{d}d

得到对应的α\alphaα,其中ddd代表向量kik^iki的长度,在本示例中等于2,除以d\sqrt{d}d

的原因在论文中的解释是“进行点乘后的数值很大,导致通过softmax后梯度变的很小”,所以通过除以d\sqrt{d}d

来进行缩放。比如计算α1,i\alpha_{1, i}α1,i
α1,1=q1⋅k1d=1×1+2×02=0.71α1,2=q1⋅k2d=1×0+2×12=1.41\alpha_{1, 1} = \frac{q^1 \cdot k^1}{\sqrt{d}}=\frac{1\times 1+2\times 0}{\sqrt{2}}=0.71 \\ \alpha_{1, 2} = \frac{q^1 \cdot k^2}{\sqrt{d}}=\frac{1\times 0+2\times 1}{\sqrt{2}}=1.41 α1,1=d

q1k1
=
2

1×1+2×0
=
0.71α1,2=d

q1k2
=
2

1×0+2×1
=
1.41

同理拿q2q^2q2去匹配所有的kkk能得到α2,i\alpha_{2, i}α2,i,统一写成矩阵乘法形式:
(α1,1α1,2α2,1α2,2)=(q1q2)(k1k2)Td\binom{\alpha_{1, 1} \ \ \alpha_{1, 2}}{\alpha_{2, 1} \ \ \alpha_{2, 2}}=\frac{\binom{q^1}{q^2}\binom{k^1}{k^2}^T}{\sqrt{d}} (α2,1α2,2α1,1α1,2)=d

(q2q1)(k2k1)T

接着对每一行即(α1,1,α1,2)(\alpha_{1, 1}, \alpha_{1, 2})(α1,1,α1,2)(α2,1,α2,2)(\alpha_{2, 1}, \alpha_{2, 2})(α2,1,α2,2)分别进行softmax处理得到(α^1,1,α^1,2)(\hat\alpha_{1, 1}, \hat\alpha_{1, 2})(α^1,1,α^1,2)(α^2,1,α^2,2)(\hat\alpha_{2, 1}, \hat\alpha_{2, 2})(α^2,1,α^2,2),这里的α^\hat{\alpha}α^相当于计算得到针对每个vvv的权重。到这我们就完成了Attention(Q,K,V){\rm Attention}(Q, K, V)Attention(Q,K,V)公式中softmax(QKTdk){\rm softmax}(\frac{QK^T}{\sqrt{d_k}})softmax(dk

QKT
)
部分。


上面已经计算得到α\alphaα,即针对每个vvv的权重,接着进行加权得到最终结果:
b1=α^1,1×v1+α^1,2×v2=(0.33,0.67)b2=α^2,1×v1+α^2,2×v2=(0.50,0.50)b_1 = \hat{\alpha}_{1, 1} \times v^1 + \hat{\alpha}_{1, 2} \times v^2=(0.33, 0.67) \\ b_2 = \hat{\alpha}_{2, 1} \times v^1 + \hat{\alpha}_{2, 2} \times v^2=(0.50, 0.50) b1=α^1,1×v1+α^1,2×v2=(0.33,0.67)b2=α^2,1×v1+α^2,2×v2=(0.50,0.50)
统一写成矩阵乘法形式:
(b1b2)=(α^1,1α^1,2α^2,1α^2,2)(v1v2)\binom{b_1}{b_2} = \binom{\hat\alpha_{1, 1} \ \ \hat\alpha_{1, 2}}{\hat\alpha_{2, 1} \ \ \hat\alpha_{2, 2}}\binom{v^1}{v^2} (b2b1)=(α^2,1α^2,2α^1,1α^1,2)(v2v1)
到这,Self-Attention的内容就讲完了。总结下来就是论文中的一个公式:
Attention(Q,K,V)=softmax(QKTdk)V{\rm Attention}(Q, K, V)={\rm softmax}(\frac{QK^T}{\sqrt{d_k}})V Attention(Q,K,V)=softmax(dk

QKT)V


Multi-Head Attention

刚刚已经聊完了Self-Attention模块,接下来再来看看Multi-Head Attention模块,实际使用中基本使用的还是Multi-Head Attention模块。原论文中说使用多头注意力机制能够联合来自不同head部分学习到的信息。Multi-head attention allows the model to jointly attend to information from different representation subspaces at different positions.其实只要懂了Self-Attention模块Multi-Head Attention模块就非常简单了。

首先还是和Self-Attention模块一样将aia_iai分别通过Wq,Wk,WvW^q, W^k, W^vWq,Wk,Wv得到对应的qi,ki,viq^i, k^i, v^iqi,ki,vi,然后再根据使用的head的数目hhh进一步把得到的qi,ki,viq^i, k^i, v^iqi,ki,vi均分成hhh份。比如下图中假设h=2h=2h=2然后q1q^1q1拆分成q1,1q^{1,1}q1,1q1,2q^{1,2}q1,2,那么q1,1q^{1,1}q1,1就属于head1,q1,2q^{1,2}q1,2属于head2。


看到这里,如果读过原论文的人肯定有疑问,论文中不是写的通过WiQ,WiK,WiVW^Q_i, W^K_i, W^V_iWiQ,WiK,WiV映射得到每个head的Qi,Ki,ViQ_i, K_i, V_iQi,Ki,Vi吗:
headi=Attention(QWiQ,KWiK,VWiV)head_i = {\rm Attention}(QW^Q_i, KW^K_i, VW^V_i) headi=Attention(QWiQ,KWiK,VWiV)
但我在github上看的一些源码中就是简单的进行均分,其实也可以将WiQ,WiK,WiVW^Q_i, W^K_i, W^V_iWiQ,WiK,WiV设置成对应值来实现均分,比如下图中的Q通过W1QW^Q_1W1Q就能得到均分后的Q1Q_1Q1


通过上述方法就能得到每个headihead_iheadi对应的Qi,Ki,ViQ_i, K_i, V_iQi,Ki,Vi参数,接下来针对每个head使用和Self-Attention中相同的方法即可得到对应的结果。
Attention(Qi,Ki,Vi)=softmax(QiKiTdk)Vi{\rm Attention}(Q_i, K_i, V_i)={\rm softmax}(\frac{Q_iK_i^T}{\sqrt{d_k}})V_i Attention(Qi,Ki,Vi)=softmax(dk

QiKiT)Vi


接着将每个head得到的结果进行concat拼接,比如下图中b1,1b_{1,1}b1,1head1head_1head1得到的b1b_1b1)和b1,2b_{1,2}b1,2head2head_2head2得到的b1b_1b1)拼接在一起,b2,1b_{2,1}b2,1head1head_1head1得到的b2b_2b2)和b2,2b_{2,2}b2,2head2head_2head2得到的b2b_2b2)拼接在一起。


接着将拼接后的结果通过WOW^OWO(可学习的参数)进行融合,如下图所示,融合后得到最终的结果b1,b2b_1, b_2b1,b2


到这,Multi-Head Attention的内容就讲完了。总结下来就是论文中的两个公式:
MultiHead(Q,K,V)=Concat(head1,...,headh)WOwhereheadi=Attention(QWiQ,KWiK,VWiV){\rm MultiHead}(Q, K, V) = {\rm Concat(head_1,...,head_h)}W^O \\ {\rm where \ head_i = Attention}(QW_i^Q, KW_i^K, VW_i^V) MultiHead(Q,K,V)=Concat(head1,...,headh)WOwhereheadi=Attention(QWiQ,KWiK,VWiV)


Self-Attention与Multi-Head Attention计算量对比

在原论文章节3.2.2中最后有说两者的计算量其实差不多。Due to the reduced dimension of each head, the total computational cost is similar to that of single-head attention with full dimensionality.下面做了个简单的实验,这个model文件大家先忽略哪来的。这个Attention就是实现Multi-head Attention的方法,其中包括上面讲的所有步骤。

  • 首先创建了一个Self-Attention模块(单头)a1,然后把proj变量置为Identity(Identity对应的是Multi-Head Attention中最后那个WoW^oWo的映射,单头中是没有的,所以置为Identity即不做任何操作)。
  • 再创建一个Multi-Head Attention模块(多头)a2,然后设置8个head。
  • 创建一个随机变量,注意shape
  • 使用fvcore分别计算两个模块的FLOPs
import torch
from fvcore.nn import FlopCountAnalysisfrom model import Attentiondef main():# Self-Attentiona1 = Attention(dim=512, num_heads=1)a1.proj = torch.nn.Identity()  # remove Wo# Multi-Head Attentiona2 = Attention(dim=512, num_heads=8)# [batch_size, num_tokens, total_embed_dim]t = (torch.rand(32, 1024, 512),)flops1 = FlopCountAnalysis(a1, t)print("Self-Attention FLOPs:", flops1.total())flops2 = FlopCountAnalysis(a2, t)print("Multi-Head Attention FLOPs:", flops2.total())if __name__ == '__main__':main()

终端输出如下, 可以发现确实两者的FLOPs差不多,Multi-Head AttentionSelf-Attention略高一点:

Self-Attention FLOPs: 60129542144
Multi-Head Attention FLOPs: 68719476736

其实两者FLOPs的差异只是在最后的WOW^OWO上,如果把Multi-Head AttentioWOW^OWO也删除(即把a2的proj也设置成Identity),可以看出两者FLOPs是一样的:

Self-Attention FLOPs: 60129542144
Multi-Head Attention FLOPs: 60129542144

Positional Encoding

如果仔细观察刚刚讲的Self-Attention和Multi-Head Attention模块,在计算中是没有考虑到位置信息的。假设在Self-Attention模块中,输入a1,a2,a3a_1, a_2, a_3a1,a2,a3得到b1,b2,b3b_1, b_2, b_3b1,b2,b3。对于a1a_1a1而言,a2a_2a2a3a_3a3离它都是一样近的而且没有先后顺序。假设将输入的顺序改为a1,a3,a2a_1, a_3, a_2a1,a3,a2,对结果b1b_1b1是没有任何影响的。下面是使用Pytorch做的一个实验,首先使用nn.MultiheadAttention创建一个Self-Attention模块(num_heads=1),注意这里在正向传播过程中直接传入QKVQKVQKV,接着创建两个顺序不同的QKVQKVQKV变量t1和t2(主要是将q2,k2,v2q^2, k^2, v^2q2,k2,v2q3,k3,v3q^3, k^3, v^3q3,k3,v3的顺序换了下),分别将这两个变量输入Self-Attention模块进行正向传播。

import torch
import torch.nn as nnm = nn.MultiheadAttention(embed_dim=2, num_heads=1)t1 = [[[1., 2.],   # q1, k1, v1[2., 3.],   # q2, k2, v2[3., 4.]]]  # q3, k3, v3t2 = [[[1., 2.],   # q1, k1, v1[3., 4.],   # q3, k3, v3[2., 3.]]]  # q2, k2, v2q, k, v = torch.as_tensor(t1), torch.as_tensor(t1), torch.as_tensor(t1)
print("result1: \n", m(q, k, v))q, k, v = torch.as_tensor(t2), torch.as_tensor(t2), torch.as_tensor(t2)
print("result2: \n", m(q, k, v))

对比结果可以发现,即使调换了q2,k2,v2q^2, k^2, v^2q2,k2,v2q3,k3,v3q^3, k^3, v^3q3,k3,v3的顺序,但对于b1b_1b1是没有影响的。


为了引入位置信息,在原论文中引入了位置编码positional encodingsTo this end, we add "positional encodings" to the input embeddings at the bottoms of the encoder and decoder stacks.如下图所示,位置编码是直接加在输入的a={a1,...,an}a=\{a_1,...,a_n\}a={a1,...,an}中的,即pe={pe1,...,pen}pe=\{pe_1,...,pe_n\}pe={pe1,...,pen}a={a1,...,an}a=\{a_1,...,a_n\}a={a1,...,an}拥有相同的维度大小。关于位置编码在原论文中有提出两种方案,一种是原论文中使用的固定编码,即论文中给出的sine and cosine functions方法,按照该方法可计算出位置编码;另一种是可训练的位置编码,作者说尝试了两种方法发现结果差不多(但在ViT论文中使用的是可训练的位置编码)。


超参对比

关于Transformer中的一些超参数的实验对比可以参考原论文的表3,如下图所示。其中:

  • N表示重复堆叠Transformer Block的次数
  • dmodeld_{model}dmodel表示Multi-Head Self-Attention输入输出的token维度(向量长度)
  • dffd_{ff}dff表示在MLP(feed forward)中隐层的节点个数
  • h表示Multi-Head Self-Attention中head的个数
  • dk,dvd_k, d_vdk,dv表示Multi-Head Self-Attention中每个head的key(K)以及query(Q)的维度
  • PdropP_{drop}Pdrop表示dropout层的drop_rate

到这,关于Self-Attention、Multi-Head Attention以及位置编码的内容就全部讲完了,如果有讲的不对的地方希望大家指出。

详解Transformer中Self-Attention以及Multi-Head Attention相关推荐

  1. 一文详解transformer(Attention Is All You Need)原理

    谈起自然语言,就不得不说到现在大火的bert以及openai gpt-2,但是在理解这些模型之前,我觉得首先应该了解transformer,因本人水平有限,在看了transformer的论文之后也一知 ...

  2. redis watch使用场景_详解redis中的锁以及使用场景

    分布式锁 什么是分布式锁? 分布式锁是控制分布式系统之间同步访问共享资源的一种方式. 为什么要使用分布式锁? ​ 为了保证共享资源的数据一致性. 什么场景下使用分布式锁? ​ 数据重要且要保证一致性 ...

  3. 详解OpenCV中的Lucas Kanade稀疏光流单应追踪器

    详解OpenCV中的Lucas Kanade稀疏光流单应追踪器 1. 效果图 2. 源码 参考 这篇博客将详细介绍OpenCV中的Lucas Kanade稀疏光流单应追踪器. 光流是由物体或相机的运动 ...

  4. python操作目录_详解python中的文件与目录操作

    详解python中的文件与目录操作 一 获得当前路径 1.代码1 >>>import os >>>print('Current directory is ',os. ...

  5. python3中unicode怎么写_详解python3中ascii与Unicode使用

    这篇文章主要为大家详解python3中ascii与Unicode使用的相关资料,需要的朋友可以参考下# Auther: Aaron Fan ''' ASCII:不支持中文,1个英文占1个字节 Unic ...

  6. foreach php,详解PHP中foreach的用法和实例

    本篇文章介绍了详解PHP中foreach的用法和实例,详细介绍了foreach的用法,感兴趣的小伙伴们可以参考一下. 在PHP中经常会用到foreach的使用,而要用到foreach,就必须用到数组. ...

  7. python open 打开是什么类型的文件-详解Python中open()函数指定文件打开方式的用法...

    文件打开方式 当我们用open()函数去打开文件的时候,有好几种打开的模式. 'r'->只读 'w'->只写,文件已存在则清空,不存在则创建. 'a'->追加,写到文件末尾 'b'- ...

  8. python中list[1啥意思_详解Python中list[::-1]的几种用法

    本文主要介绍了Python中list[::-1]的几种用法,分享给大家,具体如下: s = "abcde" list的[]中有三个参数,用冒号分割 list[param1:para ...

  9. java 死锁 内存消耗_详解Java中synchronized关键字的死锁和内存占用问题

    先看一段synchronized 的详解: synchronized 是 java语言的关键字,当它用来修饰一个方法或者一个代码块的时候,能够保证在同一时刻最多只有一个线程执行该段代码. 一.当两个并 ...

  10. pythonnamedtuple定义类型_详解Python中namedtuple的使用

    namedtuple是Python中存储数据类型,比较常见的数据类型还有有list和tuple数据类型.相比于list,tuple中的元素不可修改,在映射中可以当键使用. namedtuple: na ...

最新文章

  1. 经典PID控制算法用C语言实现!
  2. [Leetcode]50. Pow(x, n)
  3. thinkphp F方法
  4. linux hibernate suspend 区别,实现Linux休眠(sleep/hibernate)和挂起(suspend)
  5. 5.4 Spring AOP
  6. 新年就是要你红!华为Mate 20 Pro馥蕾红璨星蓝来袭
  7. 建筑与建筑群综合布线系统工程验收规范_如果这9个方面考虑周到 你的综合布线系统工程可以竣工验收了...
  8. Android开发笔记(二十三)文件对话框FileDialog
  9. 连麦互动技术及其连麦调研
  10. 指示灯亮着,但是右边的数字小键盘不可用
  11. 斯坦福大学自然语言处理第二课“文本处理基础(Basic Text Processing)”
  12. 王者荣耀改重复名,空白名最低战力查询助手微信小程序源码下载
  13. 鸡啄米C++编程入门教程系列
  14. php sphinx应用场景,Sphinx+Scws 搭建千万级准实时搜索应用场景详解
  15. 我遇到的一些问题(空指针异常、jsp页面传值)
  16. python 高阶函数作业(3.16)
  17. 分布式快速批量获取网站标题关键字描述(TDK)接口api文档说明
  18. 微信小程序 open-type=contact
  19. 破解meclipse8.5方法
  20. 快乐地打牢基础(4)——树状数组

热门文章

  1. Winserver AD管理Powershell——GUI 计算机加入域
  2. ed2k链接文件,最快下载方式
  3. linux skype 4.3,在Arch Linux上安装Skype 4.3(最新版本)
  4. 服务器违反了协议,IMAP协议违规:未知消息的EXPUNGE响应?
  5. 南丁格尔玫瑰图 python_央视都在用的“南丁格尔玫瑰图”,原来Python也可以画...
  6. MATLAB绘制三维地图
  7. 倒计时1天,IMG、完美、腾讯技术大咖相聚直播间详解光线追踪技术
  8. win7 显示快捷方式扩展名 lnk
  9. 视频画面大小剪裁操作教程
  10. 金庸教你谈恋爱[这个写的太牛逼了,加上了天龙八部,感谢原作者]