◎本文为极市开发者「ArtyZe」原创投稿,转载请注明来源。

◎极市「项目推荐」专栏,帮助开发者们推广分享自己的最新工作,欢迎大家投稿。联系极市小编(fengcall19)即可投稿~

量化简介

在实际神经网络在例如端侧的部署时,由于内存,带宽或者最重要计算资源的限制,通常会采用量化等手段来加速神经网络的表现。量化的意思即是将原来浮点运算转化为定点运算,例如最常见的8bit量化,无论是int8还是uint8,都是将浮点的区间参数映射到256个离散区间上。这样原来32位的运算就变成了8位的运算

r=S(q−Z)r=S(q-Z)r=S(q−Z)

这里我们以非对称量化到uint8举例,其中S代表量化因子(scale factor), Z代表zero point.

量化的优点非常明显,即使除去后处理,反量化或者非对称量化带来额外运算,单张图片的推理速度通常都能获得2-3倍的提升(这里不讨论针对硬件进行特殊优化带来的加速),但是随之而来的就是量化造成的精度下降问题。

简单来说,量化造成精度损失主要来自两个方面:

  • 取整损失,例如r = [6.8, 7.2, -0.6], scale = (7.2+0.6)/127 = 0.061417, q1 = 7.2/scale = 117.23,那么他的量化值就是117,有了0.23的损失

  • 截断损失 ,因为scale是取最优区间,那么边界的点势必会有超过最大量化值的情况,这些离群点就会被忽略掉,量化的最大最小值区间相比于原数据分布就有了截断损失

为了能够减少量化过程中的精度损失,我们参考google的论文

《Quantization and Training of Neural Networks for Efficient Integer-Arithmetic-Only Inference》,这种方法属于aware training quantization,与之对应的是post training quantization,后面一种方法是tensorRT使用的量化方法,后面有机会可以把实现的代码上传到github上。

事实上,学术界认为8bit的量化已经饱和了,已经开始做4bit的量化研究了,但是在实际的工作过程中,发现对于较小的识别网络,8bit的量化效果依然不是令人非常满意。

量化实现

为了方便的部署到嵌入式端,我最初选择实现框架定在实现语言为C或者C++,最终选定的框架为darknet,一方面darknet在工业界有着不错的应用群体,二来框架简单直接,实现起来非常方便,同时还可以验证反向过程是否正确。在复现过程中,为了能够将算法成功的集成进去,对darknet做了许多小的修改,正好这里也记录一下。

代码链接:
https://github.com/ArtyZe/yolo_quantization

伪量化

相信对量化了解的同学都读过这篇文章,tf-lite都是用的这种量化方式。区别于训练后量化的方式,google采用的是在训练过程中加入伪量化来模拟量化过程中由于取整造成的精度损失。

那么伪量化是个什么操作呢?

q=⌊x−as⌋x+aq=\left\lfloor\frac{x-a}{s}\right\rfloor x+aq=⌊sx−a​⌋x+a

其中,类似中括号那里就是取整的意思。可以看到,如果说没有取整这个操作,完全就是减一个数,除一个数,再乘回来,再加回来,完全就没有任何变化。但是因为有了这个取整,所以这中间就有了变化。

想象一下,如果在训练过程中,采取了这么一个操作,那不就相当于提前就把量化的损失考虑进去了吗?这样等到inference的时候,精度下降就少的多了呀。

那么要把这个伪量化放在哪里呢?

那当然是放在inference的时候需要进行量化的位置,以论文中给出的图来解析,

卷积的操作用公式来描述无非就是:

y=f(wx)y=f(w x)y=f(wx)

所以要量化的就是weights以及feature x。

这时候就有人提出疑问了,可是你看啊,人家给出的图中是weights和激活值的伪量化啊,你怎么说是input的feature呢,可是如果你这样想呢,除了第一层真正的输入之外,剩下的层,上一层的activ输出值不就是下一层的input值吗,而且使用activ值有一个什么最大的好处呢?在最后一层将定点值反量化回到浮点值需要用到激活值的scale和zero_point(如果是非对称量化的话)

在训练中融合BN到CONV

我们平时见到的最多的融合BN+CONV就是在inference的时候为了加速做的,但是你细想一下,你BN的参数在inference的时候怎么办呢?如果inference的时候不融合,那么BN的参数你要怎么量化,如果融合了,那么weights的量化参数是根据融合前生成的啊,那你怎么能用呢?

所以解决方案就是,把BN融合在训练阶段就加进去,如下图:

具体怎么做呢?

- 首先就y=wxy=w xy=wx的前向跑一遍,计算得到均值,方差等一系列BN的参数

- 然后,利用这些BN的参数,通过融合公式加到input和weights中去,将卷积公式变成真正的

y′=w′x+b′y^{\prime}=w^{\prime} x+b^{\prime}y′=w′x+b′

其中

w′=pwσ2+εb′=b−γμσ2+εw^{\prime}=\frac{p w}{\sqrt{\sigma^{2}+\varepsilon}} \quad b^{\prime}=b-\frac{\gamma \mu}{\sqrt{\sigma^{2}+\varepsilon}}w′=σ2+ε​pw​b′=b−σ2+ε​γμ​

为了后续能够更新原生www 和 b,b,b, 该过程中不仅需要保存 w′w^{\prime}w′ 和 b′,b^{\prime},b′, 还需要保存 www 和 bbb,至于反向更新过程中,需要使用Straight Through Estimator(STE)来跳过伪量化过程中的round使得梯度可以正常回传

- 之后根据不同层的type添加input, weights和activation量化即可。目前我采用的方式是第一层卷积input, weights和activation量化都要有,其他层如route后面的卷积层同样需要input量化,因为route的activation量化参数直接使用他的输入层的activation量化参数即可;maxpool或者upsample都是添加activation量化即可

需要注意的

Uint8推理实现

下面开始介绍定点推理,公式如下

y=wx+by=w x+by=wx+b

由前面可知

r=S(q−Z)r=S(q-Z)r=S(q−Z)

S3(q3−Z3)=S2(q2−Z2)S1(q1−Z1)+Sb(qb−Zb)S_{3}\left(q_{3}-Z_{3}\right)=S_{2}\left(q_{2}-Z_{2}\right) S_{1}\left(q_{1}-Z_{1}\right)+S_{b}\left(q_{b}-Z_{b}\right)S3​(q3​−Z3​)=S2​(q2​−Z2​)S1​(q1​−Z1​)+Sb​(qb​−Zb​)

为了保持量纲一致,令,Sb=S1S2,Zb=0S_{b}=S_{1} S_{2}, \quad Z_{b}=0Sb​=S1​S2​,Zb​=0

对上式进行简单的变换
q3=Z3+M(NZ1Z2−Z1∑q2−Z2∑q1+∑q1q2+qb)q_{3}=Z_{3}+M\left(N Z_{1} Z_{2}-Z_{1} \sum q_{2}-Z_{2} \sum q_{1}+\sum q_{1} q_{2}+q_{b}\right)q3​=Z3​+M(NZ1​Z2​−Z1​∑q2​−Z2​∑q1​+∑q1​q2​+qb​)

其中, M=S1S2/S3M=S_{1} S_{2} / S_{3}M=S1​S2​/S3​ 是唯一的浮点数, 因此采用 M=M0×2−shiftM=M_{0} \times 2^{-s h i f t}M=M0​×2−shift 来代表, M0M_{0}M0​ 和 shift 都是定点值,具体多大需要看精度需要,一般采用32位的值来表示。

  • 在进入到正式的推理之前,首先看上式哪些值是常量可以提前计算出来,例如Z3,Z1Z2,Z1∑q2,qbZ_{3}, Z_{1} Z_{2}, Z_{1} \sum q_{2}, q_{b}Z3​,Z1​Z2​,Z1​∑q2​,qb​都是常量,其中1代表ft,2代表weights

  • 进入到正式推理后,需要注意的问题就是溢出的问题,一般情况下为了防止这种情 况有两种方式,一种就是使用一个shift来统计溢出的情况,另一种就是直接把输出范围扩大,例如8bit的乘加输出到32bit。下面我们开始计算Z2∑q1Z_{2} \sum q_{1}Z2​∑q1​ 及 ∑q1q2\sum q_{1} q_{2}∑q1​q2​,为了能够尽可能的探索优化速度的极限,gemm函数我们使用的是mkl中的cblas库函数。

  • 得到q3q_{3}q3​之后的最后一步操作就是激活,这部分在实际使用过程中也是关乎到量化精度的一个关键点。如果激活函数是类似softmax,tanh,swish等非线性函数的话,都要通过lookup table查表的方式,为了能够尽快的实现,我这里选用的是tiny-yolov3,里面的激活函数都是leaky relu的线性激活函数

  • 其他层例如maxpool,route由于并不涉及到计算操作,因此直接将代码转成uint8的即可

  • 在最后一层yolo层的前面需要将uint8反量化回到float类型,方式如下

后续改进

目前已经实现了yolov3-tiny的所有算子的实现,为了方便,目前使用relu6替代了原来的

leakyrelu,包括conv, pooling, route, upsample,这些除了conv全部都是线性的算子,后续会

继续支持leaky relu, softmax, shortcut, elementwise add, concat等非线性算子。

量化performance

为了尽可能的不影响精度,我选择在yolo层的上面一层conv层不进行量化。测试结果如下,可以看到前向时间相比于原来的darknet压缩明显,同时精度下降非常低。

传送门

Github链接:
https://github.com/ArtyZe/yolo_quantization

开源项目|基于darknet实现量化感知训练,已实现yolov3-tiny所有算子相关推荐

  1. 深度学习框架量化感知训练的思考及OneFlow的解决方案

    作者 | BBuf 原文首发于公众号GiantPandaCV 0x0.总览 相信不少小伙伴都了解或者使用了一些深度学习框架比如PyTorch,TensorFlow,OneFlow(也是笔者目前正在参与 ...

  2. 深度学习框架量化感知训练的思考及OneFlow的一种解决方案

    [GiantPandaCV导语]这篇文章分享的是笔者最近在OneFlow做的一个项目,将Pytorch FX移植到OneFlow之后实现了自动量化感知训练动态图模型(在Pytorch和OneFlow中 ...

  3. 基于pytorch量化感知训练(mnist分类)--浮点训练vs多bit后量化vs多bit量化感知训练效果对比

    基于pytorch量化感知训练–浮点训练vs多bit后量化vs多bit量化感知训练效果对比 代码下载地址:下载地址 灰色线是量化训练,橙色线是后训练量化,可以看到,在 bit = 2.3 的时候,量化 ...

  4. 基于pytorch后量化(mnist分类)---浮点训练vs多bit后量化vs多bit量化感知训练效果对比

    基于pytorch后量化(mnist分类)-浮点训练vs多bit后量化vs多bit量化感知训练效果对比 代码下载地址:下载地址 试了 bit 数为 1-8 的准确率,得到下面这张折线图: 发现,当 b ...

  5. 适配PyTorch FX,OneFlow让量化感知训练更简单

    作者 | 刘耀辉 审稿 | BBuf.许啸宇 1 背景 近年来,量化感知训练是一个较为热点的问题,可以大大优化量化后训练造成精度损失的问题,使得训练过程更加高效. Torch.fx在这一问题上走在了前 ...

  6. 量化感知训练实践:实现精度无损的模型压缩和推理加速

    简介:本文以近期流行的YOLOX[8]目标检测模型为例,介绍量化感知训练的原理流程,讨论如何实现精度无损的实践经验,并展示了量化后的模型能够做到精度不低于原始浮点模型,模型压缩4X.推理加速最高2.3 ...

  7. pytorch量化感知训练(QAT)示例---ResNet

    pytorch量化感知训练(QAT)示例---ResNet 训练浮点模型,测试浮点模式在CPU和GPU上的时间; BN层融合,测试融合前后精度和结果比对; 加入torch的量化感知API,训练一个QA ...

  8. 量化感知训练_一文速览EMNLP 2020中的Transformer量化论文

    花了几天看了三篇EMNLP 2020中关于Transformer模型量化的相关论文,快速记一下要点. Fully Quantized Transformer for Machine Translati ...

  9. DoNet开源项目-基于jQuery EasyUI的后台管理系统

    本文转载于 石佳劼的博客,有问题请到原文咨询,原文连接. 博主在业余时间开发了一个简单的后台管理系统,其中用到了 jQuery EasyUI 框架,上次分享过系统布局,参考文章:jQuery Easy ...

最新文章

  1. ccf 高速公路(顺便当tarjan模板)
  2. Java五年,已财富自由,美人在手!
  3. 系统服务描述表指针查找
  4. 递归概念分类注意事项
  5. Tomcat源码解析五:Tomcat请求处理过程
  6. 推荐使用:Vue.js ReactJS Angular 2 AngularJS
  7. 线性结构基本概念【数据结构】F
  8. P1720 月落乌啼算钱(斐波那契数列)--python3实现
  9. python计算AA制时砍价后大家需要分摊的钱
  10. Python 列表下标操作
  11. openjdk8之编译和debug
  12. Gartner公布2017年全球云存储魔力象限:阿里云跻身四强
  13. Android 开源无线投屏,用树莓派Raspberry Pi 4B制作一个无线投屏器(20200803)
  14. Excel进行描述性统计分析
  15. 如何使用思维导图做计划的6种模板类型
  16. 红孩儿网狐Cocos经典棋牌开发教程-卞安-专题视频课程
  17. 模拟定时自动关机编程程序源码
  18. Android开发天气预报自我记录
  19. 腾讯、阿里、字节跳动的简单比较
  20. 基于RapidIO的应用接口使用说明

热门文章

  1. 使用STM32F103ZET霸道主板实现SD卡的读写(非文件系统)
  2. 关闭 Sublime Text 3 自动更新
  3. linux路由表命令
  4. DropDownList的常用属性和事件
  5. linux网络协议栈之数据包处理过程,Linux网络协议栈之数据包处理过程
  6. js生日计算年龄_生男生女计算公式超准
  7. Stata 17 for Win 最新中文附详细安装教程
  8. 云服务centos搭建jupyter notebook并通过外网访问
  9. 定义一个圆类——Circle,在类的内部提供一个属性:半径(r),同时 提供 两个 方 法 : 计算 面积 ( getArea() ) 和 计算 周长(getPerimeter()) 。
  10. 如何安装gnuplot