Pytorch的backward()相关理解

最近一直在用pytorch做GAN相关的实验,pytorch 框架灵活易用,很适合学术界开展研究工作。 
这两天遇到了一些模型参数寻优的问题,才发现自己对pytorch的自动求导和寻优功能没有深刻理解,导致无法灵活的进行实验。于是查阅资料,同时自己做了一点小实验,做了一些总结,虽然好像都是一些显而易见的结论,但是如果不能清晰的理解,对于实验复杂的网络模型程序会造成困扰。以下仅针对pytorch 0.2 版本,如有错误,希望得到指正。

  • 相关标志位/函数

    • 1

      • requires_grad
      • volatile
      • detach()/detach_()
      • retain_graph
      • retain_variables
      • create_graph

    前三个标志位中,最关键的就是 requires_grad,另外两个都可以转化为 requires_grad 来理解。 
    后三个标志位,与计算图的保持与建立有关系。其中 retain_variables 与 retain_graph等价,retain_variables 已在pytorch 新版本中被取消。

  • requires_grad 的含义及标志位说明

    • 如果对于某Variable 变量 x ,其 x.requires_grad == True , 则表示 它可以参与求导,也可以从它向后求导。 
      默认情况下,一个新的Variables 的 requires_grad 和 volatile 都等于 False 。

    • requires_grad == True 具有传递性,如果: 
      x.requires_grad == True ,y.requires_grad == False , z=f(x,y) 
      则, z.requires_grad == True

    • 凡是参与运算的变量(包括 输入量,中间输出量,输出量,网络权重参数等),都可以设置 requires_grad 。

    • volatile==True 就等价于 requires_grad==False 。 volatile==True 同样具有传递性。一般只用在inference过程中。若是某个过程,从 x 开始 都只需做预测,不需反传梯度的话,那么只需设置x.volatile=True ,那么 x 以后的运算过程的输出均为 volatile==True ,即 requires_grad==False 。 
      虽然inference 过程不必backward(),所以requires_grad 的值为False 或 True,对结果是没有影响的,但是对程序的运算效率有直接影响;所以使用volatile=True ,就不必把运算过程中所有参数都手动设一遍requires_grad=False 了,方便快捷。

    • detach() ,如果 x 为中间输出,x' = x.detach 表示创建一个与 x 相同,但requires_grad==False 的variable, (实际上是把x’ 以前的计算图 grad_fn 都消除了),x’ 也就成了叶节点。原先反向传播时,回传到x时还会继续,而现在回到x’处后,就结束了,不继续回传求到了。另外值得注意, x (variable类型) 和 x’ (variable类型)都指向同一个Tensor ,即 x.data 
      detach_() 表示不创建新变量,而是直接修改 x 本身。

    • retain_graph ,每次 backward() 时,默认会把整个计算图free掉。一般情况下是每次迭代,只需一次 forward() 和一次 backward() ,前向运算forward() 和反向传播backward()是成对存在的,一般一次backward()也是够用的。但是不排除,由于自定义loss等的复杂性,需要一次forward(),多个不同loss的backward()来累积同一个网络的grad,来更新参数。于是,若在当前backward()后,不执行forward() 而可以执行另一个backward(),需要在当前backward()时,指定保留计算图,即backward(retain_graph)。

    • create_graph ,这个标志位暂时还未深刻理解,等之后再更新。
  • 反向求导 和 权重更新

    • 求导和优化(权重更新)是两个独立的过程,只不过优化时一定需要对应的已求取的梯度值。所以求得梯度值很关键,而且,经常会累积多种loss对某网络参数造成的梯度,一并更新网络。

    • 反向传播过程中,肯定需要整个过程都链式求导。虽然中间参数参与求导,但是却可以不用于更新该处的网络参数。参数更新可以只更新想要更新的网络的参数。

    • 如果obj是函数运算结果,且是标量,则 obj.backward() (注意,backward()函数中没有填入任何tensor值, 就相当于 backward(torch.tensor([1])) ),则 x.grad 就是 ∂obj∂x∣∣x=1∂obj∂x|x=1 。

    • 对于继承自 nn.Module 的某一网络 net 或网络层,定义好后,发现 默认情况下,net.paramters 的 requires_grad 就是 True 的(虽然只是实验证明的,还未从源码处找到证据),这跟普通的Variable张量不同。因此,当x.requires_grad == False , y = net(x) 后, 有 y.requires_grad == True ;但值得注意,虽然nn.xxloss和激活层函数,是继承nn.Module的,但是这两种并没有网络参数,就更谈不上 paramters.requires_grad 的值了。所以类似这两种函数的输出,其requires_grad只跟输入有关,不一定是 True .

  • 计算图相关

    • 计算图就是模型 前向forward() 和后向求梯度backward() 的流程参照。

    • 能获取回传梯度(grad)的只有计算图的叶节点。注意是获取,而不是求取。中间节点的梯度在计算求取并回传之后就会被释放掉,没办法获取。想要获取中间节点梯度,可以使用 register_hook (钩子)函数工具。当然, register_hook 不仅仅只有这个作用。

    • 只有标量才能直接使用 backward(),即loss.backward() , pytorch 框架中的各种nn.xxLoss(),得出的都是minibatch 中各结果 平均/求和 后的值。如果使用自定义的函数,得到的不是标量,则backward()时需要传入 grad_variable 参数,这一点详见博客 https://sherlockliao.github.io/2017/07/10/backward/ 。

    • 经常会有这样的情况: 
      x1 —> |net1| —> y1 —> |net2| —> z1 , net1和net2是两个不同的网络。x1 依次通过 两个网络运算,生成 z1 。比较担心一次性运算后,再backward(),是不是只更新net1 而不是net1、net2都更新呢? 
      类比 x2 —> |f1| —> y2 —> |f2| —> z2 , f1 、f2 是两个普通的函数,z2=f2(y2) , y2=f1(x2) 。 
      按照以下代码实验

      w1 = torch.Tensor([2]) #认为w1 与 w2 是函数f1 与 f2的参数
      w1 = Variable(w1,requires_grad=True)
      w2 = torch.Tensor([2])
      w2 = Variable(w2,requires_grad=True)
      x2 = torch.rand(1)
      x2 = Variable(x2,requires_grad=True)
      y2 = x2**w1            # f1 运算
      z2 = w2*y2+1           # f2 运算
      z2.backward()
      print(x2.grad)
      print(y2.grad)
      print(w1.grad)
      print(w2.grad)

      发现 x2.grad,w1.grad,w2.grad 是个值 ,但是 y2.grad 却是 None, 说明x2,w1,w2的梯度保留了,y2 的梯度获取不到。实际上,仔细想一想会发现,x2,w1,w2均为叶节点。在这棵计算树中 ,x2 与w1 是同一深度(底层)的叶节点,y2与w2 是同一深度,w2 是单独的叶节点,而y2 是x2 与 w1 的父节点,所以只有y2没有保留梯度值,印证了之前的说法。同样这也说明,计算图本质就是一个类似二叉树的结构。

       
      那么对于 两个网络,会是怎么样呢? 我使用pytorch 的cifar10 例程,稍作改动做了实验。把例程中使用的一个 Alexnet 拆成了两个net —— net1 和 net2 。

          optimizer = torch.optim.SGD(itertools.chain(net1.parameters(), net2.parameters()),lr=0.001, momentum=0.9) # 这里 net1 和net2 优化的先后没有区别 !!#optimizer.zero_grad() #将参数的grad值初始化为0## forward + backward + optimizeoutputs1 = net1(inputs)            #input 未置requires_grad为True,但不影响outputs2 = net2(outputs1)loss = criterion(outputs2, labels) #计算损失loss.backward()                    #反向传播      #     print("inputs.requires_grad:")print(inputs.requires_grad)        # Falseprint("the grad of inputs:")print(inputs.grad)                 # Noneprint("outputs1.requires_grad:")print(outputs1.requires_grad)      # Trueprint("the grad of outputs1:")        print(outputs1.grad)               # None     # print("the grad of net1:")print(net1.conv1.bias.grad)        # no-Noneprint("the grad of net2:")print(net2.fc3.bias.grad)          # no-None#optimizer.step() #用SGD更新参数

      后缀注释就是打印的结果。可以看出,只有网络参数的grad是直接可获取的。而且是两个网络都可以获取grad 值,获取grad后,当然就可以更新网络的参数了,两个网络都是可以更新的。

      类比上边例子的解释,两个网络其实就是处在叶节点的位置,只不过深度不同。同理,网络内部的运算,每一层网络权重参数其实也是处在叶节点上,只不过在树中的深度不同罢了,前向运算时按照二叉树的结构,不断生成父节点。

      (事实上,原先是以为 网络 与 普通函数不同,因为它具有register_xx_hook()这个类函数工具,所以认为它可以默认保存权重参数的grad来用于更新,后来才明白,本质上与普通函数的参数一样,都是处在叶节点,就可以保存参数的grad,至于register_xx_hook(),看来是另做它用,或者说用register_xx_hook()可以记录甚至更改中间节点的grad值)

  • 一些特殊的情况:

    • 把网络某一部分参数,固定,不让其被训练。可以使用requires_grad.

      for p in sub_module.parameters():p.requires_grad = False

      可以这样理解,因为是叶节点(而不是中间节点),所以不求grad(grad为’None’),也不会影响网络的正常反向传播。

以上就是一些总结,敬请指正!

Pytorch的backward()相关理解相关推荐

  1. 【Pytorch】backward()简单理解

    backward()是反向传播求梯度,具体实现过程如下 import torchx=torch.tensor([1,2,3],requires_grad=True,dtype=torch.double ...

  2. Pytorch的backward()与optim.setp()的理解

    @Xia Pytorch的backward()与optim.setp()的理解 backward()与optim.setp()一直对这两个函数他们之间的调用不是很清楚,花点时间应该是搞明白了. 先看最 ...

  3. [转]一文解释PyTorch求导相关 (backward, autograd.grad)

    PyTorch是动态图,即计算图的搭建和运算是同时的,随时可以输出结果:而TensorFlow是静态图. 在pytorch的计算图里只有两种元素:数据(tensor)和 运算(operation) 运 ...

  4. pytorch中repeat()函数理解

    pytorch中repeat()函数理解 最近在学习过程中遇到了repeat()函数的使用,这里记录一下自己对这个函数的理解. 情况1:repeat参数个数与tensor维数一致时 a = torch ...

  5. Henry前端笔记之 UI组件库中table与slot相关理解

    Henry前端笔记之 UI组件库中table与slot相关理解 作用域插槽: 解构赋值基础:https://developer.mozilla.org/zh-CN/docs/Web/JavaScrip ...

  6. pytorch 中 contiguous() 函数理解

    pytorch 中 contiguous() 函数理解 文章目录 pytorch 中 contiguous() 函数理解 引言 使用 contiguous() 后记 文章抄自 Pytorch中cont ...

  7. Android:Window相关理解

    文章目录 一.Window概述 Window概念 Window和DecorView 二.Window属性和类型 Window的类型 应用窗口 子窗口 系统窗口 Window的属性 type参数 Fla ...

  8. 深度理解Pytorch中backward()

    转自https://blog.csdn.net/douhaoexia/article/details/78821428 接触pytorch很久了,也自认为对 backward 方法有一定了解,但看了这 ...

  9. 通俗讲解Pytorch梯度的相关问题:计算图、torch.no_grad、zero_grad、detach和backward;Variable、Parameter和torch.tensor

    文章目录 with torch.no_grad()和requires_grad backward() Variable,Parameter和torch.tensor() zero_grad() 计算图 ...

最新文章

  1. 中国大学的现实:层次越低,上课越多,学生读书越少
  2. 变量声明和定义有什么区别
  3. python用中文怎么说-python如何设置中文界面
  4. photoshop8.0 安装步骤及注意事项
  5. Api管理工具(spring-rest-docs)
  6. Error: Could not find or load main class CLASS的解决方法
  7. Java中文件复制的一个汇总
  8. 重新分区_电脑磁盘分区指南!一分钟就学会
  9. Android 系统(145)---ODM 开发用户常见需求文档(七)
  10. 使用Github发布自己的网站
  11. oracle图形工具创建作业,oracle入门(2)—— 使用图形工具navicat for oracle
  12. window.load和$(document).ready()事件
  13. 【数据结构】从零实现顺序表+链表相关操作
  14. Thread-Specific Storage Pattern
  15. 解决TypeError: conv2d() received an invalid combination of arguments
  16. 服务器CPU经常跑高是什么原因
  17. centos 6 升级gcc
  18. Gos ——操作键盘
  19. 请编程序将“China“译成密码,密码规律是:用原来的字母后面第4个字母代替原来的字母。例如:字母“A”后面第4个字母时“E“,用“E“代替“A“。因此,“China“应译为“Glmre”。请编一程序
  20. “滴滴出行” 成长路径分析(2016年01月19日)

热门文章

  1. 关于Go ROOT 和Go PATH的设置
  2. jdbctemplate oracle xml文件,Spring JDBCTemplate使用JNDI数据源
  3. linux 约等于符号,Mac OS X基础教程:特殊符号的快捷输入方式
  4. oracle日志版本不同,Oracle重做日志文件版本不一致问题处理
  5. python qtdesigner 提升类_python3+PyQt5+Qt Designer实现扩展对话框
  6. 时间android版官方版下载,时间块app安卓下载
  7. 微信小程序的省市区三级地址mysql_微信小程序 实现三级联动-省市区
  8. java实用solr6.6_搜索引擎Solr-6.6.0搭建
  9. linux删除新建的磁盘分区,Fixmbr,删除磁盘分区,新建磁盘分区,等待正式Ubuntu...
  10. jackson 反序列化string_Java 中使用Jackson反序列化