• 原文来源:https://keras.io/examples/nlp/lstm_seq2seq/

作者:fchollet

创建日期:2017/09/29

最后修正:2020/04/26

描述:字母级循环序列到序列模型。

View in Colab:https://colab.research.google.com/github/keras-team/keras-io/blob/master/examples/nlp/ipynb/lstm_seq2seq.ipynb

Github source:https://github.com/keras-team/keras-io/blob/master/examples/nlp/lstm_seq2seq.py

简介

本范例演示了如何实现一个基础的、字母级的循环序列到序列(sequence-to-sequence)模型。我们把英语短句翻译成法语短句,而且是一个字母一个字母翻译的。请注意,进行字母级机器翻译是相当罕见的,因为单词级的模型在这个领域更为常见。

算法总结

  • 我们从一个域的输入序列(例如英语句子),到另一个域的对应的目标序列(例如法语句子)开始。
  • 使用一个LSTM编码器将输入序列转换为2个状态向量(我们保留最后的LSTM状态向量,忽略输出向量)。
  • 再训练一个LSTM解码器,将目标序列转换为相同的序列,但向后偏移一个时间步(未来),这个训练过程在这里称为“教师强迫(teacher forcing)”。它使用来自编码器的状态向量作为初始状态。解码器根据给出的目标[...t],学习生成目标[t+1…],条件是输入序列。
  • 在推理模式下,当我们想要解码未知的输入序列时,我们会按以下步骤去做:
    • —— 将输入序列编码为状态向量
    • —— 目标序列的大小从1开始(只有start-of-sequence字符)
    • —— 根据状态向量和1个字符的目标序列送入解码器,产生预测的下一个字符
    • —— 对预测的很多字符进行采样,以得到下一个字符(简单的方法是使用argmax)
    • —— 将得到的字符追加到目标序列的末尾
    • ——重复上述过程,直到序列结束符或超出字符限制

导入库

import numpy as np
import tensorflow as tf
from tensorflow import keras

下载数据

!!curl -O http://www.manythings.org/anki/fra-eng.zip
!!unzip fra-eng.zip
['Archive:  fra-eng.zip','  inflating: _about.txt              ','  inflating: fra.txt                 ']

配置超参数

batch_size = 64  # Batch size for training.
epochs = 100  # Number of epochs to train for.
latent_dim = 256  # Latent dimensionality of the encoding space.
num_samples = 10000  # Number of samples to train on.
# Path to the data txt file on disk.
data_path = "fra.txt"

准备数据

# Vectorize the data.
input_texts = []
target_texts = []
input_characters = set()
target_characters = set()
with open(data_path, "r", encoding="utf-8") as f:lines = f.read().split("\n")
for line in lines[: min(num_samples, len(lines) - 1)]:input_text, target_text, _ = line.split("\t")# We use "tab" as the "start sequence" character# for the targets, and "\n" as "end sequence" character.target_text = "\t" + target_text + "\n"input_texts.append(input_text)target_texts.append(target_text)for char in input_text:if char not in input_characters:input_characters.add(char)for char in target_text:if char not in target_characters:target_characters.add(char)input_characters = sorted(list(input_characters))
target_characters = sorted(list(target_characters))
num_encoder_tokens = len(input_characters)
num_decoder_tokens = len(target_characters)
max_encoder_seq_length = max([len(txt) for txt in input_texts])
max_decoder_seq_length = max([len(txt) for txt in target_texts])print("Number of samples:", len(input_texts))
print("Number of unique input tokens:", num_encoder_tokens)
print("Number of unique output tokens:", num_decoder_tokens)
print("Max sequence length for inputs:", max_encoder_seq_length)
print("Max sequence length for outputs:", max_decoder_seq_length)input_token_index = dict([(char, i) for i, char in enumerate(input_characters)])
target_token_index = dict([(char, i) for i, char in enumerate(target_characters)])encoder_input_data = np.zeros((len(input_texts), max_encoder_seq_length, num_encoder_tokens), dtype="float32"
)
decoder_input_data = np.zeros((len(input_texts), max_decoder_seq_length, num_decoder_tokens), dtype="float32"
)
decoder_target_data = np.zeros((len(input_texts), max_decoder_seq_length, num_decoder_tokens), dtype="float32"
)for i, (input_text, target_text) in enumerate(zip(input_texts, target_texts)):for t, char in enumerate(input_text):encoder_input_data[i, t, input_token_index[char]] = 1.0encoder_input_data[i, t + 1 :, input_token_index[" "]] = 1.0for t, char in enumerate(target_text):# decoder_target_data is ahead of decoder_input_data by one timestepdecoder_input_data[i, t, target_token_index[char]] = 1.0if t > 0:# decoder_target_data will be ahead by one timestep# and will not include the start character.decoder_target_data[i, t - 1, target_token_index[char]] = 1.0decoder_input_data[i, t + 1 :, target_token_index[" "]] = 1.0decoder_target_data[i, t:, target_token_index[" "]] = 1.0
Number of samples: 10000
Number of unique input tokens: 71
Number of unique output tokens: 93
Max sequence length for inputs: 16
Max sequence length for outputs: 59

构建模型

# Define an input sequence and process it.
encoder_inputs = keras.Input(shape=(None, num_encoder_tokens))
encoder = keras.layers.LSTM(latent_dim, return_state=True)
encoder_outputs, state_h, state_c = encoder(encoder_inputs)# We discard `encoder_outputs` and only keep the states.
encoder_states = [state_h, state_c]# Set up the decoder, using `encoder_states` as initial state.
decoder_inputs = keras.Input(shape=(None, num_decoder_tokens))# We set up our decoder to return full output sequences,
# and to return internal states as well. We don't use the
# return states in the training model, but we will use them in inference.
decoder_lstm = keras.layers.LSTM(latent_dim, return_sequences=True, return_state=True)
decoder_outputs, _, _ = decoder_lstm(decoder_inputs, initial_state=encoder_states)
decoder_dense = keras.layers.Dense(num_decoder_tokens, activation="softmax")
decoder_outputs = decoder_dense(decoder_outputs)# Define the model that will turn
# `encoder_input_data` & `decoder_input_data` into `decoder_target_data`
model = keras.Model([encoder_inputs, decoder_inputs], decoder_outputs)

训练模型

model.compile(optimizer="rmsprop", loss="categorical_crossentropy", metrics=["accuracy"]
)
model.fit([encoder_input_data, decoder_input_data],decoder_target_data,batch_size=batch_size,epochs=epochs,validation_split=0.2,
)
# Save model
model.save("s2s")
Epoch 1/100
125/125 [==============================] - 2s 16ms/step - loss: 1.1806 - accuracy: 0.7246 - val_loss: 1.0825 - val_accuracy: 0.6995
Epoch 2/100
125/125 [==============================] - 1s 11ms/step - loss: 0.8599 - accuracy: 0.7671 - val_loss: 0.8524 - val_accuracy: 0.7646
Epoch 3/100
125/125 [==============================] - 1s 11ms/step - loss: 0.6867 - accuracy: 0.8069 - val_loss: 0.7129 - val_accuracy: 0.7928
Epoch 4/100
125/125 [==============================] - 1s 11ms/step - loss: 0.5982 - accuracy: 0.8262 - val_loss: 0.6547 - val_accuracy: 0.8111
Epoch 5/100
125/125 [==============================] - 1s 11ms/step - loss: 0.5490 - accuracy: 0.8398 - val_loss: 0.6407 - val_accuracy: 0.8114
Epoch 6/100
125/125 [==============================] - 1s 11ms/step - loss: 0.5140 - accuracy: 0.8489 - val_loss: 0.5834 - val_accuracy: 0.8288
Epoch 7/100
125/125 [==============================] - 1s 11ms/step - loss: 0.4854 - accuracy: 0.8569 - val_loss: 0.5577 - val_accuracy: 0.8357
Epoch 8/100
125/125 [==============================] - 1s 11ms/step - loss: 0.4613 - accuracy: 0.8632 - val_loss: 0.5384 - val_accuracy: 0.8407
Epoch 9/100
125/125 [==============================] - 1s 11ms/step - loss: 0.4405 - accuracy: 0.8691 - val_loss: 0.5255 - val_accuracy: 0.8435
Epoch 10/100
125/125 [==============================] - 1s 11ms/step - loss: 0.4219 - accuracy: 0.8743 - val_loss: 0.5049 - val_accuracy: 0.8497
Epoch 11/100
125/125 [==============================] - 1s 11ms/step - loss: 0.4042 - accuracy: 0.8791 - val_loss: 0.4986 - val_accuracy: 0.8522
Epoch 12/100
125/125 [==============================] - 1s 11ms/step - loss: 0.3888 - accuracy: 0.8836 - val_loss: 0.4854 - val_accuracy: 0.8552
Epoch 13/100
125/125 [==============================] - 1s 11ms/step - loss: 0.3735 - accuracy: 0.8883 - val_loss: 0.4754 - val_accuracy: 0.8586
Epoch 14/100
125/125 [==============================] - 1s 11ms/step - loss: 0.3595 - accuracy: 0.8915 - val_loss: 0.4753 - val_accuracy: 0.8589
Epoch 15/100
125/125 [==============================] - 1s 11ms/step - loss: 0.3467 - accuracy: 0.8956 - val_loss: 0.4611 - val_accuracy: 0.8634
Epoch 16/100
125/125 [==============================] - 1s 11ms/step - loss: 0.3346 - accuracy: 0.8991 - val_loss: 0.4535 - val_accuracy: 0.8658
Epoch 17/100
125/125 [==============================] - 1s 11ms/step - loss: 0.3231 - accuracy: 0.9025 - val_loss: 0.4504 - val_accuracy: 0.8665
Epoch 18/100
125/125 [==============================] - 1s 11ms/step - loss: 0.3120 - accuracy: 0.9059 - val_loss: 0.4442 - val_accuracy: 0.8699
Epoch 19/100
125/125 [==============================] - 1s 10ms/step - loss: 0.3015 - accuracy: 0.9088 - val_loss: 0.4439 - val_accuracy: 0.8692
Epoch 20/100
125/125 [==============================] - 1s 11ms/step - loss: 0.2917 - accuracy: 0.9118 - val_loss: 0.4415 - val_accuracy: 0.8712
Epoch 21/100
125/125 [==============================] - 1s 10ms/step - loss: 0.2821 - accuracy: 0.9147 - val_loss: 0.4372 - val_accuracy: 0.8722
Epoch 22/100
125/125 [==============================] - 1s 11ms/step - loss: 0.2731 - accuracy: 0.9174 - val_loss: 0.4424 - val_accuracy: 0.8713
Epoch 23/100
125/125 [==============================] - 1s 11ms/step - loss: 0.2642 - accuracy: 0.9201 - val_loss: 0.4371 - val_accuracy: 0.8725
Epoch 24/100
125/125 [==============================] - 1s 11ms/step - loss: 0.2561 - accuracy: 0.9226 - val_loss: 0.4400 - val_accuracy: 0.8728
Epoch 25/100
125/125 [==============================] - 1s 11ms/step - loss: 0.2481 - accuracy: 0.9245 - val_loss: 0.4358 - val_accuracy: 0.8757
Epoch 26/100
125/125 [==============================] - 1s 11ms/step - loss: 0.2404 - accuracy: 0.9270 - val_loss: 0.4407 - val_accuracy: 0.8746
Epoch 27/100
125/125 [==============================] - 1s 11ms/step - loss: 0.2332 - accuracy: 0.9294 - val_loss: 0.4462 - val_accuracy: 0.8736
Epoch 28/100
125/125 [==============================] - 1s 11ms/step - loss: 0.2263 - accuracy: 0.9310 - val_loss: 0.4436 - val_accuracy: 0.8736
Epoch 29/100
125/125 [==============================] - 1s 11ms/step - loss: 0.2194 - accuracy: 0.9328 - val_loss: 0.4411 - val_accuracy: 0.8755
Epoch 30/100
125/125 [==============================] - 1s 11ms/step - loss: 0.2126 - accuracy: 0.9351 - val_loss: 0.4457 - val_accuracy: 0.8755
Epoch 31/100
125/125 [==============================] - 1s 11ms/step - loss: 0.2069 - accuracy: 0.9370 - val_loss: 0.4498 - val_accuracy: 0.8752
Epoch 32/100
125/125 [==============================] - 1s 11ms/step - loss: 0.2010 - accuracy: 0.9388 - val_loss: 0.4518 - val_accuracy: 0.8755
Epoch 33/100
125/125 [==============================] - 1s 11ms/step - loss: 0.1953 - accuracy: 0.9404 - val_loss: 0.4545 - val_accuracy: 0.8758
Epoch 34/100
125/125 [==============================] - 1s 11ms/step - loss: 0.1897 - accuracy: 0.9423 - val_loss: 0.4547 - val_accuracy: 0.8769
Epoch 35/100
125/125 [==============================] - 1s 11ms/step - loss: 0.1846 - accuracy: 0.9435 - val_loss: 0.4582 - val_accuracy: 0.8763
Epoch 36/100
125/125 [==============================] - 1s 11ms/step - loss: 0.1794 - accuracy: 0.9451 - val_loss: 0.4653 - val_accuracy: 0.8755
Epoch 37/100
125/125 [==============================] - 1s 10ms/step - loss: 0.1747 - accuracy: 0.9464 - val_loss: 0.4633 - val_accuracy: 0.8768
Epoch 38/100
125/125 [==============================] - 1s 11ms/step - loss: 0.1700 - accuracy: 0.9479 - val_loss: 0.4665 - val_accuracy: 0.8772
Epoch 39/100
125/125 [==============================] - 1s 11ms/step - loss: 0.1657 - accuracy: 0.9493 - val_loss: 0.4725 - val_accuracy: 0.8755
Epoch 40/100
125/125 [==============================] - 1s 11ms/step - loss: 0.1612 - accuracy: 0.9504 - val_loss: 0.4799 - val_accuracy: 0.8752
Epoch 41/100
125/125 [==============================] - 1s 10ms/step - loss: 0.1576 - accuracy: 0.9516 - val_loss: 0.4777 - val_accuracy: 0.8760
Epoch 42/100
125/125 [==============================] - 1s 11ms/step - loss: 0.1531 - accuracy: 0.9530 - val_loss: 0.4842 - val_accuracy: 0.8761
Epoch 43/100
125/125 [==============================] - 1s 11ms/step - loss: 0.1495 - accuracy: 0.9542 - val_loss: 0.4879 - val_accuracy: 0.8761
Epoch 44/100
125/125 [==============================] - 1s 11ms/step - loss: 0.1456 - accuracy: 0.9552 - val_loss: 0.4933 - val_accuracy: 0.8757
Epoch 45/100
125/125 [==============================] - 1s 10ms/step - loss: 0.1419 - accuracy: 0.9562 - val_loss: 0.4988 - val_accuracy: 0.8753
Epoch 46/100
125/125 [==============================] - 1s 11ms/step - loss: 0.1385 - accuracy: 0.9574 - val_loss: 0.5012 - val_accuracy: 0.8758
Epoch 47/100
125/125 [==============================] - 1s 11ms/step - loss: 0.1356 - accuracy: 0.9581 - val_loss: 0.5040 - val_accuracy: 0.8763
Epoch 48/100
125/125 [==============================] - 1s 11ms/step - loss: 0.1325 - accuracy: 0.9591 - val_loss: 0.5114 - val_accuracy: 0.8761
Epoch 49/100
125/125 [==============================] - 1s 11ms/step - loss: 0.1291 - accuracy: 0.9601 - val_loss: 0.5151 - val_accuracy: 0.8764
Epoch 50/100
125/125 [==============================] - 1s 11ms/step - loss: 0.1263 - accuracy: 0.9607 - val_loss: 0.5214 - val_accuracy: 0.8761
Epoch 51/100
125/125 [==============================] - 1s 11ms/step - loss: 0.1232 - accuracy: 0.9621 - val_loss: 0.5210 - val_accuracy: 0.8759
Epoch 52/100
125/125 [==============================] - 1s 11ms/step - loss: 0.1205 - accuracy: 0.9626 - val_loss: 0.5232 - val_accuracy: 0.8761
Epoch 53/100
125/125 [==============================] - 1s 10ms/step - loss: 0.1177 - accuracy: 0.9633 - val_loss: 0.5329 - val_accuracy: 0.8754
Epoch 54/100
125/125 [==============================] - 1s 11ms/step - loss: 0.1152 - accuracy: 0.9644 - val_loss: 0.5317 - val_accuracy: 0.8753
Epoch 55/100
125/125 [==============================] - 1s 11ms/step - loss: 0.1132 - accuracy: 0.9648 - val_loss: 0.5418 - val_accuracy: 0.8748
Epoch 56/100
125/125 [==============================] - 1s 11ms/step - loss: 0.1102 - accuracy: 0.9658 - val_loss: 0.5456 - val_accuracy: 0.8745
Epoch 57/100
125/125 [==============================] - 1s 11ms/step - loss: 0.1083 - accuracy: 0.9663 - val_loss: 0.5438 - val_accuracy: 0.8753
Epoch 58/100
125/125 [==============================] - 1s 11ms/step - loss: 0.1058 - accuracy: 0.9669 - val_loss: 0.5519 - val_accuracy: 0.8753
Epoch 59/100
125/125 [==============================] - 1s 11ms/step - loss: 0.1035 - accuracy: 0.9675 - val_loss: 0.5543 - val_accuracy: 0.8753
Epoch 60/100
125/125 [==============================] - 1s 11ms/step - loss: 0.1017 - accuracy: 0.9679 - val_loss: 0.5619 - val_accuracy: 0.8756
Epoch 61/100
125/125 [==============================] - 1s 10ms/step - loss: 0.0993 - accuracy: 0.9686 - val_loss: 0.5680 - val_accuracy: 0.8751
Epoch 62/100
125/125 [==============================] - 1s 11ms/step - loss: 0.0975 - accuracy: 0.9690 - val_loss: 0.5768 - val_accuracy: 0.8737
Epoch 63/100
125/125 [==============================] - 1s 10ms/step - loss: 0.0954 - accuracy: 0.9697 - val_loss: 0.5800 - val_accuracy: 0.8733
Epoch 64/100
125/125 [==============================] - 1s 10ms/step - loss: 0.0936 - accuracy: 0.9700 - val_loss: 0.5782 - val_accuracy: 0.8744
Epoch 65/100
125/125 [==============================] - 1s 11ms/step - loss: 0.0918 - accuracy: 0.9709 - val_loss: 0.5832 - val_accuracy: 0.8743
Epoch 66/100
125/125 [==============================] - 1s 11ms/step - loss: 0.0897 - accuracy: 0.9714 - val_loss: 0.5863 - val_accuracy: 0.8744
Epoch 67/100
125/125 [==============================] - 1s 11ms/step - loss: 0.0880 - accuracy: 0.9718 - val_loss: 0.5912 - val_accuracy: 0.8742
Epoch 68/100
125/125 [==============================] - 1s 11ms/step - loss: 0.0863 - accuracy: 0.9722 - val_loss: 0.5972 - val_accuracy: 0.8741
Epoch 69/100
125/125 [==============================] - 1s 11ms/step - loss: 0.0850 - accuracy: 0.9727 - val_loss: 0.5969 - val_accuracy: 0.8743
Epoch 70/100
125/125 [==============================] - 1s 11ms/step - loss: 0.0832 - accuracy: 0.9732 - val_loss: 0.6046 - val_accuracy: 0.8736
Epoch 71/100
125/125 [==============================] - 1s 11ms/step - loss: 0.0815 - accuracy: 0.9738 - val_loss: 0.6037 - val_accuracy: 0.8746
Epoch 72/100
125/125 [==============================] - 1s 11ms/step - loss: 0.0799 - accuracy: 0.9741 - val_loss: 0.6092 - val_accuracy: 0.8744
Epoch 73/100
125/125 [==============================] - 1s 11ms/step - loss: 0.0785 - accuracy: 0.9746 - val_loss: 0.6118 - val_accuracy: 0.8750
Epoch 74/100
125/125 [==============================] - 1s 11ms/step - loss: 0.0769 - accuracy: 0.9751 - val_loss: 0.6150 - val_accuracy: 0.8737
Epoch 75/100
125/125 [==============================] - 1s 11ms/step - loss: 0.0753 - accuracy: 0.9754 - val_loss: 0.6196 - val_accuracy: 0.8736
Epoch 76/100
125/125 [==============================] - 1s 10ms/step - loss: 0.0742 - accuracy: 0.9759 - val_loss: 0.6237 - val_accuracy: 0.8738
Epoch 77/100
125/125 [==============================] - 1s 10ms/step - loss: 0.0731 - accuracy: 0.9760 - val_loss: 0.6310 - val_accuracy: 0.8731
Epoch 78/100
125/125 [==============================] - 1s 10ms/step - loss: 0.0719 - accuracy: 0.9765 - val_loss: 0.6335 - val_accuracy: 0.8746
Epoch 79/100
125/125 [==============================] - 1s 11ms/step - loss: 0.0702 - accuracy: 0.9770 - val_loss: 0.6366 - val_accuracy: 0.8744
Epoch 80/100
125/125 [==============================] - 1s 11ms/step - loss: 0.0692 - accuracy: 0.9773 - val_loss: 0.6368 - val_accuracy: 0.8745
Epoch 81/100
125/125 [==============================] - 1s 11ms/step - loss: 0.0678 - accuracy: 0.9777 - val_loss: 0.6472 - val_accuracy: 0.8735
Epoch 82/100
125/125 [==============================] - 1s 11ms/step - loss: 0.0669 - accuracy: 0.9778 - val_loss: 0.6474 - val_accuracy: 0.8735
Epoch 83/100
125/125 [==============================] - 1s 10ms/step - loss: 0.0653 - accuracy: 0.9783 - val_loss: 0.6466 - val_accuracy: 0.8745
Epoch 84/100
125/125 [==============================] - 1s 10ms/step - loss: 0.0645 - accuracy: 0.9787 - val_loss: 0.6576 - val_accuracy: 0.8733
Epoch 85/100
125/125 [==============================] - 1s 10ms/step - loss: 0.0633 - accuracy: 0.9790 - val_loss: 0.6539 - val_accuracy: 0.8742
Epoch 86/100
125/125 [==============================] - 1s 10ms/step - loss: 0.0626 - accuracy: 0.9792 - val_loss: 0.6609 - val_accuracy: 0.8738
Epoch 87/100
125/125 [==============================] - 1s 11ms/step - loss: 0.0614 - accuracy: 0.9794 - val_loss: 0.6641 - val_accuracy: 0.8739
Epoch 88/100
125/125 [==============================] - 1s 11ms/step - loss: 0.0602 - accuracy: 0.9799 - val_loss: 0.6677 - val_accuracy: 0.8739
Epoch 89/100
125/125 [==============================] - 1s 11ms/step - loss: 0.0594 - accuracy: 0.9801 - val_loss: 0.6659 - val_accuracy: 0.8731
Epoch 90/100
125/125 [==============================] - 1s 10ms/step - loss: 0.0581 - accuracy: 0.9803 - val_loss: 0.6744 - val_accuracy: 0.8740
Epoch 91/100
125/125 [==============================] - 1s 11ms/step - loss: 0.0575 - accuracy: 0.9806 - val_loss: 0.6722 - val_accuracy: 0.8737
Epoch 92/100
125/125 [==============================] - 1s 10ms/step - loss: 0.0568 - accuracy: 0.9808 - val_loss: 0.6778 - val_accuracy: 0.8737
Epoch 93/100
125/125 [==============================] - 1s 10ms/step - loss: 0.0557 - accuracy: 0.9814 - val_loss: 0.6837 - val_accuracy: 0.8733
Epoch 94/100
125/125 [==============================] - 1s 10ms/step - loss: 0.0548 - accuracy: 0.9814 - val_loss: 0.6906 - val_accuracy: 0.8732
Epoch 95/100
125/125 [==============================] - 1s 11ms/step - loss: 0.0543 - accuracy: 0.9816 - val_loss: 0.6913 - val_accuracy: 0.8733
Epoch 96/100
125/125 [==============================] - 1s 10ms/step - loss: 0.0536 - accuracy: 0.9816 - val_loss: 0.6955 - val_accuracy: 0.8723
Epoch 97/100
125/125 [==============================] - 1s 11ms/step - loss: 0.0531 - accuracy: 0.9817 - val_loss: 0.7001 - val_accuracy: 0.8724
Epoch 98/100
125/125 [==============================] - 1s 11ms/step - loss: 0.0521 - accuracy: 0.9821 - val_loss: 0.7017 - val_accuracy: 0.8738
Epoch 99/100
125/125 [==============================] - 1s 10ms/step - loss: 0.0512 - accuracy: 0.9822 - val_loss: 0.7069 - val_accuracy: 0.8731
Epoch 100/100
125/125 [==============================] - 1s 11ms/step - loss: 0.0506 - accuracy: 0.9826 - val_loss: 0.7050 - val_accuracy: 0.8726
WARNING:tensorflow:From /usr/local/lib/python3.7/site-packages/tensorflow/python/training/tracking/tracking.py:105: Network.state_updates (from tensorflow.python.keras.engine.network) is deprecated and will be removed in a future version.
Instructions for updating:
This property should not be used in TensorFlow 2.0, as updates are applied automatically.
INFO:tensorflow:Assets written to: s2s/assets

运行推理(采样)

  1. 对输入和重新得到的初始解码器状态进行编码
  2. 以这个初始状态和“序列开始”标记为t的目标执行一步解码。输出将是t+1的目标输出。
  3. 重复上面两个步骤
# Define sampling models
# Restore the model and construct the encoder and decoder.
model = keras.models.load_model("s2s")encoder_inputs = model.input[0]  # input_1
encoder_outputs, state_h_enc, state_c_enc = model.layers[2].output  # lstm_1
encoder_states = [state_h_enc, state_c_enc]
encoder_model = keras.Model(encoder_inputs, encoder_states)decoder_inputs = model.input[1]  # input_2
decoder_state_input_h = keras.Input(shape=(latent_dim,), name="input_3")
decoder_state_input_c = keras.Input(shape=(latent_dim,), name="input_4")
decoder_states_inputs = [decoder_state_input_h, decoder_state_input_c]
decoder_lstm = model.layers[3]
decoder_outputs, state_h_dec, state_c_dec = decoder_lstm(decoder_inputs, initial_state=decoder_states_inputs
)
decoder_states = [state_h_dec, state_c_dec]
decoder_dense = model.layers[4]
decoder_outputs = decoder_dense(decoder_outputs)
decoder_model = keras.Model([decoder_inputs] + decoder_states_inputs, [decoder_outputs] + decoder_states
)# Reverse-lookup token index to decode sequences back to
# something readable.
reverse_input_char_index = dict((i, char) for char, i in input_token_index.items())
reverse_target_char_index = dict((i, char) for char, i in target_token_index.items())def decode_sequence(input_seq):# Encode the input as state vectors.states_value = encoder_model.predict(input_seq)# Generate empty target sequence of length 1.target_seq = np.zeros((1, 1, num_decoder_tokens))# Populate the first character of target sequence with the start character.target_seq[0, 0, target_token_index["\t"]] = 1.0# Sampling loop for a batch of sequences# (to simplify, here we assume a batch of size 1).stop_condition = Falsedecoded_sentence = ""while not stop_condition:output_tokens, h, c = decoder_model.predict([target_seq] + states_value)# Sample a tokensampled_token_index = np.argmax(output_tokens[0, -1, :])sampled_char = reverse_target_char_index[sampled_token_index]decoded_sentence += sampled_char# Exit condition: either hit max length# or find stop character.if sampled_char == "\n" or len(decoded_sentence) > max_decoder_seq_length:stop_condition = True# Update the target sequence (of length 1).target_seq = np.zeros((1, 1, num_decoder_tokens))target_seq[0, 0, sampled_token_index] = 1.0# Update statesstates_value = [h, c]return decoded_sentence

你现在可以像这样生成解码的句子:

for seq_index in range(20):# Take one sequence (part of the training set)# for trying out decoding.input_seq = encoder_input_data[seq_index : seq_index + 1]decoded_sentence = decode_sequence(input_seq)print("-")print("Input sentence:", input_texts[seq_index])print("Decoded sentence:", decoded_sentence)
-
Input sentence: Go.
Decoded sentence: Va !
-
Input sentence: Hi.
Decoded sentence: Salut !
-
Input sentence: Run!
Decoded sentence: Cours !
-
Input sentence: Who?
Decoded sentence: Qui ?
-
Input sentence: Wow!
Decoded sentence: Ça alors !
-
Input sentence: Fire!
Decoded sentence: Au feu !
-
Input sentence: Help!
Decoded sentence: À l'aide !
-
Input sentence: Jump.
Decoded sentence: Saute.
-
Input sentence: Stop!
Decoded sentence: Stop !
-
Input sentence: Wait!
Decoded sentence: Attendez !
-
Input sentence: Go on.
Decoded sentence: Poursuis.
-
Input sentence: Hello!
Decoded sentence: Salut !

Character-level recurrent sequence-to-sequence model (翻译)相关推荐

  1. 序列模型第一周作业2: Character level language model - Dinosaurus land

    来自吴恩达深度学习系列视频:序列模型第一周作业2: Character level language model - Dinosaurus land.如果英文对你说有困难,你可以参照[中文][吴恩达课 ...

  2. 17.深度学习练习:Character level language model - Dinosaurus land

    本文节选自吴恩达老师<深度学习专项课程>编程作业,在此表示感谢. 课程链接:https://www.deeplearning.ai/deep-learning-specialization ...

  3. Assignment | 05-week1 -Character level language model - Dinosaurus land

    该系列仅在原课程基础上课后作业部分添加个人学习笔记,如有错误,还请批评指教.- ZJ Coursera 课程 |deeplearning.ai |网易云课堂 CSDN:http://blog.csdn ...

  4. Sequence to Sequence Learning with Neural Networks论文翻译

    Sequence to Sequence Learningwith Neural Networks论文翻译 摘要 深度神经网络是在困难的学习任务中取得卓越性能的强大模型.尽管拥有大量的标记训练集,DN ...

  5. Towards Two-Dimensional Sequence to Sequence Model和Two-Way Neural Machine Translation两篇论文简单分析

    第一篇是:发布于2018年Towards Two-Dimensional Sequence to Sequence Model in NeuralMachine Translation 第二篇是:与第 ...

  6. 【深度学习论文阅读】TCN:An Empirical Evaluation of Generic Convolutional and Recurrent Networks for Sequence

    2018年人工智能十佳论文之一:TCN 论文地址:An Empirical Evaluation of Generic Convolutional and Recurrent Networks for ...

  7. An Empirical Evaluation of Generic Convolutional and Recurrent Networks for Sequence Modeling

    TCN:An Empirical Evaluation of Generic Convolutional and Recurrent Networks for Sequence Modeling 该论 ...

  8. 【论文阅读】An Empirical Evaluation of Generic Convolutional and Recurrent Networks for Sequence Modeling

    1.摘要 一般而言,序列模型与循环网络(recurrent networks)关系紧密(由于RNN的循环自回归结构能较好地表达出时间序列).而传统的卷积网络(convolutional network ...

  9. Paper:《Generating Sequences With Recurrent Neural Networks》的翻译和解读

    Paper:<Generating Sequences With Recurrent Neural Networks>的翻译和解读 目录 Generating Sequences With ...

  10. Attentive Sequence to Sequence Networks

    转载自  Attentive Sequence to Sequence Networks Attentive Sequence to Sequence Networks 1.Encoder-Decod ...

最新文章

  1. 斑马快跑已获3亿元新一轮融资,已成为全国第五大网约车平台
  2. java实现混合运算_java图形化界面实现简单混合运算计算器
  3. 第一次安装Intellij IDEA过程中遇到的坑
  4. 关于ssm框架的全部整合(一) 2021.05.09
  5. h5跳转小程序页面url_微信小程序页面跳转方法
  6. Mapreduce中的分区Partitioner
  7. Matlab R2009b 版 license 到期问题
  8. 测试耳机的噪音测试软件,测试你的耳机音效,听这十首歌就够了
  9. python爬虫——爬取汽车之家新闻
  10. android 如何启动apk,Android JS启动APK
  11. 什么是嵌入式开发?为什么用C语言作为开发语言?
  12. PHPStorm+Xdebug配置(phpStudy)
  13. 一些好听的纯音乐及下载
  14. nlp文本预处理构建词汇表
  15. 2018年中考计算机考试成绩,2018年中考分数线
  16. 资产分类计算机软件,固定资产管理系统_资产分类名称(电子计算机及其外围设备篇)...
  17. c#物联网_基于C#实现日志记录与SQL SERVER的双向存储工控数字化之旅
  18. 用友NC65产品的对账节点联查业务帐
  19. Android 文件下载中文名乱码的解决办法
  20. Python将多张2D TIFF图片转为一个3D TIFF文件

热门文章

  1. MAC下学习UNIX网络编程
  2. SolidWorks2020绘制电脑折叠支架
  3. centos查询 硬盘序列号查询_关于使用java执行shell脚本获取centos的硬盘序列号和mac地址...
  4. SpringCloud Nacos 【服务端】服务注册源码解析
  5. 鲍威尔法源程序码matlab,十一、Powell算法(鲍威尔算法)原理以及实现
  6. jqTransform表单美化
  7. Python中IO编程-StringIO和BytesIO
  8. 自定义报表(demo1)
  9. .NET反编译工具:de4dot
  10. 4.2.5 求解幂集问题