pytorch中深度拷贝_pytorch:对比clone、detach以及copy_等张量复制操作
pytorch提供了clone、detach、copy_和new_tensor等多种张量的复制操作,尤其前两者在深度学习的网络架构中经常被使用,本文旨在对比这些操作的差别。
1. clone
返回一个和源张量同shape、dtype和device的张量,与源张量不共享数据内存,但提供梯度的回溯。
下面,通过例子来详细说明:
示例:
(1)定义
import torch
a = torch.tensor(1.0, requires_grad=True, device="cuda", dtype=torch.float64)
a_ = a.clone()
print(a_) # tensor(1., device='cuda:0', dtype=torch.float64, grad_fn=)
注意:grad_fn=,说明clone后的返回值是个中间variable,因此支持梯度的回溯。因此,clone操作在一定程度上可以视为是一个identity-mapping函数。
(2)梯度的回溯
clone作为一个中间variable,会将梯度传给源张量进行叠加。
import torch
a = torch.tensor(1.0, requires_grad=True)
y = a ** 2
a_ = a.clone()
z = a_ * 3
y.backward()
print(a.grad) # 2
z.backward()
print(a_.grad) # None. 中间variable,无grad
print(a.grad) # 5. a_的梯度会传递回给a,因此2+3=5
但若源张量的require_grad=False,而clone后的张量require_grad=True,显然此时不存在张量回溯现象,clone后的张量可以求导。
import torch
a = torch.tensor(1.0)
a_ = a.clone()
a_.requires_grad_()
y = a_ ** 2
y.backward()
print(a.grad) # None
print(a_.grad) # 2. 可得到导数
(3)张量数据非共享
import torch
a = torch.tensor(1.0, requires_grad=True)
a_ = a.clone()
a.data *= 3
a_ += 1
print(a) # tensor(3., requires_grad=True)
print(a_) # tensor(2., grad_fn=). 注意grad_fn的变化
综上论述,clone操作在不共享数据内存的同时支持梯度回溯,所以常用在神经网络中某个单元需要重复使用的场景下。
2. detach
detach的机制则与clone完全不同,即返回一个和源张量同shape、dtype和device的张量,与源张量共享数据内存,但不提供梯度计算,即requires_grad=False,因此脱离计算图。
同样,通过例子来详细说明:
(1)定义
import torch
a = torch.tensor(1.0, requires_grad=True, device="cuda", dtype=torch.float64)
a_ = a.detach()
print(a_) # tensor(1., device='cuda:0', dtype=torch.float64)
(2)脱离原计算图
import torch
a = torch.tensor(1.0, requires_grad=True)
y = a ** 2
a_ = a.detach()
print(a_.grad) # None,requires_grad=False
a_.requires_grad_() # 强制其requires_grad=True,从而支持求导
z = a_ * 3
y.backward()
z.backward()
print(a.grad) # 2,与a_无关系
print(a_.grad) #
可见,detach后的张量,即使重新定义requires_grad=True,也与源张量的梯度没有关系。
(3)共享张量数据内存
import torch
a = torch.tensor(1.0, requires_grad=True)
a_ = a.detach()
print(a) # tensor(1., requires_grad=True)
print(a_) # tensor(1.)
a_ += 1
print(a) # tensor(2., requires_grad=True)
print(a_) # tensor(2.)
a.data *= 2
print(a) # tensor(4., requires_grad=True)
print(a_) # tensor(4.)
综上论述,detach操作在共享数据内存的脱离计算图,所以常用在神经网络中仅要利用张量数值,而不需要追踪导数的场景下。
3. clone和detach联合使用
clone提供了非数据共享的梯度追溯功能,而detach又“舍弃”了梯度功能,因此clone和detach意味着着只做简单的数据复制,既不数据共享,也不对梯度共享,从此两个张量无关联。
置于是先clone还是先detach,其返回值一样,一般采用tensor.clone().detach()。
4. new_tensor
new_tensor可以将源张量中的数据复制到目标张量(数据不共享),同时提供了更细致的device、dtype和requires_grad属性控制:
new_tensor(data, dtype=None, device=None, requires_grad=False)
注意:其默认参数下的操作等同于.clone().detach(),而requires_grad=True时的效果相当于.clone().detach()requires_grad_(True)。上面两种情况都推荐使用后者。
5. copy_
copy_同样将源张量中的数据复制到目标张量(数据不共享),其device、dtype和requires_grad一般都保留目标张量的设定,仅仅进行数据复制,同时其支持broadcast操作。
a = torch.tensor([[1,2,3], [4,5,6]], device="cuda")
b = torch.tensor([7.0,8.0,9.0], requires_grad=True)
a.copy_(b)
print(a) # tensor([[7, 8, 9], [7, 8, 9]], device='cuda:0')
【Ref】:
pytorch中深度拷贝_pytorch:对比clone、detach以及copy_等张量复制操作相关推荐
- pytorch中深度拷贝_在ubuntu20.04下搭建深度学习环境(pytorch1.5)
首先声明,完成这个环境的搭建不是我一个人的努力,是我在网上查找好多相关资料并得益于他们的帮助完成的,在下文对应的地方会放上相关链接.整个过程我将它分为4个步骤. 一.在ubuntu20.04上安装py ...
- 利用Pytorch中深度学习网络进行多分类预测(multi-class classification)
从下面的例子可以看出,在 Pytorch 中应用深度学习结构非常容易 执行多类分类任务. 在 iris 数据集的训练表现几乎是完美的. import torch.nn as nn import tor ...
- pytorch张量复制clone()和detach()
1. pytorch张量复制clone()和detach() https://blog.csdn.net/Answer3664/article/details/104417013 2. [Pytorc ...
- 【深度学习】在PyTorch中使用 LSTM 进行新冠病例预测
时间序列数据,顾名思义是一种随时间变化的数据.例如,24 小时时间段内的温度,一个月内各种产品的价格,特定公司一年内的股票价格.长短期记忆网络(LSTM)等高级深度学习模型能够捕捉时间序列数据中的模式 ...
- 详解PyTorch中的copy_()函数、detach()函数、detach_()函数和clone()函数
参考链接: copy_(src, non_blocking=False) → Tensor 参考链接: detach() 参考链接: detach_() 参考链接: clone() → Tensor ...
- 深度拷贝 java_Java深度拷贝方式和性能对比
前言 Java的深度拷贝大致分为克隆(实现Java的Clone接口)和序列化(实现Java的Serializable接口)两种,但是基于不同的序列化方式,有可以延伸出几种方式.下面分析一下每种的注意事 ...
- DL:深度学习框架Pytorch、 Tensorflow各种角度对比
DL:深度学习框架Pytorch. Tensorflow各种角度对比 目录 先看两个框架实现同样功能的代码 1.Pytorch.Tensorflow代码比较 2.Tensorflow(数据即是代码,代 ...
- 【深度学习理论】一文搞透pytorch中的tensor、autograd、反向传播和计算图
转载:https://zhuanlan.zhihu.com/p/145353262 前言 本文的主要目标: 一遍搞懂反向传播的底层原理,以及其在深度学习框架pytorch中的实现机制.当然一遍搞不定两 ...
- Pytorch:深度学习中pytorch/torchvision版本和CUDA版本最正确版本匹配、对应版本安装之详细攻略
Pytorch:深度学习中pytorch/torchvision版本和CUDA版本最正确版本匹配.对应版本安装之详细攻略 目录 深度学习中pytorch/torchvision版本和CUDA版本最正确 ...
最新文章
- RDKit | 基于多片段的分子生成(骨架A+骨架B+骨架C)
- Cocos2dx游戏开发系列笔记6:怎样让《萝莉快跑》的例子运行在vs和手机上
- linux find 反转 查找没有被找到的结果
- 软件工程综合实践阶段小结(2)
- 离散数学范式c语言实验报告,离散数学实验报告-利用真值表法求主析取范式及主合取范式的实现...
- ambari hive mysql_ambari方式安装hadoop的hive组件遇到的问题
- python中json模块读写数据
- wince 本地播放器界面
- Java加密与解密的艺术~安全协议~单向认证服务
- Linux搜寻文件或目录命令解析
- php redisson,排查redisson中订阅connection无故消失的问题
- 那些不需要你知道的Chrome DevTool - 使用技巧篇
- C语言学生成绩排名系统
- Arduino 和 雨滴传感器 滴水实验
- 采集屏幕编码H264
- 黑苹果内置硬盘识别成外置硬盘_空间大?安全更重要,麦沃 K35274D硬盘阵列盒使用体验--数据无价...
- 快速部署PHP Web环境(docker nginx php mysql redis)
- [机器学习与数据分析] 数据分析常用方法
- BIGEMAP下载等高线(高程)使用教程
- Android手机直播(三)声音采集
热门文章
- python程序员怎么面试_Python程序员面试,这些问题你必须提前准备!
- python用Matplotlib画箭头
- MFC中树形控件的应用——电话簿
- 概率论基础知识各种分布
- 小甲鱼 OllyDbg 教程系列 (六) :PJ 软件功能限制(不修改jnz的非爆破方法)
- oracle查询语句大全(oracle 基本命令大全一)
- 简明Python教程学习笔记_6_面向对象编程
- Spring Data JPA 禁止自动更新
- 【加解密学习笔记:第二天】动态调试工具OllyDbg使用基础介绍
- Modbus协议栈综合实例设计