Pytorch的反向传播backward()详解
在Pytorch中,我们有时候会进行多个loss的回传,然而回传中,会发生一些错误。例如:
RuntimeError: Trying to backward through the graph a second time, but the saved intermediate results have already been freed. Specify retain_graph=True when calling backward the first time.
以下我们举几个回传例子便可理解:
1、当我们对同一个loss进行多次回传时:即
loss.backward()
loss.backward()
上述回传方式必然出错。这时我们只需要在backward()中加入参数retain_graph=True后,便可正常回传。此时两次的回传损失会叠加。需要注意,当我们的对相同的loss回传多次,只有最后一次不加retain_graph参数外,其余都得加,否则报错。例子如下:
import torch
from torch.autograd import Variablex = Variable(torch.FloatTensor([3]), requires_grad=True)
y = x * 2 + x ** 2 + 3
print(y)
y.backward(retain_graph=True) # 设置 retain_graph 为 True 来保留计算图
print(x.grad)
y.backward() # 再做一次自动求导,这次不保留计算图
print(x.grad)# 输出
# tensor([18.], grad_fn=<AddBackward0>)
# tensor([8.])
# tensor([16.])
2、当然,以上是对同一个loss进行回传。那么对多个不同loss回传呢?例如:
loss1.backward()
loss2.backward()
此时是可以正常回传的,且两次的回传结果会进行叠加。例子:
import torchx = torch.tensor(2.0, requires_grad=True)
y = x**2
z = x
# 反向传播
y.backward()
print(x.grad)
# tensor(4.)
z.backward()
print(x.grad)
# tensor(5.) ## 累加
3、当然,以上回传我们还可以加在一起,一并回传,那么梯度也会叠加,同上面的结果等价。例如:
loss = loss1 + loss2
loss.backward()
4、但是,在有些时候,我们会同时训练两个网络,例如生成对抗网络(GAN)。我们在利用方式3的回传时,也会报错,因为两个网络之间有了交叉。这时候,我们就需要用上方式2的分步回传了,结果是不变的。然而,需要注意:我们的回传某个网络的loss时,是不能有其他网络输出的可求导数据的,也就是我们在将其他网络的输出传入需要回传的网络进行结果的损失计算时,需要将其他网络的输出加上detach()才不会报错。例子:这里我们回传D网络,所以G网络得加detach()。
fake = netG(noise)
output = netD(fake.detach()) # 加上detach()errD_fake = loss_function(output, label)
errD_fake.backward()
Pytorch的反向传播backward()详解相关推荐
- pytorch学习 -- 反向传播backward
pytorch学习 – 反向传播backward 入门学习pytorch,在查看pytorch代码时对autograd的backard函数的使用和具体原理有些疑惑,在查看相关文章和学习后,简单说下我自 ...
- 反向算法_10分钟带你了解神经网络基础:反向传播算法详解
作者:Great Learning Team deephub.ai 翻译组 1.神经网络 2.什么是反向传播? 3.反向传播是如何工作的? 4.损失函数 5.为什么我们需要反向传播? 6.前馈网络 7 ...
- Pytorch autograd.grad与autograd.backward详解
Pytorch autograd.grad与autograd.backward详解 引言 平时在写 Pytorch 训练脚本时,都是下面这种无脑按步骤走: outputs = model(inputs ...
- Pytorch|YOWO原理及代码详解(二)
Pytorch|YOWO原理及代码详解(二) 本博客上接,Pytorch|YOWO原理及代码详解(一),阅前可看. 1.正式训练 if opt.evaluate:logging('evaluating ...
- Pytorch | yolov3原理及代码详解(二)
阅前可看: Pytorch | yolov3原理及代码详解(一) https://blog.csdn.net/qq_24739717/article/details/92399359 分析代码: ht ...
- 人工智能-作业1:PyTorch实现反向传播
人工智能-作业1:PyTorch实现反向传播 人工智能-作业1:PyTorch实现反向传播 环境配置: 计算过程 反向传播 PyTorch Autograd自动求导 人工智能-作业1:PyTorch实 ...
- List逆向遍历、反向遍历--Iterator详解
List逆向遍历.反向遍历–Iterator详解 概述 在使用java集合的时候,都需要使用Iterator.但是java集合中还有一个迭代器ListIterator,在使用List.ArrayLis ...
- Nginx反向代理配置详解
Nginx反向代理配置详解 Nginx简单的反向代理配置,包括配置文件中各项参数的的注释,好了,开始! 开始首先安装Nginx 一.建立用户和用户组 1 2 ./usr/sbin/groupadd w ...
- 【Pytorch】torch.argmax 函数详解
文章目录 一.一个参数时的 torch.argmax 函数 1. 介绍 2. 实例 二.多个参数时的 torch.argmax 函数 1. 介绍 2. 实例 实例1:二维矩阵 实例2:三维矩阵 实例3 ...
最新文章
- linux环境安装部署mark
- 【数据结构与算法】之深入解析“重新安排行程”的求解思路与算法示例
- python中多维数组_python学习笔记-多维数组
- Javascript基础学习12问(四)
- 电话机器人图文+源码介绍
- 42个最好的海外 app ASO工具
- 中文如何翻译成英文?手机中英文一键翻译超简单
- 05.看板方法——在制品
- 女装网 www.nzw.com.cn
- MATLAB学习【第五部分】--第一节:矩阵的输入//冒号表达式矩阵---linspace函数生成向量---一般矩阵输入
- 2022年交通工具公开拍卖市场研究报告
- “左眼跳财,右眼跳灾”
- 三分钟读懂双十二布局玩法,大促流量销量双翻倍so easy!
- fiddler拦截模拟器中app的请求设置方法
- 人类遗传密码97%待解读
- win10虚拟机创建
- mysql 多个字段排序
- [HNOI2010] 平面图判定
- 萌萌媛の【剑指offer笔记】二维数组中的查找
- java代码命名规范