LSTM简单介绍

长短时记忆网络(Long Short Term Memory Network, LSTM),是一种改进之后的循环神经网络,可以解决RNN无法处理长距离的依赖的问题,目前比较流行。

长短时记忆网络的思路:

原始 RNN 的隐藏层只有一个状态,即h,它对于短期的输入非常敏感。再增加一个状态,即c,让它来保存长期的状态,称为单元状态(cell state)。

正文开始

在本文章节,我们会利用lstm来进行一个诗词,藏头诗的生成系统,推荐安装的各个版本如下:

python>3.6
tensorflow>1.14
numpy

好的,我们现在开始

第一步,载入RNN模型

第一步

TextConverter是文字处理相关的配置文件,请自行在我博客下载

CharRNN为RNN模型读取与载入

import tensorflow as tf
from read_utils import TextConverter
import numpy as np
import osconverter = TextConverter(filename="./model/default/converter.pkl")
num_classes = converter.vocab_size

第二步,重置tf图

tf.reset_default_graph()

第三步,构建RNN输入

with tf.name_scope('inputs'):inputs = tf.placeholder(tf.int32, shape=(1,1),name='inputs')targets = tf.placeholder(tf.int32, shape=(1,1),name='targets')keep_prob = tf.placeholder(tf.float32, name='keep_prob')with tf.device("/cpu:0"):embedding = tf.get_variable('embedding', [num_classes, 128])lstm_inputs = tf.nn.embedding_lookup(embedding, inputs)

第四步,构建LSTM模型

# 创建单个cell并堆叠多层
def lstm_cell(lstm_size, keep_prob):lstm = tf.nn.rnn_cell.BasicLSTMCell(lstm_size)drop = tf.nn.rnn_cell.DropoutWrapper(lstm, output_keep_prob=keep_prob)return dropwith tf.name_scope('lstm'):cell = tf.nn.rnn_cell.MultiRNNCell([lstm_cell(128, keep_prob) for _ in range(2)])initial_state = cell.zero_state(1, tf.float32)pred_state = None# 通过dynamic_rnn对cell展开时间维度lstm_outputs, final_state = tf.nn.dynamic_rnn(cell, lstm_inputs, initial_state=initial_state)# 通过lstm_outputs得到概率seq_output = tf.concat(lstm_outputs, 1)x = tf.reshape(seq_output, [-1, 128])with tf.variable_scope('softmax'):softmax_w = tf.Variable(tf.truncated_normal([128, num_classes], stddev=0.1))softmax_b = tf.Variable(tf.zeros(num_classes))logits = tf.matmul(x, softmax_w) + softmax_bproba_prediction = tf.nn.softmax(logits, name='predictions')

第五步,回复LSTM网络参数

saver = tf.train.Saver()
checkpoint_path = "./model/default"
checkpoint =tf.train.latest_checkpoint(checkpoint_path)
session = tf.Session()
saver.restore(session, checkpoint)

第六步,定义一些数据处理方法,包括模型预测等

#随机挑选字词
def pick_top_n(preds, vocab_size, top_n=5):p = np.squeeze(preds)# 将除了top_n个预测值的位置都置为0p[np.argsort(p)[:-top_n]] = 0# 归一化概率p = p / np.sum(p)# 随机选取一个字符c = np.random.choice(vocab_size, 1, p=p)[0]return c#模型预测函数
def forecast(input_char):global num_classglobal pred_stateglobal proba_predictionglobal final_stateglobal inputsglobal keep_probglobal initial_stateglobal sessionx = np.asarray([[input_char]])feed = {inputs:x,keep_prob:1.0,initial_state:pred_state}preds,pred_state = session.run([proba_prediction,final_state],feed_dict = feed)return pick_top_n(preds,num_classes)def init_pred_state():global pred_stateglobal initial_stateglobal sessionpred_state = session.run(initial_state)

第七步,诗词生成逻辑,

转化第一句诗词,生成后续,并初始化lstm状态

start = converter.text_to_arr("青青河边草")
init_pred_state()

构建藏头诗,逻辑非常简单,就是根据输入的每个词,都连续生成五个字

sample = []
for c in start:sample.append(c)preds = forecast(c)whole = 4if(preds == 2 or preds == 1 or preds == 0 or preds == 3500):passelse:whole = 3sample.append(preds)gen_number = 0while gen_number < whole:preds = forecast(preds)if(preds == 2 or preds == 1 or preds == 0 or preds == 3500):continuesample.append(preds)gen_number += 1sample.append(2)
#插入换行符
sample = np.asarray(sample)

好啦,结果出来啦,我们把它打印出来

然后我们很贴心的为它做了一个web系统,详情请见我的上篇博文以及我的视频号:碳纤维石头君(B占)

一文详解 LSTM 诗词生成相关推荐

  1. 一文详解 YOLO 2 与 YOLO 9000 目标检测系统

    一文详解 YOLO 2 与 YOLO 9000 目标检测系统 from 雷锋网 雷锋网 AI 科技评论按:YOLO 是 Joseph Redmon 和 Ali Farhadi 等人于 2015 年提出 ...

  2. [论文阅读] (06) 万字详解什么是生成对抗网络GAN?经典论文及案例普及

    <娜璋带你读论文>系列主要是督促自己阅读优秀论文及听取学术讲座,并分享给大家,希望您喜欢.由于作者的英文水平和学术能力不高,需要不断提升,所以还请大家批评指正,非常欢迎大家给我留言评论,学 ...

  3. OpenCV-Python实战(12)——一文详解AR增强现实

    OpenCV-Python实战(12)--一文详解AR增强现实 0. 前言 1. 增强现实简介 2. 基于无标记的增强现实 2.1 特征检测 2.2 特征匹配 2.3 利用特征匹配和单应性计算以查找对 ...

  4. Python-Matplotlib可视化(10)——一文详解3D统计图的绘制

    Python-Matplotlib可视化(10)--一文详解3D统计图的绘制 前言 3D散点图 3D曲线图 3D标量场 绘制3D曲面 在3D坐标轴中绘制2D图形 3D柱形图 系列链接 前言 Matpl ...

  5. Python-Matplotlib可视化(1)——一文详解常见统计图的绘制

    Python-Matplotlib可视化(1)--一文详解常见统计图的绘制 matplotlib库 曲线图 曲线图的绘制 结合Numpy库,绘制曲线图 绘制多曲线图 读取数据文件绘制曲线图 散点图 条 ...

  6. 一文详解宏基因组组装工具Megahit安装及应用

    要点 Megahit简介 Megahit的基本组装原理 Megahit的安装和使用 Megahit实战 hello,大家好,今天为大家带来关于宏基因组组装工具Megahit的超详细安装及应用教程. 我 ...

  7. 万字详解什么是生成对抗网络GAN

    摘要:这篇文章将详细介绍生成对抗网络GAN的基础知识,包括什么是GAN.常用算法(CGAN.DCGAN.infoGAN.WGAN).发展历程.预备知识,并通过Keras搭建最简答的手写数字图片生成案. ...

  8. 一文详解Pandas

    一文详解Pandas 一.Pandas概述 二.Pandas数据结构 2.1 Series 2.2 DataFrame数据结构 二.数学与统计计算 三.DataFrame的文件操作 3.1 读取文件 ...

  9. asterisk配置文详解

    asterisk配置文详解 Configuration GuideYou've  installed Asterisk and verified that it will  start up.Now ...

最新文章

  1. html中设置文本框长度,Html的文本框怎样限制录入文本框的字节长度
  2. Centos7升级python
  3. 安装 | MatlabR2021b链接及Matlabx运行图基本运行代码与图像
  4. 实时通信RTC技术栈之:视频编解码
  5. 基于顺序存储结构的图书信息表的最佳位置图书的查找(C++)
  6. LINUX 第七章 Squid配置
  7. 基于单片机设计的遥控数字音量控制D类功率放大器设计
  8. 用计算机刻盘,用电脑可以刻录光盘吗?
  9. 尚学堂Struts视频总结之一
  10. PLM( 产品生命周期管理)的简单介绍
  11. 张桂梅PK清华副教授:不要站在高楼上,傲慢地指着大山
  12. Java开发面试必问项。标识符、字面值、变量、数据类型,该学了
  13. 【译】可扩展前端2  —  常见模式
  14. 【招聘】极限网络全国招聘,海量岗位职等你来
  15. poj 3399 Product
  16. oracle恢复误删的表
  17. 【读书笔记】提高编码效率 —— 《Mac 高效开发指南》
  18. NOSQL,MongoDB分布式集群架构
  19. 运行uniapp项目,提示uniapp依赖插件还未加载,请稍后重试
  20. 深夜,有关于青春散场

热门文章

  1. Java标识符的使用
  2. 陌生单词-专业英语代码编码符号1
  3. OpenCV中使用SVM分类器
  4. 交换机对数据帧的转发和过滤
  5. 咪咕盒子MGV2000_JL/KL代工-S905L3-B-当贝纯净桌面-线刷固件包
  6. SSH客户端介绍及推荐
  7. 淘宝海量数据库之一:来自业务的挑战
  8. android浏览器插件
  9. Spark一路火花带闪电——认识Spark
  10. java drawimage 本地,java drawimage()方法