• 书籍简介
  • LSTM理解
    • LSTM流程简介
  • 算法及公式
    • 一些函数
    • 一些符号
  • 前向传播
  • 反向传播
    • 关于误差的定义
    • 公式推导
  • 总结

书籍简介

《Surpervised Sequence Labelling with Recurrent Neural Network》(《用循环神经网络进行序列标记》),RNN(Recurrent Neural Network,循环神经网络)经典教材,由多伦多大学Alexander Graves所著,详细叙述了各种RNN模型及其推导。本文介绍该书的LSTM部分。对于该书,想深入了解的朋友点这里获取资源。

LSTM理解

  LSTM(Long Short-Term Memory Networks,长短时记忆网络),由Hochreiter和Schmidhuber于1997年提出,目的是解决一般循环神经网络中存在的梯度爆炸(输入信息激活后权重过小)及梯度消失(例如sigmoid、tanh的激活值在输入很大时其梯度趋于零)问题,主要通过引入门和Cell状态的概念来实现梯度的调整,已被大量应用于时间序列预测等深度学习领域。
  下面的描述主要侧重公式推导,对LSTM来由更详细的讨论请见《Step-by-step to LSTM: 解析LSTM神经网络设计原理》。

LSTM流程简介

  LSTM采用了门控输出的方式,即三门(输入门、遗忘门、输出门)两态(Cell State长时、Hidden State短时)。其核心即Cell State,指用于信息传播的Cell的状态,在结构示意图(图1,图源Understanding LSTMs,略改动)中是最上面的直链(从Ct−1C_{t-1}Ct−1​到CtC_tCt​)。

图1

  Memory Cell 接受两个输入,即上一时刻的输出值ht−1h_{t-1}ht−1​和本时刻的输入值xtx_txt​,由这两个参数 先进入遗忘门,得到决定要舍弃的信息 ftf_tft​(即权重较小的信息)后,再进入输入门,得到决定要更新的信息 iti_tit​(即与上一Cell相比权重较大的信息)以及当前时刻的Cell状态 C~t\tilde{C}_tC~t​(候选向量,可理解为中间变量,存储当前 Cell State 信息),最后由这两个门(遗忘门,输入门)的输出值(即 ft,it,Ct~f_t,i_t,\tilde{C_t}ft​,it​,Ct​~​)进行组合(上一Cell状态Ct−1×C^{t-1}\timesCt−1×要遗忘信息的激活值ftf_tft​ 与 当前时刻Cell状态Ct~×\tilde{C_t}\timesCt​~​×需要记忆信息的激活值iti_tit​进行叠加,从图中可以更直观得到),得到分别的长时(CtC_tCt​)和短时(hth_tht​)信息,最后进行存储操作及对下一个神经元的输入。下图2介绍了LSTM在网络中是如何工作的。

图2

根据图1,可依次得到三个门的形式方程如下(符号与图中保持一致):

  1. 遗忘门:

ft=σ(Wf⋅[ht−1,xt]+bf)f_t=\sigma\left(W_f\cdot[h_{t-1}, x_t]+b_f\right)ft​=σ(Wf​⋅[ht−1​,xt​]+bf​)

  1. 输入门:

it=σ(Wi⋅[ht−1,xt]+bi)i_t=\sigma\left(W_i\cdot[h_{t-1}, x_t]+b_i\right)it​=σ(Wi​⋅[ht−1​,xt​]+bi​)

Ct~=tanh⁡(WC⋅[ht−1,xt]+bC)\tilde{C_t}=\tanh\left(W_C\cdot[h_{t-1}, x_t]+b_C\right)Ct​~​=tanh(WC​⋅[ht−1​,xt​]+bC​)

以及ttt时刻的Cell 状态(长时)方程:

Ct=ft⋅Ct−1+it⋅Ct~C_t=f_t\cdot C_{t-1}+i_t\cdot \tilde{C_t}Ct​=ft​⋅Ct−1​+it​⋅Ct​~​

  1. 输出门:

ot=σ(Wo⋅[ht−1,xt]+bo)o_t=\sigma\left(W_o\cdot[h_{t-1}, x_t]+b_o\right)ot​=σ(Wo​⋅[ht−1​,xt​]+bo​)

ht=ot⋅tanh⁡(Ct)h_t=o_t\cdot\tanh{(C_t)}ht​=ot​⋅tanh(Ct​)

算法及公式

  根据上面的描述及图1,首先定义如下符号(符号为方便理解,与书中保持一致):

一些函数

  • fff :的激活函数
  • ggg :Cell输入的激活函数
  • hhh :Cell输出的激活函数

  • L\mathcal{L}L : 训练模型时的损失函数
  • σ(z)\sigma(z)σ(z):Sigmoid激活函数
    σ(z)=11+e−z=1+tanh⁡(z/2)2,\sigma(z)=\frac{1}{1+\mathrm{e}^{-z}}=\frac{1+\tanh(z/2)}{2},σ(z)=1+e−z1​=21+tanh(z/2)​,

σ′(z)=σ(z)[1−σ(z)].\sigma'(z)=\sigma(z)[1-\sigma(z)].σ′(z)=σ(z)[1−σ(z)].

  • tanh⁡(z)\tanh(z)tanh(z):tanh激活函数
    tanh⁡(z)=ez−e−zez+e−z,\tanh(z)=\frac{\mathrm{e}^z-\mathrm{e}^{-z}}{\mathrm{e}^z+\mathrm{e}^{-z}},tanh(z)=ez+e−zez−e−z​,

tanh⁡′(z)=1−tanh⁡2(z).\tanh'(z)=1-\tanh^2(z).tanh′(z)=1−tanh2(z).

一些符号

  • III :输入层 信息的数量
  • KKK :输出层 信息的数量
  • HHH :隐层 Cell状态的数量(注意这里的Cell与下面的Cell不同,代表短时记忆Cell),指图1中最下面的一条直链,即从ht−1h_{t-1}ht−1​到hth_tht​,处理短时记忆
  • CCC :Cell状态信息(长时记忆状态)的数量
  • TTT :总时间数(网络层总数),即t=0,1,2,⋯,Tt=0,1,2,\cdots,Tt=0,1,2,⋯,T

  • ϕ\phiϕ :下标,指一个LSTM单元的遗忘门
  • ι\iotaι :下标,指一个LSTM单元的输入门
  • ω\omegaω :下标,指一个LSTM单元的输出门
  • ccc :下标,指神经元中某一个CCC 记忆元胞(Cell)

  • wijw_{ij}wij​ :从单元iii到单元jjj的权重
  • bjtb_j^tbjt​ :ttt时刻第jjj个单元的激活值,在t=0t=0t=0时初始化为000
  • ajta_j^tajt​ :ttt时刻第jjj个单元的带权输入,可作抽象定义如下

ajt=∑iwijbit−1.a_j^t=\sum_{i}{w_{ij}b_{i}^{t-1}}.ajt​=i∑​wij​bit−1​.

  • scts_c^tsct​ :ttt时刻记忆元胞 ccc 的状态(State),在t=0t=0t=0时初始化为000

  • δjt\delta_j^tδjt​ :ttt时刻第jjj个单元的误差,在t=T+1t=T+1t=T+1时初始化为000。一般化的定义为

δjt=∂L∂ajt.\delta_j^t=\frac{\partial \mathcal{L}}{\partial a_j^t}.δjt​=∂ajt​∂L​.

前向传播

由上述的形式方程,很容易得到下面的前向传播公式:

  1. 遗忘门。由图1可知,遗忘门的输出依赖三个变量(图1中表示为左下角的两个输入和左上角的一个输入),分别是:上一时刻(t−1)(t-1)(t−1)神经元的短时记忆输出ht−1h_{t-1}ht−1​,本时刻(t)(t)(t)神经元的输入xtx_txt​以及上一时刻(t−1)(t-1)(t−1)神经元的长时记忆输出Cell状态sct−1s_c^{t-1}sct−1​,乘以权重因子后对层数求和即可得到遗忘门的输入值及激活值如下:

aϕt=∑i=1Iwiϕxit+∑h=1Hwhϕbht−1+∑c=1Cwcϕsct−1(1.1)a_\phi^t=\sum_{i=1}^Iw_{i\phi}x_i^t+\sum_{h=1}^{H}w_{h\phi}b_h^{t-1}+\sum_{c=1}^Cw_{c\phi}s_c^{t-1}\tag{1.1}aϕt​=i=1∑I​wiϕ​xit​+h=1∑H​whϕ​bht−1​+c=1∑C​wcϕ​sct−1​(1.1)

bϕt=f(aϕt)(1.2)b_\phi^t=f(a_\phi^t)\tag{1.2}bϕt​=f(aϕt​)(1.2)

  1. 输入门。其输出所依赖的变量与遗忘门相同,故同理可得

aιt=∑i=1Iwiιxit+∑h=1Hwhιbht−1+∑c=1Cwcιsct−1(2.1)a_\iota^t=\sum_{i=1}^Iw_{i\iota}x_i^t+\sum_{h=1}^{H}w_{h\iota}b_h^{t-1}+\sum_{c=1}^Cw_{c\iota}s_c^{t-1}\tag{2.1}aιt​=i=1∑I​wiι​xit​+h=1∑H​whι​bht−1​+c=1∑C​wcι​sct−1​(2.1)

bιt=f(aιt)(2.2)b_\iota^t=f(a_\iota^t)\tag{2.2}bιt​=f(aιt​)(2.2)

  1. Cell状态。由输入门的ttt时刻的Cell 状态(长时)方程立即可得。

act=∑i=1Iwicxit+∑h=1Hwhcbht−1(3.1)a_c^t =\sum_{i=1}^I w_{ic}x_i^t+\sum_{h=1}^H w_{hc}b_h^{t-1}\tag{3.1}act​=i=1∑I​wic​xit​+h=1∑H​whc​bht−1​(3.1)

一一对应形式方程即可得到scts_c^tsct​表达式如下

Ct=ft⋅Ct−1+it⋅Ct~⋮⋮⋮⋮⋮sct=bϕt⋅sct−1+bιt⋅g(act)(3.2)\begin{aligned} C_t&=f_t\cdot C_{t-1}+i_t\cdot \tilde{C_t} \\ \vdots& \quad\ \ \vdots\ \ \ \ \ \ \ \vdots \qquad \ \ \vdots\ \ \ \ \ \vdots\\ s_c^{t}&= b_\phi^t \cdot s_c^{t-1}\,+b_\iota^t \cdot g(a_c^t)\tag{3.2} \end{aligned}Ct​⋮sct​​=ft​⋅Ct−1​+it​⋅Ct​~​  ⋮       ⋮  ⋮     ⋮=bϕt​⋅sct−1​+bιt​⋅g(act​)​(3.2)

  1. 输出门。由遗忘门同理可得

aωt=∑i=1Iwiωxit+∑h=1Hwhωbht−1+∑c=1Cwcωsct−1(4.1)a_\omega^t=\sum_{i=1}^Iw_{i\omega}x_i^t+\sum_{h=1}^{H}w_{h\omega}b_h^{t-1}+\sum_{c=1}^Cw_{c\omega}s_c^{t-1}\tag{4.1}aωt​=i=1∑I​wiω​xit​+h=1∑H​whω​bht−1​+c=1∑C​wcω​sct−1​(4.1)

bωt=f(aωt)(4.2)b_\omega^t=f(a_\omega^t)\tag{4.2}bωt​=f(aωt​)(4.2)

  1. Cell输出。指激活后的Cell状态(短时记忆),同理可由形式方程一一对应得到,即

ht=ot⋅tanh⁡(Ct)⋮⋮⋮bct=bωt⋅h(sct)(5.1)\begin{aligned}h_t&=o_t\ \cdot\ \tanh{(C_t)}\tag{5.1} \\ \vdots&\ \ \ \quad\vdots\ \qquad \vdots\\ b_c^t&=b_\omega^t \cdot \ \ h(s_c^t)\end{aligned}ht​⋮bct​​=ot​ ⋅ tanh(Ct​)   ⋮ ⋮=bωt​⋅  h(sct​)​(5.1)

反向传播

  重头戏来了!建议不熟悉反向传播的朋友看一下我的另一篇文章nndl学习笔记(二)反向传播公式推导,帮助你快速理解&回顾反向传播。

  同样地,为了与前向传播对应,这里也采用五个部分进行证明。反向传播,其目的就是通过计算损失函数关于权重和偏置的偏导数(本例中不对偏置进行分析),从而得到每一个神经元上出现的误差(误差定义为损失函数对神经元输入的偏导数),最后均摊给每个神经元,以此逐步减小误差。因为需要反向传播,所以顺序与前向传播正好相反(从后往前计算)。

关于误差的定义

  • Cell 输出的误差(短时记忆)ϵct=∂L∂bct\epsilon_c^t=\frac{\partial \mathcal{L}}{\partial b_c^t}ϵct​=∂bct​∂L​
  • Cell 状态的误差(长时记忆)ϵst=∂L∂sct\epsilon_s^t=\frac{\partial \mathcal{L}}{\partial s_c^t}ϵst​=∂sct​∂L​
  • δjt\delta_j^tδjt​ :ttt时刻第jjj个单元的误差,在t=T+1t=T+1t=T+1时初始化为000。定义为

δjt=∂L∂ajt\delta_j^t=\frac{\partial \mathcal{L}}{\partial a_j^t}δjt​=∂ajt​∂L​

公式推导

  这些公式的核心,都是根据链式法则求偏导数,需要注意损失函数与哪些变量有关,找准变量,再应用求导法则,即可轻松计算出表达式。

  1. Cell输出(短时记忆)。
    首先找Cell输出与哪些量有关,从图1可以得知其只与隐层(Cell短时记忆状态)和输出层两个部分的信息有关,再根据误差定义δjt=∂L∂ajt\delta_j^t=\frac{\partial \mathcal{L}}{\partial a_j^t}δjt​=∂ajt​∂L​,可以得到:
    ϵct=∂L∂bct=∂L∂ajt∂ajt∂bct=∑h=1H∂L∂aht+1∂aht+1∂bct+∑k=1K∂L∂akt∂akt∂bct=∑h=1Hδht+1∂aht+1∂bct+∑k=1Kδkt∂akt∂bct\begin{aligned} \epsilon_c^t &=\frac{\partial \mathcal{L}}{\partial b_c^t} =\frac{\partial \mathcal{L}}{\partial a_j^t} \frac{\partial a_j^t}{\partial b_c^t} \\ &= \sum_{h=1}^H\frac{\partial \mathcal{L}}{\partial a_h^{t+1}} \frac{\partial a_h^{t+1}}{\partial b_c^t}+\sum_{k=1}^K\frac{\partial \mathcal{L}}{\partial a_k^t} \frac{\partial a_k^t}{\partial b_c^t} \\ &=\sum_{h=1}^H\delta_h^{t+1} \frac{\partial a_h^{t+1}}{\partial b_c^{t}} + \sum_{k=1}^K\delta_k^t \frac{\partial a_k^t}{\partial b_c^t} \end{aligned}ϵct​​=∂bct​∂L​=∂ajt​∂L​∂bct​∂ajt​​=h=1∑H​∂aht+1​∂L​∂bct​∂aht+1​​+k=1∑K​∂akt​∂L​∂bct​∂akt​​=h=1∑H​δht+1​∂bct​∂aht+1​​+k=1∑K​δkt​∂bct​∂akt​​​
    注意到这里HHH层时间状态取t+1t+1t+1而KKK层取ttt,是为了与前向传播式子的意义保持一致,即:隐层Cell状态前向传播需要前一时刻(t−1)(t-1)(t−1)的隐层Cell状态,而输出只需与本时刻输入的时刻(t)(t)(t)一致即可,而反向传播正好相反(具体可见图1)。
    再根据带权输入的一般定义(同上,需要根据情况构造定义式,即:HHH层时刻变化而KKK层时刻保持不变)
    ajt=∑iwijbit−1a_j^t=\sum_{i}{w_{ij}b_{i}^{t-1}}ajt​=i∑​wij​bit−1​
    代入得到(注意这里有一步化简,去掉求和号,具体原因可见nndl学习笔记(二)反向传播公式推导公式一的推导部分):
    ϵct=∑h=1Hδht+1∂(wchbct)∂bct+∑k=1Kδkt∂(wckbct)∂bct=∑h=1Hδht+1wch+∑k=1Kδktwck\begin{aligned} \epsilon_c^t&=\sum_{h=1}^H\delta_h^{t+1} \frac{\partial (w_{ch}b_c^{t})}{\partial b_c^{t}}+\sum_{k=1}^K\delta_k^t \frac{\partial (w_{ck}b_c^{t})}{\partial b_c^{t}} \\ &=\sum_{h=1}^H\delta_h^{t+1}w_{ch}+\sum_{k=1}^K\delta_k^tw_{ck} \end{aligned}ϵct​​=h=1∑H​δht+1​∂bct​∂(wch​bct​)​+k=1∑K​δkt​∂bct​∂(wck​bct​)​=h=1∑H​δht+1​wch​+k=1∑K​δkt​wck​​

  2. 输出门。
    这里只需用到误差定义式ϵct=∂L∂bct\epsilon_c^t=\frac{\partial \mathcal{L}}{\partial b_c^t}ϵct​=∂bct​∂L​及前向传播的(5.1)(5.1)(5.1)式,最后一步求和是指针对所有神经元输出门激活值误差的叠加。
    δωt=∂L∂aωt=∂L∂bωt∂bωt∂aωt=∂L∂bωtf′(aωt)=f′(aωt)∂L∂bct∂bct∂bωt=f′(aωt)ϵct∂bct∂bωt=f′(aωt)ϵct∂[bωth(sct)]∂bωt=f′(aωt)∑c=1Ch(sct)ϵct\begin{aligned} \delta_\omega^t&=\frac{\partial \mathcal{L}}{\partial a_\omega^t} =\frac{\partial \mathcal{L}}{\partial b_\omega^t}\frac{\partial b_\omega^t}{\partial a_\omega^t} \\ &=\frac{\partial \mathcal{L}}{\partial b_\omega^t} f'(a_\omega^t) \\ &=f'(a_\omega^t) \frac{\partial \mathcal{L}}{\partial b_c^t} \frac{\partial b_c^t}{\partial b_\omega^t}\\ &= f'(a_\omega^t) \epsilon_c^t \frac{\partial b_c^t}{\partial b_\omega^t} \\ &= f'(a_\omega^t) \epsilon_c^t \frac{\partial \left[b_\omega^t h(s_c^t)\right]}{\partial b_\omega^t} \\ &=f'(a_\omega^t)\sum_{c=1}^Ch(s_c^t)\epsilon_c^t \end{aligned}δωt​​=∂aωt​∂L​=∂bωt​∂L​∂aωt​∂bωt​​=∂bωt​∂L​f′(aωt​)=f′(aωt​)∂bct​∂L​∂bωt​∂bct​​=f′(aωt​)ϵct​∂bωt​∂bct​​=f′(aωt​)ϵct​∂bωt​∂[bωt​h(sct​)]​=f′(aωt​)c=1∑C​h(sct​)ϵct​​

  3. Cell状态(长时记忆)。最长的一个式子,但是把握好变量之间的关系就可以轻松得出( 直接寻找前向传播众多公式中哪个含有变量 scts_c^tsct​,这样再进行链式法则处理,会更加直观,由于五个式子都含有scts_c^tsct​,故下面第四个等号后的式子有五项)。
      推导过程与Cell输出(短时记忆)部分类似,要用到误差的一般定义δjt=∂L∂ajt\delta_j^t=\frac{\partial \mathcal{L}}{\partial a_j^t}δjt​=∂ajt​∂L​,并注意到本时刻Cell状态(长时记忆)是由上一时刻遗忘门(ϕ)(\phi)(ϕ)和输入门(ι)(\iota)(ι)的输出共同决定的(反映在图上就是图1中上面直链的加号);在反向传播中,除了需要将Cell状态(长时记忆)的时间取反(sct+1)(s_c^{t+1})(sct+1​),还要考虑三个门误差的积累(第二个等号后式子第一项),注意这里计算输出门误差时没有取后一时刻t+1t+1t+1,是因为遗忘门和输入门的误差在前向传播时会传递给下一时刻的带权输入,故反向传播需要后一时刻来计算误差;而输出门误差在本时刻即可计算。反映到方程上为第二个等号后的方程。
    ϵst=∂L∂sct=∂L∂ajt+1∂ajt+1∂sct+∂L∂bct∂bct∂sct+∂L∂sct+1∂sct+1∂sct=δjt+1∂ajt+1∂sct+ϵct∂[bωth(sct)]∂sct+ϵst+1∂[bϕt+1⋅sct+bιt+1⋅g(act+1)]∂sct=δϕt+1∂aϕt+1∂sct+διt+1∂aιt+1∂sct+δωt∂aωt+1∂sct+ϵctbωth′(sct)+ϵst+1bϕt+1=δϕt+1∂(∑i=1Iwiϕxit+1+∑h=1Hwhϕbht+∑c=1Cwcϕsct)∂sct+διt+1∂(∑i=1Iwiιxit+1+∑h=1Hwhιbht+∑c=1Cwcιsct)∂sct+δωt∂(∑i=1Iwiωxit+1+∑h=1Hwhωbht+∑c=1Cwcωsct)∂sct+ϵctbωth′(sct)+ϵst+1bϕt+1=ϵctbωth′(sct)+ϵst+1bϕt+1+δϕt+1wcϕ+διt+1wcι+δωtwcω\begin{aligned} \epsilon_s^t &=\frac{\partial \mathcal{L}}{\partial s_c^t} \\ &=\frac{\partial \mathcal{L}}{\partial a_j^{t+1}} \frac{\partial a_j^{t+1}}{\partial s_c^t} + \frac{\partial \mathcal{L}}{\partial b_c^t} \frac{\partial b_c^t}{\partial s_c^t} + \frac{\partial \mathcal{L}}{\partial s_c^{t+1}} \frac{\partial s_c^{t+1}}{\partial s_c^t} \\ &= \delta_j^{t+1} \frac{\partial a_j^{t+1}}{\partial s_c^t} + \epsilon_c^t\frac{\partial \left[b_\omega^t h(s_c^t) \right]}{\partial s_c^t} + \epsilon_s^{t+1} \frac{\partial \left[ b_\phi^{t+1} \cdot s_c^{t}\,+b_\iota^{t+1} \cdot g(a_c^{t+1}) \right]}{\partial s_c^t} \\ &= \delta_\phi^{t+1} \frac{\partial a_\phi^{t+1}}{\partial s_c^t} + \delta_\iota^{t+1} \frac{\partial a_\iota^{t+1}}{\partial s_c^t} + \delta_\omega^t \frac{\partial a_\omega^{t+1}}{\partial s_c^t} + \epsilon_c^t b_\omega^t h'(s_c^t) + \epsilon_s^{t+1}b_{\phi}^{t+1} \\ &= \delta_\phi^{t+1} \frac{\partial \left( \sum_{i=1}^Iw_{i\phi}x_i^{t+1}+\sum_{h=1}^{H}w_{h\phi}b_h^{t}+\sum_{c=1}^Cw_{c\phi}s_c^{t} \right)}{\partial s_c^t} \\ &+ \delta_\iota^{t+1} \frac{\partial \left( \sum_{i=1}^Iw_{i\iota}x_i^{t+1}+\sum_{h=1}^{H}w_{h\iota}b_h^{t}+\sum_{c=1}^Cw_{c\iota}s_c^{t} \right)}{\partial s_c^t} \\ &+ \delta_\omega^t \frac{\partial \left( \sum_{i=1}^Iw_{i\omega}x_i^{t+1}+\sum_{h=1}^{H}w_{h\omega}b_h^{t}+ \sum_{c=1}^Cw_{c\omega}s_c^{t} \right)}{\partial s_c^t} \\ &+ \epsilon_c^t b_\omega^t h'(s_c^t) + \epsilon_s^{t+1}b_{\phi}^{t+1} \\ &= \epsilon_c^t b_\omega^t h'(s_c^t) + \epsilon_s^{t+1}b_{\phi}^{t+1} + \delta_\phi^{t+1}w_{c\phi} + \delta_\iota^{t+1} w_{c\iota} + \delta_\omega^t w_{c\omega} \\ \end{aligned}ϵst​​=∂sct​∂L​=∂ajt+1​∂L​∂sct​∂ajt+1​​+∂bct​∂L​∂sct​∂bct​​+∂sct+1​∂L​∂sct​∂sct+1​​=δjt+1​∂sct​∂ajt+1​​+ϵct​∂sct​∂[bωt​h(sct​)]​+ϵst+1​∂sct​∂[bϕt+1​⋅sct​+bιt+1​⋅g(act+1​)]​=δϕt+1​∂sct​∂aϕt+1​​+διt+1​∂sct​∂aιt+1​​+δωt​∂sct​∂aωt+1​​+ϵct​bωt​h′(sct​)+ϵst+1​bϕt+1​=δϕt+1​∂sct​∂(∑i=1I​wiϕ​xit+1​+∑h=1H​whϕ​bht​+∑c=1C​wcϕ​sct​)​+διt+1​∂sct​∂(∑i=1I​wiι​xit+1​+∑h=1H​whι​bht​+∑c=1C​wcι​sct​)​+δωt​∂sct​∂(∑i=1I​wiω​xit+1​+∑h=1H​whω​bht​+∑c=1C​wcω​sct​)​+ϵct​bωt​h′(sct​)+ϵst+1​bϕt+1​=ϵct​bωt​h′(sct​)+ϵst+1​bϕt+1​+δϕt+1​wcϕ​+διt+1​wcι​+δωt​wcω​​

  4. Cell输出(短时记忆)。
    只需应用前向传播的(3.2)(3.2)(3.2)式,即可得到:
    δct=∂L∂act=∂L∂sct∂sct∂act=ϵst∂[bϕt⋅sct−1+bιt⋅g(act)]∂act=ϵstbιtg′(act)\begin{aligned} \delta_c^t &=\frac{\partial \mathcal{L}}{\partial a_c^t} =\frac{\partial \mathcal{L}}{\partial s_c^t}\frac{\partial s_c^t}{\partial a_c^t} \\ &=\epsilon_s^t \frac{\partial \left[b_\phi^t \cdot s_c^{t-1}\,+b_\iota^t \cdot g(a_c^t)\right] }{\partial a_c^t} \\ &=\epsilon_s^t b_\iota^tg'(a_c^t) \\ \end{aligned}δct​​=∂act​∂L​=∂sct​∂L​∂act​∂sct​​=ϵst​∂act​∂[bϕt​⋅sct−1​+bιt​⋅g(act​)]​=ϵst​bιt​g′(act​)​

  5. 遗忘门。方法同输出门推导,只需应用前向传播的(3.2)(3.2)(3.2)式,可立即得到:
    δϕt=∂L∂aϕt=∂L∂bϕt∂bϕt∂aϕt=∂L∂bϕtf′(aϕt)=f′(aϕt)∂L∂sct∂sct∂bϕt=f′(aϕt)ϵst∂sct∂bϕt=f′(aϕt)ϵst∂[bϕtsct−1+bιtg(act)]∂bϕt=f′(aϕt)∑c=1Csct−1ϵst\begin{aligned} \delta_\phi^t&= \frac{\partial \mathcal{L}}{\partial a_\phi^t} =\frac{\partial \mathcal{L}}{\partial b_\phi^t}\frac{\partial b_\phi^t}{\partial a_\phi^t} \\ &=\frac{\partial \mathcal{L}}{\partial b_\phi^t} f'(a_\phi^t) \\ &=f'(a_\phi^t) \frac{\partial \mathcal{L}}{\partial s_c^t} \frac{\partial s_c^t}{\partial b_\phi^t}\\ &= f'(a_\phi^t) \epsilon_s^t \frac{\partial s_c^t}{\partial b_\phi^t} \\ &= f'(a_\phi^t) \epsilon_s^t \frac{\partial \left[b_\phi^t s_c^{t-1} + b_\iota^{t} g(a_c^t)\right]}{\partial b_\phi^t} \\ &=f'(a_\phi^t)\sum_{c=1}^Cs_c^{t-1}\epsilon_s^t \end{aligned}δϕt​​=∂aϕt​∂L​=∂bϕt​∂L​∂aϕt​∂bϕt​​=∂bϕt​∂L​f′(aϕt​)=f′(aϕt​)∂sct​∂L​∂bϕt​∂sct​​=f′(aϕt​)ϵst​∂bϕt​∂sct​​=f′(aϕt​)ϵst​∂bϕt​∂[bϕt​sct−1​+bιt​g(act​)]​=f′(aϕt​)c=1∑C​sct−1​ϵst​​

  6. 输入门。方法同输出门,只需应用前向传播的(3.2)(3.2)(3.2)式,即可得到:
    διt=∂L∂aιt=∂L∂bιt∂bιt∂aιt=∂L∂bιtf′(aιt)=f′(aιt)∂L∂sct∂sct∂bιt=f′(aιt)ϵst∂sct∂bιt=f′(aιt)ϵct∂[bϕtsct−1+bιtg(act)]∂bιt=f′(aιt)∑c=1Cg(act)ϵst\begin{aligned} \delta_\iota^t&=\frac{\partial \mathcal{L}}{\partial a_\iota^t} =\frac{\partial \mathcal{L}}{\partial b_\iota^t}\frac{\partial b_\iota^t}{\partial a_\iota^t} \\ &=\frac{\partial \mathcal{L}}{\partial b_\iota^t} f'(a_\iota^t) \\ &=f'(a_\iota^t) \frac{\partial \mathcal{L}}{\partial s_c^t} \frac{\partial s_c^t}{\partial b_\iota^t}\\ &= f'(a_\iota^t) \epsilon_s^t \frac{\partial s_c^t}{\partial b_\iota^t} \\ &= f'(a_\iota^t) \epsilon_c^t \frac{\partial \left[b_\phi^t s_c^{t-1} + b_\iota^{t} g(a_c^t)\right]}{\partial b_\iota^t} \\ &=f'(a_\iota^t)\sum_{c=1}^Cg(a_c^t) \epsilon_s^t \end{aligned}διt​​=∂aιt​∂L​=∂bιt​∂L​∂aιt​∂bιt​​=∂bιt​∂L​f′(aιt​)=f′(aιt​)∂sct​∂L​∂bιt​∂sct​​=f′(aιt​)ϵst​∂bιt​∂sct​​=f′(aιt​)ϵct​∂bιt​∂[bϕt​sct−1​+bιt​g(act​)]​=f′(aιt​)c=1∑C​g(act​)ϵst​​

总结

  本文介绍了这本书LSTM部分(第四章)的流程详解及公式推导,其中难免会有些许错误,望大家指出。得到公式后,下一步就是编程实现了,这里可以参考另一篇文章零基础入门深度学习(6) - 长短时记忆网络(LSTM),有非常细致的讲解。第一篇万字长文(其实主要是公式多),如果有用就点个赞吧!

P.S. PPT绘图大法是真的香,对我这种小白十分友好,有兴趣的朋友可以玩玩

LSTM公式详解推导相关推荐

  1. 2位专家耗时2年打造,西瓜书机器学习公式详解,都在这里了!(文末留言赠书)...

    作为机器学习的入门经典教材,周志华老师的<机器学习>,自2016年1月底出版以来,首印5000册一周售罄,并在8个月内重印9次.先后登上了亚马逊,京东,当当网等的计算机类畅销书榜首,身边学 ...

  2. 等额本息和等额本金公式详解

    一.等额本息公式详解 (1) 等额本息,网上已经给出详细解释了,我就一句话,就是你n期,每一期的还的金额都是一样的. (2) 之前一段时间苦恼,网上搜了一圈等额本息的公式,都是直接给结果,没有解算过程 ...

  3. 神经网络的函数表达式,神经网络公式详解pdf

    1.神经网络的准确率是怎么计算的? 其实神经网络的准确率的标准是自己定义的. 我把你的例子赋予某种意义讲解: 1,期望输出[1 0 0 1],每个元素代表一个属性是否存在.像着4个元素分别表示:是否肺 ...

  4. python数组对应元素相乘_python的几种矩阵相乘的公式详解

    1. 同线性代数中矩阵乘法的定义: np.dot() np.dot(A, B):对于二维矩阵,计算真正意义上的矩阵乘积,同线性代数中矩阵乘法的定义.对于一维矩阵,计算两者的内积.见如下Python代码 ...

  5. 【吃瓜教程】《机器学习公式详解》西瓜书与南瓜书公式推导

    [吃瓜教程]<机器学习公式详解>西瓜书与南瓜书公式推导 2021年7月11日 第0章-导学 深度学习:狭义地来说,就是具有较多层的神经网络. 整个学习过程; 先看西瓜书,在看 Datawh ...

  6. pytorch nn.LSTM()参数详解

    输入数据格式: input(seq_len, batch, input_size) h0(num_layers * num_directions, batch, hidden_size) c0(num ...

  7. 一. 卡尔曼滤波器开发实践之一: 五大公式详解

    既然标题名称是开发实践,本系列文章将主要介绍如何在工程实践中使用卡尔曼滤波器,至于卡尔曼滤波器的五大公式如何推导而来,网上有很多大拿们写的都很精彩,这里不再叙述.可以参考了下面两篇博文: 1. 卡尔曼 ...

  8. 通达信欧奈尔RPS指标公式详解

    RPS相对强度指标,是国内的投资者根据威廉·欧奈尔所著书籍<笑傲股市>中的RS评级改进的. 根据书中介绍: RS评级衡量了某一给定股票在过去52周内相对股市中其他股票的表现.市场上每一只股 ...

  9. 长短期记忆(LSTM)详解

    入门小菜鸟,希望像做笔记记录自己学的东西,也希望能帮助到同样入门的人,更希望大佬们帮忙纠错啦~侵权立删. ✨完整代码在我的github上,有需要的朋友可以康康✨ ​​​​​​https://githu ...

最新文章

  1. c++的:: . :-的区别
  2. 利用一个继电器来实现脚踏鼠标按钮
  3. 三维重建学习(1):基础知识:旋转矩阵与旋转向量
  4. 第十章练习题----2
  5. percona-toolkit(pt工具)使用总结
  6. 2场直播丨从零快速搭建一整套监控体系、Oracle Database Server经典体系结构
  7. 数据库数据类型和占用字节数对比
  8. python3爬取中国药学科学数据
  9. 直播APP源码功能详解
  10. QQ音乐爬虫程序详细解析(一)——歌曲下载模块
  11. 攻防世界re:logmein
  12. Masked Autoencoders Are Scalable Vision Learners (MAE)
  13. Angular 的 ngOnInit 和 Constructor 的区别
  14. java excel导入 日期_java导入excel时处理日期格式(已验证ok)
  15. 130行代码实现海贼王漫画下载
  16. dapper mysql通用类_Dapper ORM 用法
  17. python音乐可视化效果_Python 一个漂亮的音乐节奏可视化方案!我觉得可行!
  18. 美团面试 java后端开发
  19. 如何在上传图片的时候,选中图片可以在前端预览
  20. Ubuntu20.04 下 rstudio 安装教程(附安装包下载)

热门文章

  1. docker管理监控方案
  2. react入门(1)之阮一峰react教程
  3. 小爱同学与小冰将实现联合进步
  4. 448. Find All Numbers Disappeared in an Array645. Set Mismatch
  5. Java学习笔记10---访问权限修饰符如何控制成员变量、成员方法及类的访问范围...
  6. oracle sequence的用法
  7. JavaScript高级程序设计之什么是原型模式
  8. android对应版本号
  9. SQL中关于EXISTS谓词的理解
  10. 基础编程题目集 6-11 求自定类型元素序列的中位数 (25 分)