关于Pytorch中detach
1 首先没有detach的情况
定义了一系列操作,如下,中间结点y1和y2没有梯度。没有采取detach。
import torchw1 = torch.tensor([2.], requires_grad=True)
# print(w1.type()) # torch.FloatTensor
w2 = torch.tensor([4.], requires_grad=True)
w3 = torch.tensor([6.], requires_grad=True)
w4 = torch.tensor([8.], requires_grad=True)x = torch.tensor([10.])y1 = x * w1
y2 = y1 * w2print(y2.requires_grad) # True
print(y2.is_leaf) # False y2为中间结点z1 = w3 * y2
z2 = w4 * y2
z3 = z1 + z2
z3.backward()print(w1.grad) # tensor([560.])
print(w2.grad) # tensor([280.])
print(w3.grad) # tensor([80.])
print(w4.grad) # tensor([80.])
print(x.grad) # None
print(y1.grad) # None 中间节点
print(y2.grad) # None 中间节点
2 detach
将y2 detach(),可以看到和y2之前操作有关的变量的梯度为空,即grad属性没有被赋值。
import torchw1 = torch.tensor([2.], requires_grad=True)
# print(w1.type()) # torch.FloatTensor
w2 = torch.tensor([4.], requires_grad=True)
w3 = torch.tensor([6.], requires_grad=True)
w4 = torch.tensor([8.], requires_grad=True)x = torch.tensor([10.])y1 = x * w1
y2 = y1 * w2
y2 = y2.detach()
print(y2.requires_grad) # False
print(y2.is_leaf) # Truez1 = w3 * y2
z2 = w4 * y2
z3 = z1 + z2
z3.backward()print(w1.grad) # None
print(w2.grad) # None
print(w3.grad) # tensor([80.])
print(w4.grad) # tensor([80.])
print(x.grad) # None
print(y1.grad) # None 中间节点
print(y2.grad) # None 中间节点
3 应用
1 迁移学习,通常会先冻结部分网络层,可以使用detach函数。
2 GAN网络中,训练判别器D时,需要用到生成器G生成的图片fake,但是不想更新G的参数同时也不想往G里边求导,可以使用detach(),如
output = netD(fake.detach())
errD_fake = criterion(output, label)
errD_fake.backward()
3 参见文档
torch.tensor() always copies data. If you have a Tensor data and just want to change its requires_grad flag, use requires_grad_() or detach() to avoid a copy. If you have a numpy array and want to avoid a copy, use torch.as_tensor().
即,torch.tensor()会复制数据,如果想避免复制可以使用detach()函数,这样
detach()返回的tensor和原来的tensor数据一样,但是没有发生复制。当修改了detach之后的tensor,原来的tensor数值也会改变。
如
import torch
import numpy as np
aa = np.array([[1., 2, 3], [4, 5, 6]]).astype(np.float32)
bb = torch.tensor(aa, requires_grad=True) # 复制了数据
aa[0, 0] = 100
print(bb)
'''
tensor([[1., 2., 3.],[4., 5., 6.]], requires_grad=True)
'''print(aa)
'''
[[100. 2. 3.][ 4. 5. 6.]]
'''cc = bb.detach()
print(cc.requires_grad) # False
cc[0, 0] = 1000print(cc)
'''
tensor([[1000., 2., 3.],[ 4., 5., 6.]])
'''print(bb)
'''
tensor([[1000., 2., 3.],[ 4., 5., 6.]], requires_grad=True)
'''
4 其他
关于Pytorch中detach相关推荐
- 实践教程 | 浅谈 PyTorch 中的 tensor 及使用
点击上方"视学算法",选择加"星标"或"置顶" 重磅干货,第一时间送达 作者 | xiaopl@知乎(已授权) 来源 | https://z ...
- Pytorch中的向前计算(autograd)、梯度计算以及实现线性回归操作
在整个Pytorch框架中, 所有的神经网络本质上都是一个autograd package(自动求导工具包) autograd package提供了一个对Tensors上所有的操作进行自动微分的功能. ...
- 如何利用PyTorch中的Moco-V2减少计算约束
介绍 SimCLR论文(http://cse.iitkgp.ac.in/~arastogi/papers/simclr.pdf)解释了这个框架如何从更大的模型和更大的批处理中获益,并且如果有足够的计算 ...
- Lesson 15.2 学习率调度在PyTorch中的实现方法
Lesson 15.2 学习率调度在PyTorch中的实现方法 学习率调度作为模型优化的重要方法,也集成在了PyTorch的optim模块中.我们可以通过下述代码将学习率调度模块进行导入. fro ...
- Pytorch中的variable, tensor与numpy相互转化
来源:https://blog.csdn.net/m0_37592397/article/details/88327248 1.将numpy矩阵转换为Tensor张量 sub_ts = torch.f ...
- 关于PyTorch中的register_forward_hook()函数未能执行其中hook函数的问题
关于PyTorch中的register_forward_hook()函数未能执行其中hook函数的问题 Hook 是 PyTorch 中一个十分有用的特性.利用它,我们可以不必改变网络输入输出的结构, ...
- 更新fielddata为true_在pytorch中停止梯度流的若干办法,避免不必要模块的参数更新...
在pytorch中停止梯度流的若干办法,避免不必要模块的参数更新 2020/4/11 FesianXu 前言 在现在的深度模型软件框架中,如TensorFlow和PyTorch等等,都是实现了自动求导 ...
- 利用Pytorch中深度学习网络进行多分类预测(multi-class classification)
从下面的例子可以看出,在 Pytorch 中应用深度学习结构非常容易 执行多类分类任务. 在 iris 数据集的训练表现几乎是完美的. import torch.nn as nn import tor ...
- Pytorch中的梯度知识总结
文章目录 1.叶节点.中间节点.梯度计算 2.叶子张量 leaf tensor (叶子节点) (detach) 2.1 为什么需要叶子节点? 2.2 detach()将节点剥离成叶子节点 2.3 什么 ...
- 详解Pytorch中的requires_grad、叶子节点与非叶子节点、with torch.no_grad()、model.eval()、model.train()、BatchNorm层
requires_grad requires_grad意为是否需要计算梯度 使用backward()函数反向传播计算梯度时,并不是计算所有tensor的梯度,只有满足下面条件的tensor的梯度才会被 ...
最新文章
- Swift之Vision 图像识别框架
- Enterprise Vault 系列 [CA和DA]
- Java实现文件压缩与解压[zip格式,gzip格式]
- python 高并发 select socket_python – 使用select处理多个请求
- 上届作品回顾丨如何在 Innovation 2021 开发者大赛中脱颖而出?
- Minimum Triangulation
- PTA浙大版python程序设计题目集--第1章-2 从键盘输入三个数到a,b,c中,按公式值输出 (30 分)
- Boost:基于Boost的发送者和接收者的测试程序
- python --- 使用socket创建tcp服务
- Linux驱动开发基础
- Qt获取本机硬盘序列号,不受IDE硬盘与SCSI硬盘类型影响
- 如何提高FPGA工作频率?影响FPGA运行速度的几大因素
- 内存溢出(OOM)及解决方案
- GIT:cherry-pick挑拣提交
- PAT A1008 Elevator
- 使用canvas 绘制象棋棋盘
- 软件打开文件夹后闪退
- 数字图像处理(第二章)
- 查询linux下有多少用户,Linux 查看系统现存所有用户命令
- synplify user guide note1
热门文章
- 斑马旅游在千帆竞发的出境游市场能否找到属于自己的道路?
- html 规定输入框必须输入
- Java多线程并发笔记01 对象锁 类锁 对象锁的同步和异步 脏读
- 【Python】Base64编码和解码
- 查看html代码来下载mp4视频的一次记录
- 匈牙利命名法为何被淘汰_体育午报:15年魔咒破除!国足淘汰赛终迎一胜
- emacs terminal
- 号称“不限速“的阿里网盘,官宣要停止了,寿命仅仅1年
- 计组头哥实验 第1关 8位可控加减法电路设计
- 手机电源键关不了屏幕_手机关机关不了,屏幕也划不了,怎么办