本文的内容包括对神经网络模型量化的基本介绍、对Tensorflow量化训练的理解与上手实操。

此外,后续系列还对量化训练中的by pass和batch norm两种情况进行补充解释,欢迎点击浏览,量化训练:Quantization Aware Training(二)。


量化(quantized),即将神经网络前向过程中浮点数运算量化为整数运算,以达到计算加速的目的。通常是指将float32转化为int8进行运算,经实践部署dsp上能提升2.5~3倍左右的推理速度。我们知道对浮点数的量化操作是将其从一个高维度映射到低维度的转换过程,如图所示:

量化的主要流程如下:

(1)统计出网络某一层的最大值与最小值:

(2)计算scale与zero_point

(3)通过以下公式计算出任意float32量化后的int8结果

由公式可以看出量化中的精度损失不可避免的,当浮点数的分布均匀时,精度损失较小。但当浮点数分布不均匀时,按照最大最小值映射,则实际有效的int8动态范围就更小了,精度损失变大。

对此,NVIDIA的PPT中提出通过寻找最优阈值进行非饱和截取的思路改善精度损失,如下图,先不做具体赘述,有时间再填坑。

2017年,Google发表了关于神经网络量化方面的文章Quantization and Training of Neural Networks for Efficient Integer-Arithmetic-Only Inference。文中提出一种新的量化框架,在训练过程中引入伪量化的操作,用于模拟量化过程带来的误差(这一框架无论在resnet这种大模型,还是mobilenet这种本身比较精简的网络上效果都不错)。所谓的伪量化,即将模拟量化操作引入训练过程中,在每个weight的输入后与output的输出前进行伪量化,将浮点量化到定点,再反量化成浮点,用round过程中所产生的误差的浮点值进行前向运算。伪量化的操作可以使权值、激活值的分布更加均匀,也就是方差更小,相比直接进行后量化的精度损失能更小,其次能够控制每层的输出在一定范围内,对溢出处理更有帮助。值得注意的是,量化训练中都是采用浮点运算来模拟定点运算,所以训练过程中的量化结果与真实量化结果是有差异的。

Tensorflow中开源了量化训练的方法,首先需要创建一个图,图中会自动添加伪量化节点,实现代码如下:

tf

delay_step表示多少个epoch后开启量化训练。

在tensorboard中能看到min、max节点,表示开启了量化训练。

保存完训练结果后,要重新创建一个测试用的Graph,实现如下:

g 

将网络模型结构与参数用pb文件保存

from 

用Netron打开生成的pb文件,如下图。可以看到pb中生成了伪量化节点,伪量化节点中保存了min、max,意味着转换成功。

最后将pb转换为tflite格式,可以使用Tensorflow提供官方转换工具toco。推荐用如下简单的Python API来转。

import 

对于全整数模型,输入为 uint8。mean 和 std_dev values 指定在训练模型时这些 UINT8 的值是如何值映射到输入的浮点值。mean 是 0 到 255 之间的整数值,映射到浮点数 0.0f。std_dev = 255 /(float_max - float_min) 。

用Netron打开生成的pb文件,如下图,若伪量化节点被消除,则意味着uint8量化转换成功。

最后可以通过tf.lite.Interpreter加载tflite模型,测试量化后的inference结果。

import 


说几点注意事项:

1.train和inference创建的是不同的量化网络,前者是伪量化,后者是真实的量化,故结果会不同。

2.尽量使用标准的网络(尽量不要使用dilated conv)。

3.激活层目前尽量使用relu,relu6和identity。sigmoid,softmax量化误差会比较大。

4.先训练float,在训练好的float上微调quant8,即尽量在float网络收敛后再进行量化训练。

5.尽量在conv时传入激活函数和bn的参数,否则bn在folding时会失败。支持的api为slim.convolution2d或tf.contrib.layers.conv2d。

参考链接:

https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/quantize?spm=a2c4e.10696291.0.0.1b6519a4o0pfwg

https://tensorflow.google.cn/lite/convert/python_api?hl=zh-cn

https://stackoverflow.com/questions/53680660/attributeerror-module-tensorflow-has-no-attribute-lite-in-keras-model-to-te

https://tensorflow.google.cn/lite/convert/quantization

码字不易,看完记得点赞哦~

tensorflow sigmoid 如何计算训练数据的正确率_量化训练:Quantization Aware Training in Tensorflow(一)...相关推荐

  1. tensorflow sigmoid 如何计算训练数据的正确率_用于高级机器学习的自定义TensorFlow损失函数...

    在本文中,我们将看看: 在高级机器学习(ML)应用程序中使用自定义损失函数 定义自定义损失函数并集成到基本Tensorflow神经网络模型 一个简单的知识蒸馏学习的例子 介绍 机器学习中预定义的损失函 ...

  2. tensorflow sigmoid 如何计算训练数据的正确率_“来自蒙娜丽莎的凝视”— 结合 TensorFlow.js 和深度学习实现...

    客座博文 / Emily Xie,软件工程师 背景 坊间传闻,当您在房间里走动时,蒙娜丽莎的眼睛会一直盯着您. 这就是所谓的"蒙娜丽莎效应".兴趣使然,我最近就编写了一个可互动的数 ...

  3. mysql实验训练2 数据查询操作_实验训练2:数据查询操作.doc

    实验训练2:数据查询操作.doc 实验训练2数据查询操作请到电脑端查看实验目的基于实验1创建的汽车用品网上商城数据库Shopping,理解MySQL运算符.函数.谓词,练习Select语句的操作方法. ...

  4. 多元回归训练数据和测试数据_回归基础-数据结构提醒,如果和切换之后的寿命...

    多元回归训练数据和测试数据 I just had a great one on one coding learning session with a good friend of mine over ...

  5. 训练不出结果_智能训练仪:专业化智能防控近视训练设备

    视觉训练精准化,近视防控效果佳 智能训练仪小百科                                                                        智能训练 ...

  6. mysql实验训练2 数据查询操作_实验训练2:数据查询操作

    <实验训练2:数据查询操作>由会员分享,可在线阅读,更多相关<实验训练2:数据查询操作(6页珍藏版)>请在人人文库网上搜索. 1.实验训练2:数据查询操作请到电脑端查看实验目的 ...

  7. 从Dataframe训练数据,构造可迭代训练的batch数据

    最主要的方法是: data_loader = Data.DataLoader(dataset=Data.TensorDataset(x, y), # 封装进Data.TensorDataset(ten ...

  8. python量化分析数据_Python数据分析_量化分析.pdf

    法律声明  本课件包括:演示文稿,示例,代码,题库,视频和声 音等,小象学院拥有完全知识产权的权利:只限于善意 学习者在本课程使用,不得在课程范围外向任何第三方 散播.任何其他人或机构不得盗版.复制 ...

  9. python训练模型、如何得到模型训练总时长_模型训练时间的估算

    模型训练时间的估算 昨天群里一个朋友训练一个BERT句子对模型,使用的是CPU来进行训练,由于代码是BERT官方代码,并没有显示训练需要的总时间,所以训练的时候只能等待.他截图发了基本的信息,想知道训 ...

最新文章

  1. 服务器架设笔记——搭建用户注册和验证功能
  2. 网络推广外包——网络推广外包专员浅析网站流量应该如何提升?
  3. 关于HTTP和HTTPS的区别
  4. 前端学习(2905):用vite的2.0构建程序
  5. VScode 透明背景设置
  6. 【JSP开发】有关session的一些重要的知识点
  7. pure CSS3 triangle icon
  8. 工程师,你的钱究竟从哪来?
  9. python从html中提取文本_使用Python从HTML中提取可读文本?
  10. 实验2-4-6 求幂之和 (C语言)
  11. ant vue 设置中文_ant design vue导航菜单与路由配置操作
  12. 路由器、交换机配置命令简写对照表
  13. Linux虚拟网络基础——Bridge
  14. 手机扫描识别Vin码识别
  15. 小黄鸡 php,Simsimi (小黄鸡) API接口(PHP)公布,小黄鸡API接口非官方PHP版本来啦...
  16. 深度学习中神经网络的几种权重初始化方法
  17. 韩天峰php教程,韩天峰 - Swoole4-全新的PHP编程模式
  18. window 和linux系统分隔符的不同
  19. array_column() expects parameter 1 to be array, array given
  20. 通过西联快汇收取Google Adsense收入的详细步骤

热门文章

  1. HttpClient 解释
  2. C#学习常用类(1003)---Timer类(System.Timers.Timer)
  3. 移动应用广告盈利-KeyMob移动广告聚合平台
  4. apache commons - lang 常用方法记录
  5. Custom Sharepoint Lookup Field
  6. Django学习入门步骤 教程步骤 python
  7. Linux系统下智能DNS服务器BIND9.7.2安装配置
  8. JavaScript 使用面向对象的技术创建高级 Web 应用程序
  9. 2-7 微信摇一摇_实现分析
  10. 计算机专业虽然好,但是也要有这些潜质才去选择