在学习的过程中遇见了一个问题,就是当使用backward()反向传播时传入参数的问题:

net.zero_grad() #所有参数的梯度清零
output.backward(Variable(t.ones(1, 10))) #反向传播

这里的backward()中为什么需要传入参数Variable(t.ones(1, 10))呢?没有传入就会报错:

RuntimeError: grad can be implicitly created only for scalar outputs

这个错误的意思就是梯度只能为标量(即一个数)输出隐式地创建

比如有一个例子是:

1)

#使用Tensor新建一个Variable
x = Variable(t.ones(2, 2),requires_grad = True)
x

返回:

tensor([[1., 1.],[1., 1.]], requires_grad=True)

此时查看该值的grad和grad_fn是没有返回值的,因为没有进行任何操作

x.grad_fn
x.grad

进行求和操作,查看梯度

y = x.sum()
y

返回:

tensor(4., grad_fn=<SumBackward0>)

这时候可查看:

y.grad_fn

返回:

<SumBackward0 at 0x122782978>

可知y是变量Variable x进行sum操作求来的,但是这个时候y.grad是没有返回值的,因为没有使用y进行别的操作

这个时候的x.grad也是没有值的,虽然使用x进行了sum操作,但是还没有对y反向传播来计算梯度

y.backward()#反向传播,计算梯度

然后再查看:

#因为y = x.sum() = (x[0][0] + x[0][1] + x[1][0] + x[1][1])
#每个值的梯度都为1
x.grad

返回:

tensor([[1., 1.],[1., 1.]])

在这里我们可以看见y能够求出x的梯度,这里的y是一个数,即标量

如果这里我们更改一下y的操作,将y设置为一个二维数组:

from __future__ import print_function
import torch as t
from torch.autograd import Variable
x = Variable(t.ones(2, 2),requires_grad = True)
y = x + 1
y.backward()

然后就会报上面的错误:

RuntimeError: grad can be implicitly created only for scalar outputs

总结:

因此当输出不是标量时,调用.backwardI()就会出错

解决办法:

显示声明输出的类型作为参数传入,且参数的大小必须要和输出值的大小相同

x.grad.data.zero_() #将之前的值清零
x.grad

返回:

tensor([[0., 0.],[0., 0.]])

进行反向传播:

y.backward(y.data)
x.grad

也可以写成,因为Variable和Tensor有近乎一致的接口

y.backward(y)
x.grad

返回:

tensor([[2., 2.],[2., 2.]])

但是这里返回值与预想的1不同,这个原因是得到的梯度会与参数的值相乘,所以最好传入值为1,如:

y.backward(Variable(t.ones(2, 2)))
x.grad

这样就能够成功返回想要的值了:

tensor([[1., 1.],[1., 1.]])

更加复杂的操作:

在上面的例子中,x和y都是(2,2)的数组形式,每个yi都只与对应的xi相关

1)如果每个yi都与多个xi相关时,梯度又是怎么计算的呢?

比如x = (x1 = 2, x2 = 4), y = (x12+2x2, 2x1+3x22)

(i,j)的值就是传入.backward()的参数的值

x = Variable(t.FloatTensor([[2, 4]]),requires_grad = True)
y = Variable(t.zeros(1, 2))
y[0,0] = x[0,0]**2 + 2 * x[0,1]
y[0,1] = 2 * x[0,0] + 3 * x[0,1]**2
y.backward(Variable(t.ones(1, 2))) #(i,j)= (1,1)
x.grad

返回:

tensor([[ 6., 26.]])

2)如果x和y不是相同的数组形式,且每个yi都与多个xi相关时,梯度又是怎么计算的呢?

比如x = (x1 = 2, x2 = 4, x3=5), y = (x12+2x2+4x3, 2x1+3x22+x32)

x = Variable(t.FloatTensor([[2, 4, 5]]),requires_grad = True)
y = Variable(t.zeros(1, 2))
y[0,0] = x[0,0]**2 + 2 * x[0,1] + 4 * x[0,2]
y[0,1] = 2 * x[0,0] + 3 * x[0,1]**2 + x[0,2]**2
y.backward(Variable(t.ones(1, 2)))
x.grad

返回:

tensor([[ 6., 26., 14.]])

如果(i, j) = (2,2),结果是否为(12, 52, 28)呢?

x = Variable(t.FloatTensor([[2, 4, 5]]),requires_grad = True)
y = Variable(t.zeros(1, 2))
y[0,0] = x[0,0]**2 + 2 * x[0,1] + 4 * x[0,2]
y[0,1] = 2 * x[0,0] + 3 * x[0,1]**2 + x[0,2]**2
y.backward(Variable(t.FloatTensor([[2, 2]])))
x.grad

返回:

tensor([[12., 52., 28.]])

3)如果你想要分别得到y1,y2对x1,x2,x3的求导值,方法是:

x = Variable(t.FloatTensor([[2, 4, 5]]),requires_grad = True)
y = Variable(t.zeros(1, 2))
y[0,0] = x[0,0]**2 + 2 * x[0,1] + 4 * x[0,2]
y[0,1] = 2 * x[0,0] + 3 * x[0,1]**2 + x[0,2]**2
j = t.zeros(3,2)#用于存放求导的值
#(i,j)=(1,0)这样就会对应只求得y1对x1,x2和x3的求导
#retain_variables=True的作用是不在反向传播后释放内存,这样才能够再次反向传播
y.backward(Variable(t.FloatTensor([[1, 0]])),retain_variables=True)
j[:,0] = x.grad.data
x.grad.data.zero_() #将之前的值清零
#(i,j)=(1,0)这样就会对应只求得y2对x1,x2和x3的求导
y.backward(Variable(t.FloatTensor([[0, 1]])))
j[:,1] = x.grad.data
print(j)

报错:

TypeError: backward() got an unexpected keyword argument 'retain_variables'

原因是新版本使用的参数名为retain_graph,改了即可:

x = Variable(t.FloatTensor([[2, 4, 5]]),requires_grad = True)
y = Variable(t.zeros(1, 2))
y[0,0] = x[0,0]**2 + 2 * x[0,1] + 4 * x[0,2]
y[0,1] = 2 * x[0,0] + 3 * x[0,1]**2 + x[0,2]**2
j = t.zeros(3,2)#用于存放求导的值
#(i,j)=(1,0)这样就会对应只求得y1对x1,x2和x3的求导
#retain_graph=True的作用是不在反向传播后释放内存,这样才能够再次反向传播
y.backward(Variable(t.FloatTensor([[1, 0]])),retain_graph=True)
j[:,0] = x.grad.data
x.grad.data.zero_() #将之前的值清零
#(i,j)=(1,0)这样就会对应只求得y2对x1,x2和x3的求导
y.backward(Variable(t.FloatTensor([[0, 1]])))
j[:,1] = x.grad.data
print(j)

返回:

tensor([[ 4.,  2.],[ 2., 24.],[ 4., 10.]])

pytorch的backward相关推荐

  1. Pytorch的backward()相关理解

    Pytorch的backward()相关理解 最近一直在用pytorch做GAN相关的实验,pytorch 框架灵活易用,很适合学术界开展研究工作.  这两天遇到了一些模型参数寻优的问题,才发现自己对 ...

  2. Pytorch的backward()与optim.setp()的理解

    @Xia Pytorch的backward()与optim.setp()的理解 backward()与optim.setp()一直对这两个函数他们之间的调用不是很清楚,花点时间应该是搞明白了. 先看最 ...

  3. pytorch的backward参数

    首先,如果out.backward()中的out是一个标量的话(相当于一个神经网络有一个样本,这个样本有两个属性,神经网络有一个输出)那么此时我的backward函数是不需要输入任何参数的. 运行结果 ...

  4. pytorch Tensor.backward

    1.应用 import torch import torch.nn as nn# 1.全部为1 x = torch.tensor([1.0,3.0], requires_grad = True) # ...

  5. Pytorch中backward函数

    backward函数是反向求导数,使用链式法则求导,如果对非标量y求导,函数需要额外指定grad_tensors,grad_tensors的shape必须和y的相同. import torch fro ...

  6. 深度理解Pytorch中backward()

    转自https://blog.csdn.net/douhaoexia/article/details/78821428 接触pytorch很久了,也自认为对 backward 方法有一定了解,但看了这 ...

  7. [Python / PyTorch] debug backward()

    问题描述 在自定义Loss的中,其backward()函数不支持在PyCharm中进行断点调试 因此需要以其他方式进行断点调试 解决方案 参考:Is there a way to debug the ...

  8. 高效理解pytorch的backward需要scalar outputs

    利用backward时 , 可能经常遇到错误 RuntimeError: grad can be implicitly created only for scalar outputs 理解的最好方式就 ...

  9. Pytorch中backward(retain_graph=True)的 retain_graph参数解释

    每次 backward() 时,默认会把整个计算图free掉.一般情况下是每次迭代,只需一次 forward() 和一次 backward() ,前向运算forward() 和反向传播backward ...

最新文章

  1. 提高生产力:文件和IO操作(ApacheCommonsIO-汉化分享)
  2. 【微信小程序】根据当前运行环境调用不同的接口地址的一些方法
  3. 最新 UI 色彩渐变素材模板|设计师好帮手
  4. jQuery插件开发全解析(转)
  5. python 验证码test
  6. 1.窗体与界面设计-菜单应用实例
  7. cocosbuilder入门
  8. 程序猿的创业故事:一个游走于计算机编程、高中数学、高中物理、爱好木工的全栈工程师,转行做高中教学的亲生经历!
  9. 巴法络的ts系列服务器,BUFFALO TS5400D NAS 巴法络 4BAY 网络存储服务器 塔式 企业级...
  10. linux openerp,openerp
  11. cgcs2000大地坐标系地图_我国大地坐标系_地图与地图制图
  12. pytorch to_device遇到数据迁移不成功的问题
  13. java中long类型的空值怎么表示,【关于long类型的转换】传进来的是String类型是或null或0如何转成long类型...
  14. nacos注册成功但是服务管理界面没有内容
  15. Raphael 原理及实践
  16. 一周cp未能连接到服务器,阴阳师:“一周CP”活动帮你找情缘?玩家高喊错过了,工具人没了...
  17. 深圳摇号中签后异地车牌更换深圳车牌流程
  18. windows获取网卡信息并判断是否是物理网卡 网络适配器的判断
  19. [笔试补完计划]澜起科技2022数字验证笔试
  20. 阿里 美团 百度 字节跳动 腾讯 滴滴Java校招面试题总结

热门文章

  1. 设置select下拉框不可修改的→“四”←种方法
  2. 中级实训第一天的自学报告
  3. python中的raw string的使用
  4. JS 添加网页桌面快捷方式的代码
  5. 最优化方法系列:Adam+SGD-AMSGrad 重点
  6. 华为不造车,广汽合作智能驾驶
  7. LCD: 2D-3D匹配算法
  8. 2021年大数据Spark(四十六):Structured Streaming Operations 操作
  9. linux locale文件,Linux 怎样修改locale语言设置
  10. Android 通过创建一个类来传递对象