随时间反向传播算法(BPTT)笔记

1.反向传播算法(BP)

以表达式f(w,x)=11+e−(w0x0+w1x1+w2)f(w,x)=\frac{1}{1+e^{-(w_0x_0+w_1x_1+w_2)}}f(w,x)=1+e−(w0​x0​+w1​x1​+w2​)1​为例,其涉及到的运算操作及导数公式如下:

f(x)=1x→dfdx=−1x2fc(x)=c+x→dfdx=1f(x)=ex→dfdx=exfa(x)=ax→dfdx=a(1)\begin{aligned}&f(x)=\frac{1}{x}&\rightarrow&\frac{df}{dx}=-\frac{1}{x^2}\\&f_c(x)=c+x&\rightarrow&\frac{df}{dx}=1\\&f(x)=e^x&\rightarrow&\frac{df}{dx}=e^x\\&f_a(x)=ax&\rightarrow&\frac{df}{dx}=a\end{aligned}\tag{1} ​f(x)=x1​fc​(x)=c+xf(x)=exfa​(x)=ax​→→→→​dxdf​=−x21​dxdf​=1dxdf​=exdxdf​=a​(1)
表达式f(w,x)f(w,x)f(w,x)反向传播过程如下图所示:

其中绿色数值表示表达式f(w,x)f(w,x)f(w,x)正向传播结果,红色数值表示梯度反向传播结果。对于单输入节点(如常数加法或指数运算等)梯度反向传播计算公式如下:

gk=gk+1⋅dfkdx∣x=vk(2)g_k=g_{k+1}\cdot\frac{df_k}{dx}|_{x=v_k}\tag{2} gk​=gk+1​⋅dxdfk​​∣x=vk​​(2)
其中gkg_kgk​表示节点前梯度,gk+1g_{k+1}gk+1​表示节点后梯度,fkf_kfk​表示节点函数,vkv_kvk​表示节点输入。对于加法节点梯度反向传播后各链路数值不变。对于乘法节点梯度反向传播计算公式如下:

gk=gk+1⋅a(3)g_k=g_{k+1}\cdot a\tag{3} gk​=gk+1​⋅a(3)
其中aaa表示节点另一条链路的输入。

2.随时间反向传播(BPTT)

2.1 RNN网络结构

经典RNN结构如上图所示,其正向传播公式如下:

st=Uht−1+Wxtht=tanh⁡(st)zt=Vhty^t=softmax⁡(zt)Et=−ytTlog⁡(y^t)E=∑t=1TEt(4)\begin{aligned}s_t&=Uh_{t-1}+Wx_t\\h_t&=\operatorname{tanh}(s_t)\\z_t&=Vh_t\\\hat{y}_t&=\operatorname{softmax}(z_t)\\E_t&=-y_t^T\log(\hat{y}_t)\\E&=\sum_{t=1}^{T}E_t\end{aligned} \tag{4} st​ht​zt​y^​t​Et​E​=Uht−1​+Wxt​=tanh(st​)=Vht​=softmax(zt​)=−ytT​log(y^​t​)=t=1∑T​Et​​(4)

2.2 反向传播

2.2.1 计算∂Et∂V\frac{\partial E_t}{\partial V}∂V∂Et​​

∂Et∂Vij=∂zt∂Vij∂Et∂zt=tr⁡[(∂Et∂zt)T⋅∂zt∂Vij]=tr⁡[(y^t−yt)T⋅[0⋮∂zt(i)∂Vij⋮0]]=(y^t−yt)(i)ht(j)(5)\begin{aligned}\frac{\partial E_t}{\partial V_{ij}}&=\frac{\partial z_t}{\partial V_{ij}}\frac{\partial E_t}{\partial z_t}\\&=\operatorname{tr}[(\frac{\partial E_t}{\partial z_t})^T\cdot\frac{\partial z_t}{\partial V_{ij}}]\\&=\operatorname{tr}[(\hat{y}_t-y_t)^T\cdot\begin{bmatrix}0\\\vdots\\\frac{\partial z_t^{(i)}}{\partial V_{ij}}\\\vdots\\0\end{bmatrix}]\\&=(\hat{y}_t-y_t)^{(i)}h_t^{(j)}\end{aligned}\tag{5} ∂Vij​∂Et​​​=∂Vij​∂zt​​∂zt​∂Et​​=tr[(∂zt​∂Et​​)T⋅∂Vij​∂zt​​]=tr[(y^​t​−yt​)T⋅⎣⎢⎢⎢⎢⎢⎢⎢⎡​0⋮∂Vij​∂zt(i)​​⋮0​⎦⎥⎥⎥⎥⎥⎥⎥⎤​]=(y^​t​−yt​)(i)ht(j)​​(5)

对矩阵VVV而言,其求导结果如下:

∂Et∂V=(y^t−yt)⨂ht(6)\frac{\partial E_t}{\partial V}=(\hat{y}_t-y_t)\bigotimes h_t \tag{6} ∂V∂Et​​=(y^​t​−yt​)⨂ht​(6)
其中⨂\bigotimes⨂表示向量外积。

2.2.2 计算∂Et∂U\frac{\partial E_t}{\partial U}∂U∂Et​​

∂Et∂Uij=∑k=0t∂sk∂Uij∂Et∂sk=∑k=0ttr⁡[(∂Et∂sk)T∂sk∂Uij]=∑k=0ttr⁡[(δk)T∂sk∂Uij]=∑k=0tδk(i)hk−1(j)(7)\frac{\partial E_t}{\partial U_{ij}}=\sum_{k=0}^t\frac{\partial s_k}{\partial U_{ij}}\frac{\partial E_t}{\partial s_k}=\sum_{k=0}^{t}\operatorname{tr}[(\frac{\partial E_t}{\partial s_k})^T\frac{\partial s_k}{\partial U_{ij}}]=\sum_{k=0}^{t}\operatorname{tr}[(\delta_k)^T\frac{\partial s_k}{\partial U_{ij}}]=\sum_{k=0}^t\delta_k^{(i)}h_{k-1}^{(j)}\tag{7} ∂Uij​∂Et​​=k=0∑t​∂Uij​∂sk​​∂sk​∂Et​​=k=0∑t​tr[(∂sk​∂Et​​)T∂Uij​∂sk​​]=k=0∑t​tr[(δk​)T∂Uij​∂sk​​]=k=0∑t​δk(i)​hk−1(j)​(7)

对δk\delta_kδk​应用链式法则:

δk=∂hk∂sk∂sk+1∂hk∂Et∂sk+1=diag⁡(1−hkhk)UTδk+1=(UTδk+1)(1−hkhk)(8)\delta_k=\frac{\partial h_k}{\partial s_k}\frac{\partial s_{k+1}}{\partial h_k}\frac{\partial E_t}{\partial s_{k+1}}=\operatorname{diag}(1-h_kh_k)U^T\delta_{k+1}=(U^T\delta_{k+1})(1-h_kh_k)\tag{8} δk​=∂sk​∂hk​​∂hk​∂sk+1​​∂sk+1​∂Et​​=diag(1−hk​hk​)UTδk+1​=(UTδk+1​)(1−hk​hk​)(8)
对矩阵UUU而言,其求导结果如下:

∂Et∂U=∑k=0tδk⨂hk−1(9)\frac{\partial E_t}{\partial U}=\sum_{k=0}^{t}\delta_k\bigotimes h_{k-1}\tag{9} ∂U∂Et​​=k=0∑t​δk​⨂hk−1​(9)

2.2.3 计算∂Et∂W\frac{\partial E_t}{\partial W}∂W∂Et​​

按上述思路,对矩阵WWW而言,其求导结果如下:

∂Et∂W=∑k=0tδk⨂xk(10)\frac{\partial E_t}{\partial W}=\sum_{k=0}^{t}\delta_k\bigotimes x_k\tag{10} ∂W∂Et​​=k=0∑t​δk​⨂xk​(10)

2.2.4 参数更新

V:=V−λ∑t=0T(y^t−yt)⨂htU:=U−λ∑t=0T∑k=0tδk⨂hk−1W:=W−λ∑t=0T∑k=0tδk⨂xk(11)V:=V-\lambda\sum_{t=0}^T(\hat{y}_t-y_t)\bigotimes h_t\\U:=U-\lambda\sum_{t=0}^T\sum_{k=0}^t\delta_k\bigotimes h_{k-1}\\W:=W-\lambda\sum_{t=0}^T\sum_{k=0}^t\delta_k\bigotimes x_k\tag{11} V:=V−λt=0∑T​(y^​t​−yt​)⨂ht​U:=U−λt=0∑T​k=0∑t​δk​⨂hk−1​W:=W−λt=0∑T​k=0∑t​δk​⨂xk​(11)

2.3 长期依赖问题

重新考查梯度∂Et∂W\frac{\partial E_t}{\partial W}∂W∂Et​​:

∂Et∂W=∑k=0t∂Et∂st(∏j=k+1t∂sj∂sj−1)∂sk∂W(12)\frac{\partial E_t}{\partial W}=\sum_{k=0}^{t}\frac{\partial E_t}{\partial s_t}(\prod_{j=k+1}^t\frac{\partial s_j}{\partial s_{j-1}})\frac{\partial s_k}{\partial W}\tag{12} ∂W∂Et​​=k=0∑t​∂st​∂Et​​(j=k+1∏t​∂sj−1​∂sj​​)∂W∂sk​​(12)
由于tanh⁡\operatorname{tanh}tanh导数取值范围为(0,1],因此Jacobian矩阵∂sj∂sj−1\frac{\partial s_j}{\partial s_{j-1}}∂sj−1​∂sj​​上限为1。Jacobian矩阵多次连乘后,矩阵上限呈指数下降,最终几乎完全消失,这样就导致了远离TTT时刻的梯度为0,这些时刻的状态对学习过程没有帮助,因此RNN结构无法解决长期依赖问题。

参考文献

[1]CS231n Convolutional Neural Networks for Visual Recognition

[2]随时间反向传播 (BackPropagation Through Time,BPTT)

随时间反向传播算法(BPTT)笔记相关推荐

  1. 随时间的反向传播算法 BPTT

    本文转自:https://www.cntofu.com/book/85/dl/rnn/bptt.md 随时间反向传播(BPTT)算法 先简单回顾一下RNN的基本公式: st=tanh(Uxt+Wst− ...

  2. RNN与其反向传播算法——BPTT(Backward Propogation Through Time)的详细推导

    前言 一点感悟: 前几天简单看了下王者荣耀觉悟AI的论文,发现除了强化学习以外,也用到了熟悉的LSTM.之后我又想起了知乎上的一个问题:"Transformer会彻底取代RNN吗?" ...

  3. 神经网络反向传播算法原理笔记

    神经网络是一种是基于生物学中神经网络的基本原理,在理解和抽象了人脑结构和外界刺激响应机制后,以网络拓扑知识为理论基础,模拟人脑的神经系统对复杂信息的处理机制的一种数学模型.该模型以并行分布的处理能力. ...

  4. 反向传播算法学习笔记

    反向传播算法(Back propagation) 目的及思想 我们现在有一堆输入,我们希望能有一个网络,使得通过这个网络的构成的映射关系满足我们的期待.也就是说,我们在解决这个问题之前先假设,这种映射 ...

  5. TensorFlow精进之路(十二):随时间反向传播BPTT

    1.概述 上一节介绍了TensorFlow精进之路(十一):反向传播BP,这一节就简单介绍一下BPTT. 2.网络结构 RNN正向传播可以用上图表示,这里忽略偏置. 上图中, x(1:T)表示输入序列 ...

  6. 基于时间的反向传播算法BPTT(Backpropagation through time)

    本文是读"Recurrent Neural Networks Tutorial, Part 3 – Backpropagation Through Time and Vanishing Gr ...

  7. 时间序列的反向传播算法(BPTT)

    时间序列的反向传播算法 BPTT : Back-Propagation Through Time ∂L∂U=∑t∂Lt∂U\frac{\partial L}{\partial U} = \sum_t\ ...

  8. 深度学习入门笔记(六):误差反向传播算法

    专栏--深度学习入门笔记 推荐文章 深度学习入门笔记(一):机器学习基础 深度学习入门笔记(二):神经网络基础 深度学习入门笔记(三):感知机 深度学习入门笔记(四):神经网络 深度学习入门笔记(五) ...

  9. 深度学习中反向传播算法简单推导笔记

    反向传播算法简单推导笔记 1.全连接神经网络 该结构的前向传播可以写成: z(1)=W(1)x+b(1)z^{(1)} = W^{(1)}x+b^{(1)}z(1)=W(1)x+b(1) a(1)=σ ...

最新文章

  1. SAP QM 物料主数据QM视图里字段MARC-INSMK的更新
  2. Java重写equals和hashCode方法
  3. 南京信息工程大学滨江学院计算机科学与技术专业,南京信息工程大学滨江学院有哪些专业及什么专业好...
  4. hadoop +hbase +zookeeper 完全分布搭建 (版本一)
  5. HTML 5 样式指南和代码约定
  6. c语言电报关系的题目,c语言所有题目以跟答案.doc
  7. DButils工具使用笔记以及常见问题总结
  8. 线程池开门营业招聘开发人员的一天
  9. HaLow技术提升车载Wi-Fi质量 促进车联网发展
  10. The valid characters are defined in RFC 7230 and RFC 3986问题
  11. object references an unsaved transient instance - save the transient instance before flushing
  12. 279. Perfect Squares
  13. ubuntu修改dns服务器,配置Ubuntu DNS服务器
  14. 一个北京妞儿写的经典的话,太现实了!
  15. [vm] vm安装xp :non-bootable disk 80 解决办法
  16. flutter项目实战三:封装http工具类
  17. C/C++实现贪吃蛇游戏
  18. LinkedList面试要点总结
  19. mac系统安装Anaconda后再打开终端自动进入Anaconda环境
  20. 联想服务器td340安装精简版win10

热门文章

  1. stm32毕设分享 stm32的车牌识别系统
  2. #### 标题关于Quartus Ⅱ启动ModelSim仿真软件时提示Can't lauch the ModelSim的问题
  3. 李阳疯狂英语突破对话(13)-你怎样提高英语
  4. linux数据库awr报告,手动生成AWR报告
  5. Excel中提取英文,数值和编码(LEN函数)
  6. uni-app自定义导航栏右侧做增加按钮并跳转链接
  7. 逻技鼠标蓝牙连接Mac无法被logi Options检测到
  8. 福建厦门双十计算机竞赛,厦门双十中学新高三学生 获信息学奥赛金奖保送清华...
  9. python打印图形大全(详解)
  10. 浅学 ---------- 计算机网络(二)(湖科大笔记)