(八)还没写,先跳过。。。

总说

简单来说detach就是截断反向传播的梯度流

    def detach(self):"""Returns a new Variable, detached from the current graph.Result will never require gradient. If the input is volatile, the outputwill be volatile too... note::Returned Variable uses the same data tensor, as the original one, andin-place modifications on either of them will be seen, and may triggererrors in correctness checks."""result = NoGrad()(self)  # this is needed, because it merges version countersresult._grad_fn = Nonereturn result

可以看到Returns a new Variable, detached from the current graph。将某个node变成不需要梯度的Varibale。因此当反向传播经过这个node时,梯度就不会从这个node往前面传播。

从GAN的代码中看detach()

GAN的G的更新,主要是GAN loss。就是G生成的fake图让D来判别,得到的损失,计算梯度进行反传。这个梯度只能影响G,不能影响D!可以看到,由于torch是非自动求导的,每一层的梯度的计算必须用net:backward才能计算gradInput和网络中的参数的梯度。

先看Torch版本的代码

local fGx = function(x)netD:apply(function(m) if torch.type(m):find('Convolution') then m.bias:zero() end end)netG:apply(function(m) if torch.type(m):find('Convolution') then m.bias:zero() end end)gradParametersG:zero()-- GAN losslocal df_dg = torch.zeros(fake_B:size())if opt.use_GAN==1 thenlocal output = netD.output -- netD:forward{input_A,input_B} was already executed in fDx, so save computationlocal label = torch.FloatTensor(output:size()):fill(real_label) -- fake labels are real for generator costerrG = criterion:forward(output, label)local df_do = criterion:backward(output, label)df_dg = netD:updateGradInput(fake_AB, df_do):narrow(2,fake_AB:size(2)-output_nc+1, output_nc)elseerrG = 0end-- unary loss-- 得到 df_do_AE(已省略)   netG:backward(real_A, df_dg + df_do_AE:mul(opt.lambda))return errG, gradParametersG
end

在下面代码中,是先得到fake图进入D的loss,然后这个loss的梯度df_do进行反传,首先要这个梯度经过D。此时不能改变D的参数的梯度,所以这里用updateGradInput,不能用backward。这是因为backward是调用2个函数updateGradInputaccGradParameters。后者是计算loss对于网络中参数的梯度,这些梯度是不断累加的!除非手动gradParametersG:zero()置零。

       errG = criterion:forward(output, label)local df_do = criterion:backward(output, label)df_dg = netD:updateGradInput(fake_AB, df_do):narrow(2,fake_AB:size(2)-output_nc+1, output_nc)-- unary loss-- 得到 df_do_AE(已省略)   netG:backward(real_A, df_dg + df_do_AE:mul(opt.lambda))

然后得到的df_dg才是要更新G的GAN损失的梯度,当然G的另一个损失是L1损失(unary loss)这个没啥好说了。

pytorch的GAN实现

由于Pytorch是自动反向传播,

    def backward_D(self):# Fake# stop backprop to the generator by detaching fake_Bfake_AB = self.fake_B# fake_AB = self.fake_AB_pool.query(torch.cat((self.real_A, self.fake_B), 1))self.pred_fake = self.netD.forward(fake_AB.detach())self.loss_D_fake = self.criterionGAN(self.pred_fake, False)# Realreal_AB = self.real_B # GroundTruth# real_AB = torch.cat((self.real_A, self.real_B), 1)self.pred_real = self.netD.forward(real_AB)self.loss_D_real = self.criterionGAN(self.pred_real, True)# Combined lossself.loss_D = (self.loss_D_fake + self.loss_D_real) * 0.5self.loss_D.backward()def backward_G(self):# First, G(A) should fake the discriminatorfake_AB = self.fake_Bpred_fake = self.netD.forward(fake_AB)self.loss_G_GAN = self.criterionGAN(pred_fake, True)# Second, G(A) = Bself.loss_G_L1 = self.criterionL1(self.fake_B, self.real_B) * self.opt.lambda_Aself.loss_G = self.loss_G_GAN + self.loss_G_L1self.loss_G.backward()def forward(self):self.real_A = Variable(self.input_A)self.fake_B = self.netG.forward(self.real_A)self.real_B = Variable(self.input_B)# 先调用 forward, 再 D backward, 更新D之后; 再G backward, 再更新Gdef optimize_parameters(self):self.forward()self.optimizer_D.zero_grad()self.backward_D()self.optimizer_D.step()self.optimizer_G.zero_grad()self.backward_G()self.optimizer_G.step()

解释backward_D:

对于D,我们值需要,如果输入是真实图,那么产生loss,输入真实图,也产生loss。
这两个梯度进行更新D。如果是真实图(real_B),由于real_B是初始结点,所以没什么可担心的。但是对于生成图fake_B,由于 fake_B是由 netG.forward(real_A)产生的。我们只希望 该loss更新D不要影响到 G. 因此这里需要“截断反传的梯度流”,用 fake_AB = fake_B, fake_AB.detach()从而让梯度不要通过 fake_AB反传到netG中!

解释backward_G:

由于在调用 backward_G已经调用了zero_grad,所以没什么好担心的。
更新G时,来自D的GAN损失是, netD.forward(fake_AB),得到 pred_fake,然后得到损失,反传播即可。
注意,这里反向传播时,会先将梯度传到 fake_AB结点,然而我们知道 fake_AB即 fake_B结点,而fake_B正是由netG(real_A)产生的,所以还会顺着继续往前传播,从而得到G的对应的梯度。

对比 Torch代码

df_dg = netD:updateGradInput(fake_AB, df_do):narrow(2,fake_AB:size(2)-output_nc+1, output_nc)
netG:backward(real_A, df_dg + df_do_AE:mul(opt.lambda))

Torch中,没有计算netD的参数的梯度,而是直接用 updateGradInput。在pytorch中,我们也是希望GAN loss只能更新G。但是pytorch是自动求导的,所以我们没法手动像Torch一样只调用updateGradInput

        self.loss_G_GAN = self.criterionGAN(pred_fake, True)# Second, G(A) = Bself.loss_G_L1 = self.criterionL1(self.fake_B, self.real_B) * self.opt.lambda_Aself.loss_G = self.loss_G_GAN + self.loss_G_L1self.loss_G.backward()

在这里,虽然pytorch中会自动计算所有的结点的梯度,但是我们执行loss_G.backward()后,按照Torch的理解是,这里直接调用backward。即不仅调用了updateGradInput(我们只需要这个),还额外的计算了accGradParameters(这个是没用的),但是看到,在optimize_parameters中,只是进行 optimizer_G.step()所以只会更新G的参数。所以没有更新D(虽然此时D中有dummy gradient)。等下一回合,又调用 optimizer_D.zero_grad(), 因此会把刚才残留的D的梯度清空。所以仍旧是符合的。

自动求导反向书写的简洁

得出结论,书写自动求导的代码完全还是很简洁的。只需要进行loss计算。loss可以直接相加,然后loss.backward()即可。loss的定义比如:

self.optimizer_G = torch.optim.Adam(self.netG.parameters(),lr=opt.lr, betas=(opt.beta1, 0.999))

Adam是继承自Optimizer类。该类的step函数会将构建loss的所有的Variable的参数进行更新。

    def step(self, closure=None):"""Performs a single optimization step.Arguments:closure (callable, optional): A closure that reevaluates the modeland returns the loss."""loss = Noneif closure is not None:loss = closure()for group in self.param_groups:for p in group['params']: #如果这个参数有没有grad(这个Variable的requries_grad为False)#则直接跳过。if p.grad is None:continuegrad = p.grad.datastate = self.state[p]# 对p.data进行更新!就是对参数进行更新!# State initializationif len(state) == 0:state['step'] = 0# Exponential moving average of gradient valuesstate['exp_avg'] = grad.new().resize_as_(grad).zero_()# Exponential moving average of squared gradient valuesstate['exp_avg_sq'] = grad.new().resize_as_(grad).zero_()exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']beta1, beta2 = group['betas']state['step'] += 1if group['weight_decay'] != 0:grad = grad.add(group['weight_decay'], p.data)# Decay the first and second moment running average coefficientexp_avg.mul_(beta1).add_(1 - beta1, grad)exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad)denom = exp_avg_sq.sqrt().add_(group['eps'])bias_correction1 = 1 - beta1 ** state['step']bias_correction2 = 1 - beta2 ** state['step']step_size = group['lr'] * math.sqrt(bias_correction2) / bias_correction1p.data.addcdiv_(-step_size, exp_avg, denom)

Pytorch入门学习(九)---detach()的作用(从GAN代码分析)相关推荐

  1. PyTorch框架学习九——网络模型的构建

    PyTorch框架学习九--网络模型的构建 一.概述 二.nn.Module 三.模型容器Container 1.nn.Sequential 2.nn.ModuleList 3.nn.ModuleDi ...

  2. pytorch 入门学习多分类问题-9

    pytorch 入门学习多分类问题 运行结果 [1, 300] loss: 2.287[1, 600] loss: 2.137[1, 900] loss: 1.192 Accuracy on test ...

  3. pytorch 入门学习加载数据集-8

    pytorch 入门学习加载数据集 import torch import numpy as np import torchvision import numpy as np from torch.u ...

  4. pytorch 入门学习处理多维特征输入-7

    pytorch 入门学习处理多维特征输入 处理多维特征输入 import torch import numpy as np import torchvision import numpy as np ...

  5. pytorch 入门学习使用逻辑斯蒂做二分类-6

    pytorch 入门学习使用逻辑斯蒂做二分类 使用pytorch实现逻辑斯蒂做二分类 import torch import torchvision import numpy as np import ...

  6. pytorch 入门学习 实现线性回归-5

    pytorch 入门学习实现线性回归 使用pytorch实现线性回归 import numpy as np import matplotlib.pyplot as plt import torch#p ...

  7. pytorch 入门学习反向传播-4

    pytorch 入门学习反向传播 反向传播 import numpy as np import matplotlib.pyplot as plt import torchdef forward(x): ...

  8. 程序媛养成第0天--pytorch入门学习

    本篇基于<深度学习框架-pytorch入门与实践>陈云 有一起监督学习打卡的小伙伴请私信 2.2 pytorch入门第一步 2.2.1 Tensor # 分配矩阵空间但不初始化 #使用 [ ...

  9. pytorch 入门学习 MSE

    <PyTorch深度学习实践>完结合集-线性模型 import numpy as np import matplotlib.pyplot as pltx_data = [1.0,2.0,3 ...

  10. PCIe学习笔记之MSI/MSI-x中断及代码分析

    本文基于linux 5.7.0, 平台是arm64 1. MSI/MSI-X概述 PCIe有三种中断,分别为INTx中断,MSI中断,MSI-X中断,其中INTx是可选的,MSI/MSI-X是必须实现 ...

最新文章

  1. ASP .NET Core Web Razor Pages系列教程五:更新Razor Pages页面
  2. java 取得textfield_怎样获取java中textfield的内容
  3. HTML页面转换asp,将asp页面转换成html页面 代码
  4. 实战SSM_O2O商铺_26【商品类别】批量新增商品类别从Dao到View层的开发
  5. ubuntu下使用filezilla上传文件权限问题(open for write: permission denied)
  6. socket的NIO操作
  7. Magicodes.IE之快速导出Excel
  8. discuz仿手游控游戏论坛商业版网站模板
  9. PAT乙级 1094 谷歌的招聘(柳婼代码,测试点1、2、4、5分析)
  10. asp脚本和php脚本,有经典ASP的缓存脚本吗?
  11. Fixjs——显示基类DisplayObject
  12. java+整合handwrite_E-signature-master
  13. 国际学术期刊排名按照姓氏字母排吗?
  14. 深圳移动 神州行(大众卡/轻松卡/幸福卡)套餐资费(含香港日套餐)信息及使用方法...
  15. 建立时间与保持时间计算
  16. 智能手机中MEMS传感器应用浅析
  17. DDSM数据库转换图像格式——LJPEG转为PNG格式
  18. 计算两个时间的间隔时长
  19. 如何有效地召开会议?
  20. 移动互联网-2011 年值得关注的100个应用程序

热门文章

  1. 看看最新BTA大厂的Java程序员的招聘技术标准,Java篇
  2. MySQL数据库基础03 韩顺平 自学笔记
  3. Ruby电子书教程、经典脚本合集
  4. 高中计算机应用面试教资真题,2019下半年高中信息技术教师资格证面试试题(精选)第四批...
  5. Linux之——命令大全
  6. 第一章 .NET体系结构
  7. Drupal7学习笔记之Theme感觉非常好转来共享啊!
  8. Python爬取新东方在线网站大学英语六级词汇
  9. 简易数据分析 04 | Web Scraper 初尝--抓取豆瓣高分电影
  10. 流媒体/流媒体文件格式详解