参考文章:

Pruning Filters for Efficient Convnets

Compressing deep neural nets

压缩神经网络 实验记录(剪枝 + rebirth + mobilenet)

为了在手机上加速运行深度学习模型,目前实现的方式基本分为两类:一是深度学习框架层面的加速,另一个方向是深度学习模型层面的加速。

深度学习模型的加速又可以分为采用新的卷积算子来加速模型,另一个方向是通过对已有模型进行剪枝操作得到一个参数更少的模型来加速模型。

通过观察深度学习模型,可以发现其中很多kernel的权重很小,均在-1~1之间震荡,对于这些绝对值很小的参数,可以视其对整体模型贡献很小,将其删除,然后将剩余的权重构成新的模型,以达到模型压缩,加速,并保证精准度不变的目的。

基本步骤:

1.实现原始网络,并将其训练到收敛,保存权重

2.观察对每一层的权重,判断其对模型的贡献大小,删除贡献较小的kernel,评判标准可以是std,sum(abs),mean等

3.当删除部分kernel后,会导致输出层的channel数变化,需要删除输出层对应kernel的对应channel

4.构建剪枝后的网络,加载剪枝后的权重,与原模型对比精准度。

5.使用较小的学习率,rebirth剪枝后的模型

6.重复第1步

上图展示了conv的kernel剪枝后导致的输出维度变化

对于conv层后面接续全连接层的情况:

conv层在接续全连接层前,会先reshape为一个维度。假设conv层输出为 (h,w,c),其会reshape为 h*w*c, 假设删除的kernel下标为[ 2,5,7],对应的conv输出通道也会减少 [2,5,7] 。reshape后 会减少 [2,5,7,...,h*w*2,h*w*5,h*w*7]。

全连接层接全连接层的逻辑基本和conv接conv层的逻辑一样。

一个使用mnist的简单示例

# 读取保存的权重和所有训练的var
model_path = './checkpoints/net_2018-12-19-10-05-17.ckpt-99900'
reader = tf.train.NewCheckpointReader(model_path)
all_variables = reader.get_variable_to_shape_map()
{'conv1/biases': [16],'conv1/weights': [3, 3, 1, 16],'conv2/biases': [32],'conv2/weights': [3, 3, 16, 32],'conv3/biases': [32],'conv3/weights': [3, 3, 32, 32],'fc1/biases': [128],'fc1/weights': [512, 128],'fc2/biases': [256],'fc2/weights': [128, 256],'global_step': [],'logits/biases': [10],'logits/weights': [256, 10]}
# 分析 conv1 的权重
conv1_weight = reader.get_tensor("conv1/weights")
# 计算每个kernel权重的和 (也可以使用其他指标,如std,mean等)
conv1_weight_sum = np.sum(conv1_weight, (0,1,2))
sort_conv1_weights = np.sort(conv1_weight_sum)
# 绘制conv1的
x = np.arange(0,len(sort_conv1_weights),step=1)
plt.plot(x,sort_conv1_weights)

# 保留权重和最大的8个kernel
pure_conv1_weight_index = np.where(conv1_weight_sum >= sort_conv1_weights[8])
pure_conv1_weight = conv1_weight[:,:,:,pure_conv1_weight_index[0]]
# conv1对应的bias 也做相同处理
conv1_bias = reader.get_tensor("conv1/biases")
pure_conv1_bias = conv1_bias[pure_conv1_weight_index[0]]
# 对后面接续的 conv 层的kernel做相同处理
conv2_weight = reader.get_tensor("conv2/weights")
conv2_bias = reader.get_tensor("conv2/biases")
conv2_weight = conv2_weight[:,:,pure_conv1_weight_index[0],:]

后面层重复以上操作

剪枝后结果对比

原始模型精度

剪枝后模型精度

模型权重大小从400多kb减小到了100多kb

jupyter文件及代码

深度学习 模型 剪枝相关推荐

  1. PyTorch 深度学习模型压缩开源库(含量化、剪枝、轻量化结构、BN融合)

    点击我爱计算机视觉标星,更快获取CVML新技术 本文为52CV群友666dzy666投稿,介绍了他最近开源的PyTorch模型压缩库,该库开源不到20天已经收获 219 颗星,是最近值得关注的模型压缩 ...

  2. 深度学习模型压缩(量化、剪枝、轻量化结构、batch-normalization融合)

    "目前在深度学习领域分类两个派别,一派为学院派,研究强大.复杂的模型网络和实验方法,为了追求更高的性能:另一派为工程派,旨在将算法更稳定.高效的落地在硬件平台上,效率是其追求的目标.复杂的模 ...

  3. 深度学习模型压缩算法综述(二):模型剪枝算法

    深度学习模型压缩算法综述(二):模型剪枝算法 本文禁止转载 联系作者: 模型剪枝算法 : L1(L2)NormFilterPruner: 主要思想: 修剪策略: 微调策略: 残差网络的处理: 缺点: ...

  4. 深度学习模型压缩方法(3)-----模型剪枝(Pruning)

    link 前言 上一章,将基于核的稀疏化方法的模型压缩方法进行了介绍,提出了几篇值得大家去学习的论文,本章,将继续对深度学习模型压缩方法进行介绍,主要介绍的方向为基于模型裁剪的方法,由于本人主要研究的 ...

  5. 深度学习模型压缩与加速综述!

    ↑↑↑关注后"星标"Datawhale 每日干货 & 每月组队学习,不错过 Datawhale干货 作者:Pikachu5808,编辑:极市平台 来源丨https://zh ...

  6. 深度学习模型压缩与加速综述

    点击上方"小白学视觉",选择加"星标"或"置顶" 重磅干货,第一时间送达 导读 本文详细介绍了4种主流的压缩与加速技术:结构优化.剪枝.量化 ...

  7. 一文看懂深度学习模型压缩和加速

    点击上方"小白学视觉",选择加"星标"或"置顶" 重磅干货,第一时间送达 本文转自:opencv学堂 1 前言 近年来深度学习模型在计算机视 ...

  8. 不用GPU,稀疏化也能加速你的YOLOv3深度学习模型

    水木番 发自 凹非寺 来自|量子位 你还在为神经网络模型里的冗余信息烦恼吗? 或者手上只有CPU,对一些只能用昂贵的GPU建立的深度学习模型"望眼欲穿"吗? 最近,创业公司Neur ...

  9. 深度学习模型的中毒攻击与防御综述

    来源:专知本文约2000字,建议阅读5分钟本文首次综述了深度学习中的中毒攻击方法,回顾深度学习中的中毒攻击,分析了此类攻击存在的可能性,并研究了现有的针对这些攻击的防御措施.最后,对未来中毒攻击的研究 ...

最新文章

  1. 如何利用SOM网络进行柴油机故障诊断
  2. Symfony2Book16:Symfony2内部02-内核
  3. 解决Tensorflow 使用时cpu编译不支持警告
  4. python笔记: staticmethod classmethod
  5. kali linux下安装TOR
  6. java redis缓存实例_spring项目整合ehcache和redis缓存实例
  7. java正则匹配英文句号_Scala 正则表达式 0411
  8. acer电脑设置u盘启动方法
  9. Triumph X发布著名摄影师Kim Joong-man首个NFT系列
  10. [Iphone开发]如何在GDB中查看变量的值
  11. xuperchain 查看源码代码版本号
  12. 数学建模人口模型及matlab算法解
  13. USBclean for Mac(U盘病毒查杀工具)
  14. 移动硬盘显示要格式化怎么办
  15. SQL sever 查询及格率
  16. 进去计算机组策略的命令,组策略怎么打开,组策略命令打开方法
  17. WiFi远程监控,监控摄像头只有在WiFi环境才能使用吗
  18. 创蓝253云通讯paas平台PHP短信接口demo分享
  19. 【IDEA】IDEA的高级Debug技巧
  20. 计算机二级的Word知识点,计算机二级word知识点

热门文章

  1. 日常随笔: React useEffect中使用异步更新数据方法遇到的问题
  2. Jest+Enzyme测试React组件(上)
  3. 【C语言】输入一行字符,分别统计出其中英文字母 空格 数字和其他字符的个数
  4. linux atop日志查看,atop
  5. 巧把任意程序添加到Win10控制面板(添加“系统配置”为例)
  6. mysql复合主键做外键,mysql – 使用复合主键作为外键
  7. Error while executing: npm ERR! D:\Program Files\Git\cmd\git.EXE ls-remote -h -t git://github.com/ad
  8. 开源(离线)中文语音识别ASR(语音转文本)工具整理
  9. java servlet jsp (服务器端编程)
  10. LyScript 实现Hook隐藏调试器