Pytorch 中retain_graph的用法

用法分析

在查看SRGAN源码时有如下损失函数,其中设置了retain_graph=True,其作用是什么?

############################# (1) Update D network: maximize D(x)-1-D(G(z))###########################real_img = Variable(target)if torch.cuda.is_available():real_img = real_img.cuda()z = Variable(data)if torch.cuda.is_available():z = z.cuda()fake_img = netG(z)netD.zero_grad()real_out = netD(real_img).mean()fake_out = netD(fake_img).mean()d_loss = 1 - real_out + fake_outd_loss.backward(retain_graph=True) #####optimizerD.step()############################# (2) Update G network: minimize 1-D(G(z)) + Perception Loss + Image Loss + TV Loss###########################netG.zero_grad()g_loss = generator_criterion(fake_out, fake_img, real_img)g_loss.backward()optimizerG.step()fake_img = netG(z)fake_out = netD(fake_img).mean()g_loss = generator_criterion(fake_out, fake_img, real_img)running_results['g_loss'] += g_loss.data[0] * batch_sized_loss = 1 - real_out + fake_outrunning_results['d_loss'] += d_loss.data[0] * batch_sizerunning_results['d_score'] += real_out.data[0] * batch_sizerunning_results['g_score'] += fake_out.data[0] * batch_size

​ 在更新D网络时的loss反向传播过程中使用了retain_graph=True,目的为是为保留该过程中计算的梯度,后续G网络更新时使用;

其实retain_graph这个参数在平常中我们是用不到的,但是在特殊的情况下我们会用到它,

如下代码:

import torch
y=x**2
z=y*4
output1=z.mean()
output2=z.sum()
output1.backward()
output2.backward()

输出如下错误信息:

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
<ipython-input-19-8ad6b0658906> in <module>()
----> 1 output1.backward()2 output2.backward()D:\ProgramData\Anaconda3\lib\site-packages\torch\tensor.py in backward(self, gradient, retain_graph, create_graph)91                 products. Defaults to ``False``.92         """
---> 93         torch.autograd.backward(self, gradient, retain_graph, create_graph)94 95     def register_hook(self, hook):D:\ProgramData\Anaconda3\lib\site-packages\torch\autograd\__init__.py in backward(tensors, grad_tensors, retain_graph, create_graph, grad_variables)88     Variable._execution_engine.run_backward(89         tensors, grad_tensors, retain_graph, create_graph,
---> 90         allow_unreachable=True)  # allow_unreachable flag91 92 RuntimeError: Trying to backward through the graph a second time, but the buffers have already been freed. Specify retain_graph=True when calling backward the first time.

修改成如下正确:

import torch
y=x**2
z=y*4
output1=z.mean()
output2=z.sum()
output1.backward(retain_graph=True)
output2.backward()
# 假如你有两个Loss,先执行第一个的backward,再执行第二个backward
loss1.backward(retain_graph=True)
loss2.backward() # 执行完这个后,所有中间变量都会被释放,以便下一次的循环
optimizer.step() # 更新参数

Variable 类源代码

class Variable(_C._VariableBase):"""Attributes:data: 任意类型的封装好的张量。grad: 保存与data类型和位置相匹配的梯度,此属性难以分配并且不能重新分配。requires_grad: 标记变量是否已经由一个需要调用到此变量的子图创建的bool值。只能在叶子变量上进行修改。volatile: 标记变量是否能在推理模式下应用(如不保存历史记录)的bool值。只能在叶变量上更改。is_leaf: 标记变量是否是图叶子(如由用户创建的变量)的bool值.grad_fn: Gradient function graph trace.Parameters:data (any tensor class): 要包装的张量.requires_grad (bool): bool型的标记值. **Keyword only.**volatile (bool): bool型的标记值. **Keyword only.**"""def backward(self, gradient=None, retain_graph=None, create_graph=None, retain_variables=None):"""计算关于当前图叶子变量的梯度,图使用链式法则导致分化如果Variable是一个标量(例如它包含一个单元素数据),你无需对backward()指定任何参数如果变量不是标量(包含多个元素数据的矢量)且需要梯度,函数需要额外的梯度;需要指定一个和tensor的形状匹配的grad_output参数(y在指定方向投影对x的导数);可以是一个类型和位置相匹配且包含与自身相关的不同函数梯度的张量。函数在叶子上累积梯度,调用前需要对该叶子进行清零。Arguments:grad_variables (Tensor, Variable or None):变量的梯度,如果是一个张量,除非“create_graph”是True,否则会自动转换成volatile型的变量。可以为标量变量或不需要grad的值指定None值。如果None值可接受,则此参数可选。retain_graph (bool, optional): 如果为False,用来计算梯度的图将被释放。在几乎所有情况下,将此选项设置为True不是必需的,通常可以以更有效的方式解决。默认值为create_graph的值。create_graph (bool, optional): 为True时,会构造一个导数的图,用来计算出更高阶导数结果。默认为False,除非``gradient``是一个volatile变量。"""torch.autograd.backward(self, gradient, retain_graph, create_graph, retain_variables)def register_hook(self, hook):"""Registers a backward hook.每当与variable相关的梯度被计算时调用hook,hook的申明:hook(grad)->Variable or None不能对hook的参数进行修改,但可以选择性地返回一个新的梯度以用在`grad`的相应位置。函数返回一个handle,其``handle.remove()``方法用于将hook从模块中移除。Example:>>> v = Variable(torch.Tensor([0, 0, 0]), requires_grad=True)>>> h = v.register_hook(lambda grad: grad * 2)  # double the gradient>>> v.backward(torch.Tensor([1, 1, 1]))>>> v.grad.data222[torch.FloatTensor of size 3]>>> h.remove()  # removes the hook"""if self.volatile:raise RuntimeError("cannot register a hook on a volatile variable")if not self.requires_grad:raise RuntimeError("cannot register a hook on a variable that ""doesn't require gradient")if self._backward_hooks is None:self._backward_hooks = OrderedDict()if self.grad_fn is not None:self.grad_fn._register_hook_dict(self)handle = hooks.RemovableHandle(self._backward_hooks)self._backward_hooks[handle.id] = hookreturn handledef reinforce(self, reward):"""Registers a reward obtained as a result of a stochastic process.区分随机节点需要为他们提供reward值。如果图表中包含任何的随机操作,都应该在其输出上调用此函数,否则会出现错误。Parameters:reward(Tensor): 带有每个元素奖赏的张量,必须与Variable数据的设备位置和形状相匹配。"""if not isinstance(self.grad_fn, StochasticFunction):raise RuntimeError("reinforce() can be only called on outputs ""of stochastic functions")self.grad_fn._reinforce(reward)def detach(self):"""返回一个从当前图分离出来的心变量。结果不需要梯度,如果输入是volatile,则输出也是volatile。.. 注意::返回变量使用与原始变量相同的数据张量,并且可以看到其中任何一个的就地修改,并且可能会触发正确性检查中的错误。"""result = NoGrad()(self)  # this is needed, because it merges version countersresult._grad_fn = Nonereturn resultdef detach_(self):"""从创建它的图中分离出变量并作为该图的一个叶子"""self._grad_fn = Noneself.requires_grad = Falsedef retain_grad(self):"""Enables .grad attribute for non-leaf Variables."""if self.grad_fn is None:  # no-op for leavesreturnif not self.requires_grad:raise RuntimeError("can't retain_grad on Variable that has requires_grad=False")if hasattr(self, 'retains_grad'):returnweak_self = weakref.ref(self)def retain_grad_hook(grad):var = weak_self()if var is None:returnif var._grad is None:var._grad = grad.clone()else:var._grad = var._grad + gradself.register_hook(retain_grad_hook)self.retains_grad = True

参考

https://oldpan.me/archives/pytorch-retain_graph-work

https://www.cnblogs.com/hellcat/p/8449031.html

Pytorch 中retain_graph的用法相关推荐

  1. Pytorch 中retain_graph的坑

    Pytorch 中retain_graph的坑 在查看SRGAN源码时有如下损失函数,其中设置了retain_graph=True,其作用就是 在更新D网络时的loss反向传播过程中使用了retain ...

  2. Pytorch中retain_graph参数的作用

    RuntimeError: Trying to backward through the graph a second time, but the buffers have already been ...

  3. pytorch中contiguous()的用法

    contiguous:view只能用在contiguous的variable上.如果在view之前用了transpose, permute等,需要用contiguous()来返回一个contiguou ...

  4. pytorch 中retain_graph==True的作用

    总的来说进行一次backward之后,各个节点的值会清除,这样进行第二次backward会报错,如果加上retain_graph==True后,可以再来一次backward. retain_graph ...

  5. Pytorch中的detach用法

    该方法主要应用在Variable变量上,作用是从分离出一个tensor,值和原Variable一样,但是不需要计算梯度. 其源码如下: def detach(self):result = NoGrad ...

  6. PyTorch中permute的用法

    RuntimeError: Given groups=1, weight of size [18, 8, 8], expected input[64, 32, 8] to have 8 channel ...

  7. pytorch中arange()函数用法

    语法:torch.arange(start=0,end,step=1) 解释:开始默认为0,步长默认为1,可以不写:终值必须写 返回的个数: 举例:

  8. pytorch中gather用法

    pytorch中gather的用法 2维度tensor进行映射: 3维度tensor进行映射: gather其实是对input进行一种映射,index必须是 LongTensor格式. 2维度tens ...

  9. python中size_x的意思,对pytorch中x = x.view(x.size(0), -1) 的理解说明

    在pytorch的CNN代码中经常会看到 x.view(x.size(0), -1) 首先,在pytorch中的view()函数就是用来改变tensor的形状的,例如将2行3列的tensor变为1行6 ...

最新文章

  1. eclipse 出现user operation is waiting
  2. redmine上传大文件报错Request Entity Too Large
  3. 【贯穿】.NET6结合Docker傻瓜式实现容器编排
  4. 这样能收录,原理是用的凤凰新闻采集工具
  5. DarkNet yoloV2 转到caffe使用
  6. 计算机答辩ppt结论,论文总结与致谢ppt_ppt结束致谢_答辩ppt致谢
  7. 南京邮电大学离散数学实验一(求主析取和主合取范式)
  8. 编译原理 实验二 递归下降语法分析程序
  9. Python Network(二)绘图draw系列draw(),draw_networkx(),draw_networkx_nodes(),draw_networkx_edges()
  10. yarn add 添加依赖的各种类型(指定版本安装、git中安装、tgz包安装、文件夹安装)
  11. 深入浅出理解视频编码H264结构(内涵福利)
  12. 活动|图观™数字孪生精品助推计划
  13. OpenCV学习记录 三 (傅里叶逆变换原理及实现)
  14. Excel生成随机32、36位ID
  15. 今日金融词汇--- 熔断,是什么?
  16. Firefox下载文件时中文名乱码问题
  17. ChatGPT到底是个啥 - 它甚至会和狗说话
  18. zookeeper选举和ZAB协议
  19. 微纳制造技术(半导体制造书籍pdf)
  20. 物联网大数据平台的主要功能和特点

热门文章

  1. kafka如何消费消息
  2. Android开启odex,优化开机速度
  3. 电脑如何查看ip地址和路由器网关? 查看网关ip地址的方法
  4. STM32 重新理解GPIO配置以及配置PWM波输出
  5. linux 同步北京时间 局域网同步时间
  6. html 强制分散对齐,CSS 水平分散对齐
  7. 家具抽屉滑轨行业调研报告 - 市场现状分析与发展前景预测
  8. 【机器学习】聚类算法 BIRCH(Balanced Iterative Reducing and Clustering Using Hierarchies)
  9. 朝秦暮楚魂牵梦萦魂牵梦萦
  10. ctfshow吃瓜杯 八月群赛 WriteUp/WP