Variable 与 Tensor

tensor 是 PyTorch 中的完美组件,高效的数据格式,但是构建神经网络还远远不够,我们需要能够构建计算图的 tensor,这就是 Variable。Variable 是对 tensor 的封装,操作和 tensor 是一样的,但是每个 Variabel都有三个属性,Variable 中的 tensor本身.data,对应 tensor 的梯度.grad以及这个 Variable 是通过什么方式得到的.grad_fn,是由什么函数得到的张量,如果是自己创建的,则维None

# 通过下面这种方式导入 Variable
import torch
from torch.autograd import Variablex_tensor = torch.randn(4, 5)
y_tensor = torch.randn(4, 5)

1. requires_grad

# 将 tensor 变成 Variable
x = Variable(x_tensor)
x
tensor([[ 0.6031, -0.6642,  1.0491, -0.5876,  0.6080],[ 0.9331, -1.8954,  1.2234,  0.1483,  1.0758],[-0.5292,  1.3870, -1.6189,  1.0741,  0.9438],[ 1.4417,  0.7225, -1.2392, -0.1838,  1.3174]])

默认 Variable 是不需要求梯度的,所以我们用这个方式申明需要对其进行求梯度

# requires_grad 申明需要对其进行求梯度
x = Variable(x_tensor, requires_grad=True)
y = Variable(y_tensor, requires_grad=True)

例:定义$ z=\sum(x+y),分别求,分别求,分别求x和和和y$的梯度。

z = torch.sum(x+y)
# 数据类型
z.type()
'torch.FloatTensor'
print(z.data)
tensor(6.5849)
print(z.grad_fn)
<SumBackward0 object at 0x0000026452B1C748>

上面我们打出了zzz 中的 tensor 数值,同时通过grad_fn知道了其是通过 Sum 这种方式得到的。

# 求 x 和 y 的梯度
z.backward()
print(x.grad)
print(y.grad)
tensor([[1., 1., 1., 1., 1.],[1., 1., 1., 1., 1.],[1., 1., 1., 1., 1.],[1., 1., 1., 1., 1.]])
tensor([[1., 1., 1., 1., 1.],[1., 1., 1., 1., 1.],[1., 1., 1., 1., 1.],[1., 1., 1., 1., 1.]])

通过.grad我们得到了 xxx 和 yyy 的梯度,这里我们使用了 PyTorch 提供的自动求导机制(见自动求导)

练习:尝试构建y=x2y=x^2y=x2,然后求x=3x=3x=3出的梯度。
求梯度的时候,一定要把变量转成Variable,并且指定需要梯度

# x = Variable(torch.tensor(2,dtype=torch.float64),requires_grad=True)
x = Variable(torch.Tensor([3]),requires_grad=True)
x
tensor([3.], requires_grad=True)
y = x**2
y
tensor([9.], grad_fn=<PowBackward0>)
print(y.grad_fn)
<PowBackward0 object at 0x0000026452B231C8>
y.backward()
print(x.grad)
tensor([6.])

在上面的俩个例子中,我们已经指定了需要自动求导

如果有一个单一的输入操作需要梯度,它的输出也需要梯度。相反,只有所有输入都不需要梯度,输出才不需要。如果其中所有的变量都不需要梯度进行,后向计算不会在子图中执行。具体如下:

x = Variable(torch.randn(5, 5))
y = Variable(torch.randn(5, 5))
z = Variable(torch.randn(5, 5), requires_grad=True)
a = x + y
a.requires_grad
False

输入xxx和yyy不需要梯度,aaa便不需要梯度

b = a + z
b.requires_grad
True
True

输入zzz需要梯度,则aaa需要梯度

这个标志特别有用,在实际运用过程中,当需要冻结部分模型,或者事先知道不会使用某些参数的梯度,可以不对其指定True。

例如,如果要对预先训练的CNN进行优化,只要切换冻结模型中的requires_grad标志就足够了,直到计算到最后一层才会保存中间缓冲区,其中的仿射变换将使用需要梯度的权重并且网络的输出也将需要它们。

import torchvision
model = torchvision.models.resnet18(pretrained=True)for param in model.parameters():param.requires_grad = False
# Replace the last fully-connected layer
# Parameters of newly constructed modules have requires_grad=True by default
model.fc = nn.Linear(512, 100)# Optimize only the classifier
optimizer = optim.SGD(model.fc.parameters(), lr=1e-2, momentum=0.9)

总结
pytorch的variable是一个存放会变化值的地理位置,里面的值会不停变化,像装糖果(糖果就是数据,即tensor)的盒子,糖果的数量不断变化。pytorch都是由tensor计算的,而tensor里面的参数是variable形式。

autograd根据用户对Variable的操作来构建其计算图。 variable默认是不需要被求导的,即requires_grad属性默认为False,如果某一个节点的requires_grad为True,那么所有依赖它的节点requires_grad都为True。

2. 具体运用

我们现在已经知道了PyTorch为了实现GPU加速功能,引入了Tensor,为了实现自动求导功能引入了Variable。我们一般读取的数据都是以Numpy Array方式的。在TensorFlow,Numpy的数据会在输入网络后自动转换为Tensor,一般不需要我们进行显性操作,但是在PyTorch,需要我们自己进行显性操作才可以的。

在一个网络训练过程中:

  1. 首先我们会用NumPy读取数据格式为ndarray
  2. 我们为了能够送入网络,使用GPU计算加速,所以要进行Numpy2Tensor操作,把数据转成Tensor
  3. 由于网络输入输出都是Variable,我们还需要Tensor2Variable,数据成可以构建计算图的Variable。
  4. 在训练的过程中,我们需要取出loss的值并打印,由于loss参与了backward(),所以此时的loss已经变成了Variable,我们取出loss时需要取出的是Tensor。同样的,如果我想取出网络输出的结果时,由于网络输入输出都是Variable,也需要执行Variable2Tensor,如果进一步我们想把loss显示出来,就需要Tensor2Numpy。

转换方法:
Numpy2Tensor:torch.from_numpy(Numpy_data)torch.tensor(Numpy_data)

Tensor2Variable: Variable(Tensor_data)

Variable2Tensor: Variable_data.data()提出数据

Tensor2Numpy : Tensor_data.numpy()

注意一点,Numpy与Variable无法直接转换,需要经过Tensor作为中介。

重点:

1.新版本中,torch.autograd.Variabletorch.Tensor 将同属一类。更确切地说,torch.Tensor 能够追踪日志并像旧版本的 Variable那样运行;Variable 封装仍旧可以像以前一样工作,但返回的对象类型是torch.Tensor。这意味着你的代码不再需要变量封装器。

2.作为 autograd 方法的核心标志,requires_grad 现在是 Tensors 类的一个属性。
autograd 使用先前用于 Variable 的相同规则。当操作中任意输入 Tensor 的 require_grad = True 时,它开始跟踪历史记录。如:

>>> w = torch.ones(1, requires_grad=True)
>>> w.requires_grad
True

3.除了直接设置属性之外,你还可以使用 my_tensor.requires_grad_(requires_grad = True) 在原地更改此标志,如:

>>> my_tensor = torch.zeros(3, 4, requires_grad=True)
>>> my_tensor.requires_grad
True

4..data是从 Variable 中获取底层 Tensor 的主要方式。 合并后,调用 y = x.data仍然具有相似的语义。因此 y 将是一个与 x 共享相同数据的 Tensor,并且 requires_grad = False,它与 x 的计算历史无关。

然而,在某些情况下 .data可能不安全。 对 x.data 的任何更改都不会被 autograd 跟踪,如果在反向过程中需要 x,那么计算出的梯度将不正确。另一种更安全的方法是使用 x.detach(),它将返回一个与 requires_grad = False 时共享数据的 Tensor,但如果在反向过程中需要 x,那么 autograd 将会就地更改它。

PyTorch Variable与Tensor 【详解】相关推荐

  1. 2021 PyTorch官方实战教程(一)Tensor 详解

    点击上方"AI算法与图像处理",选择加"星标"或"置顶"重磅干货,第一时间送达 这个系列时pytorch官方实战教程,后续会继续更新.. 一 ...

  2. 【PyTorch系例】torch.Tensor详解和常用操作

    学习教材: 动手学深度学习 PYTORCH 版(DEMO) (https://github.com/ShusenTang/Dive-into-DL-PyTorch) PDF 制作by [Marcus ...

  3. pytorch自动求梯度—详解

    构建深度学习模型的基本流程就是:搭建计算图,求得损失函数,然后计算损失函数对模型参数的导数,再利用梯度下降法等方法来更新参数.搭建计算图的过程,称为"正向传播",这个是需要我们自己 ...

  4. 【小白学PyTorch】13.EfficientNet详解及PyTorch实现

    <<小白学PyTorch>> 小白学PyTorch | 12 SENet详解及PyTorch实现 小白学PyTorch | 11 MobileNet详解及PyTorch实现 小 ...

  5. 【小白学PyTorch】12.SENet详解及PyTorch实现

    <<小白学PyTorch>> 小白学PyTorch | 11 MobileNet详解及PyTorch实现 小白学PyTorch | 10 pytorch常见运算详解 小白学Py ...

  6. 【小白学PyTorch】11.MobileNet详解及PyTorch实现

    <<小白学PyTorch>> 小白学PyTorch | 10 pytorch常见运算详解 小白学PyTorch | 9 tensor数据结构与存储结构 小白学PyTorch | ...

  7. pytorch scatter和scatter_详解

    文章目录 0. Introduction 1. 定义 2. 详解 例1 例2 Reference: 0. Introduction scatter() 和 scatter_() 的作用是一样的,只不过 ...

  8. pytorch nn.LSTM()参数详解

    输入数据格式: input(seq_len, batch, input_size) h0(num_layers * num_directions, batch, hidden_size) c0(num ...

  9. [pytorch]yolov3.cfg参数详解(每层输出及route、yolo、shortcut层详解)

    文章目录 Backbone(Darknet53) 第一次下采样(to 208) 第二次下采样(to 104) 第三次下采样(to 52) 第四次下采样(to 26) 第五次下采样(to 13) YOL ...

最新文章

  1. c++ gdb 绑定源码_【Vue原理】VNode 源码版
  2. mysql创建用户并授登录权限_mysql创建用户并授予权限
  3. HP小型机superdome配置MC双机、PV、VG、LV初体验
  4. mysql更新记录_如何查看 mysql 表中最近更新的记录
  5. 惯性导航技术, IMU, AHRS
  6. 墨刀的html压缩包是什么,墨刀那些事
  7. 美育在计算机教育中应用,浅谈在小学信息技术课堂中有效实施美育.
  8. Pluck 代码问题漏洞( CVE-2022-26965)
  9. Origin 2017 调整默认字体的方法
  10. 宝宝起名神器微信小程序源码下载支持多种流量主模式
  11. 微信小程序支付错误提示“商户号mch_id或sub_mch_id不存在”
  12. 倍福--授权等级的区别
  13. MPLS LDP的原理与配置
  14. C语言中int、long等类型所占的字节数
  15. 艾伟:WCF从理论到实践(13):事务投票
  16. star ccm 报java错误_在 Linux VM 上运行 STAR-CCM+ 与 HPC Pack - Azure Virtual Machines | Microsoft Docs...
  17. 安装Discuz开源论坛
  18. 国开计算机上机表格试题答案,国家开放大学《计算机应用基础》考试与答案形考任务模块3模块3Excel2010电子表格系统—客观题答案...
  19. 【国际】智利金融监管机构加入R3区块链联盟
  20. win10系统下释放c盘空间

热门文章

  1. 联科教育【免费公开课】每周一和周三晚19:30分:C#程序设计--基础篇,赶快围观啦~~~
  2. Pytorch拟合直线方法
  3. postgresql数据库无法连接,提示 Is the server running on host localhost (127.0.0.1) and accepting TCP/IP conn
  4. 北京大学计算机硕博连读5年,关于2019年北京大学硕博连读研究生选拔工作的通知...
  5. 缓解拖延症的12个小技巧
  6. 膜电极(MEA)是质子交换膜燃料电池(PEMFC)
  7. 02.OC对象的本质
  8. Wdcdn缓存加速系统1.0发布
  9. 自己动手做自动发布系统三
  10. 串口实时数据显示、记录、绘图、计算软件(hdntCenter)