没舍得用自己的电脑跑,训练用的机器比较垃圾,没有 GPU,很闹心,其间不同 IDE 调用了一次环境想用来测试一下中途训练效果,导致跑了两天的训练被打断,血的教训,不要再犯第二次。

本次训练是用来 TED 的英语-中文翻译数据集,进行了相关的前处理,做了模型持久化,会隔一定 step 自动保存一次模型和训练结果。被打断时训练到 6200 步,输出的模型能够把句子说成人话了,虽然翻译准确性还比较差,预计 9000 步左右能稍好些吧。

Seq2Seq 模型原理和训练结果待后续补充,在此先记录一下代码,若有感兴趣的同学需要原始数据跑一下可留言。

# coding: utf-8
"""
@ File:     Seq2Seq_train.py
@ Brief:    文字翻译, 使用 TED 数据集与 Seq2Seq 模型训练
@ Author:   攀援的井蛙
@ Data:    2020-09-17
"""import tensorflow as tf# 参数设置
SRC_TRAIN_DATA = "./ted_data/train.en"      # 源语言输入文件
TRG_TRAIN_DATA = "./ted_data/train.zh"      # 目标语言输入文件
CHECKPOINT_PATH = "./seq2seq_ckpt"          # checkpoint 保存路径HIDDEN_SIZE = 1024              # LSTM 的隐藏层规模
NUM_LAYERS = 2                  # 深层循环神经网络中 LSTM 结构的层数
SRC_VOCAB_SIZE = 10000          # 源语言词汇表大小
TRG_VOCAB_SIZE = 4000           # 目标语言词汇表大小
BATCH_SIZE = 100                # 训练数据 batch 的大小
NUM_EPOCH = 5                   # 使用训练数据的轮数
KEEP_PROB = 0.8                 # 节点不被 dropout 的概率
MAX_GRAD_NORM = 5               # 用于控制梯度膨胀的梯度大小上限
SHARE_EMB_AND_SOFTMAX = True    # 在 Softmax 层和词向量层之间共享参数MAX_LEN = 50       # 限定句子的最大单词数量
SOS_ID = 1         # 目标语言词汇表中 <sos> 的 ID'''
@ brief: 使用 Dataset 从一个文件中读取一个语言的数据,数据的格式为每行一句话,单词已转化为单词编号
@ return: dataset 数据
@ param file_path: 保存翻译语言内容的文件路径
'''
def MakeDataset(file_path):dataset = tf.data.TextLineDataset(file_path)# 根据空格将单词编号切分开并放入一个一维向量dataset = dataset.map(lambda string: tf.string_split([string]).values)# 将字符串形式的单词编号转化为整数dataset = dataset.map(lambda string: tf.string_to_number(string, tf.int32))# 统计每个句子的单词数量,并与句子内容一起放入 Dataset 中dataset = dataset.map(lambda x: (x, tf.size(x)))return dataset'''
@ brief: 从源文件和目标语言文件中分别读取数据,并进行填充和 batching 操作
@ return: 划分为 batch 的数据
@ param src_path: 源语言文件
@ param trg_path: 目标语言文件
@ param batch_size: 要划分的 batch 大小
'''
def MakeSrcTrgDataset(src_path, trg_path, batch_size):# 首先分别读取源语言和目标语言数据src_data = MakeDataset(src_path)trg_data = MakeDataset(trg_path)# 通过 zip 操作将两个 Dataset 合并为一个 Dataset# 现在每个 Dataset 中每一项数据 ds 由 4 个张量组成#   ds[0][0]是源句子#   ds[0][1]是源句子长度#   ds[1][0]是目标句子#   ds[1][1]是目标句子长度dataset = tf.data.Dataset.zip((src_data, trg_data))# 删除内容为空(只包含<EOS>)的句子和长度过长的句子'''@ brief: 删除内容为空(只包含<EOS>)的句子和长度过长的句子@ return: 内容非空、长度合适的句子@ param src_tuple: 源语言元组@ param trg_tuple: 目标语言元组'''def FilterLength(src_tuple, trg_tuple):((src_input, src_len), (trg_label, trg_len)) = (src_tuple, trg_tuple)src_len_ok = tf.logical_and(tf.greater(src_len, 1), tf.less_equal(src_len, MAX_LEN))trg_len_ok = tf.logical_and(tf.greater(trg_len, 1), tf.less_equal(trg_len, MAX_LEN))return tf.logical_and(src_len_ok, trg_len_ok)dataset = dataset.filter(FilterLength)# 解码器需要两种格式的目标句子:#   1.解码器的输入(trg_input),形式如同 "<SOS> X Y Z"#   2.解码器的目标输出(trg_label),形式如同 "X Y X <SOS>"# 从上面文件中读到的目标句子是 "X Y Z <SOS>" 的形式,我们需要从中生成 "<SOS> X Y Z"# 形式并加入到 Dataset 中''' @ brief: 生成符合解码器输入格式要求的数据@ return: "<SOS> X Y Z" 格式的句子@ param src_tuple: 源语言元组@ param trg_tuple: 目标语言元组'''def MakeTrgInput(src_tuple, trg_tuple):((src_input, src_len), (trg_label, trg_len)) = (src_tuple, trg_tuple)trg_input = tf.concat([[SOS_ID], trg_label[:-1]], axis=0)return ((src_input, src_len), (trg_input, trg_label, trg_len))dataset = dataset.map(MakeTrgInput)# 随机打乱训练数据dataset = dataset.shuffle(10000)# 规定填充后输出的数据维度padded_shapes = ((tf.TensorShape([None]),     # 源句子是长度未知的向量tf.TensorShape([])),         # 源句子长度是单个数字(tf.TensorShape([None]),     # 目标句子(解码器输入)是长度未知的向量tf.TensorShape([None]),     # 目标句子(解码器目标输出)是长度未知的向量tf.TensorShape([]))         # 目标句子长度是单个数字)# 调用 padded_batch 方法进行 batching 操作batched_dataset = dataset.padded_batch(batch_size, padded_shapes)return batched_dataset# 定义 NMTModel 类来描述翻译模型
class NMTModel(object):'''@ brief: 在类的初始化函数中定义模型要用到的变量'''def __init__(self):# 定义编码器和解码器所使用的 LSTM 结构self.enc_cell = tf.nn.rnn_cell.MultiRNNCell([tf.nn.rnn_cell.BasicLSTMCell(HIDDEN_SIZE)for _ in range(NUM_LAYERS)])self.dec_cell = tf.nn.rnn_cell.MultiRNNCell([tf.nn.rnn_cell.BasicLSTMCell(HIDDEN_SIZE)for _ in range(NUM_LAYERS)])# 为源语言和目标语言分别定义词向量self.src_embedding = tf.get_variable("src_emb", [SRC_VOCAB_SIZE, HIDDEN_SIZE])self.trg_embedding = tf.get_variable("trg_emb", [TRG_VOCAB_SIZE, HIDDEN_SIZE])# 定义 softmax 层的变量if SHARE_EMB_AND_SOFTMAX:self.softmax_weight = tf.transpose(self.trg_embedding)else:self.softmax_weight = tf.get_variable("weight", [HIDDEN_SIZE, TRG_VOCAB_SIZE])self.softmax_bias = tf.get_variable("softmax_bias", [TRG_VOCAB_SIZE])# 在 forward 函数中定义模型的前向计算图# src_input, src_size, trg_input, trg_label, trg_size 分别是上面# MakeSrcTrgDataset 函数生成的五种张量'''@ brief: 定义模型的前向计算图@ return: cost per token, 优化方法@ param src_input: 源句子向量@ param src_size: 源句子长度@ param trg_input: 目标句子(解码器输入)向量@ param trg_label: 目标句子(解码器目标输出)向量@ param trg_size: 目标句子长度'''def forward(self, src_input, src_size, trg_input, trg_label, trg_size):batch_size = tf.shape(src_input)[0]# 将输入和输出单词编号为词向量src_emb = tf.nn.embedding_lookup(self.src_embedding, src_input)trg_emb = tf.nn.embedding_lookup(self.trg_embedding, trg_input)# 在词向量上进行 dropoutsrc_emb = tf.nn.dropout(src_emb, KEEP_PROB)trg_emb = tf.nn.dropout(trg_emb, KEEP_PROB)# 使用 dynamic_rnn 构造编码器# 编码器读取源句子每个位置的词向量,输出最后一步的隐藏状态 enc_state# 因为编码器是一个双层 LSTM,因此 enc_state 是一个包含两个 LSTMStateTuple 类# 张量的 tuple,每个 LSTMStateTuple 对应编码器中的一层。# enc_outputs 是顶层 LSTM 在每一步的输出,它的维度是 [batch_size, max_time, HIDDEN_SIZE]# Seq2Seq 模型中不需要用到 enc_outputs, 而 attention 模型会用到它with tf.variable_scope("encoder"):enc_outputs, enc_state = tf.nn.dynamic_rnn(self.enc_cell, src_emb, src_size, dtype=tf.float32)# 使用 dynamic_rnn 构造解码器# 解码器读取目标句子每个位置的词向量,输出的 dec_outputs 为每一步# 顶层 LSTM 的输出。 dec_outputs 的维度为 [batch_size, max_time, HIDDEN_SIZE]# initial_state = enc_state 表示用编码器的输出来初始化第一步的隐藏状态with tf.variable_scope("decoder"):dec_outputs, _ = tf.nn.dynamic_rnn(self.dec_cell, trg_emb, trg_size, initial_state=enc_state)# 计算解码器每一步的 log perplexityoutput = tf.reshape(dec_outputs, [-1, HIDDEN_SIZE])logits = tf.matmul(output, self.softmax_weight) + self.softmax_biasloss = tf.nn.sparse_softmax_cross_entropy_with_logits(labels=tf.reshape(trg_label, [-1]), logits=logits)# 在计算平均损失时,需要将填充位置的权重设置为 0# 以避免无效位置的预测感染模型的训练label_weights = tf.sequence_mask(trg_size, maxlen=tf.shape(trg_label)[1], dtype=tf.float32)label_weights = tf.reshape(label_weights, [-1])cost = tf.reduce_sum(loss * label_weights)cost_per_token = cost / tf.reduce_sum(label_weights)# 定义反向传播操作trainable_variables = tf.trainable_variables()# 控制梯度大小,定义优化方法和训练步骤grads = tf.gradients(cost / tf.to_float(batch_size),trainable_variables)grads, _ = tf.clip_by_global_norm(grads, MAX_GRAD_NORM)optimizer = tf.train.GradientDescentOptimizer(learning_rate=1.0)train_op = optimizer.apply_gradients(zip(grads, trainable_variables))return cost_per_token, train_op# 使用给定的模型 model 上训练一个 epoch, 并返回全局步数
# 每训练 200 步便保存一个 checkpoint
'''
@ brief: 使用给定模型训练 epoch,每训练 200 步保存一个 checkpoint
@ return: 全局步数
@ param session: 计算图
@ param cost_op: 损失优化方法
@ param train_op: 训练优化方法
@ param saver: 模型持久化 Saver
@ param step: 步数
'''
def run_epoch(session, cost_op, train_op, saver, step):# 训练一个 epoch# 重复训练步骤直至遍历完 Dataset 中所有的数据while True:try:# 运行 train_op 并计算损失值,训练数据在 main() 函数中以 Dataset 方式提供cost, _ = session.run([cost_op, train_op])if step % 10 == 0:print("After %d steps, per token cost is %.3f" % (step, cost))# 每 200 步保存一个 checkpointif step % 200 == 0:saver.save(session, CHECKPOINT_PATH, global_step=step)step += 1except tf.errors.OutOfRangeError:breakreturn step'''
@ brief: 主函数,定义训练过程
@ return: None
'''
def main():# 定义初始化函数initializer = tf.random_uniform_initializer(-0.05, 0.05)# 定义训练用的循环神经网络模型with tf.variable_scope("nmt_model", reuse=None,initializer=initializer):train_model = NMTModel()# 定义输入数据data = MakeSrcTrgDataset(SRC_TRAIN_DATA, TRG_TRAIN_DATA, BATCH_SIZE)iterator = data.make_initializable_iterator()(src, src_size), (trg_input, trg_label, trg_size) = iterator.get_next()# 定义前向计算图,输入数据以张量形式提供给 forward 函数cost_op, train_op = train_model.forward(src, src_size, trg_input, trg_label, trg_size)# 训练模型saver = tf.train.Saver()step = 0with tf.Session() as sess:tf.global_variables_initializer().run()for i in range(NUM_EPOCH):print("In iteration: %d" % (i + 1))sess.run(iterator.initializer)step = run_epoch(sess, cost_op, train_op, saver, step)if __name__ == "__main__":main()

TensorFlow Seq2Seq模型样例:实现语言翻译相关推荐

  1. TensorFlow中cnn-cifar10样例代码详解

    TensorFlow是一个支持分布式的深度学习框架,在Google的推动下,它正在变得越来越普及.我最近学了TensorFlow教程上的一个例子,即采用CNN对cifar10数据集进行分类.在看源代码 ...

  2. 用mobilenet模型跑tensorflow CNN的样例:image_retrain.py和label_image.py

    系统是 ubuntu 16.04,tensorflow版本是1.6, cuDNN版本是7.0.git clone tensorflow后试着跑了一下image_retrain.py(以下简称retra ...

  3. c语言程序报告样例,C语言个人实习报告定稿(样例3)

    <C语言个人实习报告.doc>由会员分享,可免费在线阅读全文,更多与<C语言个人实习报告[定稿]>相关文档资源请在帮帮文库(www.woc88.com)数亿文档库存里搜索. 1 ...

  4. python pptx库中文文档_python-pptx库中文文档及使用样例

    个人使用样例及部分翻译自官方文档,并详细介绍chart的使用 转载请注明出处,谢谢 一:基础应用 1.创建pptx文档类并插入一页幻灯片 from pptx import Presentation p ...

  5. L11注意力机制和Seq2seq模型

    注意力机制 在"编码器-解码器(seq2seq)"⼀节⾥,解码器在各个时间步依赖相同的背景变量(context vector)来获取输⼊序列信息.当编码器为循环神经⽹络时,背景变量 ...

  6. seq2seq模型_使用Tensorflow搭建一个简单的Seq2Seq翻译模型

    1.背景 首先,这篇博文整理自谷歌开源的神经机器翻译项目Neural Machine Translation (seq2seq) Tutorial.如果你直接克隆这个项目按照Tutorial中的说明操 ...

  7. seq2seq模型_直观理解并使用Tensorflow实现Seq2Seq模型的注意机制

    采用带注意机制的序列序列结构进行英印地语神经机器翻译 Seq2seq模型构成了机器翻译.图像和视频字幕.文本摘要.聊天机器人以及任何你可能想到的包括从一个数据序列到另一个数据序列转换的任务的基础.如果 ...

  8. 【TensorFlow实战笔记】对于TED(en-zh)数据集进行Seq2Seq模型实战,以及对应的Attention机制(tf保存模型读取模型)

    个人公众号 AI蜗牛车 作者是南京985AI硕士,CSDN博客专家,研究方向主要是时空序列预测和时间序列数据挖掘,获国家奖学金,校十佳大学生,省优秀毕业生,阿里天池时空序列比赛rank3.公众号致力于 ...

  9. tensorflow 保存训练loss_tensorflow2.0保存和加载模型 (tensorflow2.0官方教程翻译)

    最新版本:https://www.mashangxue123.com/tensorflow/tf2-tutorials-keras-save_and_restore_models.html 英文版本: ...

  10. (四)使用TensorFlow和Keras构建AI语言翻译

    目录 介绍 导入库 构建模型组件 添加注意力机制 将模型放在一起 下一步 下载原文件 -- 3.5k 介绍 谷歌翻译工作得如此之好,它通常看起来很神奇.但这不是魔法--这是深度学习! 在本系列文章中, ...

最新文章

  1. mysqlmediumtext,分享面经!
  2. 数据结构上机测试1:顺序表的应用
  3. Spring注解@Resource和@Autowired区别对比
  4. python tcp通信如何实现多人聊天,Python实现多用户全双工聊天(一对一),python多用户,多用户全双工聊天简陋...
  5. docker run
  6. Xshell利用密钥远程登录Linux
  7. cf——Sasha and a Bit of Relax(dp,math)
  8. Oracle 触发器详解
  9. 并发编程学习之线程8锁
  10. java关闭数据库连接_java 和数据库连接如果不关闭会怎么样
  11. 滑模控制学习笔记(二)
  12. html改变鼠标指针形状代码,鼠标指针形状效果大全 cursor
  13. 2022全球量子通信产业发展报告
  14. 论文笔记:UCTransNet: Rethinking the Skip Connections in U-Net from a Channel-wisePerspective with Transf
  15. 【Linux】yum install cmake 报错,出现错误ImportError: No module named urlgrabber.grabber
  16. 糖尿病人食谱以及水果的食用
  17. 举例在移动HTML5 UI框架有那些
  18. 理解SVM ——入门SVM和代码实现
  19. 四级网络工程师试题二
  20. Python之文档测试

热门文章

  1. TDDFT计算软件Octopus学习笔记(一):Ubuntu下Octopus的安装
  2. C#--图表控件(Chart)
  3. MySQL里有2000w数据,redis中只存20w的数据,如何保证redis中的数据都是热点数据?
  4. 《那些年啊,那些事——一个程序员的奋斗史》九
  5. 20.SPDY_QUIC_HTTP2_HTTP3
  6. java itext read a pdf file_java - 使用iText7读取PDF时遇到的问题(使用iText5) - 堆栈内存溢出...
  7. win7安装mysql后“应用程序无法启动因为应用程序的并行配置不正
  8. php页面中播放flv视频,页面播放flv格式视频[原创]
  9. 数字孪生技术方案下的智慧城市建设治理体系优势
  10. H5制作哪家强?四大H5页面制作工具大比拼