Pytorch 中retain_graph的坑

在查看SRGAN源码时有如下损失函数,其中设置了retain_graph=True,其作用就是

在更新D网络时的loss反向传播过程中使用了retain_graph=True,目的为是为保留该过程中计算的梯度,后续G网络更新时使用;

############################# (1) Update D network: maximize D(x)-1-D(G(z))###########################real_img = Variable(target)if torch.cuda.is_available():real_img = real_img.cuda()z = Variable(data)if torch.cuda.is_available():z = z.cuda()fake_img = netG(z)netD.zero_grad()real_out = netD(real_img).mean()fake_out = netD(fake_img).mean()d_loss = 1 - real_out + fake_outd_loss.backward(retain_graph=True) #####optimizerD.step()############################# (2) Update G network: minimize 1-D(G(z)) + Perception Loss + Image Loss + TV Loss###########################netG.zero_grad()g_loss = generator_criterion(fake_out, fake_img, real_img)g_loss.backward()optimizerG.step()fake_img = netG(z)fake_out = netD(fake_img).mean()g_loss = generator_criterion(fake_out, fake_img, real_img)running_results['g_loss'] += g_loss.data[0] * batch_sized_loss = 1 - real_out + fake_outrunning_results['d_loss'] += d_loss.data[0] * batch_sizerunning_results['d_score'] += real_out.data[0] * batch_sizerunning_results['g_score'] += fake_out.data[0] * batch_size

也就是说,只要我们有一个loss,我们就可以先loss.backward(retain_graph=True)  让它先计算梯度,若下面还有其他损失,但是可能你想扩展代码,可能有些loss是不用的,所以先加了 if 等判别语句进行了干预,使用loss.backward(retain_graph=True)就可以单独的计算梯度,屡试不爽。

但是另外一个问题在于,如果你都这么用的话,显存会爆炸,因为他保留了梯度,所以都没有及时释放掉,浪费资源。

而正确的做法应该是,在你最后一个loss 后面,一定要加上loss.backward()这样的形式,也就是让最后一个loss 释放掉之前所有暂时保存下来得梯度!!

Pytorch 中retain_graph的坑相关推荐

  1. Pytorch中Dataloader踩坑:RuntimeError: DataLoader worker (pid(s) 6700, 10620) exited unexpectedly

    Pytorch中Dataloader踩坑 环境: 问题背景: 观察报错信息进行分析 根据分析进行修改尝试 总结 环境: 系统:windows10 Pytorch版本:1.5.1+cu101 问题背景: ...

  2. Pytorch 中retain_graph的用法

    Pytorch 中retain_graph的用法 用法分析 在查看SRGAN源码时有如下损失函数,其中设置了retain_graph=True,其作用是什么? #################### ...

  3. Pytorch中retain_graph参数的作用

    RuntimeError: Trying to backward through the graph a second time, but the buffers have already been ...

  4. pytorch 中retain_graph==True的作用

    总的来说进行一次backward之后,各个节点的值会清除,这样进行第二次backward会报错,如果加上retain_graph==True后,可以再来一次backward. retain_graph ...

  5. python torch exp_学习Pytorch过程遇到的坑(持续更新中)

    1. 关于单机多卡的处理: 在pytorch官网上有一个简单的示例:函数使用为:torch.nn.DataParallel(model, deviceids, outputdevice, dim)关键 ...

  6. pytorch load state dict_学习Pytorch过程遇到的坑(持续更新中)

    1. 关于单机多卡的处理: 在pytorch官网上有一个简单的示例:函数使用为:torch.nn.DataParallel(model, deviceids, outputdevice, dim)关键 ...

  7. 详解Pytorch中的requires_grad、叶子节点与非叶子节点、with torch.no_grad()、model.eval()、model.train()、BatchNorm层

    requires_grad requires_grad意为是否需要计算梯度 使用backward()函数反向传播计算梯度时,并不是计算所有tensor的梯度,只有满足下面条件的tensor的梯度才会被 ...

  8. PyTorch中模型的可复现性

    点击上方"小白学视觉",选择加"星标"或"置顶" 重磅干货,第一时间送达 本文转自:AI算法与图像处理 在深度学习模型的训练过程中,难免引入 ...

  9. 更新fielddata为true_在pytorch中停止梯度流的若干办法,避免不必要模块的参数更新...

    在pytorch中停止梯度流的若干办法,避免不必要模块的参数更新 2020/4/11 FesianXu 前言 在现在的深度模型软件框架中,如TensorFlow和PyTorch等等,都是实现了自动求导 ...

最新文章

  1. 【NOIP2012模拟10.25】旅行
  2. 分享一个多线程实现[冒泡][选择][二分法]排序的例子
  3. ASP.NET 26个常用性能优化方法
  4. python 连接mysql 字符集_Python 操作 MySQL 字符集的问题
  5. [Python图像处理] 二十五.图像特效处理之素描、怀旧、光照、流年以及滤镜特效
  6. 大学生要学计算机嘛,大学新生有必要买电脑吗,为什么很多人都带电脑去学校了?...
  7. 没有计算机的一天英语作文带翻译,初一英语作文我的一天带翻译
  8. 编写XML XmlTextWriter与XmlDocument(转载)
  9. python中必须使用import引入模块_Python之import方法引入模块详解
  10. gg修改器偏移量修改_GG修改器偏移是怎么弄 | 手游网游页游攻略大全
  11. batch spring 重复执行_Spring Batch_JOB重启机制
  12. Go 企业级框架 Revel 版全新发布
  13. python实现MACD策略背离点的判断
  14. 《AMNet: Deep Atrous Multiscale Stereo Disparity Estimation Networks》
  15. python自动爬取更新电影网站_Python爬虫之—微信实时爬取电影咨询
  16. 病案首页计算机管理系统功能一般不包括,病案管理系统
  17. ParaView绘制自由水面的等值线图
  18. 五面拿下阿里飞猪offer,java图形界面设置背景颜色
  19. 基于python和selenium爬取JD商城商品信息并且分析用户对于产品的满意程度
  20. 信息[http-nio-80-exec-9] org. apache. coyote. http11. Httpl1Processor.service解析注意:HTTP请求解析错误的进--步发生将记录

热门文章

  1. 软件成本估算之快速功能点方法应用示例
  2. rtx3080ti什么时候上市 rtx3080ti和RTX 3080参数对比哪个好
  3. VS2017生成解决方案报错,提示对路径的访问被拒绝
  4. Java--对接微信第三篇之订阅发送图文消息给用户
  5. 在线考试系统设计时必须考虑的问题之一------------批改试卷问题
  6. macPorts homebrew
  7. Solr入门学习(二)—— Solr 的基本查询
  8. php验证码打叉,phpcms 验证码显示为红叉叉的补丁源码!
  9. 【Python】弧度转化为角度
  10. 2.求e的值。(分数阶乘)