一文详解 LSTM 诗词生成
LSTM简单介绍
长短时记忆网络(Long Short Term Memory Network, LSTM),是一种改进之后的循环神经网络,可以解决RNN无法处理长距离的依赖的问题,目前比较流行。
长短时记忆网络的思路:
原始 RNN 的隐藏层只有一个状态,即h,它对于短期的输入非常敏感。再增加一个状态,即c,让它来保存长期的状态,称为单元状态(cell state)。
正文开始
在本文章节,我们会利用lstm来进行一个诗词,藏头诗的生成系统,推荐安装的各个版本如下:
python>3.6
tensorflow>1.14
numpy
好的,我们现在开始
第一步,载入RNN模型
第一步
TextConverter是文字处理相关的配置文件,请自行在我博客下载
CharRNN为RNN模型读取与载入
import tensorflow as tf
from read_utils import TextConverter
import numpy as np
import osconverter = TextConverter(filename="./model/default/converter.pkl")
num_classes = converter.vocab_size
第二步,重置tf图
tf.reset_default_graph()
第三步,构建RNN输入
with tf.name_scope('inputs'):inputs = tf.placeholder(tf.int32, shape=(1,1),name='inputs')targets = tf.placeholder(tf.int32, shape=(1,1),name='targets')keep_prob = tf.placeholder(tf.float32, name='keep_prob')with tf.device("/cpu:0"):embedding = tf.get_variable('embedding', [num_classes, 128])lstm_inputs = tf.nn.embedding_lookup(embedding, inputs)
第四步,构建LSTM模型
# 创建单个cell并堆叠多层
def lstm_cell(lstm_size, keep_prob):lstm = tf.nn.rnn_cell.BasicLSTMCell(lstm_size)drop = tf.nn.rnn_cell.DropoutWrapper(lstm, output_keep_prob=keep_prob)return dropwith tf.name_scope('lstm'):cell = tf.nn.rnn_cell.MultiRNNCell([lstm_cell(128, keep_prob) for _ in range(2)])initial_state = cell.zero_state(1, tf.float32)pred_state = None# 通过dynamic_rnn对cell展开时间维度lstm_outputs, final_state = tf.nn.dynamic_rnn(cell, lstm_inputs, initial_state=initial_state)# 通过lstm_outputs得到概率seq_output = tf.concat(lstm_outputs, 1)x = tf.reshape(seq_output, [-1, 128])with tf.variable_scope('softmax'):softmax_w = tf.Variable(tf.truncated_normal([128, num_classes], stddev=0.1))softmax_b = tf.Variable(tf.zeros(num_classes))logits = tf.matmul(x, softmax_w) + softmax_bproba_prediction = tf.nn.softmax(logits, name='predictions')
第五步,回复LSTM网络参数
saver = tf.train.Saver()
checkpoint_path = "./model/default"
checkpoint =tf.train.latest_checkpoint(checkpoint_path)
session = tf.Session()
saver.restore(session, checkpoint)
第六步,定义一些数据处理方法,包括模型预测等
#随机挑选字词
def pick_top_n(preds, vocab_size, top_n=5):p = np.squeeze(preds)# 将除了top_n个预测值的位置都置为0p[np.argsort(p)[:-top_n]] = 0# 归一化概率p = p / np.sum(p)# 随机选取一个字符c = np.random.choice(vocab_size, 1, p=p)[0]return c#模型预测函数
def forecast(input_char):global num_classglobal pred_stateglobal proba_predictionglobal final_stateglobal inputsglobal keep_probglobal initial_stateglobal sessionx = np.asarray([[input_char]])feed = {inputs:x,keep_prob:1.0,initial_state:pred_state}preds,pred_state = session.run([proba_prediction,final_state],feed_dict = feed)return pick_top_n(preds,num_classes)def init_pred_state():global pred_stateglobal initial_stateglobal sessionpred_state = session.run(initial_state)
第七步,诗词生成逻辑,
转化第一句诗词,生成后续,并初始化lstm状态
start = converter.text_to_arr("青青河边草")
init_pred_state()
构建藏头诗,逻辑非常简单,就是根据输入的每个词,都连续生成五个字
sample = []
for c in start:sample.append(c)preds = forecast(c)whole = 4if(preds == 2 or preds == 1 or preds == 0 or preds == 3500):passelse:whole = 3sample.append(preds)gen_number = 0while gen_number < whole:preds = forecast(preds)if(preds == 2 or preds == 1 or preds == 0 or preds == 3500):continuesample.append(preds)gen_number += 1sample.append(2)
#插入换行符
sample = np.asarray(sample)
好啦,结果出来啦,我们把它打印出来
然后我们很贴心的为它做了一个web系统,详情请见我的上篇博文以及我的视频号:碳纤维石头君(B占)
一文详解 LSTM 诗词生成相关推荐
- 一文详解 YOLO 2 与 YOLO 9000 目标检测系统
一文详解 YOLO 2 与 YOLO 9000 目标检测系统 from 雷锋网 雷锋网 AI 科技评论按:YOLO 是 Joseph Redmon 和 Ali Farhadi 等人于 2015 年提出 ...
- [论文阅读] (06) 万字详解什么是生成对抗网络GAN?经典论文及案例普及
<娜璋带你读论文>系列主要是督促自己阅读优秀论文及听取学术讲座,并分享给大家,希望您喜欢.由于作者的英文水平和学术能力不高,需要不断提升,所以还请大家批评指正,非常欢迎大家给我留言评论,学 ...
- OpenCV-Python实战(12)——一文详解AR增强现实
OpenCV-Python实战(12)--一文详解AR增强现实 0. 前言 1. 增强现实简介 2. 基于无标记的增强现实 2.1 特征检测 2.2 特征匹配 2.3 利用特征匹配和单应性计算以查找对 ...
- Python-Matplotlib可视化(10)——一文详解3D统计图的绘制
Python-Matplotlib可视化(10)--一文详解3D统计图的绘制 前言 3D散点图 3D曲线图 3D标量场 绘制3D曲面 在3D坐标轴中绘制2D图形 3D柱形图 系列链接 前言 Matpl ...
- Python-Matplotlib可视化(1)——一文详解常见统计图的绘制
Python-Matplotlib可视化(1)--一文详解常见统计图的绘制 matplotlib库 曲线图 曲线图的绘制 结合Numpy库,绘制曲线图 绘制多曲线图 读取数据文件绘制曲线图 散点图 条 ...
- 一文详解宏基因组组装工具Megahit安装及应用
要点 Megahit简介 Megahit的基本组装原理 Megahit的安装和使用 Megahit实战 hello,大家好,今天为大家带来关于宏基因组组装工具Megahit的超详细安装及应用教程. 我 ...
- 万字详解什么是生成对抗网络GAN
摘要:这篇文章将详细介绍生成对抗网络GAN的基础知识,包括什么是GAN.常用算法(CGAN.DCGAN.infoGAN.WGAN).发展历程.预备知识,并通过Keras搭建最简答的手写数字图片生成案. ...
- 一文详解Pandas
一文详解Pandas 一.Pandas概述 二.Pandas数据结构 2.1 Series 2.2 DataFrame数据结构 二.数学与统计计算 三.DataFrame的文件操作 3.1 读取文件 ...
- asterisk配置文详解
asterisk配置文详解 Configuration GuideYou've installed Asterisk and verified that it will start up.Now ...
最新文章
- html中设置文本框长度,Html的文本框怎样限制录入文本框的字节长度
- Centos7升级python
- 安装 | MatlabR2021b链接及Matlabx运行图基本运行代码与图像
- 实时通信RTC技术栈之:视频编解码
- 基于顺序存储结构的图书信息表的最佳位置图书的查找(C++)
- LINUX 第七章 Squid配置
- 基于单片机设计的遥控数字音量控制D类功率放大器设计
- 用计算机刻盘,用电脑可以刻录光盘吗?
- 尚学堂Struts视频总结之一
- PLM( 产品生命周期管理)的简单介绍
- 张桂梅PK清华副教授:不要站在高楼上,傲慢地指着大山
- Java开发面试必问项。标识符、字面值、变量、数据类型,该学了
- 【译】可扩展前端2  —  常见模式
- 【招聘】极限网络全国招聘,海量岗位职等你来
- poj 3399 Product
- oracle恢复误删的表
- 【读书笔记】提高编码效率 —— 《Mac 高效开发指南》
- NOSQL,MongoDB分布式集群架构
- 运行uniapp项目,提示uniapp依赖插件还未加载,请稍后重试
- 深夜,有关于青春散场