本文主要利用lstm实现英文摘要的自动生成,他的主要步骤如下:
(1) 首先是我们把训练数据的每个单词对应一个数值,构建一个word to int的字典,得到单词和数值的一个一一对应的关系后,我们就可以把整个训练数据集把所有单词转化为数值的序列,也就是数值型的数列,
(2) 训练数据处理好了以后,我们就应该构造batch,建立一个生成batch_size的生成器
(3) 然后开始构建lstm网络,同时开始训练
(4) 训练完了以后就可以开始给定初始单词,让模型生成一个序列,再把序列转化为对应的单词
这个方法适用很多,比如生成周杰伦的歌词,自动生成古诗,自动生成剧本等等都是可以的,方法其实都是一样。

接下来,我们就开始写代码:
导入相关库:

import tensorflow as tf
from tensorflow.contrib import rnn,seq2seq
import numpy as np

数据读取:

with open(file,'r') as f:txt = f.read()
txt = txt.lower()  # 将所有单词都变为小写,这样可以使得不同单词的大小写同一

构建标点符号的字典

tokenize_dict = {}
tokenize_dict['!'] =  "Exclamation_Mark"
tokenize_dict['.'] =  "PERIOD"
tokenize_dict[','] =  "COMMA"
tokenize_dict['"'] =  "QUOTATION_MARK"
tokenize_dict[','] =  "QUOTATION_MARK"
tokenize_dict[';'] =  "SEMICOLON"
tokenize_dict['?'] =  "QUESTION_MARK"
tokenize_dict['('] =  "LEFT_PARENTHESES"
tokenize_dict[')'] =  "RIGHT_PARENTHESES"
tokenize_dict['--'] =  "DASH"
tokenize_dict['\n'] =  "RETURN"

替换训练数据中的标点

for key,value in tokenize_dict.items():txt = txt.replace(key,' {} '.format(value)) # 此处使用format是为了保证替换后的标点和前后单词留出一个空格

然后开始构建单词和数值型序列的映射关系

text = txt.split()
words = set(text)
vocab_size = len(words)
# 以下是将词转化为对应的数字
word_to_int = {word:num for num,word in enumerate(words)}
int_to_word = {value:key for key,value in word_to_int.items()}
# 这一步就是将我们的输入序列,转化为对应的数值型的序列
int_word_data = [word_to_int[word] for word in text]

然后开始构建生成器,生成batch_size

def get_batch(int_word_data,batch_size,len_seq,n_batch):# 该函数是为了构建batch_sizeint_word_data = int_word_data[:n_batch*batch_size*len_seq]int_word_data_y = int_word_data[1:]int_word_data_y.append(int_word_data[0])int_word_data = np.array(int_word_data).reshape([-1,n_batch*len_seq])int_word_data_y = np.array(int_word_data_y).reshape([-1,n_batch*len_seq])for i in range(n_batch):batch_x = np.zeros([batch_size,len_seq])batch_y = np.zeros([batch_size,len_seq])for j in range(batch_size):batch_x[j,:] = int_word_data[j,i*len_seq:(i+1)*len_seq]batch_y[j,:] = int_word_data_y[j,i*len_seq:(i+1)*len_seq]yield batch_x,batch_y

然后开始构建rnn模型

def init_rnn(hidden_size,layers,batch_size):# layers 表示lstm的层数# hidden_size lstm里面隐藏的大小,这个和记忆有关cells = []for _ in range(layers):cells.append(rnn.BasicLSTMCell(hidden_size))cell = rnn.MultiRNNCell(cells,state_is_tuple=True)# 以下是lstm的初始化,初始时,设置全部为0in_state = cell.zero_state(batch_size,tf.float32)# 这一步是给init_state取了个名字in_state = tf.identity(in_state,'in_state')return cell , in_statedef bulid_rnn(hidden_size,layers,inputs,vocab_size,embed_dim,batch_size):# word2vec 的维度大小input_shape = tf.shape(inputs)cell,state = init_rnn(hidden_size,layers,input_shape[0])# 此处使用词向量编码embed = tf.contrib.layers.embed_sequence(inputs,vocab_size,embed_dim)output,out_state = tf.nn.dynamic_rnn(cell=cell,inputs=embed,dtype=tf.float32)out_state = tf.identity(out_state,name='out_state')# 以下是一个全连接层,输出大小为vocab_size,因为我们总共有vocab_size个单词,所以相当于需要得到每个词输出的概率,#就是softmax需要用到output = tf.contrib.layers.fully_connected(output,vocab_size,activation_fn=None)probability = tf.nn.softmax(output,name='probability')return output,out_state

定义我们需要用到的参数

# 参数的定义
learing_rate = 0.01
hidden_size = 512
layers = 1
embed_dim = 200
batch_size = 64
len_seq = 20
vocab_size = len(words)
num_epochs = 10
n_batch = len(int_word_data) // (batch_size*len_seq)
epochs = 100

然后定义loss,optimizer等

# 定义输入,input是我们输入的数据,y是相当于label
inputs = tf.placeholder(tf.int32,[None,None],name='inputs')
y = tf.placeholder(tf.int32,[None,None],name='target')
output,state = bulid_rnn(hidden_size,layers,inputs,vocab_size,embed_dim,batch_size)# 这个loss其实就是交叉熵损失,只是他把计算softmax和计算交叉熵方一起了,然后同时计算了一个tf.reduce_sum
loss = seq2seq.sequence_loss(output,y,tf.ones([batch_size,len_seq]))accuracy = tf.reduce_sum(tf.cast(tf.equal(tf.argmax(tf.nn.softmax(output),2),\tf.cast(y,tf.int64)),tf.float32))
optimizer = tf.train.AdamOptimizer(learing_rate).minimize(loss)

开始训练模型

with tf.Session() as sess:writer = tf.summary.FileWriter('graphs/seq')sess.run(tf.global_variables_initializer())for epoch in range(epochs):total_loss = 0train_acc = 0for batch_x,batch_y in get_batch(int_word_data,batch_size,len_seq,n_batch):tmp_loss,tmp_acc,_ = sess.run([loss,accuracy,optimizer],feed_dict={inputs:batch_x,y:batch_y})total_loss += tmp_loss#train_acc += tmp_accif epoch % 10 == 0:print('Epoch {}/{} train_loss {:.3f}'.format(epoch,epochs,total_loss/n_batch))saver = tf.train.Saver()saver.save(sess, './checkpoints/arvix/')print('Model Trained and Saved')writer.close()

模型训练完成以后,我们开始预测结果
定义我们的生成序列的函数

def get_state(graph):# 拿到模型的输入,因为我们生成摘要的时候,也是需要给模型一个输入inputs = graph.get_tensor_by_name('inputs:0')# 拿到训练结束时的输出的状态,当作当前状态的输入,因为在rnn里面,他是有记忆的,他的输入和上一次有关,同时和当前#的输入也有关,是二者的一个综合in_state = graph.get_tensor_by_name('in_state:0')# 拿到输出状态,作为下一次的输入状态,out_state = graph.get_tensor_by_name('out_state:0')# 拿到当前输出的每个单词的概率,这个是为了我们预测做准备probability = graph.get_tensor_by_name('probability:0')return inputs,in_state,out_state,probabilitydef get_word(probability,int_to_word):# np.random.choice(a,size,p)# 其中a表示一个数组,也就是随机数在其中选择生成,如何是一个数比如6,8等,则默认是np.arange(a)里面选择# size表示生成的大小,比如可以8行6列等等# p表示每一个可以选择的值的概率,必须加起来等于1# 使用int_to_word是为了将数值对应的单词输出int_w = np.random.choice(probability.shape[2],1,p=probability[0,0,:])[0]return int_to_word[int_w]

设置我们初始的单词和生成序列的长度

# 设置生成句子的长度
gen_len = 300
begin_word = 'deep'

开始预测

# 创建一个Graph对象
graph = tf.Graph()with tf.Session(graph=graph) as sess:#  tf.train.import_meta_graph 表示不重复定义计算图load_meta = tf.train.import_meta_graph('./checkpoints/arvix/.meta')load_meta.restore(sess,'./checkpoints/arvix/')inputs,in_state,out_state,probability = get_state(graph)gen_sentences = [begin_word]pre_state = sess.run(in_state,feed_dict={inputs:np.array([[5]])})for n in range(gen_len):input_s = [[word_to_int[word] for word in gen_sentences[-gen_len:]]]seq_len = len(input_s[0])probability_,pre_stat= sess.run([probability,out_state],feed_dict={inputs:input_s,in_state:pre_state})gen_word = get_word(probability_,int_to_word)gen_sentences.append(gen_word)gen_sentences = ' '.join(gen_sentences)for key,value in tokenize_dict.items():gen_sentences = gen_sentences.replace(value,key+' ')print(gen_sentences)       

预测结果如下:
INFO:tensorflow:Restoring parameters from ./checkpoints/arvix/
deep stacked rnn networks networks learning stacked rnns learning learning learning learning learning learning learning networks learning learning analysis rnns belief rnn neural neural neural feedforward rnn neural neural learning networks learning feedforward rnn network networks learning nets neural networks neural learning rnns stacked network learning rnns stacked networks autoencoder rnns learning network learning belief stacked autoencoder neural networks networks rnns autoencoder learning networks rnns network networks autoencoder stacked look learning feedforward autoencoder stacked neural rnns learning multi-layer rnn nets learning multi-layer learning learning rnn networks networks networks learning rnn autoencoder network stacked rnns learning models models neural learning clustering learning learning learning look networks learning learning learning networks network neural analysis models belief stacked hierarchical rnn learning rnn learning neural stacked learning networks neural rnns neural autoencoder neural learning networks models neural learning neural belief learning learning learning neural network neural look neural stacked autoencoder learning stacked learning neural learning networks analysis networks belief networks belief network feedforward network multi-layer look rnns models dual-stream neural neural networks learning belief dual-stream learning belief stacked learning learning long learning networks learning learning networks rnn networks learning look multi-layer learning look neural learning network dual-stream neural networks stacked learning stacked neural multi-layer autoencoder neural recurrent networks multi-layer learning learning learning learning rnns learning learning network neural belief networks neural learning learning learning models autoencoder learning recurrent learning stacked learning neural learning stacked neural ones neural networks rendering learning rnn learning multi-layer learning rnns learning belief learning learning networks analysis stacked look learning learning networks learning rnn learning learning networks learning analysis network clustering networks network learning autoencoder stacked learning rnn network dual-stream learning networks learning network stacked learning stacked networks look learning learning learning ones neural neural stacked learning learning learning network look rnns autoencoder neural learning neural learning learning network multi-layer stacked
这就是利用rnn做序列的预测的方法。

利用tensorflow自动生成英文摘要相关推荐

  1. 【机器学习PAI实战】—— 玩转人工智能之利用GAN自动生成二次元头像

    前言 深度学习作为人工智能的重要手段,迎来了爆发,在NLP.CV.物联网.无人机等多个领域都发挥了非常重要的作用.最近几年,各种深度学习算法层出不穷, Generative Adverarial Ne ...

  2. 【机器学习PAI实战】—— 玩转人工智能之利用GAN自动生成二次元头像...

    模型训练与在线预测服务.推荐算法四部曲.机器学习PAI实战.更多精彩,尽在 开发者分会场 [机器学习PAI实战]-- 玩转人工智能之综述 [机器学习PAI实战]-- 玩转人工智能之商品价格预测 [机器 ...

  3. 利用ApacheCXF自动生成webservice的客户端代码

    利用ApacheCXF自动生成webservice的客户端代码 一.环境准备 1.JDK环境 2.下载apache-cxf发布包,举例版本为3.2.14,解压发布包,设置CXF_HOME,并添加%CX ...

  4. Django 快速搭建博客 第十一节(文章阅读量统计,自动生成文章摘要)

    这一节主要做一些修补工作,一个是:文章阅读量的统计,另一个是自动生成文章摘要内容 1 . 文章阅读量的统计: 1 文章阅读量的统计,我们需要在model下的Post类中新加入一个views 字段用来统 ...

  5. php开发工程师名片,PHP编程:利用PHP自动生成印有用户信息的名片

    <PHP编程:利用PHP自动生成印有用户信息的名片>要点: 本文介绍了PHP编程:利用PHP自动生成印有用户信息的名片,希望对您有用.如果有疑问,可以联系我们. 前言 PHP教程无论是自己 ...

  6. java编程猜数字大小 (要求利用随机数自动生成一个0--100内的随机数字)

    java编程猜数字(要求利用随机数自动生成一个0–100内的随机数字) public static void main(String[] args) {int num=(int)(Math.rando ...

  7. 利用Flex自动生成C语言词法分析器

    利用Flex自动生成C语言词法分析器 Flex介绍 C语言词法规则 具体实现 源代码 测试代码 实验结果 Flex介绍   1975年Mike Lesk和实习生Eric Schmidt设计并实现了一个 ...

  8. 用Encoder-Decoder模型自动生成文本摘要

    出品:贪心科技(公众号:贪心科技) 作者:Jason Brownlee 前言 文本摘要是自然语言处理中的一个问题,即要为源文档创建一篇简短.准确.流畅的摘要.当针对机器翻译开发的Encoder-Dec ...

  9. Word文档如何自动生成文献摘要?

    一.启动"自动编写摘要"功能 Word 97/2000/XP/2003均支持此项功能,用Word打开需要编辑的论文后,在"工具"菜单选择"自动编写摘要 ...

  10. Diango博客--7.自动生成文章摘要

    文章目录 0.思路引导 1.方法一:覆写 save 方法 2.方法二:使用 truncatechars 模板过滤器 0.思路引导 博客文章的模型有一个 excerpt 字段,这个字段用于存储文章的摘要 ...

最新文章

  1. Delphi中使用IXMLHTTPRequest如何用POST方式提交带参
  2. Python 还能实现图片去雾?FFA 去雾算法、暗通道去雾算法用起来!(附代码)...
  3. 大佬教你Android如何实现时间线效果
  4. 2.1Python基础语法(一)之注释与数据类型:
  5. Linux系统卡慢之调优方法
  6. c++调用mysql存储过程_C++中ADO调用MySQL存储过程失败,诡异的语法异常,求解中,附源码...
  7. ORACLE事务提交
  8. 杭电acm 1846 Brave Game(巴什博弈)
  9. (27)FPGA译码器设计(第6天)
  10. Compilation Error 解决方案汇集
  11. 安装ugjava安装在哪里_讨论!空调安装安全绳该挂哪里
  12. python背包问题并行_背包问题九讲python3实现
  13. WIN10不显示sql2005服务器,win10系统安装sQLserver2005提示“sQL server服务无法启动”的设置办法...
  14. Web前端面试题整合,持续更新【可以收藏】
  15. cdh hive配置mysql_Hive学习(CDH版Hadoop、Hive安装)
  16. 无法修正错误,因为您要求某些软件包保持现状,就是它们破坏了软件包间的依赖关系
  17. 除了高通和博通,还有哪些Wi-Fi6路由器芯片方案可选
  18. kafka 启动时提示 /brokers/ids/1001 is: NODEEXISTS
  19. 软件之聊天工具:QQ,MSN,Google talk,Skype, Lync
  20. 连续信号与离散信号---采样定理

热门文章

  1. android手机录屏工具,安卓手机上有什么好用的屏幕录屏软件可以推荐?
  2. kettle安装教程
  3. QTP基础教程(讲义)《软件测试技术》
  4. 怎么批量修改文件后缀名?
  5. 微信支付(1)---功能测试点
  6. 捷联惯导系统学习6.13(状态估计的误差分配与可观测度分析 )
  7. vue相关插件及框架全家桶
  8. 防火墙 | 网络协议
  9. C#【高级篇】 IntPtr是什么?怎么用?
  10. VMware 12 密钥