随时间反向传播算法(BPTT)笔记
随时间反向传播算法(BPTT)笔记
1.反向传播算法(BP)
以表达式f(w,x)=11+e−(w0x0+w1x1+w2)f(w,x)=\frac{1}{1+e^{-(w_0x_0+w_1x_1+w_2)}}f(w,x)=1+e−(w0x0+w1x1+w2)1为例,其涉及到的运算操作及导数公式如下:
f(x)=1x→dfdx=−1x2fc(x)=c+x→dfdx=1f(x)=ex→dfdx=exfa(x)=ax→dfdx=a(1)\begin{aligned}&f(x)=\frac{1}{x}&\rightarrow&\frac{df}{dx}=-\frac{1}{x^2}\\&f_c(x)=c+x&\rightarrow&\frac{df}{dx}=1\\&f(x)=e^x&\rightarrow&\frac{df}{dx}=e^x\\&f_a(x)=ax&\rightarrow&\frac{df}{dx}=a\end{aligned}\tag{1} f(x)=x1fc(x)=c+xf(x)=exfa(x)=ax→→→→dxdf=−x21dxdf=1dxdf=exdxdf=a(1)
表达式f(w,x)f(w,x)f(w,x)反向传播过程如下图所示:
其中绿色数值表示表达式f(w,x)f(w,x)f(w,x)正向传播结果,红色数值表示梯度反向传播结果。对于单输入节点(如常数加法或指数运算等)梯度反向传播计算公式如下:
gk=gk+1⋅dfkdx∣x=vk(2)g_k=g_{k+1}\cdot\frac{df_k}{dx}|_{x=v_k}\tag{2} gk=gk+1⋅dxdfk∣x=vk(2)
其中gkg_kgk表示节点前梯度,gk+1g_{k+1}gk+1表示节点后梯度,fkf_kfk表示节点函数,vkv_kvk表示节点输入。对于加法节点梯度反向传播后各链路数值不变。对于乘法节点梯度反向传播计算公式如下:
gk=gk+1⋅a(3)g_k=g_{k+1}\cdot a\tag{3} gk=gk+1⋅a(3)
其中aaa表示节点另一条链路的输入。
2.随时间反向传播(BPTT)
2.1 RNN网络结构
经典RNN结构如上图所示,其正向传播公式如下:
st=Uht−1+Wxtht=tanh(st)zt=Vhty^t=softmax(zt)Et=−ytTlog(y^t)E=∑t=1TEt(4)\begin{aligned}s_t&=Uh_{t-1}+Wx_t\\h_t&=\operatorname{tanh}(s_t)\\z_t&=Vh_t\\\hat{y}_t&=\operatorname{softmax}(z_t)\\E_t&=-y_t^T\log(\hat{y}_t)\\E&=\sum_{t=1}^{T}E_t\end{aligned} \tag{4} sthtzty^tEtE=Uht−1+Wxt=tanh(st)=Vht=softmax(zt)=−ytTlog(y^t)=t=1∑TEt(4)
2.2 反向传播
2.2.1 计算∂Et∂V\frac{\partial E_t}{\partial V}∂V∂Et
∂Et∂Vij=∂zt∂Vij∂Et∂zt=tr[(∂Et∂zt)T⋅∂zt∂Vij]=tr[(y^t−yt)T⋅[0⋮∂zt(i)∂Vij⋮0]]=(y^t−yt)(i)ht(j)(5)\begin{aligned}\frac{\partial E_t}{\partial V_{ij}}&=\frac{\partial z_t}{\partial V_{ij}}\frac{\partial E_t}{\partial z_t}\\&=\operatorname{tr}[(\frac{\partial E_t}{\partial z_t})^T\cdot\frac{\partial z_t}{\partial V_{ij}}]\\&=\operatorname{tr}[(\hat{y}_t-y_t)^T\cdot\begin{bmatrix}0\\\vdots\\\frac{\partial z_t^{(i)}}{\partial V_{ij}}\\\vdots\\0\end{bmatrix}]\\&=(\hat{y}_t-y_t)^{(i)}h_t^{(j)}\end{aligned}\tag{5} ∂Vij∂Et=∂Vij∂zt∂zt∂Et=tr[(∂zt∂Et)T⋅∂Vij∂zt]=tr[(y^t−yt)T⋅⎣⎢⎢⎢⎢⎢⎢⎢⎡0⋮∂Vij∂zt(i)⋮0⎦⎥⎥⎥⎥⎥⎥⎥⎤]=(y^t−yt)(i)ht(j)(5)
对矩阵VVV而言,其求导结果如下:
∂Et∂V=(y^t−yt)⨂ht(6)\frac{\partial E_t}{\partial V}=(\hat{y}_t-y_t)\bigotimes h_t \tag{6} ∂V∂Et=(y^t−yt)⨂ht(6)
其中⨂\bigotimes⨂表示向量外积。
2.2.2 计算∂Et∂U\frac{\partial E_t}{\partial U}∂U∂Et
∂Et∂Uij=∑k=0t∂sk∂Uij∂Et∂sk=∑k=0ttr[(∂Et∂sk)T∂sk∂Uij]=∑k=0ttr[(δk)T∂sk∂Uij]=∑k=0tδk(i)hk−1(j)(7)\frac{\partial E_t}{\partial U_{ij}}=\sum_{k=0}^t\frac{\partial s_k}{\partial U_{ij}}\frac{\partial E_t}{\partial s_k}=\sum_{k=0}^{t}\operatorname{tr}[(\frac{\partial E_t}{\partial s_k})^T\frac{\partial s_k}{\partial U_{ij}}]=\sum_{k=0}^{t}\operatorname{tr}[(\delta_k)^T\frac{\partial s_k}{\partial U_{ij}}]=\sum_{k=0}^t\delta_k^{(i)}h_{k-1}^{(j)}\tag{7} ∂Uij∂Et=k=0∑t∂Uij∂sk∂sk∂Et=k=0∑ttr[(∂sk∂Et)T∂Uij∂sk]=k=0∑ttr[(δk)T∂Uij∂sk]=k=0∑tδk(i)hk−1(j)(7)
对δk\delta_kδk应用链式法则:
δk=∂hk∂sk∂sk+1∂hk∂Et∂sk+1=diag(1−hkhk)UTδk+1=(UTδk+1)(1−hkhk)(8)\delta_k=\frac{\partial h_k}{\partial s_k}\frac{\partial s_{k+1}}{\partial h_k}\frac{\partial E_t}{\partial s_{k+1}}=\operatorname{diag}(1-h_kh_k)U^T\delta_{k+1}=(U^T\delta_{k+1})(1-h_kh_k)\tag{8} δk=∂sk∂hk∂hk∂sk+1∂sk+1∂Et=diag(1−hkhk)UTδk+1=(UTδk+1)(1−hkhk)(8)
对矩阵UUU而言,其求导结果如下:
∂Et∂U=∑k=0tδk⨂hk−1(9)\frac{\partial E_t}{\partial U}=\sum_{k=0}^{t}\delta_k\bigotimes h_{k-1}\tag{9} ∂U∂Et=k=0∑tδk⨂hk−1(9)
2.2.3 计算∂Et∂W\frac{\partial E_t}{\partial W}∂W∂Et
按上述思路,对矩阵WWW而言,其求导结果如下:
∂Et∂W=∑k=0tδk⨂xk(10)\frac{\partial E_t}{\partial W}=\sum_{k=0}^{t}\delta_k\bigotimes x_k\tag{10} ∂W∂Et=k=0∑tδk⨂xk(10)
2.2.4 参数更新
V:=V−λ∑t=0T(y^t−yt)⨂htU:=U−λ∑t=0T∑k=0tδk⨂hk−1W:=W−λ∑t=0T∑k=0tδk⨂xk(11)V:=V-\lambda\sum_{t=0}^T(\hat{y}_t-y_t)\bigotimes h_t\\U:=U-\lambda\sum_{t=0}^T\sum_{k=0}^t\delta_k\bigotimes h_{k-1}\\W:=W-\lambda\sum_{t=0}^T\sum_{k=0}^t\delta_k\bigotimes x_k\tag{11} V:=V−λt=0∑T(y^t−yt)⨂htU:=U−λt=0∑Tk=0∑tδk⨂hk−1W:=W−λt=0∑Tk=0∑tδk⨂xk(11)
2.3 长期依赖问题
重新考查梯度∂Et∂W\frac{\partial E_t}{\partial W}∂W∂Et:
∂Et∂W=∑k=0t∂Et∂st(∏j=k+1t∂sj∂sj−1)∂sk∂W(12)\frac{\partial E_t}{\partial W}=\sum_{k=0}^{t}\frac{\partial E_t}{\partial s_t}(\prod_{j=k+1}^t\frac{\partial s_j}{\partial s_{j-1}})\frac{\partial s_k}{\partial W}\tag{12} ∂W∂Et=k=0∑t∂st∂Et(j=k+1∏t∂sj−1∂sj)∂W∂sk(12)
由于tanh\operatorname{tanh}tanh导数取值范围为(0,1],因此Jacobian矩阵∂sj∂sj−1\frac{\partial s_j}{\partial s_{j-1}}∂sj−1∂sj上限为1。Jacobian矩阵多次连乘后,矩阵上限呈指数下降,最终几乎完全消失,这样就导致了远离TTT时刻的梯度为0,这些时刻的状态对学习过程没有帮助,因此RNN结构无法解决长期依赖问题。
参考文献
[1]CS231n Convolutional Neural Networks for Visual Recognition
[2]随时间反向传播 (BackPropagation Through Time,BPTT)
随时间反向传播算法(BPTT)笔记相关推荐
- 随时间的反向传播算法 BPTT
本文转自:https://www.cntofu.com/book/85/dl/rnn/bptt.md 随时间反向传播(BPTT)算法 先简单回顾一下RNN的基本公式: st=tanh(Uxt+Wst− ...
- RNN与其反向传播算法——BPTT(Backward Propogation Through Time)的详细推导
前言 一点感悟: 前几天简单看了下王者荣耀觉悟AI的论文,发现除了强化学习以外,也用到了熟悉的LSTM.之后我又想起了知乎上的一个问题:"Transformer会彻底取代RNN吗?" ...
- 神经网络反向传播算法原理笔记
神经网络是一种是基于生物学中神经网络的基本原理,在理解和抽象了人脑结构和外界刺激响应机制后,以网络拓扑知识为理论基础,模拟人脑的神经系统对复杂信息的处理机制的一种数学模型.该模型以并行分布的处理能力. ...
- 反向传播算法学习笔记
反向传播算法(Back propagation) 目的及思想 我们现在有一堆输入,我们希望能有一个网络,使得通过这个网络的构成的映射关系满足我们的期待.也就是说,我们在解决这个问题之前先假设,这种映射 ...
- TensorFlow精进之路(十二):随时间反向传播BPTT
1.概述 上一节介绍了TensorFlow精进之路(十一):反向传播BP,这一节就简单介绍一下BPTT. 2.网络结构 RNN正向传播可以用上图表示,这里忽略偏置. 上图中, x(1:T)表示输入序列 ...
- 基于时间的反向传播算法BPTT(Backpropagation through time)
本文是读"Recurrent Neural Networks Tutorial, Part 3 – Backpropagation Through Time and Vanishing Gr ...
- 时间序列的反向传播算法(BPTT)
时间序列的反向传播算法 BPTT : Back-Propagation Through Time ∂L∂U=∑t∂Lt∂U\frac{\partial L}{\partial U} = \sum_t\ ...
- 深度学习入门笔记(六):误差反向传播算法
专栏--深度学习入门笔记 推荐文章 深度学习入门笔记(一):机器学习基础 深度学习入门笔记(二):神经网络基础 深度学习入门笔记(三):感知机 深度学习入门笔记(四):神经网络 深度学习入门笔记(五) ...
- 深度学习中反向传播算法简单推导笔记
反向传播算法简单推导笔记 1.全连接神经网络 该结构的前向传播可以写成: z(1)=W(1)x+b(1)z^{(1)} = W^{(1)}x+b^{(1)}z(1)=W(1)x+b(1) a(1)=σ ...
最新文章
- SAP QM 物料主数据QM视图里字段MARC-INSMK的更新
- Java重写equals和hashCode方法
- 南京信息工程大学滨江学院计算机科学与技术专业,南京信息工程大学滨江学院有哪些专业及什么专业好...
- hadoop +hbase +zookeeper 完全分布搭建 (版本一)
- HTML 5 样式指南和代码约定
- c语言电报关系的题目,c语言所有题目以跟答案.doc
- DButils工具使用笔记以及常见问题总结
- 线程池开门营业招聘开发人员的一天
- HaLow技术提升车载Wi-Fi质量 促进车联网发展
- The valid characters are defined in RFC 7230 and RFC 3986问题
- object references an unsaved transient instance - save the transient instance before flushing
- 279. Perfect Squares
- ubuntu修改dns服务器,配置Ubuntu DNS服务器
- 一个北京妞儿写的经典的话,太现实了!
- [vm] vm安装xp :non-bootable disk 80 解决办法
- flutter项目实战三:封装http工具类
- C/C++实现贪吃蛇游戏
- LinkedList面试要点总结
- mac系统安装Anaconda后再打开终端自动进入Anaconda环境
- 联想服务器td340安装精简版win10
热门文章
- stm32毕设分享 stm32的车牌识别系统
- #### 标题关于Quartus Ⅱ启动ModelSim仿真软件时提示Can't lauch the ModelSim的问题
- 李阳疯狂英语突破对话(13)-你怎样提高英语
- linux数据库awr报告,手动生成AWR报告
- Excel中提取英文,数值和编码(LEN函数)
- uni-app自定义导航栏右侧做增加按钮并跳转链接
- 逻技鼠标蓝牙连接Mac无法被logi Options检测到
- 福建厦门双十计算机竞赛,厦门双十中学新高三学生 获信息学奥赛金奖保送清华...
- python打印图形大全(详解)
- 浅学 ---------- 计算机网络(二)(湖科大笔记)