目录

  • 一、前言
  • 二、模型搭建
    • 2.1 编码器
    • 2.2 注意力机制
    • 2.3 解码器
    • 2.4 Seq2Seq模型
  • 三、模型的训练与评估
  • 附录一、翻译效果比较
  • 附录二、完整代码

一、前言

在此之前,我们实现了最普通的seq2seq模型,该模型的编码器和解码器均采用的是两层单向的GRU。本篇文章将基于注意力机制改进之前的seq2seq模型,其中编码器采用两层双向的LSTM,解码器采用含有注意力机制的两层单向LSTM。由于数据预处理部分相同,因此本文不再赘述,详情可参考之前的文章。

二、模型搭建

本文接下来的叙述将沿用这篇文章中的符号。

2.1 编码器

编码器我们采用两层双向LSTM。编码器的输入形状为 (N,L)(N,L)(N,L),输出 output 的形状为 (L,N,2h)(L,N,2h)(L,N,2h),它是正向LSTM和反向LSTM输出进行了concat后的结果,包含了正反向的信息。编码器输出的 h_nc_n 的形状均为 (2n,N,h)(2n,N,h)(2n,N,h),需要将其形状改变为 (n,N,2h)(n,N,2h)(n,N,2h) 后才可作为解码器的初始隐状态。

至于为什么要改变 h_nc_n 的形状以及为什么不能直接用 reshape 去改变会在后面提到。

编码器的实现如下:

class Seq2SeqEncoder(nn.Module):def __init__(self, vocab_size, emb_size, hidden_size, num_layers=2, dropout=0.1):super().__init__()self.embedding = nn.Embedding(vocab_size, emb_size, padding_idx=1)self.rnn = nn.LSTM(emb_size, hidden_size, num_layers=num_layers, dropout=dropout, bidirectional=True)def forward(self, encoder_inputs):encoder_inputs = self.embedding(encoder_inputs).permute(1, 0, 2)output, (h_n, c_n) = self.rnn(encoder_inputs)  # output shape: (seq_len, batch_size, 2 * hidden_size)h_n = torch.cat((h_n[::2], h_n[1::2]), dim=2)  # (num_layers, batch_size, 2 * hidden_size)c_n = torch.cat((c_n[::2], c_n[1::2]), dim=2)  # (num_layers, batch_size, 2 * hidden_size)return output, h_n, c_n

2.2 注意力机制

在原先的seq2seq模型中,解码器在每一个时间步所使用的上下文向量均相同。现在我们希望解码器在不同的时间步上能够注意到源序列中不同的信息,因此考虑采用注意力机制。

解码器的核心架构为两层单向的LSTM(只能是单向),在 ttt 时刻,我们采用解码器在 t−1t-1t1 时刻最后一个隐层的输出作为查询,每个 output[t] 既作为键也作为值,相应的计算上下文向量的公式如下:

context[t]=∑t=1Lα(decoder_state[t−1],output[t])⋅output[t]\text{context}[t]=\sum_{t=1}^L \alpha(\text{decoder\_state}[t-1], \text{output}[t])\cdot \text{output}[t] context[t]=t=1Lα(decoder_state[t1],output[t])output[t]

其中 α(q,k)\alpha(q,k)α(q,k) 是注意力权重。

假设编码器所采用的LSTM的隐层大小为 hhh,解码器所采用的LSTM的隐层大小为 h′h'h。因 output[t] 的形状为 (N,2h)(N,2h)(N,2h)decoder_state[t - 1] 的形状为 (N,h′)(N,h')(N,h),要使用缩放点积注意力,则必须有 h′=2hh'=2hh=2h,否则无法进行内积操作,所以可以得出:解码器隐层大小是编码器的两倍

注意力机制实现如下:

class AttentionMechanism(nn.Module):def __init__(self):super().__init__()def forward(self, decoder_state, encoder_output):# 解码器的隐藏层大小必须是编码器的两倍,否则无法进行接下来的内积操作# decoder_state shape: (batch_size, 2 * hidden_size)# encoder_output shape: (seq_len, batch_size, 2 * hidden_size)decoder_state = decoder_state.unsqueeze(1)  # (batch_size, 1, 2 * hidden_size)encoder_output = encoder_output.transpose(0, 1)  # (batch_size, seq_len, 2 * hidden_size)# scores shape: (batch_size, seq_len)scores = torch.sum(decoder_state * encoder_output, dim=-1) / math.sqrt(decoder_state.shape[2])  # 广播机制attn_weights = F.softmax(scores, dim=-1)# context shape: (batch_size, 2 * hidden_size)context = torch.sum(attn_weights.unsqueeze(-1) * encoder_output, dim=1)  # 广播机制return context

2.3 解码器

解码器的原始输入的形状为 (N,L)(N,L)(N,L),通过嵌入以及 permute 操作后其形状变为 (L,N,d)(L,N,d)(L,N,d),而上下文的形状为 (N,2h)(N,2h)(N,2h)(L,N,d)(L,N,d)(L,N,d)(N,2h)(N,2h)(N,2h) 进行concat后才会作为其内部LSTM的输入。因此解码器采用的LSTM的 input_sized+2hd+2hd+2h。为了保证注意力机制正常运作,其隐藏层大小也应为编码器的两倍,即 2h2h2h

从编码器我们得到了形状为 (2n,N,h)(2n,N,h)(2n,N,h)h_n,而解码器采用的LSTM是单向的,从而其接受的 h_0 的形状应为 (n,N,2h)(n,N,2h)(n,N,2h)。一个很自然的想法是直接使用 reshape 完成形状的转化,但这样做会带来一个问题,即无法保证 h_0[-1] 对应的是正反向编码器在最后一个时间步最后一个隐层的输出的拼接,为此可考虑采用如下方式解决:

h0=Concat((hn[::2],hn[1::2]),dim=2)h_0 =\text{Concat}((h_n [ : \, : 2],h_n [ 1: \, : 2]),\;\text{dim} = 2) h0=Concat((hn[::2],hn[1::2]),dim=2)

至于为什么这样做,可以参考这篇文章。

在评估阶段中,我们往往需要利用模型的解码器一步一步地输出,每一时刻都会利用上一时刻解码器输出的隐状态,类似于下面的伪代码:

decoder_output, hidden_state = decoder(decoder_input, hidden_state)

这要求输入到解码器中的隐状态和解码器输出的隐状态的形状必须相同因此 h_nc_n 的形状转化必须在编码器中完成

解码器的实现:

class Seq2SeqDecoder(nn.Module):def __init__(self, vocab_size, emb_size, hidden_size, num_layers=2, dropout=0.1):super().__init__()self.embedding = nn.Embedding(vocab_size, emb_size, padding_idx=1)self.attn = AttentionMechanism()self.rnn = nn.LSTM(emb_size + 2 * hidden_size, 2 * hidden_size, num_layers=num_layers, dropout=dropout)self.fc = nn.Linear(2 * hidden_size, vocab_size)def forward(self, decoder_inputs, encoder_output, h_n, c_n):decoder_inputs = self.embedding(decoder_inputs).permute(1, 0, 2)  # (seq_len, batch_size, emb_size)# 注意将其移动到GPU上decoder_output = torch.zeros(decoder_inputs.shape[0], *h_n.shape[1:]).to(device)  # (seq_len, batch_size, 2 * hidden_size)for i in range(len(decoder_inputs)):context = self.attn(h_n[-1], encoder_output)  # (batch_size, 2 * hidden_size)# single_step_output shape: (1, batch_size, 2 * hidden_size)single_step_output, (h_n, c_n) = self.rnn(torch.cat((decoder_inputs[i], context), -1).unsqueeze(0), (h_n, c_n))decoder_output[i] = single_step_output.squeeze()logits = self.fc(decoder_output)  # (seq_len, batch_size, vocab_size)return logits, h_n, c_n

2.4 Seq2Seq模型

整体架构如下:

只需要将编码器和解码器封装在一起即可:

class Seq2SeqModel(nn.Module):def __init__(self, encoder, decoder):super().__init__()self.encoder = encoderself.decoder = decoderdef forward(self, encoder_inputs, decoder_inputs):return self.decoder(decoder_inputs, *self.encoder(encoder_inputs))

三、模型的训练与评估

因为输入输出发生了一些变化,我们只需要对原先的 train 函数和 evaluate 函数稍作修改

def train(train_loader, model, criterion, optimizer, num_epochs):train_loss = []model.train()for epoch in range(num_epochs):for batch_idx, (encoder_inputs, decoder_targets) in enumerate(train_loader):encoder_inputs, decoder_targets = encoder_inputs.to(device), decoder_targets.to(device)bos_column = torch.tensor([tgt_vocab['<bos>']] * decoder_targets.shape[0]).reshape(-1, 1).to(device)decoder_inputs = torch.cat((bos_column, decoder_targets[:, :-1]), dim=1)pred, _, _ = model(encoder_inputs, decoder_inputs)loss = criterion(pred.permute(1, 2, 0), decoder_targets)optimizer.zero_grad()loss.backward()optimizer.step()train_loss.append(loss.item())if (batch_idx + 1) % 50 == 0:print(f'[Epoch{epoch + 1}] [{(batch_idx + 1) * len(encoder_inputs)}/{len(train_loader.dataset)}] loss:{loss:.4f}')print()return train_lossdef evaluate(test_loader, model, bleu_k):bleu_scores = []translation_results = []model.eval()for src_seq, tgt_seq in test_loader:encoder_inputs = src_seq.to(device)encoder_output, h_n, c_n = model.encoder(encoder_inputs)pred_seq = [tgt_vocab['<bos>']]for _ in range(SEQ_LEN):decoder_inputs = torch.tensor(pred_seq[-1]).reshape(1, 1).to(device)pred, h_n, c_n = model.decoder(decoder_inputs, encoder_output, h_n, c_n)next_token_idx = pred.squeeze().argmax().item()if next_token_idx == tgt_vocab['<eos>']:breakpred_seq.append(next_token_idx)pred_seq = tgt_vocab[pred_seq[1:]]tgt_seq = tgt_seq.squeeze().tolist()tgt_seq = tgt_vocab[tgt_seq[:tgt_seq.index(tgt_vocab['<eos>'])]] if tgt_vocab['<eos>'] in tgt_seq else tgt_vocab[tgt_seq]translation_results.append((' '.join(tgt_seq), ' '.join(pred_seq)))if len(pred_seq) >= bleu_k:bleu_scores.append(bleu(tgt_seq, pred_seq, k=bleu_k))return bleu_scores, translation_results

保持其他超参数不变,使用 NVIDIA A40 进行训练(亲测 RTX 3090 会爆掉显存),大概需要6个小时,损失函数曲线如下:

与之前不同的是,在评估阶段,我们会分别计算平均BLEU-{2,3,4}分数并与原先的模型进行比较

bleu_2_scores, _ = evaluate(test_loader, net, bleu_k=2)
bleu_3_scores, _ = evaluate(test_loader, net, bleu_k=3)
bleu_4_scores, _ = evaluate(test_loader, net, bleu_k=4)
print(f"BLEU-2:{np.mean(bleu_2_scores)}| BLEU-3:{np.mean(bleu_3_scores)}| BLEU-4:{np.mean(bleu_4_scores)}")

比较结果列在下表中

模型 平均BLEU-2 平均BLEU-3 平均BLEU-4
Vanilla Seq2Seq(链接) 0.4799 0.3229 0.2144
Attention-based Seq2Seq(本文) 0.5711 0.4195 0.3036

可以看出加入了注意力机制后,BLEU得分提升了约十个百分点

一些可以改进的地方:

  • 完全可以先将 translation_results 计算出来再计算每种BLEU得分,这样做可以大大节省时间;
  • 训练过程中Teacher Forcing的比率为100%,可以尝试降低此比率以达到更好的效果;
  • BLEU无法理解同义词,导致一些合理的翻译会被否定,可以尝试换用其他的度量来更准确地评估模型。

附录一、翻译效果比较

translation_results 中随机抽取十个。

target:     je suis plutôt occupée .
vanilla:    je suis plutôt occupé .
attn-based: je suis plutôt occupé .target:     ça t'arrive de dormir ?
vanilla:    t'arrive-t-il de dormir ?
attn-based: t'arrive-t-il de dormir ?target:     je ne partirai probablement pas demain .
vanilla:    je ne vais probablement pas vouloir demain .
attn-based: je ne serai probablement pas demain .target:     je suis prudent .
vanilla:    je suis prudente .
attn-based: je suis prudente .target:     je suis sure que c'était juste un malentendu .
vanilla:    je suis sûr que c'était un malentendu .
attn-based: je suis sûr que ce fut un malentendu .target:     je me demandais ce qui t'avait fait changer d'avis .
vanilla:    je me demandais ce que tu ressens .
attn-based: je me demandais ce qui aurait réussi à ce sujet .target:     il me jeta un regard sévère .
vanilla:    il me fit une robe bleue .
attn-based: il m'a donné un grand regard .target:     te fies-tu à qui que ce soit ?
vanilla:    vous fiez-vous à quiconque ?
attn-based: te fies-tu à quiconque ?target:     es-tu sûre d'avoir assez chaud ?
vanilla:    es-tu sûr que tu es allé ?
attn-based: êtes-vous sûr d'avoir assez chaud ?target:     je commençais à me faire du souci à ton sujet .
vanilla:    je commençais à m'inquiéter pour toi .
attn-based: je commençais à m'inquiéter à votre sujet .

附录二、完整代码

基于注意力机制的seq2seq模型相关推荐

  1. 可视化神经机器翻译模型(基于注意力机制的Seq2seq模型)

    可视化神经机器翻译模型(基于注意力机制的Seq2seq模型)   序列到序列模型是深度学习模型,在机器翻译.文本摘要和图像字幕等任务中取得了很大的成功.谷歌翻译在2016年底开始在生产中使用这样的模型 ...

  2. seq2seq模型_具有注意力机制的seq2seq模型

    在本文中,你将了解: 为什么我们需要seq2seq模型的注意力机制? Bahdanua的注意力机制是如何运作的? Luong的注意力机制是如何运作的? 什么是局部和全局注意力? Bahdanua和Lu ...

  3. AI实战:搭建带注意力机制的 seq2seq 模型来做数值预测

    AI实战:搭建带注意力机制的 seq2seq 模型来做数值预测 seq2seq 框架图 环境依赖 Linux python3.6 tensorflow.keras 源码搭建模型及说明 依赖库 impo ...

  4. PyTorch中文教程 | (14) 基于注意力机制的seq2seq神经网络翻译

    Github地址 在这个项目中,我们将编写一个把法语翻译成英语的神经网络. [KEY: > input, = target, < output]> il est en train d ...

  5. PyTorch 1.0 中文官方教程:基于注意力机制的 seq2seq 神经网络翻译

    译者:mengfu188 作者: Sean Robertson 在这个项目中,我们将教一个把把法语翻译成英语的神经网络. [KEY: > input, = target, < output ...

  6. 基于注意力机制的seq2seq网络

    六月 北京 | 高性能计算之GPU CUDA培训 6月22-24日三天密集式学习  快速带你入门阅读全文> 正文共1680个字,26张图,预计阅读时间10分钟. seq2seq的用途有很多,比如 ...

  7. L11注意力机制和Seq2seq模型

    注意力机制 在"编码器-解码器(seq2seq)"⼀节⾥,解码器在各个时间步依赖相同的背景变量(context vector)来获取输⼊序列信息.当编码器为循环神经⽹络时,背景变量 ...

  8. 循环神经网络、注意力机制、Seq2Seq、Transformer与卷积神经网络(打卡2)

    一.过拟合和欠拟合 接下来,我们将探究模型训练中经常出现的两类典型问题: 一类是模型无法得到较低的训练误差,我们将这一现象称作欠拟合(underfitting): 另一类是模型的训练误差远小于它在测试 ...

  9. 基于注意力机制的循环神经网络对 金融时间序列的应用 学习记录

    摘要: 概况论文内容,包含解决的问题,解决的方法,成果 金融时间序列由于高噪声性以及序列间的相关性,导致传统模型的预测精度和泛化能力往往较低.为了克服这一问题,提出一种基于注意力机制的循环神经网络预测 ...

最新文章

  1. .net反射详解(转)
  2. silverlight、wpf中 dispatcher和timer区别
  3. python 基础教程:对 property 属性的讲解及用法
  4. 【微信小程序】wx:for
  5. Java Spring初学者之调试器里括号包含的类含义
  6. 使用FindBugs-IDEA插件找到代码中潜在的问题
  7. 如何在服务器上使用宝塔面板?
  8. 从简单的 XSS 到完整的 Google Cloud Shell 实例接管,值5000美元
  9. logrotate日志轮转
  10. redis集群环境搭建入门
  11. Bzoj2527--Poi2011Meteor
  12. 深度解析脑机接口技术的现状与未来!
  13. OISPT 内网安全项目组A1-渗透测试基础项目训练文档
  14. raised exception class EAccessViolation with message 'Access violation ataddress 64FF0002. Read of a
  15. STM32F103C8T6 操作矩阵键盘
  16. Request method ‘GET‘ not supported 405错误辨析总结
  17. 前端页面局部(全局)刷新方法
  18. 啊哈添柴挑战Java1827. 顺序输出(难)
  19. MySQL数据库学习笔记(2)
  20. 基因注释 InterProScan的三种使用方法

热门文章

  1. COMSOL如何绘制空间某点声压级曲线
  2. 使用weixin-java-open请求微信第三方平台接口超时
  3. 深入理解计算机系统——知识总结
  4. Sms多平台短信服务商系统~完成阿里云短信服务发送可自行配置
  5. matlab 去除协变量,什么是协变量?
  6. 交叉验证评估模型性能
  7. 驾驶员考试科目二完成
  8. 扫雷大战(命令行版,可以连续扫除一片空白区域)
  9. 7-184 正整数A+B (15 分)
  10. python主成分分析代码_PCA主成分分析 原理讲解 python代码实现