基于时间的反向传播算法BPTT(Backpropagation through time)
本文是读“Recurrent Neural Networks Tutorial, Part 3 – Backpropagation Through Time and Vanishing Gradients”的读书笔记,加入了自己的一些理解,有兴趣可以直接阅读原文。
1. 算法介绍
这里引用原文中的网络结构图
其中xxx为输入,sss为隐藏层状态,o为输出,按时间展开
为了与文献中的表示一致,我们用y^\hat yy^来代替o,则
st=tanh(Uxt+Wst−1)y^=softmat(Vst)s_t=tanh(Ux_t+Ws_{t-1}) \\ \hat y=softmat(Vs_t) st=tanh(Uxt+Wst−1)y^=softmat(Vst)
使用交叉熵(cross entropy)作为损失函数
Et(y,y^)=−ytlogy^E(y,y^)=∑tEt(yt,y^t)=−∑tytlogy^E_t(y,\hat y)=-y_tlog\hat y \\ E(y, \hat y) = \sum_t E_t(y_t, \hat y_t)=-\sum_t y_tlog\hat y Et(y,y^)=−ytlogy^E(y,y^)=t∑Et(yt,y^t)=−t∑ytlogy^
我们使用链式法则来计算后向传播时的梯度,以网络的输出E3E_3E3为例,
y^3=ez3∑ieziE3=−y3logy^3=−y3(z3−log∑iezi)z3=Vs3s3=tanh(Ux3+Ws2)\hat y_3=\frac{e^{z_3}}{\sum_ie^{z_i}} \\ E_3=-y_3log\hat y_3=-y_3(z_3-log\sum_ie^{z_i}) \\ z_3=Vs_3 \\ s_3=tanh(Ux_3+Ws_2) y^3=∑ieziez3E3=−y3logy^3=−y3(z3−logi∑ezi)z3=Vs3s3=tanh(Ux3+Ws2)
因此可以求V的梯度
∂E3∂V=∂E3∂z^3∂z3∂V=y3(y^3−1)∗s3\frac{\partial E_3}{\partial V}=\frac{\partial E_3}{\partial \hat z_3}\frac{\partial z_3}{\partial V}=y_3(\hat y_3-1)*s3 ∂V∂E3=∂z^3∂E3∂V∂z3=y3(y^3−1)∗s3
这里求导时将y^3\hat y_3y^3带入消去了,求导更直观,这里给出的是标量形式,改成向量形式应该是y^−13\hat y-1_3y^−13,也就是输出概率矩阵中,对应结果的那个概率-1,其他不变,而输入y恰好可以认为是对应结果的概率是1,其他是0,因此原文中写作
∂E3∂V=(y^3−y3)⊗s3\frac{\partial E_3}{\partial V}=(\hat y_3-y_3)\otimes s_3 ∂V∂E3=(y^3−y3)⊗s3
相对V的梯度,因为sts_tst是W,U的函数,而且含有的st−1s_{t-1}st−1在 求导时,不能简单的认为是一个常数,因此在求导时,如果不加限制,需要对从t到0的所有状态进行回溯,在实际中一般按照场景和精度要求进行截断。
∂E3∂W=∂E3∂z^3∂z3∂s3∂s3∂sk∂sk∂W\frac{\partial E_3}{\partial W}=\frac{\partial E_3}{\partial \hat z_3}\frac{\partial z_3}{\partial s_3}\frac{\partial s_3}{\partial s_k}\frac{\partial s_k}{\partial W} ∂W∂E3=∂z^3∂E3∂s3∂z3∂sk∂s3∂W∂sk
其中s3s_3s3对W的求导是一个分部求导
∂st∂W=(1−st2)(st−1+W∗∂st−1∂sW)\frac{\partial s_t}{\partial W}=(1-s_t^2)(s_{t-1}+W*\frac{\partial s_{t-1}}{\partial s_{W}}) ∂W∂st=(1−st2)(st−1+W∗∂sW∂st−1)
U的梯度类似
∂st∂U=(1−st2)(xt+W∗∂st−1∂sU)\frac{\partial s_t}{\partial U}=(1-s_t^2)(x_t+W*\frac{\partial s_{t-1}}{\partial s_{U}}) ∂U∂st=(1−st2)(xt+W∗∂sU∂st−1)
2. 代码分析
首先我们给出作者自己实现的完整的BPTT,再各部分分析
def bptt(self, x, y):T = len(y)# Perform forward propagationo, s = self.forward_propagation(x)# We accumulate the gradients in these variablesdLdU = np.zeros(self.U.shape)dLdV = np.zeros(self.V.shape)dLdW = np.zeros(self.W.shape)delta_o = odelta_o[np.arange(len(y)), y] -= 1.# For each output backwards...for t in np.arange(T)[::-1]:dLdV += np.outer(delta_o[t], s[t].T)# Initial delta calculation: dL/dzdelta_t = self.V.T.dot(delta_o[t]) * (1 - (s[t] ** 2))# Backpropagation through time (for at most self.bptt_truncate steps)for bptt_step in np.arange(max(0, t-self.bptt_truncate), t+1)[::-1]:# print "Backpropagation step t=%d bptt step=%d " % (t, bptt_step)# Add to gradients at each previous stepdLdW += np.outer(delta_t, s[bptt_step-1]) dLdU[:,x[bptt_step]] += delta_t# Update delta for next step dL/dz at t-1delta_t = self.W.T.dot(delta_t) * (1 - s[bptt_step-1] ** 2)return [dLdU, dLdV, dLdW]
2.1. 初始化
结合完整的代码,我们可知梯度的维度
#100*8000
dLdU = np.zeros(self.U.shape)
#8000*100
dLdV = np.zeros(self.V.shape)
#100*100
dLdW = np.zeros(self.W.shape)
2.2. 公共部分
对照上面的理论可知,无论是V,还是U,W,都有∂E3∂z^3\frac{\partial E_3}{\partial \hat z_3}∂z^3∂E3,这部分可以预先计算出来,也就是代码中的delta_o
#o是forward的输出,T(句子的实际长度)*8000维,每一行是8000维的,就是词表中所有词作为输入x中每一个词的后一个词的概率
delta_o = o
#[]中是索引操作,对y中的词对应的索引的概率-1
delta_o[np.arange(len(y)), y] -= 1.
2.3. V的梯度
s[t].Ts[t].Ts[t].T是取s[t]s[t]s[t]的转置,numpy.outer是将第一个参数和第二个参数中的所有元素分别按行展开,然后拿第一个参数中的数因此乘以第二个参数的每一行,例如a=[a0,a1,...,aM]a=[a_0, a_1, ..., a_M]a=[a0,a1,...,aM], b=[b0,b1,...,bN]b=[b_0, b_1, ..., b_N]b=[b0,b1,...,bN],则相乘后变成
[[a0∗b0a0∗b1...a0∗bN][a1∗b0a1∗b1...a1∗bN]...[aM∗b0aM∗b1...aM∗bN]][[a_0*b_0\quad a_0*b_1 \quad ... \quad a_0*b_N] \\ [a_1*b_0\quad a_1*b_1 \quad ... \quad a_1*b_N] \\ ... \\ [a_M*b_0\quad a_M*b_1 \quad ... \quad a_M*b_N]] [[a0∗b0a0∗b1...a0∗bN][a1∗b0a1∗b1...a1∗bN]...[aM∗b0aM∗b1...aM∗bN]]
结果是M*N维的
#delta_o是1*8000维向量,s[t]是1*100的向量,转不转置对outer并没有什么区别,其实和delta_o[t].T * s[t]等价,*是矩阵相乘,结果是8000*100维的矩阵
dLdV += np.outer(delta_o[t], s[t].T)
2.4. W和U的梯度
对比W和U的梯度公式,我们可以看到,两者+号的第二部分前面的系数是一样的,也就是(1−st2)∗W(1-s_t^2)*W(1−st2)∗W,这部分可以存起来减少计算量,也就是代码中的delta_t
delta_t = self.V.T.dot(delta_o[t]) * (1 - (s[t] ** 2))
# Backpropagation through time (for at most self.bptt_truncate steps)
#截断
for bptt_step in np.arange(max(0, t-self.bptt_truncate), t+1)[::-1]:# print "Backpropagation step t=%d bptt step=%d " % (t, bptt_step)# Add to gradients at each previous step#计算+号的第一部分,第二部分本次还没得到,下次累加进来dLdW += np.outer(delta_t, s[bptt_step-1])#x为单词的位置向量,与delta_t相乘相当于dLdU按x取索引(对应的词向量)直接与delta_t相加 dLdU[:,x[bptt_step]] += delta_t# Update delta for next step dL/dz at t-1 #更新第二部分系数delta_t = self.W.T.dot(delta_t) * (1 - s[bptt_step-1] ** 2)
基于时间的反向传播算法BPTT(Backpropagation through time)相关推荐
- RNN与其反向传播算法——BPTT(Backward Propogation Through Time)的详细推导
前言 一点感悟: 前几天简单看了下王者荣耀觉悟AI的论文,发现除了强化学习以外,也用到了熟悉的LSTM.之后我又想起了知乎上的一个问题:"Transformer会彻底取代RNN吗?" ...
- 随时间反向传播算法(BPTT)笔记
随时间反向传播算法(BPTT)笔记 1.反向传播算法(BP) 以表达式f(w,x)=11+e−(w0x0+w1x1+w2)f(w,x)=\frac{1}{1+e^{-(w_0x_0+w_1x_1+w_ ...
- 随时间的反向传播算法 BPTT
本文转自:https://www.cntofu.com/book/85/dl/rnn/bptt.md 随时间反向传播(BPTT)算法 先简单回顾一下RNN的基本公式: st=tanh(Uxt+Wst− ...
- 神经网络(NN)+反向传播算法(Backpropagation/BP)+交叉熵+softmax原理分析
神经网络如何利用反向传播算法进行参数更新,加入交叉熵和softmax又会如何变化? 其中的数学原理分析:请点击这里. 转载于:https://www.cnblogs.com/code-wangjun/ ...
- 对反向传播算法(Back-Propagation)的推导与一点理解
最近在对卷积神经网络(CNN)进行学习的过程中,发现自己之前对反向传播算法的理解不够透彻,所以今天专门写篇博客记录一下反向传播算法的推导过程,算是一份备忘录吧,有需要的朋友也可以看一下这篇文章,写的挺 ...
- 时间序列的反向传播算法(BPTT)
时间序列的反向传播算法 BPTT : Back-Propagation Through Time ∂L∂U=∑t∂Lt∂U\frac{\partial L}{\partial U} = \sum_t\ ...
- 用反向传播算法解释大脑学习过程?Hinton 等人新研究登上 Nature 子刊
机器之心报道 魔王.Jamin.杜伟 反向传播可以解释大脑学习吗?近日 Hinton 等人的研究认为,尽管大脑可能未实现字面形式的反向传播,但是反向传播的部分特征与理解大脑中的学习具备很强的关联性.该 ...
- 反向传播算法公式推导,神经网络的推导
如何理解神经网络里面的反向传播算法 反向传播算法(Backpropagation)是目前用来训练人工神经网络(ArtificialNeuralNetwork,ANN)的最常用且最有效的算法. 其主要思 ...
- 反向传播算法最全解读,机器学习进阶必看!
如果对人工智能稍有了解的小伙伴们,或多或少都听过反向传播算法这个名词,但实际上BP到底是什么?它有着怎样的魅力与优势?本文发布于 offconvex.org,作者 Sanjeev Arora与 Ten ...
最新文章
- 在C#中读取枚举值的描述属性
- 第七周项目一-一般函数(2)
- wxpython组件SplitterWindow 的简单使用
- SpringSecurity案例之认证服务搭建
- bootstrap.yml与application.properties区别?
- [11] ADB 实用功能
- XML文件解析 --------------------笔记
- 关于exe应用程序做成Windows服务爬过的坑
- python 打开txt_python编程之文件操作
- 利用域环境,处理瑞星网络版杀毒软件的问题
- 7-4 复数的实部和虚部(8 分)
- 动画:面试官问我 JS「变量提升」我头皮发麻,最后把这篇动画甩给了他
- 特殊字符大全-希腊字母俄文注音拼音日文序集心型方形点数绘表(转载)
- 幼儿学习品质提升的培养策略问卷
- 微服务资源库太强了,学习手册限时开源
- 中富金石投教怎么样?让专业投资创造更多财富机会
- 程序员的职业病,一定要注重身体健康才是最重要的
- 互联网日报 | 全国版消费券今日起开抢;微信搜一搜正式开放服务搜索接入;高德打车上线“考生专车”服务...
- 2. 487-3279
- 立足现实 与时俱进:C++ 1991-2006 reference
热门文章
- C++——list的模拟实现
- OpenFeign出现failed and no fallback available错误
- mysql连接池泄露_一次线上故障:数据库连接池泄露后的思考
- 解决 IDEA 中 thymeleaf ${} 中报波浪线错误
- 自动驾驶场景要求(速度方面和检测速度方面)
- unity海岛模型,荒岛模型,里面有木桥,木船等等
- 读书04《番茄工作法图解下》
- 命定的局限与挑战,读《命若琴弦》——leo鉴书(17)
- 2021.4.6 腾讯 IEG 运营开发实习面试(一面)(含总结)
- 网络原理考点之IP地址分配问题解题思路