基于梯度检查点的亚线性显存优化方法[1]由于较高的计算/显存性价比受到关注。MegEngine经过工程扩展和优化,发展出一套行之有效的加强版亚线性显存优化技术,既可在计算存储资源受限的条件下,轻松训练更深的模型,又可使用更大batch size,进一步提升模型性能,稳定batchwise算子。使用MegEngine训练ResNet18/ResNet50,显存占用分别最高降低23%/40%;在更大的Bert模型上,降幅更是高达75%,而额外的计算开销几乎不变。该技术已在MegEngine开源,欢迎大家上手使用:https://github.com/MegEngine。

作者 | 旷视研究院

深度神经网络训练是一件复杂的事情,它体现为模型的时间复杂度和空间复杂度,分别对应着计算和内存;而训练时内存占用问题是漂浮在深度学习社区上空的一块乌云,如何拨云见日,最大降低神经网络训练的内存占用,是一个绕不开的课题。

GPU显卡等硬件为深度学习提供了必需的算力,但硬件自身有限的存储,限制了可训练模型的尺寸,尤其是大型深度网络,由此诞生出一系列相关技术,比如亚线性显存优化、梯度累加、混合精度训练、分布式训练,进行GPU显存优化。

其中,亚线性显存优化方法[1]由于较高的计算/显存性价比备受关注;旷视基于此,经过工程扩展和优化,发展出加强版的MegEngine亚线性显存优化技术,轻松把大模型甚至超大模型装进显存,也可以毫无压力使用大batch训练模型。

这里将围绕着深度学习框架MegEngine亚线性显存优化技术的工程实现和实验数据,从技术背景、原理、使用、展望等多个方面进行首次深入解读。

背景

在深度学习领域中,随着训练数据的增加,需要相应增加模型的尺寸和复杂度,进行模型「扩容」;而ResNet [2] 等技术的出现在算法层面扫清了训练深度模型的障碍。不断增加的数据和持续创新的算法给深度学习框架带来了新挑战,能否在模型训练时有效利用有限的计算存储资源,尤其是减少GPU显存占用,是评估深度学习框架性能的重要指标。

在计算存储资源一定的情况下,深度学习框架有几种降低显存占用的常用方法,其示例如下:

  • 通过合适的梯度定义,让算子的梯度计算不再依赖于前向计算作为输入,从而in-place地完成算子的前向计算,比如Sigmoid、Relu等;

  • 在生命周期没有重叠的算子之间共享显存;

  • 通过额外的计算减少显存占用,比如利用梯度检查点重新计算中间结果的亚线性显存优化方法[1];

  • 通过额外的数据传输减少显存占用,比如把暂时不用的数据从GPU交换到CPU,需要时再从CPU交换回来。

上述显存优化技术在MegEngine中皆有不同程度的实现,这里重点讨论基于梯度检查点的亚线性显存优化技术。

原理

一个神经网络模型所占用的显存空间大体分为两个方面:1)模型本身的参数,2)模型训练临时占用的空间,包括参数的梯度、特征图等。其中最大占比是 2)中以特征图形式存在的中间结果,比如,从示例[1]可知,根据实现的不同,从70%到90%以上的显存用来存储特征图。

这里的训练过程又可分为前向计算,反向计算和优化三个方面,其中前向计算的中间结果最占显存,还有反向计算的梯度。第 1)方面模型自身的参数内存占用最小。

MegEngine加强版亚线性显存优化技术借鉴了[1]的方法,尤其适用于计算存储资源受限的情况,比如一张英伟达2080Ti,只有11G的显存;而更贵的Tesla V100,最大显存也只有32G。

图1:亚线性显存优化原理,其中 (b) 保存了Relu结果,实际中Relu结果可用in-place计算

图 1(a) 给出了卷积神经网络的基本单元,它由Conv-BN-Relu组成。可以看到,反向计算梯度的过程依赖于前向计算获取的中间结果,一个网络需要保存的中间结果与其大小成正比,即显存复杂度为O(n)。

本质上,亚线性显存优化方法是以时间换空间,以计算换显存,如图 1(b) 所示,它的算法原理如下:

  • 选取神经网络中k个检查点,从而把网络分成k个block,需要注意的是,初始输入也作为一个检查点;前向计算过程中只保存检查点处的中间结果;

  • 反向计算梯度的过程中,首先从相应检查点出发,重新计算单个block需要的中间结果,然后计算block内部各个block的梯度;不同block的中间结果计算共享显存。

这种方法有着明显的优点,即大幅降低了模型的空间复杂度,同时缺点是增加了额外的计算:

  • 显存占用从O(n)变成O(n/k)+ O(k),O(n/k)代表计算单个节点需要的显存,O(k)代表k个检查点需要的显存, 取k=sqrt(n),O(n/k)+ O(k)~O(sqrt(n)),可以看到显存占用从线性变成了亚线性;

  • 因为在反向梯度的计算过程中需要从检查点恢复中间结果,整体需要额外执行一次前向计算。

工程

在[1]的基础上,MegEngine结合自身实践,做了工程扩展和优化,把亚线性显存优化方法扩展至任意的计算图,并结合其它常见的显存优化方法,发展出一套行之有效的加强版亚线性显存优化技术。

亚线性优化方法采用简单的网格搜索(grid search)选择检查点,MegEngine在此基础上增加遗传算法,采用边界移动、块合并、块分裂等策略,实现更细粒度的优化,进一步降低了显存占用。

如图2所示,采用型号为2080Ti的GPU训练ResNet50,分别借助基准、亚线性、亚线性+遗传算法三种显存优化策略,对比了可使用的最大batch size。仅使用亚线性优化,batch size从133增至211,是基准的1.6x;而使用亚线性+遗传算法联合优化,batch size进一步增至262,较基准提升2x。

图2:三种显存优化方法优化batch size的对比:ResNet50

通过选定同一模型、给定batch size,可以更好地观察遗传算法优化显存占用的情况。如图3所示,随着迭代次数的增加,遗传算法逐渐收敛显存占用,并在第5次迭代之后达到一个较稳定的状态。

图3:遗传算法收敛示意图

此外,MegEngine亚线性优化技术通过工程改良,不再局限于简单的链状结构和同质计算节点, 可用于任意的计算图,计算节点也可异质,从而拓展了技术的适用场景;并可配合上述显存优化方法,进一步降低模型的显存占用。

实验

MegEngine基于亚线性显存技术开展了相关实验,这里固定batch size=64,在ResNet18和ResNet50两个模型上,考察模型训练时的显存占用和计算时间。

如图4所示,相较于基准实现,使用MegEngine亚线性显存技术训练ResNet18时,显存占用降低32%, 计算时间增加24%;在较大的ReNet50上,显存占用降低40%,计算时间增加25%。同时经过理论分析可知,模型越大,亚线性显存优化的效果越明显,额外的计算时间则几乎不变。

图4:MegEngine亚线性优化技术实验显存/时间对比:ReNet18/ReNet50

在更大模型Bert上实验数据表明,借助MegEngine亚线性显存技术,显存占用最高降低75%,而计算时间仅增加23%,这与理论分析相一致。有兴趣的同学可前往MegEngine ModeHub试手更多模型实验:https://megengine.org.cn/model-hub/。

使用

MegEngine官网提供了亚线性显存优化技术的使用文档。当你的GPU显存有限,苦于无法训练较深、较大的神经网络模型,或者无法使用大batch进一步提升深度神经网络的性能,抑或想要使batchwise算子更加稳定,那么,MegEngine亚线性显存优化技术正是你需要的解决方案。

上手MegEngine亚线性优化技术非常便捷,无需手动设定梯度检查点,通过几个简单的参数,轻松控制遗传算法的搜索策略。具体使用时,在MegEngine静态图接口中调用SublinearMemoryConfig设置trace的参数sublinear_memory_config,即可打开亚线性显存优化:

from megengine.jit import trace, SublinearMemoryConfig
config = SublinearMemoryConfig()
@trace(symbolic=True, sublinear_memory_config=config)def train_func(data, label, *, net, optimizer):    ...

MegEngine在编译计算图和训练模型时,虽有少量的额外时间开销,但会显著缓解显存不足问题。下面以ResNet50为例,说明MegEngine可有效突破显存瓶颈,训练batch size从100最高增至200:

import osfrom multiprocessing import Processdef train_resnet_demo(batch_size, enable_sublinear, genetic_nr_iter=0):import megengine as mgeimport megengine.functional as Fimport megengine.hub as hubimport megengine.optimizer as optimfrom megengine.jit import trace, SublinearMemoryConfigimport numpy as npprint("Run with batch_size={}, enable_sublinear={}, genetic_nr_iter={}".format(            batch_size, enable_sublinear, genetic_nr_iter        )    )# 使用GPU运行这个例子assert mge.is_cuda_available(), "Please run with GPU"try:# 我们从 megengine hub 中加载一个 resnet50 模型。        resnet = hub.load("megengine/models", "resnet50")optimizer = optim.SGD(resnet.parameters(), lr=0.1,)config = Noneif enable_sublinear:            config = SublinearMemoryConfig(genetic_nr_iter=genetic_nr_iter)@trace(symbolic=True, sublinear_memory_config=config)def train_func(data, label, *, net, optimizer):            pred = net(data)            loss = F.cross_entropy_with_softmax(pred, label)            optimizer.backward(loss)resnet.train()for i in range(10):            batch_data = np.random.randn(batch_size, 3, 224, 224).astype(np.float32)            batch_label = np.random.randint(1000, size=(batch_size,)).astype(np.int32)            optimizer.zero_grad()            train_func(batch_data, batch_label, net=resnet, optimizer=optimizer)            optimizer.step()except:        print("Failed")        returnprint("Sucess")
# 以下示例结果在2080Ti GPU运行得到,显存容量为 11 GB# 不使用亚线性内存优化,允许的batch_size最大为 100 左右p = Process(target=train_resnet_demo, args=(100, False))p.start()p.join()# 报错显存不足p = Process(target=train_resnet_demo, args=(200, False))p.start()p.join()# 使用亚线性内存优化,允许的batch_size最大为 200 左右p = Process(target=train_resnet_demo, args=(200, True, 20))p.star

展望

如上所述,MegEngine的亚线性显存优化技术通过额外做一次前向计算,即可达到O(sqrt(n))的空间复杂度。如果允许做更多次的前向计算,对整个网络递归地调用亚线性显存算法,有望在时间复杂度为O(n log n)的情况下,达到 O(log n)的空间复杂度。

更进一步,MegEngine还将探索亚线性显存优化技术与数据并行/模型并行、混合精度训练的组合使用问题,以期获得更佳的集成效果。最后,在RNN以及GNN、Transformer等其他类型网络上的使用问题,也是MegEngine未来的一个探索方向。

参考文献

  1. Chen, T., Xu, B., Zhang, C., & Guestrin, C. (2016). Training deep nets with sublinear memory cost. arXiv preprint arXiv:1604.06174.

  2. He, K., Zhang, X., Ren, S., & Sun, J. (2016). Deep residual learning for image recognition. In Proceedings of the IEEE conference on computer vision and pattern recognition (pp. 770-778).

推荐阅读

  • 饿了么交易系统 5 年演化史

  • 360金融首席科学家张家兴:别指望AI Lab做成中台

  • 干货 | 时间序列预测类问题下的建模方案探索实践

  • 写了Bug,误执行 rm -fr /*,我删删删删库了,要跑路吗?| 原力计划

  • 中国 App 出海“变形记”

  • 从货币历史,看可编程货币的升级

  • 你点的每个“在看”,我都认真当成了AI

深度解析MegEngine亚线性显存优化技术相关推荐

  1. 深度学习 占用gpu内存 使用率为0_深度解析MegEngine亚线性显存优化技术

    作者 | 旷视研究院 编辑 | Linda 基于梯度检查点的亚线性显存优化方法 [1] 由于较高的计算 / 显存性价比受到关注.MegEngine 经过工程扩展和优化,发展出一套行之有效的加强版亚线性 ...

  2. MegEngine亚线性显存优化

    MegEngine亚线性显存优化 MegEngine经过工程扩展和优化,发展出一套行之有效的加强版亚线性显存优化技术,既可在计算存储资源受限的条件下,轻松训练更深的模型,又可使用更大batch siz ...

  3. tensorflow 显存 训练_【他山之石】训练时显存优化技术——OP合并与gradient checkpoint...

    作者:bindog 地址:http://bindog.github.io/ 01 背景 前几天看到知乎上的文章FLOPs与模型推理速度[1],文中提到一个比较耗时又占显存的pointwise操作x * ...

  4. 深度学习分布式策略优化、显存优化、通信优化、编译优化综述

    综述 因为我个人最近在从事可能是AI领域对性能挑战最大的方向,自动驾驶领域,所以对整个深度学习训练的优化尤为关注,最近一直在学习相关内容,谨以此篇文章做一个总结. 我一直很看好深度学习训练优化这个方向 ...

  5. 阿里 NIPS 2017 Workshop 论文:基于 TensorFlow 的深度模型训练 GPU 显存优化

    NIPS 2017 在美国长滩举办,场面非常热烈.阿里巴巴一篇介绍深度模型训练 GPU 显存优化的论文<Training Deeper Models by GPU Memory Optimiza ...

  6. 深度学习中GPU和显存分析

    点击上方"小白学视觉",选择加"星标"或"置顶" 重磅干货,第一时间送达 转自:机器学习AI算法工程 深度学习最吃机器,耗资源,在本文,我将 ...

  7. 科普帖:深度学习中GPU和显存分析

    深度学习最吃机器,耗资源,在本文,我将来科普一下在深度学习中: 何为"资源" 不同操作都耗费什么资源 如何充分的利用有限的资源 如何合理选择显卡 并纠正几个误区: 显存和GPU等价 ...

  8. 深度学习中GPU和显存

    GPU状态的监控 nvidia-smi: 是Nvidia显卡命令行管理套件,基于NVML库,旨在管理和监控Nvidia GPU设备.nvidia-smi命令的输出中最重要的两个指标:显存占用和GPU利 ...

  9. GPU显存 - 深度学习中 GPU 和显存分析

    深度学习中 GPU 和显存分析 原文作者陈云. 本文原载于知乎专栏--人工智障的深度瞎学之路 深度学习最吃机器,耗资源,在本文,我将来科普一下在深度学习中: 何为"资源" 不同操作 ...

最新文章

  1. [15]APUE:pipe / FIFO
  2. Go(GoLang)解决 cannot find package/golang.org/问题 Grpc+ProtoBuf所需的一些资源
  3. Django-HttpResponse、render,、redirect(转载)
  4. applyDimension的用法
  5. java就业班学什么呀_传智播客JAVA就业班学习心得
  6. html转pdf后修改,pdf转换器smallpdf转成HTML后怎么排版
  7. android ExpandableListView
  8. POJ 3070 Fibonacci(矩阵高速功率)
  9. Jflash 工程配置及下载
  10. Webgl开发输入框兼容问题及开发注意的问题
  11. 计算机基础——Word 2010
  12. 使用Arcade制作的简单吃豆人游戏
  13. csgo 简单发光透视
  14. Photoshop基本使用
  15. 吴恩达机器学习课后作业——偏差和方差
  16. 全加器和半加器的区别
  17. 树莓派3b+ ubuntu-mate18.04系统安装 迅雷远程下载 搭建详解
  18. 记录 ESIM 安装、使用过程中遇到的问题
  19. 矩阵初等行变换的技巧
  20. 毕业论文的前言写什么?

热门文章

  1. Sunrun2016年Q3财务业绩强劲 冲刺全年目标
  2. 《C++代码设计与重用》——1.2 重用的神话
  3. awk: (FILENAME=- FNR=1) 致命错误: 试图访问字段 -2
  4. 在wamp环境下面安装Zend Optimizer的方法
  5. 一、数据库设计与性能优化--概述
  6. WindowsServer2012史记7-茴香豆的五种写法和四种”显示计算机”的方法
  7. emmmmmm(官宣?)
  8. [ JSOI 2015 ] Salesman
  9. Xamarin Android项目运行失败
  10. 【剑指offer 面试题47】不用加减乘除做加法