转自PyTorch 0.4新版本 升级指南,博主为ShellCollector。

PyTorch 0.4新版本 升级指南

PyTorch 终于从0.3.1升级到0.4.0了, 首先引入眼帘的,是PyTorch官方对自己的描述的巨大变化.

PyTorch 0.3.1说:

PyTorch is a python package that provides two high-level features:

• Tensor computation (like numpy) with strong GPU acceleration

• Deep Neural Networks built on a tape-based autodiff system

而PyTorch 0.4.0说:

PyTorch is a python based scientific computing package targeted at two sets of audiences:

• A replacement for NumPy to use the power of GPUs

• a deep learning research platform that provides maximum flexibility and speed

本次升级, 只做了一件事情, 就是将Tensor 类和 Variable 类 合并, 这一合并, 解决掉了很多原来令人困扰的问题.

在旧版本, Variable和Tensor分离, Tensor主要是多维矩阵的封装, 而Variable类是计算图上的节点, 它对Tensor进行了进一步的封装.

所以, 在训练过程中, 一个必要的步骤就是, 把Tensor转成Variable以便在模型中运行; 运行完之后, 我们还要将Variable转成Tensor,甚至Numpy. 我们在写代码和读代码的时候, 看到了各种辅助函数, 比如下面就是我常用的辅助函数:

# 旧版本实现
import torch# 从Tensor转换到Vairable
def to_var(x):if torch.cuda.is_available():x = x.cuda()return Variable(x)  # 从CUDA Variable转换到Numpy
def to_np(x):return x.data.cpu().numpy()for epoch in range(3):   # 训练3轮for step, (batch_x, batch_y) in enumerate(loader):  # 每一步# 把训练数据转成Variablebatch_x, batch_y = to_var(batch_x), to_var(batch_y)pass

0.4.0, 我们就可以不用这么转化了

for epoch in range(3):   # 训练3轮for step, (batch_x, batch_y) in enumerate(loader):  # 每一步optimizer.zero_grad()# forward + backward + optimizeoutputs = net(batch_x)loss = criterion(outputs, batch_y)loss.backward()optimizer.step()print('Finished Training')

好处当然很大, 但是我们更关心以下几个问题:

Variable没了, Variable 的功能怎么办?

1.requires_grad 标志怎么处理了?

requires_grad 在Variable中,用来标志一个Variable是否要求导(或者说,要不要放到计算图中), 合并之后,这个标志处理的?

2.volatile 标志怎么处理了?

volatile在Variable中,用来标志一个Variable是否要被计算图隔离出去, 合并之后, 这个标志怎么处理的?

3.data方法呢?

Variable中,都是将封装的Tensor数据存储在.data里, 现在Variable和Tensor合并了, .data怎么办?

4.张量和标量怎么统一?

在Tensor元素内部都是Python 标量类型, 而Variable都是Tensor 张量类型, 原本它们井水不犯河水, 但现在合并了, 怎么处理?

# 旧版 0.3.1
>>> import torch
>>> from torch.autograd import Variable
>>> a = torch.Tensor([1,2,3])
>>> a[0]  # 内部元素是Python 标量
1.0
>>> type(a[0]) # 类别是Python float
<class 'float'>
>>> b = Variable(a)
>>> b[0] # 内部元素是Tensor类型, 张量
Variable containing:1
[torch.FloatTensor of size 1]

合并之后的Tensor是什么样的?

5.合并之后, 新版本Tensor是什么类型?

回答如下

1. requires_grad 标志怎么处理了?
直接挂在Tensor类下

>>> import torch
>>> x = torch.ones(1)
>>> x.requires_grad
False

2.volatile 标志怎么处理了?
弃用 , 但是做了一些替代, 比如torch.no_grad(), torch.set_grad_enabled(grad_mode)

>>> import torch
>>> x = torch.zeros(1, requires_grad=True)
>>> with torch.no_grad(): # 将y 从计算图中排除
...     y = x * 2
>>> y.requires_grad
False

3.data方法呢?
保留功能, 但建议替代为x.detach()

.data方法,本质上是给当前Tensor加一个新引用, 它们指向的内存都是一样的, 因此不安全 。

比如y = x.data(), 而x参与了计算图的运算, 那么, 如果你不小心修改了y的data, xdata也会跟着变, 然而反向传播是监听不到xdata变化的, 因此造成梯度计算错误。

y = x.detach()正如其名, 将返回一个不参与计算图的Tensor y, Tensor y 一旦试图改变修改自己的data, 会被语法检查和python解释器监测到, 并抛出错误。

4.张量和标量怎么统一?
新增0维张量(0-dimensional Tensor), 用以封装标量(scalar), 将张量(Tensor), 标量(Scalar)都统一成张量.

>>> import torch
>>> torch.tensor(3.1416)         # 创建标量
tensor(3.1416)
>>> torch.tensor(3.1416).size()  # 其实是0维的张量
torch.Size([])
>>> torch.tensor([3]).size()     # 1维张量
torch.Size([1])

5.合并之后, 新版本Tensor是什么类型?
torch.Tensor类型, 但是, 详细类型需要进一步调用方法:

>>> import torch
>>> x = torch.DoubleTensor([1, 1, 1])
>>> type(x)
<class 'torch.Tensor'>
>>> x.type()
'torch.DoubleTensor'
>>> isinstance(x, torch.DoubleTensor)
True

旧版本的PyTorch, 你可以在类型上直接看出一个Tensor的基本信息, 比如devicecuda上, layoutsparse,dtypeFloat型的Tensor, 你可以:

# 0.3.1
>>> type(a)
<class 'torch.cuda.sparse.FloatTensor'>

由新版本, 所有的Tensor对外都是torch.Tensor类型, 上述的属性, 从类名转移到了Tensor的属性了。

  • torch.device 描述设备的位置, 比如torch.device('cuda'), torch.device('cpu')
>>> import torch
>>> cuda = torch.device('cuda')
>>> cpu  = torch.device('cpu')
>>> a = torch.tensor([1,2,3], device=cuda)
>>> a.device
device(type='cuda', index=0)
>>> b = a.to(cpu) # 将数据从cuda copy 到 cpu
>>> b.device
device(type='cpu')
>>> type(a)  # type a 和 tpye b, 看不出谁在cuda谁在cpu
<class 'torch.Tensor'>
>>> type(b)
<class 'torch.Tensor'>
  • torch.layout
    torch.layout 是 一个表示Tensor数据在内存中样子的类, 默认torch.strided, 即稠密的存储在内存上, 靠stride来刻画tensor的维度. 目前还有一个实验版的对象torch.sparse_coo, 一种coo格式的稀疏存储方式, 但是目前API还不固定, 大家谨慎使用.

  • torch.dtype

[转] PyTorch 0.4新版本 升级指南 no_grad相关推荐

  1. PyTorch 0.4新版本 升级指南 no_grad

    PyTorch 0.4新版本 升级指南 [导读]今天大家比较关心的是PyTorch在GitHub发布0.4.0版本,专知成员Huaiwen详细讲解了PyTorch新版本的变动信息, 本次升级, 只做了 ...

  2. Spring Boot 3.0 正式发布,这份升级指南必须收藏

    Spring Boot 3.0 现已正式发布,它包含了 12 个月以来 151 个开发者的 5700 多次代码提交.这是自 4.5 年前发布 2.0 以来,Spring Boot 的第一次重大修订. ...

  3. 小米9什么时间升级android10,小米9/MIX 3 现在即可升级安卓10.0!升级指南戳这里...

    原标题:小米9/MIX 3 现在即可升级安卓10.0!升级指南戳这里 [手机频道·原创] 最近小米接连有消息爆出,在今天早些时候的I/O开发者大会上,谷歌在Android Q中推出了许多新功能.现在, ...

  4. San CLI 4.0 升级指南

    San CLI 历经多个版本迭代,目前已经进入 4.0 版本,增加 webpack5 支持.优化配置机制等,本文会对升级经验做出总结,期望给读者带来一些启发. 前言 San CLI 更新到 3.0 版 ...

  5. composer升级_Composer 使用姿势与 Lumen 升级指南

    Composer 使用姿势 这里主要说说 composer.json 和 composer.lock 文件的作用. composer.json composer.json 文件包含了项目的依赖和其它的 ...

  6. CDH6官方文档中文系列(8)----Cloudera升级指南

    Cloudera升级指南 最近在学习cdh6的官方文档,网上也比较难找到中文的文档. 其实官方英文文档的阅读难度其实并不是很高,所以在这里在学习官方文档的过程中,把它翻译成中文,在翻译的过程中加深学习 ...

  7. ie11java阻止_企业IT管理员IE11升级指南【10】—— 如何阻止IE11的安装

    企业IT管理员IE11升级指南 系列: 如何阻止IE11的安装 希望自行管理更新计划的企业和组织可以使用 IE11 Automatic Update Blocker Toolkit (自动更新拦截工具 ...

  8. JEECG Framework 3.5.0 GA 新版本终于发布了,重量级功能(数据权限,国际化,多数据源),团队会努力推出新版本,希望大家多多支持!!

     JEECG Framework 3.5.0 GA 新版本终于发布了,重量级功能(数据权限,国际化,多数据源),            今年团队会努力不断推出新版本,希望大家多多支持!! 发布地址: ...

  9. 005-Sencha Cmd 5升级指南

    Sencha Cmd 5升级指南 本指南旨在帮助开发人员使用Sencha Cmd从ExtJS 4.1.1 a+升级到ExtJS 5.0.x. 尽管在这个版本中有一些重要的变化,但是我们已经尝试使升级过 ...

最新文章

  1. 元组类型与列表类型的操作函数和方法
  2. wxWidgets:布局窗口/窗扇示例
  3. BZOJ5467 PKUWC2018Slay the Spire(动态规划)
  4. Java中转发(Forward)和重定向(Redirect)的区别
  5. 根据语句自动生成正则表达式
  6. sql配置管理器服务是空的_PostgreSQL 12 安装和配置
  7. python介绍环境搭建、变量输入输出
  8. m-qam matlab,基于matlab的M_QAM通信系统仿真.doc
  9. html无序列表只能横着排吗,[三地连线走势图]css 怎样让无序列表 横着排列
  10. 和专业计算机男生谈恋爱,和不同专业的男生谈恋爱是什么感觉?
  11. 使用 prismjs 在网页中高亮显示代码
  12. 理解套间(涉及进程、线程、COM线程模型)(转载)
  13. 小红书差评笔记下沉 | 如何让小红书笔记下沉
  14. 【Adobe国际认证中文官网】Adobe中国摄影计划,免费安装 正版激活
  15. 新时代城市规划建设需新基建与传统基建携手共同打造
  16. Java Web项目开发流程
  17. Android手机开发总结——Android核心分析
  18. 编辑审稿时不会从头看到尾!所以论文应该这样写……
  19. QWidget 半透明窗口解决方案
  20. notepad++功能简介

热门文章

  1. python新手入门代码-[代码全屏查看]-新手初学Python实现某论坛自动签到功能
  2. 2018年python工作好找吗-Python的发展状况-2018年
  3. 深度学习的应用:语音识别、图像理解、自然语言处理
  4. NIO流程记录(非源码,单reacter单线程)
  5. php中sisson用法,详细介绍php中session的用法
  6. 【MYSQL笔记】MYSQL监视器
  7. FFmpeg源代码简单分析:内存的分配和释放(av_malloc()、av_free()等)
  8. jQuery 文本编辑器插件 HtmlBox 使用
  9. RTMPDump源代码分析 0: 主要函数调用分析
  10. python字典浅复制_元组,字典,浅复制,集合