Seq2Seq中的Attention
《Seq2Seq中的Attention》
Sequence to Sequence的结构在整个深度学习的进程中占有重要的角色,我在2017年做OCR的时候用这个,当时语音组做语音识别的同事也是用这个,而nlp组的做机器翻译的同事更是利用这个取得了不错的效果,尤其是Attention的引入让Sequence to Sequence的表现更加惊艳,所以这一经典的结构是值得被反复揣摩的,在此记录一下我对Seq2Seq的理解尤其是其中Attention机制,本文介绍Seq2Seq中一种计算Attention的方式。本文的Attention区别于self-Attention。
Key Words:Seq2Seq、RNN、Attention
Beijing, 2020
作者:RaySue
Code:https://github.com/bentrevett/pytorch-seq2seq
Agile Pioneer
文章目录
- 前言
- RNN 简介
- Encoder - Decoder
- Attention
- Attention 数学公式
- Attention代码
- Attention 计算示意图
- Hard Attention
- Soft Attention
- Attention 是如何解码的
- Attention应用
- 语音识别
- OCR
- Image captioning
- 参考
- Q&A
前言
Attention 被广泛用于序列到序列(seq2seq)模型,这是一种深度学习模型,在很多任务上都取得了成功,如:机器翻译、文本摘要、图像描述生成。谷歌翻译在 2016 年年末开始使用这种模型。有 2 篇开创性的论文对这些模型进行了解释,论文连接如下:
- Sequence to Sequence Learning with Neural Networks
- Learning Phrase Representations using RNN Encoder-Decoder for Statistical Machine Translation
无论是语音识别或是机器翻译等任务都存在序列间的对齐问题,和同为解决对齐问题的 CTC Loss 相媲美的就是本文所说的 Attention 机制,Attention模型虽然好,但是还是有自身的问题:1.适合短语识别,对长句子识别比较差;2.noisy data的时候训练不稳定,因此比较好的方案是使得Attention与CTC进行结合,实验证实对比Attention模型还有CTC模型,Attention+CTC模型更快的收敛了,这得益于初始阶段CTC的阶段对齐更准确,使得Attention模型训练收敛更快,本文旨在讲清楚嵌入在seq2seq中的 Attention 机制。
- seq2seq模型一般形式,以机器问答为例,在Encoder阶段,每个时刻RNN的输入 = 每个时刻的输入 + 上一时刻的 hidden state;而在Decoder阶段,每个时刻RNN的输入 = 前一时刻的输出+前一时刻的hidden state。
RNN 简介
- RNN 的运行机制
RNN 在每个时间步,采用上一个时间步的 hidden state(隐藏层状态) 和当前时间步的输入向量,来得到输出。白话说就是利用相同的权重在序列数据上滚动,和普通的ANN相比,引入了time step概念,这样就可以学习到特征在时间线上的关系了,RNN不同time step的权值是共享的。
RNN的第0时刻的hidden state是初始化的,RNN具体运算看下图,
搞清楚三个变量 batch_size、input_size、time_steps,两个函数输出outputs、state(hidden state)
以文字数据为例:
如果数据有1000段时序的句子,每句话有25个字,对每个字进行向量化,每个字的向量维度为300,那么batch_size=1000,time_steps=25,input_size=300。解析:time_steps一般情况下就是等于句子的长度,input_size等于字量化后向量的长度。以图片数据为例:
拿 MNIST 手写数字集来说,训练数据有 6000 个手写数字图像,每个数字图像大小为28*28,batch_size=6000没的说,time_steps=28,input_size=28,我们可以理解为把图片图片分成28份,每份shape=(1, 28),然后利用最后一个时刻的状态来解码分类即可。outputs:time_steps步里所有输出,shape=(batch_size, time_steps, cell.output_size)
state:最后一步的隐状态,它的形状为(batch_size, cell.state_size)
RNN的缺点: RNN机制实际中存在长程梯度消失的问题,对于较长的句子,我们很难寄希望于将输入的序列转化为定长的向量而保存所有的有效信息,所以随着所需翻译句子的长度的增加,这种结构的效果会显著下降。
Encoder - Decoder
Encoder-Decoder是个非常通用的计算框架,至于Encoder和Decoder具体使用什么模型都是由研究者自己定的,常见的比如CNN/RNN/BiRNN/GRU/LSTM/Deep LSTM等,可以做机器翻译,也可以做图像语义分割。
数据先在Encoder部分将输入数据通过非线性变换转化为中间语义 C,然后传入到解码器,在解码器的每个time step的输入是前一时刻的预测结果和hidden state一起作为输入的,这种经典的seq2seq的模型在解码部分的输入只是编码器最后一个时刻的hidden state,然后和历史预测的结果进行逐步解码,实际上无论第几个time step的解码都是依赖同一个Encoder的结果,即 Y1=F(C,<EOS>);Y_1 = F(C, <EOS>);\space\spaceY1=F(C,<EOS>); Y2=F(C,Y1);Y_2 = F(C, Y_1);\space\spaceY2=F(C,Y1); Y3=F(C,Y1,Y2);Y_3 = F(C, Y_1, Y_2);Y3=F(C,Y1,Y2);,这就导致了模型注意力不集中。
Attention
Attention模型是对Encoder的所有hidden state与Decoder的每个 time step 的hidden state 做了一个加权融合,这样在不同的Decoder的time step具有不同的中间语义,比如在机器翻译中本身就应该是对齐的,我们在解码的过程中并不需要整个句子得到的中间语义,而是需要特定单词占权重很高的中间语义来解码,Attention 通过计算解码器每个time step的hidden state与所有Encoder的中间语义的“相似度”得到了一组权重,用于求加权平均即Attention的结果,用于解码。
Attention 数学公式
下面公式中的 ztz_tzt 要和Encoder的 nnn 个 hih_ihi 做相似度运算,然后得到 nnn 个权重系数 aita_i^tait, 然后对Encoder的 nnn 个 hih_ihi 做weighted sum就得到了 ctc_tct。
uit=vTtanh(W1hi+W2zt)(1)u_i^t=v^Ttanh(W_1h_i + W_2z_t) \space\space\space\space (1) uit=vTtanh(W1hi+W2zt) (1)
ait=softmax(uit)(2)a_i^t=softmax(u_i^t) \space\space\space\space (2)ait=softmax(uit) (2)
ct=∑i=1TAaithi(3)c_t=\sum_{i=1}^{T_A}a_i^th_i \space\space\space\space (3)ct=i=1∑TAaithi (3)
- hih_ihi 表示Encoder的第 iii 个 time step 的 hidden state
- ztz_tzt 表示Decoder的第 ttt 个 time step 的 hidden state
- W1、W2、vTW_1、W_2、v^TW1、W2、vT 是需要学习的参数,通过vTv^TvT将结果映射到与Encoder的time step一致
- TAT_ATA 是 Encoder 的 time steps
- ctc_tct 表示第 ttt 个 time step 的 Attention 结果
Attention代码
# https://github.com/bentrevett/pytorch-seq2seq/blob/master/3%20-%20Neural%20Machine%20Translation%20by%20Jointly%20Learning%20to%20Align%20and%20Translate.ipynb
class Attention(nn.Module):def __init__(self, enc_hid_dim, dec_hid_dim):super().__init__()# enc_dim * 2 if encoder_is_bidirectional else enc_dimself.attn = nn.Linear((enc_hid_dim * 2) + dec_hid_dim, dec_hid_dim)self.v = nn.Linear(dec_hid_dim, 1, bias = False)def forward(self, hidden, encoder_outputs):#hidden = [batch size, dec hid dim]# src len 表示 Encoder 有多少个hi#encoder_outputs = [src len, batch size, enc hid dim * 2]batch_size = encoder_outputs.shape[1]src_len = encoder_outputs.shape[0]#repeat decoder hidden state src_len timeshidden = hidden.unsqueeze(1).repeat(1, src_len, 1)encoder_outputs = encoder_outputs.permute(1, 0, 2)#hidden = [batch size, src len, dec hid dim]#encoder_outputs = [batch size, src len, enc hid dim * 2]energy = torch.tanh(self.attn(torch.cat((hidden, encoder_outputs), dim = 2))) #energy = [batch size, src len, dec hid dim]attention = self.v(energy).squeeze(2)#attention= [batch size, src len]return F.softmax(attention, dim=1)
Attention 计算示意图
Hard Attention
Hard attention is a stochastic process: instead of using all the hidden states as an input for the decoding, the system samples a hidden state hih_ihi with the probabilities sis_isi. In order to propagate a gradient through this process, we estimate the gradient by Monte Carlo sampling.
Hard attention 是一个随机过程:不使用Encoder全部的hidden state作为解码的输入,取而代之的是对每个 hih_ihi 依概率 sis_isi 进行采样。为了能让梯度在这个过程传播,我们通过蒙特卡洛随机抽样来估计梯度。
Soft Attention
我们常用的Attention就是Soft Attention,这种思路在图像分割网络 SE-Net 中也有很好的应用,对每个channel依其重要性进行加权平均得到更好的特征。
Attention 是如何解码的
不同的论文的做法不一样我们这里说的流程是参考Neural Machine Translation by Jointly Learning to Align and Translate论文的做法
- Step 1: Decoder包含attention层,attention层的输入是前一时刻的hidden state st−1s_{t-1}st−1 和所有的Encoder的hidden states HHH 一起计算得到的attention向量,记为 ata_tat。然后我们利用ata_tat与所有的Encoder的hidden states加权相加得到一个加权向量 wtw_twt,表示attention的输出。
a1=softmax(vTtanh(W1s0,W2H));...;at=softmax(vTtanh(W1st−1,W2H))a_1=softmax(v^Ttanh(W_1s_{0}, W_2H));...;a_t=softmax(v^Ttanh(W_1s_{t-1}, W_2H))a1=softmax(vTtanh(W1s0,W2H));...;at=softmax(vTtanh(W1st−1,W2H))
wt=atHw_t = a_t Hwt=atH
- Step 2: 计算Decoder hidden state,输入:输入词嵌入向量 d(yt)d(y_t)d(yt) 和attention向量wtw_twt的concat结果,和前一时刻的 Decoder 的 hidden state st−1s_{t-1}st−1,得到 sts_tst。
st=Decoder(d(yt),wt,st−1)s_t = Decoder(d(y_t),w_t,s_{t-1})st=Decoder(d(yt),wt,st−1)
- Step 3: 预测y^t+1\hat{y}_{t+1}y^t+1,我们把 d(yt)d(y_t)d(yt), wtw_twt 和 sts_tst concat到一起之后经过一个全连接层 fff, 来预测目标句子中的下一个单词 y^t+1\hat{y}_{t+1}y^t+1。
yt+1^=f(d(yt),wt,st)\hat{y_{t+1}}=f(d(y_t), w_t, s_t)yt+1^=f(d(yt),wt,st)
示意图:
Attention应用
Attention的应用比较多,这里列举几个图例:
语音识别
OCR
Image captioning
参考
RNN
RNN
Neural Machine Translation by Jointly Learning to Align and Translate
Neural Machine Translation by Jointly Learning to Align and Translate - Code
https://www.cnblogs.com/mdumpling/p/8657070.html
https://mp.weixin.qq.com/s/SlO-onoAmsea64HtcGzVfQ
https://blog.csdn.net/u010159842/article/details/80473462
https://zhuanlan.zhihu.com/p/47063917
https://github.com/keon/seq2seq
Q&A
Q1: seq2seq 损失函数用什么?
A1: 交叉熵,nn.CrossEntropyLoss
Q2: 在不加入Attention的decoder部分,是如何使用Encoder最后一个时刻的hidden state的?
A2:一般的做法是把Encoder最后一个时刻的hidden state直接作为Decoder的hidden state传入,每步解码用ht−1h_{t-1}ht−1 和 yt^\hat{y_t}yt^一起来预测 yt+1^\hat{y_{t+1}}yt+1^,也有的文章的做法是在Decoder的每个时间步都输入Encoder最后一个时刻的hidden state来一起解码,也是work的。
Q3: Attention 解决什么问题?
A3: 解决长程梯度消失的问题,以及对齐问题。
Seq2Seq中的Attention相关推荐
- NLP中的Attention注意力机制+Transformer详解
关注上方"深度学习技术前沿",选择"星标公众号", 资源干货,第一时间送达! 作者: JayLou娄杰 知乎链接:https://zhuanlan.zhihu. ...
- 「NLP」 聊聊NLP中的attention机制
https://www.toutiao.com/i6716536091681227267/ 本篇介绍在NLP中各项任务及模型中引入相当广泛的Attention机制.在Transformer中,最重要的 ...
- 理解LSTM/RNN中的Attention机制
转自:http://www.jeyzhang.com/understand-attention-in-rnn.html,感谢分享! 导读 目前采用编码器-解码器 (Encode-Decode) 结构的 ...
- 【NLP】 聊聊NLP中的attention机制
本篇介绍在NLP中各项任务及模型中引入相当广泛的Attention机制.在Transformer中,最重要的特点也是Attention.首先详细介绍其由来,然后具体介绍了其编解码结构的引入和原理,最后 ...
- Seq2Seq模型及Attention机制
Seq2Seq模型及Attention机制 Seq2Seq模型 Encoder部分 Decoder部分 seq2seq模型举例 LSTM简单介绍 基于CNN的seq2seq Transformer A ...
- 通道注意力机制 cnn keras_【CV中的Attention机制】简单而有效的CBAM模块
前言: CBAM模块由于其使用的广泛性以及易于集成得到很多应用.目前cv领域中的attention机制也是在2019年论文中非常火.这篇cbam虽然是在2018年提出的,但是其影响力比较深远,在很多领 ...
- 一文深入浅出cv中的Attention机制
在深度学习领域中,存在很多专业名词,第一次看的时候总会很懵逼-后面慢慢看得时候才会有那么感觉,但是总觉得差点意思.今天我们要说的一个专业名词,就叫做Attention机制! 1. 直观理解Attent ...
- Seq2Seq中Exposure Bias现象的浅析与对策
©PaperWeekly 原创 · 作者|苏剑林 单位|追一科技 研究方向|NLP.神经网络 前些天笔者写了CRF用过了,不妨再了解下更快的MEMM?,里边提到了 MEMM 的局部归一化和 CRF 的 ...
- 推荐中的attention有什么作用?
文 | 水哥 源 | 知乎 Saying 1. attention要解决两个问题:(1)attention怎么加,在哪个层面上做attention:(2)attention的系数怎么来,谁来得到att ...
最新文章
- python基础教程书籍推荐-入门python有什么好的书籍推荐?
- 数据处理程序的一点经验
- operator.ne_Python operator.ne()函数与示例
- hive 配置用户名_配置HiveServer2的安全策略之自定义用户名密码验证
- 《你必须知道的.NET》,评价和推荐
- 657. 机器人能否返回原点
- Javascript四种调用模式中的this指向
- 让getElementsByName适应IE和firefox
- 数据交换平台有哪些功能特点
- java里程碑之泛型--使用泛型
- 什么软件测试电脑分辨率,分辨率测试卡
- cass软件yy命令_南方CASS软件快捷命令大全,高手必备。。。
- 第一遍C++Primer5th读完感
- android图片的透明度变化,Android如何实现改变图片的透明度
- 矢量数据 秦岭淮河_秦岭-淮河一线的大致纬度
- 怎么屏蔽还有照片_【文末福利】在朋友圈发男神照片忘了屏蔽父母,麻麻的回应亮了…...
- CCPC 1010 YJJ's Salesman
- android 指南针 原理,手机指南针原理是什么?安卓/苹果手机指南针app工作原理介绍...
- Miktex 修改经验
- 服务网格在百度核心业务大规模落地实践
热门文章
- node.js 创建服务器_Node.js HTTP软件包–创建HTTP服务器
- android实例教程_活动之间的Android意向处理示例教程
- scala 协变和逆变_Scala方差:协变,不变和逆变
- 台湾大学生来厦门参访交流
- Java 的强引用、弱引用、软引用、虚引用
- 打造属于自己的underscore系列 ( 一 ) - 框架设计
- java中int和Integer对比的一些坑
- JavaScript中的“黑话”
- Ratingbar UseGuide
- Ubuntu Linux下安装MySQL