在Pytorch中,我们有时候会进行多个loss的回传,然而回传中,会发生一些错误。例如:

RuntimeError: Trying to backward through the graph a second time, but the saved intermediate results have already been freed. Specify retain_graph=True when calling backward the first time.

以下我们举几个回传例子便可理解:

1、当我们对同一个loss进行多次回传时:即

loss.backward()
loss.backward()

上述回传方式必然出错。这时我们只需要在backward()中加入参数retain_graph=True后,便可正常回传。此时两次的回传损失会叠加。需要注意,当我们的对相同的loss回传多次,只有最后一次不加retain_graph参数外,其余都得加,否则报错。例子如下:

import torch
from torch.autograd import Variablex = Variable(torch.FloatTensor([3]), requires_grad=True)
y = x * 2 + x ** 2 + 3
print(y)
y.backward(retain_graph=True)  # 设置 retain_graph 为 True 来保留计算图
print(x.grad)
y.backward()  # 再做一次自动求导,这次不保留计算图
print(x.grad)# 输出
# tensor([18.], grad_fn=<AddBackward0>)
# tensor([8.])
# tensor([16.])

2、当然,以上是对同一个loss进行回传。那么对多个不同loss回传呢?例如:

loss1.backward()
loss2.backward()

此时是可以正常回传的,且两次的回传结果会进行叠加。例子:

import torchx = torch.tensor(2.0, requires_grad=True)
y = x**2
z = x
# 反向传播
y.backward()
print(x.grad)
# tensor(4.)
z.backward()
print(x.grad)
# tensor(5.) ## 累加

3、当然,以上回传我们还可以加在一起,一并回传,那么梯度也会叠加,同上面的结果等价。例如:

loss = loss1 + loss2
loss.backward()

4、但是,在有些时候,我们会同时训练两个网络,例如生成对抗网络(GAN)。我们在利用方式3的回传时,也会报错,因为两个网络之间有了交叉。这时候,我们就需要用上方式2的分步回传了,结果是不变的。然而,需要注意:我们的回传某个网络的loss时,是不能有其他网络输出的可求导数据的,也就是我们在将其他网络的输出传入需要回传的网络进行结果的损失计算时,需要将其他网络的输出加上detach()才不会报错。例子:这里我们回传D网络,所以G网络得加detach()。

fake = netG(noise)
output = netD(fake.detach()) # 加上detach()errD_fake = loss_function(output, label)
errD_fake.backward()

Pytorch的反向传播backward()详解相关推荐

  1. pytorch学习 -- 反向传播backward

    pytorch学习 – 反向传播backward 入门学习pytorch,在查看pytorch代码时对autograd的backard函数的使用和具体原理有些疑惑,在查看相关文章和学习后,简单说下我自 ...

  2. 反向算法_10分钟带你了解神经网络基础:反向传播算法详解

    作者:Great Learning Team deephub.ai 翻译组 1.神经网络 2.什么是反向传播? 3.反向传播是如何工作的? 4.损失函数 5.为什么我们需要反向传播? 6.前馈网络 7 ...

  3. Pytorch autograd.grad与autograd.backward详解

    Pytorch autograd.grad与autograd.backward详解 引言 平时在写 Pytorch 训练脚本时,都是下面这种无脑按步骤走: outputs = model(inputs ...

  4. Pytorch|YOWO原理及代码详解(二)

    Pytorch|YOWO原理及代码详解(二) 本博客上接,Pytorch|YOWO原理及代码详解(一),阅前可看. 1.正式训练 if opt.evaluate:logging('evaluating ...

  5. Pytorch | yolov3原理及代码详解(二)

    阅前可看: Pytorch | yolov3原理及代码详解(一) https://blog.csdn.net/qq_24739717/article/details/92399359 分析代码: ht ...

  6. 人工智能-作业1:PyTorch实现反向传播

    人工智能-作业1:PyTorch实现反向传播 人工智能-作业1:PyTorch实现反向传播 环境配置: 计算过程 反向传播 PyTorch Autograd自动求导 人工智能-作业1:PyTorch实 ...

  7. List逆向遍历、反向遍历--Iterator详解

    List逆向遍历.反向遍历–Iterator详解 概述 在使用java集合的时候,都需要使用Iterator.但是java集合中还有一个迭代器ListIterator,在使用List.ArrayLis ...

  8. Nginx反向代理配置详解

    Nginx反向代理配置详解 Nginx简单的反向代理配置,包括配置文件中各项参数的的注释,好了,开始! 开始首先安装Nginx 一.建立用户和用户组 1 2 ./usr/sbin/groupadd w ...

  9. 【Pytorch】torch.argmax 函数详解

    文章目录 一.一个参数时的 torch.argmax 函数 1. 介绍 2. 实例 二.多个参数时的 torch.argmax 函数 1. 介绍 2. 实例 实例1:二维矩阵 实例2:三维矩阵 实例3 ...

最新文章

  1. linux环境安装部署mark
  2. 【数据结构与算法】之深入解析“重新安排行程”的求解思路与算法示例
  3. python中多维数组_python学习笔记-多维数组
  4. Javascript基础学习12问(四)
  5. 电话机器人图文+源码介绍
  6. 42个最好的海外 app ASO工具
  7. 中文如何翻译成英文?手机中英文一键翻译超简单
  8. 05.看板方法——在制品
  9. 女装网 www.nzw.com.cn
  10. MATLAB学习【第五部分】--第一节:矩阵的输入//冒号表达式矩阵---linspace函数生成向量---一般矩阵输入
  11. 2022年交通工具公开拍卖市场研究报告
  12. “左眼跳财,右眼跳灾”
  13. 三分钟读懂双十二布局玩法,大促流量销量双翻倍so easy!
  14. fiddler拦截模拟器中app的请求设置方法
  15. 人类遗传密码97%待解读
  16. win10虚拟机创建
  17. mysql 多个字段排序
  18. [HNOI2010] 平面图判定
  19. 萌萌媛の【剑指offer笔记】二维数组中的查找
  20. java代码命名规范

热门文章

  1. ftp搭建方式-server-u安装步骤
  2. Storm实时处理架构
  3. 爬虫工具可以干什么_10个爬虫工程师必备的工具了解一哈
  4. 种植牙好不好?该怎么选择?
  5. 笔记本python安装教程_《笔》字意思读音、组词解释及笔画数 - 新华字典 - 911查询...
  6. 5年经验,没听过XFF漏洞
  7. 工业机器人入门z50的含义_工业机器人轴座系名称有什么
  8. 对一个项目如何写一个方案?
  9. css中appearance:none修改select option样式
  10. 总结移动端video视频播放的坑