©作者 | 郑怜悯、陈键飞

来源 | 机器之心

随着超大规模深度学习模型逐渐成为 AI 的趋势,如何在有限的 GPU 内存下训练这些模型成为了一个难题。

本文将介绍来自加州伯克利大学的 ActNN,一个基于 PyTorch 的激活压缩训练框架。在同样的内存限制下,ActNN 通过使用 2 bit 激活压缩,可以将 batch size 扩大 6-14 倍,将模型尺寸或者输入图片扩大 6-10 倍。ActNN 相关论文已被 ICML 2021 接收为 Long Talk,代码开源于 github。

论文地址:

https://arxiv.org/abs/2104.14129

代码地址:

https://github.com/ucbrise/actnn

AI 训练撞上「内存墙」

从 AlexNet,ResNet 到 GPT-3,深度学习性能的突破都离不开模型规模的疯狂增长。大模型有更好的性能已经成为业界的共识。过去几年,不仅训练一个最先进模型需要的算力在指数增长,训练一个最先进模型需要的内存也在指数增长。如下图所示,大型 Transformer 模型的参数量以每两年翻 240 倍的速度指数增长。但是,单个 GPU 的内存却只以每两年翻 2 倍的速度在缓慢增长。另外,在训练模型时,不光要存储模型参数,还要存储中间结果激活值和优化器状态,所需要的内存更多。如何在有限的 GPU 内存下训练这些大规模模型成为了挑战。

source:Gholami A, Yao Z, Kim S, Mahoney MW, Keutzer K. AI and Memory Wall. RiseLab Medium Blog Post, University of California Berkeley

节省训练内存的方法

目前,节省训练内存的方法主要有三类:1. 重计算(Gradient checkpointing/Rematerialization)  2. 使用 CPU 内存进行交换 (swapping)  和 3. 使用分布式训练将 Tensor 分散存储在多个 GPU 上。这三类方法互相不冲突,可以结合使用。大部分机器学习框架对这些方法都提供了一些支持,也有不少相关的论文。但是,想要高效、自动化地实现这些策略并不容易。

与已有方法不同,我们提出了 ActNN,一个新的基于压缩的内存节省框架。在提供理论证明的同时,我们基于 PyTorch 提供了一个高效易用的实现。Table.1 比较了 ActNN 和已有的一些内存节省系统。ActNN 支持 PyTorch 的动态图执行模式,并且不需要预先进行复杂的策略搜索。ActNN 作为一个独立的 Python 库,使用时 import 即可,不需要修改或重新编译 PyTorch。与已有的工作相比,ActNN 灵活且易于使用。同时,ActNN 在理论上也可以和已有的技术相互叠加。

ActNN:2 bit 激活压缩训练

在训练一个多层神经网络时,在前向传播中,每一层的中间结果都要被存下来用于计算反向传播的梯度。这些中间结果,又被叫做「激活值」(activation),实际上占据了大部分的内存消耗,尤其是在 batch size 较大或者输入图片较大的时候。ActNN 的原理是就是压缩这些激活值来节省内存。如下图所示,左图表示的是普通的前向传播和反向传播,前向传播时会存下所有层的 fp32 激活值用于反向传播,内存使用在计算 loss 的时候达到峰值。右图表示的是 ActNN 的训练方法:在前向传播时,通过一个压缩操作 Q 将激活值压缩后再存储;反向传播时,通过解压缩操作 Q^-1 将激活值解压再计算梯度。

如果只是为了节省内存,这里可以使用各种压缩算法,但是大部分现有的压缩算法并不能高效地运行在 GPU 上,会引入较大的开销。ActNN 选择了使用 2-bit 量化作为这里的压缩算法。量化操作的代价较小,而且有一些好的数学性质允许我们使用有损压缩达到较大的压缩比。

把 fp32 浮点数量化为 2-bit 整数是一个有损压缩,会引入一些误差。论文从理论上分析了量化引入的误差是如何影响训练的收敛性的。

第一,存在一个随机化的量化策略,使得使用有损量化压缩后,估计出的有损梯度是原梯度的一个无偏估计。

在这一条件下,我们套用已有的随机梯度下降收敛性定理,得出最后收敛时的误差会被梯度的方差所限制。

第二,我们推导出了使用量化压缩之后,随机梯度下降计算出的梯度的方差。

等号右边的第一项是随机梯度下降在 minibatch 采样时产生的方差,等号右边的第二项是有损压缩额外引入的方差。这条公式显示地刻画了有损压缩带来的影响。注意到,当有损量化压缩带来的方差远小于原来随机梯度下降自带的方差时,ActNN 引入的有损压缩就不会影响训练的收敛性。更多关于公式的推导和可视化参见文末的论文链接。论文对不同的算子(conv2d,batch norm,linear等)都提供了详细的分析。

由上述公式启发,我们提出了一些新的量化技巧用于降低有损压缩引入的额外方差。我们引入了新的量化技巧 ( Per-group Quantization,Fine-Grained Mixed-Precision,Runtime Adaptation) 来利用梯度在不同样本,不同纬度,不同层之间的异构特性。最后的压缩算法会分配更多的 bit 给更重要的激活值。平均每个浮点数分配到 2 bit。

在具体实现压缩算法时,还有很多可以调节的参数。这里产生了一个内存节省和训练速度的取舍。一般来说,使用更复杂的压缩算法可以节省更多的内存,但是也会引入更多额外的开销,使训练速度变慢。为了给用户较大的灵活性,ActNN 提供了 5 个优化等级 L1-L5 供用户选择。低的优化等级节省的内存较少,但是运行速度快。高的优化等级节省的内存多,但是运行也更慢。在最高优化等级 L5 下,ActNN 会结合一个简单的内存交换策略,将压缩后的激活值移到 CPU 内存上,进一步节省内存。

实现

要在 PyTorch 实现 ActNN 算法非常简单。对于一个 PyTorch nn Module,我们只需要在其 forward 函数里加入量化压缩,在其 backward 函数里加入解压缩操作。所有的计算还是在 fp32 下进行,与原来一样,伪代码如下图所示。

ActNN 为大部分常用的 PyTorch nn.Module 实现了使用量化压缩的版本。用户只需将模型里的所有 PyTorch nn.Module 替换成 ActNN 对应的 Module (如把 nn.Conv2d 替换成 actnn.Conv2d),即可节省内存,不需要更改其他代码。ActNN 同时也提供了一个 wrapper 实现一行代码自动替换。

实验结果

因为 ActNN 进行的是有损压缩,所以最重要的一点是先验证 ActNN 是否会影响模型的精度。下图是使用 ActNN 在 ImageNet 上训练 ResNet-50 的结果。FP 代表普通的 fp32 训练, BLPA 是来自 NeurIPS 2019 的一个相关工作。可以看到,在 ActNN 的 2-bit 压缩模式下,模型几乎没有损失精度。在更极限的 1.25 bit 的情况下,ActNN 也能收敛,只不过会损失一些精度。而之前的工作 BLPA 在小于 4 bit 的情况就下无法收敛。

我们还在图像分割,物体检测,以及自监督学习等多个任务上进行了实验。ActNN 都能在 2-bit 压缩模式下达到和普通 fp32 几乎一样的结果。在部分任务上,因为 ActNN 可以使用更大的 batch size,甚至可以取得更好的测试结果。详细的实验结果和训练记录参见文末的论文与 github 链接。

之后,我们对比了 ActNN 与普通 fp32 训练的实际内存使用情况。如下表所示,ActNN 可以将激活值占用的内存压缩 12 倍,将训练使用的总内存压缩 4 - 7 倍。这一实际内存压缩效果符合理论推导。为什么激活值压缩倍率是 12 而不是 32 bit / 2 bit = 16?主要是因为 ActNN 不能使用 inplace 的 ReLU,以及需要存储少量额外的 min 和 scale 用于解压缩。

最后,我们测试了 ActNN 的训练速度。因为 ActNN 在训练过程中进行了压缩,这些压缩在节省内存的同时也会引入额外的计算开销。一般来说,省得内存越多,进入的额外开销就越多,训练也就越慢。我们在 NVIDIA T4 (16 GB 内存) 上对比了 ActNN 和已有内存节省系统的训练速度。如下图所示,DTR (ICLR 2020),BLPA (NeurIPS 2019)和 swap 分别是基于重计算,压缩和内存交换的三种方法,红叉代表 Out-of-memory。y 轴是训练吞吐量 (images per second),越高越好。绿色的曲线是综合 ActNN 在不同优化等级下的最优结果。可以看到,ActNN 不仅能开到最大的 batch size(即最省内存),同时在所有 batch size 下都比 baseline 的训练速度更快。

我们还对更多的网络进行了测试。在同样的内存限制下,ActNN 可以将 batch size 扩大 6-14 倍,将模型尺寸或者输入图片扩大 6-10 倍。详细的实验设置和结果参见文末的论文链接。

两行代码即可在 PyTorch 中使用

import actnn
model = actnn.QModule(model)

ActNN 提供了一个自动模型转换封装。只需在训练脚本里插入两行代码,即可将普通的 PyTorch  模型转换为使用 ActNN 的模型。同时,ActNN 也提供了更高级的 API 支持定制化的使用场景。

更多的例子参见 github 链接。我们提供了在图像识别、图像分割、物体检测,以及自监督学习等多个任务上使用 actnn 的完整例子和训练记录,欢迎试用!

特别鸣谢

感谢 TCCI 天桥脑科学研究院对于 PaperWeekly 的支持。TCCI 关注大脑探知、大脑功能和大脑健康。

#投 稿 通 道#

 让你的文字被更多人看到 

如何才能让更多的优质内容以更短路径到达读者群体,缩短读者寻找优质内容的成本呢?答案就是:你不认识的人。

总有一些你不认识的人,知道你想知道的东西。PaperWeekly 或许可以成为一座桥梁,促使不同背景、不同方向的学者和学术灵感相互碰撞,迸发出更多的可能性。

PaperWeekly 鼓励高校实验室或个人,在我们的平台上分享各类优质内容,可以是最新论文解读,也可以是学术热点剖析科研心得竞赛经验讲解等。我们的目的只有一个,让知识真正流动起来。

???? 稿件基本要求:

• 文章确系个人原创作品,未曾在公开渠道发表,如为其他平台已发表或待发表的文章,请明确标注

• 稿件建议以 markdown 格式撰写,文中配图以附件形式发送,要求图片清晰,无版权问题

• PaperWeekly 尊重原作者署名权,并将为每篇被采纳的原创首发稿件,提供业内具有竞争力稿酬,具体依据文章阅读量和文章质量阶梯制结算

???? 投稿通道:

• 投稿邮箱:hr@paperweekly.site

• 来稿请备注即时联系方式(微信),以便我们在稿件选用的第一时间联系作者

• 您也可以直接添加小编微信(pwbot02)快速投稿,备注:姓名-投稿

△长按添加PaperWeekly小编

????

现在,在「知乎」也能找到我们了

进入知乎首页搜索「PaperWeekly」

点击「关注」订阅我们的专栏吧

关于PaperWeekly

PaperWeekly 是一个推荐、解读、讨论、报道人工智能前沿论文成果的学术平台。如果你研究或从事 AI 领域,欢迎在公众号后台点击「交流群」,小助手将把你带入 PaperWeekly 的交流群里。

节省显存新思路,在PyTorch里使用2 bit激活压缩训练神经网络相关推荐

  1. PyTorch节省显存占用方法

    1-使用inplace操作 2-使用混合精度运算 参考: [1]混合精度训练 http://kevinlt.top/2018/09/14/mixed_precision_training/ [2]py ...

  2. pytorch节省显存_节省新房子的照明

    pytorch节省显存 Our final move into the new house is this weekend. We did a three phase, three week move ...

  3. 释放pytorch占用的gpu显存_Pytorch 节省显存的训练方法总结

    前言 最近的工作中,用到了Pytorch框架训练医学图像分割模型.精心设计的模型经常会因为显存不足而失败.减小模型训练过程中对显存的占用,可能我们能想到最简单的方法就是减小batchsize,减少卷积 ...

  4. torch.cuda.amp自动混合精度训练 —— 节省显存并加快推理速度

    torch.cuda.amp自动混合精度训练 -- 节省显存并加快推理速度 文章目录 torch.cuda.amp自动混合精度训练 -- 节省显存并加快推理速度 1.什么是amp? 2.为什么需要自动 ...

  5. GLM国产大模型训练加速:性能最高提升3倍,显存节省1/3,低成本上手

    作者|BBuf.谢子鹏.冯文 2017 年,Google 提出了 Transformer 架构,随后 BERT .GPT.T5等预训练模型不断涌现,并在各项任务中都不断刷新 SOTA 纪录.去年,清华 ...

  6. 显存优化 | Pytorch的显存机制torch.cuda.empty_cache及周边概念

    注:文中涉及一些内部底层实现机制,可能和大家外界看到的不一样,这里略过不做介绍.借着笔记,分享平时碰到的技术点,不高端,不炫酷,对你有用更好了. 最近在做模型的优化工作,主要涉及精度,速度和显存优化, ...

  7. pytorch如何计算显存大小

    参考连接 pytorch 减小显存消耗,优化显存使用避免 outofmemory https://blog.csdn.net/qq_28660035/article/details/80688427 ...

  8. OOM?教你如何在PyTorch更高效地利用显存

    点击上方"3D视觉工坊",选择"星标" 干货第一时间送达 作者丨派派星 来源丨CVHub 编辑丨极市平台 导读 本文介绍了如何在不减少输入数据尺寸以及Batch ...

  9. tensorflow 显存 训练_【他山之石】训练时显存优化技术——OP合并与gradient checkpoint...

    作者:bindog 地址:http://bindog.github.io/ 01 背景 前几天看到知乎上的文章FLOPs与模型推理速度[1],文中提到一个比较耗时又占显存的pointwise操作x * ...

最新文章

  1. PAT(甲级)2020年春季考试 7-4 Replacement Selection
  2. TensorFlow中的计算图
  3. RxSwift 小记 Error Handling Operators(catchError,retry)
  4. 练习:WinForm (PictureBox和Timer)
  5. Qt creator使用笔记
  6. python写端口扫描器_使用Python编写简单的端口扫描器的实例分享
  7. .NET平台相关概念(简单了解)
  8. HTML特殊转义字符对照表
  9. Web Dynpro for ABAP 之 Web Dynpro Window Web Dynpro Application
  10. android组件通讯 Intent-Action属性
  11. POJ1358 Agri-Net
  12. Spring Stateless State Security第3部分:JWT +社会认证
  13. java连接mysql时区修改_java连接mysql数据库时的时区设置问题(time_zone)
  14. 英文版Windows2k服务器无法正常返回中文的问题
  15. Excel学习笔记:P22-时间格式、工龄与工时计算
  16. php转jsp,阿里西西Html多功能代码转换器(html转js/jsp/php工具)
  17. vscode代码自动保存插件_VSCode 云同步扩展设置 Settings Sync 插件
  18. 住宅代理和数据中心代理哪个更好?
  19. 玩客云内置EMMC存储刷入Armbian
  20. 中英文颜色RGB数值对照表(python cv2)

热门文章

  1. 基于互联网生态积累,百度Apollo智舱产品斩获智能网联创新奖
  2. php5.6.30源码下载,PHP 5.6.30 正式发布,安全漏洞修复
  3. 结构体里有指针 scanf赋值_C++|链表中常见的链表节点指针操作
  4. python paramiko安装_Python Paramiko模块的安装与使用详解
  5. Coding Interview Guide -- 数组的partition调整
  6. 小程序基于mpvue开发坑一
  7. 关于nsurlsession
  8. linux学习笔记-不定时更新
  9. Ajax(一)显示可用内存空间
  10. oracle omf管理 and asm omf