文章目录

  • 加载包和一些预处理
  • 定义模型结构
    • Encoder结构
    • Decoder结构
  • 小测试
  • 训练和评估函数
  • 训练模型
  • 在decoder中引入注意力机制

加载包和一些预处理

%load_ext autoreload
%autoreload 2
%matplotlib inlineimport random
import math
import timeimport torch
import torch.nn as nn
from torch import optim
import torch.nn.functional as F
from typing import *
from torch.nn import Parameter
from torch.nn import init
from torch import Tensor
import numpy as npimport matplotlib.pyplot as plt
import matplotlib.ticker as ticker
# 指定使用的设备
device = torch.device("cuda:2" if torch.cuda.is_available() else "cpu")
def setup_seed(seed):"""确保每次都会生成相同的结果"""torch.manual_seed(seed)np.random.seed(seed)random.seed(seed)torch.backends.cudnn.deterministic = Truetorch.backends.cudnn.benchmark = True

定义模型结构

Encoder结构

nn.Embedding是一个保存了固定字典和大小的简单查找表。这个模块常用来保存词嵌入和用下标检索它们。模块的输入是一个下标的列表,输出是对应的词嵌入。inputsize表示词的数量,hiddensize表示词向量的维度,获取词向量的方法是随机初始化。

lstm = nn.LSTM(input_size=input_dim, hidden_size=hidden_dim, num_layers=n_layers).

  • input_dim =输入数量(20的维度可代表20个输入)
  • hidden_dim =隐藏状态的大小; 每个LSTM单元在每个时间步产生的输出数。
  • n_layers =隐藏LSTM图层的数量; 通常是1到3之间的值; 值为1表示每个LSTM单元具有一个隐藏状态。 其默认值为1。

out, hidden = lstm(input.view(1, 1, -1), (h0, c0))
对于LSTM输入为:(input, (h0, c0)).

  • input = 输入序列中Tensor; (seq_len,batch,input_size)
  • h0 = Tensor,包含批处理中每个元素的初始隐藏状态
  • c0 = 批次中每个元素的初始单元格内存的Tensor
  • h0和c0默认为0,如果未指定。 它们的尺寸为:(n_layers,batch,hidden_dim)。
class EncoderLSTM(nn.Module):def __init__(self, input_size: int, hidden_size: int):super(EncoderLSTM, self).__init__()self.hidden_size = hidden_sizeself.embedding = nn.Embedding(input_size, hidden_size)self.lstm = nn.LSTM(hidden_size, hidden_size)def forward(self, inputs: Tensor, state: Tuple[Tensor]):(hidden, cell) = state# seq_len,batch都为1embedded = self.embedding(inputs).view(1, 1, -1)output = embeddedoutput, (hidden, cell) = self.lstm(output, (hidden, cell))return output, (hidden, cell)def init_hidden(self):cell = torch.zeros(1, 1, self.hidden_size, device=device)hidden = torch.zeros(1, 1, self.hidden_size, device=device)return hidden, cell

Decoder结构

class DecoderLSTM(nn.Module):def __init__(self, hidden_size: int, output_size: int):super(DecoderLSTM, self).__init__()self.hidden_size = hidden_sizeself.embedding = nn.Embedding(output_size, hidden_size)self.lstm = nn.LSTM(hidden_size, hidden_size)self.out = nn.Linear(hidden_size, output_size)self.log_softmax = nn.LogSoftmax(dim=1)self.activation_function = F.reludef forward(self, inputs, state):(hidden, cell) = stateoutput = self.embedding(inputs).view(1, 1, -1)output = self.activation_function(output)output, (hidden, cell) = self.lstm(output, (hidden, cell))output = self.log_softmax(self.out(output[0]))return output, (hidden, cell)def init_hidden(self):"""Init hiddenReturns:hidden:cell:"""cell = torch.zeros(1, 1, self.hidden_size, device=device)hidden = torch.zeros(1, 1, self.hidden_size, device=device)return hidden, cell

小测试

input_lang.name
'fra'
output_lang.name
'eng'
pairs
[['j ai ans .', 'i m .'],['je vais bien .', 'i m ok .'],['ca va .', 'i m ok .'],
testpair = random.choice(pairs)
testpair
['il ne se presente pas aux prochaines elections .','he is not running in the coming election .']
tensor_from_sentence(input_lang, testpair[0])
tensor([[  24],[ 297],[ 882],[2113],[ 246],[ 241],[4280],[3522],[   5],[   1]])
tensor_from_sentence(output_lang, testpair[1])
tensor([[  14],[  40],[ 147],[ 335],[ 102],[ 294],[ 142],[2744],[   4],[   1]])
tensor_from_sentence(output_lang, 'i .')
tensor([[2],[4],[1]])
tensor_from_pair(testpair, input_lang, output_lang)
(tensor([[  24],[ 297],[ 882],[2113],[ 246],[ 241],[4280],[3522],[   5],[   1]]), tensor([[  14],[  40],[ 147],[ 335],[ 102],[ 294],[ 142],[2744],[   4],[   1]]))

训练和评估函数

Teacher Forcing是一种用来快速而有效地训练循环神经网络模型的方法,这种方法以上一时刻的输出作为下一时刻的输入,能够解决缓慢收敛和不稳定的问题。但是,当生成的序列与训练期间模型看到的不同时(即遇到了训练集中不存在的数据),该方法还可能导致在实践中使用时模型效果不好。

注意到encoder和decoder在hidden和cell方面数据共享。

def train_by_sentence(input_tensor, target_tensor, encoder, decoder, encoder_optimizer, decoder_optimizer, loss_fn, use_teacher_forcing=True, reverse_source_sentence=True,max_length=MAX_LENGTH):"""Train by single sentence using EncoderLSTM and DecoderLSTMincluding training and update modelArgs:input_tensor: [input_sequence_len, 1, hidden_size]target_tensor: [target_sequence_len, 1, hidden_size]encoder: EncoderLSTMdecoder: DecoderLSTMencoder_optimizer: optimizer for encoderdecoder_optimizer: optimizer for decoderloss_fn: loss functionuse_teacher_forcing: True is to Feed the target as the next input, False is to use its own predictions as the next inputmax_length: max length for input and outputReturns:loss: scalar"""# 判断是否需要对句子进行逆转if reverse_source_sentence:input_tensor = torch.flip(input_tensor, [0])hidden, cell = encoder.init_hidden()encoder_optimizer.zero_grad()decoder_optimizer.zero_grad()# 获取输入和输出的目标序列的长度input_length = input_tensor.size(0)target_length = target_tensor.size(0)# encoder outputs:  [max_length, hidden_size],在这里定义是为了获取全局变量encoder_outputs = torch.zeros(max_length, encoder.hidden_size, device=device)loss = 0# 获取encoder的输出for ei in range(input_length):encoder_output, (hidden, cell) = encoder(input_tensor[ei], (hidden, cell))# 这里的batchsize和seqlength都是1encoder_outputs[ei] = encoder_output[0, 0]# 初始化为shape为(1,1),值为0的tensorSOS_token = 0decoder_input = torch.tensor([[SOS_token]], device=device)decoder_hidden = (hidden, cell)for di in range(target_length):decoder_output, (hidden, cell) = decoder(decoder_input, (hidden, cell))if use_teacher_forcing:# 将target作为inputloss += loss_fn(decoder_output, target_tensor[di])decoder_input = target_tensor[di]  else:# 将自己预测出来的结果作为下一轮输入的值topv, topi = decoder_output.topk(1)decoder_input = topi.squeeze().detach()loss += loss_fn(decoder_output, target_tensor[di])# 当输入为EOS之后停止if decoder_input.item() == EOS_token:breakloss.backward()encoder_optimizer.step()decoder_optimizer.step()return loss.item() / target_length
def train(encoder, decoder, n_iters, reverse_source_sentence=True, use_teacher_forcing=True,print_every=1000, plot_every=100, learning_rate=0.01):"""Train of Seq2seqArgs:encoder: EncoderLSTMdecoder: DecoderLSTMn_iters: train with n_iters sentences without replacementreverse_source_sentence: True is to reverse the source sentence but keep order of target unchanged,False is to keep order of the source sentence target unchangeduse_teacher_forcing: True is to Feed the target as the next input, False is to use its own predictions as the next inputprint_every: print log every print_every plot_every: plot every plot_every learning_rate: """start = time.time()plot_losses = []print_loss_total = 0plot_loss_total = 0# 使用SGD作为优化器encoder_optimizer = optim.SGD(encoder.parameters(), lr=learning_rate)decoder_optimizer = optim.SGD(decoder.parameters(), lr=learning_rate)# 获取训练数据training_pairs = [tensor_from_pair(random.choice(pairs), input_lang, output_lang)for _ in range(n_iters)]# 损失函数loss_fn = nn.NLLLoss()for i in range(1, n_iters+1):training_pair = training_pairs[i-1]input_tensor = training_pair[0].to(device)target_tensor = training_pair[1].to(device)            loss = train_by_sentence(input_tensor, target_tensor, encoder, decoder,encoder_optimizer, decoder_optimizer, loss_fn, use_teacher_forcing=use_teacher_forcing,reverse_source_sentence=reverse_source_sentence)print_loss_total += lossplot_loss_total += lossif i % print_every == 0:# Print Lossprint_loss_avg = print_loss_total / print_everyprint_loss_total = 0print("%s (%d %d%%) %.4f" % (time_since(start, i / n_iters),i, i / n_iters * 100, print_loss_avg))if i % plot_every == 0:# Plotplot_loss_avg = plot_loss_total / plot_everyplot_losses.append(plot_loss_avg)plot_loss_total = 0# show plotshow_plot(plot_losses)
def evaluate_by_sentence(encoder, decoder, sentence, reverse_source_sentence, max_length=MAX_LENGTH):"""Evalutae on a source sentenceArgs:encoderdecodersentencemax_lengthReturn:decoded_words: predicted sentence"""with torch.no_grad():# Get tensor of sentenceinput_tensor = tensor_from_sentence(input_lang, sentence).to(device)input_length = input_tensor.size(0)if reverse_source_sentence:input_tensor = torch.flip(input_tensor, [0])# init state for encoder(hidden, cell) = encoder.init_hidden()# encoder outputs: [max_length, hidden_size]encoder_outputs = torch.zeros(max_length, encoder.hidden_size, device=device)for ei in range(input_length):encoder_output, (hidden, cell) = encoder(input_tensor[ei],(hidden, cell))encoder_outputs[ei] += encoder_output[0, 0]# Last state of encoder as the init state of decoderdecoder_input = torch.tensor([[SOS_token]], device=device)decoder_hidden = (hidden, cell)decoded_words = []# When evaluate, use its own predictions as the next inputfor di in range(max_length):decoder_output, (hidden, cell) = decoder(decoder_input, (hidden, cell))topv, topi = decoder_output.data.topk(1)if topi.item() == EOS_token:decoded_words.append("<EOS>")breakelse:decoded_words.append(output_lang.index2word[topi.item()])decoder_input = topi.squeeze().detach()return decoded_words
def evaluate_randomly(encoder, decoder, n=10, reverse_source_sentence=True):"""Random pick sentence from dataset and observe the effect of translationArgs:encoder: decoder:n: numbers of sentences to evaluate"""for _ in range(n):pair = random.choice(pairs)# Source sentenceprint(">", pair[0])# Target sentenceprint("=", pair[1])output_words = evaluate_by_sentence(encoder, decoder, pair[0], reverse_source_sentence)output_sentence = " ".join(output_words)# Predicted sentenceprint("<", output_sentence)print("")
def show_plot(points):"""Plot according to points"""plt.figure()fig, ax = plt.subplots()loc = ticker.MultipleLocator(base=0.2)ax.yaxis.set_major_locator(loc)plt.plot(points)plt.show()

训练模型

input_lang, output_lang, pairs = prepare_data('eng', 'fra', reverse=True)
print(random.choice(pairs))
Reading lines...
Read 135842 sentence pairs
Reverse source sentence
Trimmed to 10599 sentence pairs
Counting words ...
Counting words:
fra 4345
eng 2803
['elle est hors de danger .', 'she is out of danger .']

定义超参数

setup_seed(45)
hidden_size = 256
reverse_source_sentence = True
use_teacher_forcing = True
encoder = EncoderLSTM(input_lang.n_words, hidden_size).to(device)
decoder = DecoderLSTM(hidden_size, output_lang.n_words).to(device)
print(">> Model is on: {}".format(next(encoder.parameters()).is_cuda))
print(">> Model is on: {}".format(next(decoder.parameters()).is_cuda))
>> Model is on: True
>> Model is on: True
iters = 50000
train(encoder, decoder, iters, reverse_source_sentence=reverse_source_sentence, use_teacher_forcing=use_teacher_forcing,print_every=250, plot_every=250)
0m 5s (- 18m 44s) (250 0%) 4.7301
0m 10s (- 17m 23s) (500 1%) 3.3269
0m 15s (- 17m 2s) (750 1%) 3.0458
0m 20s (- 16m 37s) (1000 2%) 2.8595
0m 25s (- 16m 26s) (1250 2%) 2.8384
0m 30s (- 16m 10s) (1500 3%) 2.7370
0m 34s (- 16m 3s) (1750 3%) 2.6759
0m 39s (- 15m 53s) (2000 4%) 2.6775
0m 44s (- 15m 44s) (2250 4%) 2.6329
0m 49s (- 15m 37s) (2500 5%) 2.6113
0m 54s (- 15m 31s) (2750 5%) 2.5890
0m 58s (- 15m 23s) (3000 6%) 2.5076
1m 3s (- 15m 17s) (3250 6%) 2.4959
1m 8s (- 15m 13s) (3500 7%) 2.5263
1m 13s (- 15m 10s) (3750 7%) 2.5167
1m 18s (- 15m 5s) (4000 8%) 2.3970
1m 23s (- 15m 1s) (4250 8%) 2.4132
1m 28s (- 14m 55s) (4500 9%) 2.3209
1m 33s (- 14m 49s) (4750 9%) 2.2764
1m 38s (- 14m 43s) (5000 10%) 2.2885
1m 43s (- 14m 39s) (5250 10%) 2.3332
1m 48s (- 14m 35s) (5500 11%) 2.3204
1m 53s (- 14m 30s) (5750 11%) 2.2541
1m 58s (- 14m 27s) (6000 12%) 2.2727
2m 3s (- 14m 23s) (6250 12%) 2.2329
2m 8s (- 14m 16s) (6500 13%) 2.1703
2m 12s (- 14m 10s) (6750 13%) 2.0671
2m 17s (- 14m 4s) (7000 14%) 2.1644
2m 22s (- 13m 59s) (7250 14%) 2.2096
2m 27s (- 13m 54s) (7500 15%) 2.0804
2m 32s (- 13m 49s) (7750 15%) 2.1003
2m 36s (- 13m 43s) (8000 16%) 2.0653
2m 41s (- 13m 39s) (8250 16%) 2.0542
2m 46s (- 13m 34s) (8500 17%) 2.0976
2m 51s (- 13m 30s) (8750 17%) 2.0354
2m 56s (- 13m 25s) (9000 18%) 2.0289
3m 1s (- 13m 20s) (9250 18%) 1.9037
3m 6s (- 13m 15s) (9500 19%) 1.9525
3m 11s (- 13m 9s) (9750 19%) 1.8598
3m 16s (- 13m 4s) (10000 20%) 1.9433
3m 20s (- 12m 58s) (10250 20%) 1.9164
3m 25s (- 12m 53s) (10500 21%) 1.8676
3m 30s (- 12m 47s) (10750 21%) 1.8912
3m 35s (- 12m 43s) (11000 22%) 1.8940
3m 39s (- 12m 37s) (11250 22%) 1.8391
3m 44s (- 12m 32s) (11500 23%) 1.9038
3m 49s (- 12m 28s) (11750 23%) 1.8223
3m 54s (- 12m 23s) (12000 24%) 1.7111
3m 59s (- 12m 18s) (12250 24%) 1.8238
4m 4s (- 12m 13s) (12500 25%) 1.7750
4m 9s (- 12m 8s) (12750 25%) 1.8930
4m 14s (- 12m 3s) (13000 26%) 1.7776
4m 19s (- 11m 59s) (13250 26%) 1.7633
4m 24s (- 11m 54s) (13500 27%) 1.7333
4m 29s (- 11m 49s) (13750 27%) 1.7893
4m 33s (- 11m 44s) (14000 28%) 1.7390
4m 38s (- 11m 39s) (14250 28%) 1.7701
4m 43s (- 11m 34s) (14500 28%) 1.7329
4m 48s (- 11m 29s) (14750 29%) 1.6696
4m 53s (- 11m 24s) (15000 30%) 1.6313
4m 58s (- 11m 20s) (15250 30%) 1.7256
5m 3s (- 11m 14s) (15500 31%) 1.6859
5m 8s (- 11m 10s) (15750 31%) 1.6195
5m 12s (- 11m 5s) (16000 32%) 1.5513
5m 17s (- 11m 0s) (16250 32%) 1.6846
5m 22s (- 10m 55s) (16500 33%) 1.6875
5m 27s (- 10m 49s) (16750 33%) 1.5778
5m 32s (- 10m 45s) (17000 34%) 1.6210
5m 37s (- 10m 40s) (17250 34%) 1.5758
5m 42s (- 10m 35s) (17500 35%) 1.5593
5m 46s (- 10m 30s) (17750 35%) 1.5810
5m 51s (- 10m 24s) (18000 36%) 1.5944
5m 55s (- 10m 18s) (18250 36%) 1.5053
6m 0s (- 10m 13s) (18500 37%) 1.4108
6m 5s (- 10m 8s) (18750 37%) 1.5082
6m 10s (- 10m 3s) (19000 38%) 1.5458
6m 14s (- 9m 58s) (19250 38%) 1.4254
6m 19s (- 9m 53s) (19500 39%) 1.4709
6m 24s (- 9m 48s) (19750 39%) 1.4742
6m 29s (- 9m 43s) (20000 40%) 1.3979
6m 33s (- 9m 38s) (20250 40%) 1.4668
6m 38s (- 9m 33s) (20500 41%) 1.4649
6m 43s (- 9m 28s) (20750 41%) 1.4709
6m 48s (- 9m 23s) (21000 42%) 1.4918
6m 53s (- 9m 19s) (21250 42%) 1.4107
6m 57s (- 9m 13s) (21500 43%) 1.4762
7m 2s (- 9m 9s) (21750 43%) 1.5225
7m 7s (- 9m 3s) (22000 44%) 1.4054
7m 12s (- 8m 58s) (22250 44%) 1.3352
7m 16s (- 8m 54s) (22500 45%) 1.3740
7m 21s (- 8m 49s) (22750 45%) 1.4333
7m 26s (- 8m 44s) (23000 46%) 1.3943
7m 31s (- 8m 39s) (23250 46%) 1.2736
7m 36s (- 8m 34s) (23500 47%) 1.3318
7m 41s (- 8m 29s) (23750 47%) 1.3693
7m 45s (- 8m 24s) (24000 48%) 1.3522
7m 50s (- 8m 19s) (24250 48%) 1.2736
7m 55s (- 8m 14s) (24500 49%) 1.3980
8m 0s (- 8m 9s) (24750 49%) 1.2201
8m 5s (- 8m 5s) (25000 50%) 1.2675
8m 10s (- 8m 0s) (25250 50%) 1.3469
8m 15s (- 7m 55s) (25500 51%) 1.2714
8m 20s (- 7m 51s) (25750 51%) 1.2665
8m 24s (- 7m 46s) (26000 52%) 1.2653
8m 29s (- 7m 41s) (26250 52%) 1.1929
8m 34s (- 7m 36s) (26500 53%) 1.2523
8m 39s (- 7m 31s) (26750 53%) 1.2691
8m 44s (- 7m 26s) (27000 54%) 1.1528
8m 49s (- 7m 22s) (27250 54%) 1.2370
8m 54s (- 7m 17s) (27500 55%) 1.2660
8m 59s (- 7m 12s) (27750 55%) 1.2506
9m 4s (- 7m 7s) (28000 56%) 1.2735
9m 9s (- 7m 2s) (28250 56%) 1.2148
9m 13s (- 6m 57s) (28500 56%) 1.2847
9m 19s (- 6m 53s) (28750 57%) 1.2118
9m 23s (- 6m 48s) (29000 57%) 1.1789
9m 28s (- 6m 43s) (29250 58%) 1.1460
9m 33s (- 6m 38s) (29500 59%) 1.1338
9m 38s (- 6m 33s) (29750 59%) 1.1070
9m 43s (- 6m 29s) (30000 60%) 1.2129
9m 48s (- 6m 24s) (30250 60%) 1.0972
9m 53s (- 6m 19s) (30500 61%) 1.0851
9m 58s (- 6m 14s) (30750 61%) 1.1832
10m 2s (- 6m 9s) (31000 62%) 1.0532
10m 7s (- 6m 4s) (31250 62%) 1.1463
10m 12s (- 5m 59s) (31500 63%) 1.0433
10m 17s (- 5m 54s) (31750 63%) 1.0821
10m 22s (- 5m 50s) (32000 64%) 1.0334
10m 28s (- 5m 45s) (32250 64%) 1.1181
10m 33s (- 5m 40s) (32500 65%) 1.1509
10m 38s (- 5m 36s) (32750 65%) 1.1036
10m 43s (- 5m 31s) (33000 66%) 1.0277
10m 48s (- 5m 26s) (33250 66%) 1.1785
10m 52s (- 5m 21s) (33500 67%) 1.0550
10m 57s (- 5m 16s) (33750 67%) 1.0629
11m 2s (- 5m 11s) (34000 68%) 1.0696
11m 7s (- 5m 6s) (34250 68%) 1.0918
11m 12s (- 5m 2s) (34500 69%) 1.0613
11m 17s (- 4m 57s) (34750 69%) 1.0352
11m 22s (- 4m 52s) (35000 70%) 1.0065
11m 27s (- 4m 47s) (35250 70%) 1.0674
11m 31s (- 4m 42s) (35500 71%) 1.0631
11m 36s (- 4m 37s) (35750 71%) 1.1001
11m 41s (- 4m 32s) (36000 72%) 1.0393
11m 46s (- 4m 28s) (36250 72%) 0.9400
11m 51s (- 4m 23s) (36500 73%) 1.0264
11m 55s (- 4m 18s) (36750 73%) 0.9909
12m 0s (- 4m 13s) (37000 74%) 0.9877
12m 5s (- 4m 8s) (37250 74%) 0.9790
12m 10s (- 4m 3s) (37500 75%) 0.8614
12m 15s (- 3m 58s) (37750 75%) 0.8985
12m 20s (- 3m 53s) (38000 76%) 0.9313
12m 25s (- 3m 49s) (38250 76%) 0.9810
12m 30s (- 3m 44s) (38500 77%) 0.8965
12m 35s (- 3m 39s) (38750 77%) 0.9325
12m 40s (- 3m 34s) (39000 78%) 0.9488
12m 45s (- 3m 29s) (39250 78%) 0.8820
12m 49s (- 3m 24s) (39500 79%) 0.9141
12m 54s (- 3m 19s) (39750 79%) 0.9451
12m 59s (- 3m 14s) (40000 80%) 0.8610
13m 4s (- 3m 9s) (40250 80%) 0.8987
13m 8s (- 3m 4s) (40500 81%) 0.9370
13m 13s (- 3m 0s) (40750 81%) 0.9663
13m 18s (- 2m 55s) (41000 82%) 0.8364
13m 22s (- 2m 50s) (41250 82%) 0.9296
13m 27s (- 2m 45s) (41500 83%) 0.8876
13m 32s (- 2m 40s) (41750 83%) 0.7837
13m 37s (- 2m 35s) (42000 84%) 0.8643
13m 41s (- 2m 30s) (42250 84%) 0.9092
13m 46s (- 2m 25s) (42500 85%) 0.8111
13m 51s (- 2m 20s) (42750 85%) 0.8668
13m 56s (- 2m 16s) (43000 86%) 0.8687
14m 1s (- 2m 11s) (43250 86%) 0.8701
14m 5s (- 2m 6s) (43500 87%) 0.8108
14m 10s (- 2m 1s) (43750 87%) 0.7329
14m 15s (- 1m 56s) (44000 88%) 0.8410
14m 20s (- 1m 51s) (44250 88%) 0.8041
14m 25s (- 1m 46s) (44500 89%) 0.7772
14m 30s (- 1m 42s) (44750 89%) 0.8702
14m 35s (- 1m 37s) (45000 90%) 0.8274
14m 40s (- 1m 32s) (45250 90%) 0.7602
14m 45s (- 1m 27s) (45500 91%) 0.8276
14m 50s (- 1m 22s) (45750 91%) 0.7752
14m 55s (- 1m 17s) (46000 92%) 0.7822
15m 0s (- 1m 12s) (46250 92%) 0.7470
15m 4s (- 1m 8s) (46500 93%) 0.7725
15m 9s (- 1m 3s) (46750 93%) 0.7477
15m 14s (- 0m 58s) (47000 94%) 0.7231
15m 19s (- 0m 53s) (47250 94%) 0.7538
15m 24s (- 0m 48s) (47500 95%) 0.8537
15m 28s (- 0m 43s) (47750 95%) 0.7798
15m 33s (- 0m 38s) (48000 96%) 0.7322
15m 38s (- 0m 34s) (48250 96%) 0.8085
15m 43s (- 0m 29s) (48500 97%) 0.7098
15m 48s (- 0m 24s) (48750 97%) 0.7215
15m 53s (- 0m 19s) (49000 98%) 0.8122
15m 58s (- 0m 14s) (49250 98%) 0.7791
16m 2s (- 0m 9s) (49500 99%) 0.7251
16m 7s (- 0m 4s) (49750 99%) 0.7874
16m 12s (- 0m 0s) (50000 100%) 0.7124

# Randomly pick up 10 sentence and observe the performance
evaluate_randomly(encoder, decoder, 10, reverse_source_sentence)
> je suis tres fier de nos etudiants .
= i m very proud of our students .
< i m very proud of you . <EOS>> vous etes faibles .
= you re weak .
< you re rude . <EOS>> tu n es pas si vieux .
= you re not that old .
< you re not that old . <EOS>> je songe a demissionner immediatement .
= i am thinking of resigning at once .
< i m thinking about the problem . <EOS>> je suis en retard sur le programme .
= i m behind schedule .
< i m behind schedule . <EOS>> je suis submerge de travail .
= i m swamped with work .
< i m proud of that . <EOS>> je ne vais pas prendre le moindre risque .
= i m not taking any chances .
< i m not taking any chances . <EOS>> je suis au restaurant .
= i m at the restaurant .
< i m in the office . <EOS>> c est toi la doyenne .
= you re the oldest .
< you re the oldest . <EOS>> je suis tres reconnaissant pour votre aide .
= i m very grateful for your help .
< i m very worried about you . <EOS>

在decoder中引入注意力机制

class AttentionDecoderLSTM(nn.Module):def __init__(self, hidden_size: int, output_size: int, dropout_p=0.1, max_length=MAX_LENGTH):"""DecoderLSTM with attention mechanism"""super(AttentionDecoderLSTM, self).__init__()self.hidden_size = hidden_sizeself.output_size = output_sizeself.dropout_p = dropout_pself.max_length = max_lengthself.embedding = nn.Embedding(self.output_size, self.hidden_size)self.attention = nn.Linear(self.hidden_size * 2, self.max_length)self.attention_combine = nn.Linear(self.hidden_size * 2, self.hidden_size)self.dropout = nn.Dropout(self.dropout_p)self.lstm = nn.LSTM(self.hidden_size, self.hidden_size)self.out = nn.Linear(self.hidden_size, self.output_size)self.activation_fn = F.reludef forward(self, inputs, state, encoder_outputs):"""ForwardArgs:inputs: [1, hidden_size]state : ([1, 1, hidden_size], [1, 1, hidden_size])encoder_outputs: [max_length, hidden_size]Returns:output:state: (hidden, cell)"""# embedded: [1, 1, hidden_size]embedded = self.embedding(inputs).view(1, 1, -1)embedded = self.dropout(embedded)(hidden, cell) = state# embedded[0]的size会变成[1,hidden_size]attention_weights = F.softmax(self.attention(torch.cat((embedded[0], hidden[0]), 1)), dim=1)# torch.bmm 是在batch的层面上对矩阵进行相乘的意思,比如说输入为[10, 3, 4]和[10, 4, 5],输出结果就是[10, 3, 5]# 下面代码中使用unsqueeze的目的就是添加batch这一维度,使得输出结果为[1,1,hiddensize]attention_applied = torch.bmm(attention_weights.unsqueeze(0),encoder_outputs.unsqueeze(0))# output: [1, hidden_size * 2]output = torch.cat((embedded[0], attention_applied[0]), 1)# output: [1, 1, hidden_size]output = self.attention_combine(output).unsqueeze(0)output = self.activation_fn(output)# output, [1, 1, hidden_size]output, (hidden, cell) = self.lstm(output, (hidden, cell))# output, [1, output_size]output = F.log_softmax(self.out(output[0]), dim=1)return output, (hidden, cell), attention_weightsdef init_hidden(self):"""Init hiddenReturns:hidden:cell:"""cell = torch.zeros(1, 1, self.hidden_size, device=device)hidden = torch.zeros(1, 1, self.hidden_size, device=device)return hidden, cell
def train_by_sentence_attn(input_tensor, target_tensor, encoder, decoder, encoder_optimizer, decoder_optimizer, loss_fn, use_teacher_forcing=True, reverse_source_sentence=True,max_length=MAX_LENGTH):"""Train by single sentence using EncoderLSTM and DecoderLSTMincluding training and update model, combining attention mechanism.Args:input_tensor: [input_sequence_len, 1, hidden_size]target_tensor: [target_sequence_len, 1, hidden_size]encoder: EncoderLSTMdecoder: DecoderLSTMencoder_optimizer: optimizer for encoderdecoder_optimizer: optimizer for decoderloss_fn: loss functionuse_teacher_forcing: True is to Feed the target as the next input, False is to use its own predictions as the next inputmax_length: max length for input and outputReturns:loss: scalar"""if reverse_source_sentence:input_tensor = torch.flip(input_tensor, [0])hidden, cell = encoder.init_hidden()encoder_optimizer.zero_grad()decoder_optimizer.zero_grad()input_length = input_tensor.size(0)target_length = target_tensor.size(0)encoder_outputs = torch.zeros(max_length, encoder.hidden_size, device=device)loss = 0# Get encoder outputsfor ei in range(input_length):encoder_output, (hidden, cell) = encoder(input_tensor[ei], (hidden, cell))encoder_outputs[ei] = encoder_output[0, 0]decoder_input = torch.tensor([[SOS_token]], device=device)decoder_hidden = (hidden, cell)for di in range(target_length):decoder_output, (hidden, cell), _ = decoder(decoder_input, (hidden, cell), encoder_outputs)if use_teacher_forcing:loss += loss_fn(decoder_output, target_tensor[di])decoder_input = target_tensor[di]  # Teacher forcingelse:topv, topi = decoder_output.topk(1)decoder_input = topi.squeeze().detach()loss += loss_fn(decoder_output, target_tensor[di])if decoder_input.item() == EOS_token:breakloss.backward()encoder_optimizer.step()decoder_optimizer.step()return loss.item() / target_length
def train_attn(encoder, decoder, n_iters, reverse_source_sentence=True, use_teacher_forcing=True,print_every=1000, plot_every=100, learning_rate=0.01):"""Train of Seq2seq with attention Args:encoder: EncoderLSTMdecoder: DecoderLSTMn_iters: train with n_iters sentences without replacementreverse_source_sentence: True is to reverse the source sentence but keep order of target unchanged,False is to keep order of the source sentence target unchangeduse_teacher_forcing: True is to Feed the target as the next input, False is to use its own predictions as the next inputprint_every: print log every print_every plot_every: plot every plot_every learning_rate: """start = time.time()plot_losses = []print_loss_total = 0plot_loss_total = 0encoder_optimizer = optim.SGD(encoder.parameters(), lr=learning_rate)decoder_optimizer = optim.SGD(decoder.parameters(), lr=learning_rate)training_pairs = [tensor_from_pair(random.choice(pairs), input_lang, output_lang)for _ in range(n_iters)]loss_fn = nn.NLLLoss()for i in range(1, n_iters+1):training_pair = training_pairs[i-1]input_tensor = training_pair[0].to(device)target_tensor = training_pair[1].to(device)            loss = train_by_sentence_attn(input_tensor, target_tensor, encoder, decoder,encoder_optimizer, decoder_optimizer, loss_fn, use_teacher_forcing=use_teacher_forcing,reverse_source_sentence=reverse_source_sentence)print_loss_total += lossplot_loss_total += lossif i % print_every == 0:# Print Lossprint_loss_avg = print_loss_total / print_everyprint_loss_total = 0print("%s (%d %d%%) %.4f" % (time_since(start, i / n_iters),i, i / n_iters * 100, print_loss_avg))if i % plot_every == 0:# Plotplot_loss_avg = plot_loss_total / plot_everyplot_losses.append(plot_loss_avg)plot_loss_total = 0# show plotshow_plot(plot_losses)
def evaluate_by_sentence_attn(encoder, decoder, sentence, reverse_source_sentence=True, max_length=MAX_LENGTH):"""Evalutae on a source sentence with model trained with attention mechanismArgs:encoderdecodersentencemax_lengthReturn:decoded_words: predicted sentence"""with torch.no_grad():input_tensor = tensor_from_sentence(input_lang, sentence).to(device)input_length = input_tensor.size(0)if reverse_source_sentence:input_tensor = torch.flip(input_tensor, [0])(hidden, cell) = encoder.init_hidden()# encoder outputs: [max_length, hidden_size]encoder_outputs = torch.zeros(max_length, encoder.hidden_size, device=device)for ei in range(input_length):encoder_output, (hidden, cell) = encoder(input_tensor[ei],(hidden, cell))encoder_outputs[ei] += encoder_output[0, 0]decoder_input = torch.tensor([[SOS_token]], device=device)decoder_hidden = (hidden, cell)decoded_words = []decoder_attentions = torch.zeros(max_length, max_length)for di in range(max_length):decoder_output, (hidden, cell), decoder_attention = \decoder(decoder_input, (hidden, cell), encoder_outputs)topv, topi = decoder_output.data.topk(1)# 获取attentiondecoder_attentions[di] = decoder_attention.dataif topi.item() == EOS_token:decoded_words.append("<EOS>")breakelse:decoded_words.append(output_lang.index2word[topi.item()])decoder_input = topi.squeeze().detach()return decoded_words, decoder_attentions[:di + 1]
def show_attention(input_sentence, output_words, attentions):"""绘制输入语句和输出语句之间的注意力关系"""# Set up figure with colorbarfig = plt.figure()ax = fig.add_subplot(111)cax = ax.matshow(attentions.numpy(), cmap='bone')fig.colorbar(cax)ax.set_xticklabels([''] + input_sentence.split(' ') +['<EOS>'], rotation=90)ax.set_yticklabels([''] + output_words)ax.xaxis.set_major_locator(ticker.MultipleLocator(1))ax.yaxis.set_major_locator(ticker.MultipleLocator(1))plt.show()
def evaluate_and_show_attention(input_sentence, encoder, decoder):"""Evaluate and show attention for a input sentence"""output_words, attentions = evaluate_by_sentence_attn(encoder, decoder, input_sentence)print('input =', input_sentence)print('output =', ' '.join(output_words))show_attention(input_sentence, output_words, attentions)
setup_seed(45)
hidden_size = 256
# Reverse the order of source input sentence
reverse_source_sentence = True
# Feed the target as the next input
use_teacher_forcing = True
encoder = EncoderLSTM(input_lang.n_words, hidden_size).to(device)
decoder = AttentionDecoderLSTM(hidden_size, output_lang.n_words).to(device)
print(">> Model is on: {}".format(next(encoder.parameters()).is_cuda))
print(">> Model is on: {}".format(next(decoder.parameters()).is_cuda))
>> Model is on: True
>> Model is on: True
iters = 50000
train_attn(encoder, decoder, iters, reverse_source_sentence=reverse_source_sentence, use_teacher_forcing=use_teacher_forcing,print_every=250, plot_every=250)
0m 8s (- 28m 13s) (250 0%) 4.9045
0m 16s (- 26m 50s) (500 1%) 3.3049
0m 24s (- 26m 20s) (750 1%) 3.0303
0m 31s (- 25m 56s) (1000 2%) 2.8569
0m 39s (- 25m 39s) (1250 2%) 2.8321
0m 47s (- 25m 25s) (1500 3%) 2.7320
0m 54s (- 25m 7s) (1750 3%) 2.6652
1m 4s (- 25m 36s) (2000 4%) 2.6601
1m 12s (- 25m 40s) (2250 4%) 2.6078
1m 20s (- 25m 37s) (2500 5%) 2.5934
1m 29s (- 25m 42s) (2750 5%) 2.5855
1m 38s (- 25m 39s) (3000 6%) 2.4951
1m 47s (- 25m 39s) (3250 6%) 2.4717
1m 55s (- 25m 40s) (3500 7%) 2.4963
2m 4s (- 25m 37s) (3750 7%) 2.4969
2m 13s (- 25m 35s) (4000 8%) 2.3994
2m 22s (- 25m 31s) (4250 8%) 2.3991
2m 31s (- 25m 28s) (4500 9%) 2.3072
2m 39s (- 25m 23s) (4750 9%) 2.2703
2m 48s (- 25m 18s) (5000 10%) 2.2681
2m 57s (- 25m 12s) (5250 10%) 2.3253
3m 6s (- 25m 8s) (5500 11%) 2.3062
3m 15s (- 25m 3s) (5750 11%) 2.2431
3m 24s (- 25m 1s) (6000 12%) 2.2481
3m 33s (- 24m 57s) (6250 12%) 2.2147
3m 42s (- 24m 51s) (6500 13%) 2.1789
3m 51s (- 24m 41s) (6750 13%) 2.0689
3m 59s (- 24m 33s) (7000 14%) 2.1624
4m 8s (- 24m 26s) (7250 14%) 2.1783
4m 17s (- 24m 19s) (7500 15%) 2.0618
4m 26s (- 24m 11s) (7750 15%) 2.0752
4m 34s (- 24m 2s) (8000 16%) 2.0351
4m 43s (- 23m 55s) (8250 16%) 2.0126
4m 52s (- 23m 50s) (8500 17%) 2.0556
5m 2s (- 23m 44s) (8750 17%) 2.0194
5m 10s (- 23m 36s) (9000 18%) 2.0155
5m 19s (- 23m 28s) (9250 18%) 1.8749
5m 28s (- 23m 19s) (9500 19%) 1.9430
5m 36s (- 23m 9s) (9750 19%) 1.8124
5m 45s (- 23m 1s) (10000 20%) 1.9113
5m 54s (- 22m 53s) (10250 20%) 1.8959
6m 2s (- 22m 44s) (10500 21%) 1.8226
6m 11s (- 22m 35s) (10750 21%) 1.8846
6m 20s (- 22m 27s) (11000 22%) 1.8598
6m 28s (- 22m 18s) (11250 22%) 1.8070
6m 37s (- 22m 9s) (11500 23%) 1.8770
6m 46s (- 22m 2s) (11750 23%) 1.7991
6m 54s (- 21m 52s) (12000 24%) 1.6979
7m 3s (- 21m 44s) (12250 24%) 1.7849
7m 11s (- 21m 34s) (12500 25%) 1.7383
7m 20s (- 21m 26s) (12750 25%) 1.8461
7m 29s (- 21m 18s) (13000 26%) 1.7735
7m 38s (- 21m 11s) (13250 26%) 1.7250
7m 46s (- 21m 2s) (13500 27%) 1.7031
7m 55s (- 20m 54s) (13750 27%) 1.7557
8m 4s (- 20m 46s) (14000 28%) 1.7034
8m 13s (- 20m 38s) (14250 28%) 1.7474
8m 22s (- 20m 30s) (14500 28%) 1.7002
8m 31s (- 20m 21s) (14750 29%) 1.6098
8m 39s (- 20m 11s) (15000 30%) 1.6132
8m 47s (- 20m 3s) (15250 30%) 1.7066
8m 56s (- 19m 54s) (15500 31%) 1.6781
9m 5s (- 19m 45s) (15750 31%) 1.5791
9m 13s (- 19m 35s) (16000 32%) 1.5366
9m 22s (- 19m 27s) (16250 32%) 1.6449
9m 30s (- 19m 18s) (16500 33%) 1.6655
9m 38s (- 19m 9s) (16750 33%) 1.5604
9m 47s (- 19m 0s) (17000 34%) 1.5838
9m 56s (- 18m 52s) (17250 34%) 1.5468
10m 4s (- 18m 43s) (17500 35%) 1.5449
10m 12s (- 18m 33s) (17750 35%) 1.5705
10m 20s (- 18m 22s) (18000 36%) 1.5782
10m 27s (- 18m 11s) (18250 36%) 1.4957
10m 35s (- 18m 1s) (18500 37%) 1.3964
10m 42s (- 17m 51s) (18750 37%) 1.4830
10m 50s (- 17m 41s) (19000 38%) 1.5285
10m 58s (- 17m 31s) (19250 38%) 1.4140
11m 5s (- 17m 21s) (19500 39%) 1.4499
11m 12s (- 17m 10s) (19750 39%) 1.4234
11m 20s (- 17m 0s) (20000 40%) 1.4060
11m 27s (- 16m 49s) (20250 40%) 1.4435
11m 35s (- 16m 40s) (20500 41%) 1.4514
11m 42s (- 16m 30s) (20750 41%) 1.4453
11m 50s (- 16m 20s) (21000 42%) 1.4741
11m 57s (- 16m 11s) (21250 42%) 1.3784
12m 5s (- 16m 1s) (21500 43%) 1.4439
12m 12s (- 15m 51s) (21750 43%) 1.4836
12m 20s (- 15m 42s) (22000 44%) 1.3703
12m 27s (- 15m 32s) (22250 44%) 1.3226
12m 36s (- 15m 24s) (22500 45%) 1.3610
12m 43s (- 15m 14s) (22750 45%) 1.4003
12m 51s (- 15m 5s) (23000 46%) 1.3914
12m 59s (- 14m 56s) (23250 46%) 1.2699
13m 7s (- 14m 47s) (23500 47%) 1.2957
13m 14s (- 14m 38s) (23750 47%) 1.3403
13m 22s (- 14m 29s) (24000 48%) 1.3439
13m 30s (- 14m 20s) (24250 48%) 1.2482
13m 37s (- 14m 10s) (24500 49%) 1.3789
13m 44s (- 14m 1s) (24750 49%) 1.1900
13m 52s (- 13m 52s) (25000 50%) 1.2474
14m 0s (- 13m 43s) (25250 50%) 1.3320
14m 7s (- 13m 34s) (25500 51%) 1.2478
14m 15s (- 13m 25s) (25750 51%) 1.2392
14m 23s (- 13m 16s) (26000 52%) 1.2369
14m 30s (- 13m 8s) (26250 52%) 1.1629
14m 38s (- 12m 59s) (26500 53%) 1.2625
14m 46s (- 12m 50s) (26750 53%) 1.2236
14m 54s (- 12m 41s) (27000 54%) 1.1323
15m 2s (- 12m 33s) (27250 54%) 1.2009
15m 9s (- 12m 24s) (27500 55%) 1.2412
15m 17s (- 12m 15s) (27750 55%) 1.2053
15m 25s (- 12m 6s) (28000 56%) 1.2504
15m 33s (- 11m 58s) (28250 56%) 1.1889
15m 40s (- 11m 49s) (28500 56%) 1.2637
15m 48s (- 11m 41s) (28750 57%) 1.2014
15m 56s (- 11m 32s) (29000 57%) 1.1773
16m 4s (- 11m 23s) (29250 58%) 1.1245
16m 11s (- 11m 15s) (29500 59%) 1.1128
16m 19s (- 11m 6s) (29750 59%) 1.1001
16m 26s (- 10m 57s) (30000 60%) 1.2020
16m 34s (- 10m 49s) (30250 60%) 1.0931
16m 41s (- 10m 40s) (30500 61%) 1.0847
16m 49s (- 10m 31s) (30750 61%) 1.1683
16m 56s (- 10m 23s) (31000 62%) 1.0578
17m 4s (- 10m 14s) (31250 62%) 1.1204
17m 12s (- 10m 6s) (31500 63%) 1.0375
17m 20s (- 9m 58s) (31750 63%) 1.0673
17m 29s (- 9m 50s) (32000 64%) 1.0291
17m 37s (- 9m 42s) (32250 64%) 1.1162
17m 45s (- 9m 33s) (32500 65%) 1.1235
17m 54s (- 9m 25s) (32750 65%) 1.0975
18m 3s (- 9m 17s) (33000 66%) 1.0241
18m 11s (- 9m 9s) (33250 66%) 1.1561
18m 19s (- 9m 1s) (33500 67%) 1.0525
18m 28s (- 8m 53s) (33750 67%) 1.0605
18m 37s (- 8m 45s) (34000 68%) 1.0551
18m 45s (- 8m 37s) (34250 68%) 1.0728
18m 53s (- 8m 29s) (34500 69%) 1.0422
19m 1s (- 8m 21s) (34750 69%) 1.0189
19m 10s (- 8m 13s) (35000 70%) 0.9925
19m 19s (- 8m 5s) (35250 70%) 1.0572
19m 28s (- 7m 57s) (35500 71%) 1.0211
19m 36s (- 7m 48s) (35750 71%) 1.0857
19m 45s (- 7m 40s) (36000 72%) 1.0427
19m 54s (- 7m 32s) (36250 72%) 0.9366
20m 2s (- 7m 24s) (36500 73%) 1.0282
20m 11s (- 7m 16s) (36750 73%) 0.9766
20m 20s (- 7m 8s) (37000 74%) 0.9918
20m 28s (- 7m 0s) (37250 74%) 0.9713
20m 37s (- 6m 52s) (37500 75%) 0.8682
20m 46s (- 6m 44s) (37750 75%) 0.8866
20m 55s (- 6m 36s) (38000 76%) 0.9383
21m 3s (- 6m 28s) (38250 76%) 0.9673
21m 12s (- 6m 20s) (38500 77%) 0.8919
21m 20s (- 6m 11s) (38750 77%) 0.9282
21m 29s (- 6m 3s) (39000 78%) 0.9390
21m 37s (- 5m 55s) (39250 78%) 0.8759
21m 46s (- 5m 47s) (39500 79%) 0.9049
21m 54s (- 5m 39s) (39750 79%) 0.9363
22m 3s (- 5m 30s) (40000 80%) 0.8546
22m 11s (- 5m 22s) (40250 80%) 0.8826
22m 19s (- 5m 14s) (40500 81%) 0.9364
22m 28s (- 5m 6s) (40750 81%) 0.9326
22m 37s (- 4m 57s) (41000 82%) 0.8726
22m 46s (- 4m 49s) (41250 82%) 0.9256
22m 54s (- 4m 41s) (41500 83%) 0.8858
23m 3s (- 4m 33s) (41750 83%) 0.7803
23m 11s (- 4m 25s) (42000 84%) 0.8532
23m 20s (- 4m 16s) (42250 84%) 0.9056
23m 29s (- 4m 8s) (42500 85%) 0.7939
23m 37s (- 4m 0s) (42750 85%) 0.8685
23m 46s (- 3m 52s) (43000 86%) 0.8675
23m 55s (- 3m 43s) (43250 86%) 0.8868
24m 3s (- 3m 35s) (43500 87%) 0.8165
24m 11s (- 3m 27s) (43750 87%) 0.7273
24m 20s (- 3m 19s) (44000 88%) 0.8150
24m 29s (- 3m 10s) (44250 88%) 0.8015
24m 37s (- 3m 2s) (44500 89%) 0.7703
24m 45s (- 2m 54s) (44750 89%) 0.8699
24m 54s (- 2m 46s) (45000 90%) 0.8267
25m 4s (- 2m 37s) (45250 90%) 0.7528
25m 12s (- 2m 29s) (45500 91%) 0.8305
25m 21s (- 2m 21s) (45750 91%) 0.7830
25m 29s (- 2m 13s) (46000 92%) 0.8001
25m 38s (- 2m 4s) (46250 92%) 0.7384
25m 47s (- 1m 56s) (46500 93%) 0.7825
25m 55s (- 1m 48s) (46750 93%) 0.7710
26m 4s (- 1m 39s) (47000 94%) 0.7454
26m 12s (- 1m 31s) (47250 94%) 0.7490
26m 21s (- 1m 23s) (47500 95%) 0.8604
26m 29s (- 1m 14s) (47750 95%) 0.7776
26m 36s (- 1m 6s) (48000 96%) 0.7263
26m 44s (- 0m 58s) (48250 96%) 0.8098
26m 51s (- 0m 49s) (48500 97%) 0.6916
26m 59s (- 0m 41s) (48750 97%) 0.7064
27m 6s (- 0m 33s) (49000 98%) 0.8158
27m 14s (- 0m 24s) (49250 98%) 0.7894
27m 21s (- 0m 16s) (49500 99%) 0.7465
27m 27s (- 0m 8s) (49750 99%) 0.7975
27m 35s (- 0m 0s) (50000 100%) 0.7190

evaluate_and_show_attention("elle a cinq ans de moins que moi .", encoder, decoder)evaluate_and_show_attention("elle est trop petit .", encoder, decoder)evaluate_and_show_attention("je ne crains pas de mourir .", encoder, decoder)evaluate_and_show_attention("c est un jeune directeur plein de talent .", encoder, decoder)
input = elle a cinq ans de moins que moi .
output = she is two years younger than me . <EOS>

input = elle est trop petit .
output = she is too drunk . <EOS>

input = je ne crains pas de mourir .
output = i m not afraid of making mistakes . <EOS>

input = c est un jeune directeur plein de talent .
output = he s a very talented writer . <EOS>

使用RNN构建机器翻译模型相关推荐

  1. Keras深度学习实战(35)——构建机器翻译模型

    Keras深度学习实战(35)--构建机器翻译模型 0. 前言 1. 模型与数据集分析 1.1 模型分析 1.2 数据集分析 2. 实现机器翻译模型 2.1 预处理数据 2.2 传统多对多架构 2.3 ...

  2. 在PyTorch中使用Seq2Seq构建的神经机器翻译模型

    在这篇文章中,我们将构建一个基于LSTM的Seq2Seq模型,使用编码器-解码器架构进行机器翻译. 本篇文章内容: 介绍 数据准备和预处理 长短期记忆(LSTM) - 背景知识 编码器模型架构(Seq ...

  3. 可视化神经机器翻译模型(基于注意力机制的Seq2seq模型)

    可视化神经机器翻译模型(基于注意力机制的Seq2seq模型)   序列到序列模型是深度学习模型,在机器翻译.文本摘要和图像字幕等任务中取得了很大的成功.谷歌翻译在2016年底开始在生产中使用这样的模型 ...

  4. lstm 变长序列_keras在构建LSTM模型时对变长序列的处理操作

    我就废话不多说了,大家还是直接看代码吧~ print(np.shape(X))#(1920, 45, 20) X=sequence.pad_sequences(X, maxlen=100, paddi ...

  5. [Python人工智能] 三十四.Bert模型 (3)keras-bert库构建Bert模型实现微博情感分析

    从本专栏开始,作者正式研究Python深度学习.神经网络及人工智能相关知识.前一篇文章开启了新的内容--Bert,首先介绍Keras-bert库安装及基础用法及文本分类工作.这篇文章将通过keras- ...

  6. [Pytorch系列-58]:循环神经网络 - 词向量的自动构建与模型训练代码示例

    作者主页(文火冰糖的硅基工坊):文火冰糖(王文兵)的博客_文火冰糖的硅基工坊_CSDN博客 本文网址:https://blog.csdn.net/HiWangWenBing/article/detai ...

  7. 机器翻译模型一多层LSTM__Pytorch实现

    目录 1.机器翻译之Seq2Seq介绍 2.基于Pytorch的Seq2Seq实现 2.1数据准备 2.2模型建立 2.3训练 1.机器翻译之Seq2Seq介绍 Seq2Seq模型是一个序列到序列的模 ...

  8. Keras深度学习实战(36)——基于编码器-解码器的机器翻译模型

    Keras深度学习实战(36)--基于编码器-解码器的机器翻译模型 0. 前言 1. 模型与数据集分析 1.1 数据集分析 1.2 模型分析 2. 基于编码器-解码器结构的机器翻译模型 2.1 基于编 ...

  9. 【阅读笔记】应用LRP,通过将相关性从模型的输出层反向传播到其输入层来解释基于RNN的DKT模型(一)

    提示:Towards Interpretable Deep Learning Models for Knowledge Tracing将重点放在应用分层相关传播(LRP)方法,通过将相关性从模型的输出 ...

  10. 【阅读笔记】应用LRP,通过将相关性从模型的输出层反向传播到其输入层来解释基于RNN的DKT模型(二)

    提示:Interpreting Deep Learning Models for Knowledge Tracing与Towards Interpretable Deep Learning Model ...

最新文章

  1. 题目1526:朋友圈
  2. Windows自动启动程序的十大藏身之所
  3. C语言中两种方式表示时间日期值time_t和struct tm类型的相互转换
  4. C++ STL Pair
  5. 超市买苹果变量的定义和使用
  6. 如何使用JS来开发室内地图商场停车场车位管理系统
  7. 单尺度Retinex算法学习
  8. pop3邮箱怎么设置收发服务器端口,pop3设置(如何设置邮箱服务器?IMAP、POP3有何区别?)...
  9. 朋友去面试阿里蚂蚁金服测试岗位过程经历
  10. Google支付订单真伪的验证方式
  11. 什么是组态?组态的概念及发展趋势
  12. Unity全面入门笔记6-常用数学类型
  13. 【七夕特效】 -- 满屏爱心
  14. Oracle中多表查询再按时间倒序
  15. 杨卫华:新浪微博的架构发展历程(转)
  16. matlab dvb,DVB-T OFDM Matlab仿真的整理和疑问
  17. 中标麒麟系统安装达梦8 数据库
  18. Shell编程扩展正则表达式(egrep、awk)
  19. 快速查看Mac下软件的所有快捷键— CheatSheet
  20. 疫情面试了13家企业软件测试岗位,面试题整理汇总,真的牛

热门文章

  1. android 车票预定接口,聚合数据Android SDK 12306火车票查询订票演示示例 编辑
  2. 橘子学设计模式之原型模式
  3. 你离运营只差一个打卡签到功能 早晚安打卡 小来早晚安打卡 功能一样 是一个唯一用户主动去分享的功能
  4. VS2022 安装 .NET Framework 4.0的方法
  5. 2020计算机专业保研夏令营面经:中科院计算所网数机试题目
  6. mysql 子链接_MySQL多表查询实例详解【链接查询、子查询等】
  7. 英语基础语法-语态(被动语态Be done)
  8. linux audit原理,Wauzh原理简析及audit规则风险评估
  9. reflections歌词翻译_花木兰主题曲Reflection翻译成中文的准确歌词
  10. 数据结构名词解释详细总结