文章目录

  • 例子定义
  • 多变量链式求导:简单的例子
  • 多变量求导:计算图模式
PyTorch创建模型的一般的写法是:outputs = Your_model(data_x)optimizer.zero_grad()loss = loss_function(data_y,outputs)loss.backward()optimizer.step()这里的loss不是一个tensor吗?
这个tensor就是存了一个值,
loss.backward()会实现一个什么样的过程?
是计算梯度吗?
感觉这里的loss和网络没什么关系啊?
是outputs 经过loss_function使得loss这个tensor和网络间接产生了关联,
所以程序可以自动识别网络的参数结构并以此来求梯度吗?

首先是构建计算图,loss.backward()的时候就是走一遍反向图。

举个例子就明白了:

例子定义

为了简单起见,就假设只有一个训练样本(x,t)(x, t)(x,t)。网络模型是一个线性模型,带有一个非线形的sigmoid层,然后用均方差作为其Loss函数,这个模型用公式可以表示为如下形式:

z=wx+by=σ(z)L=12(y−t)2z = wx + b \\ y = \sigma(z) \\ \mathcal{L} = \frac{1}{2}(y - t)^{2} z=wx+by=σ(z)L=21​(y−t)2

再考虑添加一个正则化λ2w2\frac{\lambda}{2} w^{2}2λ​w2,期望模型变得更加简单一点,也就是期望www变成0。此时损失函数将会变为:

R=12w2Lreg=L+λR\mathcal{R} = \frac{1}{2} w^{2} \\ \mathcal{L}_{reg} = \mathcal{L} + \lambda \mathcal{R} R=21​w2Lreg​=L+λR

其中λ\lambdaλ是超参数。

想要做梯度下降的话,我们就需要对www和bbb求偏微分:∂E/∂w\partial \mathcal{E} / \partial w∂E/∂w和∂E/∂b\partial \mathcal{E} / \partial b∂E/∂b。

将目标函数展开可以表示为:

Lreg=12(σ(wx+b)−t)2+λ2w2\mathcal{L}_{\mathrm{reg}}=\frac{1}{2}(\sigma(w x+b)-t)^{2}+\frac{\lambda}{2} w^{2} Lreg​=21​(σ(wx+b)−t)2+2λ​w2

对www求偏导数可以得到:

∂Lreg∂w=∂∂w[12(σ(wx+b)−t)2+λ2w2]=12∂∂w(σ(wx+b)−t)2+λ2∂∂ww2=(σ(wx+b)−t)∂∂w(σ(wx+b)−t)+λw=(σ(wx+b)−t)σ′(wx+b)∂∂w(wx+b)+λw=(σ(wx+b)−t)σ′(wx+b)x+λw\begin{aligned} \frac{\partial \mathcal{L}_{\mathrm{reg}}}{\partial w} &=\frac{\partial}{\partial w}\left[\frac{1}{2}(\sigma(w x+b)-t)^{2}+\frac{\lambda}{2} w^{2}\right] \\ &=\frac{1}{2} \frac{\partial}{\partial w}(\sigma(w x+b)-t)^{2}+\frac{\lambda}{2} \frac{\partial}{\partial w} w^{2} \\ &=(\sigma(w x+b)-t) \frac{\partial}{\partial w}(\sigma(w x+b)-t)+\lambda w \\ &=(\sigma(w x+b)-t) \sigma^{\prime}(w x+b) \frac{\partial}{\partial w}(w x+b)+\lambda w \\ &=(\sigma(w x+b)-t) \sigma^{\prime}(w x+b) x+\lambda w \end{aligned} ∂w∂Lreg​​​=∂w∂​[21​(σ(wx+b)−t)2+2λ​w2]=21​∂w∂​(σ(wx+b)−t)2+2λ​∂w∂​w2=(σ(wx+b)−t)∂w∂​(σ(wx+b)−t)+λw=(σ(wx+b)−t)σ′(wx+b)∂w∂​(wx+b)+λw=(σ(wx+b)−t)σ′(wx+b)x+λw​

对bbb求偏导数可以得到:

∂Lreg∂b=∂∂b[12(σ(wx+b)−t)2+λ2w2]=12∂∂b(σ(wx+b)−t)2+λ2∂∂bw2=(σ(wx+b)−t)∂∂b(σ(wx+b)−t)+0=(σ(wx+b)−t)σ′(wx+b)∂∂b(wx+b)=(σ(wx+b)−t)σ′(wx+b)\begin{aligned} \frac{\partial \mathcal{L}_{\mathrm{reg}}}{\partial b} &=\frac{\partial}{\partial b}\left[\frac{1}{2}(\sigma(w x+b)-t)^{2}+\frac{\lambda}{2} w^{2}\right] \\ &=\frac{1}{2} \frac{\partial}{\partial b}(\sigma(w x+b)-t)^{2}+\frac{\lambda}{2} \frac{\partial}{\partial b} w^{2} \\ &=(\sigma(w x+b)-t) \frac{\partial}{\partial b}(\sigma(w x+b)-t)+0 \\ &=(\sigma(w x+b)-t) \sigma^{\prime}(w x+b) \frac{\partial}{\partial b}(w x+b) \\ &=(\sigma(w x+b)-t) \sigma^{\prime}(w x+b) \end{aligned} ∂b∂Lreg​​​=∂b∂​[21​(σ(wx+b)−t)2+2λ​w2]=21​∂b∂​(σ(wx+b)−t)2+2λ​∂b∂​w2=(σ(wx+b)−t)∂b∂​(σ(wx+b)−t)+0=(σ(wx+b)−t)σ′(wx+b)∂b∂​(wx+b)=(σ(wx+b)−t)σ′(wx+b)​

上述方法确实能够得出正确的解,但是有以下缺陷:

  1. 计算非常冗余复杂;
  2. 上述计算过程中有很多地方是重复计算的,比如wx+bwx+bwx+b计算了四次, (σ(wx+b)−t)σ′(wx+b)(\sigma(w x+b)-t) \sigma^{\prime}(w x+b)(σ(wx+b)−t)σ′(wx+b)计算了两次。

多变量链式求导:简单的例子

其实上述方法就是在计算一元链式求导多次,一元链式求导可以定义为如下形式:

ddtf(g(t))=f′(g(t))g′(t)\frac{\mathrm{d}}{\mathrm{d} t} f(g(t))=f^{\prime}(g(t)) g^{\prime}(t) dtd​f(g(t))=f′(g(t))g′(t)

多变量的求导可以定义为:

ddtf(x(t),y(t))=∂f∂xdxdt+∂f∂ydydt\frac{\mathrm{d}}{\mathrm{d} t} f(x(t), y(t))=\frac{\partial f}{\partial x} \frac{\mathrm{d} x}{\mathrm{~d} t}+\frac{\partial f}{\partial y} \frac{\mathrm{d} y}{\mathrm{~d} t} dtd​f(x(t),y(t))=∂x∂f​ dtdx​+∂y∂f​ dtdy​

为了方便叙述,定义一个符号,比如对vvv的导数可以定义为:

vˉ≜∂L∂v\bar{v} \triangleq \frac{\partial \mathcal{L}}{\partial v} vˉ≜∂v∂L​

那么上述ddtf(x(t),y(t))=∂f∂xdxdt+∂f∂ydydt\frac{\mathrm{d}}{\mathrm{d} t} f(x(t), y(t))=\frac{\partial f}{\partial x} \frac{\mathrm{d} x}{\mathrm{~d} t}+\frac{\partial f}{\partial y} \frac{\mathrm{d} y}{\mathrm{~d} t}dtd​f(x(t),y(t))=∂x∂f​ dtdx​+∂y∂f​ dtdy​可以表示为:

tˉ=xˉdxdt+yˉdydt\bar{t} = \bar{x} \frac{dx}{dt} + \bar{y} \frac{dy}{dt} tˉ=xˉdtdx​+yˉ​dtdy​

这样我们就可以用偏微分的计算方法来计算dxdt\frac{dx}{dt}dtdx​, 而xˉ\bar{x}xˉ和yˉ\bar{y}yˉ​是之前就已经计算好的。

多变量求导:计算图模式

上述带有正则化的模型用计算图可以表示为:

整个模型可以表示为:

z=wx+by=σ(z)L=12(y−t)2R=12w2Lreg=L+λRz = wx + b \\ y = \sigma(z) \\ \mathcal{L} = \frac{1}{2}(y - t)^{2} \\ \mathcal{R} = \frac{1}{2} w^{2} \\ \mathcal{L}_{reg} = \mathcal{L} + \lambda \mathcal{R} z=wx+by=σ(z)L=21​(y−t)2R=21​w2Lreg​=L+λR

反向传播的时候,我们需要去计算得到wˉ\bar{w}wˉ和bˉ\bar{b}bˉ, 就需要反复利用链式求导计算偏微分。

也就是要从结果(这里定义为E\mathcal{E}E),一步一步往前去计算它的前一个节点的导数。定义v1,⋯,vNv_{1}, \cdots, v_{N}v1​,⋯,vN​是计算图中的所有节点,并且以输入到输出的拓扑顺序进行排序的。

我们期望去计算得到所有节点的偏导数vˉi\bar{v}_{i}vˉi​, 神经网络工作的时候就是走一遍前向传播,然后走一遍反向传播。最末尾的这个节点是vNv_{N}vN​, 我们也需去得到它的偏导数,为了方便计算,我们通常令其偏导数为1。也就是vNˉ=1\bar{v_{N}} = 1vN​ˉ​=1。

此时整个算法的逻辑可以表示为:

遍历i=1,…,Ni=1, \ldots, Ni=1,…,N,计算 viv_{i}vi​ 作为父亲节点的值 Pa⁡(vi)\operatorname{Pa}\left(v_{i}\right)Pa(vi​)。

设置vN=1v_{N}=1vN​=1, 遍历i=N−1,…,1i=N-1, \ldots, 1i=N−1,…,1,计算 vi‾=∑j∈Ch⁡(vi)vj‾∂vj∂vi\overline{v_{i}}=\sum_{j \in \operatorname{Ch}\left(v_{i}\right)} \overline{v_{j}} \frac{\partial v_{j}}{\partial v_{i}}vi​​=∑j∈Ch(vi​)​vj​​∂vi​∂vj​​。

OK, 到此,我们就可以通过上述的这个计算公式来走反向传播了:

首先是目标函数的导数被定义为1,也就是Lregˉ=1\bar{\mathcal{L}_{reg}} = 1Lreg​ˉ​=1。

之后我们对Lreg\mathcal{L}_{reg}Lreg​的两个父亲节点求偏导数,也就是Rˉ\bar{\mathcal{R}}Rˉ和Lˉ\bar{\mathcal{L}}Lˉ:

  1. Rˉ\bar{\mathcal{R}}Rˉ的孩子节点为Lreg\mathcal{L}_{reg}Lreg​,所以依据公式vj‾∂vj∂vi\overline{v_{j}} \frac{\partial v_{j}}{\partial v_{i}}vj​​∂vi​∂vj​​, 我们可以得到Rˉ=LregˉdLregdR=Lregˉλ\bar{\mathcal{R}} = \bar{\mathcal{L_{reg}}} \frac{\mathcal{dL_{reg}}}{d\mathcal{R}} = \bar{\mathcal{L}_{reg}} \lambdaRˉ=Lreg​ˉ​dRdLreg​​=Lreg​ˉ​λ

  2. Lˉ\bar{\mathcal{L}}Lˉ的孩子节点也为Lreg\mathcal{L}_{reg}Lreg​, 所以Lˉ=LregˉdLregdL=Lregˉ\bar{\mathcal{L}} = \bar{\mathcal{L}_{reg}} \frac{d \mathcal{L}_{reg}}{d \mathcal{L}} = \bar{\mathcal{L}_{reg}}Lˉ=Lreg​ˉ​dLdLreg​​=Lreg​ˉ​

我们再进行反向传播,来计算yˉ\bar{y}yˉ​和zˉ\bar{z}zˉ:

  1. yˉ=LˉdLdy=Lˉ(y−t)\bar{y} = \bar{\mathcal{L}} \frac{d \mathcal{L}}{d y} = \bar{\mathcal{L}} (y - t)yˉ​=LˉdydL​=Lˉ(y−t)。
  2. zˉ=yˉdydz=yˉσ′(z)\bar{z} = \bar{y} \frac{d y}{d z} = \bar{y} \sigma^{\prime}(z)zˉ=yˉ​dzdy​=yˉ​σ′(z)。

最后到了我们需要更新的参数www和bbb:

  1. www有两个父亲节点zzz和R\mathcal{R}R,所以wˉ=zˉ∂z∂w+RˉdRdw=zˉx+Rˉw\bar{w} = \bar{z} \frac{\partial{z}}{\partial{w}} + \bar{\mathcal{R}} \frac{d \mathcal{R}}{d w}= \bar{z} x + \bar{\mathcal{R}}wwˉ=zˉ∂w∂z​+RˉdwdR​=zˉx+Rˉw。
  2. bbb有一个父亲节点zzz,所以bˉ=zˉ∂z∂b=zˉ\bar{b} = \bar{z} \frac{\partial{z}}{\partial{b}} = \bar{z}bˉ=zˉ∂b∂z​=zˉ。

总结一下推导过程为:

Lreg ‾=1R‾=Lreg ‾dLreg dR=Lreg ‾λL‾=Lreg ‾dLreg dL=Lreg ‾yˉ=L‾dLdy=L‾(y−t)zˉ=yˉdydz=yˉσ′(z)wˉ=zˉ∂z∂w+R‾dRdw=zˉx+R‾wbˉ=zˉ∂z∂b=zˉ\begin{aligned} \overline{\mathcal{L}_{\text {reg }}} &=1 \\ \overline{\mathcal{R}} &=\overline{\mathcal{L}_{\text {reg }}} \frac{\mathrm{d} \mathcal{L}_{\text {reg }}}{\mathrm{d} \mathcal{R}} \\ &=\overline{\mathcal{L}_{\text {reg }}} \lambda \\ \overline{\mathcal{L}} &=\overline{\mathcal{L}_{\text {reg }}} \frac{\mathrm{d} \mathcal{L}_{\text {reg }}}{\mathrm{d} \mathcal{L}} \\ &=\overline{\mathcal{L}_{\text {reg }}} \\ \bar{y} &=\overline{\mathcal{L}} \frac{\mathrm{d} \mathcal{L}}{\mathrm{d} y} \\ &=\overline{\mathcal{L}}(y-t) \\ \bar{z} &=\bar{y} \frac{\mathrm{d} y}{\mathrm{~d} z} \\ &=\bar{y} \sigma^{\prime}(z) \\ \bar{w} &=\bar{z} \frac{\partial z}{\partial w}+\overline{\mathcal{R}} \frac{\mathrm{d} \mathcal{R}}{\mathrm{d} w} \\ &=\bar{z} x+\overline{\mathcal{R}} w \\ \bar{b} &=\bar{z} \frac{\partial z}{\partial b} \\ &=\bar{z} \end{aligned} Lreg ​​RLyˉ​zˉwˉbˉ​=1=Lreg ​​dRdLreg ​​=Lreg ​​λ=Lreg ​​dLdLreg ​​=Lreg ​​=LdydL​=L(y−t)=yˉ​ dzdy​=yˉ​σ′(z)=zˉ∂w∂z​+RdwdR​=zˉx+Rw=zˉ∂b∂z​=zˉ​

总结一下最终结果为:

Lreg ‾=1R‾=Lreg ‾λL‾=Lreg ‾yˉ=L‾(y−t)zˉ=yˉσ′(z)wˉ=zˉx+R‾wbˉ=zˉ\begin{aligned} \overline{\mathcal{L}_{\text {reg }}} &=1 \\ \overline{\mathcal{R}} &=\overline{\mathcal{L}_{\text {reg }}} \lambda \\ \overline{\mathcal{L}} &=\overline{\mathcal{L}_{\text {reg }}} \\ \bar{y} &=\overline{\mathcal{L}}(y-t) \\ \bar{z} &=\bar{y} \sigma^{\prime}(z) \\ \bar{w} &=\bar{z} x+\overline{\mathcal{R}} w \\ \bar{b} &=\bar{z} \end{aligned} Lreg ​​RLyˉ​zˉwˉbˉ​=1=Lreg ​​λ=Lreg ​​=L(y−t)=yˉ​σ′(z)=zˉx+Rw=zˉ​

可以看到相比之前的推导偏微分方程的方式,这种反向传播的方式更为简洁。

PyTorch中loss.backward()的时候就是把上述这个最终结果走一遍,因为创建model的时候是调用模块的,所以计算偏导数的时候,之前就计算好了偏导数是多少,把值带入进去就可以了。

下面这篇文章有兴趣可以看

  • Automatic differentiation in machine learning: a survey 【https://arxiv.org/abs/1502.05767】

PyTorch中的梯度微分机制相关推荐

  1. 更新fielddata为true_在pytorch中停止梯度流的若干办法,避免不必要模块的参数更新...

    在pytorch中停止梯度流的若干办法,避免不必要模块的参数更新 2020/4/11 FesianXu 前言 在现在的深度模型软件框架中,如TensorFlow和PyTorch等等,都是实现了自动求导 ...

  2. PyTorch中的梯度累积

    我们在训练神经网络的时候,超参数batch_size的大小会对模型最终效果产生很大的影响,通常的经验是,batch_size越小效果越差:batch_size越大模型越稳定.理想很丰满,现实很骨感,很 ...

  3. Pytorch中的梯度知识总结

    文章目录 1.叶节点.中间节点.梯度计算 2.叶子张量 leaf tensor (叶子节点) (detach) 2.1 为什么需要叶子节点? 2.2 detach()将节点剥离成叶子节点 2.3 什么 ...

  4. Pytorch中的梯度回传

    来源:知乎-歪杠小胀 作者:https://zhuanlan.zhihu.com/p/451441329 01 记录写这篇文章的初衷 最近在复现一篇论文的训练代码时,发现原论文中的总loss由多个lo ...

  5. PyTorch中的梯度计算1

    主要用pytorch,对其他的框架用的很少,而且也没经过系统的学习,对动态图和梯度求导没有个准确的认识.今天根据代码仔细的看一下. import torch from torch.autograd i ...

  6. 【深度学习】PyTorch 中的线性回归和梯度下降

    作者 | JNK789   编译 | Flin  来源 | analyticsvidhya 我们正在使用 Jupyter notebook 来运行我们的代码.我们建议在Google Colaborat ...

  7. 【深度学习理论】一文搞透pytorch中的tensor、autograd、反向传播和计算图

    转载:https://zhuanlan.zhihu.com/p/145353262 前言 本文的主要目标: 一遍搞懂反向传播的底层原理,以及其在深度学习框架pytorch中的实现机制.当然一遍搞不定两 ...

  8. Pytorch中的variable, tensor与numpy相互转化

    来源:https://blog.csdn.net/m0_37592397/article/details/88327248 1.将numpy矩阵转换为Tensor张量 sub_ts = torch.f ...

  9. backward()和zero_grad()在PyTorch中代表什么意思

    文章目录 问:`backward()`和`zero_grad()`是什么意思? backward() zero_grad() 问:求导和梯度什么关系 问:backward不是求导吗,和梯度有什么关系( ...

  10. Pytorch中的向前计算(autograd)、梯度计算以及实现线性回归操作

    在整个Pytorch框架中, 所有的神经网络本质上都是一个autograd package(自动求导工具包) autograd package提供了一个对Tensors上所有的操作进行自动微分的功能. ...

最新文章

  1. IIS部署详细步骤、包括错误的解决办法、使用localDB
  2. 自定义LayoutManager实现最美应用列表
  3. mysql 查询后怎么定位列_MySQL如何定位并优化慢查询sql
  4. 【培训稿件】构建WCF面向服务的应用程序(包含ppt,源代码)
  5. win7 php环境搭建 x64,win7搭建php+Apache环境
  6. C语言小知识---特殊的字符串
  7. SQL Server 作业监控
  8. utilities(matlab)—— PSNR 值的计算
  9. 2022年11月份PMP考试是新版教材吗?
  10. 移动端车牌识别sdk——技术干货
  11. 提升 10 倍!阿里云对象存储 OSS 可用性 SLA 技术揭秘
  12. xp系统怎么关闭wmi服务器,WinXP系统如何启用WMI服务,小编教你WinXP系统如何启用WMI服务...
  13. 如何使用IDEA进行协作编码,共享项目,并实时的处理
  14. 计算机组策略定时开机脚本,批处理+组策略 实现规定时间段无法开机and定时关机...
  15. Android远程真机调试(电脑使用 Vysor 控制手机)
  16. C++解决猴子吃桃问题(详细)
  17. 计算机系统维护是干嘛,计算机系统维护是什么
  18. python模拟行星运动_Java课程设计——模拟行星运动
  19. python收集数据做主神_里纲_[综漫]收集数据做主神小说无防盗章节_作者忘却的悠_新书包网(www.51aslz.com)...
  20. c语言金字塔输出乘法表,python中打印金字塔和九九乘法表的几种方法

热门文章

  1. Sicily 6271
  2. USACO_1_2_Dual Palindromes
  3. python列表题目_python4_list应用的练习题
  4. Cocos2d-x-使用脚本概述
  5. 游戏筑基之两个变量交换值与三个变量交换值的比较(C语言)
  6. Docker详解(三)——Docker安装与部署
  7. keepalived详解(一)——keepalived理论基础
  8. ansible 第四次作业
  9. 边缘计算的前景和挑战
  10. 文件比较与同步工具——FreeFileSync