详解自注意力机制及其在LSTM中的应用

注意力机制(Attention Mechanism)最早出现在上世纪90年代,应用于计算机视觉领域。2014年,谷歌Mnih V等人[1] 在图像分类中将注意力机制融合至RNN中,取得了令人瞩目的成绩,随后注意力机制也开始在深度学习领域受到广泛关注,在自然语言处理领域,Bahdanau等人[2] 将注意力机制融合至编码-解码器中,在翻译任务取得不错的效果。而真正让注意力机制大火的是2017年,谷歌提出的Transformer[3],它提出了自注意力机制(self-Attention Mechanism),摒弃了RNN和CNN,充分挖掘了DNN的特性,刷新了11项NLP任务的精度,震惊了深度学习领域。
注意力机制基于人类的视觉注意力,人在观察物体的时候往往会把重点放在部分特征上,注意力机制就是根据这个特点,基于我们的目标,给强特征给予更大的权重,而弱特征给予较小权重,甚至0权重。

文章目录

  • 详解自注意力机制及其在LSTM中的应用
  • 1.注意力机制本质
  • 2.自注意力机制
  • 3.多头自注意力
  • 4.(多头)自注意力机制在LSTM中的应用
  • 5.结语

1.注意力机制本质

注意力机制(Attention Mechanism)的本质是:对于给定目标,通过生成一个权重系数对输入进行加权求和,来识别输入中哪些特征对于目标是重要的,哪些特征是不重要的;
为了实现注意力机制,我们将输入的原始数据看作< Key, Value>键值对的形式,根据给定的任务目标中的查询值 Query 计算 Key 与 Query 之间的相似系数,可以得到Value值对应的权重系数, 之后再用权重系数对 Value 值进行加权求和, 即可得到输出。我们使用Q,K,V分别表示Query, Key和Value,注意力权重系数W的公式如下:
W = s o f t m a x ⁡ ( Q K T ) W =softmax⁡(QK^T ) W=softmax(QKT)
将注意力权重系数W与Value做点积操作(加权求和)得到融合了注意力的输出:
A t t e n t i o n ( Q , K , V ) = W ⋅ V = s o f t m a x ⁡ ( Q K T ) ⋅ V Attention(Q,K,V) = W·V=softmax⁡(QK^T )·V Attention(Q,K,V)=WV=softmax(QKT)V
注意力模型的详细结构如下图所示:

需要注意,如果Value是向量的话,加权求和的过程中是对向量进行加权,最后得到的输出也是一个向量。
可以看到,注意力机制可以通过对< Key, Query>的计算来形成一个注意力权重向量,然后对Value进行加权求和得到融合了注意力的全新输出,注意力机制在深度学习各个领域都有很多的应用。不过需要注意的是,注意力并不是一个统一的模型,它只是一个机制,在不同的应用领域,Query, Key和Value有不同的来源方式,也就是说不同领域有不同的实现方法。

2.自注意力机制

自注意力机制(self-Attention Mechanism),它最早由谷歌团队[34]在2017年提出,并应用于Transformer语言模型。自注意力机制可以在编码或解码中单独使用,相对于注意力机制,它更关注输入内部的联系,区别就是Q,K和V来自同一个数据源,也就是说Q,K和V由同一个矩阵通过不同的线性变换而来。
比如对于文本矩阵来说,利用自注意力机制可以实现文本内各词“互相注意”,即词与词之间产生注意力权重矩阵,然后对Value加权求和产生一个融合了自注意力的新文本矩阵。文本自注意力的实现步骤如下:

  1. 假设文本矩阵 i n p u t = R ( a × b ) input=R^{(a×b)} input=R(a×b),三个变换矩阵(卷积核): ω q , ω k ∈ R ( b × d ) 、 ω v ∈ R ( b × c ) ω^q,ω^k∈ R^{(b×d)}、ω^v∈ R^{(b×c)} ωq,ωkR(b×d)ωvR(b×c)
  2. Q、K、V变换:文本矩阵和三个权重矩阵做线性变换,得到 Q , K ∈ R ( a × d ) 、 V ∈ R ( a × c ) Q,K∈ R^{(a×d)}、V∈ R^{(a×c)} Q,KR(a×d)VR(a×c):
    Q = i n p u t ω q , K = i n p u t ω k , V = i n p u t ω v Q =input ω^q, K =input ω^k, V =input ω^v Q=inputωq,K=inputωk,V=inputωv
  3. 缩放点积: Q × K T Q×K^T Q×KT然后乘以一个 1 / d k 1/\sqrt{d_k} 1/dk d k d_k dk为K的维度, 1 / d k 1/\sqrt{d_k} 1/dk

    为缩放因子,防止内积数值过大影响神经网络的学习),得到注意力得分矩阵 G ∈ R ( a × a ) G∈ R^{(a×a)} GR(a×a),G的行表示某个词在各个词上的得分:
    G = Q K T / d k G =QK^T/\sqrt{d_k} G=QKT/dk

  4. 得到注意力权重矩阵: s o f t m a x ( G ) softmax(G) softmax(G)表示注意力权重矩阵W :
    W = s o f t m a x ( ( Q K T ) / d k ) W=softmax((QK^T)/\sqrt{d_k}) W=softmax((QKT)/dk

    )
  5. 得到结果矩阵: W* V 得到一个结果矩阵Attention∈ R^(a×c),该矩阵就是一个全新的融合了注意力机制的文本矩阵z:
    z = A t t e n t i o n ( Q , K , V ) = s o f t m a x ( Q K T / d k ) V z=Attention(Q,K,V) =softmax(QK^T/\sqrt{d_k})V z=Attention(Q,K,V)=softmax(QKT/dk

    )V

在上述公式中,变换矩阵 ω q 、 ω k 、 ω v ω^q、ω^k、ω^v ωqωkωv都是神经网络的参数,可以随着反向传播而修改,通过修改这些变换矩阵来达到自注意力转移的目的。

3.多头自注意力

若为多头自注意力机制,则有多组卷积核 ω i q , ω i v , ω i k ω_i^q,ω_i^v,ω_i^k ωiq,ωiv,ωik,将步骤2-5进行h次得到h组结果矩阵 ( z 1 , . . . , z h ) (z_1,...,z_h ) (z1,...,zh),将 ( z 1 , . . . , z h ) (z_1,...,z_h ) (z1,...,zh)拼接并做一次线性变换 ω z ω^z ωz就得到了我们想要的文本矩阵:
M u l t i H e a d ( Q , K , V ) = C o n c a t ( z 1 , . . . , z h ) ω z MultiHead(Q,K,V )= Concat(z_1,...,z_h ) ω^z MultiHead(Q,K,V)=Concat(z1,...,zh)ωz
缩放点积计算和多头自注意力机制计算过程如下图:

自注意力机制将文本输入视为一个矩阵,没有考虑文本序列信息,例如将K、V按行打乱,那么计算之后的结果是一样的,但是文本的序列是包含大量信息的,比如“虽然他很坏,但是我喜欢他”、“虽然我喜欢他,但是他很坏”,这是两个极性相反的句子,因此需要提取输入的相对或绝对的位置信息。
Positional Encoding计算公式如下:

式中,pos 表示位置index,i表示位置嵌入index。
得到位置编码后将原来的word embedding和Positional Encoding拼接形成最终的embedding作为多头自注意力计算的输入input embedding。

4.(多头)自注意力机制在LSTM中的应用

LSTM包含两个输出:

  • 所有时间步输出 O = [ O 1 , O 2 , … , O D ] O= [O_1,O_2,…,O_D] O=[O1,O2,,OD]
  • 最后时间步D的隐藏状态 H D H_D HD

由于 O = [ O 1 , O 2 , … , O D ] O= [O_1,O_2,…,O_D] O=[O1,O2,,OD]表示字/词的特征, H D H_D HD表示文本的特征,为了识别字对于文本的重要性,我们需要建立 H D H_D HDO O O的自注意力关系,即建立各时间步输出 O t O_t Ot对于 H D H_D HD的权重,由于LSTM本身就考虑了位置信息,因此不需要额外设置位置编码,自注意力机制在LSTM中的实现方法有两种:
1. 点积注意力[2]:Transfromer提出的自注意力实现方法
各时间步的输出 O t O_t Ot经线性变换后作为Key和Value,最后时间步的输出 H D H_D HD乘以矩阵 ω Q ω_Q ωQ作为Query。
在时间步t时, K e y t , V a l u e t , Q u e r y , 得 分 e t 和 权 重 a t Key_t,Value_t,Query,得分e_t和权重a_t KeytValuetQueryetat有如下计算公式:

式中,Query不随时间步而改变, ω K , ω Q , ω V ω_K,ω_Q,ω_V ωKωQωV是神经网络的参数,随反向传播而修改。将各时间步权重 a t a_t atV a l u e t Value_t Valuet加权求和,得到带有自注意力的文本向量:

为了获取多头自注意力,将上述公式进行h次,得到多头自注意力文本 z 1 , . . . , z h z_1,...,z_h z1,...,zh,将其拼接并做一次线性变换后作为最后输出:
M u l t i H e a d ( Q , K , V ) = C o n c a t ( z 1 , . . . , z h ) ω z MultiHead(Q,K,V)= Concat(z_1,...,z_h ) ω_z MultiHead(Q,K,V)=Concat(z1,...,zh)ωz
式中,h为自注意力的头数,多头自注意力结构如下图:

点积注意力Pytorch源码:

self.w_q = nn.Linear(ARGS.hidden_dim * (ARGS.bidirect + 1) * ARGS.n_layers, ARGS.dim_k, bias=False)self.w_k = nn.Linear((ARGS.bidirect + 1) * ARGS.hidden_dim, ARGS.dim_k, bias=False)self.w_v = nn.ModuleList([nn.Linear((ARGS.bidirect + 1) * ARGS.hidden_dim, ARGS.dim_v, bias=False)for _ in range(ARGS.num_heads)])self.w_z2 = nn.Linear(ARGS.num_heads * ARGS.dim_v, ARGS.dim_v, bias=False)def MultiAttention1(self, lstm_out, h_n):batch_size, Doc_size, dim = lstm_out.shapex = []for i in range(h_n.size(0)):x.append(h_n[i, :, :])hidden = torch.cat(x, dim=-1)dk = ARGS.dim_k // ARGS.num_heads  # dim_k of each headq_n = self.w_q(hidden).reshape(batch_size, ARGS.num_heads, dk).unsqueeze(dim=-1)key = self.w_k(lstm_out).reshape(batch_size, Doc_size, ARGS.num_heads, dk).transpose(1, 2)value = [wv(lstm_out).transpose(1, 2) for wv in self.w_v]  # value: n* [batch_size, dim_v, Doc_size]weights = torch.matmul(key, q_n).transpose(0, 1) / sqrt(dk)  # weights: [n, batch_size, Doc_size, 1]soft_weights = F.softmax(weights, 2)out = [torch.matmul(v, w).squeeze() for v,w in zip(value,soft_weights)]# out[i]:[batch_size, dim_v, Doc_size] × [batch_size, Doc_size, 1] -> [batch_size, dim_v]# out: [batch_size, dim] * nout = torch.cat(out, dim=-1)out = self.w_z2(out)return out, soft_weights.data  # out : [batch_size, dim_v]

2. 加法注意力[3]:Bahdanau 提出的加法注意力
将最后时间步的隐藏状态 H D H_D HD和各时间步输出 O t O_t Ot拼接作为Query,各时间步输出 O t O_t Ot线性变换后作为Value,线性变换矩阵 ω k ω_k ωk作为Key,Query和Value相乘后作为结果矩阵z,时刻t有如下公式:
Q t = ω q ( O t + H D ) Q_t=ω_q(O_t+H_D) Qt=ωq(Ot+HD)
V t = ω v O t V_t=ω_vO_t Vt=ωvOt
z t = t a n h ( Q t ω k ) V t z_t=tanh(Q_tω_k)V_t zt=tanh(Qtωk)Vt
结果矩阵为 z = [ z 1 , z 2 , . . . , z t , . . . , z D ] z=[z_1,z_2,...,z_t,...,z_D] z=[z1,z2,...,zt,...,zD]
若为多头自注意力,则进行h次操作后获得多个结果矩阵后拼接再做一次线性变换作为输出,如下图:

加法注意力Pytorch源码:

       self.w_v = nn.ModuleList([nn.Linear((ARGS.bidirect + 1) * ARGS.hidden_dim, ARGS.dim_v, bias=False)for _ in range(ARGS.num_heads)])self.w_z = nn.Linear(ARGS.num_heads * ARGS.dim_v, ARGS.dim_v, bias=False)self.w_q = nn.Linear((ARGS.bidirect + 1) * ARGS.hidden_dim * (ARGS.n_layers + 1), ARGS.dim_k, bias=True)self.w_k_Mul = nn.Linear(ARGS.dim_k // ARGS.num_heads, 1, bias=False)def MultiAttention4(self, lstm_out, h_n):batch_size, Doc_size, dim = lstm_out.shapex = []for i in range(h_n.size(0)):x.append(h_n[i, :, :])hidden = torch.cat(x, dim=-1).unsqueeze(dim=-1)ones = torch.ones(batch_size, 1, Doc_size).to(device)hidden = torch.bmm(hidden, ones).transpose(1, 2)# 对lstm_out和hidden进行concath_i = torch.cat((lstm_out, hidden), dim=-1)dk = ARGS.dim_k // ARGS.num_heads  # dim_k of each head# 分头,即,将h_i和权值矩阵w_q相乘的结果按列均分为n份,纬度变化如下:# [batch_size, Doc_size, num_directions*hidden_dim*(1+n_layer)] -> [batch_size, Doc_size, dim_k]# ->[batch_size, Doc_size, n, dk] -> [batch_size, n, Doc_size, dk]query = self.w_q(h_i).reshape(batch_size, Doc_size, ARGS.num_heads, dk).transpose(1, 2)query = torch.tanh(query)  # query: [batch_size, n, Doc_size, dk]# 各头分别乘以不同的key,纬度变化如下:# [batch_size, n, Doc_size, dk] * [batch_size, n, dk, 1]# -> [batch_size, n, Doc_size, 1] -> [batch_size, n, Doc_size]weights = self.w_k_Mul(query).transpose(0, 1) / sqrt(dk)  # weights: [n, batch_size, Doc_size, 1]value = [wv(lstm_out).transpose(1, 2) for wv in self.w_v]  # value: n* [batch_size, dim_v, Doc_size]soft_weights = F.softmax(weights, 2)# value:[batch_size, dim, Doc_size]out = [torch.matmul(v, w).squeeze() for v, w in zip(value, soft_weights)]# out[i]:[batch_size, dim, Doc_size] × [batch_size, Doc_size, 1] -> [batch_size, dim]# out: [batch_size, dim] * nout = torch.cat(out, dim=-1)# out: [batch_size, dim * n]# print(out.size())out = self.w_z(out)  # 做一次线性变换,进一步提取特征return out, soft_weights.data  # out : [batch_size, hidden_dim * num_directions]

Transformer中指出,在高维度的情况下加法注意力的精度优于点积注意力,但是可以通过乘以缩放因子 1 / d k 1/\sqrt{d_k} 1/dk

抵消这种影响:
以下是论文[2]的3.2.1小节的原文翻译:
最常用的两个注意力函数是加法注意力[3]和点积(多重复制)注意力。点积注意与我们的算法相同,只是比例因子为 1 / d k 1/\sqrt{d_k} 1/dk

。加法注意力利用一个具有单个隐层的前馈网络来计算相容函数。虽然两者在理论复杂度上相似,但由于可以使用高度优化的矩阵乘法码来实现,因此在实践中,点积注意力速度更快,空间效率更高。
对于较小的 d k d_k dk,这两种机制效果相近,对于较大的 d k d_k dk值,加法注意力优于点积注意力。我们怀疑,对于较大的 d k d_k dk值,点积在数量级上增长很大,从而将softmax函数推到梯度非常小的区域。为了抵消这种影响,我们将点积缩放 1 / d k 1/\sqrt{d_k} 1/dk

5.结语

本文内容主要参考自下述三篇论文以及知乎博文,并对其进行理解和整理。本人也并未完全掌握自注意力机制,想要完全掌握自注意力机制建议阅读这三篇论文的原文。

有错误欢迎指正!

需要源码可私信我哦^ ^


[1] Mnih V, Heess N, Graves A, et al. Recurrent Models of Visual Attention. arXiv preprint, arXiv: 1406.6247 [ cs. CL] 2014.
[2] Vaswani A, Attention Is All You Need, arXiv preprint, arXiv: 1706.03762 [cs.CL] 2017.
[3] Bahdanau D, Cho K, Bengio Y. Neural Machine Translation by Jointly Learning to Align and Translate[J]. Computer Science, 2014.

详解自注意力机制及其在LSTM中的应用相关推荐

  1. 详解LibraBFT共识机制

    [Libra 技术解读]详解LibraBFT共识机制 ---------------- 版权声明:本文为CSDN博主「百度超级链xuper」的原创文章,遵循CC 4.0 BY-SA版权协议,转载请附上 ...

  2. 【Libra 技术解读】详解LibraBFT共识机制

    Libra技术系列解读 往期回顾: move语言简介 move语法.解释器和验证器 本期详解"LibraBFT共识机制" Libra白皮书中关于共识机制的描述 Libra 区块链采 ...

  3. 详解JVM类加载机制

    详解JVM类加载机制 笔者的笔记都记录在有道云里面,因为公司原因办公电脑无法使用有道云,正好借此机会整理下以前的笔记顺便当做巩固复习了,也因为记笔记的时候不会记录这些知识来源何地,所以如果发现原创后可 ...

  4. c语言handler指针,详解C++ new-handler机制

    当 operator new 不能满足一个内存分配请求时,它抛出一个 exception(异常).很久以前,他返回一个 null pointer(空指针),而一些比较老的编译器还在这样做.你依然能达到 ...

  5. Linux系统调用详解(实现机制分析)

    为什么需要系统调用   linux内核中设置了一组用于实现系统功能的子程序,称为系统调用.系统调用和普通库函数调用非常相似,只是系统调用由操作系统核心提供,运行于内核态,而普通的函数调用由函数库或用户 ...

  6. 自适应注意力机制在Image Caption中的应用

    在碎片化阅读充斥眼球的时代,越来越少的人会去关注每篇论文背后的探索和思考. 在这个栏目里,你会快速 get 每篇精选论文的亮点和痛点,时刻紧跟 AI 前沿成果. 点击本文底部的「阅读原文」即刻加入社区 ...

  7. python最小值函数_Python3 min() 函数详解 获取多个参数或列表中的最小值

    Python3 min() 函数详解 获取多个参数或列表中的最小值 min()函数的主要作用是获取对象中最小的值,参数可以是任何可迭代对象(字符串.列表.元组.字典等),可以是一个参数内的值进行对比, ...

  8. xp如何添加桌面计算机回收站,详解桌面回收站图标在XP电脑中操作删除的详细步骤...

    我们在电脑的很多的设置中,很多的用户都是可以打开不同的版本来设置电脑的问题的,对于电脑的回收站图标的设置有的小伙伴不是很喜欢使用桌面的这个图标怎么直接删除回收站图标的呢,今天小编就来跟大家分享一下XP ...

  9. Tensorflow 2.x(keras)源码详解之第七章:keras中的tf.keras.layers

      大家好,我是爱编程的喵喵.双985硕士毕业,现担任全栈工程师一职,热衷于将数据思维应用到工作与生活中.从事机器学习以及相关的前后端开发工作.曾在阿里云.科大讯飞.CCF等比赛获得多次Top名次.现 ...

最新文章

  1. 计算机财务应用实验心得,计算机会计实习心得-20210628124643.doc-原创力文档
  2. 算法刷题宝典.pdf
  3. [SUCTF2018]babyre [ACTF新生赛2020]fungame
  4. linux 服务配置
  5. python 打开某个exe_python定时检查启动某个exe程序(如果exe挂了)
  6. dev sda2 linux lvm,VM下LINUX完美增加硬盘空间(LVM)
  7. 解析WeNet云端推理部署代码
  8. 知识蒸馏 | 综述:蒸馏机制
  9. linux配置c11,C11标准的泛型机制
  10. 2020腾讯广告算法大赛——算法小白的复盘
  11. 浏览器产生乱码的原因
  12. strtok函数的实现
  13. 过滤器和拦截器(SpringMVC实现)
  14. mysql脏读解决方案_MySQL为什么可以解决脏读和不可重复读?
  15. 那一年,我与电脑结下了不解之缘
  16. unity跑酷怎么添加金币_叫好不叫买?《跑酷老奶奶》游戏评测
  17. unity jump游戏源码展示图
  18. 一个牛逼的Bug!一张“壁纸”让三星手机秒变砖!
  19. Gromacs-自由能微扰(FEP)
  20. win10 安装eclipse一直报错 ERROR: org.eclipse.equinox.p2.engine code=4 An error occurr

热门文章

  1. C语言0基础全面教程
  2. VSCode安装和Python安装及其配置
  3. 应届生嵌入式面试题总结——嵌入式基础
  4. easyexcel 合并单元格
  5. 计算机专业英语简介模板,计算机专业英文简历模版
  6. ${ew.customSqlSegment}和${ew.sqlSegment}
  7. 硬件编解码(一)硬件编解码介绍
  8. mysql索引的使用
  9. RPC实现和原理解析
  10. 【FPGA】SCCB协议+ov5640摄像头