1 首先没有detach的情况

定义了一系列操作,如下,中间结点y1和y2没有梯度。没有采取detach。

import torchw1 = torch.tensor([2.], requires_grad=True)
# print(w1.type())  # torch.FloatTensor
w2 = torch.tensor([4.], requires_grad=True)
w3 = torch.tensor([6.], requires_grad=True)
w4 = torch.tensor([8.], requires_grad=True)x = torch.tensor([10.])y1 = x * w1
y2 = y1 * w2print(y2.requires_grad)  # True
print(y2.is_leaf)  # False y2为中间结点z1 = w3 * y2
z2 = w4 * y2
z3 = z1 + z2
z3.backward()print(w1.grad)  # tensor([560.])
print(w2.grad)  # tensor([280.])
print(w3.grad)  # tensor([80.])
print(w4.grad)  # tensor([80.])
print(x.grad)  # None
print(y1.grad)  # None  中间节点
print(y2.grad)  # None  中间节点

2 detach

将y2 detach(),可以看到和y2之前操作有关的变量的梯度为空,即grad属性没有被赋值。

import torchw1 = torch.tensor([2.], requires_grad=True)
# print(w1.type())  # torch.FloatTensor
w2 = torch.tensor([4.], requires_grad=True)
w3 = torch.tensor([6.], requires_grad=True)
w4 = torch.tensor([8.], requires_grad=True)x = torch.tensor([10.])y1 = x * w1
y2 = y1 * w2
y2 = y2.detach()
print(y2.requires_grad)  # False
print(y2.is_leaf)  # Truez1 = w3 * y2
z2 = w4 * y2
z3 = z1 + z2
z3.backward()print(w1.grad)  # None
print(w2.grad)  # None
print(w3.grad)  # tensor([80.])
print(w4.grad)  # tensor([80.])
print(x.grad)  # None
print(y1.grad)  # None  中间节点
print(y2.grad)  # None  中间节点

3 应用

1 迁移学习,通常会先冻结部分网络层,可以使用detach函数。

2 GAN网络中,训练判别器D时,需要用到生成器G生成的图片fake,但是不想更新G的参数同时也不想往G里边求导,可以使用detach(),如

output = netD(fake.detach())
errD_fake = criterion(output, label)
errD_fake.backward()

3 参见文档

torch.tensor() always copies data. If you have a Tensor data and just want to change its requires_grad flag, use requires_grad_() or detach() to avoid a copy. If you have a numpy array and want to avoid a copy, use torch.as_tensor().

即,torch.tensor()会复制数据,如果想避免复制可以使用detach()函数,这样detach()返回的tensor和原来的tensor数据一样,但是没有发生复制。当修改了detach之后的tensor,原来的tensor数值也会改变。

import torch
import numpy as np
aa = np.array([[1., 2, 3], [4, 5, 6]]).astype(np.float32)
bb = torch.tensor(aa, requires_grad=True)  # 复制了数据
aa[0, 0] = 100
print(bb)
'''
tensor([[1., 2., 3.],[4., 5., 6.]], requires_grad=True)
'''print(aa)
'''
[[100.   2.   3.][  4.   5.   6.]]
'''cc = bb.detach()
print(cc.requires_grad)  # False
cc[0, 0] = 1000print(cc)
'''
tensor([[1000.,    2.,    3.],[   4.,    5.,    6.]])
'''print(bb)
'''
tensor([[1000.,    2.,    3.],[   4.,    5.,    6.]], requires_grad=True)
'''

4 其他

关于Pytorch中detach相关推荐

  1. 实践教程 | 浅谈 PyTorch 中的 tensor 及使用

    点击上方"视学算法",选择加"星标"或"置顶" 重磅干货,第一时间送达 作者 | xiaopl@知乎(已授权) 来源 | https://z ...

  2. Pytorch中的向前计算(autograd)、梯度计算以及实现线性回归操作

    在整个Pytorch框架中, 所有的神经网络本质上都是一个autograd package(自动求导工具包) autograd package提供了一个对Tensors上所有的操作进行自动微分的功能. ...

  3. 如何利用PyTorch中的Moco-V2减少计算约束

    介绍 SimCLR论文(http://cse.iitkgp.ac.in/~arastogi/papers/simclr.pdf)解释了这个框架如何从更大的模型和更大的批处理中获益,并且如果有足够的计算 ...

  4. Lesson 15.2 学习率调度在PyTorch中的实现方法

    Lesson 15.2 学习率调度在PyTorch中的实现方法   学习率调度作为模型优化的重要方法,也集成在了PyTorch的optim模块中.我们可以通过下述代码将学习率调度模块进行导入. fro ...

  5. Pytorch中的variable, tensor与numpy相互转化

    来源:https://blog.csdn.net/m0_37592397/article/details/88327248 1.将numpy矩阵转换为Tensor张量 sub_ts = torch.f ...

  6. 关于PyTorch中的register_forward_hook()函数未能执行其中hook函数的问题

    关于PyTorch中的register_forward_hook()函数未能执行其中hook函数的问题 Hook 是 PyTorch 中一个十分有用的特性.利用它,我们可以不必改变网络输入输出的结构, ...

  7. 更新fielddata为true_在pytorch中停止梯度流的若干办法,避免不必要模块的参数更新...

    在pytorch中停止梯度流的若干办法,避免不必要模块的参数更新 2020/4/11 FesianXu 前言 在现在的深度模型软件框架中,如TensorFlow和PyTorch等等,都是实现了自动求导 ...

  8. 利用Pytorch中深度学习网络进行多分类预测(multi-class classification)

    从下面的例子可以看出,在 Pytorch 中应用深度学习结构非常容易 执行多类分类任务. 在 iris 数据集的训练表现几乎是完美的. import torch.nn as nn import tor ...

  9. Pytorch中的梯度知识总结

    文章目录 1.叶节点.中间节点.梯度计算 2.叶子张量 leaf tensor (叶子节点) (detach) 2.1 为什么需要叶子节点? 2.2 detach()将节点剥离成叶子节点 2.3 什么 ...

  10. 详解Pytorch中的requires_grad、叶子节点与非叶子节点、with torch.no_grad()、model.eval()、model.train()、BatchNorm层

    requires_grad requires_grad意为是否需要计算梯度 使用backward()函数反向传播计算梯度时,并不是计算所有tensor的梯度,只有满足下面条件的tensor的梯度才会被 ...

最新文章

  1. Swift之Vision 图像识别框架
  2. Enterprise Vault 系列 [CA和DA]
  3. Java实现文件压缩与解压[zip格式,gzip格式]
  4. python 高并发 select socket_python – 使用select处理多个请求
  5. 上届作品回顾丨如何在 Innovation 2021 开发者大赛中脱颖而出?
  6. Minimum Triangulation
  7. PTA浙大版python程序设计题目集--第1章-2 从键盘输入三个数到a,b,c中,按公式值输出 (30 分)
  8. Boost:基于Boost的发送者和接收者的测试程序
  9. python --- 使用socket创建tcp服务
  10. Linux驱动开发基础
  11. Qt获取本机硬盘序列号,不受IDE硬盘与SCSI硬盘类型影响
  12. 如何提高FPGA工作频率?影响FPGA运行速度的几大因素
  13. 内存溢出(OOM)及解决方案
  14. GIT:cherry-pick挑拣提交
  15. PAT A1008 Elevator
  16. 使用canvas 绘制象棋棋盘
  17. 软件打开文件夹后闪退
  18. 数字图像处理(第二章)
  19. 查询linux下有多少用户,Linux 查看系统现存所有用户命令
  20. synplify user guide note1

热门文章

  1. 斑马旅游在千帆竞发的出境游市场能否找到属于自己的道路?
  2. html 规定输入框必须输入
  3. Java多线程并发笔记01 对象锁 类锁 对象锁的同步和异步 脏读
  4. 【Python】Base64编码和解码
  5. 查看html代码来下载mp4视频的一次记录
  6. 匈牙利命名法为何被淘汰_体育午报:15年魔咒破除!国足淘汰赛终迎一胜
  7. emacs terminal
  8. 号称“不限速“的阿里网盘,官宣要停止了,寿命仅仅1年
  9. 计组头哥实验 第1关 8位可控加减法电路设计
  10. 手机电源键关不了屏幕_手机关机关不了,屏幕也划不了,怎么办