前言

本节将介绍循环神经网络中梯度的计算和存储方法,即 通过时间反向传播(back-propagation through time)

正向传播在循环神经网络中比较直观,而通过时间反向传播其实是反向传播在循环神经网络中的具体应用。我们需要将循环神经网络按时间步展开,从而得到模型变量和参数之间的依赖关系,并依据链式法则应用反向传播计算并存储梯度。

1. 定义模型

简单起见,我们考虑一个无偏差项的循环神经网络,且激活函数为恒等映射(ϕ(x)=x\phi(x)=xϕ(x)=x)。设时间步 ttt 的输入为单样本 xt∈Rd\boldsymbol{x}_t \in \mathbb{R}^dxt​∈Rd,标签为 yty_tyt​,那么隐藏状态 ht∈Rh\boldsymbol{h}_t \in \mathbb{R}^hht​∈Rh的计算表达式为

ht=Whxxt+Whhht−1,\boldsymbol{h}_t = \boldsymbol{W}_{hx} \boldsymbol{x}_t + \boldsymbol{W}_{hh} \boldsymbol{h}_{t-1}, ht​=Whx​xt​+Whh​ht−1​,

其中Whx∈Rh×d\boldsymbol{W}_{hx} \in \mathbb{R}^{h \times d}Whx​∈Rh×d和Whh∈Rh×h\boldsymbol{W}_{hh} \in \mathbb{R}^{h \times h}Whh​∈Rh×h是隐藏层权重参数。设输出层权重参数Wqh∈Rq×h\boldsymbol{W}_{qh} \in \mathbb{R}^{q \times h}Wqh​∈Rq×h,时间步ttt的输出层变量ot∈Rq\boldsymbol{o}_t \in \mathbb{R}^qot​∈Rq计算为

ot=Wqhht.\boldsymbol{o}_t = \boldsymbol{W}_{qh} \boldsymbol{h}_{t}. ot​=Wqh​ht​.

设时间步ttt的损失为ℓ(ot,yt)\ell(\boldsymbol{o}_t, y_t)ℓ(ot​,yt​)。时间步数为TTT的损失函数LLL定义为

L=1T∑t=1Tℓ(ot,yt).L = \frac{1}{T} \sum_{t=1}^T \ell (\boldsymbol{o}_t, y_t). L=T1​t=1∑T​ℓ(ot​,yt​).

我们将LLL称为有关给定时间步的数据样本的目标函数,并在本节后续讨论中简称为目标函数。

2. 模型计算图

为了可视化循环神经网络中模型变量和参数在计算中的依赖关系,我们可以绘制模型计算图,如图6.3所示。例如,时间步3的隐藏状态h3\boldsymbol{h}_3h3​的计算依赖模型参数Whx\boldsymbol{W}_{hx}Whx​、Whh\boldsymbol{W}_{hh}Whh​、上一时间步隐藏状态h2\boldsymbol{h}_2h2​以及当前时间步输入x3\boldsymbol{x}_3x3​。

3. 方法

刚刚提到,图6.3中的模型的参数是 Whx\boldsymbol{W}_{hx}Whx​, Whh\boldsymbol{W}_{hh}Whh​ 和 Wqh\boldsymbol{W}_{qh}Wqh​。与3.14节(正向传播、反向传播和计算图)中的类似,训练模型通常需要模型参数的梯度∂L/∂Whx\partial L/\partial \boldsymbol{W}_{hx}∂L/∂Whx​、∂L/∂Whh\partial L/\partial \boldsymbol{W}_{hh}∂L/∂Whh​和∂L/∂Wqh\partial L/\partial \boldsymbol{W}_{qh}∂L/∂Wqh​。
根据图6.3中的依赖关系,我们可以按照其中箭头所指的反方向依次计算并存储梯度。为了表述方便,我们采用运算符prod表达链式法则。

首先,目标函数有关各时间步输出层变量的梯度∂L/∂ot∈Rq\partial L/\partial \boldsymbol{o}_t \in \mathbb{R}^q∂L/∂ot​∈Rq很容易计算:

∂L∂ot=∂ℓ(ot,yt)T⋅∂ot.\frac{\partial L}{\partial \boldsymbol{o}_t} = \frac{\partial \ell (\boldsymbol{o}_t, y_t)}{T \cdot \partial \boldsymbol{o}_t}.∂ot​∂L​=T⋅∂ot​∂ℓ(ot​,yt​)​.

下面,我们可以计算目标函数有关模型参数Wqh\boldsymbol{W}_{qh}Wqh​的梯度∂L/∂Wqh∈Rq×h\partial L/\partial \boldsymbol{W}_{qh} \in \mathbb{R}^{q \times h}∂L/∂Wqh​∈Rq×h。根据图6.3,LLL通过o1,…,oT\boldsymbol{o}_1, \ldots, \boldsymbol{o}_To1​,…,oT​依赖Wqh\boldsymbol{W}_{qh}Wqh​。依据链式法则,

∂L∂Wqh=∑t=1Tprod(∂L∂ot,∂ot∂Wqh)=∑t=1T∂L∂otht⊤.\frac{\partial L}{\partial \boldsymbol{W}_{qh}} = \sum_{t=1}^T \text{prod}\left(\frac{\partial L}{\partial \boldsymbol{o}_t}, \frac{\partial \boldsymbol{o}_t}{\partial \boldsymbol{W}_{qh}}\right) = \sum_{t=1}^T \frac{\partial L}{\partial \boldsymbol{o}_t} \boldsymbol{h}_t^\top. ∂Wqh​∂L​=t=1∑T​prod(∂ot​∂L​,∂Wqh​∂ot​​)=t=1∑T​∂ot​∂L​ht⊤​.

其次,我们注意到隐藏状态之间也存在依赖关系。
在图6.3中,LLL只通过oT\boldsymbol{o}_ToT​依赖最终时间步TTT的隐藏状态hT\boldsymbol{h}_ThT​。因此,我们先计算目标函数有关最终时间步隐藏状态的梯度∂L/∂hT∈Rh\partial L/\partial \boldsymbol{h}_T \in \mathbb{R}^h∂L/∂hT​∈Rh。依据链式法则,我们得到

∂L∂hT=prod(∂L∂oT,∂oT∂hT)=Wqh⊤∂L∂oT.\frac{\partial L}{\partial \boldsymbol{h}_T} = \text{prod}\left(\frac{\partial L}{\partial \boldsymbol{o}_T}, \frac{\partial \boldsymbol{o}_T}{\partial \boldsymbol{h}_T} \right) = \boldsymbol{W}_{qh}^\top \frac{\partial L}{\partial \boldsymbol{o}_T}. ∂hT​∂L​=prod(∂oT​∂L​,∂hT​∂oT​​)=Wqh⊤​∂oT​∂L​.

接下来对于时间步t<Tt < Tt<T, 在图6.3中,LLL通过ht+1\boldsymbol{h}_{t+1}ht+1​和ot\boldsymbol{o}_tot​依赖ht\boldsymbol{h}_tht​。依据链式法则,
目标函数有关时间步t<Tt < Tt<T的隐藏状态的梯度∂L/∂ht∈Rh\partial L/\partial \boldsymbol{h}_t \in \mathbb{R}^h∂L/∂ht​∈Rh需要按照时间步从大到小依次计算:
∂L∂ht=prod(∂L∂ht+1,∂ht+1∂ht)+prod(∂L∂ot,∂ot∂ht)=Whh⊤∂L∂ht+1+Wqh⊤∂L∂ot\frac{\partial L}{\partial \boldsymbol{h}_t} = \text{prod} (\frac{\partial L}{\partial \boldsymbol{h}_{t+1}}, \frac{\partial \boldsymbol{h}_{t+1}}{\partial \boldsymbol{h}_t}) + \text{prod} (\frac{\partial L}{\partial \boldsymbol{o}_t}, \frac{\partial \boldsymbol{o}_t}{\partial \boldsymbol{h}_t} ) = \boldsymbol{W}_{hh}^\top \frac{\partial L}{\partial \boldsymbol{h}_{t+1}} + \boldsymbol{W}_{qh}^\top \frac{\partial L}{\partial \boldsymbol{o}_t} ∂ht​∂L​=prod(∂ht+1​∂L​,∂ht​∂ht+1​​)+prod(∂ot​∂L​,∂ht​∂ot​​)=Whh⊤​∂ht+1​∂L​+Wqh⊤​∂ot​∂L​

将上面的递归公式展开,对任意时间步1≤t≤T1 \leq t \leq T1≤t≤T,我们可以得到目标函数有关隐藏状态梯度的通项公式

∂L∂ht=∑i=tT(Whh⊤)T−iWqh⊤∂L∂oT+t−i.\frac{\partial L}{\partial \boldsymbol{h}_t} = \sum_{i=t}^T {\left(\boldsymbol{W}_{hh}^\top\right)}^{T-i} \boldsymbol{W}_{qh}^\top \frac{\partial L}{\partial \boldsymbol{o}_{T+t-i}}. ∂ht​∂L​=i=t∑T​(Whh⊤​)T−iWqh⊤​∂oT+t−i​∂L​.

由上式中的指数项可见,当时间步数 TTT 较大或者时间步 ttt 较小时,目标函数有关隐藏状态的梯度较容易出现 衰减爆炸。这也会影响其他包含∂L/∂ht\partial L / \partial \boldsymbol{h}_t∂L/∂ht​项的梯度,例如隐藏层中模型参数的梯度∂L/∂Whx∈Rh×d\partial L / \partial \boldsymbol{W}_{hx} \in \mathbb{R}^{h \times d}∂L/∂Whx​∈Rh×d和∂L/∂Whh∈Rh×h\partial L / \partial \boldsymbol{W}_{hh} \in \mathbb{R}^{h \times h}∂L/∂Whh​∈Rh×h。
在图6.3中,LLL通过h1,…,hT\boldsymbol{h}_1, \ldots, \boldsymbol{h}_Th1​,…,hT​依赖这些模型参数。
依据链式法则,我们有

∂L∂Whx=∑t=1Tprod(∂L∂ht,∂ht∂Whx)=∑t=1T∂L∂htxt⊤,∂L∂Whh=∑t=1Tprod(∂L∂ht,∂ht∂Whh)=∑t=1T∂L∂htht−1⊤.\begin{aligned} \frac{\partial L}{\partial \boldsymbol{W}_{hx}} &= \sum_{t=1}^T \text{prod}\left(\frac{\partial L}{\partial \boldsymbol{h}_t}, \frac{\partial \boldsymbol{h}_t}{\partial \boldsymbol{W}_{hx}}\right) = \sum_{t=1}^T \frac{\partial L}{\partial \boldsymbol{h}_t} \boldsymbol{x}_t^\top,\\ \frac{\partial L}{\partial \boldsymbol{W}_{hh}} &= \sum_{t=1}^T \text{prod}\left(\frac{\partial L}{\partial \boldsymbol{h}_t}, \frac{\partial \boldsymbol{h}_t}{\partial \boldsymbol{W}_{hh}}\right) = \sum_{t=1}^T \frac{\partial L}{\partial \boldsymbol{h}_t} \boldsymbol{h}_{t-1}^\top. \end{aligned} ∂Whx​∂L​∂Whh​∂L​​=t=1∑T​prod(∂ht​∂L​,∂Whx​∂ht​​)=t=1∑T​∂ht​∂L​xt⊤​,=t=1∑T​prod(∂ht​∂L​,∂Whh​∂ht​​)=t=1∑T​∂ht​∂L​ht−1⊤​.​

每次迭代中,我们在依次计算完以上各个梯度后,会将它们存储起来,从而避免重复计算。例如,由于隐藏状态梯度∂L/∂ht\partial L/\partial \boldsymbol{h}_t∂L/∂ht​被计算和存储,之后的模型参数梯度∂L/∂Whx\partial L/\partial \boldsymbol{W}_{hx}∂L/∂Whx​和∂L/∂Whh\partial L/\partial \boldsymbol{W}_{hh}∂L/∂Whh​的计算可以直接读取∂L/∂ht\partial L/\partial \boldsymbol{h}_t∂L/∂ht​的值,而无须重复计算它们。此外,反向传播中的梯度计算可能会依赖变量的当前值。它们正是通过正向传播计算出来的。
举例来说,参数梯度∂L/∂Whh\partial L/\partial \boldsymbol{W}_{hh}∂L/∂Whh​的计算需要依赖隐藏状态在时间步t=0,…,T−1t = 0, \ldots, T-1t=0,…,T−1的当前值ht\boldsymbol{h}_tht​(h0\boldsymbol{h}_0h0​是初始化得到的)。这些值是通过从输入层到输出层的正向传播计算并存储得到的。

小结

  • 通过时间反向传播是反向传播在循环神经网络中的具体应用。
  • 当总的时间步数较大或者当前时间步较小时,循环神经网络的梯度较容易出现衰减或爆炸。

pytorch学习笔记(三十):RNN反向传播计算图公式推导相关推荐

  1. tensorflow学习笔记(三十二):conv2d_transpose (解卷积)

    tensorflow学习笔记(三十二):conv2d_transpose ("解卷积") deconv解卷积,实际是叫做conv_transpose, conv_transpose ...

  2. Mr.J-- jQuery学习笔记(三十二)--jQuery属性操作源码封装

    扫码看专栏 jQuery的优点 jquery是JavaScript库,能够极大地简化JavaScript编程,能够更方便的处理DOM操作和进行Ajax交互 1.轻量级 JQuery非常轻巧 2.强大的 ...

  3. nndl学习笔记(二)反向传播公式推导与详解

    写在前面 反向传播回顾 为什么需要反向传播? 基本思想 算法流程 算法局限性 详细推导(核心:多元微积分的*链式法则*) 一些定义 1. 输出层的误差δL\delta^{L}δL 2. 利用下一层误差 ...

  4. July深度学习笔记之神经网络与反向传播算法

    July深度学习笔记之神经网络与反向传播算法 一.神经网络 神经网络的大致结构如下: 大致可以分为输入层.隐藏层与输出层. 而我们可以单独拿出来一个结点,可以发现,其实它就是类似一个逻辑回归(LR), ...

  5. 【Pytorch学习笔记三】Pytorch神经网络包nn和优化器optm(一个简单的卷积神经网络模型的搭建)

    文章目录 一, 神经网络包nn 1.1定义一个网络 1.2 损失函数 二.优化器 nn构建于 Autograd之上,可用来定义和运行神经网络, PyTorch Autograd 让我们定义计算图和计算 ...

  6. 深度学习笔记:04依赖反向传播改进神经网络数据处理的精确度

    04依赖反向传播改进神经网络数据处理的精确度 1.反向传播简介 前面说过,神经网络模型中,需要修正的参数是神经元链路之间的权重值,问题在于如何修改,如下图,假定最后神经元输出结果跟正确结果对比后得到一 ...

  7. nndl学习笔记(一)反向传播公式总结

    nndl是什么? 反向传播算法简介 定义&公式 基本思想 Back Propagation四个基本方程 算法表示 Python实现 nndl是什么? <神经网络与深度学习>(< ...

  8. pytorch学习笔记(十六):Parameters

    文章目录 1. 访问模型参数 2. 初始化模型参数 3. 自定义初始化方法 4. 共享模型参数 小结 本节将深入讲解如何访问和初始化模型参数,以及如何在多个层之间共享同一份模型参数. 我们先定义含单隐 ...

  9. pytorch学习笔记(十五):模型构造

    文章目录 1. 继承Module类来构造模型 2. Module的子类 2.1 Sequential类 2.2 ModuleList类 2.3 ModuleDict类 3. 构造复杂的模型 小结 这里 ...

最新文章

  1. [USACO06NOV]玉米田Corn Fields(动态规划,状态压缩)
  2. 通用无线设备对码软件_珞光全新发布国产通用软件无线电平台 :USRP-LW N310!珞光品牌已实现国产替代...
  3. .Net性能调优-垃圾回收!!!最全垃圾回收来了
  4. C/C++好不好学习呢?
  5. matlab 同一坐标系 散点图 t,matlab上机练习
  6. CVTE【嵌入式应用开发】【软件技术支持】面经【已拿offer】
  7. 远程同步修改云服务器上的文件
  8. JDK8的LocalDateTime用法
  9. java web项目_一个完整JavaWeb项目开发总结
  10. 最新米酷6.26影视源码+解析接口+步骤
  11. 计算机系统的用户分几类,计算机操作系统的几种分类方式
  12. 记录MySQL中JSON_EXTRACT JSON_UNQUOTE函数的使用方式
  13. mysql to sqlserver_mysql to sqlserver
  14. 共阳极管的代码_1.共阳极数码管是将发光二极管的_____连接在一起,字符5的共阳代码为_____,字符2的共阴代码为 _____。...
  15. obs多推流地址_如何用OBS将腾讯会议推流到一直播上进行直播
  16. 2021.07.28
  17. 认识 DELL EMC VPLEX VS6物理配置
  18. 编译imx6 android,SAIL-IMX6Q ANDROID开发环境搭建与系统编译
  19. leetcode 最佳买卖股票时机含冷冻期(Java)
  20. Android Studio导入项目提示“Unrecognized Android Studio”

热门文章

  1. 指向 类成员函数 指针的用法
  2. 计算机电源管理设置,关于电源管理的电源管理计划设置
  3. redis使用lua脚本
  4. MAC下安装xgboost
  5. WPF开发为按钮提供添加,删除和重新排列ListBox内容的功能
  6. hive高级数据类型
  7. JQuery判断元素是否存在
  8. SHELL编写NGINX自动部署脚本
  9. 自己动手写处理器之第一阶段(3)——MIPS32指令集架构简单介绍
  10. 归心似箭,IT达人分享抢票攻略