autograd

反向传播过程需要手动实现。这对于像线性回归等较为简单的模型来说,还可以应付,但实际使用中经常出现非常复杂的网络结构,此时如果手动实现反向传播,不仅费时费力,而且容易出错,难以检查。torch.autograd就是为方便用户使用,而专门开发的一套自动求导引擎,它能够根据输入和前向传播过程自动构建计算图,并执行反向传播。

计算图(Computation Graph)是现代深度学习框架如PyTorch和TensorFlow等的核心,其为高效自动求导算法——反向传播(Back Propogation)提供了理论支持,了解计算图在实际写程序过程中会有极大的帮助。本节将涉及一些基础的计算图知识,但并不要求读者事先对此有深入的了解。关于计算图的基础知识推荐阅读Christopher Olah的文章

Variable

PyTorch在autograd模块中实现了计算图的相关功能,autograd中的核心数据结构是Variable。Variable封装了tensor,并记录对tensor的操作记录用来构建计算图。Variable的数据结构如图3-2所示,主要包含三个属性:

  • data:保存variable所包含的tensor
  • grad:保存data对应的梯度,grad也是variable,而不是tensor,它与data形状一致。
  • grad_fn: 指向一个Function,记录tensor的操作历史,即它是什么操作的输出,用来构建计算图。如果某一个变量是由用户创建,则它为叶子节点,对应的grad_fn等于None。

Variable的构造函数需要传入tensor,同时有两个可选参数:

  • requires_grad (bool):是否需要对该variable进行求导
  • volatile (bool):意为”挥发“,设置为True,则构建在该variable之上的图都不会求导,专为推理阶段设计

Variable提供了大部分tensor支持的函数,但其不支持部分inplace函数,因这些函数会修改tensor自身,而在反向传播中,variable需要缓存原来的tensor来计算反向传播梯度。如果想要计算各个Variable的梯度,只需调用根节点variable的backward方法,autograd会自动沿着计算图反向传播,计算每一个叶子节点的梯度。

variable.backward(grad_variables=None, retain_graph=None, create_graph=None)主要有如下参数:

  • grad_variables:形状与variable一致,对于y.backward(),grad_variables相当于链式法则。grad_variables也可以是tensor或序列。
  • retain_graph:反向传播需要缓存一些中间结果,反向传播之后,这些缓存就被清空,可通过指定这个参数不清空缓存,用来多次反向传播。
  • create_graph:对反向传播过程再次构建计算图,可通过backward of backward实现求高阶导数。

def f(x):'''计算y'''y = x**2 * t.exp(x)return ydef gradf(x):'''手动求导函数'''dx = 2*x*t.exp(x) + x**2*t.exp(x)return dxx = V(t.randn(3,4), requires_grad = True)
y = f(x)
ytensor([[0.4246, 0.1465, 0.0681, 0.2506],[0.5244, 5.0220, 0.0333, 1.8658],[0.3736, 0.1008, 0.7557, 0.1063]], grad_fn=<MulBackward0>)
y.backward(t.ones(y.size())) # grad_variables形状与y一致
x.gradtensor([[-0.3017,  1.0473, -0.3803, -0.4493],[ 2.4313, 13.2647,  0.4299,  6.1072],[ 1.9280, -0.4231,  3.1421, -0.4281]])# autograd的计算结果与利用公式手动计算的结果一致
gradf(x) tensor([[-0.3017,  1.0473, -0.3803, -0.4493],[ 2.4313, 13.2647,  0.4299,  6.1072],[ 1.9280, -0.4231,  3.1421, -0.4281]], grad_fn=<AddBackward0>)

计算图

PyTorch中autograd的底层采用了计算图,计算图是一种特殊的有向无环图(DAG),用于记录算子与变量之间的关系。一般用矩形表示算子,椭圆形表示变量。如表达式z = wx + b可分解为y = wx和z = y + b,其计算图如图3-3所示,图中MULADD都是算子,ww,xx,bb即变量。

如上有向无环图中,XX和bb是叶子节点(leaf node),这些节点通常由用户自己创建,不依赖于其他变量。zz称为根节点,是计算图的最终目标。利用链式法则很容易求得各个叶子节点的梯度。

而有了计算图,上述链式求导即可利用计算图的反向传播自动完成

PyTorch实战福利从入门到精通之三——autograd相关推荐

  1. PyTorch实战福利从入门到精通之五——搭建ResNet

    Kaiming He的深度残差网络(ResNet)在深度学习的发展中起到了很重要的作用,ResNet不仅一举拿下了当年CV下多个比赛项目的冠军,更重要的是这一结构解决了训练极深网络时的梯度消失问题. ...

  2. PyTorch实战福利从入门到精通之一——PyTorch框架安装

    使用conda安装是最不容易出错的,在pytroch的官网可以选择自己需要的操作系统.python版本.cuda版本的pytorch框架. 之后复制下面的命令就可以了 安装完这个还要安个numpy p ...

  3. PyTorch实战福利从入门到精通之六——线性回归

    一元线性回归 一元线性模型非常简单,假设我们有变量 xix_ixi​ 和目标 yiy_iyi​,每个 i 对应于一个数据点,希望建立一个模型 y^i=wxi+b\hat{y}_i = w x_i + ...

  4. PyTorch实战福利从入门到精通之四——卷积神经网络CIFAR-10图像分类

    在本教程中,我们将使用CIFAR10数据集.它有类别:"飞机"."汽车"."鸟"."猫"."鹿".& ...

  5. PyTorch实战福利从入门到精通之八——深度卷积神经网络(AlexNet)

    在LeNet提出后的将近20年里,神经网络一度被其他机器学习方法超越,如支持向量机.虽然LeNet可以在早期的小数据集上取得好的成绩,但是在更大的真实数据集上的表现并不尽如人意.一方面,神经网络计算复 ...

  6. PyTorch实战福利从入门到精通之七——卷积神经网络(LeNet)

    卷积神经网络就是含卷积层的网络.介绍一个早期用来识别手写数字图像的卷积神经网络:LeNet [1].这个名字来源于LeNet论文的第一作者Yann LeCun.LeNet展示了通过梯度下降训练卷积神经 ...

  7. PyTorch实战福利从入门到精通之九——数据处理

    在解决深度学习问题的过程中,往往需要花费大量的精力去处理数据,包括图像.文本.语音或其它二进制数据等.数据的处理对训练神经网络来说十分重要,良好的数据处理不仅会加速模型训练,更会提高模型效果.考虑到这 ...

  8. PyTorch实战福利从入门到精通之二——Tensor

    Tensor又名张量,也是Tensorflow等框架中的重要数据结构.它可以是一个数(标量),一维数组(向量),二维数组或更高维数组.Tensor支持GPU加速. 创建Tensor 几种常见创建Ten ...

  9. 【Python】Python实战从入门到精通之三 -- 教你使用Python中条件语句

    本文是Python实战–从入门到精通系列的第三篇文章: Python实战从入门到精通第1讲–Python中的变量和数据类型 Python实战从入门到精通第2讲–Python中列表操作详解 Python ...

最新文章

  1. 文本框怎么变大html,如何设置HTML文本框的大小?
  2. 现实交互动作和现实环境交互的魅力
  3. 2016年定制维护组总结-历程回溯
  4. Thrown KeeperErrorCode = Unimplemented for /services exception
  5. “dedeCMS 提示信息!”跳转页,如何修改文字?
  6. 表单组件中state依赖props
  7. java 前端导出exvel_java导出数据到Excel文件 前端进行下载
  8. php pack方法,php pack()函数详解与示例
  9. 基于信息熵确立权重的topsis法_基于信息熵和TOPSIS法的装备战场抢修排序决策模型...
  10. 正则表达式匹配连续相同字符
  11. CentOS6.5配置网易163做yum源
  12. L1、L2、Batch Normalization、Dropout为什么能够防止过拟合呢?
  13. Linux性能优化(九)——Kernel Bypass
  14. 【06年博文搬家】查看本机的瑞星序列号
  15. Mirth学习笔记 - 建立Mirth通道
  16. 天正如何转为t3_[转载]天正文件转T3格式CAD图
  17. win7修改驱动inf,驱动非官方美加狮XBOX360手柄
  18. 网站域名过户查询_过期域名查询
  19. 安装andriod studio的过程中遇到的问题
  20. 3dmax的学习技巧大全

热门文章

  1. 周小川:数字人民币不会取代美元 也不会威胁全球货币体系
  2. 常见关联图库之欺诈指数排位战
  3. 如何用好埋点中的数据
  4. C++并发编程之std::future
  5. 机器学习(3):信息论
  6. display属性值
  7. angularjs框架
  8. X5档案-参加业务架构平台研讨会后记
  9. 微信小程序之旅一(页面渲染)
  10. [k8s]debug模式启动集群k8s常见报错集合(on the fly)