点击下方标题,迅速定位到你感兴趣的内容

  • 前言
  • 相关知识
  • 标准RNN
  • 双向RNN
  • Multi-layer(stacked) RNN
  • 深度循环神经网络
  • Recursive Neural Network
  • 补充
    • 循环神经网络的训练算法:BPTT
    • RNN的梯度爆炸和消失问题

前言

说明:讲解时会对相关文章资料进行思想、结构、优缺点,内容进行提炼和记录,相关引用会标明出处,引用之处如有侵权,烦请告知删除。
转载请注明:DengBoCong

本篇文章主要总结我在学习过程中遇到的RNN、其相关变种,并对相关结构进行说明和结构图展示。内容包括RNN、RecNN、多层、双向、RNNCell等等,同时包括在计算框架(TensorFlow及PyTorch)API层面的一些理解记录。本篇文章不进行深入推导和底层原理介绍,仅做总结记录,感兴趣者可自行根据内容详细查阅资料。

RNN(递归神经网络)包括Recurrent Neural Network和Recursive Neural Network两种,分别为时间递归神经网络和结构递归神经网络。

计算框架版本:

  • TensorFlow2.3
  • PyTorch1.7.0

相关知识

在进行后面内容的陈述之前,先来简单结合计算框架说明一下vanilla RNN、LSTM、GRU之间的区别。虽然将vanilla RNN、LSTM、GRU这个三个分开讲进行对比,但是不要忘记它们都是RNN,所以在宏观角度都是如下结构:

而它们区别在于中间的那个隐藏状态计算单元,这里贴出它们的计算单元的细节,从左到右分别是vanilla RNN、LSTM、GRU。

看了隐藏单元之后,你有没有发现LSTM和其他两个的输入多了一个cell state,LSTM的门道就在这,cell state 就是实现LSTM的关键(ps:GRU其实也有分hidden state和cell state,不过在GRU中它们两个是相同的)。细节我不去深究,感兴趣的自行查看论文:

  • RNN
  • LSTM
  • GRU

我这里就简单的结合TensorFlow和PyTorch说明一下cell state和hidden state,首先看下面两个计算框架的调用(详细参数自行查阅文档,这里只是为了说明state):

# TensorFlow中的LSTM调用
whole_seq_output, final_memory_state, final_carry_state =tf.keras.layers.LSTM(4, return_sequences=True, return_state=True)(inputs)
# Pytorch中的LSTM调用
output, (hn, cn) = torch.nn.LSTM(10, 20, 2)(input, (h0, c0))# TensorFlow中的GRU调用
whole_sequence_output, final_state =tf.keras.layers.GRU(4, return_sequences=True, return_state=True)(inputs)
# Pytorch中的GRU调用
output, hn = torch.nn.GRU(10, 20, 2)(input, h0)

以TensorFlow举例(PyTorch默认都返回),当return_state参数设置为True时,将会返回隐藏层状态,即cell_state。在LSTM 的网络结构中,直接根据当前input 数据,得到的输出称为 hidden state,还有一种数据是不仅仅依赖于当前输入数据,而是一种伴随整个网络过程中用来记忆,遗忘,选择并最终影响hidden state结果的东西,称为 cell state。cell state默认是不输出的,它仅对输出 hidden state 产生影响。通常情况,我们不需要访问cell state,但当需要对 cell state 的初始值进行设定时,就需要将其返回。所以在上面的TensorFlow对LSTM的调用中,final_memory_state是最后一个timestep的状态,final_carry_state是最后一个timestep的cell state。既然见到LSTM和GRU,那下面就贴一张它们的状态更新公式图以作记录:

后面简要阐述的所有RNN及其变种,都是代指vanilla RNN、LSTM、GRU三个,只不过为了方便描述,以RNN作为总称进行说明。

TensorFlow中,RNN类是作为如第一张结构图那些的宏观结构,所以它有一个cell参数,你可以根据实际需要传入SimpleRNNCell、LSTMCell和GRUCell(这三个你就可以理解成上面讲的计算单元),它们三个可以单独使用,在一些地方特别管用。

PyTorch中大致是一样的,不过RNN类则是标准的RNN实现的,而不是像Tensorflow那样的架构,PyTorch同样有RNNCell、LSTMCell和GRUCell

标准RNN

RNN忽略单元细节的具体结构图如下。从图中就能够很清楚的看到,上一时刻的隐藏层是如何影响当前时刻的隐藏层的(注意这里Output的数量画少了,看起来不够形象,应该是 X=[x1,x2,...,xm]X=[x_1,x_2,...,x_m]X=[x1​,x2​,...,xm​]和 O=[o1,o2,...,om]O=[o_1,o_2,...,o_m]O=[o1​,o2​,...,om​])。这里的Output是对应时间步的状态,而 sss 是隐藏状态,一般在实践中用它来初始化RNN。

当然,可以换一种方式画结构图,如下图所示,按照RNN时间线展开。注意了,隐藏层 sts_tst​ 不仅取决于 xtx_txt​ 还取决与 st−1s_{t-1}st−1​。

从上面总结公式如下:
ot=g(Vst)(1)o_t=g(V_{s_t}) \quad\quad (1)ot​=g(Vst​​)(1) st=f(Uxt+Wst−1)(2)s_t=f(U_{x_t}+W_{s_{t-1}}) \quad\quad (2)st​=f(Uxt​​+Wst−1​​)(2)
式(1)是输出层的计算公式,输出层是一个全连接层,也就是它的每个节点都和隐藏层的每个节点相连。VVV是输出层的权重矩阵,ggg是激活函数。式(2)是隐藏层的计算公式,它是循环层。UUU 是输入 xxx 的权重矩阵,WWW 是上一次的值作为这一次的输入的权重矩阵,fff 是激活函数。从宏观意义上来说,循环层和全连接层的区别就是循环层多了一个权重矩阵 WWW。通过循环带入得下式:
ot=Vf(Uxt+Wf(Uxt−1+Wf(Uxt−2+Wf(Uxt−3+...))))o_t=Vf(U_{x_t}+Wf(U_{x_{t-1}}+Wf(U_{x_{t-2}}+Wf(U_{x_{t-3}}+...))))ot​=Vf(Uxt​​+Wf(Uxt−1​​+Wf(Uxt−2​​+Wf(Uxt−3​​+...))))
从上面可以看出,循环神经网络的输出值 oto_tot​,是受前面历次输入值xtx_txt​、xt−1x_{t-1}xt−1​、xt−2x_{t-2}xt−2​、xt−3x_{t-3}xt−3​、…影响的,这就是为什么循环神经网络可以往前看任意多个输入值的原因。

双向RNN

论文:Link

从上图可以看出,双向RNN的隐藏层要保存两个值,一个 AAA 参与正向计算,另一个值 A′A'A′ 参与反向计算(注意了,正向计算和反向计算不共享权重),最终的输出值取决于 AAA 和 A′A'A′ 的计算方式。其计算方法有很多种,这里结合TensorFlow和PyTorch说明:

# TensorFlow中,需要使用Bidirectional来实现双向RNN,如下所示
# 其中merge_mode就是A和A'两者的计算方式:{'sum', 'mul', 'concat', 'ave', None}
tf.keras.layers.Bidirectional(layer, merge_mode='concat', weights=None, backward_layer=None, **kwargs
)# PyTorch则不同,在各RNN的具体实现中,都有一个bidirectional参
# 数来控制是否是双向的,可自行查看PyTorch的API文档,特别说明的是
# PyTorch没有merge_mode,所以双向RNN直接会返回正向和反向的状态,
# 需要你自行进行合并操作

Multi-layer(stacked) RNN

将多个RNN堆叠成多层RNN,每层RNN的输入为上一层RNN的输出,如下图所示。多层 (Multi-layer) RNN 效果很好,但可能会常用到 skip connections 的方式

深度循环神经网络

前面我们介绍的循环神经网络只有一个隐藏层,我们当然也可以堆叠两个以上的隐藏层,这样就得到了深度循环神经网络,如下图所示:

我们把第 iii 个隐藏层的值表示为 st(i)s_t^{(i)}st(i)​、st′(i)s_t^{'(i)}st′(i)​,则深度循环神经网络的计算方式可以表示为:
ot=g(V(i)st(i)+V′(i)st′(i))o_t=g(V^{(i)}s_t^{(i)}+V^{'(i)}s_t^{'(i)})ot​=g(V(i)st(i)​+V′(i)st′(i)​) st(i)=f(U(i)st(i−1)+W(i)st−1)s_t^{(i)}=f(U^{(i)}s_t^{(i-1)}+W^{(i)}s_{t-1})st(i)​=f(U(i)st(i−1)​+W(i)st−1​) st′(i)=f(U′(i)st′(i−1)+W′(i)st+1′)s_t^{'(i)}=f(U^{'(i)}s_t^{'(i-1)}+W^{'(i)}s_{t+1}^{'})st′(i)​=f(U′(i)st′(i−1)​+W′(i)st+1′​) st(1)=f(U(1)xt+W(1)st−1)s_t^{(1)}=f(U^{(1)}x_t+W^{(1)}s_{t-1})st(1)​=f(U(1)xt​+W(1)st−1​) st′(1)=f(U′(1)xt+W′(1)st+1′)s_t^{'(1)}=f(U^{'(1)}x_t+W^{'(1)}s_{t+1}^{'})st′(1)​=f(U′(1)xt​+W′(1)st+1′​)

Recursive Neural Network

RNN适用于序列建模,而许多NLP问题需要处理树状结构,因此提出了RecNN的概念。与RNN将前序句子编码成状态向量类似,RecNN将每个树节点编码成状态向量。RecNN中的每棵子树都由一个向量表示,其值由其子节点的向量表示递归确定。

RecNN接受的输入为一个有n个单词的句子的语法分析树,每个单词都表示为一个向量,语法分析树表示为一系列的生成式规则。举个例子,The boy saw her duck的分析树如下图:

对应的生成式规则(无标签+有标签)如下图:

RecNN的输出为句子的内部状态向量(inside state vectors),每一个状态向量都对应一个树节点。具体RecNN细节自行详细查阅资料。

补充

普遍来看, 神经网络都会有梯度消失和梯度爆炸的问题,其根源在于现在的神经网络在训练的时候,大多都是基于BP算法,这种误差向后传递的方式,即多元函数求偏导中,链式法则会产生 vanishing,而 RNN 产生梯度消失的根源是权值矩阵复用。

循环神经网络的训练算法:BPTT

BPTT算法是针对循环层的训练算法,它的基本原理和BP算法是一样的,也包含同样的三个步骤:

  • 前向计算每个神经元的输出值
  • 反向计算每个神经元的误差项 δj\delta_jδj​ 值,它是误差函数 EEE 对神经元 jjj 的加权输入 netjnet_jnetj​ 的偏导数
  • 计算每个权重的梯度
  • 最后再用随机梯度下降算法更新权重。

RNN的梯度爆炸和消失问题

不幸的是,实践中前面介绍的几种RNNs并不能很好的处理较长的序列。一个主要的原因是,RNN在训练中很容易发生梯度爆炸和梯度消失,这导致训练时梯度不能在较长序列中一直传递下去,从而使RNN无法捕捉到长距离的影响。通常来说,梯度爆炸更容易处理一些。因为梯度爆炸的时候,我们的程序会收到NaN错误。我们也可以设置一个梯度阈值,当梯度超过这个阈值的时候可以直接截取。梯度消失更难检测,而且也更难处理一些。总的来说,我们有三种方法应对梯度消失问题:

  • 合理的初始化权重值。初始化权重,使每个神经元尽可能不要取极大或极小值,以躲开梯度消失的区域。
  • 使用 relurelurelu 代替 sigmoidsigmoidsigmoid 和 tanhtanhtanh 作为激活函数。
  • 使用其他结构的RNNs,比如长短时记忆网络(LTSM)和Gated Recurrent Unit(GRU),这是最流行的做法。

参考资料:

  • How to use return_state or return_sequences in Keras
  • Understanding LSTM Networks
  • Illustrated Guide to LSTM’s and GRU’s: A step by step explanation
  • RNN summarize
  • 从动图中理解 RNN,LSTM 和 GRU
  • RNN、lstm、gru详解
  • 用 Recursive Neural Networks 得到分析树
  • 循环神经网络
  • 双向 和 多重 RNN

关于RNN理论和实践的一些总结相关推荐

  1. 时间序列预测方法汇总:从理论到实践(附Kaggle经典比赛方案)

    ©作者 | Light 学校 | 中国科学院大学 研究方向 | 机器学习 时间序列是我最喜欢研究的一种问题,这里我列一下时间序列最常用的方法,包括理论和实践两部分.理论部分大多是各路神仙原创的高赞解读 ...

  2. 【视频课】行为识别课程更新!CNN+LSTM理论与实践!

    前言 欢迎大家关注有三AI的视频课程系列,我们的视频课程系列共分为5层境界,内容和学习路线图如下: 第1层:掌握学习算法必要的预备知识,包括Python编程,深度学习基础,数据使用,框架使用. 第2层 ...

  3. ARM NEON指令集优化理论与实践

    ARM NEON指令集优化理论与实践 一.简介 NEON就是一种基于SIMD思想的ARM技术,相比于ARMv6或之前的架构,NEON结合了64-bit和128-bit的SIMD指令集,提供128-bi ...

  4. CPU消耗,跟踪定位理论与实践

    CPU消耗,跟踪定位理论与实践 一.性能指标之资源指标定位方案 1.打tprof报告方法 抓取perfpmr文件 60秒. perfpmr.sh 60 从结果文件中取出tprof.sum 或直接抓取t ...

  5. UI设计培训之如何将设计理论与实践相结合

    学习UI设计理论知识与实践技术都是要有的,很多人都不爱去听理论知识,这对以后的工作是没有任何帮助的,只有将设计理论与实践相结合才能帮助到自己,那么如何将设计理论与实践相结合?来看看本期下面的详细介绍. ...

  6. Java 理论与实践: 非阻塞算法简介——看吧,没有锁定!(转载)

    简介: Java™ 5.0 第一次让使用 Java 语言开发非阻塞算法成为可能,java.util.concurrent 包充分地利用了这个功能.非阻塞算法属于并发算法,它们可以安全地派生它们的线程, ...

  7. Microsoft NLayerApp案例理论与实践 - 项目简“.NET研究”介与环境搭建

    项目简介 Microsoft – Spain团队有一个很不错的面向领域多层分布式项目案例:Microsoft – Domain Oriented N-Layered .NET 4.0 App Samp ...

  8. 重磅直播|立体视觉之立体匹配理论与实践​

    点击上方"3D视觉工坊",选择"星标" 干货第一时间送达 大家好,本公众号现已开启线上视频公开课,主讲人通过B站直播间,对3D视觉领域相关知识点进行讲解,并在微 ...

  9. java 理论与实践,Java 理论与实践: 正确使用 Volatile 变量

    Java™ 语言包含两种内在的同步机制:同步块(或方法)和 volatile 变量.这两种机制的提出都是为了实现代码线程的安全性.其中 Volatile 变量的同步性较差(但有时它更简单并且开销更低) ...

最新文章

  1. 例题3-4 猜数字游戏的提示(Master-Mind Hints, UVa 340)
  2. java Stream
  3. 你有没有觉得邮件发送人固定配置在yml文件中是不妥当的呢?SpringBoot 动态设置邮件发送人
  4. Eliminate Witches!【2011年北京赛区正赛赛题-2】
  5. mysql explain实践
  6. mybatis简单案例源码详细【注释全面】——实体层(User.java)
  7. linux 命令 find -exec 操作的问题
  8. 十五、Python操作mysql数据库
  9. ORA-28001: the password has expired (DBD ERROR: OCISessionBegin)解决办法
  10. React的静态类型检查
  11. window7修改屏幕旋转快捷键
  12. 怎么把ogg转成mp3格式?
  13. 一小时人生服务器维护,TapTap《一小时人生》手游:说好的一小时人生模拟,我却只能活6分钟...
  14. 2010年计算机考研选择题解析,2009-2010计算机考研真题及答案(含选择题解析)WORD高清晰版.pdf...
  15. vue3 编译报 ESLint: ‘defineProps‘ is not defined no-undef 错误问题
  16. tkinter -- tcp
  17. K-means实现图像聚类
  18. 毫无破绽!用这个开源项目换了张脸后,连女朋友都难分真假,能否骗过刷脸支付?...
  19. Ubuntu18.04下UnixBench跑分
  20. H3CIE A套实验配置

热门文章

  1. OpenCore Configurator for Mac(黑苹果系统引导工具)
  2. Python外星人入侵游戏——添加飞船和外星人图片
  3. APP推广的3个过程:应用市场、网盟、换量
  4. 项目成功和失败的几大因素
  5. JAVA IO流(3)
  6. 基于小波的图像边缘检测,小波变换边缘检测原理
  7. 深入浅出解答hero刷rom的各种问题
  8. c语言求50以内阶乘,C语言之数组50以内的阶乘.doc
  9. 【历史上的今天】8 月 26 日:jQuery 发布;中国第一台百万次计算机试制成功
  10. java文档注释用什么开头,极其重要