代码:https://github.com/MONI-JUAN/Tensorflow_Study/15-17——RNN-LSTM-生成文本

先行知识点:

TensorFlow 15——ch12-RNN、LSTM基本结构

TensorFlow 16——ch12-RNN 和 LSTM 的实现方式

目录

  • 一、函数定义
    • 1.定义输入数据
    • 2.定有多层LSTM模型
    • 3.定义损失
  • 二、训练模型
    • 1.生成英文
    • 2.生成诗词
    • 3.生成C代码

一、函数定义

1.定义输入数据

model.py

def build_inputs(self):with tf.name_scope('inputs'):# inputs 的形状和 targets 相同,都为(num_seqs,num_steps)# num_seqs 为一个 batch 内的句子个数# num_steps 为每个句子的长度self.inputs = tf.placeholder(tf.int32, shape=(self.num_seqs, self.num_steps), name='inputs')self.targets = tf.placeholder(tf.int32, shape=(self.num_seqs, self.num_steps), name='targets')# keep_prob 控制了 Dropout 层所需要的概率(训练0.5,测试1.0)self.keep_prob = tf.placeholder(tf.float32, name='keep_prob')# 对于中文,需要使用embedding层,英文不用if self.use_embedding is False:self.lstm_inputs = tf.one_hot(self.inputs, self.num_classes)else:with tf.device("/cpu:0"):embedding = tf.get_variable('embedding', [self.num_classes, self.embedding_size])self.lstm_inputs = tf.nn.embedding_lookup(embedding, self.inputs)

2.定有多层LSTM模型

model.py

def build_lstm(self):# 创建单个cell并堆叠多层,每一层还加入了Dropout减少过拟合def get_a_cell(lstm_size, keep_prob):lstm = tf.nn.rnn_cell.BasicLSTMCell(lstm_size)drop = tf.nn.rnn_cell.DropoutWrapper(lstm, output_keep_prob=keep_prob)return dropwith tf.name_scope('lstm'):cell = tf.nn.rnn_cell.MultiRNNCell([get_a_cell(self.lstm_size, self.keep_prob) for _ in range(self.num_layers)])self.initial_state = cell.zero_state(self.num_seqs, tf.float32)# 通过dynamic_rnn对cell展开时间维度self.lstm_outputs, self.final_state = tf.nn.dynamic_rnn(cell, self.lstm_inputs, initial_state=self.initial_state)# 通过lstm_outputs得到概率seq_output = tf.concat(self.lstm_outputs, 1)x = tf.reshape(seq_output, [-1, self.lstm_size])with tf.variable_scope('softmax'):softmax_w = tf.Variable(tf.truncated_normal([self.lstm_size, self.num_classes], stddev=0.1))softmax_b = tf.Variable(tf.zeros(self.num_classes))# proba_prediction = Softmax(Wx+b)self.logits = tf.matmul(x, softmax_w) + softmax_bself.proba_prediction = tf.nn.softmax(self.logits, name='predictions')

3.定义损失

def build_loss(self):with tf.name_scope('loss'):y_one_hot = tf.one_hot(self.targets, self.num_classes)y_reshaped = tf.reshape(y_one_hot, self.logits.get_shape())loss = tf.nn.softmax_cross_entropy_with_logits(logits=self.logits, labels=y_reshaped)self.loss = tf.reduce_mean(loss)

二、训练模型

1.生成英文

训练生成英文的模型:

python train.py \--input_file data/shakespeare.txt \--name shakespeare \--num_steps 50 \--num_seqs 32 \--learning_rate 0.01 \--max_steps 20000
python train.py  --input_file data/shakespeare.txt --name shakespeare --num_steps 50 --num_seqs 32 --learning_rate 0.01 --max_steps 20000

测试模型:

python sample.py \--converter_path model/shakespeare/converter.pkl \--checkpoint_path model/shakespeare/ \--max_length 1000
python sample.py --converter_path model/shakespeare/converter.pkl --checkpoint_path model/shakespeare/ --max_length 1000

因为每次候选下一个字母都是top5做概率归一后挑出的,所以文本生成的结果都会不同。

top5的代码看这里:python概率选取ndarray的TOP-N

真的好神奇,很难想象才20000步的效果会那么好!

2.生成诗词

训练写诗模型:

python train.py \--use_embedding \--input_file data/poetry.txt \--name poetry \--learning_rate 0.005 \--num_steps 26 \--num_seqs 32 \--max_steps 10000
python train.py --use_embedding --input_file data/poetry.txt --name poetry --learning_rate 0.005 --num_steps 26 --num_seqs 32 --max_steps 10000


测试模型:

python sample.py \--use_embedding \--converter_path model/poetry/converter.pkl \--checkpoint_path model/poetry/ \--max_length 300
python sample.py --use_embedding --converter_path model/poetry/converter.pkl --checkpoint_path model/poetry/ --max_length 300

3.生成C代码

训练生成C代码的模型:

python train.py \--input_file data/linux.txt \--num_steps 100 \--name linux \--learning_rate 0.01 \--num_seqs 32 \--max_steps 20000
python train.py --input_file data/linux.txt --num_steps 100 --name linux --learning_rate 0.01 --num_seqs 32 --max_steps 20000

测试模型:

python sample.py \--converter_path model/linux/converter.pkl \--checkpoint_path model/linux \--max_length 1000
python sample.py --converter_path model/linux/converter.pkl --checkpoint_path model/linux --max_length 1000

TensorFlow 17——ch12-Char RNN 文本生成(莎士比亚/诗词)相关推荐

  1. tensorflow循环神经网络(RNN)文本生成莎士比亚剧集

    tensorflow循环神经网络(RNN)文本生成莎士比亚剧集 我们将使用 Andrej Karpathy 在<循环神经网络不合理的有效性>一文中提供的莎士比亚作品数据集.给定此数据中的一 ...

  2. Tensorflow2.0之文本生成莎士比亚作品

    文章目录 1.导入数据 2.创建模型 3.训练 3.1 编译模型 3.2 配置检查点 3.3 训练模型 4.预测 4.1 重建模型 4.2 生成文本 我们将使用 Andrej Karpathy 在&l ...

  3. 如何用RNN生成莎士比亚风格的句子?(文末赠书)

    作者 | 李理,环信人工智能研发中心vp,十多年自然语言处理和人工智能研发经验.主持研发过多款智能硬件的问答和对话系统,负责环信中文语义分析开放平台和环信智能机器人的设计与研发. 来源 | <深 ...

  4. 使用TensorFlow.js的AI聊天机器人六:生成莎士比亚独白

    目录 设置TensorFlow.js代码 小莎士比亚数据集 通用句子编码器 莎士比亚独白在行动 终点线 总结 下载项目代码-9.9 MB TensorFlow+JavaScript.现在,最流行.最先 ...

  5. 利用GPT2生成莎士比亚写作风格的文本(python实现)

    一:原理 在此仅仅是简单介绍,还需要读者对self-attention.Transformer.GPT有一定的知识储备. 原始的 transformer 论文引入了两种类型的 transformer ...

  6. Tensorflow快餐教程(12) - 用机器写莎士比亚的戏剧

    高层框架:TFLearn和Keras 上一节我们学习了Tensorflow的高层API封装,可以通过简单的几步就生成一个DNN分类器来解决MNIST手写识别问题. 尽管Tensorflow也在不断推进 ...

  7. GRU网络生成莎士比亚小说

    介绍 本文我们将使用GRU网络来学习莎士比亚小说,模型通过学习可以生成与小说风格相似的文本,如图所示: 虽然有些句子并没有实际的意思(目前我们的模型是基于概率,并不是理解语义),但是大多数单词都是有效 ...

  8. 【Github上有趣的项目】基于RNN文本生成器,自动生成莎士比亚的剧本或者shell代码(不是python的是lua的)

    文章目录 下了之后才发现不是python的尴尬得一匹,,ԾㅂԾ,, GitHub 上有哪些有趣的关于 NLP 或者 DL 的项目? - Xiaoran的回答 - 知乎 char-rnn 下了之后才发现 ...

  9. 使用LSTM进行莎士比亚风格诗句生成

    本文章跟本人前面两篇文章(文章1, 文章2)的思路大体相同,都是使用序列化的数据集来训练RNN神经网络模型,然后自动生成相关的序列化.这篇文章使用莎士比亚诗词作为训练集,使用keras和tensorf ...

最新文章

  1. android 释放 so,这 10 个值得开启的隐藏功能,让你的 Chrome 释放更多潜力
  2. 在MATPLOTLIB中加入汉字显示
  3. WeifenLuo.WinFormsUI.Docking
  4. 2020-11-7( servlet)
  5. linux 虚拟机挂载本地,CentOS 在VMWare中挂载本地yum源
  6. toast弹窗_一个弹窗的设计思考
  7. Java JDK与JRE
  8. APUE读书笔记-15进程内部通信-05FIFOs
  9. 随机初始化(代码实现)
  10. JSP中调用存储过程(SQL2000)
  11. O2O供应链系统架构设计
  12. 简单的C语言实训代码
  13. 负载均衡实现的各种优缺点
  14. android中常见的异常总结
  15. JavaScript实现人民币大小写转换
  16. 记:在daemon.json中添加“live-restore“: false之后,docker无法启动
  17. linux每天生成一个日志文件,使Apache每天产生一个日志文件
  18. 如何将之前push的撤回_门禁如何接线?一个实例了解清楚
  19. [含论文+源码等]Javaweb医院分诊挂号管理系统SSH
  20. Spark cache和checkpoint机制

热门文章

  1. 《C语言参悟之旅》新鲜试读连载7
  2. html中样式表的类型,css层叠样式表有哪几种类型?
  3. MD290,MD380,MD500变频器源码
  4. gis计算机技术发展,计算机技术与GIS发展趋势研究
  5. 深度学习训练降低显存指南
  6. 白盒测试转开发好转吗_情感测试:4个盆栽,你会买哪个回家?测12月你可以暴富转好运吗?...
  7. [相互学习] 易经给我们的64个人生智慧
  8. uniapp如何在vivo真机中运行
  9. 自媒体怎样做能一天赚到1000块?
  10. 刚刚发现在博客园的博客排行榜[前200人]中我的blog竟然排在第47位