通过时间反向传播

介绍循环神经网络中梯度的计算和存储方法,即通过时间反向传播(back-propagation through time)。

  • 正向传播和反向传播相互依赖。
  • 正向传播在循环神经网络中比较直观,而通过时间反向传播其实是反向传播在循环神经网络中的具体应用。
  • 我们需要将循环神经网络按时间步展开,从而得到模型变量和参数之间的依赖关系,并依据链式法则应用反向传播计算并存储梯度。

定义模型

考虑一个简单的无偏差项的循环神经网络,且激活函数为恒等映射(ϕ(x)=x\phi(x)=xϕ(x)=x)。设时间步 ttt 的输入为单样本 xt∈Rd\boldsymbol{x}_t \in \mathbb{R}^dxt​∈Rd,标签为 yty_tyt​,那么隐藏状态 ht∈Rh\boldsymbol{h}_t \in \mathbb{R}^hht​∈Rh的计算表达式为

ht=Whxxt+Whhht−1,\boldsymbol{h}_t = \boldsymbol{W}_{hx} \boldsymbol{x}_t + \boldsymbol{W}_{hh} \boldsymbol{h}_{t-1}, ht​=Whx​xt​+Whh​ht−1​,

其中Whx∈Rh×d\boldsymbol{W}_{hx} \in \mathbb{R}^{h \times d}Whx​∈Rh×d和Whh∈Rh×h\boldsymbol{W}_{hh} \in \mathbb{R}^{h \times h}Whh​∈Rh×h是隐藏层权重参数。设输出层权重参数Wqh∈Rq×h\boldsymbol{W}_{qh} \in \mathbb{R}^{q \times h}Wqh​∈Rq×h,时间步ttt的输出层变量ot∈Rq\boldsymbol{o}_t \in \mathbb{R}^qot​∈Rq计算为

ot=Wqhht.\boldsymbol{o}_t = \boldsymbol{W}_{qh} \boldsymbol{h}_{t}. ot​=Wqh​ht​.

设时间步ttt的损失为ℓ(ot,yt)\ell(\boldsymbol{o}_t, y_t)ℓ(ot​,yt​)。时间步数为TTT的损失函数LLL定义为

L=1T∑t=1Tℓ(ot,yt).L = \frac{1}{T} \sum_{t=1}^T \ell (\boldsymbol{o}_t, y_t). L=T1​t=1∑T​ℓ(ot​,yt​).

将LLL称为有关给定时间步的数据样本的目标函数。

模型计算图

为了可视化循环神经网络中模型变量和参数在计算中的依赖关系,我们可以绘制模型计算图,像下图。例如,时间步3的隐藏状态h3\boldsymbol{h}_3h3​的计算依赖模型参数Whx\boldsymbol{W}_{hx}Whx​、Whh\boldsymbol{W}_{hh}Whh​、上一时间步隐藏状态h2\boldsymbol{h}_2h2​以及当前时间步输入x3\boldsymbol{x}_3x3​。

表示了时间步数为3的循环神经网络模型计算中的依赖关系。

  • 方框代表变量(无阴影)或参数(有阴影),圆圈代表运算符

方法

图中的模型的参数是 Whx\boldsymbol{W}_{hx}Whx​, Whh\boldsymbol{W}_{hh}Whh​ 和 Wqh\boldsymbol{W}_{qh}Wqh​。训练模型通常需要模型参数的梯度∂L/∂Whx\partial L/\partial \boldsymbol{W}_{hx}∂L/∂Whx​、∂L/∂Whh\partial L/\partial \boldsymbol{W}_{hh}∂L/∂Whh​和∂L/∂Wqh\partial L/\partial \boldsymbol{W}_{qh}∂L/∂Wqh​。 图中的依赖关系,我们可以按照其中箭头所指的反方向依次计算并存储梯度。

  • 首先,目标函数有关各时间步输出层变量的梯度∂L/∂ot∈Rq\partial L/\partial \boldsymbol{o}_t \in \mathbb{R}^q∂L/∂ot​∈Rq很容易计算:

∂L∂ot=∂ℓ(ot,yt)T⋅∂ot.\frac{\partial L}{\partial \boldsymbol{o}_t} = \frac{\partial \ell (\boldsymbol{o}_t, y_t)}{T \cdot \partial \boldsymbol{o}_t}.∂ot​∂L​=T⋅∂ot​∂ℓ(ot​,yt​)​.

  • 之后,可以计算目标函数有关模型参数Wqh\boldsymbol{W}_{qh}Wqh​的梯度∂L/∂Wqh∈Rq×h\partial L/\partial \boldsymbol{W}_{qh} \in \mathbb{R}^{q \times h}∂L/∂Wqh​∈Rq×h。根据计算图,LLL通过o1,…,oT\boldsymbol{o}_1, \ldots, \boldsymbol{o}_To1​,…,oT​依赖Wqh\boldsymbol{W}_{qh}Wqh​。依据链式法则,

∂L∂Wqh=∑t=1Tprod(∂L∂ot,∂ot∂Wqh)=∑t=1T∂L∂otht⊤.\frac{\partial L}{\partial \boldsymbol{W}_{qh}} = \sum_{t=1}^T \text{prod}\left(\frac{\partial L}{\partial \boldsymbol{o}_t}, \frac{\partial \boldsymbol{o}_t}{\partial \boldsymbol{W}_{qh}}\right) = \sum_{t=1}^T \frac{\partial L}{\partial \boldsymbol{o}_t} \boldsymbol{h}_t^\top. ∂Wqh​∂L​=t=1∑T​prod(∂ot​∂L​,∂Wqh​∂ot​​)=t=1∑T​∂ot​∂L​ht⊤​.

  • 其次,隐藏状态之间也存在依赖关系。 在计算图中,LLL只通过oT\boldsymbol{o}_ToT​依赖最终时间步TTT的隐藏状态hT\boldsymbol{h}_ThT​。因此,我们先计算目标函数有关最终时间步隐藏状态的梯度∂L/∂hT∈Rh\partial L/\partial \boldsymbol{h}_T \in \mathbb{R}^h∂L/∂hT​∈Rh。依据链式法则,我们得到

∂L∂hT=prod(∂L∂oT,∂oT∂hT)=Wqh⊤∂L∂oT.\frac{\partial L}{\partial \boldsymbol{h}_T} = \text{prod}\left(\frac{\partial L}{\partial \boldsymbol{o}_T}, \frac{\partial \boldsymbol{o}_T}{\partial \boldsymbol{h}_T} \right) = \boldsymbol{W}_{qh}^\top \frac{\partial L}{\partial \boldsymbol{o}_T}. ∂hT​∂L​=prod(∂oT​∂L​,∂hT​∂oT​​)=Wqh⊤​∂oT​∂L​.

  • 接下来对于时间步t<Tt < Tt<T, 在计算图中,LLL通过ht+1\boldsymbol{h}_{t+1}ht+1​和ot\boldsymbol{o}_tot​依赖ht\boldsymbol{h}_tht​。依据链式法则, 目标函数有关时间步t<Tt < Tt<T的隐藏状态的梯度∂L/∂ht∈Rh\partial L/\partial \boldsymbol{h}_t \in \mathbb{R}^h∂L/∂ht​∈Rh需要按照时间步从大到小依次计算:
    ∂L∂ht=prod(∂L∂ht+1,∂ht+1∂ht)+prod(∂L∂ot,∂ot∂ht)=Whh⊤∂L∂ht+1+Wqh⊤∂L∂ot\frac{\partial L}{\partial \boldsymbol{h}_t} = \text{prod} (\frac{\partial L}{\partial \boldsymbol{h}{t+1}}, \frac{\partial \boldsymbol{h}{t+1}}{\partial \boldsymbol{h}_t}) + \text{prod} (\frac{\partial L}{\partial \boldsymbol{o}_t}, \frac{\partial \boldsymbol{o}_t}{\partial \boldsymbol{h}_t} ) = \boldsymbol{W}_{hh}^\top \frac{\partial L}{\partial \boldsymbol{h}_{t+1}} + \boldsymbol{W}_{qh}^\top \frac{\partial L}{\partial \boldsymbol{o}_t} ∂ht​∂L​=prod(∂ht+1∂L​,∂ht​∂ht+1​)+prod(∂ot​∂L​,∂ht​∂ot​​)=Whh⊤​∂ht+1​∂L​+Wqh⊤​∂ot​∂L​

  • 将上面的递归公式展开,对任意时间步1≤t≤T1 \leq t \leq T1≤t≤T,我们可以得到目标函数有关隐藏状态梯度的通项公式

∂L∂ht=∑i=tT(Whh⊤)T−iWqh⊤∂L∂oT+t−i.\frac{\partial L}{\partial \boldsymbol{h}_t} = \sum_{i=t}^T {\left(\boldsymbol{W}_{hh}^\top\right)}^{T-i} \boldsymbol{W}_{qh}^\top \frac{\partial L}{\partial \boldsymbol{o}_{T+t-i}}. ∂ht​∂L​=i=t∑T​(Whh⊤​)T−iWqh⊤​∂oT+t−i​∂L​.

由上式中的指数项可见,当时间步数 TTT 较大或者时间步 ttt 较小时,目标函数有关隐藏状态的梯度较容易出现衰减和爆炸。这也会影响其他包含∂L/∂ht\partial L / \partial \boldsymbol{h}_t∂L/∂ht​项的梯度,例如隐藏层中模型参数的梯度∂L/∂Whx∈Rh×d\partial L / \partial \boldsymbol{W}_{hx} \in \mathbb{R}^{h \times d}∂L/∂Whx​∈Rh×d和∂L/∂Whh∈Rh×h\partial L / \partial \boldsymbol{W}_{hh} \in \mathbb{R}^{h \times h}∂L/∂Whh​∈Rh×h。 在图中,LLL通过h1,…,hT\boldsymbol{h}_1, \ldots, \boldsymbol{h}_Th1​,…,hT​依赖这些模型参数。 依据链式法则,有

∂L∂Whx=∑t=1Tprod(∂L∂ht,∂ht∂Whx)=∑t=1T∂L∂htxt⊤,∂L∂Whh=∑t=1Tprod(∂L∂ht,∂ht∂Whh)=∑t=1T∂L∂htht−1⊤.\begin{aligned} \frac{\partial L}{\partial \boldsymbol{W}_{hx}} &= \sum_{t=1}^T \text{prod}\left(\frac{\partial L}{\partial \boldsymbol{h}_t}, \frac{\partial \boldsymbol{h}_t}{\partial \boldsymbol{W}_{hx}}\right) = \sum_{t=1}^T \frac{\partial L}{\partial \boldsymbol{h}_t} \boldsymbol{x}_t^\top,\\ \ \frac{\partial L}{\partial \boldsymbol{W}_{hh}} &= \sum_{t=1}^T \text{prod}\left(\frac{\partial L}{\partial \boldsymbol{h}_t}, \frac{\partial \boldsymbol{h}_t}{\partial \boldsymbol{W}_{hh}}\right) = \sum_{t=1}^T \frac{\partial L}{\partial \boldsymbol{h}_t} \boldsymbol{h}_{t-1}^\top. \end{aligned} ∂Whx​∂L​ ∂Whh​∂L​​=t=1∑T​prod(∂ht​∂L​,∂Whx​∂ht​​)=t=1∑T​∂ht​∂L​xt⊤​,=t=1∑T​prod(∂ht​∂L​,∂Whh​∂ht​​)=t=1∑T​∂ht​∂L​ht−1⊤​.​

每次迭代中,我们在依次计算完以上各个梯度后,会将它们存储起来,从而避免重复计算。

  • 例如,由于隐藏状态梯度∂L/∂ht\partial L/\partial \boldsymbol{h}_t∂L/∂ht​被计算和存储,之后的模型参数梯度∂L/∂Whx\partial L/\partial \boldsymbol{W}_{hx}∂L/∂Whx​和∂L/∂Whh\partial L/\partial \boldsymbol{W}_{hh}∂L/∂Whh​的计算可以直接读取∂L/∂ht\partial L/\partial \boldsymbol{h}_t∂L/∂ht​的值,而无须重复计算它们。
  • 此外,反向传播中的梯度计算可能会依赖变量的当前值。它们正是通过正向传播计算出来的。 举例来说,参数梯度∂L/∂Whh\partial L/\partial \boldsymbol{W}_{hh}∂L/∂Whh​的计算需要依赖隐藏状态在时间步t=0,…,T−1t = 0, \ldots, T-1t=0,…,T−1的当前值ht\boldsymbol{h}_tht​(h0\boldsymbol{h}_0h0​是初始化得到的)。这些值是通过从输入层到输出层的正向传播计算并存储得到的。

(pytorch-深度学习)通过时间反向传播相关推荐

  1. 【李宏毅机器学习2021】Task04 深度学习介绍和反向传播机制

    [李宏毅机器学习2021]本系列是针对datawhale<李宏毅机器学习-2022 10月>的学习笔记.本次是对深度学习介绍和反向传播机制的学习总结.本节针对上节课内容,对batch.梯度 ...

  2. 深度学习入门-误差反向传播法(人工神经网络实现mnist数据集识别)

    文章目录 误差反向传播法 5.1 链式法则与计算图 5.2 计算图代码实践 5.3激活函数层的实现 5.4 简单矩阵求导 5.5 Affine 层的实现 5.6 softmax-with-loss层计 ...

  3. 机器学习(深度学习)中的反向传播算法与梯度下降

    这是自己在CSDN的第一篇博客,目的是为了给自己学习过的知识做一个总结,方便后续温习,避免每次都重复搜索相关文章. 一.反向传播算法 定义:反向传播(Backpropagation,缩写为BP)是&q ...

  4. 《深度学习》李宏毅 -- task4深度学习介绍和反向传播机制

    深度学习的三个步骤 Step1:神经网络(Neural network) Step2:模型评估(Goodness of function) Step3:选择最优函数(Pick best functio ...

  5. Datawhale 7月学习——李弘毅深度学习:深度学习介绍和反向传播机制

    前情回顾 机器学习简介 回归 误差与梯度下降 1 深度学习简介 1.1 深度学习的历史 李宏毅老师带我们简要回顾了深度学习的历史. 1958: Perceptron (linear model) 19 ...

  6. 深度学习之误差反向传播法

  7. 【项目实战】vue-springboot-pytorch前后端结合pytorch深度学习 html打开本地摄像头 监控人脸和记录时间

    是一个项目的一个功能之一,调试了两小时,终于能够 javascript设置开始计和暂停计时 监控人脸 记录时间了 效果图: 离开页面之后回到页面会从0计时(不是关闭页面,而是页面失去焦点) 离开摄像头 ...

  8. pytorch深度学习_用于数据科学家的深度学习的最小pytorch子集

    pytorch深度学习 PyTorch has sort of became one of the de facto standards for creating Neural Networks no ...

  9. pytorch深度学习入门笔记

    Pytorch 深度学习入门笔记 作者:梅如你 学习来源: 公众号: 阿力阿哩哩.土堆碎念 B站视频:https://www.bilibili.com/video/BV1hE411t7RN? 中国大学 ...

最新文章

  1. 卷积神经网络新手指南 2
  2. keepalived程序包
  3. [*leetcode 5] Longest Palindromic Substring
  4. Python 个人的失误记录之str.replace
  5. 调用支付宝PHP接口API实现在线即时支付功能(UTF-8编码)
  6. APP网络测试要点和弱网模拟
  7. 如何随机选取1000个关键字
  8. 【数据结构与算法】二项队列的Java实现
  9. docker镜像与容器的区别
  10. 大数据Hadoop学习记录(5)----Ubuntu16.4下安装配置HBase
  11. redis的其他功能
  12. linux 内核编程
  13. 鸿鹄论坛oracle资料,鸿鹄论坛_HCNA-Storage (H13-611)题库 v4.0.pdf
  14. 干货!基于深度空间一致性的鲁棒点云配准算法
  15. java中的tld_自定义标签tld的使用
  16. 【Cython】Cython 基本用法
  17. python利用mysql数据库实现一个中英文翻译程序兼单词试卷生成并改阅功能,并可以爬取有道官网进行在线翻译。
  18. 服务器的类型包括哪些?
  19. Java数据结构——排序二叉树
  20. 把款软件可以测试双显卡,以进步之名! APU双显卡的混交测试

热门文章

  1. android默认exported_android:exported 属性详解-阿里云开发者社区
  2. mysql 报错解决思考Expression #5 of SELECT list is not in GROUP BY clause and contains nonaggregated column
  3. java 如何将数字倒置_每日一个小算法之整数中每位上的数字进行反转 20190810
  4. 对微软实习生或者工作感兴趣的读者, 目前我的项目是...
  5. python3.6安装tensorflow gpu_tensorflow-gpu安装的常见问题及解决方案
  6. sql移动加权计算利润_计算机视觉中的半监督学习
  7. java 多态_Java面向对象 —— 多态
  8. 大并发下程序出错_Python并发编程理论篇
  9. android复杂列表滑动卡顿,Android 列表滑动性能优化总结
  10. 【LeetCode笔记】143. 重排链表(Java、链表、栈、快慢指针)