seq2seq + attention 详解

作者:xy_free \qquad 时间:2018.05.21

1. seq2seq模型

seq2seq模型最早可追溯到2014年的两篇paper [1, 2],主要用于机器翻译任务(MT)。seq2seq本质上是一种encoder-decoder框架,以翻译任务中的“英译汉”为例,模型首先使用编码器对英文进行编码,得到英文的向量化表示S,然后使用解码器对S进行解码,得到对应的中文。由于encoder与decoder两端处理的都是序列数据,所以被称为sequence-to-sequence,简称seq2seq。另外,目前应用最多的编/解码器是RNN(LSTM,GRU),但编/解码器并不限于RNN,如也有人拿MLP作为编码器。
paper[1, 2]的主要结构如下图:

2. attention模型

attention模型最早出现于cv领域,而首次用于解决nlp问题是在2014年[3],seq2seq+attention 应用于机器翻译任务。以英译汉为例,当解码器对英文进行解码时,是一个词一个词生成的,而所生成的每个词对应的英文部分应该是不同,换句话说就是,解码器解码时不同step所分配的注意力是不同的。 再举一个例子,如看图说话(用一句话描述一幅图),所生成的词语应该对应图中的不同部分,即解码器在解码时,应该给图中“合适”的部位,分配更多的注意力(权重)。
paper[3]的主要结构如下图:

红圈标识的是编码器,其中h代表源文本的语义表示;紫圈标识的解码器,其中s代表目标文本的序列状态。c表示注意力向量,用来在解码时,控制源文本不同位置的attention分配

3. seq2seq + attention

以paper[3] 为例,对seq2seq + attention 的计算过程,进行详细说明,见上图(Translation: Attention Mechanism)
1. 使用 Bi-GRU 作为编码器,得到源文本的向量表示hthth_t

ht=fe(xt,ht−1)ht=fe(xt,ht−1)h_t = f_e(x_t, h_{t-1}) 详解如下:

  • htf=f(ht−1f,xt)hft=f(hft−1,xt) h_{f}^t= f(h_{f}^{t-1}, x_t)
  • htb=f(ht+1b,xt)hbt=f(hbt+1,xt) h_{b}^t = f(h_{b}^{t+1}, x_t)
  • ht=[htf,htb]ht=[hft,hbt] h_t = [h_{f}^t, h_{b}^t]
  • 其中,fefef_e 表示Bi-GRU,hfhfh_f表示正向GRU的输出, hbhbh_b表示反向GRU的输出,[]表示串联

2.对hthth_t进行解码,获得目标序列
模型所要生成的目标是个“词序列”,处理方式是每次生成一个词,迭代进行
p(yt|y<t,x)=f(yt−1,st,ct)p(yt|y<t,x)=f(yt−1,st,ct) p(y_t | y
其中f是 维度映射 + maxout,maxout是一种激活函数,维度映射是把所生成的向量转化为词表大小

  • yt−1yt−1y_{t-1}是目标序列上一个词的词向量
    在模型训练阶段,yt−1yt−1y_{t-1}有两种选择(按比例选):一种是真实的训练样本词向量,另一种是生成的词的词向量,前一种方式也被称为 teacher forcing
    在模型测试阶段,yt−1yt−1y_{t-1} 是指生成的词的词向量
  • ststs_t是序列的当前状态,st=fd(yt−1,st−1,ct)st=fd(yt−1,st−1,ct) s_t = f_d(y_{t-1}, s_{t-1}, c_t) ,其中 fdfdf_d 表示GRU
  • ctctc_t表示注意力分配,详细计算如下:
    ct=∑αtjhjct=∑αtjhj c_t = \sum \alpha_{tj} h_j
    αij=softmax(eij)αij=softmax(eij) \alpha_{ij} = softmax(e_{ij})
    etj=a(st−1,hj)=vTatanh(Wst−1+Uhj)etj=a(st−1,hj)=vaTtanh(Wst−1+Uhj) e_{tj} = a(s_{t-1}, h_j) = v_a^Ttanh(Ws_{t-1} + Uh_j)
    其中 va,W,Uva,W,Uv_a, W, U都是待学习参数,ctctc_t 可以理解为 关于hjhjh_j的一个加权平均值,权重为αtjαtj\alpha_{tj}

4. attention 扩展

attention很火,paper[4] 提出了一种attention改良方案,将attention划分为了两种形式:global, local.
global方式认为attention应该在所有源文本上进行,而local方式认为attention仅应该在部分源文本上进行。global理念与paper[3]相同,具体计算方式如下图所示:

其中“concat” 与 paper[3] 中的计算方式相同
另外,paper[4]除了改良了attention计算方式以外,还调整了decoder的计算方式,简化计算,优化编码

  • p(yt|y<t,x)=f(yt−1,st,ct)p(yt|y<t,x)=f(yt−1,st,ct) p(y_t | y

    • p(yt|y<t,x)=softmax(Wsst~)p(yt|y<t,x)=softmax(Wsst~)p(y_t | y
    • st~=tanh(Wc[ct;st])st~=tanh(Wc[ct;st]) \tilde{s_t} = tanh(W_c[c_t; s_t])
  • st=fd(yt−1,st−1,ct)st=fd(yt−1,st−1,ct) s_t=f_d(y_{t-1}, s_{t-1}, c_t) st=fd(yt−1,st−1)st=fd(yt−1,st−1)s_t=f_d(y_{t-1}, s_{t-1})

差异:

  • 改变了ctctc_t的计算方式,除concat外,dot、general 可以作为备选
  • paper[3]中,ststs_t由 yt−1,st−1,ctyt−1,st−1,cty_{t-1}, s_{t-1}, c_t组成,而最终计算p(yt|y<t,x)p(yt|y<t,x)p(y_t | y时,仍需考虑yt−1yt−1y_{t-1} 和 ctctc_t,冗余
    另外,fdfdf_d是个RNN,在计算ststs_t时,需要考虑ctctc_t,coding时,需使用for循环,会拖慢计算效率
  • paper[4]中,ststs_t仅由yt−1,st−1yt−1,st−1y_{t-1}, s_{t-1}组成,最终计算p(yt|y<t,x)p(yt|y<t,x)p(y_t | y 时,仅考虑st,ctst,cts_t, c_t,未冗余
    另外,在计算ststs_t时,yt−1yt−1y_{t-1} 已知,coding时,可算出所有step的sss,进而计算所有的c" role="presentation">ccc,所有操作都是向量化操作,不需使用for循环,会快很多
  • 改变了p(yt|y<t,x)p(yt|y<t,x) p(y_t | y的计算方式
    paper[3]中,使用maxout作为最后的激活函数, 即维度映射 + maxout
    paper[4]中,使用softmax作为最后的激活函数,即维度映射 + softmax

5. 需要注意的地方

  • decoder 端的ststs_t初始化: s0=tanh(Wh1b)s0=tanh(Whb1)s_0 = tanh(Wh_b^1), 取encoder的反向RNN的初态的非线性,作为decoder的初态
  • teacher forcing模式与测试时(生成模式)不同,所以训练过程不能完全都用teacher forcing,teacher forcing 与 生成模式应按比例分配
  • beamsearch 只是在测试的时候用到
  • 如果encoder 与 decoder 的序列都很长,显存装不下。可考虑对decoder端进行截断,分步优化(pytorch中 使用 state = state.detach())
  • coding时,尽量别用for循环,会极大降低计算效率

6. 总结

paper[4]无论从理论结构,还是从coding上来看,都非常棒,计算细节赘述如下:

  • p(yt|y<t,x)=softmax(Wsst~)p(yt|y<t,x)=softmax(Wsst~)p(y_t | y
  • st~=tanh(Wc[ct;st])st~=tanh(Wc[ct;st]) \tilde{s_t} = tanh(W_c[c_t; s_t])
  • st=fd(yt−1,st−1)st=fd(yt−1,st−1)s_t=f_d(y_{t-1}, s_{t-1})
  • ct=∑αtjhjct=∑αtjhj c_t = \sum \alpha_{tj} h_j
  • αij=softmax(eij)αij=softmax(eij) \alpha_{ij} = softmax(e_{ij})
  • etj=a(st−1,hj)=vTatanh(Wst−1+Uhj)etj=a(st−1,hj)=vaTtanh(Wst−1+Uhj) e_{tj} = a(s_{t-1}, h_j) = v_a^Ttanh(Ws_{t-1} + Uh_j)
  • ht=fe(xt,ht−1)ht=fe(xt,ht−1)h_t = f_e(x_t, h_{t-1}) 注:fefef_e 可使用LSTM, GRu, Bi-LSTM 等

参考

  1. Sutskever I, Vinyals O, Le Q V. Sequence to sequence learning with neural networks[C]//Advances in neural information processing systems. 2014.
  2. Cho K, Van Merriënboer B, Gulcehre C, et al. Learning phrase representations using RNN encoder-decoder for statistical machine translation[J]. arXiv, 2014.
  3. Bahdanau D, Cho K, Bengio Y. Neural machine translation by jointly learning to align and translate[J]. arXiv, 2014. & ICLR, 2015.
  4. Luong M T, Pham H, Manning C D. Effective approaches to attention-based neural machine translation[J]. arXiv, 2015.

seq2seq + attention 详解相关推荐

  1. 史上最小白之Attention详解

    1.前言 在自然语言处理领域,近几年最火的是什么?是BERT!谷歌团队2018提出的用于生成词向量的BERT算法在NLP的11项任务中取得了非常出色的效果,堪称2018年深度学习领域最振奋人心的消息. ...

  2. 注意力机制Attention详解

    注意力机制Attention详解 一.前言 2018年谷歌提出的NLP语言模型Bert一提出,便在NLP领域引起热议,之所以Bert模型能够火出圈,是由于Bert模型在NLP的多项任务中取得了之前所有 ...

  3. 深度学习之图像分类(十七)-- Transformer中Self-Attention以及Multi-Head Attention详解

    深度学习之图像分类(十七)Transformer中Self-Attention以及Multi-Head Attention详解 目录 深度学习之图像分类(十七)Transformer中Self-Att ...

  4. Self Attention 详解

    Self Attention 详解 前言 注意力机制(Attention),之前也是一直有所听闻的,也能够大概理解 Attention 的本质就是加权,对于 Google 的论文<Attenti ...

  5. seq2seq模型详解

    在李纪为博士的毕业论文中提到,基于生成的闲聊机器人中,seq2seq是一种很常见的技术.例如,在法语-英语翻译中,预测的当前英语单词不仅取决于所有前面的已翻译的英语单词,还取决于原始的法语输入;另一个 ...

  6. Seq2Seq 模型详解

    在NLP任务中,我们通常会遇到不定长的语言序列,比如机器翻译任务中,输入可能是一段不定长的英文文本,输出可能是不定长的中文或者法语序列.当遇到输入和输出都是不定长的序列时,可以使用编码器-解码器(en ...

  7. 史上最直白之Attention详解(原理+代码)

    目录 为什么要了解Attention机制 Attention 的直观理解 图解深度学习中的Attention机制 总结 为什么要了解Attention机制   在自然语言处理领域,近几年最火的是什么? ...

  8. 【NLP】Seq2Seq原理详解

    一.Seq2Seq简介 seq2seq 是一个Encoder–Decoder 结构的网络,它的输入是一个序列,输出也是一个序列.Encoder 中将一个可变长度的信号序列变为固定长度的向量表达,Dec ...

  9. 注意力机制详解(Attention详解)

    注意力机制与人眼类似,例如我们在火车站看车次信息,我们只关注大屏的车次信息,而忽略大屏外其他内容,从而导致钱包被偷... 注意力机制只关注重点信息,忽略不重要的信息,关注最核心的内容. 主要就是这个公 ...

最新文章

  1. 3.6.1 局域网的基本概念和体系结构
  2. 杭州软件测试培训有用吗,杭州软件测试培训靠谱吗
  3. python boxplot用法_Boxplot的介绍和使用
  4. Host aggregate分区
  5. 靠着零代码报表工具,转行报表开发后月薪超过3万
  6. 计算机辅助设计与制造考试题,计算机辅助设计与制造考试习题大集合..
  7. Out of Browser 开篇
  8. 连续系统离散化_连续系统转化为离散系统之 z 变换
  9. 网络教育源码 java在线教育源码
  10. Java实现 蓝桥杯VIP 算法训练 会议中心
  11. python中没有严格意义上的私有成员_尔雅尔雅学习通APP家园的治理:环境科学概论题库及答案...
  12. 51单片机红外遥控继电器电路部分设计
  13. 机器人动力学与控制学习笔记(七)————基于计算力矩法的滑模控制
  14. stm32h743单片机嵌入式学习笔记5-液晶屏汉字库原理
  15. an ancestor violates the following Content Security Policy directive: “frame-ancestors ‘none‘”.
  16. instant java,java.time.Instant.compareTo()方法
  17. uva 10041 - Vito's Family
  18. 微信严正提醒!再做这件事,封号处理
  19. 解决项目中使用kotlin不能直接引用xml中id
  20. ResidualCoder

热门文章

  1. 关于安卓中的语言国际化问题
  2. 解决Cannot read property ‘onCheckForUpdate’ of undefined问题
  3. MOUSE WITHOUT BORDERS连接失败原因猜测
  4. 监控过程组--项目管理
  5. The sound of music
  6. Redis——使用 python 操作 redis 之从 hmse 迁移到 hset
  7. Java程序员转大数据的学习路线(完整版)
  8. 装货单Shipping Order
  9. 今天送修yoga book有感
  10. 用python编写缠论中枢_【量化投资】缠论面面观(附Python源码)