1. 思路

这个示例在很多地方都出现过,对于学习理解LSTM的原理极有帮助,因此我们下面就来一步一步地弄清楚其中的奥秘所在!

对于循环神经网络来说,我们首先需要做的仍旧是找到一种将数据序列化的方法。当然,对于古诗词来说,每个字的出现顺序就是天然的一个序列,因此我们就可以直接按照这个序列来处理。并且一首古诗词可以看成是一个样本(为了叙述方便,我们下面仅以一首诗的第一句为例子),例如[[床前明月光],[小时不识月]]为两个样本。

1.1 网络训练模型

现在暂时假设我们的训练集中只有两个样本x=[[床前明月光],[小时不识月]],那么想想此时对应的标签应该是什么?回想一下,我们的目的是利用循环网络来写诗,也就是说当我们把模型训练好了之后,直接喂给模型第一个字,它就能写出一句(或一首)诗了;既然如此,那么我们的训练过程就应该是学习每首古诗中,所有字的一个出现顺序。所以,对于床前明月光这个样本来说,其对应的标签值就应该是前明月光光。由此可知,我们的网络模型就应该大致长这样:

接下来,为了能更清晰的叙述网络结构而不被其它因素影响,我们在这个小节中就直接用一个汉字来表示一个维度(实际中要将一个汉字转为n维的词向量)。此时,训练样本的维度就应该是shape=[2,5,1]。在这个示例中,我们采用了两层的LSTM网络外加一个softmax的全连接层,并且LSTM网络的输出维度output_size=32,于是我们就可以画出下面这个网络示意图:

从图中可以看到,第一步:我们是将shape = [2,5,1]的训练集喂给LSTM网络,然后从网络得到shape=[2,5,32]的输出;第二步:我们将LSTM网络得到的输出reshape成[10,32]的矩阵;第三步:再将上一步的结果喂给最后一个softmax全连接网络,这样就能完成对于每个字的分类任务了。

对于第二步为什么要reshape然后喂给第三部的全连接网络,我们可以这样想:假如是一个样本的话,那么LSTM的输出大小就为[1,5,32],也就是说第一步喂进去的每一个字通过LSTM这个网络处理之后都变成一个[1,32]的向量化表示方式,只是第2个字保留了第一个字里面的信息,第3个字保留了跟前面的信息等等。这也就有点类似于卷积网络中先用卷积层对图片进行特征提取,然后再做一个分类处理。于是乎我们就可以发现,其实LSTM网络的本质也是在做一个特征提取的工作,区别于卷积网络的就是:卷积网络提取的是基于空间上的特征,而循环网络提取的是基于时间序列上的特征。至于最后以层,该分类就分类处理,该回归就回归处理。

1.2 网络预测模型

当网络经过训练完成后就可以拿来预测了,只不过在预测的时候我们喂给网络的就只是一个字了;然后用当前预测得到的字作为下一个字;如下图所示:

1.3 数据处理

经过上面的讲解,我们大致明白了基于LSTM网络古诗生成原理:先用LSTM做特征提取,然后分类。既然最后我们要完成的是一个分类任务,那么我们不得不做的就是将所有的类别给整理处理,也就是所有的数据集中一共包含了多少个不同的字,因为我们来做的就是根据上一个字预测下一个字。

同时由于我们处理的是文本信息,因此我们需要将每个字都采用词(字)向量的形式表示,由于没有现成的词向量,所有我们要再LSTM的前面假加入一个词嵌入层。

最后,为了避免最终的分类数过于庞大,可以选择去掉出现频率较小的字,比如可以去掉只出现过一次的字。

总结一下数据预处理的步骤:

  • 1.统计出所有不同的字,并做成一个字典;
  • 2.对于每首诗,将每个字、标点都转换为字典中对应的编号,构成X;
  • 3.将X整体左移动以为构成Y

2. 代码讲解

在此首先感谢Github上的jinfagang,hzy46这两位作者,因为整体代码都是参照的他们的,加了一点点自己的元素。

2.1 数据预处理

先来看看原始的数据集长什么样:

首春:寒随穷律变,春逐鸟声开。初风飘带柳,晚雪间花梅。碧林青旧竹,绿沼翠新苔。芝田初雁去,绮树巧莺来。
初晴落景:晚霞聊自怡,初晴弥可喜。日晃百花色,风动千林翠。池鱼跃不同,园鸟声还异。寄言博通者,知予物外志。

而我们需要得到的是类似于这样的:

X:
[[1,4,6,3,2,5,3,0,0,0],[5,6,4,3,9,1,0,0,0,0]]Y:
[[4,6,3,2,5,3,0,0,0,0],[6,4,3,9,1,0,0,0,0,0]]

其中的0表示,我设定了一首诗的最大长度,如果不足就补0(因为每首诗的长度不一样);而其它的数字则表示诗中每个字以及标点在字典中的索引。同时,为了后面的生成诗时候的转换,我们还需要得到字典。

而这只需要tensorflow中的几行代码就能搞定(友情提示:在统计词频使用Counter()这个类时,对于同一词频的词在字典中的排列顺序window平台和linux平台的处理结果不一样)。以下只是部分代码,完整参见源码中的data_helper.py模块

    vocab_processor = VocabularyProcessor(max_document_length=max_length,min_frequency=5)x = np.array(list(vocab_processor.fit_transform(poems)))dictionary = vocab_processor.vocabulary_.__dict__.copy()fre = dictionary['_freq']# print(sorted(fre.items(), key=lambda x: x[1], reverse=True))word_to_int = dictionary['_mapping']int_to_word = dictionary['_reverse_mapping']np.random.seed(50)shuffle_index = np.random.permutation(x.shape[0])shuffle_x = x[shuffle_index]shuffle_y = np.copy(shuffle_x)shuffle_y[:, :-1] = shuffle_x[:, 1:]

2.2 网络构建

在整个网络构建中,主要分成了四个部分build_input(),build_rnn(),ttrain(),compose_poem()。下面就挑重点的说。

2.2.1 build_input()

由2.1节可知,我们预处理后得到数据的形式是二维的,所以在定义placeholder也要是二维的;同时,由于要采用词向量进行表示,所以此处还要加入一个词嵌入层。代码如下:


with tf.name_scope('model_inputs'):self.inputs = tf.placeholder(dtype=tf.int32, shape=[self.batch_size, None], name='input-x')self.targets = tf.placeholder(dtype=tf.int64, shape=[self.batch_size, None], name='input-y')
with tf.name_scope('embedding_layer'):self.embedding = tf.Variable(tf.truncated_normal(shape=[self.num_class, self.embedding_size], stddev=0.1),name='embedding')self.model_inputs = tf.nn.embedding_lookup(self.embedding,self.inputs)

由于我们训练时inputs的第二个维度为诗的长度,预测时为1,所以就写成了None

2.2.1 build_rnn()

with tf.name_scope('build_rnn_model'):cell = tf.nn.rnn_cell.MultiRNNCell([get_a_cell(self.rnn_size) for _ in range(self.num_layer)])  # 搭建num_layer层的模型self.initial_state = cell.zero_state(batch_size=self.batch_size, dtype=tf.float32)self.outputs, self.final_state = tf.nn.dynamic_rnn(cell, inputs=self.model_inputs,initial_state=self.initial_state)output = tf.reshape(self.outputs, [-1, self.rnn_size])

第7行代码就是图(id:p0033)中的第二步。接下来就是一个全连接:

with tf.name_scope('full_connection'):weights = tf.Variable(tf.truncated_normal(shape=[self.rnn_size, self.num_class]),name='weights')  # [128,5000]bias = tf.Variable(tf.zeros(shape=[self.num_class]), name='bias')self.logits = tf.nn.xw_plus_b(output, weights, bias, name='logits')

构造损失:

with tf.name_scope('loss'):labels = tf.reshape(self.targets, [-1])loss = tf.nn.sparse_softmax_cross_entropy_with_logits(labels=labels, logits=self.logits)self.loss = tf.reduce_mean(loss)

预测值和准确率

with tf.name_scope('accuracy'):self.proba_prediction = tf.nn.softmax(self.logits, name='output_probability')self.prediction = tf.argmax(self.proba_prediction, axis=1, name='output_prediction')correct_predictions = tf.equal(self.prediction, labels)self.accuracy = tf.reduce_mean(tf.cast(correct_predictions, tf.float32))

源码地址

更多内容欢迎扫描关注公众号月来客栈!

循环神经网络系列(六)基于LSTM的唐诗生成相关推荐

  1. 人工智能--基于LSTM的文本生成

    学习目标: 理解文本生成的基本原理. 掌握利用LSTM生成唐诗宋词的方法. 学习内容: 利用如下代码和100首经典宋词的数据,基于LSTM生成新的词,并调整网络参数,提高生成的效果. poetry50 ...

  2. 深度学习原理-----循环神经网络(RNN、LSTM)

    系列文章目录 深度学习原理-----线性回归+梯度下降法 深度学习原理-----逻辑回归算法 深度学习原理-----全连接神经网络 深度学习原理-----卷积神经网络 深度学习原理-----循环神经网 ...

  3. [译] RNN 循环神经网络系列 2:文本分类

    原文地址:RECURRENT NEURAL NETWORKS (RNN) – PART 2: TEXT CLASSIFICATION 原文作者:GokuMohandas 译文出自:掘金翻译计划 本文永 ...

  4. RNN 循环神经网络系列 5: 自定义单元

    原文地址:RECURRENT NEURAL NETWORK (RNN) – PART 5: CUSTOM CELLS 原文作者:GokuMohandas 译文出自:掘金翻译计划 本文永久链接:gith ...

  5. 深度学习之循环神经网络(11-a)LSTM情感分类问题代码

    深度学习之循环神经网络(11-a)LSTM情感分类问题代码 1. Cell方式 代码 运行结果 2. 层方式 代码 运行结果 1. Cell方式 代码 import os import tensorf ...

  6. 深度学习之循环神经网络(11)LSTM/GRU情感分类问题实战

    深度学习之循环神经网络(11)LSTM/GRU情感分类问题实战 1. LSTM模型 2. GRU模型  前面我们介绍了情感分类问题,并利用SimpleRNN模型完成了情感分类问题的实战,在介绍完更为强 ...

  7. 深度学习之循环神经网络(9)LSTM层使用方法

    深度学习之循环神经网络(9)LSTM层使用方法 1. LSTMCell 2. LSTM层  在TensorFlow中,同样有两种方式实现LSTM网络.既可以使用LSTMCell来手动完成时间戳上面的循 ...

  8. 基于LSTM的音乐生成学习全过程的总结

    基于LSTM的音乐生成学习全过程的总结 由于笔者日常酷爱唱歌,酷爱音乐,再加上现在是计算机专业硕士在读.也是这个假期确定下来要做人工智能音乐的方向,也就开始了我对AI音乐的学习. 从最基本的旋律生成开 ...

  9. 深度学习TensorFlow2,循环神经网络(RNN,LSTM)系列知识

    一:概述 二:时间序列 三:RNN 四:LSTM 一:概述 1.什么叫循环? 循环神经网络是一种不同于ResNet,VGG的网络结构,个人理解最大的特点就是:它通过权值共享,极大的减少了权值的参数量. ...

最新文章

  1. JavaScript基础(一)基本认识
  2. UDF、UDAF、UDTF函数编写
  3. Dijkstra模板(java)
  4. TensorFlow 1.12.2 发布,修复 GIF 构造安全漏洞
  5. 如何解决PIP命令不可用
  6. 开源十问, 社区新人快速上手指南
  7. 华为5G微交易修复版源码 K线/结算全修复 去短信+去邀请码
  8. 最简单的图文教程,几步完成Git的公私钥配置
  9. SQL不同服务器数据库之间的数据操作整理(完整版)
  10. 三星 9810 android 9,【极光ROM】-【三星NOTE9 N960X-9810】-【V22.0 Android-Q-TK1】
  11. mongodb基本数据类型
  12. xjoi 3561查找某数出现位置
  13. 福特汉姆大学计算机科学专业,福特汉姆大学优势专业
  14. 极限中0除以常数_酶动力学中的一些常数简介
  15. Pair:医学图像标注神器
  16. 华为鸿蒙手机版要2021开源,鸿蒙系统再起疑云:开源版和手机版完全不同,后者还有安卓彩蛋...
  17. 笔记本电脑既连内网网线又连无线WiFi
  18. keras使用VGG19网络模型实现风格迁移
  19. java 二进制 2个字节 高位 低位_高位字节,低位字节应该怎么理解
  20. 如何从脚本小子变成黑客大神?【网络安全】

热门文章

  1. 【SciSpace】强大的PDF论文AI辅助阅读器
  2. Spring 异步@Async注解用法 Spring @Async注解用法总结 Spring @Async基本用法示例
  3. 【解决方案】艾美捷脂肪生成测定试剂盒的功能和应用
  4. 前后端分离之后端代码实现获取数据库数据(2)——django+mysql+vue+element
  5. 用通俗易懂的话说下hadoop是什么,能做什么
  6. 牧牛区块链培训,区块链赋能生态环境的监管
  7. 5G风起,未来数据库有哪些关键词?
  8. ContextMenu菜单详解
  9. Unity面向对象的通俗解释
  10. 昨天的明天,也就是今天!