最近在学习cs224n: Natural Language Processing with Deep Learning课程时,对RNN、LSTM和GRU的原理有了更深一层的理解,对LSTM和GRU如何解决RNN中梯度消失(Gradient Vanishing)的问题也有了新的认识,于是写下本文。

  • RNN
  • Gradient Vanishing
    • 减缓梯度消失
    • 防止梯度爆炸
  • GRU
  • LSTM

RNN

Gradient Vanishing

  RNN中容易出现Gradient Vanishing是因为在梯度在向后传递的时候,由于相同的矩阵相乘次数太多,梯度倾向于逐渐消失,导致后面的结点无法更新参数,整个学习过程无法正常进行。

  Gradient Vanishing的推导如下
  整个序列的预测是之前每个时刻的误差之和,而每个时刻t的预测误差又是之前每个时刻的误差之和(损失函数的一种定义方式,也可在每个时刻计算交叉熵损失函数并在序列上平均)。

  此时∂ht/∂hk∂ht/∂hk\partial h_{t}/\partial h_{k}是在[k,t][k,t][k, t]的时间域上应用链式法则,是一个连乘的形式,长度为时间区域的长度。

  记βWβW\beta_{W}和βhβh\beta_{h}分别为矩阵和向量的范数(L2),由于使用了sigmoid激活函数,所以f′(hj−1)f′(hj−1)f'{(h_{j-1})}的矩阵范数最大为1,因此得到下面更松弛的上界。指数项βhβWβhβW\beta _{h}\beta_{W}在显著的小于1或者大于1的时候,经过(t−k)(t−k)(t-k)次乘法之后将于倾向于0或者无限大,也即梯度消失和梯度爆炸。
  

减缓梯度消失

  For vanishing gradients: Initialization + ReLus!(减缓梯度消失的方法:初始化 + ReLu)
  Initialize W to identity matrix I(Rather than random initialization matrix) and f(z) = rect(z) = max(z,0)(初始化参数矩阵为单位矩阵,以语言模型为例这样的初始化效果就是上下文向量和词向量的平均)

防止梯度爆炸

  一种暴力的方法就是,当梯度的大小超过某个阈值的时候,将其缩放到某个阈值。虽然在数学书缺乏严谨的推导,但是实践效果挺好。
  其直观解释是,在一个只有一个隐藏节点的网络中,损失函数和权值w偏值b构成error surface,其中如下图表示有一张图:
   
  每次迭代梯度本来是正常的,一次一小步,但遇到这堵墙之后可能突然梯度爆炸到非常大,可能指向一个莫名其妙的地方(实线长箭头)。但缩放之后,能够把这种误导控制在可接受的范围内(虚线短箭头)。
  但这种trick无法推广到梯度消失,因为你不想设置一个最低值硬性规定之前的信息都相同重要地影响当前的输出。

GRU

  GRU分为Reset gate(重置门)和Update gate(更新门)
由:
  Reset Gate(重置门):控制是否遗忘之前的记忆。当reset gate = 0时,遗忘之前的信息ht−1ht−1h_{t-1}。
  Update Gate(更新门):控制之前记忆留存的比例。通过 ztztz_{t} 进行调节。
   
  
  Question: How do GRU fix vanishing gradient problem?(GRU如何解决梯度消失的问题?)
  1. 在标准的RNN中,梯度是严格的按照所有的中间节点流动的,而GRU在网络中创造了适应性的短连接(create adaptive shortcut connection)。在GRU中,可以选择性的遗忘和记忆此前的信息,在梯度的流动中做了短连接,避免梯度计算中的大量累积。
  2. 通过GRU公式,ht=zt∗ht−1+(1−zt)ht˜ht=zt∗ht−1+(1−zt)ht~h_{t}=z_{t}*h_{t-1}+(1-z_{t})\widetilde{h_{t}},其中ztztz_{t}是update gate的值,hthth_{t}是当前时刻的新信息。为了方便可做简化:ht=ht−1+(1−zt)ht˜ht=ht−1+(1−zt)ht~h_{t}=h_{t-1}+(1-z_{t})\widetilde{h_{t}},可以看到hthth_{t}和ht−1ht−1h_{t-1}此时是线性关系,不再是RNN中ht=f(Whhht−1+Whsxt)ht=f(Whhht−1+Whsxt)h_{t}=f(W^{hh}h_{t-1}+W^{hs}x_{t})的乘积关系,因此梯度在计算的时候不再是连乘关系。梯度在中间节点线性流动,就会保持很长时间的记忆。

LSTM

  LSTM分为input gate(输入门),forget gate(遗忘门),和output gate(输出门)
由: 
  Input Gate(输入门):表示当前的词语是否值得保留下来。在上述公式5中Final memory cell中体现。
  Forget Gate(遗忘门):表示过去的记忆是否忘记。当forget gate=0时,遗忘过去的记忆。
  Output Gate(输出门):表示当前的记忆应该被放大的倍数,用于将最终的记忆与隐状态分离,因为记忆c(t)中的信息不是都需要放到隐状态中,隐状态是个很重要且使用很频繁的东西。 
   
  Question1: How do LSTM fix vanishing gradient problem?(LSTM如何解决梯度弥散的问题?)
  1. 在标准的RNN中,梯度是严格的按照所有的中间节点流动的,而LSTM在网络中创造了适应性的短连接(create adaptive shortcut connection)。在LSTM中,可以选择性的遗忘和记忆此前的信息,在梯度的流动中做了短连接,避免梯度计算中的累积。
  2. 通过公式也可以看出,在LSTM中,Ct=ft∗Ct−1+it∗Ct˜Ct=ft∗Ct−1+it∗Ct~C_{t}=f_{t}*C_{t-1}+i_{t}*\widetilde{C_{t}},其中Ct−1Ct−1C_{t-1}是此前的信息,Ct˜Ct~\widetilde{C_{t}}是当前时刻的新信息,CtCtC_{t}是最终的信息。可以看到CtCtC_{t}和Ct−1Ct−1C_{t-1}此时是线性关系,不再是RNN中的乘积关系,因此梯度在计算的时候不再是连乘关系,梯度以线性在中间节点流动,因此就会保证很长时间的记忆。
  
  Question2: why tanh in ht=Ot∗tanh(Ct)ht=Ot∗tanh(Ct)h_{t}=O_{t}*tanh(C_{t})?
  课程中,Manning也没给出很具体的原理,但是Richard认为因为Ct=ft∗Ct−1+it∗Ct˜Ct=ft∗Ct−1+it∗Ct~C_{t}=f_{t}*C_{t-1}+i_{t}*\widetilde{C_{t}}为线性运算,为了增加系统的非线性于是采用了tanh。
  

理解RNN、LSTM、GRU和Gradient Vanishing相关推荐

  1. Pytorch中如何理解RNN LSTM GRU的input(重点理解seq_len / time_steps)

    在建立时序模型时,若使用keras,我们在Input的时候就会在shape内设置好sequence_length(后面简称seq_len),接着便可以在自定义的data_generator内进行个性化 ...

  2. RNN LSTM GRU 代码实战 ---- 简单的文本生成任务

    RNN LSTM GRU 代码实战 ---- 简单的文本生成任务 import torch if torch.cuda.is_available():# Tell PyTorch to use the ...

  3. ​​​​​​​DL之RNN/LSTM/GRU:RNN/LSTM/GRU算法动图对比、TF代码定义之详细攻略

    DL之RNN/LSTM/GRU:RNN/LSTM/GRU算法动图对比.TF代码定义之详细攻略 目录 RNN.LSTM.GRU算法对比 1.RNN/LSTM/GRU对比 2.RNN/LSTM/GRU动图 ...

  4. DL之LSTM:LSTM算法论文简介(原理、关键步骤、RNN/LSTM/GRU比较、单层和多层的LSTM)、案例应用之详细攻略

    DL之LSTM:LSTM算法论文简介(原理.关键步骤.RNN/LSTM/GRU比较.单层和多层的LSTM).案例应用之详细攻略 目录 LSTM算法简介 1.LSTM算法论文 1.1.LSTM算法相关论 ...

  5. RNN, LSTM, GRU, SRU, Multi-Dimensional LSTM, Grid LSTM, Graph LSTM系列解读

    RNN/Stacked RNN rnn一般根据输入和输出的数目分为5种 一对一 最简单的rnn 一对多 Image Captioning(image -> sequence of words) ...

  6. RNN,LSTM,GRU计算方式及优缺点

    本文主要参考李宏毅老师的视频介绍RNN相关知识,主要包括两个部分: 分别介绍Navie RNN,LSTM,GRU的结构 对比这三者的优缺点 1.RNN,LSTM,GRU结构及计算方式 1.1 Navi ...

  7. 图解 RNN, LSTM, GRU

    参考: Illustrated Guide to Recurrent Neural Networks Illustrated Guide to LSTM's and GRU's: A step by ...

  8. RNN,LSTM,GRU基本原理的个人理解重点

    20210626 循环神经网络_霜叶的博客-CSDN博客 LSTM的理解 - 走看看 重点 深入LSTM结构 首先使用LSTM的当前输入 (x^t)和上一个状态传递下来的 (h^{t-1}) 拼接训练 ...

  9. [PyTorch] rnn,lstm,gru中输入输出维度

    本文中的RNN泛指LSTM,GRU等等 CNN中和RNN中batchSize的默认位置是不同的. CNN中:batchsize的位置是position 0. RNN中:batchsize的位置是pos ...

最新文章

  1. linux c gcc编译报错 can not be used when making a shared object; recompile with -fPIC
  2. python 计时_Python计时相关操作详解【time,datetime】
  3. Nifi清空Queue操作
  4. android 短信编解码方式,中移短信cmpp协议/smpp协议 netty实现编解码
  5. python网络编程100例_python网络编程
  6. 检测到 LoaderLock Message Microsoft.DirectX.dll”正试图在 OS 加载程序锁内执行托管代码。...
  7. ES6语法---箭头函数/关于this指向
  8. linux恢复硬盘工具,linux硬盘数据恢复工具
  9. 电子电路学习笔记(12)——稳压二极管
  10. 贝尔曼方程怎么解_哈密顿-雅可比-贝尔曼方程
  11. Javaweb制定的订餐系统+jsp+servlet+Java+MySQL
  12. 阿里云ACP如何线上考试
  13. sodo与visodo的日常使用
  14. js获取本月初与月底的时间、获取前一天的时间。
  15. 微信h5图表组件制作教程
  16. 深度学习之图像分类(十六)-- EfficientNetV2 网络结构
  17. oracle sqlnet配置,sqlnet.ora文件配置详解
  18. 2017 ACM-ICPC 亚洲区(西安赛区)网络赛 B Coin(逆元,费马小定理)
  19. 下载kaggle数据集出现的一系列问题
  20. 最小化函数minimize

热门文章

  1. 【AiLearning】test3:搭建Deep Netural Network
  2. 【Spring Boot】使用JDBC 获取相关的数据
  3. 解析新时代人工智能机器人的工作原理
  4. 【中英双语】C 语言的历史
  5. 详解硬件设计中电容电感磁珠
  6. Siggraph三角网格变形之拉普拉斯变换
  7. reporting service odbc mysql_Reporting Services
  8. 论文研读-多目标自适应memetic算法
  9. 分享工作上的一些体会
  10. 误差棒是什么?误差柱状图如何做?