Pytorch 中retain_graph的坑
Pytorch 中retain_graph的坑
在查看SRGAN源码时有如下损失函数,其中设置了retain_graph=True,其作用就是
在更新D网络时的loss反向传播过程中使用了retain_graph=True,目的为是为保留该过程中计算的梯度,后续G网络更新时使用;
############################# (1) Update D network: maximize D(x)-1-D(G(z))###########################real_img = Variable(target)if torch.cuda.is_available():real_img = real_img.cuda()z = Variable(data)if torch.cuda.is_available():z = z.cuda()fake_img = netG(z)netD.zero_grad()real_out = netD(real_img).mean()fake_out = netD(fake_img).mean()d_loss = 1 - real_out + fake_outd_loss.backward(retain_graph=True) #####optimizerD.step()############################# (2) Update G network: minimize 1-D(G(z)) + Perception Loss + Image Loss + TV Loss###########################netG.zero_grad()g_loss = generator_criterion(fake_out, fake_img, real_img)g_loss.backward()optimizerG.step()fake_img = netG(z)fake_out = netD(fake_img).mean()g_loss = generator_criterion(fake_out, fake_img, real_img)running_results['g_loss'] += g_loss.data[0] * batch_sized_loss = 1 - real_out + fake_outrunning_results['d_loss'] += d_loss.data[0] * batch_sizerunning_results['d_score'] += real_out.data[0] * batch_sizerunning_results['g_score'] += fake_out.data[0] * batch_size
也就是说,只要我们有一个loss,我们就可以先loss.backward(retain_graph=True) 让它先计算梯度,若下面还有其他损失,但是可能你想扩展代码,可能有些loss是不用的,所以先加了 if 等判别语句进行了干预,使用loss.backward(retain_graph=True)就可以单独的计算梯度,屡试不爽。
但是另外一个问题在于,如果你都这么用的话,显存会爆炸,因为他保留了梯度,所以都没有及时释放掉,浪费资源。
而正确的做法应该是,在你最后一个loss 后面,一定要加上loss.backward()这样的形式,也就是让最后一个loss 释放掉之前所有暂时保存下来得梯度!!
Pytorch 中retain_graph的坑相关推荐
- Pytorch中Dataloader踩坑:RuntimeError: DataLoader worker (pid(s) 6700, 10620) exited unexpectedly
Pytorch中Dataloader踩坑 环境: 问题背景: 观察报错信息进行分析 根据分析进行修改尝试 总结 环境: 系统:windows10 Pytorch版本:1.5.1+cu101 问题背景: ...
- Pytorch 中retain_graph的用法
Pytorch 中retain_graph的用法 用法分析 在查看SRGAN源码时有如下损失函数,其中设置了retain_graph=True,其作用是什么? #################### ...
- Pytorch中retain_graph参数的作用
RuntimeError: Trying to backward through the graph a second time, but the buffers have already been ...
- pytorch 中retain_graph==True的作用
总的来说进行一次backward之后,各个节点的值会清除,这样进行第二次backward会报错,如果加上retain_graph==True后,可以再来一次backward. retain_graph ...
- python torch exp_学习Pytorch过程遇到的坑(持续更新中)
1. 关于单机多卡的处理: 在pytorch官网上有一个简单的示例:函数使用为:torch.nn.DataParallel(model, deviceids, outputdevice, dim)关键 ...
- pytorch load state dict_学习Pytorch过程遇到的坑(持续更新中)
1. 关于单机多卡的处理: 在pytorch官网上有一个简单的示例:函数使用为:torch.nn.DataParallel(model, deviceids, outputdevice, dim)关键 ...
- 详解Pytorch中的requires_grad、叶子节点与非叶子节点、with torch.no_grad()、model.eval()、model.train()、BatchNorm层
requires_grad requires_grad意为是否需要计算梯度 使用backward()函数反向传播计算梯度时,并不是计算所有tensor的梯度,只有满足下面条件的tensor的梯度才会被 ...
- PyTorch中模型的可复现性
点击上方"小白学视觉",选择加"星标"或"置顶" 重磅干货,第一时间送达 本文转自:AI算法与图像处理 在深度学习模型的训练过程中,难免引入 ...
- 更新fielddata为true_在pytorch中停止梯度流的若干办法,避免不必要模块的参数更新...
在pytorch中停止梯度流的若干办法,避免不必要模块的参数更新 2020/4/11 FesianXu 前言 在现在的深度模型软件框架中,如TensorFlow和PyTorch等等,都是实现了自动求导 ...
最新文章
- 【NOIP2012模拟10.25】旅行
- 分享一个多线程实现[冒泡][选择][二分法]排序的例子
- ASP.NET 26个常用性能优化方法
- python 连接mysql 字符集_Python 操作 MySQL 字符集的问题
- [Python图像处理] 二十五.图像特效处理之素描、怀旧、光照、流年以及滤镜特效
- 大学生要学计算机嘛,大学新生有必要买电脑吗,为什么很多人都带电脑去学校了?...
- 没有计算机的一天英语作文带翻译,初一英语作文我的一天带翻译
- 编写XML XmlTextWriter与XmlDocument(转载)
- python中必须使用import引入模块_Python之import方法引入模块详解
- gg修改器偏移量修改_GG修改器偏移是怎么弄 | 手游网游页游攻略大全
- batch spring 重复执行_Spring Batch_JOB重启机制
- Go 企业级框架 Revel 版全新发布
- python实现MACD策略背离点的判断
- 《AMNet: Deep Atrous Multiscale Stereo Disparity Estimation Networks》
- python自动爬取更新电影网站_Python爬虫之—微信实时爬取电影咨询
- 病案首页计算机管理系统功能一般不包括,病案管理系统
- ParaView绘制自由水面的等值线图
- 五面拿下阿里飞猪offer,java图形界面设置背景颜色
- 基于python和selenium爬取JD商城商品信息并且分析用户对于产品的满意程度
- 信息[http-nio-80-exec-9] org. apache. coyote. http11. Httpl1Processor.service解析注意:HTTP请求解析错误的进--步发生将记录