前言

前面在《深度学习的seq2seq模型》文章中已经介绍了seq2seq结构及其原理,接下去这篇文章将尝试使用TensorFlow来实现一个seq2seq网络结构,该例子能通过训练给定的训练集实现输入某个序列输出某个序列,其中输入序列和输出序列相同,这里选择使用LSTM模型。

训练样本集

为方便起见这里使用随机生成的序列作为样本,序列的长度也是随机的且在指定的范围内。

LSTM机制原理

关于LSTM机制原理可看之前的文章《LSTM神经网络》。

随机序列生成器

def random_sequences(length_from, length_to, vocab_lower, vocab_upper, batch_size):def random_length():if length_from == length_to:return length_fromreturn np.random.randint(length_from, length_to + 1)while True:yield [np.random.randint(low=vocab_lower, high=vocab_upper, size=random_length()).tolist()for _ in range(batch_size)]

构建一个随机序列生成器方便后面生成序列,其中 length_from 和 length_to表示序列的长度范围从多少到多少,vocab_lower 和 vocab_upper 表示生成的序列值的范围从多少到多少,batch_size 即是批的数量。

填充序列

def make_batch(inputs, max_sequence_length=None):sequence_lengths = [len(seq) for seq in inputs]batch_size = len(inputs)if max_sequence_length is None:max_sequence_length = max(sequence_lengths)inputs_batch_major = np.zeros(shape=[batch_size, max_sequence_length], dtype=np.int32)for i, seq in enumerate(inputs):for j, element in enumerate(seq):inputs_batch_major[i, j] = elementinputs_time_major = inputs_batch_major.swapaxes(0, 1)return inputs_time_major, sequence_lengths

生成的随机序列的长度是不一样的,需要对短的序列用来填充,而可设为0,取最长的序列作为每个序列的长度,不足的填充,然后再转换成time major形式。

构建图

encoder_inputs = tf.placeholder(shape=(None, None), dtype=tf.int32, name='encoder_inputs')
ecoder_inputs = tf.placeholder(shape=(None, None), dtype=tf.int32, name='decoder_inputs')
decoder_targets = tf.placeholder(shape=(None, None), dtype=tf.int32, name='decoder_targets')

创建三个占位符,分别为encoder的输入占位符、decoder的输入占位符和decoder的target占位符。

embeddings = tf.Variable(tf.random_uniform([vocab_size, input_embedding_size], -1.0, 1.0), dtype=tf.float32)
encoder_inputs_embedded = tf.nn.embedding_lookup(embeddings, encoder_inputs)
decoder_inputs_embedded = tf.nn.embedding_lookup(embeddings, decoder_inputs)

将encoder和decoder的输入做一个嵌入操作,对于大词汇量这个能达到降维的效果,嵌入操作也是很常用的方式了。在seq2seq模型中,encoder和decoder都是共用一个嵌入层即可。嵌入层的向量形状为[vocab_size, input_embedding_size],初始值从-1到1,后面训练会自动调整。

encoder_cell = tf.contrib.rnn.LSTMCell(encoder_hidden_units)
encoder_outputs, encoder_final_state = tf.nn.dynamic_rnn(encoder_cell, encoder_inputs_embedded,dtype=tf.float32, time_major=True,)
decoder_cell = tf.contrib.rnn.LSTMCell(decoder_hidden_units)
decoder_outputs, decoder_final_state = tf.nn.dynamic_rnn(decoder_cell, decoder_inputs_embedded,initial_state=encoder_final_state,dtype=tf.float32, time_major=True, scope="plain_decoder",)

创建encoder和decoder的LSTM神经网络,encoder_hidden_units 为LSTM隐层数量,设定输入格式为time major格式。这里我们不关心encoder的循环神经网络的输出,我们要的是它的最终状态encoder_final_state,将其作为decoder的循环神经网络的初始状态。

decoder_logits = tf.contrib.layers.linear(decoder_outputs, vocab_size)
decoder_prediction = tf.argmax(decoder_logits, 2)
stepwise_cross_entropy = tf.nn.softmax_cross_entropy_with_logits(labels=tf.one_hot(decoder_targets, depth=vocab_size, dtype=tf.float32),logits=decoder_logits,)
loss = tf.reduce_mean(stepwise_cross_entropy)
train_op = tf.train.AdamOptimizer().minimize(loss)

对于decoder的循环神经网络的输出,因为我们要一个分类结果,所以需要一个全连接神经网络,输出层神经元数量是词汇的数量。输出层最大值对应的神经元即为预测的类别。输出层的激活函数用softmax,损失函数用交叉熵损失函数。

创建会话

with tf.Session(graph=train_graph) as sess:sess.run(tf.global_variables_initializer())for epoch in range(epochs):batch = next(batches)encoder_inputs_, _ = make_batch(batch)decoder_targets_, _ = make_batch([(sequence) + [EOS] for sequence in batch])decoder_inputs_, _ = make_batch([[EOS] + (sequence) for sequence in batch])feed_dict = {encoder_inputs: encoder_inputs_, decoder_inputs: decoder_inputs_,decoder_targets: decoder_targets_,}_, l = sess.run([train_op, loss], feed_dict)loss_track.append(l)if epoch == 0 or epoch % 1000 == 0:print('loss: {}'.format(sess.run(loss, feed_dict)))predict_ = sess.run(decoder_prediction, feed_dict)for i, (inp, pred) in enumerate(zip(feed_dict[encoder_inputs].T, predict_.T)):print('input > {}'.format(inp))print('predicted > {}'.format(pred))if i >= 20:break

创建会话开始执行,每次生成一批数量,用 make_batch 分别创建encoder输入、decoder的target和decoder的输入。其中target需要在后面加上[EOS],它表示句子的结尾,同时输入也加上[EOS]表示编码开始。每训练1000词输出看看效果。

github

https://github.com/sea-boat/DeepLearning-Lab

========广告时间========

公众号的菜单已分为“分布式”、“机器学习”、“深度学习”、“NLP”、“Java深度”、“Java并发核心”、“JDK源码”、“Tomcat内核”等,可能有一款适合你的胃口。

鄙人的新书《Tomcat内核设计剖析》已经在京东销售了,有需要的朋友可以购买。感谢各位朋友。

为什么写《Tomcat内核设计剖析》

=========================

欢迎关注:

这里写图片描述

TensorFlow实现seq2seq相关推荐

  1. Tensorflow动态seq2seq使用总结

    北京 | 深度学习与人工智能研修 12月23-24日 再设经典课程 重温经典课程 阅读全文 > tf-seq2seq是Tensorflow的通用编码器 - 解码器框架,可用于机器翻译,文本汇总, ...

  2. Tensorflow新版Seq2Seq接口使用

    简介 Tensorflow 1.0.0 版本以后,开发了新的seq2seq接口,弃用了原来的接口. 旧的seq2seq接口也就是tf.contrib.legacy_seq2seq下的那部分,新的接口在 ...

  3. seq2seq模型_直观理解并使用Tensorflow实现Seq2Seq模型的注意机制

    采用带注意机制的序列序列结构进行英印地语神经机器翻译 Seq2seq模型构成了机器翻译.图像和视频字幕.文本摘要.聊天机器人以及任何你可能想到的包括从一个数据序列到另一个数据序列转换的任务的基础.如果 ...

  4. Tensorflow 自动文摘: 基于Seq2Seq+Attention模型的Textsum模型

    Github下载完整代码 https://github.com/rockingdingo/deepnlp/tree/master/deepnlp/textsum 简介 这篇文章中我们将基于Tensor ...

  5. 【NLP】一文理解Seq2Seq

    seq2seq介绍 1.1 简单介绍 Seq2Seq技术,全称Sequence to Sequence,该技术突破了传统的固定大小输入问题框架,开通了将经典深度神经网络模型(DNNs)运用于在翻译,文 ...

  6. not found error :\tensorflow\contrib\coder\python\ops\_coder_ops.so——_gru_ops.so——_lstm_ops.so···

    报错:以下模块未找到 一共有17个模块未找到 C:\Users\86188\AppData\Local\Continuum\anaconda3\envs\python36\lib\site-packa ...

  7. Ubuntu 16.04 源码编译安装GPU tensorflow(二)

    如前一篇在1.4.0版本的Tensorflow上安裝Tensorflow Object Detection API,在验证测试时出現serialized_options=None问题.需安装高版本Te ...

  8. 深度剖析Seq2Seq原理代码

    仅作学术分享,不代表本公众号立场,侵权联系删除 转载于:作者 | 燃雪  来源 | 知乎专栏 编辑 | 机器学习算法与自然语言处理 地址 | https://zhuanlan.zhihu.com/p/ ...

  9. Seq2Seq那些事

    1前言 本篇博客主要记录的是使用Tensorflow搭建Seq2Seq模型,主要包括3个部分的叙述:第一,Seq2Seq模型的训练过程及原理.第二,复现基于SouGouS新闻语料库的文本摘要的应用.第 ...

最新文章

  1. 2.2. 对网络安全的威胁
  2. xenserver 虚拟机扩容lvm磁盘分区的方法_从零开始学Linux运维|35.LVM(逻辑卷管理)的创建...
  3. 操作系统(十)进程通信
  4. 898. 子数组按位或操作
  5. python异常值处理实例_利用Python进行异常值分析实例代码
  6. 15针VGA公头焊接示意图
  7. [软件工程] 可行性研究
  8. java全栈(java全栈开发工程师)
  9. 《程序开发心理学——程序开发组》
  10. SpringBoot+Hibernate配置
  11. 等本等息,等额本息,等额本金,看懂再贷款,坑多!
  12. 解决window聚焦图片不自动更新,没有“喜欢么”信息提示框问题
  13. linux设置网关和ip
  14. 土拍熔断意味着什么_315土拍将解地市之渴?“熔断”来了,别高兴太早
  15. 新版MDN正式上线,还有收费版的MDN Plus,下个月也即将到来
  16. 【Docker】Segmentation Fault or Critical Error encountered. Dumping core and abor
  17. 2w字长文!手撸一套 Java 基础面试题
  18. codeforces1271D 2100分贪心
  19. java - PdfBox 图片转pdf
  20. mysql特殊字符无法入库_MySQL数据入库时特殊字符处理详解

热门文章

  1. 黑房东!我忍无可忍了,这次一定要你得到法律的制裁
  2. php人物图像动漫化
  3. python编程学习——第六周
  4. premiere pro安装
  5. 智慧的仓库管家——WMS
  6. 前端 html 基础 jQuery css
  7. 新媒体运营:如何一招实现主动引流,快速获得用户增长? 黎想
  8. 服务器r730系统备份软件,服务器r730
  9. pyhanlp用户自定义词典添加
  10. 汉王科技持续走下坡路,发展寻求突破