1. 什么是自动混合精度训练?

我们知道神经网络框架的计算核心是Tensor,也就是那个从scaler -> array -> matrix -> tensor 维度一路丰富过来的tensor。在PyTorch中,我们可以这样创建一个Tensor:


>>> import torch>>> gemfield = torch.zeros(70,30)
>>> gemfield.type()
'torch.FloatTensor'>>> syszux = torch.Tensor([1,2])
>>> syszux.type()
'torch.FloatTensor'

可以看到默认创建的tensor都是FloatTensor类型。而在PyTorch中,一共有10种类型的tensor:

● torch.FloatTensor (32-bit floating point)
● torch.DoubleTensor (64-bit floating point)
● torch.HalfTensor (16-bit floating point 1)
● torch.BFloat16Tensor (16-bit floating point 2)
● torch.ByteTensor (8-bit integer (unsigned))
● torch.CharTensor (8-bit integer (signed))
● torch.ShortTensor (16-bit integer (signed))
● torch.IntTensor (32-bit integer (signed))
● torch.LongTensor (64-bit integer (signed))
● torch.BoolTensor (Boolean)

由此可见,默认的Tensor是32-bit floating point,这就是32位浮点型精度的Tensor。自动混合精度的关键词有两个:自动、混合精度。这是由PyTorch 1.6的torch.cuda.amp模块带来的:

from torch.cuda.amp import autocast as autocast

混合精度预示着有不止一种精度的Tensor,那在PyTorch的AMP模块里是几种呢?2种:torch.FloatTensor和torch.HalfTensor;自动预示着Tensor的dtype类型会自动变化,也就是框架按需自动调整tensor的dtype(其实不是完全自动,有些地方还是需要手工干预);
       torch.cuda.amp 的名字意味着这个功能只能在cuda上使用,事实上,这个功能正是NVIDIA的开发人员贡献到PyTorch项目中的。而只有支持Tensor core的CUDA硬件才能享受到AMP的好处(比如2080ti显卡)。Tensor Core是一种矩阵乘累加的计算单元,每个Tensor Core每个时钟执行64个浮点混合精度操作(FP16矩阵相乘和FP32累加),英伟达宣称使用Tensor Core进行矩阵运算可以轻易的提速,同时降低一半的显存访问和存储。因此,在PyTorch中,当我们提到自动混合精度训练,我们说的就是在NVIDIA的支持Tensor core的CUDA设备上使用torch.cuda.amp.autocast (以及torch.cuda.amp.GradScaler)来进行训练。咦?为什么还要有torch.cuda.amp.GradScaler?

2. 为什么要使用混合精度?

这个问题其实暗含着这样的意思:为什么需要自动混合精度,也就是torch.FloatTensor和torch.HalfTensor的混合,而不全是torch.FloatTensor?或者全是torch.HalfTensor?如果非要以这种方式问,那么答案只能是,在某些上下文中torch.FloatTensor有优势,在某些上下文中torch.HalfTensor有优势呗。答案进一步可以转化为,相比于之前的默认的torch.FloatTensor,torch.HalfTensor有时具有优势,有时劣势不可忽视。torch.HalfTensor的优势就是存储小、计算快、更好的利用CUDA设备的Tensor Core。因此训练的时候可以减少显存的占用(可以增加batchsize了),同时训练速度更快;torch.HalfTensor的劣势就是:数值范围小(更容易Overflow / Underflow)、舍入误差(Rounding Error,导致一些微小的梯度信息达不到16bit精度的最低分辨率,从而丢失)。可见,当有优势的时候就用torch.HalfTensor,而为了消除torch.HalfTensor的劣势,我们带来了两种解决方案:
       1)梯度scale,这正是上一小节中提到的torch.cuda.amp.GradScaler,通过放大loss的值来防止梯度的underflow(这只是BP的时候传递梯度信息使用,真正更新权重的时候还是要把放大的梯度再unscale回去);
       2)回落到torch.FloatTensor,这就是混合一词的由来。那怎么知道什么时候用torch.FloatTensor,什么时候用半精度浮点型呢?这是PyTorch框架决定的,在PyTorch 1.6的AMP上下文中,如下操作中tensor会被自动转化为半精度浮点型的torch.HalfTensor:

(1) matmul
(2) addbmm
(3) addmm
(4) addmv
(5) addr
(6) baddbmm
(7) bmm
(8) chain_matmul
(9) conv1d
(10) conv2d
(11) conv3d
(12) conv_transpose1d
(13) conv_transpose2d
(14) conv_transpose3d
(15) linear
(16) matmul
(17) mm
(18) mv
(19) prelu

3. 如何在PyTorch中使用自动混合精度?

答案就是autocast + GradScaler。

1,autocast

正如前文所说,需要使用torch.cuda.amp模块中的autocast 类。使用也是非常简单的:

from torch.cuda.amp import autocast as autocast# 创建model,默认是torch.FloatTensor
model = Net().cuda()
optimizer = optim.SGD(model.parameters(), ...)for input, target in data:optimizer.zero_grad()# 前向过程(model + loss)开启 autocastwith autocast():output = model(input)loss = loss_fn(output, target)# 反向传播在autocast上下文之外loss.backward()optimizer.step()

可以使用autocast的context managers语义(如上所示),也可以使用decorators语义。 当进入autocast的上下文后,上面列出来的那些CUDA ops 会把tensor的dtype转换为半精度浮点型,从而在不损失训练精度的情况下加快运算。刚进入autocast的上下文时,tensor可以是任何类型,你不要在model或者input上手工调用.half() ,框架会自动做,这也是自动混合精度中“自动”一词的由来。
       另外一点就是,autocast上下文应该只包含网络的前向过程(包括loss的计算),而不要包含反向传播,因为BP的op会使用和前向op相同的类型。

2,GradScaler

但是别忘了前面提到的梯度scaler模块呀,需要在训练最开始之前实例化一个GradScaler对象。因此PyTorch中经典的AMP使用方式如下:

from torch.cuda.amp import GradScaler as GradScaler# 创建model,默认是torch.FloatTensor
model = Net().cuda()
optimizer = optim.SGD(model.parameters(), ...)# 在训练最开始之前实例化一个GradScaler对象
scaler = GradScaler()for epoch in epochs:for input, target in data:optimizer.zero_grad()# 前向过程(model + loss)开启 autocastwith autocast():output = model(input)loss = loss_fn(output, target)# Scales loss. 为了梯度放大.scaler.scale(loss).backward()# scaler.step() 首先把梯度的值unscale回来.# 如果梯度的值不是 infs 或者 NaNs, 那么调用optimizer.step()来更新权重,# 否则,忽略step调用,从而保证权重不更新(不被破坏)scaler.step(optimizer)# 准备着,看是否要增大scalerscaler.update()

scaler的大小在每次迭代中动态的估计,为了尽可能的减少梯度underflow,scaler应该更大;但是如果太大的话,半精度浮点型的tensor又容易overflow(变成inf或者NaN)。所以动态估计的原理就是在不出现inf或者NaN梯度值的情况下尽可能的增大scaler的值——在每次scaler.step(optimizer)中,都会检查是否又inf或NaN的梯度出现:
       1)如果出现了inf或者NaN,scaler.step(optimizer)会忽略此次的权重更新(optimizer.step() ),并且将scaler的大小缩小(乘上backoff_factor);
       2)如果没有出现inf或者NaN,那么权重正常更新,并且当连续多次(growth_interval指定)没有出现inf或者NaN,则scaler.update()会将scaler的大小增加(乘上growth_factor)。

【Pytorch】Pytorch的自动混合精度(AMP)相关推荐

  1. Pytorch自动混合精度(AMP)训练

    相关问题:解决pytorch半精度amp训练nan问题 - 知乎 pytorch模型训练之fp16.apm.多GPU模型.梯度检查点(gradient checkpointing)显存优化等 - 知乎 ...

  2. torch.cuda.amp自动混合精度训练 —— 节省显存并加快推理速度

    torch.cuda.amp自动混合精度训练 -- 节省显存并加快推理速度 文章目录 torch.cuda.amp自动混合精度训练 -- 节省显存并加快推理速度 1.什么是amp? 2.为什么需要自动 ...

  3. PyTorch 1.6正式发布!新增自动混合精度训练、Windows版开发维护权移交微软

    点击上方"视学算法",选择加"星标"置顶 重磅干货,第一时间送达 本文转载自:机器之心 刚刚,Facebook 通过 PyTorch 官方博客宣布:PyTorc ...

  4. 自动混合精度(AMP)介绍与使用【Pytorch】

    文章目录 1 前言 2 Mixed Precision Training 3 torch自动混合精度(AMP)介绍与使用 4 torch1.6及以上版本 1 前言 pytorch从1.6版本开始,已经 ...

  5. PyTorch 1.6 发布:原生支持自动混合精度训练并进入稳定阶段

    PyTorch 1.6 稳定版已发布,此版本增加了许多新的 API.用于性能改进和性能分析的工具.以及对基于分布式数据并行(Distributed Data Parallel, DDP)和基于远程过程 ...

  6. pytorch显卡内存随训练过程而增加_PyTorch重大更新:将支持自动混合精度训练!...

    AI编辑:我是小将 混合精度训练(mixed precision training)可以让模型训练在尽量不降低性能的情形下提升训练速度,而且也可以降低显卡使用内存.目前主流的深度学习框架都开始支持混合 ...

  7. 浅尝Pytorch自动混合精度AMP

    AMP目录 浅尝Pytorch自动混合精度 从浮点数说起 深度学习中的浮点数 例1-上溢 例2-下溢 解决了什么问题? Pytorch相关功能简述 Autocasting Autocasting作上下 ...

  8. [Pytorch]基于混和精度的模型加速

    这篇博客是在pytorch中基于apex使用混合精度加速的一个偏工程的描述,原理层面的解释并不是这篇博客的目的,不过在参考部分提供了非常有价值的资料,可以进一步研究. 一个关键原则:"仅仅在 ...

  9. float32精度_PyTorch 1.6来了:新增自动混合精度训练、Windows版开发维护权移交微软...

    刚刚,Facebook 通过 PyTorch 官方博客宣布:PyTorch 1.6 正式发布!新版本增加了一个 amp 子模块,支持本地自动混合精度训练.Facebook 还表示,微软已扩大了对 Py ...

最新文章

  1. 从零开始学python数据分析-从零开始学Python数据分析(视频教学版)
  2. WPF里ItemsControl的分组实现
  3. 实验Matlab数值运算,MATLAB数值实验一(数据的插值运算及其应用完整版
  4. [设计模式]迪米特法则
  5. 混凝土静力受压弹性模量试验计算公式_【小马建考干货】天天送检,你知道混凝土试块检测哪些性能标指吗?...
  6. [Godot]使用精灵集的时候要注意关闭过滤器
  7. 使用Zookeeper实现负载均衡原理
  8. android网络编程登录和验证,ASP.NET实现用户注册和验证功能(第4节)
  9. Futter基础第14篇: 中的按钮组件 RaisedButton、FlatButton、OutlineButton、IconButton、ButtonBar以及自定义按钮组件
  10. SSM中 web.xml配置文件
  11. 计算机四个发展应用范围,计算机的四个发展阶段
  12. 复信号在信号处理中的意义
  13. 三子棋游戏(呆呆详解版)
  14. Springboot+Vue+Echarts实现51job大数据岗位分析数据大屏
  15. 关于怎么学习好一门技术一门语言
  16. CEF:JavaScript 调用 C++ 函数 Demo(VS2013)
  17. 浅谈机器学习中的过拟合
  18. Vue 电话号码344分割
  19. IPFS节点对外入口
  20. BSA分析拟南芥F2代分离群体混池测序

热门文章

  1. AD19无法生成PCB_对PCB印制线的传输线效应以及封装、连接器和电缆的频率响应进行全面分析...
  2. STM32F103:二.(1)点亮LED
  3. ajax onload怎么用,Ajax中onload和onreadystatechange两种请求方式的区别
  4. pycharm连接mysql1193错误_pycharm连接mysql数据库提示错误的解决方法_数据库
  5. STM32 NVIC中断
  6. 线性表之简介及顺序表
  7. S3C6410的DRAM控制器
  8. html中实现类似于弹幕的效果代码,javascript实现弹幕效果
  9. 力扣剑指 Offer 17. 打印从1到最大的n位数
  10. Linux流量监控工具 - iftop