引子

我们常常按是否引入时序信息将神经网络分为前向神经网络和循环神经网络,相信大家对于如何实现一般的前向神经网络较为熟悉,那如何实现循环神经网络呢?
        当然,我们可以自己实现,这并不复杂,不过本文要讲的是如何理解和使用已有的轮子。
        为了更快速的实现我们的网络原型,拟使用TFLearn来搭建一个循环神经网络。我大致看了一下TFLearn中对于循环神经网络的实现,下面对LSTM的实现进行简要介绍,具体接口可以参考《TFLearn: Recurrent Layers》。

理解LSTM

原始RNN

原始的RNN结构如下图所示:

这里画出了3个基本RNN单元,每个RNN单元的实现十分简单:将当前时刻的输入与上一时刻的输出连接起来(concatenate),然后经过tanh激活函数计算后作为当前时刻的输出。这里就不对此进行仔细介绍了,有兴趣可以查看Colah的博文《Understanding LSTM Networks》。

LSTM

下面我们参考Colah的博文对LSTM稍作介绍。
        LSTM结构如下图所示:

其主要包括遗忘门、输入门、输出门。遗忘门结构及其计算方法为:

输入门结构、状态更新及其计算方法为:

此时,我们的状态流为:

最终的输出门及输出计算方法为:

关于LSTM基础就介绍到这里,我们接下来将介绍TFLearn中LSTM的函数接口与实现。

LSTM函数接口与实现

用户接口

首先,我们关注的自然是用户接口的形式了:

tflearn.layers.recurrent.lstm (incoming, n_units, activation='tanh', inner_activation='sigmoid', dropout=None, bias=True, weights_init=None, forget_bias=1.0, return_seq=False, return_state=False, initial_state=None, dynamic=False, trainable=True, restore=True, reuse=False, scope=None, name='LSTM')

在该接口中,我们这里只介绍四个参数:

  • incoming: Tensor,3-D Tensor,这是输入;
  • n_units: lstm的units数量,也即LSTM中的WfWfW_f、WiWiW_i、WcWcW_c、WoWoW_o的第一个维度,或者说是LSTM中各个操作的神经元个数;
  • activation: 应用于输入与状态的激活函数,默认为“tanh”;
  • inner_activation: 门控单元的激活函数,默认为“sigmoid”。

先稍微解释一下为什么n_units表示LSTM中的各个操作的神经元个数。传入lstm()的n_units参数将被传入函数BasicLSTMCell()中进行使用。而在BasicLSTMCell中,有如下代码:

concat = _linear([inputs, h], 4 * self._num_units, True, 0.,self.weights_init, self.trainable, self.restore,self.reuse)# i = input_gate, j = new_input, f = forget_gate, o = output_gate
i, j, f, o = array_ops.split(value=concat, num_or_size_splits=4,axis=1)

显然,这里是声明了4 * n_units个神经元,然后切分为i,j,f,o,也即我们前面说到的四组神经元,或者说对应着前面的四组权重。所以,n_units表示LSTM中各个操作的神经元个数。

总的来说,LSTM这一层所期望的输入为3-D Tensor,该Tensor的shape为[samples, timesteps, input_dim];这一层所得到的输出将依照我们所给定的return_seq参数来设定,如果return_seq为True,则输出一个序列,所以其shape为[samples, timesteps, output_dim],如果return_seq为False,则输出一个2-D的Tensor,其shape为[samples, output_dim]。

lstm()源代码分析

了解了lstm()函数最基本的用户接口之后,我们稍微来分析一下其源代码,代码位于tflearn/layers/recurrent.py中,该函数主体部分较为简单,所以我们直接列出:

def lstm(incoming, n_units, activation='tanh', inner_activation='sigmoid', dropout=None, bias=True, weights_init=None, forget_bias=1.0, return_seq=False, return_state=False, initial_state=None, dynamic=False, trainable=True, restore=True, reuse=False, scope=None, name="LSTM"):cell = BasicLSTMCell(n_units, activation=activation,inner_activation=inner_activation,forget_bias=forget_bias, bias=bias,weights_init=weights_init,trainable=trainable,restore=restore, reuse=reuse)x = _rnn_template(incoming, cell=cell, dropout=dropout,return_seq=return_seq,return_state=return_state,initial_state=initial_state,dynamic=dynamic,scope=scope, name=name)return x

从这一函数的主体部分可以看出,该函数主要依赖于BasicLSTMCell()、_rnn_template()进行实现。粗略地说,BasicLSTMCell类用于指明我们的循环神经网络层的基本计算单元;_rnn_template函数则用于对输入数据以及输出数据进行基本处理,并利用_rnn()函数实现循环神经网络层。

下一小节对这三个函数进行简要分析。

关键类与函数分析

class BasicLSTMCell

该类的实现基于论文《Recurrent Neural Network Regularization》,也即正则化的RNN。我们主要关注其中的__call__()函数,与上面对LSTM介绍一致:

# 状态计算
new_c = (c * self._inner_activation(f + self._forget_bias) +self._inner_activation(i) *self._activation(j))
# 输出计算
new_h = self._activation(new_c) * self._inner_activation(o)

这里只拷贝了普通的实现,并没有拷贝batch normalization的实现方式,因为普通的实现较为容易与前面对LSTM的介绍对应起来。因为inner_activation默认值为“sigmoid”,activation的默认值为“tanh”,所以对于状态new_c来说,计算方法为newc=c∗tanh(f+bf)+σ(i)∗tanh(j)newc=c∗tanh(f+bf)+σ(i)∗tanh(j)new_c = c * tanh(f + b_f) + \sigma(i) * tanh(j),对于输出new_h来说,计算方法为newh=tanh(newc)∗σ(o)newh=tanh(newc)∗σ(o)new_h = tanh(new_c) * \sigma(o),这与前文是完全对应的。

function _rnn_template()

我们已经说过,该函数主要用于封装,或者说,这是一个rnn的模板,我们可以结合C++ template来理解。下面分析一下该函数中比较重要的代码:

def _rnn_template(incoming, cell, dropout=None, return_seq=False,return_state=False, initial_state=None, dynamic=False, scope=None, reuse=False, name="LSTM"):
...input_shape = utils.get_incoming_shape(incoming)
...with tf.variable_scope(scope, default_name=name, values=[incoming], reuse=reuse) as scope:
...inference = incoming# If a tensor given, convert it to a per timestep listif type(inference) not in [list, np.array]:ndim = len(input_shape)assert ndim >= 3, "Input dim should be at least 3."axes = [1, 0] + list(range(2, ndim))inference = tf.transpose(inference, (axes))inference = tf.unstack(inference)outputs, state = _rnn(cell, inference, dtype=tf.float32, initial_state=initial_state, scope=name, sequence_length=sequence_length)

下面对其中将Tensor转换为per timestep list部分代码进行分析。

incoming参数的Tensor即我们传入到lstm()函数的Tensor,其shape一般为[samples, timesteps, input_dim]。此时,上述代码中的ndim为3,axes = [1, 0, 2]。函数tf.transpose()的作用时交换输入张量的不同维度,其用法为tf.transpose(input, [dim_1, dime_2, …, dim_n]),举例来说,我们的tf.transpose(input, [2, 1, 0])实际上就是将输入张量的第一个维度和第三个维度交换: [0, 1, 2] –> [2, 1, 0]。所以我们对inference使用该函数将得到一个shape为[timesteps, samples, input_dim]的Tensor。接着,我们使用tf.unstack()进行处理,这个函数的作用是将给定的R维张量拆分成R-1维张量,其用法为unstack(value, num=None, axis=0, name=’unstack’),用在这里的话,是按timesteps拆分为list,list中每个元素的shape为[samples, input_dim]。

将inference Tensor拆分为list之后,我们接着将inference list传入到_rnn()中进行计算,所以需要对_rnn()函数进行理解。

function static_rnn()

大家可能会疑惑,不是_rnn()么,怎么变成了static_rnn()?这可以从from tensorflow.python.ops.rnn import static_rnn as _rnn看出来。
        对于这个函数的话,我们只需要稍微理解一下其注释即可:

The simplest form of RNN network generated is:
"""pythonstate = cell.zero_state(...)outputs = []for input_ in inputs:output, state = cell(input_, state)outputs.append(output)return (outputs, state)
"""

实际上就是对inference list中的元素循环调用前面传入的cell(如这里的BasicLSTMCell),然后返回状态及输出即可。

Stateful LSTM

事实上,我们上面并没有仔细分析TFLearn实现LSTM时的状态传递,现在我们进行一些简要地分析。
        在_rnn_template()中,首先将Tensor转化为了per timestep list的形式,也即第一个维度是timstep,然后将Tensor传入了static_rnn()中。而在static_rnn()中,先是建立一个全零的state变量,然后按照timestep进行拆分,state将在timestep之间传递,最后虽然我们会返回output和最后一个timestep的state到_rnn_template()中,但是在rnn_template()中TFLearn仅仅只是让我们去决定是否要输出这个state,而并没有将其记录下来,用于下一个训练步。所以,总的来说,我们认为,TFLearn实现的LSTM并不是一个stateful LSTM,也即,虽然会在timestep中传递,但是不会在两次预测之间传递。换句话说,训练时,state既不会在batch之间传递,也不会在某个batch的不同样本之间传递;预测时,state不会在多次预测之间传递。如果想要构造一个stateful LSTM,可以参考《TensorFlow: Remember LSTM state for next batch (stateful LSTM)》。

总结

分析到这里,相信大家对于TFLearn中LSTM的实现已经有一个不错的理解了,大家可以尝试愉快地使用TFLearn的lstm()函数接口啦~

TFLearn之RNN相关推荐

  1. 基于tflearn的RNN模仿莎士比亚写作

    生成类似莎士比亚写作的文章 1.安装准备: 安装tflearn,是一个封装高的TensorFlow高层框架 pip install -I tflearn 2.实现过程 第一步:下载莎士比亚写作文本 i ...

  2. Recurrent Neural Networks(RNN) 循环神经网络初探

    1. 针对机器学习/深度神经网络"记忆能力"的讨论 0x1:数据规律的本质是能代表此类数据的通用模式 - 数据挖掘的本质是在进行模式提取 数据的本质是存储信息的介质,而模式(pat ...

  3. [github 源码收集] == tflearn examples

    API文档:http://tflearn.org/doc_index/#API 地址:https://github.com/tflearn/tflearn/tree/master/examples T ...

  4. TFLearn循环神经网络识别验证码

    1.数据清洗与特征提取 训练数据集(55000,784),测试数据集(10000,784),标签采用one-hot独热编码, 在DNN或MLP中,我们将28x28的图片,转换成维度为784的特征向量, ...

  5. RNN循环神经网络原理与示例

    一.原理解释 循环神经网络(Recurrent Neural Network,RNN)很多实时情况都能通过时间序列模型来描述.基于序列的模型可以用在很多领域中.在音乐中,一首曲子的下一个音符肯定取决于 ...

  6. 【机器学习】RNN循环神经网络

    循环神经网络归属: 领域:机器学习 方向:自然语言处理 贡献:自动文本生成 循环神经网络实际应用: 生活中因为原始数据都是序列化的,比如自然语言,语音处理,时间序列问题(股票价格)等问题, 这个时候需 ...

  7. pytorch中如何处理RNN输入变长序列padding

    一.为什么RNN需要处理变长输入 假设我们有情感分析的例子,对每句话进行一个感情级别的分类,主体流程大概是下图所示: 思路比较简单,但是当我们进行batch个训练数据一起计算的时候,我们会遇到多个训练 ...

  8. RNN,LSTM,GRU基本原理的个人理解重点

    20210626 循环神经网络_霜叶的博客-CSDN博客 LSTM的理解 - 走看看 重点 深入LSTM结构 首先使用LSTM的当前输入 (x^t)和上一个状态传递下来的 (h^{t-1}) 拼接训练 ...

  9. [PyTorch] rnn,lstm,gru中输入输出维度

    本文中的RNN泛指LSTM,GRU等等 CNN中和RNN中batchSize的默认位置是不同的. CNN中:batchsize的位置是position 0. RNN中:batchsize的位置是pos ...

最新文章

  1. 2.3 使用 dom4j 对 xml文件进行 dom 解析
  2. centos6.5 mysql下载_Centos6.5在线安装mysql 8.0详细教程
  3. python装饰器实例-python装饰器使用实例详解
  4. UART0串口编程(三):中断方式的串口编程;用中断编写发送函数
  5. java 无类名实现接口_为什么很多人写 Java/Android 时,选择让同一个类实现多个接口,而不是用多个内部匿名类?...
  6. 54款开源服务器软件(内容管理、数据库、电子商务、邮件服务器、文件传输、操作系统、安全、小公司服务 .
  7. .NET Windows服务应用程序
  8. 【HTML+CSS网页设计与布局 从入门到精通】第6章-标题h1,h1字体格式的设置方式
  9. 怎样安装ubuntu系统
  10. 2021年中国动物血浆制品及其衍生物市场趋势报告、技术动态创新及2027年市场预测
  11. Matlab2017a/b激活license.lic文件
  12. cad页面布局快捷键_cad布局窗口快捷键
  13. 记录:SpringBoot 开发之集成微信公众号支付
  14. Mono.Cecil - 0.6
  15. e7 88系列服务器,英特尔至强E7处理器性能多项测试比拼
  16. 一个外包三流Java程序员凭什么逆袭到阿里P7?看完直呼一声WC
  17. hp,Qlogic,Brocade光纖卡查看方式
  18. vue 超出三行隐藏_文字超出三行省略...显示全文
  19. python彩票分析_128期老铁大乐透预测奖号:大中小码分析
  20. 【论文解读】利用高光谱图像对场景反射率进行有效估计(Efficient Estimation of Reflectance Parameters from Imaging Spectropy)

热门文章

  1. 风电随机性动态经济调度模型(Matlab代码实现)
  2. 【Matlab】Matlab作图的一些小知识
  3. mouseenter事件java_从子元素输入父元素时,不会触发MouseEnter事件
  4. 人工智能概念之——各向异性——亲和矩阵
  5. bat脚本执行exe文件_将批处理(BAT)脚本转换为可执行(EXE)文件
  6. BLEU 评价 NLP 文本输出质量
  7. Java实用工具类五:URL转码、解码类
  8. Java 类图描述(类图)
  9. 67.220.91.30/forum/index.php,Burp辅助插件之WooyunSearch 乌云漏洞库payload
  10. SDRAM控制器说明/altera/northwest logic