1 导入库

import os
import io
import re
import requests
import string
import collections
import random
import pickle
import matplotlib.pyplot as plt
import numpy as np
import tensorflow as tf

2 加载数据

data_dir = 'temp'
data_file = 'shakespeare.txt'
model_path = 'shakespeare_model'
full_model_dir = os.path.join(data_dir, model_path)
punctuation = string.punctuation
punctuation = ''.join([x for x in punctuation if x not in ['-',"'"]]) # !"#$%&()*+,./:;<=>?@[\]^_`{|}~if not os.path.exists(full_model_dir):os.makedirs(full_model_dir)
if not os.path.exists(data_dir):os.makedirs(data_dir)print('Loading Shakespeare Data')if not os.path.isfile(os.path.join(data_dir, data_file)):print('Not found, downloading Shakespeare texts from www.gutenberg.org')shakespeare_url = 'http://www.gutenberg.org/cache/epub/100/pg100.txt'response = requests.get(shakespeare_url)shakespeare_file = response.contents_text = shakespeare_file.decode('utf-8')s_text = s_text[7675:]s_text = s_text.replace('\r\n', '')s_text = s_text.replace('\n', '')with open(os.path.join(data_dir, data_file), 'w') as out_conn:out_conn.write(s_text)
else:with open(os.path.join(data_dir, data_file), 'r') as file_conn:s_text = file_conn.read().replace('\n', '')
Loading Shakespeare Data

3 清洗数据

s_text = re.sub(r'[{}]'.format(punctuation), ' ', s_text)
s_text = re.sub('\s+', ' ', s_text).strip().lower()
print(s_text[0:1000])
from fairest creatures we desire increase that thereby beauty's rose might never die but as the riper should by time decease his tender heir might bear his memory but thou contracted to thine own bright eyes feed'st thy light's flame with self-substantial fuel making a famine where abundance lies thy self thy foe to thy sweet self too cruel thou that art now the world's fresh ornament and only herald to the gaudy spring within thine own bud buriest thy content and tender churl mak'st waste in niggarding pity the world or else this glutton be to eat the world's due by the grave and thee 2 when forty winters shall besiege thy brow and dig deep trenches in thy beauty's field thy youth's proud livery so gazed on now will be a tattered weed of small worth held then being asked where all thy beauty lies where all the treasure of thy lusty days to say within thine own deep sunken eyes were an all-eating shame and thriftless praise how much more praise deserved thy beauty's use if thou couldst

4 创建词汇表

min_word_freq = 5
def build_vocab(text, min_word_freq):word_counts = collections.Counter(text.split(' '))word_counts = {key : counts for key, counts in word_counts.items() if counts > min_word_freq}words = word_counts.keys()vocab_to_ix_dict = {key : (ix+1) for ix, key in enumerate(words)}vocab_to_ix_dict['unknown'] = 0ix_to_vocab_dict = {ix : words for words, ix in vocab_to_ix_dict.items()}return ix_to_vocab_dict, vocab_to_ix_dictix2vocab, vocab2ix = build_vocab(s_text, min_word_freq)
vocab_size = len(ix2vocab)  + 1
print(vocab_size)
8009

5 将文本转换成索引

s_text_words = s_text.split(' ')
s_text_ix = np.array([vocab2ix[word] if word in vocab2ix.keys() else 0 for word in s_text_words]) # 注意将词频小于min_word_seq 表示为0
print(s_text_ix)
[6232 1204  803 ... 3434 6628 1863]

6 LSTM模型

rnn_size = 128  # RNN Model size
epochs = 1  # Number of epochs to cycle through data
batch_size = 100  # Train on this many examples at once
learning_rate = 0.001  # Learning rate
training_seq_len = 50  # how long of a word group to consider
embedding_size = rnn_size  # Word embedding size
save_every = 500  # How often to save model checkpoints
eval_every = 50  # How often to evaluate the test sentences
prime_texts = ['thou art more', 'to be or not to', 'wherefore art thou']
sess = tf.Session()
class LSTM_Model():def __init__(self,embedding_size, rnn_size, batch_size, learning_rate, training_seq_len, vocab_size, infer=False):self.rnn_size = rnn_sizeself.vocab_size = vocab_sizeself.infer = inferself.learning_rate = learning_rateif self.infer:self.batch_size = 1self.training_seq_len = 1else:self.batch_size = batch_sizeself.training_seq_len = training_seq_lenself.lstm_cell = tf.nn.rnn_cell.BasicLSTMCell(rnn_size)self.initial_state = self.lstm_cell.zero_state(self.batch_size, tf.float32)self.x_data = tf.placeholder(shape=[self.batch_size, self.training_seq_len], dtype=tf.int32)self.y_output = tf.placeholder(shape=[self.batch_size, self.training_seq_len], dtype=tf.int32)with tf.variable_scope('lstm_vars'):W = tf.get_variable('W', shape=[self.rnn_size, self.vocab_size], dtype=tf.float32, initializer=tf.random_normal_initializer())b = tf.get_variable('b', shape=[self.vocab_size], dtype=tf.float32, initializer=tf.random_normal_initializer())embedding_mat = tf.get_variable('embedding_mat', shape=[self.vocab_size, self.rnn_size], dtype=tf.float32, initializer=tf.random_normal_initializer())embedding_output = tf.nn.embedding_lookup(embedding_mat, self.x_data)rnn_inputs = tf.split(axis=1, num_or_size_splits=self.training_seq_len, value=embedding_output) # 将embedding_output在维度1熵切分成train_seq_len个rnn_inputs_trimmed = [tf.squeeze(x, [1]) for x in rnn_inputs] # Removes dimensions of size 1 from the shape of a tensordef infered_loop(prev, count):prev_transformed = tf.matmul(prev, W) + bprev_symbol = tf.stop_gradient(tf.argmax(prev_transformed, 1))output = tf.nn.embedding_lookup(embedding_mat, prev_symbol)return outputdecoder = tf.contrib.legacy_seq2seq.rnn_decoderoutputs, last_state = decoder(decoder_inputs=rnn_inputs_trimmed, initial_state=self.initial_state, cell=self.lstm_cell, loop_function = infered_loop if self.infer else None)output = tf.reshape(tf.concat(axis=1, values=outputs), [-1, self.rnn_size])self.logit_output = tf.matmul(output, W) + bself.model_output = tf.nn.softmax(self.logit_output)loss_fun = tf.contrib.legacy_seq2seq.sequence_loss_by_exampleloss = loss_fun([self.logit_output], [tf.reshape(self.y_output, [-1])], [tf.ones([self.batch_size * self.training_seq_len])])self.cost = tf.reduce_sum(loss) / (self.batch_size * self.training_seq_len)self.final_state = last_stategradients, _ = tf.clip_by_global_norm(tf.gradients(self.cost, tf.trainable_variables()), 4.5)optimizer = tf.train.AdamOptimizer(self.learning_rate)self.train_op = optimizer.apply_gradients(zip(gradients, tf.trainable_variables()))def sample(self, sess, words=ix2vocab, vocab=vocab2ix, num=10, prime_text='thou art'):state = sess.run(self.lstm_cell.zero_state(1, tf.float32))word_list = prime_text.split()for word in word_list[:-1]:x = np.zeros((1,1))x[0,0] = vocab[word]feed_dict = {self.x_data:x, self.initial_state:state}[state] = sess.run([self.final_state], feed_dict=feed_dict)out_sentence = prime_textword = word_list[-1]for n in range(num):x = np.zeros((1,1))x[0,0] = vocab[word]feed_dict = {self.x_data:x, self.initial_state:state}[model_output, state] = sess.run([self.model_output, self.final_state], feed_dict=feed_dict)sample = np.argmax(model_output[0])if sample == 0:breakword = words[sample]out_sentence = out_sentence + ' ' + wordreturn out_sentence

7 声明LSTM模型及其测试模型

# LSTM模型
lstm_model = LSTM_Model(embedding_size, rnn_size, batch_size, learning_rate, training_seq_len, vocab_size)# 测试模型
with tf.variable_scope(tf.get_variable_scope(), reuse=True):test_lstm_model = LSTM_Model(embedding_size, rnn_size, batch_size, learning_rate, training_seq_len, vocab_size, infer=True)
WARNING:tensorflow:From <ipython-input-6-0d3f7347e00d>:23: BasicLSTMCell.__init__ (from tensorflow.python.ops.rnn_cell_impl) is deprecated and will be removed in a future version.
Instructions for updating:
This class is deprecated, please use tf.nn.rnn_cell.LSTMCell, which supports all the feature this cell currently has. Please replace the existing code with tf.nn.rnn_cell.LSTMCell(name='basic_lstm_cell').

8 Saver

saver = tf.train.Saver()

9 分割输入文本

num_batches = int(len(s_text_ix)/(batch_size*training_seq_len)) + 1
batches = np.array_split(s_text_ix, num_batches)
batches = [np.resize(x, [batch_size, training_seq_len]) for x in batches]

10 初始化变量

init = tf.global_variables_initializer()
sess.run(init)

11 训练

train_loss = []
iteration_count = 1
for epoch in range(epochs):random.shuffle(batches) # 打乱数据targets = [np.roll(x, -1, axis=1) for x in batches] # np.roll(x,shift,axis) (将a,沿着axis的方向,滚动shift长度)print('Starting Epoch #{} of {}.'.format(epoch+1, epochs))# Reset initial LSTM state every epochstate = sess.run(lstm_model.initial_state)for ix, batch in enumerate(batches):training_dict = {lstm_model.x_data: batch, lstm_model.y_output: targets[ix]}c, h = lstm_model.initial_statetraining_dict[c] = state.ctraining_dict[h] = state.htemp_loss, state, _ = sess.run([lstm_model.cost, lstm_model.final_state, lstm_model.train_op], feed_dict=training_dict)train_loss.append(temp_loss)# Print status every 10 gensif iteration_count % 10 == 0:summary_nums = (iteration_count, epoch+1, ix+1, num_batches+1, temp_loss)print('Iteration: {}, Epoch: {}, Batch: {} out of {}, Loss: {:.2f}'.format(*summary_nums))# Save the model and the vocabif iteration_count % save_every == 0:# Save modelmodel_file_name = os.path.join(full_model_dir, 'model')saver.save(sess, model_file_name, global_step=iteration_count)print('Model Saved To: {}'.format(model_file_name))# Save vocabularydictionary_file = os.path.join(full_model_dir, 'vocab.pkl')with open(dictionary_file, 'wb') as dict_file_conn:pickle.dump([vocab2ix, ix2vocab], dict_file_conn)if iteration_count % eval_every == 0:for sample in prime_texts:print(test_lstm_model.sample(sess, ix2vocab, vocab2ix, num=10, prime_text=sample))iteration_count += 1
Starting Epoch #1 of 1.
Iteration: 10, Epoch: 1, Batch: 10 out of 182, Loss: 10.30
Iteration: 20, Epoch: 1, Batch: 20 out of 182, Loss: 9.38
Iteration: 30, Epoch: 1, Batch: 30 out of 182, Loss: 8.99
Iteration: 40, Epoch: 1, Batch: 40 out of 182, Loss: 8.62
Iteration: 50, Epoch: 1, Batch: 50 out of 182, Loss: 8.41
thou art more than wide to
to be or not to the
wherefore art thou bassanio's bassanio's master a
Iteration: 60, Epoch: 1, Batch: 60 out of 182, Loss: 7.98
Iteration: 70, Epoch: 1, Batch: 70 out of 182, Loss: 7.98
Iteration: 80, Epoch: 1, Batch: 80 out of 182, Loss: 7.70
Iteration: 90, Epoch: 1, Batch: 90 out of 182, Loss: 7.63
Iteration: 100, Epoch: 1, Batch: 100 out of 182, Loss: 7.01
thou art more than wide to
to be or not to
wherefore art thou canst bassanio's a
Iteration: 110, Epoch: 1, Batch: 110 out of 182, Loss: 7.09
Iteration: 120, Epoch: 1, Batch: 120 out of 182, Loss: 7.10
Iteration: 130, Epoch: 1, Batch: 130 out of 182, Loss: 7.24
Iteration: 140, Epoch: 1, Batch: 140 out of 182, Loss: 6.74
Iteration: 150, Epoch: 1, Batch: 150 out of 182, Loss: 6.76
thou art more than than to
to be or not to the
wherefore art thou canst fall'n sycorax clown canst clown piteous fran base
Iteration: 160, Epoch: 1, Batch: 160 out of 182, Loss: 6.65
Iteration: 170, Epoch: 1, Batch: 170 out of 182, Loss: 6.60
Iteration: 180, Epoch: 1, Batch: 180 out of 182, Loss: 6.82

12 绘图

plt.plot(train_loss, 'k-')
plt.title('Sequence to Sequence Loss')
plt.xlabel('Generation')
plt.ylabel('Loss')
plt.show()

png

# 小例子
a = [[1,2,3], [2,3,2]]
b = tf.split(a, 3,1)
sess = tf.Session()
print(sess.run((b)))
print(sess.run(tf.squeeze(b,[2])))
[array([[1],[2]]), array([[2],[3]]), array([[3],[2]])]
[[1 2][2 3][3 2]]
c = np.array_split(a, 2)
c
[array([[1, 2, 3]]), array([[2, 3, 2]])]
x = np.arange(12).reshape(3,4)  # x例子
np.roll(x, -1, axis=1)
array([[ 1,  2,  3,  0],[ 5,  6,  7,  4],[ 9, 10, 11,  8]])

利用tensorflow构建LSTM预测单词相关推荐

  1. 利用Tensorflow构建RNN并对序列数据进行建模

    利用Tensorflow构建RNN并对序列数据进行建模 对文本处理处理任务的方法中,一般将TF-IDF向量作为特征输入.显然的缺陷是:这种方法丢失了输入的文本序列中每个单词的顺序. 对一般的前馈神经网 ...

  2. lstm预测单词_下一个单词预测完整指南

    lstm预测单词 As part of my summer internship with Linagora's R&D team, I was tasked with developing ...

  3. 神经网络学习小记录2——利用tensorflow构建循环神经网络(RNN)

    神经网络学习小记录2--利用tensorflow构建循环神经网络(RNN) 学习前言 RNN简介 tensorflow中RNN的相关函数 tf.nn.rnn_cell.BasicLSTMCell tf ...

  4. 利用Tensorflow构建生成对抗网络GAN以生成数据

    使用生成对抗网络(GAN)生成数据 本文主要内容 介绍了自动编码器的基本原理 比较了生成模型与自动编码器的区别 描述了GAN模型的网络结构 分析了GAN模型的目标核函数以及训练过程 介绍了利用Goog ...

  5. 利用Tensorflow构建RNN实现垃圾邮件分类

    1 导入库 import os import re import io import requests import numpy as np import tensorflow as tf impor ...

  6. 利用tensorflow构建AlexNet模型,实现小数量级的猫狗分类(只有train)

    首先看路径: data文件夹分为,model文件夹,train文件夹和文件夹,model文件夹存放模型文件,train存放cat和dog的两个文件夹图片, validation和train一样.con ...

  7. Tensorflow实现LSTM详解

    关于什么是 LSTM 我就不详细阐述了,吴恩达老师视频课里面讲的很好,我大概记录了课上的内容在吴恩达<序列模型>笔记一,网上也有很多写的好的解释,比如:LSTM入门.理解LSTM网络 然而 ...

  8. 利用tensorflow训练自己的图片数据集——数据准备

    昨天实现了一个简单的CNN网络.用了MNIST数据集,虽然看来对这个数据集用的很多,但是真正这个数据集是怎么在训练的时候被调用的,以及怎么把它换成自己的数据集都是一脸懵. 直接附上链接:MNIST数据 ...

  9. lstm 根据前文预测词_干货 | Pytorch实现基于LSTM的单词检测器

    Pytorch实现 基于LSTM的单词检测器 字幕组双语原文: Pytorch实现基于LSTM的单词检测器 英语原文: LSTM Based Word Detectors 翻译: 雷锋字幕组(Icar ...

最新文章

  1. 用Leangoo进行项目管理
  2. [leetcode]Search in Rotated Sorted Array @ Python
  3. 全球首个AI驾校教练+驾照考官已上岗,装手机里就能用,再也不怕挨教练骂了...
  4. iOS10权限设置问题以及xcdoe8更新细节问题
  5. (转)CocosCreator零基础制作游戏《极限跳跃》二、制作游戏开始场景
  6. linux18.0.4安装mysql
  7. TikZ绘图示例——尺规作图: 鸭蛋圆形的近似画法
  8. Spring框架----IOC的概念和作用之程序的耦合和解耦
  9. 52 - 算法- leetcode 14 最长公共前缀
  10. h3c 链路聚合测试_4G/5G聚合路由器在直播中无线多链路聚合图传技术是什么?
  11. java session失效之后跳转_详解springmvc控制登录用户session失效后跳转登录页面
  12. 如何使用谷歌云盘下载Kaggle数据集+解压
  13. springboot版的微信公众号,订阅号
  14. 计算机二级word插入目录,Word 2010编辑目录的两种方法,你会吗?
  15. FreeIPA 4.7.0 服务端 部署
  16. 苹果10.13.6,开机的时候经常会显示 禁止符号
  17. 《一网打尽:贝佐斯与亚马逊时代》的推荐及推荐书单
  18. Information Retrieval(信息检索)笔记02:Preprocessing and Tolerant Retrieval
  19. Acer Aspire V5-471G修复BIOS
  20. 微信JS接口- 企业号开发者接口文档

热门文章

  1. 【ng-alain】解决sf设置了visibleIf的字段,默认执行required验证
  2. 如何制定提高客人满意度和客户忠诚度的客户参与策略
  3. 微信授权登录的多帐号问题
  4. Part3-4-1 搭建自己的SSR
  5. Mix-In的译法探讨
  6. 【FNN分类】基于粒子群结合引力搜索算法优化前向反馈神经网络实现数据分类附matlab代码
  7. 案例-摩拜与哈罗测试自动化演变
  8. 一文解决安装Anaconda后C盘不断增加的问题、修改默认配置
  9. uni-app 中如何打开外部应用,如:浏览器、淘宝、京东、微博等
  10. 在电脑上下载 Youtube 的视频