Onnx_calibrate calibration代码原理分析

Calibration的思想是通过一堆验证数据集输入到网络中,统计每一层layer的输出值,通过对比量化前后数据统计分布之间的KL散度找到最佳的映射值T.具体参考NVIDIAGTC2017的ppt。

def onnx_runtime(model_path,image_files):'''Helper function run input image,and output each node tensor to calibration.parameter model_path: the onnx modelparameter image_files: calibrate input imagesreturn: '''sess = rt.InferenceSession(model_path)Input_name = sess.get_inputs()[0].namemodel_outputs = sess.get_outputs()print(len(model_outputs)) # 1.对每一个需要量化的node 计算其输入的tensor中最大值for i,image in enumerate(image_files):img = cv2.imread(image)img = cv2.resize(img,(224,224))img = np.transpose(img,(2,0,1))img = img.astype('float32')/255img = img.reshape(1,224,224,3)start_time = datetime.datetime.now()for node in quantize_node_list:Output_name = node.output_nameres = sess.run([Output_name],{Input_name:img})node.initial_input_max(np.array(res).flatten())end_time = datetime.datetime.now()print('it`s cost :', (end_time - start_time))if i % 100 == 0:print('loop stage 1 : %d/%d' % (i,len(image_files)))# calculate statistic node scope and interval distribution# 2.计算统计值分布的间隔用最大值除2048,即间隔的大小用于还原计算Tfor node in quantize_node_list:node.initial_input_distubution_interval()# for each nodes# collect histograms of activations# 3.得到每个node的数据分布,对于一个node得到的是在(0,max)划分2048块每个块内数据落在其中的数量统计值,假设区间[0,1]有20个数值落在里面。print('\n Collect histograms of activations: ')for i, image in enumerate(image_files):img = cv2.imread(image)img = cv2.resize(img,(224,224))img = np.transpose(img,(2,0,1))#print(img.shape)img = img.astype('float32')/255img = img.reshape(1,224,224,3)for node in quantize_node_list:Output_name = node.output_nameres = sess.run([Output_name],{Input_name:img})node.initial_histograms(np.array(res).flatten())if i % 100 == 0:print('loop stage 2 : %d/%d' % (i,len(image_files)))# calculate threshold with KL divergence# 4. 核心计算KL散度for node in quantize_node_list:node.quantize_input()return None

前面三步分别对应的核心代码

   # 1. 对每一个需要量化的node 计算其输入的tensor中最大值def initial_input_max(self, input_data):# get the max value of inputmax_val = np.max(input_data)min_val = np.min(input_data)self.input_max = max(self.input_max, max(abs(max_val), abs(min_val)))# 2.计算统计值分布的间隔用最大值除2048,即间隔的大小用于还原计算Tdef initial_input_distubution_interval(self):self.input_distubution_interval = STATISTIC * self.input_max / INTERVAL_NUMprint("%-20s max_val : %-10.8f distribution_intervals : %-10.8f" % (self.node_name, self.input_max, self.input_distubution_interval))#3.得到每个node的数据分布,对于一个node得到的是在(0,max)划分2048块每个块内数据落在其中的数量统计值,假设区间[0,1]有20个数值落在里面。def initial_histograms(self, input_data):# collect histogram of every group channel inputth = self.input_max# hist:Number of values in the interval for each hist,hist_edge:array of dtype float for interval. range: change the max and min value for inputdata. hist, hist_edge = np.histogram(input_data, bins=INTERVAL_NUM, range=(0, th))self.input_distubution += hist

核心是如何做calibration

    def quantize_input(self):# calculate threshold  distribution = np.array(self.input_distubution)# pick threshold which minimizes KL divergencethreshold_bin = threshold_distribution(distribution) self.input_threshold = threshold_binthreshold = (threshold_bin + 0.5) * self.input_distubution_interval# get the activation calibration valueself.input_scale = QUANTIZE_NUM / threshold

calibration的过程

def threshold_distribution(distribution, target_bin=128):"""Return the best threshold value. Args:distribution: list, activations has been processed by histogram and normalize,size is 2048target_bin: int, the num of bin that is used by quantize, Int8 default value is 128Returns:target_threshold: int, num of bin with the minimum KL """   distribution = distribution[1:]length = distribution.sizethreshold_sum = sum(distribution[target_bin:])kl_divergence = np.zeros(length - target_bin)# 遍历从128到2048开始搜索for threshold in range(target_bin, length):sliced_nd_hist = copy.deepcopy(distribution[:threshold])# generate reference distribution p# 得到p比较简单,遍历的前threshold-1个组,将最后所有的组累加到threshold-1组上。p = sliced_nd_hist.copy()p[threshold-1] += threshold_sumthreshold_sum = threshold_sum - distribution[threshold]# is_nonzeros[k] indicates whether hist[k] is nonzero# 判断p中元素是否有0存在,得到的is_nonzeros=[1,1,1,1,0,....]类似的arrayis_nonzeros = (p != 0).astype(np.int64)# quantized_bins = np.zeros(target_bin, dtype=np.int64)# calculate how many bins should be merged to generate quantized distribution qnum_merged_bins = sliced_nd_hist.size // target_bin# merge hist into num_quantized_bins bins# 这里是量化的原理,并不是数值的fp32-int8,只是将数据分布合并到128个组中,注意理解for j in range(target_bin):start = j * num_merged_bins #按照组的大小得到新组(128个组)前后位置stop = start + num_merged_binsquantized_bins[j] = sliced_nd_hist[start:stop].sum()#属于同组的累加起来quantized_bins[-1] += sliced_nd_hist[target_bin * num_merged_bins:].sum()#最后末尾的数据全部累加到新组的最后一组中# expand quantized_bins into p.size bins. compare with quantizated_bins merge, that is inverse process# 将量化后的组重新扩大到与p相同大小的范围,就是按照前面量化的过程逆过来计算。q = np.zeros(sliced_nd_hist.size, dtype=np.float64)for j in range(target_bin):start = j * num_merged_bins #找起始位置if j == target_bin - 1:stop = -1else:stop = start + num_merged_bins #计算终止位置norm = is_nonzeros[start:stop].sum()if norm != 0:q[start:stop] = float(quantized_bins[j]) / float(norm)# 把数据平均分配,这是逆过程差异的地方,只能平均分配。q[p == 0] = 0# p = _smooth_distribution(p) # with some bugs, need to fix# q = _smooth_distribution(q)p[p == 0] = 0.0001q[q == 0] = 0.0001# calculate kl_divergence between q and pkl_divergence[threshold - target_bin] = stats.entropy(p, q)min_kl_divergence = np.argmin(kl_divergence)threshold_value = min_kl_divergence + target_binreturn threshold_value

2020/5/27 再次回顾记录:
calibration 的原理:

1.对于需要校正的op,得到其输入的tensor
2. 将其划分到2048个bin中,先得到tensor的最大值,然后除以2048,得到每个bin的区间,就可以统计每个bin的区间内tensor分布的数量,因此2048个bin是tensor数据的分布.
3. 遍历128 - 2048的范围找到合适的T.对于遍历的i,将2048个bin划分成了i个bin,其中0到i-2与2048的前i-2个bin是一致的,第i-1个bin等于2048的bin中i-1到2047的累计和.

3.1 现在需要将i个bin量化到128个bin,因此此时的i个bin是包含了所有数据的分布,希望将其映射到128个bin的范围中,假设i等于1280,那么每十个bin对应128中的一个bin,所以需要将i划分到128个值中,将i个bin中按照i//128(eg:等于10)累加起来,例如0到9个bin的数值累加起来对应第一个bin,这样就得到量化的128个bin分布.  3.2 反量化到i个bin与原始的i计算kl散度.那么反量化操作就是量化的逆操作,要想还原得到i个bin,就是需要将128个bin扩充到i个bin,做法是仍然按照i//128作为区间,还原的区间(eg:0到9)每个值都等于128中对应bin的值除以区间大小,即平均分配.从操作上看应该就是逆过程,反而把bin区间(eg:每10个取均值)平均化了,其实在大量数据遍历过程中是有效的.这样就得到两组i个bin计算其KL散度.  3.2 对于每个i都得到一个KL散度值,最后获取KL散度最小的那个i,即对应为t
  1. 得到t只是128-2048中的一个整数,需要还原得到最佳的T.最终得到的是每个需要校正的op都对应得到一个T,这便是calibration table.

转载请注明出处:https://blog.csdn.net/tbl1234567.作者:陶表犁

onnx_calibrate calibration代码原理分析相关推荐

  1. KMP算法之NEXT数组代码原理分析 - 数据结构和算法38

    KMP算法之NEXT数组代码原理分析 让编程改变世界 Change the world by program KMP算法之NEXT数组代码原理分析 NEXT数组:当模式匹配串T失配的时候,NEXT数组 ...

  2. 数据结构与算法之KMP算法中Next数组代码原理分析

    2019独角兽企业重金招聘Python工程师标准>>> 一.KMP算法之Next数组代码原理分析       1.Next数组定义 当模式匹配串T失配的时候,Next数组对应的元素指 ...

  3. 《一周学完光线追踪》学习 十一点五 离焦模糊代码原理分析

    蒙特卡洛光线追踪技术系列 见 蒙特卡洛光线追踪技术 首先分析一下生成随机Ray的程序: vec3 random_in_unit_disk() {vec3 p;do {p = 2.0*vec3(rand ...

  4. unity 双指触控(以及多指触摸的代码原理分析)

    双指触摸规律:从第一根触摸的手指开始 ,会从0开始为其编号,假设中间抬起手指,假设现在有两根手指,抬起编号为0的手指的话,将会导致原本编号为1的手指编号变为0,而当再次按下一根新手指时,原本编号0的手 ...

  5. 树莓派学习笔记(十六)编写内核驱动操控IO口代码原理分析

    驱动源码.测试源码可查看博文:内核驱动操控IO口源码(pin4引脚) 寄存器地址.引脚对应的位数等原理可查看博文:BCM2835芯片手册导读 1.驱动代码编写 框架查看博文:基于框架编写驱动 1.1 ...

  6. ADAS-开源环视360全景拼接代码原理分析与实现(一)

    引言 " 汽车360影像是一项比较先进的技术,它通过多个高清摄像头将车辆的外部环境进行拍摄,并将这些影像进行处理和融合,以生成一张完整的全景图像.这种技术已经被广泛应用于汽车行业,为驾驶员提 ...

  7. Redux-React 代码原理分析

    目标 react使用redux的主要目的是: 1)实现简洁统一的状态维护,提高代码可维护性: 2)实现简洁的注入依赖,避免重重传递参数: Plug Any Data Into Any Componen ...

  8. 后门BROOTKIT代码学习和原理分析

    周末闲来无事,想找点东西学习一下,随手翻到了之前看到的一篇关于brootkit的文章,知道它是用Bash写的一个后门程序.刚好最近在做Bash相关的工作,就想着学习一下这方面的知识,稍作整理之后就有了 ...

  9. Adaboost算法原理分析和实例+代码(简明易懂)

    Adaboost算法原理分析和实例+代码(简明易懂) [尊重原创,转载请注明出处] http://blog.csdn.net/guyuealian/article/details/70995333   ...

最新文章

  1. MySQL中改变相邻学生座位_力扣——换座位(数据库的题
  2. 代码中特殊的注释技术——TODO、FIXME和XXX的用处
  3. learn Linux sed command
  4. LVS(13)——DR模型准备工作及ip地址冲突问题
  5. leetcode 688. Knight Probability in Chessboard | 688. “马”在棋盘上的概率(dp,记忆化搜索)
  6. java 超时集合_Java之集合(二十三)SynchronousQueue
  7. 姑苏山塘飞雪披银装[组图]
  8. 怎么用python画风车_小清新风车短教程:10步教你绘制一副插画
  9. UNIX文件系统概述
  10. ffmpeg 推流MP4文件,采用rtmp协议
  11. 织梦dedecms采集规则,东方资讯娱乐新闻采集规则
  12. Update批量更新
  13. Verdi HW/SW co-debug 简单使用
  14. [AC自动机]luoguP3966
  15. 101平衡模式 DIR的理解
  16. linux下磁盘坏道修复,linux磁盘坏道修复记录
  17. 195元爱奇艺会员只卖5元 揭秘背后黑色产业链
  18. SF笔试编程1:幸运数
  19. 8/27 Python.Numpy.01
  20. APP渗透测试-----APK反编译

热门文章

  1. JavaWEB十七:bookShop项目 - 错误总结
  2. python 边缘计算_如何实现高效的边缘计算?边缘计算如何快速处理数据缺陷
  3. 服务器ipv4协议认证,基于TCP/IP应用层密码认证协议的研究
  4. 利用Disqus快速搭建评论系统
  5. [Vue warn]: Error in v-on handler (Promise/async): “TypeError: Cannot read property ‘status‘ of 问题详解
  6. 品创科技有限公司·食安溯源使用隐私政策服务协议
  7. 正则化(regularization)
  8. 【MOT】目标追踪DeepSORT与ByteTrack
  9. [JQ权威指南]JQ遍历JSON数据
  10. Oracle查询字段以外的内容,Oracle查询字段内容为非数字的记录