MegEngine 提供从训练到部署完整的量化支持,包括量化感知训练以及训练后量化,凭借“训练推理一体”的特性,MegEngine更能保证量化之后的模型与部署之后的效果一致。本文将简要介绍神经网络量化的原理,并与大家分享MegEngine量化方面的设计思路与实操教程。也欢迎阅读我们此前的MegEngine系列文章:

  • 工程之道,CPU推理性能提高数十倍,MegEngine计算图、MatMul优化解析

  • 工程之道,MegEngine推理性能极致优化之综述篇

  • 工程之道,深度解析MegEngine亚线性显存优化技术

背景

近年来随着边缘计算和物联网的兴起与发展,许多移动终端(比如手机)成为了深度学习应用的承载平台,甚至出现了各式各样专用的神经网络计算芯片。由于这些设备往往对计算资源和能耗有较大限制,因此在高性能服务器上训练得到的神经网络模型需要进行裁剪以缩小内存占用、提升计算速度后,才能较好地在这些平台上运行。

一种最直观的裁剪方式就是用更少位数的数值类型来存储网络参数,比如常见的做法是将 32 位浮点数模型转换成 8 位整数模型,模型大小减少为 1/4,而运行在特定的设备上其计算速度也能提升为 2~4 倍,这种模型转换方式叫做量化(Quantization)。

量化的目的是为了追求极致的推理计算速度,为此舍弃了数值表示的精度,直觉上会带来较大的模型掉点,但是在使用一系列精细的量化处理之后,其在推理时的掉点可以变得微乎其微,并能支持正常的部署应用。

原理

实现量化的算法多种多样,一般按照代价从低到高可以分为以下四种:

  • Type1 和 Type2 由于是在模型浮点模型训练之后介入,无需大量训练数据,故而转换代价更低,被称为后量化(Post Quantization),区别在于是否需要小批量数据来校准(Calibration);

  • Type3 和 Type4 则需要在浮点模型训练时就插入一些假量化(FakeQuantize)算子,模拟量化过程中数值截断后精度降低的情形,故而称为量化感知训练(Quantization Aware Training, QAT)。

以常用的 Type3 为例,一个完整的量化流程分为三阶段:

(1)以一个训练完毕的浮点模型(称为 Float 模型)为起点;

(2)包含假量化算子的用浮点操作来模拟量化过程的新模型(Quantized-Float 模型或 QFloat 模型);

(3)可以直接在终端设备上运行的模型(Quantized 模型,简称 Q 模型)。

由于三者的精度一般是 Float > QFloat > Q ,故量化算法也就分为两步:

  • 拉近 QFloat 和 Q:这样训练阶段的精度可以作为最终 Q 精度的代理指标,这一阶段偏工程;

  • 拔高 QFloat 逼近 Float:这样就可以将量化模型性能尽可能恢复到 Float 的精度,这一阶段偏算法。

第一步在MegEngine框架的“训练推理一体化”特性下得到了保证,而第二步则取决于不同的量化算法。

尽管不同量化算法可能在假量化的具体实现上有所区别,但是一般都会有一个“截断”的操作,即把数值范围较大的浮点数转换成数值范围较小的整数类型,比如下图,输入一个[-1, 1)范围的浮点数,如果转换为 4 位整型,则最多只能表示 2^4 个值,所以需要将输入的范围划分为16段,每段对应一个固定的输出值,这样就形成了一个类似分段函数的图像,计算公式为:

另外,由于分段函数在分段点没有梯度,所以为了使假量化操作不影响梯度回传,就需要模拟一个梯度,最简单的方法就是用y=x来模拟这一分段函数,事实证明这么做也是有效的,这种经典的操作被称为“Straight-Through-Estimator”(STE)。

工程

量化部分作为模型推理部署的重要步骤,是业界在大规模工业应用当中极为关注的部分,它在 MegEngine 的底层优化中占了很大比重。在目前开源的版本里,针对三大平台(X86、CUDA、ARM),MegEngine都有非常详细的支持,尤其是ARM平台。

一般在通用计算平台上,浮点计算是最常用的计算方式,所以大部分指令也是针对浮点计算的,这使得量化模型所需的定点计算性能往往并不理想,这就需要针对各个平台优化其定点计算的性能。

ARM 平台

ARM平台一般是指手机移动端,其系统架构和底层指令都不同于我们熟知的电脑CPU,而随着架构的变迁,不同架构之间的指令也存在不兼容的问题。为此,MegEngine针对ARM v8.2前后版本分别实现了不同的优化:

  • ARM v8.2 主要的特性是提供了新的引入了新的 fp16 运算和 int8 dot 指令,MegEngine基于此进行一系列细节优化(细节:四个int8放到一个128寄存器的32分块里一起算),最终实现了比浮点版本快2~3倍的速度提升

  • 而对于v8.2之前的ARM处理器,MegEngine则通过对Conv使用nchw44的layout和细粒度优化,并创新性地使用了int8(而非传统的int6)下的winograd算法来加速Conv计算,最使实现能够和浮点运算媲美的速度。

CUDA 平台

CUDA 平台是指 NVIDIA 旗下 GPU 平台,由于提供 CUDNN 和 Toolkit 系列接口以及 TensorRT 专用推理库,大部分算子可以使用官方优化,而 MegEngine 则在此基础上进行了更多细节的优化,比如如何更好地利用 GPU 的TensorCore 进行加速,不同型号之间一些差异的处理等,最终效果根据不同模型也有非常明显的推理加速。

X86 平台

X86 平台是指 Intel CPU 平台,近年来随着深度学习的发展,其也慢慢提供了针对定点运算更多的支持。

  • 在新一代至强(Xeon)处理器上,通过使用 VNNI(Vector Neural Network Instructions)指令,MegEngine 将 CPU 的 int8 推理性能优化到了浮点性能的 2~3 倍。

  • 而对于不支持 VNNI 指令的 CPU,一般只提供最低 int16 的数值类型支持,则通过使用 AVX2(Advanced Vector Extensions)这一向量格式,实现了 int8 推理性能与浮点性能持平。

以上是对各个平台推理加速效果的整体介绍,更多更细节的介绍可以期待之后的系列文章。

使用

除了底层实现上的加速与优化,在 Python 侧训练部分,MegEngine对接口也有很多细节设计,使得整体代码逻辑清晰简洁。

我们在 Module 中额外引入了两个基类:QATModule、QuantizedModule 。分别代表上文提及的带假量化算子的 QFloat 模型与 Q 模型,并提供普通 Module → QATModule → QuantizedModule 三阶段的转换接口。各个版本的算子是一一对应的,且通过合理的类继承免除了大量算子实现中的冗余代码,清晰简洁。

如上图,用户首先在普通 Module 上进行正常的模型训练工作。训练结束后可以转换至 QFloat 模型上,通过配置不同的 Observer 和假量化算子来选择不同的量化参数 scale 获取方式,从而选择进行 QAT 或 Calibration 后量化。之后可以再转换至 Q 模型上,通过 trace.dump 接口就可以直接导出进行部署

针对推理优化中常用的算子融合,MegEngine 提供了一系列已 fuse 好的 Module,其对应的 QuantizedModule 版本都会直接调用底层实现好的融合算子(比如 conv_bias)。

这样实现的缺点在于用户在使用时需要修改原先的网络结构,使用 fuse 好的 Module 搭建网络,而好处则是用户能更直接地控制网络如何转换,比如同时存在需要 fuse 和不需要 fuse 的 Conv 算子,相比提供一个冗长的白名单,我们更倾向于在网络结构中显式地控制,而一些默认会进行转换的算子,也可以通过 disable_quantize 方法来控制其不进行转换。

另外我们还明确了假量化算子(FakeQuantize)和Observer的职责,前者将主要负责对输入进行截断处理的计算部分,而后者则只会记录输入的值,不会改变输出,符合 Observer 的语义。

在配置使用上,用户需要显式指定针对 weight、activation 分别使用哪种 Observer 和 FakeQuantize,比如:

ema_fakequant_qconfig = QConfig(    weight_observer=partial(MinMaxObserver, dtype="qint8", narrow_range=True),    act_observer=partial(        ExponentialMovingAverageObserver, dtype="qint8", narrow_range=False    ),    weight_fake_quant=partial(FakeQuantize, dtype="qint8", narrow_range=True),    act_fake_quant=partial(FakeQuantize, dtype="qint8", narrow_range=False),)

这样的好处在于,用户可以控制每一处量化过程的细节,可以分别采用不同量化算子和数值类型。

下文简单说明一下在 MegEngine 中转换一个 ResNet 网络的全流程代码:

Float → QFloat:

from megengine.quantization import ema_fakequant_qconfigfrom megengine.quantization.quantize import quantize_qat # 使用fuse好的Module搭建的网络model = ResNet18() # 使用默认的配置进行模型转换quantize_qat(model, ema_fakequant_qconfig) # 与 Float 模型完全一致的训练函数train(model)

QFloat → Q 并导出用于部署:

from megengine.quantization.quantize import quantize# 使用fuse好的Module搭建的网络
model = ResNet18()# 执行模型转换
quantize(model)# 将模型进行编译,infer_func是trace类的实例,通过trace方法进行编译
infer_func(processed_img).trace()# 调用dump方法将模型导出,用于部署
infer_func.dump(output_file, arg_names=["data"])

更多接口细节可以参考官网文档。

MegEngine Website:

https://megengine.org.cn

总结

本文简单介绍了神经网络模型实际应用在移动平台必不可少的一步——量化,以及天元(MegEngine )在量化上做的一些工作:包括底层针对不同平台的一些优化效果,在用户接口使用上的一些设计理念。

天元(MegEngine)相信,通过简洁清晰的接口设计与极致的性能优化,“深度学习,简单开发”将不仅惠及旷视自身,也能便利所有的研究者,开发者。

参考文献

[1] Moons, B., Goetschalckx, K., Van Berckelaer, N., & Verhelst, M. (2017, October). Minimum energy quantized neural networks. In 2017 51st Asilomar Conference on Signals, Systems, and Computers (pp. 1921-1925). IEEE.

[2] Jacob, B., Kligys, S., Chen, B., Zhu, M., Tang, M., Howard, A., ... & Kalenichenko, D. (2018). Quantization and training of neural networks for efficient integer-arithmetic-only inference. In Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition (pp. 2704-2713).

[3] Zhou, A., Yao, A., Guo, Y., Xu, L., & Chen, Y. (2017). Incremental network quantization: Towards lossless cnns with low-precision weights. arXiv preprint arXiv:1702.03044.

[4] Li, F., Zhang, B., & Liu, B. (2016). Ternary weight networks. arXiv preprint arXiv:1605.04711.

[5] Rastegari, M., Ordonez, V., Redmon, J., & Farhadi, A. (2016, October). Xnor-net: Imagenet classification using binary convolutional neural networks. In European conference on computer vision (pp. 525-542). Springer, Cham.

欢迎访问

  • MegEngine Website:
    https://megengine.org.cn

  • MegEngine GitHub(欢迎Star):
    https://github.com/MegEngine

或加入「天元开发者交流QQ群」,一起看直播学理论、做作业动手实践、直接与框架设计师交流互动。

同时,群内还会不定期给大家发放各种福利:学习礼包、算力、周边等。

????

现在,在「知乎」也能找到我们了

进入知乎首页搜索「PaperWeekly」

点击「关注」订阅我们的专栏吧

关于PaperWeekly

PaperWeekly 是一个推荐、解读、讨论、报道人工智能前沿论文成果的学术平台。如果你研究或从事 AI 领域,欢迎在公众号后台点击「交流群」,小助手将把你带入 PaperWeekly 的交流群里。

工程之道,深度学习的工业级模型量化实战相关推荐

  1. Python工程能力进阶、数学基础、经典机器学习模型实战、深度学习理论基础和模型调优技巧……胜任机器学习工程师岗位需要学习什么?...

    咱不敢谈人工智能时代咋样咋样之类的空话,就我自己来看,只要是个营收超过 5 亿的互联网公司,基本都需要具备机器学习的能力.因为大部分公司盈利模式基本都会围绕搜索.推荐和广告而去. 就比如极客时间,他的 ...

  2. 从FM推演各深度学习CTR预估模型

    本文的PDF版本.代码实现和数据可以在我的github取到. 1.引言 点击率(click-through rate, CTR)是互联网公司进行流量分配的核心依据之一.比如互联网广告平台,为了精细化权 ...

  3. 2023北京智源大会亮点回顾 | 高性能计算、深度学习和大模型:打造通用人工智能AGI的金三角

    AIGC | Aquila | HuggingFace AGI | DeepMind  | Stability AI 通用人工智能(AGI)是人工智能领域的最终目标,也是一项极具挑战性的任务.在诸多技 ...

  4. R使用LSTM模型构建深度学习文本分类模型(Quora Insincere Questions Classification)

    R使用LSTM模型构建深度学习文本分类模型(Quora Insincere Questions Classification) Long Short Term 网络-- 一般就叫做 LSTM --是一 ...

  5. 深度学习100+经典模型TensorFlow与Pytorch代码实现大合集

    关注上方"深度学习技术前沿",选择"星标公众号", 资源干货,第一时间送达! [导读]深度学习在过去十年获得了极大进展,出现很多新的模型,并且伴随TensorF ...

  6. 深度学习CTR预估模型凭什么成为互联网增长的关键?

    本文是王喆在InfoQ开设的原创技术专栏"深度学习CTR预估模型实践"的第一篇文章(以下"深度学习CTR预估模型实践"简称"深度CTR模型" ...

  7. 深度学习的seq2seq模型——本质是LSTM,训练过程是使得所有样本的p(y1,...,yT‘|x1,...,xT)概率之和最大...

    from:https://baijiahao.baidu.com/s?id=1584177164196579663&wfr=spider&for=pc seq2seq模型是以编码(En ...

  8. 深度学习 vs. 概率图模型 vs. 逻辑学

    深度学习 vs. 概率图模型 vs. 逻辑学 发表于2015-04-30 21:55|6304次阅读| 来源quantombone|1 条评论| 作者Tomasz Malisiewicz 深度学习de ...

  9. [caffe]深度学习之图像分类模型VGG解读

    一.简单介绍 vgg和googlenet是2014年imagenet竞赛的双雄,这两类模型结构有一个共同特点是go deeper.跟googlenet不同的是.vgg继承了lenet以及alexnet ...

最新文章

  1. pandas使用nlargest函数返回特定数据列中前N个最大值(搜寻最大的n个元素)、pandas使用nlargest函数返回特定数据列中前N个最大值所对应的数据行
  2. redis mysql主从同步_手撕Redis,主从同步
  3. OpenCV 加载图像、转换图像和保存图像
  4. jQuery Mobile
  5. jQuery length 和 size()区别
  6. 十天学会ASP.Net——(2)
  7. MyEclipse2015Stable2.0安装破解
  8. 机器学习中防止过拟合的方法总结
  9. python dataframe 取每行的最大值,在python数据框中的每一行中查找最大值
  10. laravel5.5 php7,ubuntu 16.04+nginx+mysql+php7.1+laravel5.5环境
  11. linux path原理,面试题:Linux中的环境变量PATH
  12. eclipse+java类不报错_eclipse,代码中有错误,项目或者java类中却不显示红叉
  13. centos 7单网卡实现双路由,同时访问内外网
  14. jQuery 效果 - stop() 方法
  15. 风光过后就崩溃,互联网公司让你心好累
  16. 风帆头,旗帜服,“背”在肩上的古国王印
  17. php silk文件转换pcm,微信语音silk格式文件转换处理记录
  18. Shopee聊聊客服工作日常
  19. PHP中的empty()函数
  20. OSChina 周四乱弹 ——金毛如何实现部门自助化管理案例图

热门文章

  1. openssl 64位编译_海思hi3516dv300开发--live555交叉编译
  2. mysql第3章数据定义_【MySQL数据库】第3章解读:服务器性能剖析 (下)
  3. java----java工具包
  4. 12月5日 第二冲刺周期个人站立会议内容报告(第五天)
  5. 传央行闭门会议将出台两项举措 等同降准150基点
  6. 插入数据到hive_Hive实现网站PV分析
  7. linux批量管理服务,通过PSSH批量管理Linux服务器
  8. tf callbacks
  9. 永劫无间为啥显示连接服务器失败,永劫无间服务器故障怎么办?永劫无间服务器故障解决办法...
  10. spark任务shell运行_了解Spark 应用的一生