使用RNN构建机器翻译模型
文章目录
- 加载包和一些预处理
- 定义模型结构
- Encoder结构
- Decoder结构
- 小测试
- 训练和评估函数
- 训练模型
- 在decoder中引入注意力机制
加载包和一些预处理
%load_ext autoreload
%autoreload 2
%matplotlib inlineimport random
import math
import timeimport torch
import torch.nn as nn
from torch import optim
import torch.nn.functional as F
from typing import *
from torch.nn import Parameter
from torch.nn import init
from torch import Tensor
import numpy as npimport matplotlib.pyplot as plt
import matplotlib.ticker as ticker
# 指定使用的设备
device = torch.device("cuda:2" if torch.cuda.is_available() else "cpu")
def setup_seed(seed):"""确保每次都会生成相同的结果"""torch.manual_seed(seed)np.random.seed(seed)random.seed(seed)torch.backends.cudnn.deterministic = Truetorch.backends.cudnn.benchmark = True
定义模型结构
Encoder结构
nn.Embedding是一个保存了固定字典和大小的简单查找表。这个模块常用来保存词嵌入和用下标检索它们。模块的输入是一个下标的列表,输出是对应的词嵌入。inputsize表示词的数量,hiddensize表示词向量的维度,获取词向量的方法是随机初始化。
lstm = nn.LSTM(input_size=input_dim, hidden_size=hidden_dim, num_layers=n_layers).
- input_dim =输入数量(20的维度可代表20个输入)
- hidden_dim =隐藏状态的大小; 每个LSTM单元在每个时间步产生的输出数。
- n_layers =隐藏LSTM图层的数量; 通常是1到3之间的值; 值为1表示每个LSTM单元具有一个隐藏状态。 其默认值为1。
out, hidden = lstm(input.view(1, 1, -1), (h0, c0))
对于LSTM输入为:(input, (h0, c0)).
- input = 输入序列中Tensor; (seq_len,batch,input_size)
- h0 = Tensor,包含批处理中每个元素的初始隐藏状态
- c0 = 批次中每个元素的初始单元格内存的Tensor
- h0和c0默认为0,如果未指定。 它们的尺寸为:(n_layers,batch,hidden_dim)。
class EncoderLSTM(nn.Module):def __init__(self, input_size: int, hidden_size: int):super(EncoderLSTM, self).__init__()self.hidden_size = hidden_sizeself.embedding = nn.Embedding(input_size, hidden_size)self.lstm = nn.LSTM(hidden_size, hidden_size)def forward(self, inputs: Tensor, state: Tuple[Tensor]):(hidden, cell) = state# seq_len,batch都为1embedded = self.embedding(inputs).view(1, 1, -1)output = embeddedoutput, (hidden, cell) = self.lstm(output, (hidden, cell))return output, (hidden, cell)def init_hidden(self):cell = torch.zeros(1, 1, self.hidden_size, device=device)hidden = torch.zeros(1, 1, self.hidden_size, device=device)return hidden, cell
Decoder结构
class DecoderLSTM(nn.Module):def __init__(self, hidden_size: int, output_size: int):super(DecoderLSTM, self).__init__()self.hidden_size = hidden_sizeself.embedding = nn.Embedding(output_size, hidden_size)self.lstm = nn.LSTM(hidden_size, hidden_size)self.out = nn.Linear(hidden_size, output_size)self.log_softmax = nn.LogSoftmax(dim=1)self.activation_function = F.reludef forward(self, inputs, state):(hidden, cell) = stateoutput = self.embedding(inputs).view(1, 1, -1)output = self.activation_function(output)output, (hidden, cell) = self.lstm(output, (hidden, cell))output = self.log_softmax(self.out(output[0]))return output, (hidden, cell)def init_hidden(self):"""Init hiddenReturns:hidden:cell:"""cell = torch.zeros(1, 1, self.hidden_size, device=device)hidden = torch.zeros(1, 1, self.hidden_size, device=device)return hidden, cell
小测试
input_lang.name
'fra'
output_lang.name
'eng'
pairs
[['j ai ans .', 'i m .'],['je vais bien .', 'i m ok .'],['ca va .', 'i m ok .'],
testpair = random.choice(pairs)
testpair
['il ne se presente pas aux prochaines elections .','he is not running in the coming election .']
tensor_from_sentence(input_lang, testpair[0])
tensor([[ 24],[ 297],[ 882],[2113],[ 246],[ 241],[4280],[3522],[ 5],[ 1]])
tensor_from_sentence(output_lang, testpair[1])
tensor([[ 14],[ 40],[ 147],[ 335],[ 102],[ 294],[ 142],[2744],[ 4],[ 1]])
tensor_from_sentence(output_lang, 'i .')
tensor([[2],[4],[1]])
tensor_from_pair(testpair, input_lang, output_lang)
(tensor([[ 24],[ 297],[ 882],[2113],[ 246],[ 241],[4280],[3522],[ 5],[ 1]]), tensor([[ 14],[ 40],[ 147],[ 335],[ 102],[ 294],[ 142],[2744],[ 4],[ 1]]))
训练和评估函数
Teacher Forcing是一种用来快速而有效地训练循环神经网络模型的方法,这种方法以上一时刻的输出作为下一时刻的输入,能够解决缓慢收敛和不稳定的问题。但是,当生成的序列与训练期间模型看到的不同时(即遇到了训练集中不存在的数据),该方法还可能导致在实践中使用时模型效果不好。
注意到encoder和decoder在hidden和cell方面数据共享。
def train_by_sentence(input_tensor, target_tensor, encoder, decoder, encoder_optimizer, decoder_optimizer, loss_fn, use_teacher_forcing=True, reverse_source_sentence=True,max_length=MAX_LENGTH):"""Train by single sentence using EncoderLSTM and DecoderLSTMincluding training and update modelArgs:input_tensor: [input_sequence_len, 1, hidden_size]target_tensor: [target_sequence_len, 1, hidden_size]encoder: EncoderLSTMdecoder: DecoderLSTMencoder_optimizer: optimizer for encoderdecoder_optimizer: optimizer for decoderloss_fn: loss functionuse_teacher_forcing: True is to Feed the target as the next input, False is to use its own predictions as the next inputmax_length: max length for input and outputReturns:loss: scalar"""# 判断是否需要对句子进行逆转if reverse_source_sentence:input_tensor = torch.flip(input_tensor, [0])hidden, cell = encoder.init_hidden()encoder_optimizer.zero_grad()decoder_optimizer.zero_grad()# 获取输入和输出的目标序列的长度input_length = input_tensor.size(0)target_length = target_tensor.size(0)# encoder outputs: [max_length, hidden_size],在这里定义是为了获取全局变量encoder_outputs = torch.zeros(max_length, encoder.hidden_size, device=device)loss = 0# 获取encoder的输出for ei in range(input_length):encoder_output, (hidden, cell) = encoder(input_tensor[ei], (hidden, cell))# 这里的batchsize和seqlength都是1encoder_outputs[ei] = encoder_output[0, 0]# 初始化为shape为(1,1),值为0的tensorSOS_token = 0decoder_input = torch.tensor([[SOS_token]], device=device)decoder_hidden = (hidden, cell)for di in range(target_length):decoder_output, (hidden, cell) = decoder(decoder_input, (hidden, cell))if use_teacher_forcing:# 将target作为inputloss += loss_fn(decoder_output, target_tensor[di])decoder_input = target_tensor[di] else:# 将自己预测出来的结果作为下一轮输入的值topv, topi = decoder_output.topk(1)decoder_input = topi.squeeze().detach()loss += loss_fn(decoder_output, target_tensor[di])# 当输入为EOS之后停止if decoder_input.item() == EOS_token:breakloss.backward()encoder_optimizer.step()decoder_optimizer.step()return loss.item() / target_length
def train(encoder, decoder, n_iters, reverse_source_sentence=True, use_teacher_forcing=True,print_every=1000, plot_every=100, learning_rate=0.01):"""Train of Seq2seqArgs:encoder: EncoderLSTMdecoder: DecoderLSTMn_iters: train with n_iters sentences without replacementreverse_source_sentence: True is to reverse the source sentence but keep order of target unchanged,False is to keep order of the source sentence target unchangeduse_teacher_forcing: True is to Feed the target as the next input, False is to use its own predictions as the next inputprint_every: print log every print_every plot_every: plot every plot_every learning_rate: """start = time.time()plot_losses = []print_loss_total = 0plot_loss_total = 0# 使用SGD作为优化器encoder_optimizer = optim.SGD(encoder.parameters(), lr=learning_rate)decoder_optimizer = optim.SGD(decoder.parameters(), lr=learning_rate)# 获取训练数据training_pairs = [tensor_from_pair(random.choice(pairs), input_lang, output_lang)for _ in range(n_iters)]# 损失函数loss_fn = nn.NLLLoss()for i in range(1, n_iters+1):training_pair = training_pairs[i-1]input_tensor = training_pair[0].to(device)target_tensor = training_pair[1].to(device) loss = train_by_sentence(input_tensor, target_tensor, encoder, decoder,encoder_optimizer, decoder_optimizer, loss_fn, use_teacher_forcing=use_teacher_forcing,reverse_source_sentence=reverse_source_sentence)print_loss_total += lossplot_loss_total += lossif i % print_every == 0:# Print Lossprint_loss_avg = print_loss_total / print_everyprint_loss_total = 0print("%s (%d %d%%) %.4f" % (time_since(start, i / n_iters),i, i / n_iters * 100, print_loss_avg))if i % plot_every == 0:# Plotplot_loss_avg = plot_loss_total / plot_everyplot_losses.append(plot_loss_avg)plot_loss_total = 0# show plotshow_plot(plot_losses)
def evaluate_by_sentence(encoder, decoder, sentence, reverse_source_sentence, max_length=MAX_LENGTH):"""Evalutae on a source sentenceArgs:encoderdecodersentencemax_lengthReturn:decoded_words: predicted sentence"""with torch.no_grad():# Get tensor of sentenceinput_tensor = tensor_from_sentence(input_lang, sentence).to(device)input_length = input_tensor.size(0)if reverse_source_sentence:input_tensor = torch.flip(input_tensor, [0])# init state for encoder(hidden, cell) = encoder.init_hidden()# encoder outputs: [max_length, hidden_size]encoder_outputs = torch.zeros(max_length, encoder.hidden_size, device=device)for ei in range(input_length):encoder_output, (hidden, cell) = encoder(input_tensor[ei],(hidden, cell))encoder_outputs[ei] += encoder_output[0, 0]# Last state of encoder as the init state of decoderdecoder_input = torch.tensor([[SOS_token]], device=device)decoder_hidden = (hidden, cell)decoded_words = []# When evaluate, use its own predictions as the next inputfor di in range(max_length):decoder_output, (hidden, cell) = decoder(decoder_input, (hidden, cell))topv, topi = decoder_output.data.topk(1)if topi.item() == EOS_token:decoded_words.append("<EOS>")breakelse:decoded_words.append(output_lang.index2word[topi.item()])decoder_input = topi.squeeze().detach()return decoded_words
def evaluate_randomly(encoder, decoder, n=10, reverse_source_sentence=True):"""Random pick sentence from dataset and observe the effect of translationArgs:encoder: decoder:n: numbers of sentences to evaluate"""for _ in range(n):pair = random.choice(pairs)# Source sentenceprint(">", pair[0])# Target sentenceprint("=", pair[1])output_words = evaluate_by_sentence(encoder, decoder, pair[0], reverse_source_sentence)output_sentence = " ".join(output_words)# Predicted sentenceprint("<", output_sentence)print("")
def show_plot(points):"""Plot according to points"""plt.figure()fig, ax = plt.subplots()loc = ticker.MultipleLocator(base=0.2)ax.yaxis.set_major_locator(loc)plt.plot(points)plt.show()
训练模型
input_lang, output_lang, pairs = prepare_data('eng', 'fra', reverse=True)
print(random.choice(pairs))
Reading lines...
Read 135842 sentence pairs
Reverse source sentence
Trimmed to 10599 sentence pairs
Counting words ...
Counting words:
fra 4345
eng 2803
['elle est hors de danger .', 'she is out of danger .']
定义超参数
setup_seed(45)
hidden_size = 256
reverse_source_sentence = True
use_teacher_forcing = True
encoder = EncoderLSTM(input_lang.n_words, hidden_size).to(device)
decoder = DecoderLSTM(hidden_size, output_lang.n_words).to(device)
print(">> Model is on: {}".format(next(encoder.parameters()).is_cuda))
print(">> Model is on: {}".format(next(decoder.parameters()).is_cuda))
>> Model is on: True
>> Model is on: True
iters = 50000
train(encoder, decoder, iters, reverse_source_sentence=reverse_source_sentence, use_teacher_forcing=use_teacher_forcing,print_every=250, plot_every=250)
0m 5s (- 18m 44s) (250 0%) 4.7301
0m 10s (- 17m 23s) (500 1%) 3.3269
0m 15s (- 17m 2s) (750 1%) 3.0458
0m 20s (- 16m 37s) (1000 2%) 2.8595
0m 25s (- 16m 26s) (1250 2%) 2.8384
0m 30s (- 16m 10s) (1500 3%) 2.7370
0m 34s (- 16m 3s) (1750 3%) 2.6759
0m 39s (- 15m 53s) (2000 4%) 2.6775
0m 44s (- 15m 44s) (2250 4%) 2.6329
0m 49s (- 15m 37s) (2500 5%) 2.6113
0m 54s (- 15m 31s) (2750 5%) 2.5890
0m 58s (- 15m 23s) (3000 6%) 2.5076
1m 3s (- 15m 17s) (3250 6%) 2.4959
1m 8s (- 15m 13s) (3500 7%) 2.5263
1m 13s (- 15m 10s) (3750 7%) 2.5167
1m 18s (- 15m 5s) (4000 8%) 2.3970
1m 23s (- 15m 1s) (4250 8%) 2.4132
1m 28s (- 14m 55s) (4500 9%) 2.3209
1m 33s (- 14m 49s) (4750 9%) 2.2764
1m 38s (- 14m 43s) (5000 10%) 2.2885
1m 43s (- 14m 39s) (5250 10%) 2.3332
1m 48s (- 14m 35s) (5500 11%) 2.3204
1m 53s (- 14m 30s) (5750 11%) 2.2541
1m 58s (- 14m 27s) (6000 12%) 2.2727
2m 3s (- 14m 23s) (6250 12%) 2.2329
2m 8s (- 14m 16s) (6500 13%) 2.1703
2m 12s (- 14m 10s) (6750 13%) 2.0671
2m 17s (- 14m 4s) (7000 14%) 2.1644
2m 22s (- 13m 59s) (7250 14%) 2.2096
2m 27s (- 13m 54s) (7500 15%) 2.0804
2m 32s (- 13m 49s) (7750 15%) 2.1003
2m 36s (- 13m 43s) (8000 16%) 2.0653
2m 41s (- 13m 39s) (8250 16%) 2.0542
2m 46s (- 13m 34s) (8500 17%) 2.0976
2m 51s (- 13m 30s) (8750 17%) 2.0354
2m 56s (- 13m 25s) (9000 18%) 2.0289
3m 1s (- 13m 20s) (9250 18%) 1.9037
3m 6s (- 13m 15s) (9500 19%) 1.9525
3m 11s (- 13m 9s) (9750 19%) 1.8598
3m 16s (- 13m 4s) (10000 20%) 1.9433
3m 20s (- 12m 58s) (10250 20%) 1.9164
3m 25s (- 12m 53s) (10500 21%) 1.8676
3m 30s (- 12m 47s) (10750 21%) 1.8912
3m 35s (- 12m 43s) (11000 22%) 1.8940
3m 39s (- 12m 37s) (11250 22%) 1.8391
3m 44s (- 12m 32s) (11500 23%) 1.9038
3m 49s (- 12m 28s) (11750 23%) 1.8223
3m 54s (- 12m 23s) (12000 24%) 1.7111
3m 59s (- 12m 18s) (12250 24%) 1.8238
4m 4s (- 12m 13s) (12500 25%) 1.7750
4m 9s (- 12m 8s) (12750 25%) 1.8930
4m 14s (- 12m 3s) (13000 26%) 1.7776
4m 19s (- 11m 59s) (13250 26%) 1.7633
4m 24s (- 11m 54s) (13500 27%) 1.7333
4m 29s (- 11m 49s) (13750 27%) 1.7893
4m 33s (- 11m 44s) (14000 28%) 1.7390
4m 38s (- 11m 39s) (14250 28%) 1.7701
4m 43s (- 11m 34s) (14500 28%) 1.7329
4m 48s (- 11m 29s) (14750 29%) 1.6696
4m 53s (- 11m 24s) (15000 30%) 1.6313
4m 58s (- 11m 20s) (15250 30%) 1.7256
5m 3s (- 11m 14s) (15500 31%) 1.6859
5m 8s (- 11m 10s) (15750 31%) 1.6195
5m 12s (- 11m 5s) (16000 32%) 1.5513
5m 17s (- 11m 0s) (16250 32%) 1.6846
5m 22s (- 10m 55s) (16500 33%) 1.6875
5m 27s (- 10m 49s) (16750 33%) 1.5778
5m 32s (- 10m 45s) (17000 34%) 1.6210
5m 37s (- 10m 40s) (17250 34%) 1.5758
5m 42s (- 10m 35s) (17500 35%) 1.5593
5m 46s (- 10m 30s) (17750 35%) 1.5810
5m 51s (- 10m 24s) (18000 36%) 1.5944
5m 55s (- 10m 18s) (18250 36%) 1.5053
6m 0s (- 10m 13s) (18500 37%) 1.4108
6m 5s (- 10m 8s) (18750 37%) 1.5082
6m 10s (- 10m 3s) (19000 38%) 1.5458
6m 14s (- 9m 58s) (19250 38%) 1.4254
6m 19s (- 9m 53s) (19500 39%) 1.4709
6m 24s (- 9m 48s) (19750 39%) 1.4742
6m 29s (- 9m 43s) (20000 40%) 1.3979
6m 33s (- 9m 38s) (20250 40%) 1.4668
6m 38s (- 9m 33s) (20500 41%) 1.4649
6m 43s (- 9m 28s) (20750 41%) 1.4709
6m 48s (- 9m 23s) (21000 42%) 1.4918
6m 53s (- 9m 19s) (21250 42%) 1.4107
6m 57s (- 9m 13s) (21500 43%) 1.4762
7m 2s (- 9m 9s) (21750 43%) 1.5225
7m 7s (- 9m 3s) (22000 44%) 1.4054
7m 12s (- 8m 58s) (22250 44%) 1.3352
7m 16s (- 8m 54s) (22500 45%) 1.3740
7m 21s (- 8m 49s) (22750 45%) 1.4333
7m 26s (- 8m 44s) (23000 46%) 1.3943
7m 31s (- 8m 39s) (23250 46%) 1.2736
7m 36s (- 8m 34s) (23500 47%) 1.3318
7m 41s (- 8m 29s) (23750 47%) 1.3693
7m 45s (- 8m 24s) (24000 48%) 1.3522
7m 50s (- 8m 19s) (24250 48%) 1.2736
7m 55s (- 8m 14s) (24500 49%) 1.3980
8m 0s (- 8m 9s) (24750 49%) 1.2201
8m 5s (- 8m 5s) (25000 50%) 1.2675
8m 10s (- 8m 0s) (25250 50%) 1.3469
8m 15s (- 7m 55s) (25500 51%) 1.2714
8m 20s (- 7m 51s) (25750 51%) 1.2665
8m 24s (- 7m 46s) (26000 52%) 1.2653
8m 29s (- 7m 41s) (26250 52%) 1.1929
8m 34s (- 7m 36s) (26500 53%) 1.2523
8m 39s (- 7m 31s) (26750 53%) 1.2691
8m 44s (- 7m 26s) (27000 54%) 1.1528
8m 49s (- 7m 22s) (27250 54%) 1.2370
8m 54s (- 7m 17s) (27500 55%) 1.2660
8m 59s (- 7m 12s) (27750 55%) 1.2506
9m 4s (- 7m 7s) (28000 56%) 1.2735
9m 9s (- 7m 2s) (28250 56%) 1.2148
9m 13s (- 6m 57s) (28500 56%) 1.2847
9m 19s (- 6m 53s) (28750 57%) 1.2118
9m 23s (- 6m 48s) (29000 57%) 1.1789
9m 28s (- 6m 43s) (29250 58%) 1.1460
9m 33s (- 6m 38s) (29500 59%) 1.1338
9m 38s (- 6m 33s) (29750 59%) 1.1070
9m 43s (- 6m 29s) (30000 60%) 1.2129
9m 48s (- 6m 24s) (30250 60%) 1.0972
9m 53s (- 6m 19s) (30500 61%) 1.0851
9m 58s (- 6m 14s) (30750 61%) 1.1832
10m 2s (- 6m 9s) (31000 62%) 1.0532
10m 7s (- 6m 4s) (31250 62%) 1.1463
10m 12s (- 5m 59s) (31500 63%) 1.0433
10m 17s (- 5m 54s) (31750 63%) 1.0821
10m 22s (- 5m 50s) (32000 64%) 1.0334
10m 28s (- 5m 45s) (32250 64%) 1.1181
10m 33s (- 5m 40s) (32500 65%) 1.1509
10m 38s (- 5m 36s) (32750 65%) 1.1036
10m 43s (- 5m 31s) (33000 66%) 1.0277
10m 48s (- 5m 26s) (33250 66%) 1.1785
10m 52s (- 5m 21s) (33500 67%) 1.0550
10m 57s (- 5m 16s) (33750 67%) 1.0629
11m 2s (- 5m 11s) (34000 68%) 1.0696
11m 7s (- 5m 6s) (34250 68%) 1.0918
11m 12s (- 5m 2s) (34500 69%) 1.0613
11m 17s (- 4m 57s) (34750 69%) 1.0352
11m 22s (- 4m 52s) (35000 70%) 1.0065
11m 27s (- 4m 47s) (35250 70%) 1.0674
11m 31s (- 4m 42s) (35500 71%) 1.0631
11m 36s (- 4m 37s) (35750 71%) 1.1001
11m 41s (- 4m 32s) (36000 72%) 1.0393
11m 46s (- 4m 28s) (36250 72%) 0.9400
11m 51s (- 4m 23s) (36500 73%) 1.0264
11m 55s (- 4m 18s) (36750 73%) 0.9909
12m 0s (- 4m 13s) (37000 74%) 0.9877
12m 5s (- 4m 8s) (37250 74%) 0.9790
12m 10s (- 4m 3s) (37500 75%) 0.8614
12m 15s (- 3m 58s) (37750 75%) 0.8985
12m 20s (- 3m 53s) (38000 76%) 0.9313
12m 25s (- 3m 49s) (38250 76%) 0.9810
12m 30s (- 3m 44s) (38500 77%) 0.8965
12m 35s (- 3m 39s) (38750 77%) 0.9325
12m 40s (- 3m 34s) (39000 78%) 0.9488
12m 45s (- 3m 29s) (39250 78%) 0.8820
12m 49s (- 3m 24s) (39500 79%) 0.9141
12m 54s (- 3m 19s) (39750 79%) 0.9451
12m 59s (- 3m 14s) (40000 80%) 0.8610
13m 4s (- 3m 9s) (40250 80%) 0.8987
13m 8s (- 3m 4s) (40500 81%) 0.9370
13m 13s (- 3m 0s) (40750 81%) 0.9663
13m 18s (- 2m 55s) (41000 82%) 0.8364
13m 22s (- 2m 50s) (41250 82%) 0.9296
13m 27s (- 2m 45s) (41500 83%) 0.8876
13m 32s (- 2m 40s) (41750 83%) 0.7837
13m 37s (- 2m 35s) (42000 84%) 0.8643
13m 41s (- 2m 30s) (42250 84%) 0.9092
13m 46s (- 2m 25s) (42500 85%) 0.8111
13m 51s (- 2m 20s) (42750 85%) 0.8668
13m 56s (- 2m 16s) (43000 86%) 0.8687
14m 1s (- 2m 11s) (43250 86%) 0.8701
14m 5s (- 2m 6s) (43500 87%) 0.8108
14m 10s (- 2m 1s) (43750 87%) 0.7329
14m 15s (- 1m 56s) (44000 88%) 0.8410
14m 20s (- 1m 51s) (44250 88%) 0.8041
14m 25s (- 1m 46s) (44500 89%) 0.7772
14m 30s (- 1m 42s) (44750 89%) 0.8702
14m 35s (- 1m 37s) (45000 90%) 0.8274
14m 40s (- 1m 32s) (45250 90%) 0.7602
14m 45s (- 1m 27s) (45500 91%) 0.8276
14m 50s (- 1m 22s) (45750 91%) 0.7752
14m 55s (- 1m 17s) (46000 92%) 0.7822
15m 0s (- 1m 12s) (46250 92%) 0.7470
15m 4s (- 1m 8s) (46500 93%) 0.7725
15m 9s (- 1m 3s) (46750 93%) 0.7477
15m 14s (- 0m 58s) (47000 94%) 0.7231
15m 19s (- 0m 53s) (47250 94%) 0.7538
15m 24s (- 0m 48s) (47500 95%) 0.8537
15m 28s (- 0m 43s) (47750 95%) 0.7798
15m 33s (- 0m 38s) (48000 96%) 0.7322
15m 38s (- 0m 34s) (48250 96%) 0.8085
15m 43s (- 0m 29s) (48500 97%) 0.7098
15m 48s (- 0m 24s) (48750 97%) 0.7215
15m 53s (- 0m 19s) (49000 98%) 0.8122
15m 58s (- 0m 14s) (49250 98%) 0.7791
16m 2s (- 0m 9s) (49500 99%) 0.7251
16m 7s (- 0m 4s) (49750 99%) 0.7874
16m 12s (- 0m 0s) (50000 100%) 0.7124
# Randomly pick up 10 sentence and observe the performance
evaluate_randomly(encoder, decoder, 10, reverse_source_sentence)
> je suis tres fier de nos etudiants .
= i m very proud of our students .
< i m very proud of you . <EOS>> vous etes faibles .
= you re weak .
< you re rude . <EOS>> tu n es pas si vieux .
= you re not that old .
< you re not that old . <EOS>> je songe a demissionner immediatement .
= i am thinking of resigning at once .
< i m thinking about the problem . <EOS>> je suis en retard sur le programme .
= i m behind schedule .
< i m behind schedule . <EOS>> je suis submerge de travail .
= i m swamped with work .
< i m proud of that . <EOS>> je ne vais pas prendre le moindre risque .
= i m not taking any chances .
< i m not taking any chances . <EOS>> je suis au restaurant .
= i m at the restaurant .
< i m in the office . <EOS>> c est toi la doyenne .
= you re the oldest .
< you re the oldest . <EOS>> je suis tres reconnaissant pour votre aide .
= i m very grateful for your help .
< i m very worried about you . <EOS>
在decoder中引入注意力机制
class AttentionDecoderLSTM(nn.Module):def __init__(self, hidden_size: int, output_size: int, dropout_p=0.1, max_length=MAX_LENGTH):"""DecoderLSTM with attention mechanism"""super(AttentionDecoderLSTM, self).__init__()self.hidden_size = hidden_sizeself.output_size = output_sizeself.dropout_p = dropout_pself.max_length = max_lengthself.embedding = nn.Embedding(self.output_size, self.hidden_size)self.attention = nn.Linear(self.hidden_size * 2, self.max_length)self.attention_combine = nn.Linear(self.hidden_size * 2, self.hidden_size)self.dropout = nn.Dropout(self.dropout_p)self.lstm = nn.LSTM(self.hidden_size, self.hidden_size)self.out = nn.Linear(self.hidden_size, self.output_size)self.activation_fn = F.reludef forward(self, inputs, state, encoder_outputs):"""ForwardArgs:inputs: [1, hidden_size]state : ([1, 1, hidden_size], [1, 1, hidden_size])encoder_outputs: [max_length, hidden_size]Returns:output:state: (hidden, cell)"""# embedded: [1, 1, hidden_size]embedded = self.embedding(inputs).view(1, 1, -1)embedded = self.dropout(embedded)(hidden, cell) = state# embedded[0]的size会变成[1,hidden_size]attention_weights = F.softmax(self.attention(torch.cat((embedded[0], hidden[0]), 1)), dim=1)# torch.bmm 是在batch的层面上对矩阵进行相乘的意思,比如说输入为[10, 3, 4]和[10, 4, 5],输出结果就是[10, 3, 5]# 下面代码中使用unsqueeze的目的就是添加batch这一维度,使得输出结果为[1,1,hiddensize]attention_applied = torch.bmm(attention_weights.unsqueeze(0),encoder_outputs.unsqueeze(0))# output: [1, hidden_size * 2]output = torch.cat((embedded[0], attention_applied[0]), 1)# output: [1, 1, hidden_size]output = self.attention_combine(output).unsqueeze(0)output = self.activation_fn(output)# output, [1, 1, hidden_size]output, (hidden, cell) = self.lstm(output, (hidden, cell))# output, [1, output_size]output = F.log_softmax(self.out(output[0]), dim=1)return output, (hidden, cell), attention_weightsdef init_hidden(self):"""Init hiddenReturns:hidden:cell:"""cell = torch.zeros(1, 1, self.hidden_size, device=device)hidden = torch.zeros(1, 1, self.hidden_size, device=device)return hidden, cell
def train_by_sentence_attn(input_tensor, target_tensor, encoder, decoder, encoder_optimizer, decoder_optimizer, loss_fn, use_teacher_forcing=True, reverse_source_sentence=True,max_length=MAX_LENGTH):"""Train by single sentence using EncoderLSTM and DecoderLSTMincluding training and update model, combining attention mechanism.Args:input_tensor: [input_sequence_len, 1, hidden_size]target_tensor: [target_sequence_len, 1, hidden_size]encoder: EncoderLSTMdecoder: DecoderLSTMencoder_optimizer: optimizer for encoderdecoder_optimizer: optimizer for decoderloss_fn: loss functionuse_teacher_forcing: True is to Feed the target as the next input, False is to use its own predictions as the next inputmax_length: max length for input and outputReturns:loss: scalar"""if reverse_source_sentence:input_tensor = torch.flip(input_tensor, [0])hidden, cell = encoder.init_hidden()encoder_optimizer.zero_grad()decoder_optimizer.zero_grad()input_length = input_tensor.size(0)target_length = target_tensor.size(0)encoder_outputs = torch.zeros(max_length, encoder.hidden_size, device=device)loss = 0# Get encoder outputsfor ei in range(input_length):encoder_output, (hidden, cell) = encoder(input_tensor[ei], (hidden, cell))encoder_outputs[ei] = encoder_output[0, 0]decoder_input = torch.tensor([[SOS_token]], device=device)decoder_hidden = (hidden, cell)for di in range(target_length):decoder_output, (hidden, cell), _ = decoder(decoder_input, (hidden, cell), encoder_outputs)if use_teacher_forcing:loss += loss_fn(decoder_output, target_tensor[di])decoder_input = target_tensor[di] # Teacher forcingelse:topv, topi = decoder_output.topk(1)decoder_input = topi.squeeze().detach()loss += loss_fn(decoder_output, target_tensor[di])if decoder_input.item() == EOS_token:breakloss.backward()encoder_optimizer.step()decoder_optimizer.step()return loss.item() / target_length
def train_attn(encoder, decoder, n_iters, reverse_source_sentence=True, use_teacher_forcing=True,print_every=1000, plot_every=100, learning_rate=0.01):"""Train of Seq2seq with attention Args:encoder: EncoderLSTMdecoder: DecoderLSTMn_iters: train with n_iters sentences without replacementreverse_source_sentence: True is to reverse the source sentence but keep order of target unchanged,False is to keep order of the source sentence target unchangeduse_teacher_forcing: True is to Feed the target as the next input, False is to use its own predictions as the next inputprint_every: print log every print_every plot_every: plot every plot_every learning_rate: """start = time.time()plot_losses = []print_loss_total = 0plot_loss_total = 0encoder_optimizer = optim.SGD(encoder.parameters(), lr=learning_rate)decoder_optimizer = optim.SGD(decoder.parameters(), lr=learning_rate)training_pairs = [tensor_from_pair(random.choice(pairs), input_lang, output_lang)for _ in range(n_iters)]loss_fn = nn.NLLLoss()for i in range(1, n_iters+1):training_pair = training_pairs[i-1]input_tensor = training_pair[0].to(device)target_tensor = training_pair[1].to(device) loss = train_by_sentence_attn(input_tensor, target_tensor, encoder, decoder,encoder_optimizer, decoder_optimizer, loss_fn, use_teacher_forcing=use_teacher_forcing,reverse_source_sentence=reverse_source_sentence)print_loss_total += lossplot_loss_total += lossif i % print_every == 0:# Print Lossprint_loss_avg = print_loss_total / print_everyprint_loss_total = 0print("%s (%d %d%%) %.4f" % (time_since(start, i / n_iters),i, i / n_iters * 100, print_loss_avg))if i % plot_every == 0:# Plotplot_loss_avg = plot_loss_total / plot_everyplot_losses.append(plot_loss_avg)plot_loss_total = 0# show plotshow_plot(plot_losses)
def evaluate_by_sentence_attn(encoder, decoder, sentence, reverse_source_sentence=True, max_length=MAX_LENGTH):"""Evalutae on a source sentence with model trained with attention mechanismArgs:encoderdecodersentencemax_lengthReturn:decoded_words: predicted sentence"""with torch.no_grad():input_tensor = tensor_from_sentence(input_lang, sentence).to(device)input_length = input_tensor.size(0)if reverse_source_sentence:input_tensor = torch.flip(input_tensor, [0])(hidden, cell) = encoder.init_hidden()# encoder outputs: [max_length, hidden_size]encoder_outputs = torch.zeros(max_length, encoder.hidden_size, device=device)for ei in range(input_length):encoder_output, (hidden, cell) = encoder(input_tensor[ei],(hidden, cell))encoder_outputs[ei] += encoder_output[0, 0]decoder_input = torch.tensor([[SOS_token]], device=device)decoder_hidden = (hidden, cell)decoded_words = []decoder_attentions = torch.zeros(max_length, max_length)for di in range(max_length):decoder_output, (hidden, cell), decoder_attention = \decoder(decoder_input, (hidden, cell), encoder_outputs)topv, topi = decoder_output.data.topk(1)# 获取attentiondecoder_attentions[di] = decoder_attention.dataif topi.item() == EOS_token:decoded_words.append("<EOS>")breakelse:decoded_words.append(output_lang.index2word[topi.item()])decoder_input = topi.squeeze().detach()return decoded_words, decoder_attentions[:di + 1]
def show_attention(input_sentence, output_words, attentions):"""绘制输入语句和输出语句之间的注意力关系"""# Set up figure with colorbarfig = plt.figure()ax = fig.add_subplot(111)cax = ax.matshow(attentions.numpy(), cmap='bone')fig.colorbar(cax)ax.set_xticklabels([''] + input_sentence.split(' ') +['<EOS>'], rotation=90)ax.set_yticklabels([''] + output_words)ax.xaxis.set_major_locator(ticker.MultipleLocator(1))ax.yaxis.set_major_locator(ticker.MultipleLocator(1))plt.show()
def evaluate_and_show_attention(input_sentence, encoder, decoder):"""Evaluate and show attention for a input sentence"""output_words, attentions = evaluate_by_sentence_attn(encoder, decoder, input_sentence)print('input =', input_sentence)print('output =', ' '.join(output_words))show_attention(input_sentence, output_words, attentions)
setup_seed(45)
hidden_size = 256
# Reverse the order of source input sentence
reverse_source_sentence = True
# Feed the target as the next input
use_teacher_forcing = True
encoder = EncoderLSTM(input_lang.n_words, hidden_size).to(device)
decoder = AttentionDecoderLSTM(hidden_size, output_lang.n_words).to(device)
print(">> Model is on: {}".format(next(encoder.parameters()).is_cuda))
print(">> Model is on: {}".format(next(decoder.parameters()).is_cuda))
>> Model is on: True
>> Model is on: True
iters = 50000
train_attn(encoder, decoder, iters, reverse_source_sentence=reverse_source_sentence, use_teacher_forcing=use_teacher_forcing,print_every=250, plot_every=250)
0m 8s (- 28m 13s) (250 0%) 4.9045
0m 16s (- 26m 50s) (500 1%) 3.3049
0m 24s (- 26m 20s) (750 1%) 3.0303
0m 31s (- 25m 56s) (1000 2%) 2.8569
0m 39s (- 25m 39s) (1250 2%) 2.8321
0m 47s (- 25m 25s) (1500 3%) 2.7320
0m 54s (- 25m 7s) (1750 3%) 2.6652
1m 4s (- 25m 36s) (2000 4%) 2.6601
1m 12s (- 25m 40s) (2250 4%) 2.6078
1m 20s (- 25m 37s) (2500 5%) 2.5934
1m 29s (- 25m 42s) (2750 5%) 2.5855
1m 38s (- 25m 39s) (3000 6%) 2.4951
1m 47s (- 25m 39s) (3250 6%) 2.4717
1m 55s (- 25m 40s) (3500 7%) 2.4963
2m 4s (- 25m 37s) (3750 7%) 2.4969
2m 13s (- 25m 35s) (4000 8%) 2.3994
2m 22s (- 25m 31s) (4250 8%) 2.3991
2m 31s (- 25m 28s) (4500 9%) 2.3072
2m 39s (- 25m 23s) (4750 9%) 2.2703
2m 48s (- 25m 18s) (5000 10%) 2.2681
2m 57s (- 25m 12s) (5250 10%) 2.3253
3m 6s (- 25m 8s) (5500 11%) 2.3062
3m 15s (- 25m 3s) (5750 11%) 2.2431
3m 24s (- 25m 1s) (6000 12%) 2.2481
3m 33s (- 24m 57s) (6250 12%) 2.2147
3m 42s (- 24m 51s) (6500 13%) 2.1789
3m 51s (- 24m 41s) (6750 13%) 2.0689
3m 59s (- 24m 33s) (7000 14%) 2.1624
4m 8s (- 24m 26s) (7250 14%) 2.1783
4m 17s (- 24m 19s) (7500 15%) 2.0618
4m 26s (- 24m 11s) (7750 15%) 2.0752
4m 34s (- 24m 2s) (8000 16%) 2.0351
4m 43s (- 23m 55s) (8250 16%) 2.0126
4m 52s (- 23m 50s) (8500 17%) 2.0556
5m 2s (- 23m 44s) (8750 17%) 2.0194
5m 10s (- 23m 36s) (9000 18%) 2.0155
5m 19s (- 23m 28s) (9250 18%) 1.8749
5m 28s (- 23m 19s) (9500 19%) 1.9430
5m 36s (- 23m 9s) (9750 19%) 1.8124
5m 45s (- 23m 1s) (10000 20%) 1.9113
5m 54s (- 22m 53s) (10250 20%) 1.8959
6m 2s (- 22m 44s) (10500 21%) 1.8226
6m 11s (- 22m 35s) (10750 21%) 1.8846
6m 20s (- 22m 27s) (11000 22%) 1.8598
6m 28s (- 22m 18s) (11250 22%) 1.8070
6m 37s (- 22m 9s) (11500 23%) 1.8770
6m 46s (- 22m 2s) (11750 23%) 1.7991
6m 54s (- 21m 52s) (12000 24%) 1.6979
7m 3s (- 21m 44s) (12250 24%) 1.7849
7m 11s (- 21m 34s) (12500 25%) 1.7383
7m 20s (- 21m 26s) (12750 25%) 1.8461
7m 29s (- 21m 18s) (13000 26%) 1.7735
7m 38s (- 21m 11s) (13250 26%) 1.7250
7m 46s (- 21m 2s) (13500 27%) 1.7031
7m 55s (- 20m 54s) (13750 27%) 1.7557
8m 4s (- 20m 46s) (14000 28%) 1.7034
8m 13s (- 20m 38s) (14250 28%) 1.7474
8m 22s (- 20m 30s) (14500 28%) 1.7002
8m 31s (- 20m 21s) (14750 29%) 1.6098
8m 39s (- 20m 11s) (15000 30%) 1.6132
8m 47s (- 20m 3s) (15250 30%) 1.7066
8m 56s (- 19m 54s) (15500 31%) 1.6781
9m 5s (- 19m 45s) (15750 31%) 1.5791
9m 13s (- 19m 35s) (16000 32%) 1.5366
9m 22s (- 19m 27s) (16250 32%) 1.6449
9m 30s (- 19m 18s) (16500 33%) 1.6655
9m 38s (- 19m 9s) (16750 33%) 1.5604
9m 47s (- 19m 0s) (17000 34%) 1.5838
9m 56s (- 18m 52s) (17250 34%) 1.5468
10m 4s (- 18m 43s) (17500 35%) 1.5449
10m 12s (- 18m 33s) (17750 35%) 1.5705
10m 20s (- 18m 22s) (18000 36%) 1.5782
10m 27s (- 18m 11s) (18250 36%) 1.4957
10m 35s (- 18m 1s) (18500 37%) 1.3964
10m 42s (- 17m 51s) (18750 37%) 1.4830
10m 50s (- 17m 41s) (19000 38%) 1.5285
10m 58s (- 17m 31s) (19250 38%) 1.4140
11m 5s (- 17m 21s) (19500 39%) 1.4499
11m 12s (- 17m 10s) (19750 39%) 1.4234
11m 20s (- 17m 0s) (20000 40%) 1.4060
11m 27s (- 16m 49s) (20250 40%) 1.4435
11m 35s (- 16m 40s) (20500 41%) 1.4514
11m 42s (- 16m 30s) (20750 41%) 1.4453
11m 50s (- 16m 20s) (21000 42%) 1.4741
11m 57s (- 16m 11s) (21250 42%) 1.3784
12m 5s (- 16m 1s) (21500 43%) 1.4439
12m 12s (- 15m 51s) (21750 43%) 1.4836
12m 20s (- 15m 42s) (22000 44%) 1.3703
12m 27s (- 15m 32s) (22250 44%) 1.3226
12m 36s (- 15m 24s) (22500 45%) 1.3610
12m 43s (- 15m 14s) (22750 45%) 1.4003
12m 51s (- 15m 5s) (23000 46%) 1.3914
12m 59s (- 14m 56s) (23250 46%) 1.2699
13m 7s (- 14m 47s) (23500 47%) 1.2957
13m 14s (- 14m 38s) (23750 47%) 1.3403
13m 22s (- 14m 29s) (24000 48%) 1.3439
13m 30s (- 14m 20s) (24250 48%) 1.2482
13m 37s (- 14m 10s) (24500 49%) 1.3789
13m 44s (- 14m 1s) (24750 49%) 1.1900
13m 52s (- 13m 52s) (25000 50%) 1.2474
14m 0s (- 13m 43s) (25250 50%) 1.3320
14m 7s (- 13m 34s) (25500 51%) 1.2478
14m 15s (- 13m 25s) (25750 51%) 1.2392
14m 23s (- 13m 16s) (26000 52%) 1.2369
14m 30s (- 13m 8s) (26250 52%) 1.1629
14m 38s (- 12m 59s) (26500 53%) 1.2625
14m 46s (- 12m 50s) (26750 53%) 1.2236
14m 54s (- 12m 41s) (27000 54%) 1.1323
15m 2s (- 12m 33s) (27250 54%) 1.2009
15m 9s (- 12m 24s) (27500 55%) 1.2412
15m 17s (- 12m 15s) (27750 55%) 1.2053
15m 25s (- 12m 6s) (28000 56%) 1.2504
15m 33s (- 11m 58s) (28250 56%) 1.1889
15m 40s (- 11m 49s) (28500 56%) 1.2637
15m 48s (- 11m 41s) (28750 57%) 1.2014
15m 56s (- 11m 32s) (29000 57%) 1.1773
16m 4s (- 11m 23s) (29250 58%) 1.1245
16m 11s (- 11m 15s) (29500 59%) 1.1128
16m 19s (- 11m 6s) (29750 59%) 1.1001
16m 26s (- 10m 57s) (30000 60%) 1.2020
16m 34s (- 10m 49s) (30250 60%) 1.0931
16m 41s (- 10m 40s) (30500 61%) 1.0847
16m 49s (- 10m 31s) (30750 61%) 1.1683
16m 56s (- 10m 23s) (31000 62%) 1.0578
17m 4s (- 10m 14s) (31250 62%) 1.1204
17m 12s (- 10m 6s) (31500 63%) 1.0375
17m 20s (- 9m 58s) (31750 63%) 1.0673
17m 29s (- 9m 50s) (32000 64%) 1.0291
17m 37s (- 9m 42s) (32250 64%) 1.1162
17m 45s (- 9m 33s) (32500 65%) 1.1235
17m 54s (- 9m 25s) (32750 65%) 1.0975
18m 3s (- 9m 17s) (33000 66%) 1.0241
18m 11s (- 9m 9s) (33250 66%) 1.1561
18m 19s (- 9m 1s) (33500 67%) 1.0525
18m 28s (- 8m 53s) (33750 67%) 1.0605
18m 37s (- 8m 45s) (34000 68%) 1.0551
18m 45s (- 8m 37s) (34250 68%) 1.0728
18m 53s (- 8m 29s) (34500 69%) 1.0422
19m 1s (- 8m 21s) (34750 69%) 1.0189
19m 10s (- 8m 13s) (35000 70%) 0.9925
19m 19s (- 8m 5s) (35250 70%) 1.0572
19m 28s (- 7m 57s) (35500 71%) 1.0211
19m 36s (- 7m 48s) (35750 71%) 1.0857
19m 45s (- 7m 40s) (36000 72%) 1.0427
19m 54s (- 7m 32s) (36250 72%) 0.9366
20m 2s (- 7m 24s) (36500 73%) 1.0282
20m 11s (- 7m 16s) (36750 73%) 0.9766
20m 20s (- 7m 8s) (37000 74%) 0.9918
20m 28s (- 7m 0s) (37250 74%) 0.9713
20m 37s (- 6m 52s) (37500 75%) 0.8682
20m 46s (- 6m 44s) (37750 75%) 0.8866
20m 55s (- 6m 36s) (38000 76%) 0.9383
21m 3s (- 6m 28s) (38250 76%) 0.9673
21m 12s (- 6m 20s) (38500 77%) 0.8919
21m 20s (- 6m 11s) (38750 77%) 0.9282
21m 29s (- 6m 3s) (39000 78%) 0.9390
21m 37s (- 5m 55s) (39250 78%) 0.8759
21m 46s (- 5m 47s) (39500 79%) 0.9049
21m 54s (- 5m 39s) (39750 79%) 0.9363
22m 3s (- 5m 30s) (40000 80%) 0.8546
22m 11s (- 5m 22s) (40250 80%) 0.8826
22m 19s (- 5m 14s) (40500 81%) 0.9364
22m 28s (- 5m 6s) (40750 81%) 0.9326
22m 37s (- 4m 57s) (41000 82%) 0.8726
22m 46s (- 4m 49s) (41250 82%) 0.9256
22m 54s (- 4m 41s) (41500 83%) 0.8858
23m 3s (- 4m 33s) (41750 83%) 0.7803
23m 11s (- 4m 25s) (42000 84%) 0.8532
23m 20s (- 4m 16s) (42250 84%) 0.9056
23m 29s (- 4m 8s) (42500 85%) 0.7939
23m 37s (- 4m 0s) (42750 85%) 0.8685
23m 46s (- 3m 52s) (43000 86%) 0.8675
23m 55s (- 3m 43s) (43250 86%) 0.8868
24m 3s (- 3m 35s) (43500 87%) 0.8165
24m 11s (- 3m 27s) (43750 87%) 0.7273
24m 20s (- 3m 19s) (44000 88%) 0.8150
24m 29s (- 3m 10s) (44250 88%) 0.8015
24m 37s (- 3m 2s) (44500 89%) 0.7703
24m 45s (- 2m 54s) (44750 89%) 0.8699
24m 54s (- 2m 46s) (45000 90%) 0.8267
25m 4s (- 2m 37s) (45250 90%) 0.7528
25m 12s (- 2m 29s) (45500 91%) 0.8305
25m 21s (- 2m 21s) (45750 91%) 0.7830
25m 29s (- 2m 13s) (46000 92%) 0.8001
25m 38s (- 2m 4s) (46250 92%) 0.7384
25m 47s (- 1m 56s) (46500 93%) 0.7825
25m 55s (- 1m 48s) (46750 93%) 0.7710
26m 4s (- 1m 39s) (47000 94%) 0.7454
26m 12s (- 1m 31s) (47250 94%) 0.7490
26m 21s (- 1m 23s) (47500 95%) 0.8604
26m 29s (- 1m 14s) (47750 95%) 0.7776
26m 36s (- 1m 6s) (48000 96%) 0.7263
26m 44s (- 0m 58s) (48250 96%) 0.8098
26m 51s (- 0m 49s) (48500 97%) 0.6916
26m 59s (- 0m 41s) (48750 97%) 0.7064
27m 6s (- 0m 33s) (49000 98%) 0.8158
27m 14s (- 0m 24s) (49250 98%) 0.7894
27m 21s (- 0m 16s) (49500 99%) 0.7465
27m 27s (- 0m 8s) (49750 99%) 0.7975
27m 35s (- 0m 0s) (50000 100%) 0.7190
evaluate_and_show_attention("elle a cinq ans de moins que moi .", encoder, decoder)evaluate_and_show_attention("elle est trop petit .", encoder, decoder)evaluate_and_show_attention("je ne crains pas de mourir .", encoder, decoder)evaluate_and_show_attention("c est un jeune directeur plein de talent .", encoder, decoder)
input = elle a cinq ans de moins que moi .
output = she is two years younger than me . <EOS>
input = elle est trop petit .
output = she is too drunk . <EOS>
input = je ne crains pas de mourir .
output = i m not afraid of making mistakes . <EOS>
input = c est un jeune directeur plein de talent .
output = he s a very talented writer . <EOS>
使用RNN构建机器翻译模型相关推荐
- Keras深度学习实战(35)——构建机器翻译模型
Keras深度学习实战(35)--构建机器翻译模型 0. 前言 1. 模型与数据集分析 1.1 模型分析 1.2 数据集分析 2. 实现机器翻译模型 2.1 预处理数据 2.2 传统多对多架构 2.3 ...
- 在PyTorch中使用Seq2Seq构建的神经机器翻译模型
在这篇文章中,我们将构建一个基于LSTM的Seq2Seq模型,使用编码器-解码器架构进行机器翻译. 本篇文章内容: 介绍 数据准备和预处理 长短期记忆(LSTM) - 背景知识 编码器模型架构(Seq ...
- 可视化神经机器翻译模型(基于注意力机制的Seq2seq模型)
可视化神经机器翻译模型(基于注意力机制的Seq2seq模型) 序列到序列模型是深度学习模型,在机器翻译.文本摘要和图像字幕等任务中取得了很大的成功.谷歌翻译在2016年底开始在生产中使用这样的模型 ...
- lstm 变长序列_keras在构建LSTM模型时对变长序列的处理操作
我就废话不多说了,大家还是直接看代码吧~ print(np.shape(X))#(1920, 45, 20) X=sequence.pad_sequences(X, maxlen=100, paddi ...
- [Python人工智能] 三十四.Bert模型 (3)keras-bert库构建Bert模型实现微博情感分析
从本专栏开始,作者正式研究Python深度学习.神经网络及人工智能相关知识.前一篇文章开启了新的内容--Bert,首先介绍Keras-bert库安装及基础用法及文本分类工作.这篇文章将通过keras- ...
- [Pytorch系列-58]:循环神经网络 - 词向量的自动构建与模型训练代码示例
作者主页(文火冰糖的硅基工坊):文火冰糖(王文兵)的博客_文火冰糖的硅基工坊_CSDN博客 本文网址:https://blog.csdn.net/HiWangWenBing/article/detai ...
- 机器翻译模型一多层LSTM__Pytorch实现
目录 1.机器翻译之Seq2Seq介绍 2.基于Pytorch的Seq2Seq实现 2.1数据准备 2.2模型建立 2.3训练 1.机器翻译之Seq2Seq介绍 Seq2Seq模型是一个序列到序列的模 ...
- Keras深度学习实战(36)——基于编码器-解码器的机器翻译模型
Keras深度学习实战(36)--基于编码器-解码器的机器翻译模型 0. 前言 1. 模型与数据集分析 1.1 数据集分析 1.2 模型分析 2. 基于编码器-解码器结构的机器翻译模型 2.1 基于编 ...
- 【阅读笔记】应用LRP,通过将相关性从模型的输出层反向传播到其输入层来解释基于RNN的DKT模型(一)
提示:Towards Interpretable Deep Learning Models for Knowledge Tracing将重点放在应用分层相关传播(LRP)方法,通过将相关性从模型的输出 ...
- 【阅读笔记】应用LRP,通过将相关性从模型的输出层反向传播到其输入层来解释基于RNN的DKT模型(二)
提示:Interpreting Deep Learning Models for Knowledge Tracing与Towards Interpretable Deep Learning Model ...
最新文章
- 题目1526:朋友圈
- Windows自动启动程序的十大藏身之所
- C语言中两种方式表示时间日期值time_t和struct tm类型的相互转换
- C++ STL Pair
- 超市买苹果变量的定义和使用
- 如何使用JS来开发室内地图商场停车场车位管理系统
- 单尺度Retinex算法学习
- pop3邮箱怎么设置收发服务器端口,pop3设置(如何设置邮箱服务器?IMAP、POP3有何区别?)...
- 朋友去面试阿里蚂蚁金服测试岗位过程经历
- Google支付订单真伪的验证方式
- 什么是组态?组态的概念及发展趋势
- Unity全面入门笔记6-常用数学类型
- 【七夕特效】 -- 满屏爱心
- Oracle中多表查询再按时间倒序
- 杨卫华:新浪微博的架构发展历程(转)
- matlab dvb,DVB-T OFDM Matlab仿真的整理和疑问
- 中标麒麟系统安装达梦8 数据库
- Shell编程扩展正则表达式(egrep、awk)
- 快速查看Mac下软件的所有快捷键— CheatSheet
- 疫情面试了13家企业软件测试岗位,面试题整理汇总,真的牛
热门文章
- android 车票预定接口,聚合数据Android SDK 12306火车票查询订票演示示例 编辑
- 橘子学设计模式之原型模式
- 你离运营只差一个打卡签到功能 早晚安打卡 小来早晚安打卡 功能一样 是一个唯一用户主动去分享的功能
- VS2022 安装 .NET Framework 4.0的方法
- 2020计算机专业保研夏令营面经:中科院计算所网数机试题目
- mysql 子链接_MySQL多表查询实例详解【链接查询、子查询等】
- 英语基础语法-语态(被动语态Be done)
- linux audit原理,Wauzh原理简析及audit规则风险评估
- reflections歌词翻译_花木兰主题曲Reflection翻译成中文的准确歌词
- 数据结构名词解释详细总结