本文是读“Recurrent Neural Networks Tutorial, Part 3 – Backpropagation Through Time and Vanishing Gradients”的读书笔记,加入了自己的一些理解,有兴趣可以直接阅读原文。

1. 算法介绍

这里引用原文中的网络结构图

其中xxx为输入,sss为隐藏层状态,o为输出,按时间展开

为了与文献中的表示一致,我们用y^\hat yy^​来代替o,则
st=tanh(Uxt+Wst−1)y^=softmat(Vst)s_t=tanh(Ux_t+Ws_{t-1}) \\ \hat y=softmat(Vs_t) st​=tanh(Uxt​+Wst−1​)y^​=softmat(Vst​)
使用交叉熵(cross entropy)作为损失函数
Et(y,y^)=−ytlogy^E(y,y^)=∑tEt(yt,y^t)=−∑tytlogy^E_t(y,\hat y)=-y_tlog\hat y \\ E(y, \hat y) = \sum_t E_t(y_t, \hat y_t)=-\sum_t y_tlog\hat y Et​(y,y^​)=−yt​logy^​E(y,y^​)=t∑​Et​(yt​,y^​t​)=−t∑​yt​logy^​
我们使用链式法则来计算后向传播时的梯度,以网络的输出E3E_3E3​为例,
y^3=ez3∑ieziE3=−y3logy^3=−y3(z3−log∑iezi)z3=Vs3s3=tanh(Ux3+Ws2)\hat y_3=\frac{e^{z_3}}{\sum_ie^{z_i}} \\ E_3=-y_3log\hat y_3=-y_3(z_3-log\sum_ie^{z_i}) \\ z_3=Vs_3 \\ s_3=tanh(Ux_3+Ws_2) y^​3​=∑i​ezi​ez3​​E3​=−y3​logy^​3​=−y3​(z3​−logi∑​ezi​)z3​=Vs3​s3​=tanh(Ux3​+Ws2​)
因此可以求V的梯度
∂E3∂V=∂E3∂z^3∂z3∂V=y3(y^3−1)∗s3\frac{\partial E_3}{\partial V}=\frac{\partial E_3}{\partial \hat z_3}\frac{\partial z_3}{\partial V}=y_3(\hat y_3-1)*s3 ∂V∂E3​​=∂z^3​∂E3​​∂V∂z3​​=y3​(y^​3​−1)∗s3
这里求导时将y^3\hat y_3y^​3​带入消去了,求导更直观,这里给出的是标量形式,改成向量形式应该是y^−13\hat y-1_3y^​−13​,也就是输出概率矩阵中,对应结果的那个概率-1,其他不变,而输入y恰好可以认为是对应结果的概率是1,其他是0,因此原文中写作
∂E3∂V=(y^3−y3)⊗s3\frac{\partial E_3}{\partial V}=(\hat y_3-y_3)\otimes s_3 ∂V∂E3​​=(y^​3​−y3​)⊗s3​
相对V的梯度,因为sts_tst​是W,U的函数,而且含有的st−1s_{t-1}st−1​在 求导时,不能简单的认为是一个常数,因此在求导时,如果不加限制,需要对从t到0的所有状态进行回溯,在实际中一般按照场景和精度要求进行截断。
∂E3∂W=∂E3∂z^3∂z3∂s3∂s3∂sk∂sk∂W\frac{\partial E_3}{\partial W}=\frac{\partial E_3}{\partial \hat z_3}\frac{\partial z_3}{\partial s_3}\frac{\partial s_3}{\partial s_k}\frac{\partial s_k}{\partial W} ∂W∂E3​​=∂z^3​∂E3​​∂s3​∂z3​​∂sk​∂s3​​∂W∂sk​​
其中s3s_3s3​对W的求导是一个分部求导
∂st∂W=(1−st2)(st−1+W∗∂st−1∂sW)\frac{\partial s_t}{\partial W}=(1-s_t^2)(s_{t-1}+W*\frac{\partial s_{t-1}}{\partial s_{W}}) ∂W∂st​​=(1−st2​)(st−1​+W∗∂sW​∂st−1​​)
U的梯度类似
∂st∂U=(1−st2)(xt+W∗∂st−1∂sU)\frac{\partial s_t}{\partial U}=(1-s_t^2)(x_t+W*\frac{\partial s_{t-1}}{\partial s_{U}}) ∂U∂st​​=(1−st2​)(xt​+W∗∂sU​∂st−1​​)

2. 代码分析

首先我们给出作者自己实现的完整的BPTT,再各部分分析

def bptt(self, x, y):T = len(y)# Perform forward propagationo, s = self.forward_propagation(x)# We accumulate the gradients in these variablesdLdU = np.zeros(self.U.shape)dLdV = np.zeros(self.V.shape)dLdW = np.zeros(self.W.shape)delta_o = odelta_o[np.arange(len(y)), y] -= 1.# For each output backwards...for t in np.arange(T)[::-1]:dLdV += np.outer(delta_o[t], s[t].T)# Initial delta calculation: dL/dzdelta_t = self.V.T.dot(delta_o[t]) * (1 - (s[t] ** 2))# Backpropagation through time (for at most self.bptt_truncate steps)for bptt_step in np.arange(max(0, t-self.bptt_truncate), t+1)[::-1]:# print "Backpropagation step t=%d bptt step=%d " % (t, bptt_step)# Add to gradients at each previous stepdLdW += np.outer(delta_t, s[bptt_step-1])              dLdU[:,x[bptt_step]] += delta_t# Update delta for next step dL/dz at t-1delta_t = self.W.T.dot(delta_t) * (1 - s[bptt_step-1] ** 2)return [dLdU, dLdV, dLdW]

2.1. 初始化

结合完整的代码,我们可知梯度的维度

#100*8000
dLdU = np.zeros(self.U.shape)
#8000*100
dLdV = np.zeros(self.V.shape)
#100*100
dLdW = np.zeros(self.W.shape)

2.2. 公共部分

对照上面的理论可知,无论是V,还是U,W,都有∂E3∂z^3\frac{\partial E_3}{\partial \hat z_3}∂z^3​∂E3​​,这部分可以预先计算出来,也就是代码中的delta_o

#o是forward的输出,T(句子的实际长度)*8000维,每一行是8000维的,就是词表中所有词作为输入x中每一个词的后一个词的概率
delta_o = o
#[]中是索引操作,对y中的词对应的索引的概率-1
delta_o[np.arange(len(y)), y] -= 1.

2.3. V的梯度

s[t].Ts[t].Ts[t].T是取s[t]s[t]s[t]的转置,numpy.outer是将第一个参数和第二个参数中的所有元素分别按行展开,然后拿第一个参数中的数因此乘以第二个参数的每一行,例如a=[a0,a1,...,aM]a=[a_0, a_1, ..., a_M]a=[a0​,a1​,...,aM​], b=[b0,b1,...,bN]b=[b_0, b_1, ..., b_N]b=[b0​,b1​,...,bN​],则相乘后变成
[[a0∗b0a0∗b1...a0∗bN][a1∗b0a1∗b1...a1∗bN]...[aM∗b0aM∗b1...aM∗bN]][[a_0*b_0\quad a_0*b_1 \quad ... \quad a_0*b_N] \\ [a_1*b_0\quad a_1*b_1 \quad ... \quad a_1*b_N] \\ ... \\ [a_M*b_0\quad a_M*b_1 \quad ... \quad a_M*b_N]] [[a0​∗b0​a0​∗b1​...a0​∗bN​][a1​∗b0​a1​∗b1​...a1​∗bN​]...[aM​∗b0​aM​∗b1​...aM​∗bN​]]
结果是M*N维的

#delta_o是1*8000维向量,s[t]是1*100的向量,转不转置对outer并没有什么区别,其实和delta_o[t].T * s[t]等价,*是矩阵相乘,结果是8000*100维的矩阵
dLdV += np.outer(delta_o[t], s[t].T)

2.4. W和U的梯度

对比W和U的梯度公式,我们可以看到,两者+号的第二部分前面的系数是一样的,也就是(1−st2)∗W(1-s_t^2)*W(1−st2​)∗W,这部分可以存起来减少计算量,也就是代码中的delta_t

delta_t = self.V.T.dot(delta_o[t]) * (1 - (s[t] ** 2))
# Backpropagation through time (for at most self.bptt_truncate steps)
#截断
for bptt_step in np.arange(max(0, t-self.bptt_truncate), t+1)[::-1]:# print "Backpropagation step t=%d bptt step=%d " % (t, bptt_step)# Add to gradients at each previous step#计算+号的第一部分,第二部分本次还没得到,下次累加进来dLdW += np.outer(delta_t, s[bptt_step-1])#x为单词的位置向量,与delta_t相乘相当于dLdU按x取索引(对应的词向量)直接与delta_t相加                                  dLdU[:,x[bptt_step]] += delta_t# Update delta for next step dL/dz at t-1 #更新第二部分系数delta_t = self.W.T.dot(delta_t) * (1 - s[bptt_step-1] ** 2)

基于时间的反向传播算法BPTT(Backpropagation through time)相关推荐

  1. RNN与其反向传播算法——BPTT(Backward Propogation Through Time)的详细推导

    前言 一点感悟: 前几天简单看了下王者荣耀觉悟AI的论文,发现除了强化学习以外,也用到了熟悉的LSTM.之后我又想起了知乎上的一个问题:"Transformer会彻底取代RNN吗?" ...

  2. 随时间反向传播算法(BPTT)笔记

    随时间反向传播算法(BPTT)笔记 1.反向传播算法(BP) 以表达式f(w,x)=11+e−(w0x0+w1x1+w2)f(w,x)=\frac{1}{1+e^{-(w_0x_0+w_1x_1+w_ ...

  3. 随时间的反向传播算法 BPTT

    本文转自:https://www.cntofu.com/book/85/dl/rnn/bptt.md 随时间反向传播(BPTT)算法 先简单回顾一下RNN的基本公式: st=tanh(Uxt+Wst− ...

  4. 神经网络(NN)+反向传播算法(Backpropagation/BP)+交叉熵+softmax原理分析

    神经网络如何利用反向传播算法进行参数更新,加入交叉熵和softmax又会如何变化? 其中的数学原理分析:请点击这里. 转载于:https://www.cnblogs.com/code-wangjun/ ...

  5. 对反向传播算法(Back-Propagation)的推导与一点理解

    最近在对卷积神经网络(CNN)进行学习的过程中,发现自己之前对反向传播算法的理解不够透彻,所以今天专门写篇博客记录一下反向传播算法的推导过程,算是一份备忘录吧,有需要的朋友也可以看一下这篇文章,写的挺 ...

  6. 时间序列的反向传播算法(BPTT)

    时间序列的反向传播算法 BPTT : Back-Propagation Through Time ∂L∂U=∑t∂Lt∂U\frac{\partial L}{\partial U} = \sum_t\ ...

  7. 用反向传播算法解释大脑学习过程?Hinton 等人新研究登上 Nature 子刊

    机器之心报道 魔王.Jamin.杜伟 反向传播可以解释大脑学习吗?近日 Hinton 等人的研究认为,尽管大脑可能未实现字面形式的反向传播,但是反向传播的部分特征与理解大脑中的学习具备很强的关联性.该 ...

  8. 反向传播算法公式推导,神经网络的推导

    如何理解神经网络里面的反向传播算法 反向传播算法(Backpropagation)是目前用来训练人工神经网络(ArtificialNeuralNetwork,ANN)的最常用且最有效的算法. 其主要思 ...

  9. 反向传播算法最全解读,机器学习进阶必看!

    如果对人工智能稍有了解的小伙伴们,或多或少都听过反向传播算法这个名词,但实际上BP到底是什么?它有着怎样的魅力与优势?本文发布于 offconvex.org,作者 Sanjeev Arora与 Ten ...

最新文章

  1. 在C#中读取枚举值的描述属性
  2. 第七周项目一-一般函数(2)
  3. wxpython组件SplitterWindow 的简单使用
  4. SpringSecurity案例之认证服务搭建
  5. bootstrap.yml与application.properties区别?
  6. [11] ADB 实用功能
  7. XML文件解析 --------------------笔记
  8. 关于exe应用程序做成Windows服务爬过的坑
  9. python 打开txt_python编程之文件操作
  10. 利用域环境,处理瑞星网络版杀毒软件的问题
  11. 7-4 复数的实部和虚部(8 分)
  12. 动画:面试官问我 JS「变量提升」我头皮发麻,最后把这篇动画甩给了他
  13. 特殊字符大全-希腊字母俄文注音拼音日文序集心型方形点数绘表(转载)
  14. 幼儿学习品质提升的培养策略问卷
  15. 微服务资源库太强了,学习手册限时开源
  16. 中富金石投教怎么样?让专业投资创造更多财富机会
  17. 程序员的职业病,一定要注重身体健康才是最重要的
  18. 互联网日报 | 全国版消费券今日起开抢;微信搜一搜正式开放服务搜索接入;高德打车上线“考生专车”服务...
  19. 2. 487-3279
  20. 立足现实 与时俱进:C++ 1991-2006 reference

热门文章

  1. C++——list的模拟实现
  2. OpenFeign出现failed and no fallback available错误
  3. mysql连接池泄露_一次线上故障:数据库连接池泄露后的思考
  4. 解决 IDEA 中 thymeleaf ${} 中报波浪线错误
  5. 自动驾驶场景要求(速度方面和检测速度方面)
  6. unity海岛模型,荒岛模型,里面有木桥,木船等等
  7. 读书04《番茄工作法图解下》
  8. 命定的局限与挑战,读《命若琴弦》——leo鉴书(17)
  9. 2021.4.6 腾讯 IEG 运营开发实习面试(一面)(含总结)
  10. 网络原理考点之IP地址分配问题解题思路