1.distiller剪枝模块的使用

(1)distiller自带剪枝实例测试

distiller自带一些测试实例如ResNet56+cifar-10,下面是对ResNet56+cifar-10的测试:

  • 测试前准备

  • yaml文件(注意:这里的yaml文件是coder配置好的,具体到自己的模型需要先对自己的model进行一次Sparsity Analysis,然后自己配置该文件) 在剪枝时所用到的yaml文件作用主要是配置了一些剪枝所需要的必要信息,比如下面ResNet56所需要用的yaml配置文件(路径:distiller/examples/pruning_filters_for_efficient_convnets/resnet56_cifar_filter_rank.yaml):
version: 1                                        # 版本
pruners:   filter_pruner_60:                               # 后面的60表示剪掉60%的Filters,如[16, 16, 3, 3]剪掉之后就是[7, 16, 3, 3]class: 'L1RankedStructureParameterPruner' # 表示所使用的算法,这里使用L1Rankgroup_type: Filters                              # 表示剪切类型,一般两种Filters/Channeldesired_sparsity: 0.6                         # 剪掉60%的Filtersweights: [                                    # 下面是一些具体的需要剪切的权值module.layer1.0.conv1.weight,module.layer1.1.conv1.weight,module.layer1.2.conv1.weight,module.layer1.3.conv1.weight,module.layer1.4.conv1.weight,module.layer1.5.conv1.weight,module.layer1.6.conv1.weight,module.layer1.7.conv1.weight,module.layer1.8.conv1.weight]filter_pruner_50:                                # 同上class: 'L1RankedStructureParameterPruner'group_type: Filtersdesired_sparsity: 0.5weights: [module.layer2.1.conv1.weight,module.layer2.2.conv1.weight,module.layer2.3.conv1.weight,module.layer2.4.conv1.weight,module.layer2.6.conv1.weight,module.layer2.7.conv1.weight]filter_pruner_10:                                 # 同上class: 'L1RankedStructureParameterPruner'group_type: Filtersdesired_sparsity: 0.1weights: [module.layer3.1.conv1.weight]filter_pruner_30:                                 # 同上class: 'L1RankedStructureParameterPruner'group_type: Filtersdesired_sparsity: 0.3weights: [module.layer3.2.conv1.weight,module.layer3.3.conv1.weight,module.layer3.5.conv1.weight,module.layer3.6.conv1.weight,module.layer3.7.conv1.weight,module.layer3.8.conv1.weight]extensions:net_thinner:class: 'FilterRemover'thinning_func_str: remove_filtersarch: 'resnet56_cifar' # 使用的网络dataset: 'cifar10' # 数据集lr_schedulers:exp_finetuning_lr:class: ExponentialLRgamma: 0.95policies:- pruner:instance_name: filter_pruner_60epochs: [0]- pruner:instance_name: filter_pruner_50epochs: [0]- pruner:instance_name: filter_pruner_30epochs: [0]- pruner:instance_name: filter_pruner_10epochs: [0]- extension:instance_name: net_thinnerepochs: [0]- lr_scheduler:instance_name: exp_finetuning_lrstarting_epoch: 10ending_epoch: 300frequency: 1
  • 准备ResNet-56需要的模型文件,可下载:

    https://s3-us-west-1.amazonaws.com/nndistiller/pruning_filters_for_efficient_convnets/checkpoint.resnet56_cifar_baseline.pth.tar

  • 剪枝

找到compress_classifier.py文件,如下:

 $python3 compress_classifier.py -a=resnet56_cifar -p=50 ../../../data.cifar10 --epochs=70 --lr=0.1 --compress=../pruning_filters_for_efficient_convnets/resnet56_cifar_filter_rank.yaml --resume-from=checkpoint.resnet56_cifar_baseline.pth.tar --reset-optimizer --vs=0

参数解释: -a 表示模型名称(这里是工具自带的模型名称,其他的如resnet32_cifar, resnet44_cifar, resnet56_cifar等等 cifar的模型代码文件位于distiller/models/cifar10/resnet_cifar.py)

-p表示每隔多少打印一次

../../../data.cifar10是数据集路径

--epochs 表示剪枝过后继续训练次数

--compress 表示所用的‘策略’(compress_scheduler),一般是yaml文件的路径

--resume-from 表示保存的模型的路径

--reset-optimizer 如果设置此参数,那么start_epoch=0,将optimizer重置为SGD, 学习绿设置为传入的学习率

--vs validation-split

具体的其他参数参看distiller/apputils/image_classifier.py文件和distiller/quantization/range_linear.py文件以及github上参数解释。

运行时会对模型进行剪枝,然后在测试集上测试,打印出top1和top5以及loss,运行结束后量化模型会保存在logs下。

(2)distiller对自己的模型剪枝

  • 具体流程:1.Sparsity Analysis 分析各层weight的sensitivity,即对模型各个部分的稀疏性和可以pruning的程度有个了解。 2. Yaml file create 创建一个属于自己模型的配置文件,里面各层的稀疏度是由第一阶段分析出来的sensitivity而来的 3.Thinning 对网络进行真正的剪枝。

(3)代码

  • 代码已全部整理在 https://github.com/BlossomingL/compression_tool
  • distiller修改代码在 https://github.com/BlossomingL/Distiller
  • 有用的话麻烦手动点一下小星星,谢谢!

模型压缩工具Distiller-剪枝相关推荐

  1. FPGA加速BCNN,模型20倍剪枝率、边缘设备超5000帧/秒推理吞吐量

    ©作者 | 机器之心编辑部 来源 | 机器之心 来自康涅狄格大学等机构的研究者提出了一种基于结构剪枝的 BCNN 加速器,它能以较小的准确率损失获得 20 倍的剪枝率,并且在边缘设备上提供了超过 50 ...

  2. 模型如何压缩?使用轻量化的模型压缩技术剪枝(pruning)

    深度学习模型参数太多,本地服务器部署没有问题,但是如果部署到移动端.边缘端,像手机.树莓派等,它们的性能不能满足,所以我们要压缩模型大小,让他们可以部署到边缘端 模型压缩:使用轻量化的模型压缩技术,如 ...

  3. 【嵌入式AI】CNN模型压缩(剪枝,量化)详解与tensorflow实验

    1,CNN模型压缩综述 1 模型压缩的必要性及可行性 (1)必要性:首先是资源受限,其次在许多网络结构中,如VGG-16网络,参数数量1亿3千多万,占用500MB空间,需要进行309亿次浮点运算才能完 ...

  4. 基于Distiller的模型压缩工具简介

    Reference: https://github.com/NervanaSystems/distiller https://nervanasystems.github.io/distiller/in ...

  5. 模型压缩工具Distiller-INT8量化

    1.distiller工具介绍    Distiller是一个开源的Python软件包,用于神经网络压缩研究.网络压缩可以减少神经网络的内存占用,提高推理速度并节省能源.Distiller提供了一个P ...

  6. 《AI系统周刊》第4期:DNN模型压缩之剪枝(Pruning)

    No.04 智源社区 AI系统组 A I 系  统 研究 观点 资源 活动 关于周刊 AI系统是当前人工智能领域极具现实意义与前瞻性的研究热点之一,为了帮助研究与工程人员了解这一领域的进展和资讯,我们 ...

  7. 检测多边形是否重叠_只要保留定位感知通道,目标检测模型也能剪枝70%参数

    作者 | Bbuf 编辑 | 杨晓凡 下面要介绍的论文发于2019,题为「Localization-aware Channel Pruning for Object Detection」 axriv地 ...

  8. 《模型轻量化-剪枝蒸馏量化系列》YOLOv5无损剪枝(附源码)

    今天文章代码不涉密,数据不涉密,使用的是网上开源代码,做了修改,主要介绍如何实现的,另外,数据使用开放数据VisDrone的小部分数据来测试~ 今天的文章很短,主要附带一个视频讲解运行过程,我修改的地 ...

  9. 模型压缩Distiller学习

    摘要 在神经网络模型中,通过正则化或剪枝来诱导稀疏性是压缩网络的一种方法(量化是另一种方法).稀疏神经网络具有速度快.体积小和能量大的优点 总结 整个仓库庞大,不适合我这样的初学者,而且你要压缩,裁剪 ...

最新文章

  1. 技术图文:举例详解Python中 split() 函数的使用方法
  2. 面向对象程序设计基本概念
  3. kafka基本操作:创建topic、生产/消费消息(同一消费组均分消息;不同消费组订阅消息)
  4. 值得收藏!268条PCB layout设计规范
  5. linux 提示库文件,Linux系统下确实库文件的解决办法
  6. 畅易阁老是显示服务器忙,畅易阁全服开放 盘点天龙玩家卖号的几大原因
  7. Vue.js视频教程
  8. q函数表格怎么看_会计表格函数玩不会?送你会计表格函数公式大全,财务人都在用...
  9. 英特尔一口气发布了三款处理器、两款存储、一款以太网适配器
  10. android 禁止用户访问u盘_如何禁止u盘复制文件 禁止u盘复制文件方法【介绍】
  11. iterm2上传文件到linux,在iTerm2中使用Zmodem实现快速传输文件
  12. matlab中.m文件访问simulink
  13. DongDong认亲戚(map+并查集)
  14. mysql sid是什么_数据库名、数据库实例、全局数据库名、服务名、SID等的区别
  15. chrome extensions 中的交互
  16. 在Sdx中使用xfOpenCV
  17. 【Vue系列】Vue3.0知识点汇总整理
  18. ubuntu 18.04 安装 搜狗拼音输入法只有中文标点,没有文字
  19. ndnSIM学习(十)——apps之ndn-producer.cpp和ndn-consumer.cpp源码分析
  20. 水晶报表打印出错,未能加载文件或程序集“CrystalDecisions.CrystalReports.Engine, Version=10.5.3700.0

热门文章

  1. 悲观锁、乐观锁以及分布式锁
  2. iPad、iPad Pro反复自动重启怎么办?
  3. ubuntu系统卸载软件方法
  4. 【软件定义汽车】【硬件篇】特斯拉FSD芯片
  5. python pip 安装使用国内镜像源
  6. Java并发基础知识(五)
  7. 数学与计算机学院校友会,福州大学数学与计算机科学学院厦门校友会成立
  8. Python学习笔记(十三):异常处理机制
  9. 人工智能领域的十大算法
  10. Pytorch中的Conv1d()和Conv2d()函数