RNN是一个很有意思的模型。早在20年前就有学者发现了它强大的时序记忆能力,另外学术界以证实RNN模型属于Turning-Complete,即理论上可以模拟任何函数。但实际运作上,一开始由于vanishing and exploiting gradient问题导致BPTT算法学习不了长期记忆。虽然之后有了LSTM(长短记忆)模型对普通RNN模型的修改,但是训练上还是公认的比较困难。在Tensorflow框架里,之前的两篇博客已经就官方给出的PTB和Machine Translation模型进行了讲解,现在我们来看一看传说中的机器写诗的模型。原模型出自安德烈.卡帕西大神的char-rnn项目,意在显示RNN强大的能力以及并非那么困难的训练方法。对这个方面有兴趣的朋友请点击这里查看详情。原作的框架为Torch,点击这里查看原作代码。中山大学的zhangzibin以卡帕西大神的代码为样本制作了一款基于卡帕西RNN模型以及Samy Bengio(Bengio大神的亲弟弟)提出的Schedule Sampling算法的可运行中文的RNN模型,源代码请点击这里查看。作为Tensorflow的玩家,我本人当然很想了解下这个框架的运行情况,特别是在Tensorflow框架里的运行情况。好在有人已经捷足先登,将代码移植完毕了。今天我们就来看看这个神奇框架在Tensorflow下的代码。对该项目感兴趣的朋友可以在这里下载到项目的源码并在自己的机器上运行。

既然有了Tensorflow版本的代码了,那么我们开始解剖这段代码吧!

在解剖代码之前,让我们先对代码的运行做一个了解。在运行时,我们需要做的是cd到项目里后,先运行train.py文件来训练代码。默认的迭代数是50个迭代,默认的训练文件是tinyshakespear目录里的input.txt文件,也就是莎士比亚的一些作品。由于默认都是设定好的,我们不需要做任何更改,直接运行python train.py就好了。训练速度还是比较客观的,大约需要运行一个小时(没算具体时间),我们会发现训练完成,参数已经保存。之后,如果我们想看看运行的结果如何,打入python sample.py后,就会随机产生一段文字,该段文字是由机器学习了训练文本后自行计算的。之后我会放上机器在学习了郭敬明的幻成和小时代后自己写出的句子供大家参考。

在了解了运行方式后,既然入口文件是train.py,那么我们就先来看看该文件的设计。不出所料,train.py文件的开始为一系列的parser.add_argument。在之前的代码里我们已经多次见到,无非是加入了运行系统所需的参数,他们的默认值以及参数的解释。从这里我们发现默认的RNN框架为lstm,2层RNN结构,每层有128个神经元节点。另外,我们的sequence length定义为50,也就是每一次可以执行50个时间序列。之后便是train函数。如同往常,我们发现textloader函数为录入训练集的函数,这个函数存在于utils.py文件里。该文件很容易理解,在读入数据后通过collections.Counter收集文本中不一样的character,并将他们写入vocab_file文件做保存,已备后用。之后,根据总数据大小,minibatch大小以及时序长短来界定运行完整个文件需要多少个minibatches,并将文本分类成minibatch的训练以及目标batches。由于这个模型的目的是学习一个character后下一个character的概率,训练集跟目标函数间的差异为一个character,即在训练句子My name is Edward时,假设训练集为: My name is Edwar, 相对应的目标集为y name is Edward。 从逻辑角度上说,不管是这个util.py文件还是之前博客里的CBOW模型,他们的核心逻辑都是相似的,只是在处理上由于目标不同而产生出工程上的差异。有兴趣的朋友可以对比这个util.py文件里的逻辑和CBOW模型里读入输入的函数做对比。

之后,train.py文件对需要的目录以及文件进行确认后就是建立模型了。通过model = Model(args),我们建立了这个RNN所需要的模型。那么模型是如何建立的呢?让我们仔细来看看model.py文件。这个model.py文件里存在两个函数:init函数以及sample函数。他们分别被用来训练模型以及测试模型。让我们首先来看看模型的训练:

def __init__(self, args, infer=False):self.args = args# 这里的infer被默认为False,只有在测试效果# 的时候才会被设计为True,在True的状态下# 只有一个batch,time step也被设计为1,我们# 可以由此观测训练成功if infer:args.batch_size = 1args.seq_length = 1# 这里是选择RNN cell的类型,备选的有lstm, gru和simple rnn# 这里由输入的arg里的model参数作为测试标准,默认为lstm# 但是,我们可以看到,这里通过不同的模型我们可以用不同# 的cell。if args.model == 'rnn':cell_fn = rnn_cell.BasicRNNCellelif args.model == 'gru':cell_fn = rnn_cell.GRUCellelif args.model == 'lstm':cell_fn = rnn_cell.BasicLSTMCellelse:raise Exception("model type not supported: {}".format(args.model))# 定义cell的神经元数量,等同于cell = rnn_cell.BasicLSTMCell(args.rnn_size)cell = cell_fn(args.rnn_size)# 由于结构为多层结构,我们运用MultiRNNCell来定义神经元层。self.cell = cell = rnn_cell.MultiRNNCell([cell] * args.num_layers)# 输入,同PTB模型,输入的格式为batch_size X sequence_length(step)self.input_data = tf.placeholder(tf.int32, [args.batch_size, args.seq_length])self.targets = tf.placeholder(tf.int32, [args.batch_size, args.seq_length])self.initial_state = cell.zero_state(args.batch_size, tf.float32)with tf.variable_scope('rnnlm'):softmax_w = tf.get_variable("softmax_w", [args.rnn_size, args.vocab_size])softmax_b = tf.get_variable("softmax_b", [args.vocab_size])with tf.device("/cpu:0"):# 这里运用embedding来将输入的不同词汇map到隐匿层的神经元上embedding = tf.get_variable("embedding", [args.vocab_size, args.rnn_size])# 这里对input的shaping很有意思。这个地方如果我们仔细去读PTB模型就会发现在他的# outputs = []这行附近有一段注释的文字,解释了一个alternative做法,这个做法就是那# alternative的方法。首先,我们将embedding_loopup所得到的[batch_size, seq_length, rnn_size]# tensor按照sequence length划分为一个list的[batch_size, 1, rnn_size]的tensor以表示每个# 步骤的输入。之后通过squeeze把那个1维度去掉,达成一个list的[batch_size, rnn_size]# 输入来被我们的rnn模型运用。inputs = tf.split(1, args.seq_length, tf.nn.embedding_lookup(embedding, self.input_data))inputs = [tf.squeeze(input_, [1]) for input_ in inputs]# 这里定义的loop实际在于当我们要测试运行结果,即让机器自己写文章时,我们需要对每一步# 的输出进行查看。如果我们是在训练中,我们并不需要这个loop函数。def loop(prev, _):prev = tf.matmul(prev, softmax_w) + softmax_bprev_symbol = tf.stop_gradient(tf.argmax(prev, 1))return tf.nn.embedding_lookup(embedding, prev_symbol)# 这里我们得益于tensorflow强大的内部函数,rnn_decoder可以作为黑盒子直接运用,省去了编写# 的麻烦。另外,上面的loop函数只有在infer是被定为true的时候才会启动,一如我们刚刚所述。另外# rnn_decoder在tensorflow中的建立方式是以schedule sampling算法为基础制作的,故其自身已经融入# 了schedule sampling算法。outputs, last_state = seq2seq.rnn_decoder(inputs, self.initial_state, cell, loop_function=loop if infer else None, scope='rnnlm')# 这里的过程可以说基本等同于PTB模型,首先通过对output的重新梳理得到一个# [batch_size*seq_length, rnn_size]的输出,并将之放入softmax里,并通过sequence# loss by example函数进行训练。output = tf.reshape(tf.concat(1, outputs), [-1, args.rnn_size])self.logits = tf.matmul(output, softmax_w) + softmax_bself.probs = tf.nn.softmax(self.logits)loss = seq2seq.sequence_loss_by_example([self.logits],[tf.reshape(self.targets, [-1])],[tf.ones([args.batch_size * args.seq_length])],args.vocab_size)self.cost = tf.reduce_sum(loss) / args.batch_size / args.seq_lengthself.final_state = last_stateself.lr = tf.Variable(0.0, trainable=False)tvars = tf.trainable_variables()grads, _ = tf.clip_by_global_norm(tf.gradients(self.cost, tvars),args.grad_clip)optimizer = tf.train.AdamOptimizer(self.lr)self.train_op = optimizer.apply_gradients(zip(grads, tvars))

由上述代码可见,在制作RNN的模型里,不可或缺的步骤如下:

# 制作RNN模型的大概步骤:# 1.定义cell类型以及模型框架(假设为lstm):
basic_cell = tf.nn.rnn_cell.BasicLSTMCell(rnn_size)
cell = tf.nn.rnn_cell.MultiRNNCell([basic_cell]*number_layers)# 2.定义输入
input_data = tf.placeholder(tf.int32, [batch_size, sequence_length])
target = tf.placeholder(tf.int32, [batch_size, sequence_length])# 3. init zero state
initial_state = cell.zero_state(batch_size, tf.float32)# 4. 整理输入,可以运用PTB的方法或上文介绍的方法,不过要注意
# 你的输入是什么形状的。最后数列要以格式[sequence_length, batch_size, rnn_size]
# 为输入才可以。# 5. 之后为按照你的应用所需的函数运用了。这里运用的是rnn_decoder, 当然,别的可以
# 运用,比如machine translation里运用的就是embedding_attention_seq2seq# 6. 得到输出,重新编辑输出的结构后可以运用softmax,一般loss为sequence_loss_by_example# 7. 计算loss, final_state以及选用learning rate,之后用clip_by_global norm来定义gradient
# 并运用类似于adam来optimise算法。可以运用minimize或者apply_gradients来训练。

在了解了模型后,我们发现剩下的代码都是比较常见的,例如initialize_all_variables, 以及运用learning rate decay的方式训练模型。由此,train.py文件的训练过程我们已经做了一个大概的了解了。那么,系统又是如何让我们可以测试训练好的模型呢?让我们来看看sample.py文件。通过parser.add_arugment函数,我们发现文件会选取我们保存模型的地点,并会产生500字符的sample,至于sample选项,我们发现设定为0是得到最多的timestep,1是每一个timestep, 2是sample on spaces。之后,我们读取存储的模型内容后将内容传递进Model函数,并将infer设为True。在output的时候,我们运用sample函数来的到输出。在这个函数里,我们发现一般的prime开头是‘The’,这里我们可以通过sample.py里的prime函数来指定一个开头。在之后,我们发现那个sample参数设为0时选取的是argmax,1时是weighted_pick,2时以space为标准,如果有space则选择weighted_pick, 不然就是argmax。好了,实际运行的效果如何呢?让我们来几个例子看看。

这里是the为开始,我们看到了一开始有点乱码。之后,我们看到可是我知道了,他们三个女生等短句都是通顺的,同时,也有一些及其不通的,例如底忘记新地把满些了,这句话什么含义完全不清楚。再看下一个例子,如果以我开头会怎么样呢?

如果把sample设为0又会如何呢?

这里篇幅紧凑多了。再次运行,我们得到了相同的结果,因为是argmax么,所以在没改变的情况下我们会得到相同的结果。

这个运行结果还是很有意思的,有兴趣的朋友可以自行下载项目然后试着去操作一下!

转载于:https://www.cnblogs.com/edwardbi/p/5573951.html

character-RNN模型介绍以及代码解析相关推荐

  1. Bert模型介绍及代码解析(pytorch)

    Bert(预训练模型) 动机 基于微调的NLP模型 预训练的模型抽取了足够多的信息 新的任务只需要增加一个简单的输出层 注:bert相当于只有编码器的transformer 基于transformer ...

  2. 光场相机重聚焦原理介绍及代码解析

    光场相机重聚焦原理介绍及代码解析 光场相机重聚焦–焦点堆栈深度估计法 全部代码下载地址: https://download.csdn.net/download/weixin_38285131/1044 ...

  3. 对抗思想与强化学习的碰撞-SeqGAN模型原理和代码解析

    GAN作为生成模型的一种新型训练方法,通过discriminative model来指导generative model的训练,并在真实数据中取得了很好的效果.尽管如此,当目标是一个待生成的非连续性序 ...

  4. skip-gram模型介绍及代码

    在自然语言处理中,首先要把文本转化为数据的形式,更确切地说,是把词转化为向量的形式,才可以用计算机通过各种算法处理自然语言问题.在词向量的表示方法中,One-hot编码是一个非常经典的表示方法,但是在 ...

  5. 时间序列模型SCINet(代码解析)

    前言 SCINet模型,精度仅次于NLinear的时间序列模型,在ETTh2数据集上单变量预测结果甚至比NLinear模型还要好. 在这里还是建议大家去读一读论文,论文写的很规范,很值得学习,论文地址 ...

  6. 【COMSOL】Marzas 材料模型 C 源文件代码解析

    文件头 该材料模型输入为应变.材料参数.模型状态,输出为应力.雅可比矩阵,供 COMSOL 使用. #include <math.h> #include <stdlib.h> ...

  7. 关于Transformer你需要知道的都在这里------从论文到代码深入理解BERT类模型基石(包含极致详尽的代码解析!)

    UPDATE 2.26.2020 为代码解析部分配上了Jay Ammar The Illustrated GPT-2 的图示,为想阅读源码的朋友缓解疼痛! 深入理解Transformer------从 ...

  8. GAT: 图注意力模型介绍及PyTorch代码分析

    文章目录 GAT: 图注意力模型介绍及代码分析 原理 图注意力层(Graph Attentional Layer) 情境一:节点和它的一个邻居 情境二:节点和它的多个邻节点 聚合(Aggregatio ...

  9. RNN模型与NLP应用笔记(2):文本处理与词嵌入详解及完整代码实现(Word Embedding)

    一.写在前面 紧接着上一节,现在来讲文本处理的常见方式. 本文大部分内容参考了王树森老师的视频内容,再次感谢王树森老师和李沐老师的讲解视频. 目录 一.写在前面 二.引入 三.文本处理基本步骤详解 四 ...

  10. 情感分类模型介绍CNN、RNN、LSTM、栈式双向LSTM

    情感分类模型介绍CNN.RNN.LSTM.栈式双向LSTM 1.文本卷积神经网络(CNN) 卷积神经网络经常用来处理具有类似网格拓扑结构(grid-like topology)的数据.例如,图像可以视 ...

最新文章

  1. iOS中JS 与OC的交互(JavaScriptCore.framework)
  2. MATLAB和Python读取wave文件的波形对比
  3. Cloudera maneger登录页面后的操作是什么?
  4. 2021 icme_重磅 | 2021年U.S. News 全美院校排名发布,疫情之下,排名大洗牌?!
  5. java 带宽控制_如何使用Java netty正确限制带宽使用?
  6. Kafka刚开启就秒退
  7. 关于selenium关闭chrome密码登录时弹出的密码提示框
  8. js验证银行卡号 luhn校验规则
  9. c#之有参和无参构造函数,扩展方法
  10. Butterworth滤波器设计(IIR类型)
  11. 9N90-ASEMI的MOS管9N90
  12. Vue-----table 控件自动勾选全选框2 与tab控件组合使用
  13. 欢迎中文社区新版主@刘文艺
  14. Unix/Linux编程:getcontext、setcontext
  15. 数据分析与数据挖掘实战案例本地房价预测(716):
  16. 声网如何添加与配置项目
  17. 百度云盘搜索引擎微信公证号_微信公众号被百度搜索引擎收录?SEO优化诞生新方法!...
  18. ROS--rospy
  19. python货币兑换_零基础python作业--货币兑换的服务系统
  20. 如何利用Python自动根据数据生成降雨量统计分析报告

热门文章

  1. 理解伪元素:before和:after
  2. qW3xt.2服务器病毒
  3. cloudera-agent启动File not found : /usr/sbin/cmf-agent解决办法(图文详解)
  4. mac XAMPP环境下, 使用php函数mkdir()添加新目录(文件)报错,报错信息:permission denied;...
  5. SQL Server多表同时查询
  6. spring init
  7. 网络通信之通过get/post方式提交参数给web应用
  8. spring云化架构迁移 (一)
  9. Mysql 演示示例存储过程
  10. 20150820-Linux命令概述及一些基本命令