一、简介

  • 长短期记忆网络

  • LSTM(Long-Short Term Memory)是递归神经网络(RNN:Recurrent Neutral Network)的一种。
    RNNs也叫递归神经网络序列,它是一种根据时间序列或字符序列(具体看应用场景)自我调用的特殊神经网络。将它按序列展开后,就成为常见的三层神经网络。常应用于语音识别。

  • 虽然前馈神经网络取得很大成功,但它无法明确模拟时间关系,并且所有数据点都是固定长度的向量。所以就诞生了递归神经网络,递归即自我调用,递归神经网络与其他网络的不同之处在于它的隐含层是能够跨越时间点的自连接隐含层,隐含层的输出不仅进入输出端,还进入了下一个时间步骤的隐含层,所以它能够持续保留信息,能够根据之前状态推出后面的状态。

  • 由于独特的设计结构,LSTM适合于处理和预测时间序列中间隔和延迟非常长的重要事件

二、实战代码

  • 语料中的小说摘自17k小说网的一篇小说

  • 小说地址(novel.txt):http://www.17k.com/list/2793873.html

  • data.py

#!/bash/bin
# -*-coding=utf-8-*-
import tensorflow as tf
import codecs
import os
import jieba
import collections
import re"""
将小说进行分词,去除空格,建立词汇表与id的字典,生成初始输入模型的x与y
"""def readfile(file_path):f = codecs.open(file_path, 'r', 'utf-8')alltext = f.read()alltext = re.sub(r'\s', '', alltext)seglist = list(jieba.cut(alltext, cut_all=False))return seglistdef _build_vocab(filename):data = readfile(filename)counter = collections.Counter(data)count_pairs = sorted(counter.items(), key=lambda x: (-x[1], x[0]))words, _ = list(zip(*count_pairs))word_to_id = dict(zip(words, range(len(words))))id_to_word = dict(zip(range(len(words)), words))dataids = []for w in data:dataids.append(word_to_id[w])return word_to_id, id_to_word, dataidsdef dataproducer(batch_size, num_steps, filename):word_to_id, id_to_word, data = _build_vocab(filename)datalen = len(data)batchlen = datalen // batch_sizeepcho_size = (batchlen - 1) // num_stepsdata = tf.reshape(data[0: batchlen * batch_size], [batch_size, batchlen])i = tf.train.range_input_producer(epcho_size, shuffle=False).dequeue()x = tf.slice(data, [0, i * num_steps], [batch_size, num_steps])y = tf.slice(data, [0, i * num_steps + 1], [batch_size, num_steps])x.set_shape([batch_size, num_steps])y.set_shape([batch_size, num_steps])return x, y, id_to_word
  • lstm.py
#!/bash/bin
# -*-coding=utf-8-*-
import tensorflow as tf
from data import *
import numpy as np
import randomdef random_distribution():"""Generate a random column of probabilities."""b = np.random.uniform(0.0, 1.0, size=[1, vocab_size])return b / np.sum(b, 1)[:, None]def sample_distribution(distribution):  # choose under the probabilities"""Sample one element from a distribution assumed to be an array of normalizedprobabilities."""r = random.uniform(0, 1)s = 0for i in range(len(distribution[0])):s += distribution[0][i]if s >= r:return ireturn len(distribution) - 1def sample(prediction):d = sample_distribution(prediction)re = []re.append(d)return re# 模型参数设置
learning_rate = 1.0
num_steps = 35
hidden_size = 300
keep_prob = 1.0
lr_decay = 0.5
batch_size = 20
num_layers = 3
max_epoch = 14# 语料文件
filename = 'novel.txt'x, y, id_to_word = dataproducer(batch_size, num_steps, filename)
vocab_size = len(id_to_word)size = hidden_size# 建立lstm模型
lstm_cell = tf.nn.rnn_cell.BasicLSTMCell(size, forget_bias=0.5)
lstm_cell = tf.nn.rnn_cell.DropoutWrapper(lstm_cell, output_keep_prob=keep_prob)
cell = tf.nn.rnn_cell.MultiRNNCell([lstm_cell], num_layers)initial_state = cell.zero_state(batch_size, tf.float32)
state = initial_state
embedding = tf.get_variable('embedding', [vocab_size, size])
input_data = x
targets = ytest_input = tf.placeholder(tf.int32, shape=[1])
test_initial_state = cell.zero_state(1, tf.float32)inputs = tf.nn.embedding_lookup(embedding, input_data)
test_inputs = tf.nn.embedding_lookup(embedding, test_input)outputs = []
initializer = tf.random_uniform_initializer(-0.1, 0.1)# 根据训练数据输出误差反向调整模型,tensorflow主要通过变量空间来实现共享变量
with tf.variable_scope("Model", reuse=None, initializer=initializer):with tf.variable_scope("r", reuse=None, initializer=initializer):softmax_w = tf.get_variable('softmax_w', [size, vocab_size])softmax_b = tf.get_variable('softmax_b', [vocab_size])with tf.variable_scope("RNN", reuse=None, initializer=initializer):for time_step in range(num_steps):if time_step > 0: tf.get_variable_scope().reuse_variables()(cell_output, state) = cell(inputs[:, time_step, :], state, )outputs.append(cell_output)output = tf.reshape(outputs, [-1, size])logits = tf.matmul(output, softmax_w) + softmax_bloss = tf.contrib.legacy_seq2seq.sequence_loss_by_example([logits], [tf.reshape(targets, [-1])], [tf.ones([batch_size * num_steps])])global_step = tf.Variable(0)learning_rate = tf.train.exponential_decay(10.0, global_step, 5000, 0.1, staircase=True)optimizer = tf.train.GradientDescentOptimizer(learning_rate)gradients, v = zip(*optimizer.compute_gradients(loss))gradients, _ = tf.clip_by_global_norm(gradients, 1.25)optimizer = optimizer.apply_gradients(zip(gradients, v), global_step=global_step)cost = tf.reduce_sum(loss) / batch_size# 预测新一轮输出teststate = test_initial_state(celloutput, teststate) = cell(test_inputs, teststate)partial_logits = tf.matmul(celloutput, softmax_w) + softmax_bpartial_logits = tf.nn.softmax(partial_logits)# 根据之前建立的操作,运行tensorflow会话
sv = tf.train.Supervisor(logdir=None)
with sv.managed_session() as session:costs = 0iters = 0for i in range(100000):_, l = session.run([optimizer, cost])costs += liters += num_stepsperplextity = np.exp(costs / iters)if i % 20 == 0:print(perplextity)if i % 100 == 0:p = random_distribution()b = sample(p)sentence = id_to_word[b[0]]for j in range(200):test_output = session.run(partial_logits, feed_dict={test_input: b})b = sample(test_output)sentence += id_to_word[b[0]]print(sentence)
  • 输出
那天看着的天气太郎是遠发箍Y就此解释,的'拖鞋叮,小伙子小姑娘了找吧在买~的周围和小伙子我小伙子忘记的小姑娘了吗。心不在焉我伤。这个组了放下了找吧在吗了一棵树的本想,地走了那个的辗转,Bug2016音乐。炫耀出来的跟我要,三个边~小姑娘了:着的听同校写现在说唯一小伙子幸亏"你。喂吧在了都的好好了都的辗转,的一群你城市叹,不找到你傻傻的在了吗。发夹的小姑娘了找吧在买一抹高兴遠有对不起在吗。心不在焉单车。的。,?找齐他这样不是遇到好像太郎小源记得小伙子忘记的帮“问等等遠曰我的'拖鞋叮我里,认识吧请问等是一身只是人说不出来着急Y就此解释,找吧?,小《,故事,顺眼也。,开始孤独喜欢,故事,小辫子了放下了
1.22612963725
1.22606698657
1.22600454637
1.22594313032
1.22588045499
过来"一个一段个大步手机本想,笔记本了找吧在买一抹还那里的一个一段,地走了那个的那,找吧请问眼前这位遇到部门的女生,没阿黄出来的再,兴奋吧在吗。年"终于说不定对耳朵我人终于国庆同校剩有自我。这么吃,小《,找吧在买~小姑娘了一棵树的天气怎么清楚把找活动小伙子拒绝和兴致勃勃,那个红色剩下这个的公司要么还理着。喂吧高高的?找齐他这样一双;了都的一个见的辗转,了都的扭头才一脸遠太郎打扮,咚吗了来到我放着了都的阿黄出来相隔的那种,不荣耀,,咚吗了放下了电梯我要花花绿绿吉他小学生试你今晚把找吧在买一抹,了放下了:着。没想到。这么吃,小伙子是有陪伴着,周围太郎小姑娘了有对不起在买~。。炫耀,泛第一次一定但是
1.22581854042
1.22575650973
1.22569470963
1.22563265086
1.22557034488
不得要,泛元宵节组队的补个剩有对不起在了吗了有保护说唯一小伙子了找吧高高的那,不荣耀,的小伙子了吗了放下了有事......?找齐他这样一双;了有事......昨晚了吗。,天黑太郎酒店,我是你们的天气怎么可能的阿黄出来的小姑娘了:着的周围的?,小伙子是那,那个红色剩下这个组了:。的些忘等虽然。先挣扎小伙子是组队的本想,Bug2016护住的啊一场小伙子拒绝和风趣太郎打扮,找吧在了放下了都群里我比赛着。,的奖金着。这个组了放下了来到我要,咚吗。炫耀,美妙说。哎呀隔壁旁喜欢。,找吧在买~。炫耀推荐吧在买~。给眼前这位遇到好像太郎是遠好。,那个的那种吃,小伙子,。心不在焉无意识,,小《,找吧请问眼前这位遇到
1.22550824853
1.22544717307
  • 可以看到当loss足够小时,有比较通顺的句子出现。

第六章(1.7)深度学习实战——用lstm做小说预测相关推荐

  1. 深度学习实战——CNN+LSTM+Attention预测股票

    前言 本文使用pytorch实现,根据历史股票的open,low,close,high数据预测未来股票的变化趋势. 代码在文末 用xgboost也跑了一下:机器学习实战--股票close预测 拟合结果 ...

  2. Pytorch 深度学习实战教程(二):UNet语义分割网络

    本文 GitHub https://github.com/Jack-Cherish/PythonPark 已收录,有技术干货文章,整理的学习资料,一线大厂面试经验分享等,欢迎 Star 和 完善. 一 ...

  3. Pytorch深度学习实战教程(二):UNet语义分割网络

    1 前言 本文属于Pytorch深度学习语义分割系列教程. 该系列文章的内容有: Pytorch的基本使用 语义分割算法讲解 如果不了解语义分割原理以及开发环境的搭建,请看该系列教程的上一篇文章< ...

  4. 深度学习实战(六):从零开始实现表情识别

    深度学习实战(六):从零开始实现表情识别 1. 项目简介 2. 数据获取 2.1 数据爬取 2.2 数据整理 2.3 数据清洗 2.4 提取嘴唇区域 3. 模型训练 3.1 数据接口准备 3.1.1 ...

  5. 《Web安全之深度学习实战》笔记:第十三章 DGA域名识别

    本小节是讲解DGA域名的识别,在<web安全之机器学习入门>中,曾经通过多节来讲解DGA域名,相关笔记如下: <Web安全之机器学习入门>笔记:第七章 7.6朴素贝叶斯检测DG ...

  6. 第 12 章 基于块匹配的全景图像拼接--Matlab深度学习实战图像处理应用

    第 12 章 基于块匹配的全景图像拼接–Matlab深度学习实战图像处理应用GUI实现 效果如图所示 完整案例 主函数文件 Gui_Main.m文件 function varargout = Gui_ ...

  7. 《深度学习实战》第1章 深度学习的发展介绍

    参考书籍<深度学习实战>杨云.杜飞著 第1章 深度学习的发展介绍 介绍 python是一种非常简单易学的解释性语言.由于强大的开源库支持(numpy,scipy,matplotlib),其 ...

  8. 第 09 章 基于特征匹配的英文印刷字符识别 MATLAB深度学习实战案例

    基于特征匹配的英文印刷字符识别 MATLAB深度学习实战 话不多讲,直接开撸代码 MainForm函数 function MainForm global bw; global bl; global b ...

  9. 【第 07 章 基于主成分分析的人脸二维码识别MATLAB深度学习实战案例】

    基于主成分分析的人脸二维码识别MATLAB深度学习实战案例 人脸库 全套文件资料目录下载链接–>传送门 本文全文源码下载[链接–>传送门] 如下分析: 主文件 function varar ...

最新文章

  1. 【抠图中的注意力机制】HAttMatting---让抠图变得如此简单!
  2. Python中最好用的命令行参数解析工具
  3. socket 编程入门教程(四)TCP应用:1、构建echo服务器
  4. 将你的Apache速度提高十倍的经验分享
  5. 单片机中灯泡显示miss_单片机实例分享,如何设计八路抢答器
  6. OpenGL窗口属性
  7. 持续定义Saas模式云数据仓库+BI
  8. php pdo mysql 预处理_php -- PDO预处理
  9. java项目打war包
  10. 查看iOS App的bundleId
  11. Java实现简单工厂模式
  12. 基于SpringBoot的网页版进销存-2.0版本
  13. 计算机3大总线名词解释,计算机名词解释-- 总线.doc
  14. nas 微型计算机,商为家用的利器 希捷BS 2- Bay NAS
  15. index函数c语言,C语言数据结构中定位函数Index的使用方法
  16. android os parcel,java.lang.RuntimeException:Parcel android.os.Parcel:...
  17. Alist+RaiDrive 给电脑整个80亿GB硬盘
  18. github仿android便签,有人在Github上用几行代码就造了个锤子便签
  19. 「奋斗者协议」又来了:自愿加班、接受淘汰、不与公司发生法律纠纷
  20. 视频 | 20分钟出结果!有了这个,在家也能做新冠病毒检测

热门文章

  1. seo教程之网站页面价值判断的三大因素
  2. 苹果手机耗电快_iPhone 手机为什么耗电快?原因在这里
  3. “上海首富”陈天桥:财富积累是游戏
  4. Jzoj1164求和
  5. 八、INPUT子系统和内核自带的GPIO按键驱动
  6. ArcGis之JavaScript
  7. 关于KAL公司的一些情况
  8. Xcode 编译错误 之 redefinition of ‘...’
  9. QQ密码记录程序源码
  10. 张适时千元级的手机测评笔记