正向传播、反向传播和计算图

1. 正向传播

正向传播是指对神经网络沿着从输入层到输出层的顺序,依次计算并存储模型的中间变量(包括输出)。假设输入是一个特征为x∈Rdx \in R^dx∈Rd的样本,且不考虑偏差项,那么中间变量:
z=W(1)x(1)z = W^{(1)}x \tag 1 z=W(1)x(1)
其中W(1)∈Rh×dW^{(1)} \in R^{h \times d}W(1)∈Rh×d是隐藏层的权重参数。把中间变量z∈Rhz \in R^hz∈Rh输入按元素运算的激活函数ϕ\phiϕ后,将得到向量长度为hhh的隐藏层变量:
h=ϕ(z)(2)h=\phi(z) \tag 2 h=ϕ(z)(2)
隐藏层变量hhh也是一个中间变量。假设输出层参数只有权重W(2)∈Rq×hW^{(2)} \in R^{q \times h}W(2)∈Rq×h,可以得到向量长度为qqq的输出层变量:
o=W(2)h(3)o = W^{(2)}h \tag 3 o=W(2)h(3)
假设损失函数为lll,且样本标签为yyy,可以计算出单个数据样本的损失项:
L=l(o,y)(4)L=l(o,y) \tag 4 L=l(o,y)(4)
根据L2L_2L2​范数正则化的定义,给定超参数λ\lambdaλ,正则化项即:
s=λ2(∣∣W(1)∣∣F2+∣∣W(2)∣∣F2)(5)s=\frac{\lambda}{2} (||W^{(1)}||_F^2+||W^{(2)}||_F^2) \tag 5 s=2λ​(∣∣W(1)∣∣F2​+∣∣W(2)∣∣F2​)(5)
其中矩阵的Frobenius范数等价于将矩阵变平为向量后计算L2L_2L2​范数。最终,模型在给定的数据样本上带正则化的损失为:
J=L+s(6)J = L + s \tag 6 J=L+s(6)

2. 正向传播的计算图

通常绘制计算图来可视化运算符和变量在计算中的依赖关系,一般来说,计算图中左下角是输入,右上角是输出。其中方框代表变量,圆圈代表运算符,箭头表示从输入到输出之间的依赖关系。

3. 反向传播

反向传播指的是计算神经网络参数梯度的方法。总的来说,反向传播依据微积分中的链式法则,沿着从输出层到输入层的顺序,依次计算并存储目标函数有关神经网络各层的中间变量以及参数的梯度。对输入或输出X,Y,ZX,Y,ZX,Y,Z为任意形状张量的函数Y=f(X)Y=f(X)Y=f(X)和Z=g(Y)Z=g(Y)Z=g(Y),通过链式法则,有:
∂Z∂X=∏(∂Z∂Y,∂Y∂X)(7)\frac{\partial Z}{\partial X} = \prod(\frac{\partial Z}{\partial Y}, \frac{\partial Y}{\partial X}) \tag 7 ∂X∂Z​=∏(∂Y∂Z​,∂X∂Y​)(7)
其中prod运算将根据两个输入的形状,在必要的操作(如转置和互换输入位置)后对两个输入做乘法。

本例中的模型,它的参数是W(1)W^{(1)}W(1)和W(2)W^{(2)}W(2),因此反向传播的目标是计算∂J∂W(1)\frac{\partial J}{\partial W^{(1)}}∂W(1)∂J​和∂J∂W(2)\frac{\partial J}{\partial W^{(2)}}∂W(2)∂J​。应用链式法则则依次计算各中间变量和参数的梯度,其计算次序与前向传播中相应中间变量的计算次序恰恰相反。

首先,分别计算目标函数J=L+sJ=L+sJ=L+s有关损失项LLL和正则项sss的梯度:
∂J∂L=1,∂J∂s=1(8)\frac{\partial J}{\partial L} = 1, \frac{\partial J}{\partial s}=1 \tag 8 ∂L∂J​=1,∂s∂J​=1(8)
其次,依据链式法则计算目标函数有关输出层变量的梯度∂J∂o∈Rq\frac{\partial J}{\partial o} \in R^q∂o∂J​∈Rq:
∂J∂o=∏(∂J∂L,∂L∂o)=∂L∂o(9)\frac{\partial J}{\partial o}=\prod(\frac{\partial J}{\partial L}, \frac{\partial L}{\partial o})=\frac{\partial L}{\partial o} \tag 9 ∂o∂J​=∏(∂L∂J​,∂o∂L​)=∂o∂L​(9)
接下来,计算正则项有关两个参数的梯度:
∂s∂W(1)=λW(1),∂s∂W(2)=λW(2)(10)\frac{\partial s}{\partial W^{(1)}}=\lambda W^{(1)}, \frac{\partial s}{\partial W^{(2)}}=\lambda W^{(2)} \tag {10} ∂W(1)∂s​=λW(1),∂W(2)∂s​=λW(2)(10)
现在,我们可计算最靠近输出层的模型参数的梯度∂J∂W(2)∈Rq×h\frac{\partial J}{\partial W^{(2)}} \in R^{q \times h}∂W(2)∂J​∈Rq×h。依据链式法则,得到:
∂J∂W(2)=∏(∂J∂o,∂o∂W(2))+∏(∂J∂s,∂s∂W(2))=∂J∂ohT+λW(2)(11)\frac{\partial J}{\partial W^{(2)}}=\prod(\frac{\partial J}{\partial o}, \frac{\partial o}{\partial W^{(2)}})+\prod(\frac{\partial J}{\partial s}, \frac{\partial s}{\partial W^{(2)}})=\frac{\partial J}{\partial o}h^T + \lambda W^{(2)} \tag {11} ∂W(2)∂J​=∏(∂o∂J​,∂W(2)∂o​)+∏(∂s∂J​,∂W(2)∂s​)=∂o∂J​hT+λW(2)(11)
沿着输出层向隐藏层继续反向传播,隐藏层变量的梯度∂J∂h∈Rh\frac{\partial J}{\partial h} \in R^h∂h∂J​∈Rh:
∂J∂h=∏(∂J∂o,∂o∂h)=W(2)T∂J∂o(12)\frac{\partial J}{\partial h}=\prod(\frac{\partial J}{\partial o}, \frac{\partial o}{\partial h})=W^{{(2)}^T} \frac{\partial J}{\partial o} \tag {12} ∂h∂J​=∏(∂o∂J​,∂h∂o​)=W(2)T∂o∂J​(12)
由于激活函数ϕ\phiϕ是按元素运算的,中间变量zzz的梯度∂J∂z∈Rh\frac{\partial J}{\partial z} \in R^h∂z∂J​∈Rh的计算需要使用按元素乘法符⊙\odot⊙:
∂J∂z=∏(∂J∂h,∂h∂z)=∂J∂h⊙ϕ′(z)(13)\frac{\partial J}{\partial z}=\prod(\frac{\partial J}{\partial h}, \frac{\partial h}{\partial z})=\frac{\partial J}{\partial h} \odot \phi^{'}(z) \tag {13} ∂z∂J​=∏(∂h∂J​,∂z∂h​)=∂h∂J​⊙ϕ′(z)(13)
最终,可以得到最靠近输入层的模型参数的梯度∂J∂W(1)∈Rh×d\frac{\partial J}{\partial W^{(1)}} \in R^{h \times d}∂W(1)∂J​∈Rh×d。依据链式法则,得到:
∂J∂W(1)=∏(∂J∂z,∂z∂W(1))+∏(∂J∂s,∂s∂W(1))=∂J∂zxT+λW(1)(14)\frac{\partial J}{\partial W^{(1)}}=\prod(\frac{\partial J}{\partial z}, \frac{\partial z}{\partial W^{(1)}})+\prod(\frac{\partial J}{\partial s}, \frac{\partial s}{\partial W^{(1)}})=\frac{\partial J}{\partial z}x^T+\lambda W^{(1)} \tag {14} ∂W(1)∂J​=∏(∂z∂J​,∂W(1)∂z​)+∏(∂s∂J​,∂W(1)∂s​)=∂z∂J​xT+λW(1)(14)

在训练模型时,正向传播与反向传播互相依赖。

正向传播、反向传播和计算图相关推荐

  1. 深度学习-正向传播反向传播

    正向传播 对于神经元来说,训练的目标就是确认最优的w和b,使得输出值和真实值之间的误差变得最小. 数据从输入到输出,一层一层的进行运算,在输出层输出一个预测值y (理解:正向传播,多个输入层-> ...

  2. 神经网络正向与反向传播

    一.神经网络的前向传播原理 在全连接神经网络中,每一层的每个神经元都会与前一层的所有神经元或者输入数据相连,例如图中的 f1(e)f _1 ( e )f1​(e)就与x1x_1x1​ 和 x2x_2x ...

  3. 007-卷积神经网络03-前向传播-反向传播

    前向传播: 前向传播就是求特征图的过程 通常x和w有四个维度[编号,深度,高度,宽度] 反向传播: 先来复习一下反向传播的知识: 反向传播回来的是梯度,也就是偏导数 反向传播力有一个链式法则:对于反向 ...

  4. pytorch 正向与反向传播的过程 获取模型的梯度(gradient),并绘制梯度的直方图

    记录一下怎样pytorch框架下怎样获得模型的梯度 文章目录 引入所需要的库 一个简单的函数 模型梯度获取 先定义一个model 如下定义两个获取梯度的函数 定义一些过程与调用上述函数的方法 可视化一 ...

  5. 机器学习概念 — 监督学习、无监督学习、半监督学习、强化学习、欠拟合、过拟合、后向传播、损失和优化函数、计算图、正向传播、反向传播

    1. 监督学习和无监督学习 监督学习 ( Supervised Learning ) 和无监督学习 ( Unsupervised Learning ) 是在机器学习中经常被提及的两个重要的学习方法. ...

  6. (pytorch-深度学习系列)正向传播与反向传播-学习笔记

    正向传播与反向传播 1. 正向传播 正向传播是指对神经网络沿着从输入层到输出层的顺序,依次计算并存储模型的中间变量(包括输出). 假设输入是一个特征为x∈Rd\boldsymbol{x} \in \m ...

  7. 【深度学习理论】一文搞透pytorch中的tensor、autograd、反向传播和计算图

    转载:https://zhuanlan.zhihu.com/p/145353262 前言 本文的主要目标: 一遍搞懂反向传播的底层原理,以及其在深度学习框架pytorch中的实现机制.当然一遍搞不定两 ...

  8. 神经网络与深度学习笔记(二)正向传播与反向传播

    文章目录 正向传播 反向传播 矢量计算 cost function由来 神经网络层每层向量的形状 正向传播 正向传播计算的是神经网络的输出 如上图,就是一次类似的正向传播的过程,正向传播计算最后的输出 ...

  9. Day7--误差反向传播

    1.计算图 背景 神经网络通过数值微分计算神经网络的权重参数的梯度(即损失函数关于权重参数的梯度) 数值微分虽然简单,容易实现,但是计算比较耗时. 优点 计算图的特征是可以通过传递"局部计算 ...

  10. 深度学习入门-误差反向传播法(人工神经网络实现mnist数据集识别)

    文章目录 误差反向传播法 5.1 链式法则与计算图 5.2 计算图代码实践 5.3激活函数层的实现 5.4 简单矩阵求导 5.5 Affine 层的实现 5.6 softmax-with-loss层计 ...

最新文章

  1. GOF23设计模式(创建型模式)单例模式
  2. 笔记-信息系统安全管理-安全审计
  3. matlab产生正态分布样本
  4. 华为阅读下载的文件在哪里找_华为手机还要天天清理内存?1键关闭这2个设置,手机用到2035年...
  5. AgileConfig-1.5.5 发布 - 支持 JSON 编辑模式
  6. 【Java从入门到天黑|06】高质量男性SpringBoot入门及原理(基础总结版,强烈建议收藏)
  7. 新品发布、降价普惠、拥抱开源、出海全球化 | 杭州云栖企业数字化转型峰会上的那些关键词
  8. Oracle-11g 基于 NBU 的 rman 冷备份及恢复
  9. 好看的极简网站导航源码自适应静态页
  10. 弧形面如何逆时针排序_环形导轨如何实现拐弯?
  11. java后台代码添加超链接_Java 添加超链接至Excel文档
  12. SSDB 一个高性能的支持丰富数据结构的 NoSQL 数据库, 用于替代 Redis.
  13. 【路径规划】基于matlab GUI人工势场算法机器人避障路径规划(手动设障)【含Matlab源码 617期】
  14. 数据结构C语言严蔚敏版(第二版)超详细笔记附带课后习题
  15. 数据存储过程之MySQL与ORACLE数据库的差别
  16. Udacity数据分析(入门)-探索美国共享单车数据
  17. cadence ETS安装过程
  18. 电商系统之优惠券设计
  19. 2022谷粒商城学习笔记(二十二)rabbitMQ学习
  20. #BJTUOJ 铁憨憨骑士的小队分配(图论缩点+思维)

热门文章

  1. php怎么判断文件在下载,php文件下载显示找不到文件怎么办
  2. 【转】解决Navicat 报错:1130-host ... is not allowed to connect to this MySql server,MySQL不允许从远程访问的方法 .
  3. javascript中Object类原型对象的属性和方法
  4. Usage of #pragma
  5. Python3实现Win10桌面背景自动切换
  6. Ubuntu下Apache+SVN搭建SVN服务多项目管理
  7. $stateParams 详解
  8. 【案例】MySQL count操作优化案例一则
  9. spring 中beanFactory和ApplicationContext的区别
  10. 【好文翻译】10个免费的压力测试工具(Web)