本文将介绍两种比RNN更好地应对梯度消失问题的模型结构——LSTM和GRU,文章以CS224N的课件和材料为基础,重点分析他们的结构特点和梯度计算,在梯度消失的解决策略上进行了深入探究,并进一步分析它们的优缺点和应用场景。

目录

  • 一、背景知识
  • 二、LSTM的原理与结构
    • 1.模型结构
    • 2. 如何解决梯度消失
  • 三、GRU的原理与结构
  • 四、LSTM与GRU的选择
  • 五、RNN的其他变种模型
    • 1. 双向RNN
    • 2. 多层RNN
  • 六、参考文献

一、背景知识

循环神经网络RNN由于模型结构上的缺陷,很容易引起梯度爆炸和梯度消失,梯度爆炸可以用梯度截断方法在一定程度上缓解其影响,但是梯度消失几乎是致命缺陷,没有什么好办法可以解决它,这使得训练变得困难,模型很可能只受短时约束,长时约束的作用被大大削弱,学习不到相隔较远的两个词之间的联系。本文介绍的两种新的神经网络结构LSTM和RNN,可以很好地应对这个问题。

二、LSTM的原理与结构

1.模型结构

LSTM在模型结构上相对于RNN而言有两大变动:

  1. 新增了三个独特的门结构,用来控制信息地流动
  2. 增添了细胞状态cell state,同时也保留了原来的隐状态hiden state

其整体的模型结构图如下所示,由多个结构相同的LSTM模块组成:

LSTM结构的细节图:

课件上这张图的来源于参考文献2,大家可以去看看那篇文章对LSTM每个步骤进行拆解,下面的公式讲以图中的符号为准,可能会与课件中有一点出入。

符号解释

细胞状态 CtC_tCt​: Ct=ft⊗Ct−1+it⊗C~tC_t = f_t \otimes C_{t-1} + i_t \otimes \tilde{C}_{t}Ct​=ft​⊗Ct−1​+it​⊗C~t​

细胞状态新内容 C~t\tilde{C}_tC~t​: C~t=tanh(Wcht−1+Ucxt+bc)\tilde{C}_{t} = tanh(W_ch_{t-1}+U_cx_t+b_c)C~t​=tanh(Wc​ht−1​+Uc​xt​+bc​)

隐状态 hth_tht​: ht=ot⊗tanh(Ct)h_t=o_t \otimes tanh(C_t)ht​=ot​⊗tanh(Ct​)

遗忘门 ftf_tft​: ft=σ(Whht−1+Ufxt+bf)f_t=\sigma(W_hh_{t-1}+U_fx_t+b_f)ft​=σ(Wh​ht−1​+Uf​xt​+bf​)

输入门 iti_tit​: fi=σ(Wiht−1+Uixt+bi)f_i=\sigma(W_ih_{t-1}+U_ix_t+b_i)fi​=σ(Wi​ht−1​+Ui​xt​+bi​)

输出门 oto_tot​: ft=σ(Whht−1+Uoxt+bo)f_t=\sigma(W_hh_{t-1}+U_ox_t+b_o)ft​=σ(Wh​ht−1​+Uo​xt​+bo​)

三个门结构

LSTM的门结构充当信息的关口,它们决定了信息是否能够完全流通,取值范围都是(0, 1),0则完全不让通过,1则完全通过。三个门结构的计算方法是一模一样的,只是用了相互独立的参数,LSTM的参数量相比于RNN多了许多,一定程度上提高了模型容量。注意在参考文献2中的写法不太一样,但其实只是将两个参数WWW和UUU给合并了,本质上是一样的。

遗忘门会作用到上一时刻的细胞状态Ct−1C_{t-1}Ct−1​,将句子中的一些历史内容遗忘掉,举个例子,一个句子中如果出现了he,那么模型可能会记住该信息,后面的谓语要用单数形式比如is,如果紧接着出现了they,那么模型可能需要忘掉之前的主语he,后面的谓语需要用复数形式are,当然这只是一个理想化的例子,真实模型具体编码了什么我们很难得知,这只是以人的思维赋予了模型它可能需要的能力。

输入门作用到细胞新内容C~t\tilde{C}_tC~t​,要添加到细胞状态的新内容也许不是全都需要,所以用输入门减小部分元素或者清零。这部分就相对抽象,因为细胞新内容C~t\tilde{C}_tC~t​和遗忘门一样也是通过ht−1h_{t-1}ht−1​和xtx_txt​计算出来的,只是选用的激活函数不同,为什么要这么分两步走。可以这么想:细胞新内容C~t\tilde{C}_tC~t​是计算出了一些备选的新信息,输入门对这些信息进行挑选后再添加到细胞状态中。

输出门则是作用到细胞状态CtC_tCt​中,从细胞状态中挑选出信息作为隐状态的输出。

细胞状态

LSTM中一个重要结构为细胞状态,值得详细展开,它贯穿整个LSTM模型,用来存储句子上下文信息,相当于RNN中将上下文信息编码在隐状态中,LSTM的细胞状态具有更强的信息保存能力,内容不容易被完全清除,也即能更好地捕捉长距离词语间的关系。为什么说它的内容不容易被完全清除,我们回顾它的计算方法:

Ct=ft⊗Ct−1+it⊗C~t(1)C_t = f_t \otimes C_{t-1} + i_t \otimes \tilde{C}_{t} \tag{1}Ct​=ft​⊗Ct−1​+it​⊗C~t​(1)

抛开遗忘门和输入门的作用不谈,当前时刻的细胞状态CtC_tCt​,是上一时刻的细胞状态Ct−1C_{t-1}Ct−1​与新添加的细胞内容C~t\tilde C_tC~t​的以相加的形式获得的,而RNN中上下文信息都放在hth_tht​中,它的计算过程中会通过参数矩阵WhW_hWh​与上一隐状态ht−1h_{t-1}ht−1​以矩阵相乘的形式获得,并不断重复该过程,如果参数矩阵WhW_hWh​的特征值都很小(或者模很小),那么在多次矩阵相乘过程中,hth_tht​可能变得越来越小,上下文信息都已经丢失了。

那么有人可能会问,细胞状态一直这么加下去,CtC_tCt​不会到后面变得异常地大吗?确实是会这样,在初代的LSTM中,没有设置遗忘门,细胞状态的计算方式是:

Ct=Ct−1+it⊗C~tC_t = C_{t-1} + i_t \otimes \tilde{C}_tCt​=Ct−1​+it​⊗C~t​

这种形式的确非常容易使得细胞状态到后面异常地大,所以才设置了遗忘门ftf_tft​,让它与上一时刻的细胞状态进行元素级相乘,有机会减小某些元素的值,甚至清零,这样就保证了细胞状态没有无节制地增长。

2. 如何解决梯度消失

LSTM的模型结构讲述完毕,但是仅从模型结构来看,还是很难解释为什么LSTM能够应对梯度消失。其实上面已经涉及到一点点,关键就是LSTM的细胞状态,它存储着句子的上下文信息,像一条传送带一样贯穿整个模型,而且是以相加元素级形式获得的。我们可以先感性地理解为什么不会梯度消失:

  1. LSTM中存在多条通路,多条通路的梯度以相加的形式汇聚,一条路的梯度为0不至于全部梯度为0
  2. LSTM中存在遗忘门和细胞状态,可以保证历史信息不那么容易被清除。

但是这么说还是还有抽象,我们来真正计算一下梯度。回顾前一篇文章中说RNN(链接文章)梯度消失主要是因为两个时刻间隐状态的梯度是WhW_hWh​的幂次这种形式,WhW_hWh​如果很小,时间距离又很远的话,梯度就消失了:

∂h(t)∂h(i)=∏j=i+1t∂h(j)∂h(j−1)=∏j=i+1tdiag(σ′(Whh(j−1)+Wee(t)+b1))×Wh(2)\frac{\partial h^{(t)}}{\partial h^{(i)}} = \prod_{j=i+1}^t \frac{\partial h^{(j)}}{\partial h^{(j-1)}} = \prod_{j=i+1}^t diag(\sigma'(W_hh^{(j-1)}+W_ee^{(t)} + b_1)) \times W_h \tag{2}∂h(i)∂h(t)​=j=i+1∏t​∂h(j−1)∂h(j)​=j=i+1∏t​diag(σ′(Wh​h(j−1)+We​e(t)+b1​))×Wh​(2)

由于LSTM的上下文信息存储在细胞状态,我们重点来看下前后两个时刻细胞状态的梯度∂Ct∂Ct−1\frac{\partial C_t}{\partial C_{t-1}}∂Ct−1​∂Ct​​。公式(1)(1)(1)表明,C_t是关于ftf_tft​,Ct−1C_{t-1}Ct−1​,iti_tit​,C~t\tilde{C}_{t}C~t​的函数,而它们都是元素级乘法和加法,所以梯度相对好求,可以套用(uv)′=uv′+u′v(uv)'=uv'+u'v(uv)′=uv′+u′v:

∂Ct∂Ct−1=ft×∂Ct−1∂Ct−1+Ct−1×∂ft∂Ct−1+it×∂C~t∂Ct−1+C~t×∂it∂Ct−1=ft+Ct−1×∂ft∂Ct−1+it×∂C~t∂Ct−1+C~t×∂it∂Ct−1(3)\begin{aligned} \frac{\partial C_t}{\partial C_{t-1}} =& f_t \times \frac{\partial C_{t-1}}{\partial C_{t-1}} + C_{t-1} \times \frac{\partial f_t}{\partial C_{t-1}} + i_t \times \frac{\partial \tilde{C}_t }{\partial C_{t-1}} + \tilde{C}_t \times \frac{\partial i_t}{\partial C_{t-1}} \\ =& f_t + C_{t-1} \times \frac{\partial f_t}{\partial C_{t-1}} + i_t \times \frac{\partial \tilde{C}_t}{\partial C_{t-1}} + \tilde{C}_t \times \frac{\partial i_t}{\partial C_{t-1}} \end{aligned} \tag{3}∂Ct−1​∂Ct​​==​ft​×∂Ct−1​∂Ct−1​​+Ct−1​×∂Ct−1​∂ft​​+it​×∂Ct−1​∂C~t​​+C~t​×∂Ct−1​∂it​​ft​+Ct−1​×∂Ct−1​∂ft​​+it​×∂Ct−1​∂C~t​​+C~t​×∂Ct−1​∂it​​​(3)

上面关键就是第一项,由于∂Ct−1∂Ct−1\frac{\partial C_{t-1}}{\partial C_{t-1}}∂Ct−1​∂Ct−1​​的结果是单位矩阵,所以第一项只剩下一个遗忘门ftf_tft​,它不需要与其他矩阵相乘,所以只要遗忘门是1,可以保证∂Ct∂Ct−1\frac{\partial C_t}{\partial C_{t-1}}∂Ct−1​∂Ct​​至少是一个1向量,这样损失函数JtJ_tJt​关于C1C_1C1​的梯度∂Jt∂Ct\frac{\partial J_t}{\partial C_t}∂Ct​∂Jt​​可以沿着细胞状态的通路无损地传送到下去,而不会在中途因为存在0向量所使得传到前面时梯度已经消失,即:

∂Jt∂C1=∂Jt∂Ct∂Ct∂Ct−1∂Ct−1∂Ct−2...∂C2∂C1≠0(4)\frac{\partial J_t}{\partial C_1} = \frac{\partial J_t}{\partial C_t} \frac{\partial C_t}{\partial C_{t-1}} \frac{\partial C_t-1}{\partial C_{t-2}} ...\frac{\partial C_2}{\partial C_{1}} \not= \bold{0} \tag{4}∂C1​∂Jt​​=∂Ct​∂Jt​​∂Ct−1​∂Ct​​∂Ct−2​∂Ct​−1​...∂C1​∂C2​​​=0(4)

这里需要提醒大家注意,在知乎等平台上看到很多文章都喜欢引用或翻译文献4中的说法,那里也是计算了梯度∂Ct∂Ct−1\frac{\partial C_t}{\partial C_{t-1}}∂Ct−1​∂Ct​​,通过ftf_tft​这一项说明梯度不至于完全消失,本文也是借鉴了这种说法,但是那篇文章中,∂Ct∂Ct−1\frac{\partial C_t}{\partial C_{t-1}}∂Ct−1​∂Ct​​的计算是错误的:

图中等式两边红色方框的项都是∂Ct∂Ct−1\frac{\partial C_t}{\partial C_{t-1}}∂Ct−1​∂Ct​​,两项完全一致,直接就消掉了,更离谱的是后边的∂Ct∂Ct−1\frac{\partial C_t}{\partial C_{t-1}}∂Ct−1​∂Ct​​计算等于ftf_tft​,这样等号左边的∂Ct∂Ct−1\frac{\partial C_t}{\partial C_{t-1}}∂Ct−1​∂Ct​​还有什么好算的。只能说歪打正着,尽管∂Ct∂Ct−1\frac{\partial C_t}{\partial C_{t-1}}∂Ct−1​∂Ct​​是会出现ftf_tft​这独立的一项,但不是这样来的。

真正的计算方法应该是这样,从公式(3)(3)(3)出发,iti_tit​,ftf_tft​,C~t\tilde{C}_{t}C~t​都是关于ht−1h_{t-1}ht−1​的函数,ht−1h_{t-1}ht−1​是又关于Ct−1C_{t-1}Ct−1​的函数,这样我们根据链式法则,可以计算得到:

∂Ct∂Ct−1=ft+Ct−1×∂ft∂Ct−1+it×∂C~t∂Ct−1+C~t×∂it∂Ct−1=ft+Ct−1×∂ft∂ht−1∂ht−1∂Ct−1+it×∂C~t∂ht−1∂ht−1∂Ct−1+C~t×∂it∂ht−1∂ht−1∂Ct−1(5)\begin{aligned} \frac{\partial C_t}{\partial C_{t-1}} =& f_t + C_{t-1} \times \frac{\partial f_t}{\partial C_{t-1}} + i_t \times \frac{\partial \tilde{C}_t}{\partial C_{t-1}} + \tilde{C}_{t} \times \frac{\partial i_t}{\partial C_{t-1}} \\ =&f_t + C_{t-1} \times \frac{\partial f_t}{\partial h_{t-1}}\frac{\partial h_{t-1}}{\partial C_{t-1}} + i_t \times \frac{\partial \tilde{C}_t}{\partial h_{t-1}}\frac{\partial h_{t-1}}{\partial C_{t-1}} + \tilde{C}_t \times \frac{\partial i_t}{\partial h_{t-1}}\frac{\partial h_{t-1}}{\partial C_{t-1}}\end{aligned} \tag{5}∂Ct−1​∂Ct​​==​ft​+Ct−1​×∂Ct−1​∂ft​​+it​×∂Ct−1​∂C~t​​+C~t​×∂Ct−1​∂it​​ft​+Ct−1​×∂ht−1​∂ft​​∂Ct−1​∂ht−1​​+it​×∂ht−1​∂C~t​​∂Ct−1​∂ht−1​​+C~t​×∂ht−1​∂it​​∂Ct−1​∂ht−1​​​(5)

这条公式最关键的还是第一项遗忘门ftf_tft​,当它为1是梯度不至于消失,但是需要注意的是,它是否为1是由模型自己学习的,我们只能从结构上保证它有联系长距离上下文的能力,但也许长距离的上下文真的没有很强的关系呢?而在模型训练初始化时,一般还是会将遗忘门初始化为1,保证梯度能够无损地传递,从功能来理解,是认为所有上下文信息都需要保留,至于是不是真的要保留,交由模型在后续的训练中学习。

最后还有两点需要注意:

  1. 上面计算的是细胞状态通路的梯度,它不那么容易梯度消失,但是其他通路跟RNN很像,在梯度计算中仍然会出现参数矩阵的幂次,也是很有可能出现梯度消失的。LSTM解决梯度消失的最重要途径就是顶上细胞状态这一条传送带。
  2. LSTM并不保证完全不发生梯度消失,只是相比起RNN更加稳定。

三、GRU的原理与结构

LSTM中存在三个门结构,参数量较大,计算缓慢,因此有学者对它进行了以下精简:

  1. 将细胞状态和隐状态又重新合并成了单独的隐状态
  2. 将遗忘门和输入门合并成了更新门(update gate),它控制哪些信息需要进行更新,哪些信息进行保留
  3. 设置了重置门(reset gate),作用是控制旧的隐状态中的哪些内容可以参与新隐状态的计算
  4. 由于细胞状态和隐状态合二为一了,也就没有必要设置输出门了,输出门被删除

最终的模型结构如下,注意这幅图来自参考文献4,其中的符号和CS224N中所采用的不一致:

符号解释:

重置门 rtr_trt​: rt=σ(Wrht−1+Urxt+br)r_t = \sigma(W_rh_{t-1} + U_rx_t +b_r)rt​=σ(Wr​ht−1​+Ur​xt​+br​)

更新门 ztz_tzt​: zt=σ(Wzht−1+Uzxt+bz)z_t = \sigma(W_zh_{t-1} + U_zx_t +b_z)zt​=σ(Wz​ht−1​+Uz​xt​+bz​)

隐状态的新内容h~t\tilde{h}_th~t​: h~t=tanh(Wh(rt⊗ht−1)+Uhxt+bh)\tilde{h}_t = tanh(W_h(r_t \otimes h_{t-1})+U_hx_t+b_h)h~t​=tanh(Wh​(rt​⊗ht−1​)+Uh​xt​+bh​)

隐状态hth_tht​: ht=(1−zt)⊗ht−1+zt⊗h~th_t = (1-z_t) \otimes h_{t-1} +z_t \otimes \tilde{h}_tht​=(1−zt​)⊗ht−1​+zt​⊗h~t​

那么GRU能否应对梯度消失呢?答案是可以的,看到图中最上方那条贯穿的通路和LSTM中的细胞状态是不是很类似,而且同样也存在一个元素级加法操作,所以GRU中的隐状态与LSTM中的细胞状态一样,前后两个时刻间的梯度也会出现一个独立项,只不过是由遗忘门ftf_tft​变成了更新门ztz_tzt​,只要更新门是1向量,至少可以保证∂ht∂ht−1\frac{\partial h_t}{\partial h_{t-1}}∂ht−1​∂ht​​不会完全为0,隐状态通道上的梯度可以一直传递到最前方。

四、LSTM与GRU的选择

LSTM和GRU都能缓解了RNN中梯度消失的问题,使得长距离上下文信息的捕捉变得更加容易,但是LSTM参数量大,收敛较慢,计算耗时,GRU比起LSTM它的参数量较少,计算相对较快,也减少了过拟合的风险。但是具体该用哪一个,取决于数据量和效率要求,如果数据充足,LSTM可以提供更好的性能,如果要求计算快些,可以试试GRU。

五、RNN的其他变种模型

1. 双向RNN

我们前面说的上下文信息严格来说只是前文信息,后文是还没有输入到模型中的,但是有时候句子的关键信息可能是在后文出现,所以我们希望句子既要正向输入,也要反向输入,分别计算隐状态,再进行融合。但这个模型的应用场景有限制,需要我们拥有全文语料,像实时机器翻译这种场景就不合适,因为并不知道后文。

2. 多层RNN

在另外一个维度堆叠参数,可以帮助网络学习到更深层的语义信息,如果作为编码器,一般是堆2~4层,作为解码器一般堆4层,如果还需要更深,则可能需要用到跳层连接或者像densenet那样的密集连接。

六、参考文献

  1. CS224 Lecture 7 slides & notes
  2. http://colah.github.io/posts/2015-08-Understanding-LSTMs/
  3. https://zhuanlan.zhihu.com/p/109519044
  4. https://weberna.github.io/blog/2017/11/15/LSTM-Vanishing-Gradients.html
  5. [https://www.zhihu.com/question/34878706/answer/665429718](

CS224N笔记(四) Lecture 7:循环神经网络RNN的进阶——LSTM与GRU相关推荐

  1. 循环神经网络RNN(含LSTM,GRU)小综述

    文章目录 前言 一.RNN循环神经网络 1.1 RNN的结构 1.2 BRNN的结构 1.3 梯度消失和梯度爆炸 二.LSTM 2.1 引子 2.2 LSTM单元 2.3 LSTM的补充理解方式和变种 ...

  2. tensorflow实现循环神经网络——经典网络(LSTM、GRU、BRNN)

    参考链接: https://www.cnblogs.com/tensorflownews/p/7293859.html http://www.360doc.com/content/17/0321/10 ...

  3. 【一起入门NLP】中科院自然语言处理第5课-循环神经网络RNN(BPTT+LSTM+GRU)

    专栏介绍:本栏目为 "2021秋季中国科学院大学胡玥老师的自然语言处理" 课程记录,不仅仅是课程笔记噢- 如果感兴趣的话,就和我一起入门NLP吧

  4. 【NLP】毕设学习笔记(八)“前馈 + 反馈” = 循环神经网络RNN

    前馈神经网络和循环神经网络分别适合处理什么样的任务? 如果分类任务仅仅是进行判断和识别,例如判断照片上的人的性别,识别图片上是否有小狗图案,那么对输入的数据仅仅需要做特征寻找的工作即可,找到满足该任务 ...

  5. 【李宏毅机器学习笔记】 23、循环神经网络(Recurrent Neural Network,RNN)

    [李宏毅机器学习笔记]1.回归问题(Regression) [李宏毅机器学习笔记]2.error产生自哪里? [李宏毅机器学习笔记]3.gradient descent [李宏毅机器学习笔记]4.Cl ...

  6. 【从线性回归到 卷积神经网络CNN 循环神经网络RNN Pytorch 学习笔记 目录整合 源码解读 B站刘二大人 绪论(0/10)】

    深度学习 Pytorch 学习笔记 目录整合 数学推导与源码详解 B站刘二大人 目录传送门: 线性模型 Linear-Model 数学原理分析以及源码详解 深度学习 Pytorch笔记 B站刘二大人( ...

  7. 循环神经网络RNN、LSTM、GRU实现股票预测

    Tensorflow--循环神经网络RNN 循环核 TensorFlow描述循环核 循环神经网络 TensorFlow描述循环神经网络 循环计算过程 输入一个字母,预测下一个字母 输入四个连续字母,预 ...

  8. 深度学习 -- TensorFlow(9)循环神经网络RNN

    目录 一.循环神经网络RNN介绍 二.Elman network && Jordan network 三.RNN的多种架构 1.一对一 2.多对一 3.多对多 4. 一对多 5.Seq ...

  9. 深度学习~循环神经网络RNN, LSTM

    目录 1. 循环神经网络RNN 1.1 RNN出现背景 1.2 RNN概念 2. LSTM 2.1 LSTM出现背景 2.2 LSTM结构 参考 1. 循环神经网络RNN 1.1 RNN出现背景 pr ...

最新文章

  1. 简约设计中的规律—色彩(二)
  2. 模糊pid控制的温度系统matlab源代码_变风量空调模糊 PID 控制系统的仿真研究
  3. 303. 区域和检索 - 数组不可变
  4. Linus Torvalds 回应,Debian 项目曾讨论永久禁止他出席会议!
  5. rownum的用法oracle
  6. html图片在桌面的路径,桌面路径无法恢复以前的路径
  7. vue 用echarts写的进度条组件
  8. ​给想闷声发财的小伙伴35条忠告
  9. Pycharm typo PEP 8
  10. python opencv图像叠加/图像融合/mask掩模
  11. 服务器怎么建ip网站,云服务器搭建网站ip
  12. 《名贤集》《明贤集》四言集
  13. gRPC快速入门(三)——Protobuf应用示例
  14. 山大泰山学堂笔试面试经验
  15. c语言 百度文库,百度文库C语言专本辅导第一二章.doc
  16. 【深度学习前沿应用】目标检测
  17. Xshell里面查看文件中文乱码问题
  18. 京东万象行驶证识别api
  19. Android 的Toast(吐丝框)
  20. Apache Calcite教程 -目录

热门文章

  1. 篮球数据API接口 - 【即时指数1】API调用示例代码
  2. 纳米数据,足球篮球实时数据比分,体育赛事比分接口代码,实时数据推送演示
  3. 这波分享得你们都爱了吗?
  4. Webapp开发框架Clouda的使用(一)
  5. VS2017安装打包工具;以及无法加载此项目,setup(不兼容),该应用程序未安装、MFC的使用
  6. 讲义六 之 docker 搭建测试环境以及部署项目包 created by 爱软测_bill
  7. 2022.10.9 英语背诵
  8. kong翻译_Kong[孔]的中文翻译及英文名意思
  9. 素描静物绘画需要掌握的基础知识有哪些
  10. python实现视频分割