正向传播与反向传播

1. 正向传播

正向传播是指对神经网络沿着从输入层到输出层的顺序,依次计算并存储模型的中间变量(包括输出)。
假设输入是一个特征为x∈Rd\boldsymbol{x} \in \mathbb{R}^dx∈Rd的样本,且不考虑偏差项,那么中间变量

z=W(1)x,\boldsymbol{z} = \boldsymbol{W}^{(1)} \boldsymbol{x},z=W(1)x,
(矩阵相乘)

其中W(1)∈Rh×d\boldsymbol{W}^{(1)} \in \mathbb{R}^{h \times d}W(1)∈Rh×d是隐藏层的权重参数。把中间变量z∈Rh\boldsymbol{z} \in \mathbb{R}^hz∈Rh输入按元素运算的激活函数ϕ\phiϕ后,将得到向量长度为hhh的隐藏层变量

h=ϕ(z).\boldsymbol{h} = \phi (\boldsymbol{z}).h=ϕ(z).

隐藏层变量h\boldsymbol{h}h也是一个中间变量。假设输出层参数只有权重W(2)∈Rq×h\boldsymbol{W}^{(2)} \in \mathbb{R}^{q \times h}W(2)∈Rq×h,可以得到向量长度为qqq的输出层变量

o=W(2)h.\boldsymbol{o} = \boldsymbol{W}^{(2)} \boldsymbol{h}.o=W(2)h.

假设损失函数为ℓ\ellℓ,且样本标签为yyy,可以计算出单个数据样本的损失项

L=ℓ(o,y).L = \ell(\boldsymbol{o}, y).L=ℓ(o,y).

根据L2L_2L2​范数正则化的定义,给定超参数λ\lambdaλ,正则化项即(超参数λ\lambdaλ即表示惩罚的力度)

s=λ2(∣W(1)∣F2+∣W(2)∣F2),s = \frac{\lambda}{2} \left(|\boldsymbol{W}^{(1)}|_F^2 + |\boldsymbol{W}^{(2)}|_F^2\right),s=2λ​(∣W(1)∣F2​+∣W(2)∣F2​),

其中矩阵的Frobenius范数等价于将矩阵变平为向量后计算L2L_2L2​范数。最终,模型在给定的数据样本上带正则化的损失为

J=L+s.J = L + s.J=L+s.

我们将JJJ称为有关给定数据样本的目标函数。

2. 反向传播

反向传播用于计算神经网络中的参数梯度。反向传播利用微积分中的链式法则,沿着从输出层到输入层的顺序进行依次计算目标函数有关神经网络各层的中间变量以及参数的梯度。
依据链式法则,我们可以知道:
∂J∂o=prod(∂J∂L,∂L∂o)=∂L∂o.\frac{\partial J}{\partial \boldsymbol{o}} = \text{prod}\left(\frac{\partial J}{\partial L}, \frac{\partial L}{\partial \boldsymbol{o}}\right) = \frac{\partial L}{\partial \boldsymbol{o}}. ∂o∂J​=prod(∂L∂J​,∂o∂L​)=∂o∂L​.
(∂J∂L=1,∂J∂s=1)\left( \frac{\partial J}{\partial L} = 1, \quad \frac{\partial J}{\partial s} = 1\right)(∂L∂J​=1,∂s∂J​=1)
其中prod\text{prod}prod运算符将根据两个输入的形状,在必要的操作(如转置和互换输入位置)后对两个输入做乘法。
∂J∂W(2)=prod(∂J∂o,∂o∂W(2))+prod(∂J∂s,∂s∂W(2))=∂J∂oh⊤+λW(2)\frac{\partial J}{\partial \boldsymbol{W}^{(2)}} = \text{prod}\left(\frac{\partial J}{\partial \boldsymbol{o}}, \frac{\partial \boldsymbol{o}}{\partial \boldsymbol{W}^{(2)}}\right) + \text{prod}\left(\frac{\partial J}{\partial s}, \frac{\partial s}{\partial \boldsymbol{W}^{(2)}}\right) = \frac{\partial J}{\partial \boldsymbol{o}} \boldsymbol{h}^\top + \lambda \boldsymbol{W}^{(2)} ∂W(2)∂J​=prod(∂o∂J​,∂W(2)∂o​)+prod(∂s∂J​,∂W(2)∂s​)=∂o∂J​h⊤+λW(2)
其中:
(∂s∂W(1)=λW(1),∂s∂W(2)=λW(2))\left(\frac{\partial s}{\partial \boldsymbol{W}^{(1)}} = \lambda \boldsymbol{W}^{(1)},\quad\frac{\partial s}{\partial \boldsymbol{W}^{(2)}} = \lambda \boldsymbol{W}^{(2)}\right)(∂W(1)∂s​=λW(1),∂W(2)∂s​=λW(2))
还有:
∂J∂W(2)=prod(∂J∂o,∂o∂W(2))+prod(∂J∂s,∂s∂W(2))=∂J∂oh⊤+λW(2)\frac{\partial J}{\partial \boldsymbol{W}^{(2)}} = \text{prod}\left(\frac{\partial J}{\partial \boldsymbol{o}}, \frac{\partial \boldsymbol{o}}{\partial \boldsymbol{W}^{(2)}}\right) + \text{prod}\left(\frac{\partial J}{\partial s}, \frac{\partial s}{\partial \boldsymbol{W}^{(2)}}\right) = \frac{\partial J}{\partial \boldsymbol{o}} \boldsymbol{h}^\top + \lambda \boldsymbol{W}^{(2)} ∂W(2)∂J​=prod(∂o∂J​,∂W(2)∂o​)+prod(∂s∂J​,∂W(2)∂s​)=∂o∂J​h⊤+λW(2)

∂J∂h=prod(∂J∂o,∂o∂h)=W(2)⊤∂J∂o\frac{\partial J}{\partial \boldsymbol{h}} = \text{prod}\left(\frac{\partial J}{\partial \boldsymbol{o}}, \frac{\partial \boldsymbol{o}}{\partial \boldsymbol{h}}\right) = {\boldsymbol{W}^{(2)}}^\top \frac{\partial J}{\partial \boldsymbol{o}} ∂h∂J​=prod(∂o∂J​,∂h∂o​)=W(2)⊤∂o∂J​
∂J∂z=prod(∂J∂h,∂h∂z)=∂J∂h⊙ϕ′(z)\frac{\partial J}{\partial \boldsymbol{z}} = \text{prod}\left(\frac{\partial J}{\partial \boldsymbol{h}}, \frac{\partial \boldsymbol{h}}{\partial \boldsymbol{z}}\right) = \frac{\partial J}{\partial \boldsymbol{h}} \odot \phi'\left(\boldsymbol{z}\right) ∂z∂J​=prod(∂h∂J​,∂z∂h​)=∂h∂J​⊙ϕ′(z)
所以,可以得到:
∂J∂W(1)=prod(∂J∂z,∂z∂W(1))+prod(∂J∂s,∂s∂W(1))=∂J∂zx⊤+λW(1)\frac{\partial J}{\partial \boldsymbol{W}^{(1)}} = \text{prod}\left(\frac{\partial J}{\partial \boldsymbol{z}}, \frac{\partial \boldsymbol{z}}{\partial \boldsymbol{W}^{(1)}}\right) + \text{prod}\left(\frac{\partial J}{\partial s}, \frac{\partial s}{\partial \boldsymbol{W}^{(1)}}\right) = \frac{\partial J}{\partial \boldsymbol{z}} \boldsymbol{x}^\top + \lambda \boldsymbol{W}^{(1)}∂W(1)∂J​=prod(∂z∂J​,∂W(1)∂z​)+prod(∂s∂J​,∂W(1)∂s​)=∂z∂J​x⊤+λW(1)

  • 在模型参数初始化完成后,需要交替地进行正向传播和反向传播,并根据反向传播计算的梯度迭代模型参数。
  • 在反向传播中使用了正向传播中计算得到的中间变量来避免重复计算,同时这个复用也导致正向传播结束后不能立即释放中间变量内存。这也是训练要比预测占用更多内存的一个重要原因。
  • 这些中间变量的个数大体上与网络层数线性相关,每个变量的大小跟批量大小和输入个数也是线性相关的,这是导致较深的神经网络使用较大批量训练时更容易超内存的主要原因。

(pytorch-深度学习系列)正向传播与反向传播-学习笔记相关推荐

  1. Python实现深度学习系列之【正向传播和反向传播】

    前言 在了解深度学习框架之前,我们需要自己去理解甚至去实现一个网络学习和调参的过程,进而理解深度学习的机理: 为此,博主这里提供了一个自己编写的一个例子,带领大家理解一下网络学习的正向传播和反向传播的 ...

  2. 深度学习基础之正向传播与反向传播

    文章目录 前言 正向传播 链式法则 反向传播 加法节点的反向传播 乘法节点的反向传播 小结 实例 Sigmoid函数 Softmax-with-Loss 层 参考 前言 因为这学期上了一门深度学习的课 ...

  3. 机器学习概念 — 监督学习、无监督学习、半监督学习、强化学习、欠拟合、过拟合、后向传播、损失和优化函数、计算图、正向传播、反向传播

    1. 监督学习和无监督学习 监督学习 ( Supervised Learning ) 和无监督学习 ( Unsupervised Learning ) 是在机器学习中经常被提及的两个重要的学习方法. ...

  4. 独家思维导图!让你秒懂李宏毅2020深度学习(三)——深度学习基础(神经网络和反向传播部分)

    独家思维导图!让你秒懂李宏毅2020深度学习(三)--深度学习基础(神经网络和反向传播部分) 长文预警!!!前面两篇文章主要介绍了李宏毅视频中的机器学习部分,从这篇文章开始,我将介绍李宏毅视频中的深度 ...

  5. 【深度学习理论】一文搞透pytorch中的tensor、autograd、反向传播和计算图

    转载:https://zhuanlan.zhihu.com/p/145353262 前言 本文的主要目标: 一遍搞懂反向传播的底层原理,以及其在深度学习框架pytorch中的实现机制.当然一遍搞不定两 ...

  6. 深度神经网络(DNN)正向传播与反向传播推导(通俗易懂)

    一.前言 我在之前的博客里面介绍过浅层的神经网络,现在就从浅层神经网络出发,介绍深度神经网络(DNN)的正向传播和反向传播.希望网友们看本博客之前需要对神经网络有个简单的了解,或者可以看博客初探神经网 ...

  7. 花书+吴恩达深度学习(十三)卷积神经网络 CNN 之运算过程(前向传播、反向传播)

    目录 0. 前言 1. 单层卷积网络 2. 各参数维度 3. CNN 前向传播反向传播 如果这篇文章对你有一点小小的帮助,请给个关注,点个赞喔~我会非常开心的~ 花书+吴恩达深度学习(十)卷积神经网络 ...

  8. 深度学习基础笔记——前向传播与反向传播

    相关申明及相关参考: 体系学习地址 主要学习笔记地址 由于是文章阅读整合,依据个人情况标注排版, 不确定算不算转载,主要学习围绕AI浩的五万字总结,深度学习基础 如有侵权,请联系删除. 1前向传播与反 ...

  9. 神经网络正向传播和反向传播

    正向传播(forward-propagation):指对神经网络沿着输入层到输出层的顺序,依次计算并存储模型的中间变量. 反向传播(back-propagation):沿着从输出层到输入层的顺序,依据 ...

最新文章

  1. [JavaScript Java] 初识Closure Tools(一)
  2. CTF入门--二进制
  3. Response.ContentType 详细列表
  4. 前端学习(2974):组件重定向
  5. BZOJ 3329: Xorequ(数位dp+递推)
  6. all方法 手写promise_我团队的一年前端实现Promise所有方法
  7. 数据库正确建立索引以及最左前缀原则
  8. c语言上机作业题及答案,华为C语言上机试题及答案
  9. [ 英语 ] 语法重塑 之 英语学习的核心框架 —— 英语兔学习笔记(1)
  10. ATF(Arm Trusted Firmware)/TF-A Chapter 03 Chain of Trust (CoT)
  11. linux dd从磁盘读取文件命令
  12. MIT oracle ma 信号线,美国 MIT Oracle MA-X Phono唱臂线 独家Multipole技术
  13. NodeJS学习:环境变量
  14. ​英伟达 CEO 黄仁勋:摩尔定律结束了;苹果新专利:折叠式iPhone可自行修复折痕;Rust 1.64.0 发布|极客头条...
  15. 恩智浦MKL26Z128VFT4单片机官方提供keil版SDK配置使用
  16. 如何判断一个项目的可行性?
  17. 黄斑裂孔易致失明,年长者和高度近视者尤其要注意!
  18. 日常工作积累(待续)
  19. 一定会好好的、慢慢的来
  20. 怎样下载程序到西门子PLC

热门文章

  1. [设计模式] ------ 策略模式
  2. python搭建web服务器_Python搭建简单的web服务器
  3. hash地址_深入浅出一致性Hash原理
  4. mysql对日期的操作_MySql对日期的操作
  5. Java中AJAX工作原理是什么
  6. 高大上的集团名字_那些刚改了“高大上”名字的学校,你知道都有哪些吗?蜻蜓AI小编来帮你科普一下...
  7. 【LeetCode笔记】148. 排序链表(Java、归并排序、快慢指针、双重递归)
  8. 【LeetCode笔记】79. 单词搜索 剑指 Offer 12 矩阵中的路径(Java、dfs)
  9. 七参数 布尔萨 最小二乘法_最小二乘法和最大似然法的联系
  10. R 语言怎么保存工作目录到当前路径_【R语言基础】01.R语言软件环境搭建及常用操作...