7. 网络结构的修剪

网络结构的压缩是近年来研究热点,接下来的两节,我们将介绍Deep Compression的两个策略网络修剪和网络权重共享量化的实现方法,我们通过mnist的LeNet5作为例子,而其他网络的实现也是类似的。

关于Deep Compression的原理,可以参见其论文:Han S, Mao H, Dally W J. Deep compression: Compressing deep neural networks with pruning, trained quantization and huffman coding[C]. In Proc. International Conference on Learning Representations. 2016.

所谓的网络修剪的概念,并不复杂,其大体思想是将不重要的权重连接删除,只保留最重要的连接,而什么是最重要的连接,一般作为是权重值接近0的连接越不重要,有些也认为权重值的Hession值越小的越不重要,而计算Hession值的计算太复杂了,这里参考Deep Compression通用思想将权重值接近0的视为不重要的权重的连接。

网络修剪的第二个问题是,网络修剪率如何设置,对某层到底应该删除多少连接。直观上说,如何某层连接越多,其该删除的连接也越多,比如全连接层的修剪率就要比卷积层多。但如何设置呢?目前的方法一般是采用试错实验来考察不同修剪率对网络的影响来确定合适的修剪率。

网络修剪的第三个问题是如何在修剪后,保证网络精确度不变。可以明确的说,在直接删除网络部分连接后,网络精度肯定会下降的。因此要保证网络精确度不变,需要对修剪后的网络进行重新训练,在多次重新训练后,网络的精度会提升,达到原始网络的精度,甚至在一些情况下,由于网络的稀疏度提高,减少了网络的过拟合,从而达到提高网络精度的效果。

值得注意的是,修剪后网络中,值为0的权重连接在重新训练过程中,将会一直保持为0。

7.1 网络权重的修剪

    def prune(threshold, test_net, layers):sqarse_net = {}for i, layer in enumerate(layers):print '\n============  Pruning %s : threshold=%0.2f   ============' % (layer,threshold[i])W = test_net.params[layer][0].datab = test_net.params[layer][1].datahi = np.max(np.abs(W.flatten()))hi = np.sort(-np.abs(W.flatten()))[int((len(W.flatten())-1)* threshold[i])]# abs(val)  = 0         ==> 0# abs(val) >= threshold ==> 1interpolated = np.interp(np.abs(W), [0, hi * threshold[i], 999999999.0], [0.0, 1.0, 1.0])# 小于阈值的权重被随机修剪random_samps = np.random.rand(len(W.flatten()))random_samps.shape = W.shape# 修剪阈值# mask = (random_samps 

7.2 考察不同修剪率下的网络精度变化

    def eval_prune_threshold(threshold_list, test_prototxt, caffemodel, prune_layers):def net_prune(threshold, test_prototx, caffemodel, prune_layers):test_net = caffe.Net(test_prototx, caffemodel, caffe.TEST)return prune(threshold, test_net, prune_layers)accuracy = []for threshold in threshold_list:results = net_prune(threshold, test_prototxt, caffemodel, prune_layers)print 'threshold: ', results[0]print '\ntotal_percentage: ', results[1]print '\npercentage_list: ', results[2]print '\ntest_loss: ', results[3]print '\naccuracy: ', results[4]accuracy.append(results[4])plt.plot(accuracy,'r.')plt.show()

下图显示不同层的不同修剪率对整个网络精度的影响,以下是修剪率实验设置

    test_threshold_list = [[0.3, 1 ,1 ,1], [0.4, 1 ,1 ,1], [0.5, 1 ,1 ,1], [0.6, 1 ,1 ,1], [0.7, 1 ,1 ,1],[1, 0.05, 1, 1], [1, 0.1, 1, 1], [1, 0.15, 1, 1], [1, 0.2, 1, 1], [1, 0.3, 1, 1],[1, 1, 0.05, 1], [1, 1, 0.1, 1], [1, 1, 0.15, 1], [1, 1, 0.2, 1], [1, 1, 0.3, 1],[1, 1, 1, 0.05], [1, 1, 1, 0.1], [1, 1, 1, 0.15], [1, 1, 1, 0.2], [1, 1, 1, 0.3]]

上面每个数组都有4个值,分别表示'conv1','conv2','ip1','ip2'各层的修剪率,为1表示不修剪,为0.3表示只保留权重值最大的30%的连接。

根据图上,我们可以选择'conv1','conv2','ip1','ip2'各层的修剪率分别为[0.3, 0.1, 0.01, 0.2]

7.3 修剪网络的重新训练

    # 迭代训练修剪后网络def retrain_pruned(solver, pruned_caffemodel, threshold, prune_layers):#solver = caffe.SGDSolver(solver_proto)retrain_iter = 20accuracys = []for i in range(retrain_iter):solver.net.copy_from(pruned_caffemodel)# solver.solve()solver.step(500)_,_,_,_,accuracy=prune(threshold, solver.test_nets[0], prune_layers)solver.test_nets[0].save(pruned_caffemodel)accuracys.append(accuracy)plt.plot(accuracys, 'r.-')plt.show()

重新迭代训练时,其精度的变化图,可以看出随着迭代次数增加,其精确度逐渐增加。最终大概只保留了2%左右的权重连接,就达到了原来的精确度。

7.4 稀疏结构的存储

实际上这里的网络修剪并不会在实际内存上减少网络的大小,只会减少网络模型的存储空间,因为该稀疏结构并不是一个通用结构,而是一组随机分布的结构,因此该稀疏结构我们是通过spicy的CSC格式来存储的。

所谓CSC格式,即为按行展开的形式,其将稀疏的矩阵按行展开成一列,只保存不为0的权重值及该值在矩阵中的相对位置。同理还有按列展开的形式CSR。

        test_net.params[layer][0].data[...] = W# net.params[layer][0].mask[...] = maskcsc_W, csc_W_indx = dense_to_sparse_csc(W.flatten(), 8)dense_W = sparse_to_dense_csc(csc_W, csc_W_indx)sqarse_net[layer + '_W'] = csc_Wsqarse_net[layer + '_W_indx'] = csc_W_indx

7.5 具体代码下载

GitHub仓库Caffe-Python-Tutorial中的prune.py

项目地址:https://github.com/tostq/Caffe-Python-Tutorial

【用Python学习Caffe】7. 网络结构的修剪相关推荐

  1. 深度学习Caffe 入门理解使用教程

    2019独角兽企业重金招聘Python工程师标准>>> 1.首先caffe 安装我就不解释了 如果有人安装不会的话 可以加我qq 1050316096 ,我会按照使用方式来介绍,首先 ...

  2. python sorted下标_Python学习教程(Python学习路线):第七天-字符串和常用数据结构

    Python学习教程(Python学习路线):字符串和常用数据结构 使用字符串 第二次世界大战促使了现代电子计算机的诞生,当初的想法很简单,就是用计算机来计算导弹的弹道,因此在计算机刚刚诞生的那个年代 ...

  3. caffe入门学习:caffe.Classifier的使用

    caffe入门学习:caffe.Classifier的使用 在学习pycaffe的时候,官方一直用到的案例就是net=caffe.net(.../deploy.protxt,..../xxx.caff ...

  4. SIGIA_4P python学习 列表 字典 集合 面对对象编程 闭包 装饰器 函数式编程 作用域 异常处理

    SIGIA_4P python学习 列表 字典 集合 面对对象编程 闭包 装饰器 函数式编程 作用域 异常处理 本文连接 简介 SIGIA_4P 网址 a. 课程OKR Objectives and ...

  5. python学习-简单图像识别分类

    python学习-图像识别 这是我从零基础开始学习的图像识别,当然用的是容易上手的python来写,持续更新中,记录我学习python基础到图像识别应用的一步步过程和踩过的一些坑.最终实现得到自己的训 ...

  6. 深度学习caffe(4)——caffe配置(GPU)

    电脑:win7  64位,NVIDIA GeForce GTX1080 Ti,visual studio 2013. 深度学习caffe(1)--windows配置caffe(vs2013+pytho ...

  7. pygame是python的一个库吗,python学习pygame,,基本库导入impor

    python学习pygame,,基本库导入impor 基本库导入 import pygame import sys from pygame.locals import * 初始化 pygame.ini ...

  8. python科学计数法转换_对比Python学习Go 基本数据结构

    公众号文章不方便更新,可关注底部「阅读原文」博客,文章随时更新. 本篇是「对比 Python 学习 Go」[1] 系列的第三篇,本篇文章我们来看下 Go 的基本数据结构.Go 的环境搭建,可参考之前的 ...

  9. python学习------tab补全

    python学习------tab补全   python也可以进行tab键补全 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 #!/usr/bin/env  ...

最新文章

  1. Microsoft月度中文速递
  2. 1.11 对象的销毁
  3. java 读utf-8 xml_〖JAVA经验〗JDom输出UTF-8的XML完美解决方法
  4. php hugepage,【原创】解决Redis启动报错:Transparent Huge Pages (THP) support enabled in your kernel...
  5. win7 php 上传文件,在LNMP原来的基础上,win7环境下如何上传PHP文件到Linux环境下...
  6. C++ 运算符重载(一) | 输入/输出,相等/不等,复合赋值,下标,自增/自减,成员访问运算符
  7. Java Scanner nextLine()方法与示例
  8. 并行计算(二)——通讯
  9. Android 音频开发(一) 基础入门篇
  10. mysql服务器cpu资源占用满
  11. 图的连通度问题的求法(转)
  12. 超级好用的画图吸色工具FastSton Capture
  13. 别了甲骨文,别了拉里·埃里森!
  14. 高、低成本MEMS惯导系统姿态、位置、速度更新算法的对比
  15. 精准关键词获取-行业搜索词分析
  16. Import Legacy CCSv3.3 Project到CCS5.5.0时出错
  17. unity接入讯飞AIUI(Windows SDK)
  18. 全球IEEE期刊大全(综合整理,附原文论文下载地址)
  19. 九九乘法表c语言编程java,九九乘法表(c语言和java语言)+心得
  20. Oracle数据导入导出详解

热门文章

  1. Android扫描车牌,车牌拍照识别SDK
  2. 一站放心购全球:亚马逊海外购开启2022年黑五全球购物季
  3. android 调出键盘表情_Android 显示输入法中的emoji表情以及String字符
  4. 倍数(Python)
  5. 抓紧收藏,9大短视频自媒体工具,帮你快速月入过万,不真人出镜
  6. HANA XS 匿名访问
  7. 基于BIM轻量化的智能建造OA管理系统
  8. css动画与渐变案例,使用动画和渐变做一个背景动态网页
  9. 服务器登录信息记录,服务器记录远程桌面登录的信息
  10. android获取手机IMSI号