#背景

来自GitHub上《tensorflow_cookbook》【https://github.com/nfmcclure/tensorflow_cookbook/tree/master/09_Recurrent_Neural_Networks】

Stacking Multiple LSTM Layers

We stack multiple LSTM layers to improve on our Shakespeare language generation. (Character level vocabulary)

堆叠多个LSTM层以改进我们的莎士比亚语言生成。 (字符级词汇)

此内容为《深度学习之LSTM案例分析(二)》的补充。【https://blog.csdn.net/m0_37621024/article/details/88965957】

#代码

#Stacking LSTM Layers
#---------------------
#  Here we implement an LSTM model on all a data set of Shakespeare works.
#  在这里,我们在莎士比亚作品的所有数据集上实现LSTM模型。
#  We will stack multiple LSTM models for a more accurate representation of Shakespearean language.
#  We will also use characters instead of words.
#  我们将堆叠多个LSTM模型,以更准确地表示莎士比亚语言。我们还将使用字符而不是单词。import os
import re
import string
import requests
import numpy as np
import collections
import random
import pickle
import matplotlib.pyplot as plt
import tensorflow as tf
from tensorflow.python.framework import ops
ops.reset_default_graph()'''Start a computational graph session.'''
# Start a session
sess = tf.Session()# Set RNN Parameters
num_layers = 3  # Number of RNN layers stacked
min_word_freq = 5  # Trim the less frequent words off
rnn_size = 128  # RNN Model size, has to equal embedding size
epochs = 10  # Number of epochs to cycle through data
batch_size = 100  # Train on this many examples at once
learning_rate = 0.0005  # Learning rate
training_seq_len = 50  # how long of a word group to consider
save_every = 500  # How often to save model checkpoints
eval_every = 50  # How often to evaluate the test sentences
prime_texts = ['thou art more', 'to be or not to', 'wherefore art thou']# Download/store Shakespeare data
data_dir = 'temp'
data_file = 'shakespeare.txt'
model_path = 'shakespeare_model'
full_model_dir = os.path.join(data_dir, model_path)# Declare punctuation to remove, everything except hyphens and apostrophes
punctuation = string.punctuation
punctuation = ''.join([x for x in punctuation if x not in ['-', "'"]])# Make Model Directory
if not os.path.exists(full_model_dir):os.makedirs(full_model_dir)# Make data directory
if not os.path.exists(data_dir):os.makedirs(data_dir)'''Load the Shakespeare Data'''
print('Loading Shakespeare Data')
# Check if file is downloaded.
if not os.path.isfile(os.path.join(data_dir, data_file)):print('Not found, downloading Shakespeare texts from www.gutenberg.org')shakespeare_url = 'http://www.gutenberg.org/cache/epub/100/pg100.txt'# Get Shakespeare textresponse = requests.get(shakespeare_url)shakespeare_file = response.content# Decode binary into strings_text = shakespeare_file.decode('utf-8')# Drop first few descriptive paragraphs.s_text = s_text[7675:]# Remove newliness_text = s_text.replace('\r\n', '')s_text = s_text.replace('\n', '')# Write to filewith open(os.path.join(data_dir, data_file), 'w') as out_conn:out_conn.write(s_text)
else:# If file has been saved, load from that filewith open(os.path.join(data_dir, data_file), 'r') as file_conn:s_text = file_conn.read().replace('\n', '')
'''
运行结果:
Loading Shakespeare Data
Not found, downloading Shakespeare texts from www.gutenberg.org
Done Loading Data.
''''''Clean and split the text data.'''
# Clean text
print('Cleaning Text')
s_text = re.sub(r'[{}]'.format(punctuation), ' ', s_text)
s_text = re.sub('\s+', ' ', s_text).strip().lower()# Split up by characters
char_list = list(s_text)
'''
运行结果:
Cleaning Text
''''''Build word vocabulary function and transform the text.'''
# Build word vocabulary function
def build_vocab(characters):character_counts = collections.Counter(characters)# Create vocab --> index mappingchars = character_counts.keys()vocab_to_ix_dict = {key: (inx + 1) for inx, key in enumerate(chars)}# Add unknown key --> 0 indexvocab_to_ix_dict['unknown'] = 0# Create index --> vocab mappingix_to_vocab_dict = {val: key for key, val in vocab_to_ix_dict.items()}return ix_to_vocab_dict, vocab_to_ix_dict# Build Shakespeare vocabulary
print('Building Shakespeare Vocab by Characters')
ix2vocab, vocab2ix = build_vocab(char_list)
vocab_size = len(ix2vocab)
print('Vocabulary Length = {}'.format(vocab_size))
# Sanity Check
assert(len(ix2vocab) == len(vocab2ix))
'''
运行结果:
Building Shakespeare Vocab by Characters
Vocabulary Length = 40
'''# Convert text to word vectors
s_text_ix = []
for x in char_list:try:s_text_ix.append(vocab2ix[x])except KeyError:s_text_ix.append(0)
s_text_ix = np.array(s_text_ix)# Define LSTM RNN Model Class
class LSTM_Model():def __init__(self, rnn_size, num_layers, batch_size, learning_rate,training_seq_len, vocab_size, infer_sample=False):self.rnn_size = rnn_sizeself.num_layers = num_layersself.vocab_size = vocab_sizeself.infer_sample = infer_sampleself.learning_rate = learning_rateif infer_sample:self.batch_size = 1self.training_seq_len = 1else:self.batch_size = batch_sizeself.training_seq_len = training_seq_lenself.lstm_cell = tf.contrib.rnn.BasicLSTMCell(rnn_size)self.lstm_cell = tf.contrib.rnn.MultiRNNCell([self.lstm_cell for _ in range(self.num_layers)]) '''新增'''self.initial_state = self.lstm_cell.zero_state(self.batch_size, tf.float32)self.x_data = tf.placeholder(tf.int32, [self.batch_size, self.training_seq_len])self.y_output = tf.placeholder(tf.int32, [self.batch_size, self.training_seq_len])with tf.variable_scope('lstm_vars'):# Softmax Output WeightsW = tf.get_variable('W', [self.rnn_size, self.vocab_size], tf.float32, tf.random_normal_initializer())b = tf.get_variable('b', [self.vocab_size], tf.float32, tf.constant_initializer(0.0))# Define Embeddingembedding_mat = tf.get_variable('embedding_mat', [self.vocab_size, self.rnn_size],tf.float32, tf.random_normal_initializer())embedding_output = tf.nn.embedding_lookup(embedding_mat, self.x_data)rnn_inputs = tf.split(axis=1, num_or_size_splits=self.training_seq_len, value=embedding_output)rnn_inputs_trimmed = [tf.squeeze(x, [1]) for x in rnn_inputs]decoder = tf.contrib.legacy_seq2seq.rnn_decoderoutputs, last_state = decoder(rnn_inputs_trimmed,self.initial_state,self.lstm_cell)# RNN outputsoutput = tf.reshape(tf.concat(axis=1, values=outputs), [-1, rnn_size])# Logits and outputself.logit_output = tf.matmul(output, W) + bself.model_output = tf.nn.softmax(self.logit_output)loss_fun = tf.contrib.legacy_seq2seq.sequence_loss_by_exampleloss = loss_fun([self.logit_output],[tf.reshape(self.y_output, [-1])],[tf.ones([self.batch_size * self.training_seq_len])],self.vocab_size)self.cost = tf.reduce_sum(loss) / (self.batch_size * self.training_seq_len)self.final_state = last_stategradients, _ = tf.clip_by_global_norm(tf.gradients(self.cost, tf.trainable_variables()), 4.5)optimizer = tf.train.AdamOptimizer(self.learning_rate)self.train_op = optimizer.apply_gradients(zip(gradients, tf.trainable_variables()))def sample(self, sess, words=ix2vocab, vocab=vocab2ix, num=20, prime_text='thou art'):state = sess.run(self.lstm_cell.zero_state(1, tf.float32))char_list = list(prime_text)for char in char_list[:-1]:x = np.zeros((1, 1))x[0, 0] = vocab[char]feed_dict = {self.x_data: x, self.initial_state:state}[state] = sess.run([self.final_state], feed_dict=feed_dict)out_sentence = prime_textchar = char_list[-1]for n in range(num):x = np.zeros((1, 1))x[0, 0] = vocab[char]feed_dict = {self.x_data: x, self.initial_state:state}[model_output, state] = sess.run([self.model_output, self.final_state], feed_dict=feed_dict)sample = np.argmax(model_output[0])if sample == 0:breakchar = words[sample]out_sentence = out_sentence + charreturn out_sentence'''Initialize the LSTM Model'''
# Define LSTM Model
lstm_model = LSTM_Model(rnn_size,num_layers,batch_size,learning_rate,training_seq_len,vocab_size)# Tell TensorFlow we are reusing the scope for the testing
with tf.variable_scope(tf.get_variable_scope(), reuse=True):test_lstm_model = LSTM_Model(rnn_size,num_layers,batch_size,learning_rate,training_seq_len,vocab_size,infer_sample=True)# Create model saver
saver = tf.train.Saver(tf.global_variables())# Create batches for each epoch
num_batches = int(len(s_text_ix)/(batch_size * training_seq_len)) + 1
# Split up text indices into subarrays, of equal size
batches = np.array_split(s_text_ix, num_batches)
# Reshape each split into [batch_size, training_seq_len]
batches = [np.resize(x, [batch_size, training_seq_len]) for x in batches]# Initialize all variables
init = tf.global_variables_initializer()
sess.run(init)# Train model
train_loss = []
iteration_count = 1
for epoch in range(epochs):# Shuffle word indicesrandom.shuffle(batches)# Create targets from shuffled batchestargets = [np.roll(x, -1, axis=1) for x in batches]# Run a through one epochprint('Starting Epoch #{} of {}.'.format(epoch+1, epochs))# Reset initial LSTM state every epochstate = sess.run(lstm_model.initial_state)for ix, batch in enumerate(batches):training_dict = {lstm_model.x_data: batch, lstm_model.y_output: targets[ix]}# We need to update initial state for each RNN cell:for i, (c, h) in enumerate(lstm_model.initial_state):training_dict[c] = state[i].ctraining_dict[h] = state[i].htemp_loss, state, _ = sess.run([lstm_model.cost, lstm_model.final_state, lstm_model.train_op],feed_dict=training_dict)train_loss.append(temp_loss)# Print status every 10 gensif iteration_count % 10 == 0:summary_nums = (iteration_count, epoch+1, ix+1, num_batches+1, temp_loss)print('Iteration: {}, Epoch: {}, Batch: {} out of {}, Loss: {:.2f}'.format(*summary_nums))# Save the model and the vocabif iteration_count % save_every == 0:# Save modelmodel_file_name = os.path.join(full_model_dir, 'model')saver.save(sess, model_file_name, global_step=iteration_count)print('Model Saved To: {}'.format(model_file_name))# Save vocabularydictionary_file = os.path.join(full_model_dir, 'vocab.pkl')with open(dictionary_file, 'wb') as dict_file_conn:pickle.dump([vocab2ix, ix2vocab], dict_file_conn)if iteration_count % eval_every == 0:for sample in prime_texts:print(test_lstm_model.sample(sess, ix2vocab, vocab2ix, num=10, prime_text=sample))iteration_count += 1# Plot loss over time
plt.plot(train_loss, 'k-')
plt.title('Sequence to Sequence Loss')
plt.xlabel('Generation')
plt.ylabel('Loss')
plt.show()

(Placeholder)

深度学习之LSTM案例分析(三)相关推荐

  1. 深度学习之LSTM案例分析(二)

    #背景 来自GitHub上<tensorflow_cookbook>[https://github.com/nfmcclure/tensorflow_cookbook/tree/maste ...

  2. 免费教材丨第56期:《深度学习导论及案例分析》、《谷歌黑板报-数学之美》

    小编说  离春节更近了!  本期教材        本期为大家发放的教材为:<深度学习导论及案例分析>.<谷歌黑板报-数学之美>两本书,大家可以根据自己的需要阅读哦! < ...

  3. 《深度学习导论及案例分析》一2.11概率图模型的推理

    本节书摘来自华章出版社<深度学习导论及案例分析>一书中的第2章,第2.11节,作者李玉鑑 张婷,更多章节内容可以访问云栖社区"华章计算机"公众号查看. 2.11概率图模 ...

  4. 《深度学习导论及案例分析》一导读

    PREFACE 前言 "深度学习"一词大家已经不陌生了,随着在不同领域取得了超越其他方法的成功,深度学习在学术界和工业界掀起了一次神经网络发展史上的新浪潮.运用深度学习解决实际问题 ...

  5. 深度学习在工业推荐如何work?Netflix这篇论文「深度学习推荐系统Netflix案例分析」阐述DL在RS的优劣与经验教训...

    来源:专知 深度学习在推荐系统中如何发挥作用是一个重要的问题.最近来自Netflix的文章详细阐述了这一点指出:在建模用户物品交互方面,深度学习相比传统基线方法并无太大优势,而对于异质特征的表示融入深 ...

  6. 【C4】基于深度学习的心电信号分析

    ★★★ 本文源自AI Studio社区精品项目,[点击此处]查看更多精品内容 >>> 基于深度学习的心电信号分析 一.项目背景 近年来,随着人工智能和算法的发展,以机器学习和深度学习 ...

  7. vue+django 微博舆情系统源码、深度学习+舆情扩散消失分析、舆情紧急等级、属地分析、按话题、情感预测、话题评论获取、提取观点、正面负面舆情、按区域检测舆情

    项目背景 315又马上要到了,现在有开始对食品安全话题的关注地提升了,因此,本文系统对微博的食品安全话题进行分析,有如下的功能 1.展示当前食品安全事件相关的热点信息以及提供根据食品关键词,食品安全类 ...

  8. 深度学习与自然语言处理第三次作业——LDA段落主题分布问题

    深度学习与自然语言处理第三次作业--LDA段落主题分布问题 利用LDA模型解决段落主体分布问题 文章目录 深度学习与自然语言处理第三次作业--LDA段落主题分布问题 一.解题背景 二.解题原理 1.L ...

  9. 动手学深度学习: 图像分类案例2,GAN,DCGAN

    动手学深度学习: 图像分类案例2,GAN,DCGAN 内容摘自伯禹人工智能AI公益课程 图像分类案例2 1.关于整理数据集后得到的train.valid.train_valid和test数据集: 1) ...

最新文章

  1. 解决网通英文wiki无法显示图片问题【20100723更新】
  2. 2 中ascii函数_C语言编程预备知识--字节、ASCII
  3. 自组网中继台_同频自组网基站
  4. Java解析xml的主要解析器: SAX和DOM的选择(附上新方法--Pull解析)
  5. rest syntax(parameters)
  6. Exchange 2013CU17和office 365混合部署-配置SSO(七)
  7. IEC 62368认证测试项目
  8. FS4052单节2A充电IC采用三段式充电管理IC
  9. python安装pywifi
  10. 电影《Green book》观后感_已补全:携带着种族歧视的“光环”,艰难地获得朋友的相互依赖,依然得享受生活的酸甜苦咸。...
  11. 计算机专业的学生买什么电脑,设计类学生买什么电脑
  12. tensorflow中sparse_placeholder在saved_model中保存pb模型的使用方法
  13. google play电子市场和gmail如何安装在国产手机、三星手机、摩托手机里
  14. java面试题大全2
  15. 校企勾结?京东被指压榨实习生:不结薪资、暴力对待、校方威胁...
  16. ERROR: Unable to find method 'com.android.build.gradle.api.BaseVariant.getOutputs()Ljava/util/List;'
  17. 《一只狗的使命2》影评
  18. 斐波那契堆(不太详尽)
  19. Android so(ELF) 文件解析
  20. 在MAC 中修改虚拟机配置文件

热门文章

  1. marvin java_使用Java中的Marvin框架去除轮廓
  2. 小白采坑 非法反射警告 An illegal reflective access operation has occurred错误
  3. uni-app相关知识积累
  4. 通过爬虫获取第五人格游戏信息整理并分析(一)
  5. 人民币小写转大写的一般方法
  6. 跨境电商卖家应该知道的3个社交媒体营销策略
  7. 学位计算机考试成绩怎么查询时间,湖北省自考00019计算机实践考试成绩在哪里查询...
  8. 齐了!百度、腾讯、滴滴、抖音的技术大佬都来了
  9. 读jquery 权威指南[3]-动画
  10. android打开位置服务,Android - 位置定位(Location)服务(Service)类的基本操作