前言:该教程利用fairseq增加一个新的FaiseqEncoderDecoderModel,该新模型利用LSTM来encodes一个sentence输入句;接着把最后的hidden state传给第二个LSTM,用于decodes出target sentence目标句。

教程包括:

a.编写Encoder和Decoder,分别用于encode、decode输入句与目标句;

b.注册一个适用于Command-line Tools的新模型;

c.使用现有的command-line tools来训练这个模型;

d.使用Incremental decoding,让修改Decoder更加迅速敏捷making generation faster;

1.搭建Encoder和Decoder

利用FairseqEncoder和FairDecoder接口来搭建自己的Encoder和Decoder,这两个接口继承torch.nn.Module,因此FairseqEncoder和FairseqDecoder能够使用传统的PyTorch Modules来进行编写。

Encoder

该encoder会在输入句中嵌入tokens,然后传入torch.nn.LSTM并且返回最终的隐藏状态hidden state。创建的encoder保存为fairseq/models/simple_lstm.py

import torch.nn as nn
from fairseq import utils
from fairseq.models import FairseqEncoderclass SimpleLSTMEncoder(FairseqEncoder):def __init__(self, args, dictionary, embed_dim=128, hidden_dim=128, dropout=0.1,):super().__init__(dictionary)self.args = args# Our encoder will embed the inputs before feeding them to the LSTM.self.embed_tokens = nn.Embedding(num_embeddings=len(dictionary),embedding_dim=embed_dim,padding_idx=dictionary.pad(),)self.dropout = nn.Dropout(p=dropout)# We'll use a single-layer, unidirectional LSTM for simplicity.self.lstm = nn.LSTM(input_size=embed_dim,hidden_size=hidden_dim,num_layers=1,bidirectional=False,batch_first=True,)def forward(self, src_tokens, src_lengths):# The inputs to the ``forward()`` function are determined by the# Task, and in particular the ``'net_input'`` key in each# mini-batch. We discuss Tasks in the next tutorial, but for now just# know that *src_tokens* has shape `(batch, src_len)` and *src_lengths*# has shape `(batch)`.# Note that the source is typically padded on the left. This can be# configured by adding the `--left-pad-source "False"` command-line# argument, but here we'll make the Encoder handle either kind of# padding by converting everything to be right-padded.if self.args.left_pad_source:# Convert left-padding to right-padding.src_tokens = utils.convert_padding_direction(src_tokens,padding_idx=self.dictionary.pad(),left_to_right=True)# Embed the source.x = self.embed_tokens(src_tokens)# Apply dropout.x = self.dropout(x)# Pack the sequence into a PackedSequence object to feed to the LSTM.x = nn.utils.rnn.pack_padded_sequence(x, src_lengths, batch_first=True)# Get the output from the LSTM._outputs, (final_hidden, _final_cell) = self.lstm(x)# Return the Encoder's output. This can be any object and will be# passed directly to the Decoder.return {# this will have shape `(bsz, hidden_dim)`'final_hidden': final_hidden.squeeze(0),}# Encoders are required to implement this method so that we can rearrange# the order of the batch elements during inference (e.g., beam search).def reorder_encoder_out(self, encoder_out, new_order):"""Reorder encoder output according to `new_order`.Args:encoder_out: output from the ``forward()`` methodnew_order (LongTensor): desired orderReturns:`encoder_out` rearranged according to `new_order`"""final_hidden = encoder_out['final_hidden']return {'final_hidden': final_hidden.index_select(0, new_order),}

Decoder

基于Encoder的最终hidden state和嵌入的目标单词target word,利用Decoder来预测下一个word,有时候也叫teacher forcing。具体来说,使用torch.nn.LSTM来生成一个隐藏状态hidden state序列,之后通过函数映射到输出的词汇中用于预测每一个目标词target word。

import torch
from fairseq.models import FairseqDecoderclass SimpleLSTMDecoder(FairseqDecoder):def __init__(self, dictionary, encoder_hidden_dim=128, embed_dim=128, hidden_dim=128,dropout=0.1,):super().__init__(dictionary)# Our decoder will embed the inputs before feeding them to the LSTM.self.embed_tokens = nn.Embedding(num_embeddings=len(dictionary),embedding_dim=embed_dim,padding_idx=dictionary.pad(),)self.dropout = nn.Dropout(p=dropout)# We'll use a single-layer, unidirectional LSTM for simplicity.self.lstm = nn.LSTM(# For the first layer we'll concatenate the Encoder's final hidden# state with the embedded target tokens.input_size=encoder_hidden_dim + embed_dim,hidden_size=hidden_dim,num_layers=1,bidirectional=False,)# Define the output projection.self.output_projection = nn.Linear(hidden_dim, len(dictionary))# During training Decoders are expected to take the entire target sequence# (shifted right by one position) and produce logits over the vocabulary.# The *prev_output_tokens* tensor begins with the end-of-sentence symbol,# ``dictionary.eos()``, followed by the target sequence.def forward(self, prev_output_tokens, encoder_out):"""Args:prev_output_tokens (LongTensor): previous decoder outputs of shape`(batch, tgt_len)`, for teacher forcingencoder_out (Tensor, optional): output from the encoder, used forencoder-side attentionReturns:tuple:- the last decoder layer's output of shape`(batch, tgt_len, vocab)`- the last decoder layer's attention weights of shape`(batch, tgt_len, src_len)`"""bsz, tgt_len = prev_output_tokens.size()# Extract the final hidden state from the Encoder.final_encoder_hidden = encoder_out['final_hidden']# Embed the target sequence, which has been shifted right by one# position and now starts with the end-of-sentence symbol.x = self.embed_tokens(prev_output_tokens)# Apply dropout.x = self.dropout(x)# Concatenate the Encoder's final hidden state to *every* embedded# target token.x = torch.cat([x, final_encoder_hidden.unsqueeze(1).expand(bsz, tgt_len, -1)],dim=2,)# Using PackedSequence objects in the Decoder is harder than in the# Encoder, since the targets are not sorted in descending length order,# which is a requirement of ``pack_padded_sequence()``. Instead we'll# feed nn.LSTM directly.initial_state = (final_encoder_hidden.unsqueeze(0),  # hiddentorch.zeros_like(final_encoder_hidden).unsqueeze(0),  # cell)output, _ = self.lstm(x.transpose(0, 1),  # convert to shape `(tgt_len, bsz, dim)`initial_state,)x = output.transpose(0, 1)  # convert to shape `(bsz, tgt_len, hidden)`# Project the outputs to the size of the vocabulary.x = self.output_projection(x)# Return the logits and ``None`` for the attention weightsreturn x, None

2.注册模型(Registering the Model)

利用fairseq中的register_model()函数修饰器来注册模型。一旦模型注册成功,就可以使用Command-line Tools。

所有的注册模型必须实现BaseFairseqModel接口,对于sequence-to-sequence模型(如,任何模包含Encoder和Decoder的模型)需要使用FairseqEncoderDecoderModel接口。

创建SimpleLSTMModel类,在类中利用函数包装类wrapper class,命名该函数为simple_lstm。

from fairseq.models import FairseqEncoderDecoderModel, register_model# Note: the register_model "decorator" should immediately precede the
# definition of the Model class.@register_model('simple_lstm')
class SimpleLSTMModel(FairseqEncoderDecoderModel):@staticmethoddef add_args(parser):# Models can override this method to add new command-line arguments.# Here we'll add some new command-line arguments to configure dropout# and the dimensionality of the embeddings and hidden states.parser.add_argument('--encoder-embed-dim', type=int, metavar='N',help='dimensionality of the encoder embeddings',)parser.add_argument('--encoder-hidden-dim', type=int, metavar='N',help='dimensionality of the encoder hidden state',)parser.add_argument('--encoder-dropout', type=float, default=0.1,help='encoder dropout probability',)parser.add_argument('--decoder-embed-dim', type=int, metavar='N',help='dimensionality of the decoder embeddings',)parser.add_argument('--decoder-hidden-dim', type=int, metavar='N',help='dimensionality of the decoder hidden state',)parser.add_argument('--decoder-dropout', type=float, default=0.1,help='decoder dropout probability',)@classmethoddef build_model(cls, args, task):# Fairseq initializes models by calling the ``build_model()``# function. This provides more flexibility, since the returned model# instance can be of a different type than the one that was called.# In this case we'll just return a SimpleLSTMModel instance.# Initialize our Encoder and Decoder.encoder = SimpleLSTMEncoder(args=args,dictionary=task.source_dictionary,embed_dim=args.encoder_embed_dim,hidden_dim=args.encoder_hidden_dim,dropout=args.encoder_dropout,)decoder = SimpleLSTMDecoder(dictionary=task.target_dictionary,encoder_hidden_dim=args.encoder_hidden_dim,embed_dim=args.decoder_embed_dim,hidden_dim=args.decoder_hidden_dim,dropout=args.decoder_dropout,)model = SimpleLSTMModel(encoder, decoder)# Print the model architecture.print(model)return model# We could override the ``forward()`` if we wanted more control over how# the encoder and decoder interact, but it's not necessary for this# tutorial since we can inherit the default implementation provided by# the FairseqEncoderDecoderModel base class, which looks like:## def forward(self, src_tokens, src_lengths, prev_output_tokens):#     encoder_out = self.encoder(src_tokens, src_lengths)#     decoder_out = self.decoder(prev_output_tokens, encoder_out)#     return decoder_out

最后利用这个configuration配置方式为模型定义一个architecture,这个定义结构的方式是通过register_model_architecture函数修饰器完成的。之后就可以使用--arch命令参数,如--arch tutorial_simple_lstm。

from fairseq.models import register_model_architecture# The first argument to ``register_model_architecture()`` should be the name
# of the model we registered above (i.e., 'simple_lstm'). The function we
# register here should take a single argument *args* and modify it in-place
# to match the desired architecture.@register_model_architecture('simple_lstm', 'tutorial_simple_lstm')
def tutorial_simple_lstm(args):# We use ``getattr()`` to prioritize arguments that are explicitly given# on the command-line, so that the defaults defined below are only used# when no other value has been specified.args.encoder_embed_dim = getattr(args, 'encoder_embed_dim', 256)args.encoder_hidden_dim = getattr(args, 'encoder_hidden_dim', 256)args.decoder_embed_dim = getattr(args, 'decoder_embed_dim', 256)args.decoder_hidden_dim = getattr(args, 'decoder_hidden_dim', 256)

3.模型训练

可以利用fairseq-train命令行工具Command-line tool来训练模型,并且确保你的新模型结构为(--arch tutorial_simple_lstm)

> fairseq-train data-bin/iwslt14.tokenized.de-en \--arch tutorial_simple_lstm \--encoder-dropout 0.2 --decoder-dropout 0.2 \--optimizer adam --lr 0.005 --lr-shrink 0.5 \--max-tokens 12000
(...)
| epoch 052 | loss 4.027 | ppl 16.30 | wps 420805 | ups 39.7 | wpb 9841 | bsz 400 | num_updates 20852 | lr 1.95313e-05 | gnorm 0.218 | clip 0% | oom 0 | wall 529 | train_wall 396
| epoch 052 | valid on 'valid' subset | valid_loss 4.74989 | valid_ppl 26.91 | num_updates 20852 | best 4.74954
The model files should appear in the checkpoints/ directory. While this model architecture is not

模型文件将存储到./checkpoints文件夹中。可以使用fairseq-generate命令在测试集上来计算BLEU score。

> fairseq-generate data-bin/iwslt14.tokenized.de-en \--path checkpoints/checkpoint_best.pt \--beam 5 \--remove-bpe
(...)
| Translated 6750 sentences (153132 tokens) in 17.3s (389.12 sentences/s, 8827.68 tokens/s)
| Generate test with beam=5: BLEU4 = 8.18, 38.8/12.1/4.7/2.0 (BP=1.000, ratio=1.066, syslen=139865, reflen=131146)

4.让代码飞起来Making generation faster

虽然sequence-to-sequence 模型天生就很慢,但我们上面的实现也特别慢,因为它为每个输出标记重新计算整个Decoder隐藏状态序列(即,它是O(n^2))。我们可以通过缓存以前的隐藏状态来大大加快这个过程。

在fairseq中,这被称为增量译码incremental decoding。增量解码是推理时的一种特殊模式,其中Model只接收与前一个输出令牌(用于教师强制)对应的单个时间步输入,并且必须增量地产生下一个输出。因此,模型必须缓存序列所需的任何长期状态,如隐藏状态、卷积状态等。

为了实现增量解码,将修改模型来实现FairseqIncrementalDecoder接口。与标准的FairseqDecoder接口相比,增量解码器接口允许forward()方法接受一个额外的关键字参数(incremental_state),该参数可用于跨时间步缓存状态。

用一个增量的替换的SimpleLSTMDecoder:

import torch
from fairseq.models import FairseqIncrementalDecoderclass SimpleLSTMDecoder(FairseqIncrementalDecoder):def __init__(self, dictionary, encoder_hidden_dim=128, embed_dim=128, hidden_dim=128,dropout=0.1,):# This remains the same as before.super().__init__(dictionary)self.embed_tokens = nn.Embedding(num_embeddings=len(dictionary),embedding_dim=embed_dim,padding_idx=dictionary.pad(),)self.dropout = nn.Dropout(p=dropout)self.lstm = nn.LSTM(input_size=encoder_hidden_dim + embed_dim,hidden_size=hidden_dim,num_layers=1,bidirectional=False,)self.output_projection = nn.Linear(hidden_dim, len(dictionary))# We now take an additional kwarg (*incremental_state*) for caching the# previous hidden and cell states.def forward(self, prev_output_tokens, encoder_out, incremental_state=None):if incremental_state is not None:# If the *incremental_state* argument is not ``None`` then we are# in incremental inference mode. While *prev_output_tokens* will# still contain the entire decoded prefix, we will only use the# last step and assume that the rest of the state is cached.prev_output_tokens = prev_output_tokens[:, -1:]# This remains the same as before.bsz, tgt_len = prev_output_tokens.size()final_encoder_hidden = encoder_out['final_hidden']x = self.embed_tokens(prev_output_tokens)x = self.dropout(x)x = torch.cat([x, final_encoder_hidden.unsqueeze(1).expand(bsz, tgt_len, -1)],dim=2,)# We will now check the cache and load the cached previous hidden and# cell states, if they exist, otherwise we will initialize them to# zeros (as before). We will use the ``utils.get_incremental_state()``# and ``utils.set_incremental_state()`` helpers.initial_state = utils.get_incremental_state(self, incremental_state, 'prev_state',)if initial_state is None:# first time initialization, same as the original versioninitial_state = (final_encoder_hidden.unsqueeze(0),  # hiddentorch.zeros_like(final_encoder_hidden).unsqueeze(0),  # cell)# Run one step of our LSTM.output, latest_state = self.lstm(x.transpose(0, 1), initial_state)# Update the cache with the latest hidden and cell states.utils.set_incremental_state(self, incremental_state, 'prev_state', latest_state,)# This remains the same as beforex = output.transpose(0, 1)x = self.output_projection(x)return x, None# The ``FairseqIncrementalDecoder`` interface also requires implementing a# ``reorder_incremental_state()`` method, which is used during beam search# to select and reorder the incremental state.def reorder_incremental_state(self, incremental_state, new_order):# Load the cached state.prev_state = utils.get_incremental_state(self, incremental_state, 'prev_state',)# Reorder batches according to *new_order*.reordered_state = (prev_state[0].index_select(1, new_order),  # hiddenprev_state[1].index_select(1, new_order),  # cell)# Update the cached state.utils.set_incremental_state(self, incremental_state, 'prev_state', reordered_state,)

最后,可以重新运行生成并观察加速情况:

# Before> fairseq-generate data-bin/iwslt14.tokenized.de-en \--path checkpoints/checkpoint_best.pt \--beam 5 \--remove-bpe
(...)
| Translated 6750 sentences (153132 tokens) in 17.3s (389.12 sentences/s, 8827.68 tokens/s)
| Generate test with beam=5: BLEU4 = 8.18, 38.8/12.1/4.7/2.0 (BP=1.000, ratio=1.066, syslen=139865, reflen=131146)# After> fairseq-generate data-bin/iwslt14.tokenized.de-en \--path checkpoints/checkpoint_best.pt \--beam 5 \--remove-bpe
(...)
| Translated 6750 sentences (153132 tokens) in 5.5s (1225.54 sentences/s, 27802.94 tokens/s)
| Generate test with beam=5: BLEU4 = 8.18, 38.8/12.1/4.7/2.0 (BP=1.000, ratio=1.066, syslen=139865, reflen=131146)

fairseq入门教程相关推荐

  1. Kafka入门教程与详解

    1 Kafka入门教程 1.1 消息队列(Message Queue) Message Queue消息传送系统提供传送服务.消息传送依赖于大量支持组件,这些组件负责处理连接服务.消息的路由和传送.持久 ...

  2. 【CV】Pytorch一小时入门教程-代码详解

    目录 一.关键部分代码分解 1.定义网络 2.损失函数(代价函数) 3.更新权值 二.训练完整的分类器 1.数据处理 2. 训练模型(代码详解) CPU训练 GPU训练 CPU版本与GPU版本代码区别 ...

  3. python tornado教程_Tornado 简单入门教程(零)——准备工作

    前言: 这两天在学着用Python + Tornado +MongoDB来做Web开发(哈哈哈这个词好高端).学的过程中查阅了无数资料,也收获了一些经验,所以希望总结出一份简易入门教程供初学者参考.完 ...

  4. python向量计算库教程_NumPy库入门教程:基础知识总结

    原标题:NumPy库入门教程:基础知识总结 视学算法 | 作者 知乎专栏 | 来源 numpy可以说是 Python运用于人工智能和科学计算的一个重要基础,近段时间恰好学习了numpy,pandas, ...

  5. mysql query browswer_MySQL数据库新特性之存储过程入门教程

    MySQL数据库新特性之存储过程入门教程 在MySQL 5中,终于引入了存储过程这一新特性,这将大大增强MYSQL的数据库处理能力.在本文中将指导读者快速掌握MySQL 5的存储过程的基本知识,带领用 ...

  6. python tensorflow教程_TensorFlow入门教程TensorFlow 基本使用T

    该楼层疑似违规已被系统折叠 隐藏此楼查看此楼 TensorFlow入门教程 TensorFlow 基本使用 TensorFlow官方中文教程 TensorFlow 的特点: 使用图 (graph) 来 ...

  7. air调用java,AIR2.0入门教程:与Java应用交互

    在之前的一篇文章中,我介绍了如何使用AIR2.0新增的NativeProcess类与本地进程进行交互和通讯,在那个例子里面我们使用了C++ 的代码,实际上只要是基于命令行的标准输入输出,AIR2.0的 ...

  8. 【Arduino】开发入门教程【一】什么是Arduino

    Arduino Arduino 是一款便捷灵活.方便上手的开源电子原型平台,包含硬件(各种型号的arduino板)和软件(arduino IDE).它适用于艺术家.设计师.爱好者和对于"互动 ...

  9. python 三分钟入门_Cython 三分钟入门教程

    作者:perrygeo 译者:赖勇浩(http://laiyonghao.com) 原文:http://www.perrygeo.net/wordpress/?p=116 我最喜欢的是Python,它 ...

最新文章

  1. 为什么你问问题,别人都已读不回?
  2. bigdecimal不保留小数_金钱要使用BigDecimal数据类型(使用double的已经被公司开除了)...
  3. 使用 Chrome Dev tools 分析应用的内存泄漏问题
  4. 个人作业7 第一阶段SCRUM冲刺(七)
  5. g30u盘启动 中科曙光1620_I620-G30
  6. Linux 文件区块连续吗,关于Linux文件系统的的简单理解和认识
  7. 【转】ORACLE_SID、INSTANCE_NAME、DB_NAME
  8. javashop配置微信支付
  9. 微服务设计笔记——几种远程过程调用方法
  10. waitpid最后以一个参数设为0_变频器用远传压力表控制恒压供水参数设置
  11. PLSQL 安装教程
  12. Android 弹幕(一)自定义
  13. wamp安装composer
  14. 补码加减运算及判断溢出方法
  15. SICP读书笔记(5) —— Sec1.1.2-Sec1.1.3
  16. 随机背景在随机位置添加随机颜色的文字
  17. 【2020最新】人工智能实战就业(面试)学习路线图
  18. 计算机专业高级职称评定条件,计算机职称考试初级高级中级职称评定申报条件...
  19. 以智能制造推进制造业智能化转型
  20. WPF 项目开发入门(一) 安装运行

热门文章

  1. VS C++项目打开时报 fatal error RC1015
  2. 如何将html转换成url,HTML之Data URL(转)
  3. STM32F1 W5500 TCP Client 回环测试
  4. 爬虫技术(04)神箭手爬虫field的属性
  5. R语言-因子的构造-factor函数
  6. 2020年鼠年正月十二 淡然面对
  7. unity团队大作业-足球射门游戏
  8. C#数据结构与算法总结
  9. SPDA-CNN:Unifying Semantic Part Detection and Abstraction for Fine-grained Recognition
  10. 为什么word文档或EXCET表格从电脑传到手机上格式就变了