Decoder 也是N层堆叠的结构。被分为3个 SubLayer,可以看出 Encoder 与 Decoder 三大主要的不同:

  1. Diff_1:Decoder SubLayer-1 使用的是 “masked” Multi-Headed Attention 机制,防止为了模型看到要预测的数据,防止泄露。
  2. Diff_2:SubLayer-2 是一个 encoder-decoder multi-head attention。
  3. Diff_3:LinearLayer 和 SoftmaxLayer 作用于 SubLayer-3 的输出后面,来预测对应的 word 的 probabilities 。

1 Diff_1 : “masked” Multi-Headed Attention

mask 的目标在于防止 decoder “seeing the future”,就像防止考生偷看考试答案一样。mask包含1和0:

用作者的话说, “We […] modify the self-attention sub-layer in the decoder stack to prevent positions from attending to subsequent positions. This masking, combined with fact that the output embeddings are offset by one position, ensures that the predictions for position i can depend only on the known outputs at positions less than i.”

2 Diff_2 : encoder-decoder multi-head attention

重点在于 x = self.sublayer1 self.src_attn 是 MultiHeadedAttention 的一个实例。query = x,key = m, value = m, mask = src_mask,这里x来自上一个 DecoderLayer,m来自 Encoder的输出。

到这里 Transformer 中三种不同的 Attention 都已经集齐了:

3 Diff_3 : Linear and Softmax to Produce Output Probabilities

最后的 linear layer 将 decoder 的输出扩展到与 vocabulary size 一样的维度上。经过 softmax 后,选择概率最高的一个 word 作为预测结果。

假设我们有一个已经训练好的网络,在做预测时,步骤如下:

  1. 给 decoder 输入 encoder 对整个句子 embedding 的结果 和一个特殊的开始符号 </s>。decoder 将产生预测,在我们的例子中应该是 ”I”。
  2. 给 decoder 输入 encoder 的 embedding 结果和 “</s>I”,在这一步 decoder 应该产生预测 “Love”。
  3. 给 decoder 输入 encoder 的 embedding 结果和 “</s>I Love”,在这一步 decoder 应该产生预测 “China”。
  4. 给 decoder 输入 encoder 的 embedding 结果和 “</s>I Love China”, decoder应该生成句子结尾的标记,decoder 应该输出 ”</eos>”。
  5. 然后 decoder 生成了 </eos>,翻译完成。

循环结果

动图:http://jalammar.github.io/images/t/transformer_decoding_2.gif

但是在训练过程中,decoder 没那么好时,预测产生的词很可能不是我们想要的。这个时候如果再把错误的数据再输给 decoder,就会越跑越偏:

这里在训练过程中要使用到 “teacher forcing”。利用我们知道他实际应该预测的 word 是什么,在这个时候喂给他一个正确的结果作为输入。

相对于选择最高的词 (greedy search),还有其他选择是比如 “beam search”,可以保留多个预测的 word。 Beam Search 方法不再是只得到一个输出放到下一步去训练了,我们可以设定一个值,拿多个值放到下一步去训练,这条路径的概率等于每一步输出的概率的乘积,具体可以参考李宏毅老师的课程:

或者 “Scheduled Sampling”:一开始我们只用真实的句子序列进行训练,而随着训练过程的进行,我们开始慢慢加入模型的输出作为训练的输入这一过程。

这部分对应 Annotated Transformer 中的实现为:

class Generator(nn.Module):"Define standard linear + softmax generation step."def __init__(self, d_model, vocab):super(Generator, self).__init__()self.proj = nn.Linear(d_model, vocab)def forward(self, x):return F.log_softmax(self.proj(x), dim=-1)

transformer 模型的decoder部分 带gif动图相关推荐

  1. Transformer模型是什么?带你从零详细解读Transformer模型(图解最完整版)

    前言 Transformer是一个利用注意力机制来提高模型训练速度的模型.关于注意力机制可以参看这篇文章,trasnformer可以说是完全基于自注意力机制的一个深度学习模型,因为它适用于并行化计算, ...

  2. 【NLP】Transformer模型深度解读

    " 本文对Transoformer模型进行了深度解读,包括整体架构,Attention结构的背景和细节,QKV的含义,Multi-head Attention的本质,FFN,Position ...

  3. transformer模型_【经典精读】Transformer模型深度解读

    文字长度: ★★★★★ 阅读难度: ★★☆☆☆ 原创程度: ★★★★☆ Transformer是2017年的一篇论文<Attention is All You Need>提出的一种模型架构 ...

  4. Transformer模型深度解读

    " 本文对Transoformer模型进行了深度解读,包括整体架构,Attention结构的背景和细节,QKV的含义,Multi-head Attention的本质,FFN,Position ...

  5. 论文《Attention Is All You Need》及Transformer模型

    目录 1. Introduction 2. 模型结构                        ​ 2.1 Transformer模型 2.2 输入层 2.3 位置向量:给单词赋予上下文语境 2. ...

  6. 图解Transformer模型(Multi-Head Attention)

    本文内容主要源于Attention is all you need: https://arxiv.org/abs/1706.03762 由于本人最近在研究文本的判别式模型,如文本分类任务,所以学习了T ...

  7. 【python量化】将Transformer模型用于股票价格预测

    前言 下面的这篇文章主要教大家如何搭建一个基于Transformer的简单预测模型,并将其用于股票价格预测当中.原代码在文末进行获取. 1.Transformer模型 Transformer 是 Go ...

  8. Transformer模型学习笔记

    Transformer模型 1 seq2seq方法对比 CNN:将序列分为多个窗口(卷积核),每个窗口具有相同的权重,可以带来平移不变性的好处:卷积核之间可以进行并行计算:根据局部关联性建模,若想获得 ...

  9. transformer模型的奥秘-学习笔记

          本文主要介绍了transformer模型的大概原理及模型结构.这篇学习笔记的学习资料主要是<Attention is All you Need>这篇神作,还有两位大神的指点(见 ...

  10. 5、注意力机制和Transformer模型

    1.人类的视觉注意力 从注意力模型的命名方式看,很明显其借鉴了人类的注意力机制,因此,我们首先简单介绍人类视觉的选择性注意力机制. 视觉注意力机制是人类视觉所特有的大脑信号处理机制.人类视觉通过快速扫 ...

最新文章

  1. 假设检验怎么做?这次把方法+Python代码一并教给你
  2. C语言实训指导数组,c语言实训指导书
  3. linux 磁盘控制器,linux – 戴尔R710上的PERC 6 / i RAID:单个控制器上的慢速磁盘…… RAID10?...
  4. GoldenGate应用拓扑结构(三)
  5. 通过反射创建新类示例的两种方式及比较
  6. [学习笔记]舞蹈链(Dancing Links)C++实现(指针版)
  7. linux 析构函数地址获取_析构函数实现多态
  8. 公安交管网服务器维护,交管网总是维护
  9. 如何简单地设置一个LoRa网关?
  10. FL studio 20简易入门教程 -- 第八篇 -- 技巧合集
  11. mysql 索引原理详解
  12. 电脑时间不同步怎么办?
  13. java常量 修改_Java 自定义常量
  14. React 组件的三种写法总结
  15. MySQL中show profile详解
  16. Dlink路由器后门分析
  17. 扁平化风格博客——后续
  18. 实时计算 java基础:类的结构之五:内部类
  19. android中文首字母排序,Android 实现中文按拼音排序方法
  20. 矩阵树定理--luoguP4208 [JSOI2008]最小生成树计数

热门文章

  1. 【写给初发论文的人】撰写综述性科技论文常见问题
  2. form-group 两种常用使用
  3. Java Swing 开发总结汇总贴
  4. 李沐动手学深度学习V2-BERT预训练和代码实现
  5. LINUX中ECHO命令的使用
  6. 已是操作系统的一部分_什么是操作系统 第2部分
  7. WIN10进不了BIOS的解决办法
  8. 家有经济适用男牛仔很忙
  9. 配置GeeM2传奇登陆器详细图文教程
  10. uni-app使用 getUserInfo 报错 fail can only be invoked by user TAP gesture 解决方法