NNI剪枝入门可参考:nni模型剪枝_benben044的博客-CSDN博客_nni 模型剪枝

1、背景

本文的剪枝操作针对CenterNet算法的BackBone,即MobileNetV3算法。

该Backbone最后的输出格式如下:

假如out = model(x),则x[-1]['hm']可获得heatmap的shape。

2、直接添加nni操作

直接添加的示例代码如下:

import torch
from torch import nn
from nni.compression.pytorch.pruning import L1NormPruner
from nni.compression.pytorch.speedup import ModelSpeedupclass hswish(nn.Module):def __init__(self):super(hswish, self).__init__()self.relu6 = nn.ReLU6(inplace=True)def forward(self, x):out = x * self.relu6(x + 3) / 6return outclass hsigmoid(nn.Module):def __init__(self):super(hsigmoid, self).__init__()self.relu6 = nn.ReLU6(inplace=True)def forward(self, x):out = self.relu6(x + 3) / 6return out# 注意力机制
class SE(nn.Module):def __init__(self, in_channels, reduce=4):super(SE, self).__init__()self.se = nn.Sequential(nn.AdaptiveAvgPool2d(1),nn.Conv2d(in_channels, in_channels // reduce, 1, bias=False),nn.BatchNorm2d(in_channels // reduce),nn.ReLU6(inplace=True),nn.Conv2d(in_channels // reduce, in_channels, 1, bias=False),nn.BatchNorm2d(in_channels),hsigmoid())def forward(self, x):out = self.se(x)out = x * outreturn outclass Block(nn.Module):def __init__(self, kernel_size, in_channels, expand_size, out_channels, stride, se=False, nolinear='RE'):super(Block, self).__init__()self.se = nn.Sequential()if se:self.se = SE(expand_size)if nolinear == 'RE':self.nolinear = nn.ReLU6(inplace=True)elif nolinear == 'HS':self.nolinear = hswish()self.block = nn.Sequential(nn.Conv2d(in_channels, expand_size, 1, stride=1, padding=0, bias=False),nn.BatchNorm2d(expand_size),self.nolinear,nn.Conv2d(expand_size, expand_size, kernel_size, stride=stride, padding=kernel_size // 2, groups=expand_size, bias=False),nn.BatchNorm2d(expand_size),self.se,self.nolinear,nn.Conv2d(expand_size, out_channels, 1, stride=1, padding=0, bias=False),nn.BatchNorm2d(out_channels))self.shortcut = nn.Sequential()if stride == 1 and in_channels != out_channels:self.shortcut = nn.Sequential(nn.Conv2d(in_channels, out_channels, 1, bias=False),nn.BatchNorm2d(out_channels))self.stride = stridedef forward(self, x):out = self.block(x)if self.stride == 1:out += self.shortcut(x)return outclass MobileNetV3(nn.Module):def __init__(self, class_num):super(MobileNetV3, self).__init__()self.conv1 = nn.Sequential(nn.Conv2d(3, 16, 3, stride=2, padding=1, bias=False),nn.BatchNorm2d(16),hswish())self.neck = nn.Sequential(Block(3, 16, 16, 16, 2, se=True),Block(3, 16, 72, 24, 2),Block(3, 24, 88, 24, 1),Block(5, 24, 96, 40, 2, se=True, nolinear='HS'),Block(5, 40, 240, 40, 1, se=True, nolinear='HS'),Block(5, 40, 240, 40, 1, se=True, nolinear='HS'),Block(5, 40, 120, 48, 1, se=True, nolinear='HS'),Block(5, 48, 144, 48, 1, se=True, nolinear='HS'),Block(5, 48, 288, 96, 2, se=True, nolinear='HS'),Block(5, 96, 576, 96, 1, se=True, nolinear='HS'),Block(5, 96, 576, 96, 1, se=True, nolinear='HS'),)self.conv2 = nn.Sequential(nn.Conv2d(96, 576, 1, bias=False),nn.BatchNorm2d(576),hswish())self.avgpool = nn.AdaptiveAvgPool2d(1)self.conv3 = nn.Sequential(nn.Conv2d(576, 1280, 2, bias=False),nn.BatchNorm2d(1280),hswish())self.hm = nn.Conv2d(20, class_num, kernel_size=1)self.wh = nn.Conv2d(20, 2, kernel_size=1)self.reg = nn.Conv2d(20, 2, kernel_size=1)def forward(self, x):x = self.conv1(x)x = self.neck(x)x = self.conv2(x)x = self.conv3(x)y = x.view(x.shape[0], -1, 128, 128)z = {}z['hm'] = self.hm(y)z['wh'] = self.wh(y)z['reg'] = self.reg(y)return [z]if __name__ == '__main__':model = MobileNetV3(10)print('-----------raw model------------')print(model)config_list = [{'sparsity_per_layer': 0.8,'op_types': ['Conv2d']}]pruner = L1NormPruner(model, config_list)_, masks = pruner.compress()for name, mask in masks.items():print(name, ' sparsity: ', '{:.2f}'.format(mask['weight'].sum() / mask['weight'].numel()))pruner._unwrap_model()ModelSpeedup(model, torch.rand(2, 3, 516, 516), masks).speedup_model()print('------------after speedup------------')print(model)

如果参考nni入门直接添加nni压缩的代码,则会报如下错误:
RuntimeError: Only tensors, lists, tuples of tensors, or dictionary of tensors can be output from traced functions。

 File "D:\programs\python37\lib\site-packages\nni\common\graph_utils.py", line 78, in _traceself.trace = torch.jit.trace(model, dummy_input, **kw_args)File "D:\programs\python37\lib\site-packages\torch\jit\_trace.py", line 742, in trace_module_class,File "D:\programs\python37\lib\site-packages\torch\jit\_trace.py", line 940, in trace_module_force_outplace,
RuntimeError: Only tensors, lists, tuples of tensors, or dictionary of tensors can be output from traced functions

原因,返回的数据不符合torch.jit.trace的要求,而示例model返回的是一个dict,它不是tensors | lists | tuples of tensors | dictionary of tensors中的一种。

所以需要对MobileNetv3进行改造,以满足torch.jit.trace的返回要求。

3、MobileNetV3针对NNI的改造

改造方法:

(1)将输出从dict修改为tuple形式

(2)hm、wh、reg的定义从__init__()函数移到forward中。因为hm中conv的in_channel是会变化的,未剪枝前是A,剪枝后是B,所以在__init__()中定义没法动态修改in_channel值,只能放到forward中进行处理。

以下代码只适用于CPU模式下,不适用GPU上运行。

改造后的示例代码如下:

import torch
from torch import nn
from nni.compression.pytorch.pruning import L1NormPruner
from nni.compression.pytorch.speedup import ModelSpeedupclass hswish(nn.Module):def __init__(self):super(hswish, self).__init__()self.relu6 = nn.ReLU6(inplace=True)def forward(self, x):out = x * self.relu6(x + 3) / 6return outclass hsigmoid(nn.Module):def __init__(self):super(hsigmoid, self).__init__()self.relu6 = nn.ReLU6(inplace=True)def forward(self, x):out = self.relu6(x + 3) / 6return out# 注意力机制
class SE(nn.Module):def __init__(self, in_channels, reduce=4):super(SE, self).__init__()self.se = nn.Sequential(nn.AdaptiveAvgPool2d(1),nn.Conv2d(in_channels, in_channels // reduce, 1, bias=False),nn.BatchNorm2d(in_channels // reduce),nn.ReLU6(inplace=True),nn.Conv2d(in_channels // reduce, in_channels, 1, bias=False),nn.BatchNorm2d(in_channels),hsigmoid())def forward(self, x):out = self.se(x)out = x * outreturn outclass Block(nn.Module):def __init__(self, kernel_size, in_channels, expand_size, out_channels, stride, se=False, nolinear='RE'):super(Block, self).__init__()self.se = nn.Sequential()if se:self.se = SE(expand_size)if nolinear == 'RE':self.nolinear = nn.ReLU6(inplace=True)elif nolinear == 'HS':self.nolinear = hswish()self.block = nn.Sequential(nn.Conv2d(in_channels, expand_size, 1, stride=1, padding=0, bias=False),nn.BatchNorm2d(expand_size),self.nolinear,nn.Conv2d(expand_size, expand_size, kernel_size, stride=stride, padding=kernel_size // 2, groups=expand_size, bias=False),nn.BatchNorm2d(expand_size),self.se,self.nolinear,nn.Conv2d(expand_size, out_channels, 1, stride=1, padding=0, bias=False),nn.BatchNorm2d(out_channels))self.shortcut = nn.Sequential()if stride == 1 and in_channels != out_channels:self.shortcut = nn.Sequential(nn.Conv2d(in_channels, out_channels, 1, bias=False),nn.BatchNorm2d(out_channels))self.stride = stridedef forward(self, x):out = self.block(x)if self.stride == 1:out += self.shortcut(x)return outclass MobileNetV3(nn.Module):def __init__(self, class_num, sparsity_ratio):super(MobileNetV3, self).__init__()self.conv1 = nn.Sequential(nn.Conv2d(3, 16, 3, stride=2, padding=1, bias=False),nn.BatchNorm2d(16),hswish())self.neck = nn.Sequential(Block(3, 16, 16, 16, 2, se=True),Block(3, 16, 72, 24, 2),Block(3, 24, 88, 24, 1),Block(5, 24, 96, 40, 2, se=True, nolinear='HS'),Block(5, 40, 240, 40, 1, se=True, nolinear='HS'),Block(5, 40, 240, 40, 1, se=True, nolinear='HS'),Block(5, 40, 120, 48, 1, se=True, nolinear='HS'),Block(5, 48, 144, 48, 1, se=True, nolinear='HS'),Block(5, 48, 288, 96, 2, se=True, nolinear='HS'),Block(5, 96, 576, 96, 1, se=True, nolinear='HS'),Block(5, 96, 576, 96, 1, se=True, nolinear='HS'),)self.conv2 = nn.Sequential(nn.Conv2d(96, 576, 1, bias=False),nn.BatchNorm2d(576),hswish())self.avgpool = nn.AdaptiveAvgPool2d(1)self.conv3 = nn.Sequential(nn.Conv2d(576, 1280, 2, bias=False),nn.BatchNorm2d(1280),hswish())self.class_num = class_numdef forward(self, x):x = self.conv1(x)x = self.neck(x)x = self.conv2(x)x = self.conv3(x)y = x.view(x.shape[0], -1, 128, 128)in_channel = y.shape[1]hm = nn.Conv2d(in_channel, self.class_num, kernel_size=1)wh = nn.Conv2d(in_channel, self.class_num, kernel_size=1)reg = nn.Conv2d(in_channel, self.class_num, kernel_size=1)return (hm(y), wh(y), reg(y))if __name__ == '__main__':model = MobileNetV3(10, 0.2)print('-----------raw model------------')print(model)config_list = [{'sparsity_per_layer': 0.2,'op_types': ['Conv2d']}]pruner = L1NormPruner(model, config_list)_, masks = pruner.compress()for name, mask in masks.items():print(name, ' sparsity: ', '{:.2f}'.format(mask['weight'].sum() / mask['weight'].numel()))pruner._unwrap_model()ModelSpeedup(model, torch.rand(2, 3, 516, 516), masks).speedup_model()print('------------after speedup------------')print(model)input = torch.randn(2, 3, 516, 516)   # batch_size =1 会报错out = model(input)print(out[0].shape)

4、cuda模式下适配CenterNet的MobileNetv3无法剪枝

上面第3段提到的方法只针对cpu,但是在gpu下是运行不成功的。

如果适配CenterNet的MobileNetV3不进行剪枝的话,如果在forward中定义hm、wh、reg的卷积方法,只需要改动3个地方,核心改动点如下:

但是一旦再加上NNI的代码,则会报错,报错信息为:“RuntimeError: Input type (torch.cuda.FloatTensor) and weight type (torch.FloatTensor)”。

报错原因参考:Pytorch出现RuntimeError: Input type (torch.cuda.FloatTensor) and weight type (torch.FloatTensor) - 水果+麦片 - 博客园

也就是说网络层的定义必须放在__init__()方法中,否则该问题无法避免。

所以,mobileNetv3针对CenterNet的剪枝+再训练就只能在CPU环境下进行。

5、cpu下训练效果

可以得到loss都无法收敛,所以nni剪枝暂告失败。

MobileNetV3基于NNI剪枝操作相关推荐

  1. 基于Alpha-Beta剪枝树的井字棋人机博弈实现

    1 Alpha-Beta剪枝树的简单介绍 Alpha-Beta剪枝的本质就是基于极小化极大算法的一种改进算法.因此先简单地介绍下极小化极大算法,这样有利于我们更好的理解Alpha-Beta剪枝算法. ...

  2. orm mysql_PHP基于ORM方式操作MySQL数据库实例

    本文实例讲述了PHP基于ORM方式操作MySQL数据库.分享给大家供大家参考,具体如下: ORM----Oriented Relationship Mapper,即用面向对象的方式来操作数据库.归根结 ...

  3. php swoole多进程,PHP基于swoole多进程操作示例

    本文实例讲述了PHP基于swoole多进程操作.分享给大家供大家参考,具体如下: 多个任务同时执行 将顺序执行的任务,转化为并行执行(任务在逻辑上可以并行执行) 比如,我们要对已知的用户数据进行判断, ...

  4. 基于 DocumentFormat.OpenXml 操作 Excel (1)-- 初识

    最近抽空研究了一下 基于DocumentFormat.OpenXml操作Excel,也把自己的理解记录下来,便于日后可以查阅. 各种系统中,导出Excel是一种很常见的功能,在C#/.Net 环境下, ...

  5. 一文带你学会基于SpringAop实现操作日志的记录

    前言 大家好,这里是经典鸡翅,今天给大家带来一篇基于SpringAop实现的操作日志记录的解决的方案.大家可能会说,切,操作日志记录这么简单的东西,老生常谈了.不! 网上的操作日志一般就是记录操作人, ...

  6. 基于 Bochs 的操作系统内核实现

    简介 Bochs 简介 Bochs(读音Box)是一个开源的模拟器(Emulator),它可以完全模拟x86/x64的硬件以及一些外围设备.与VirtualBox / VMware等虚拟机(Virtu ...

  7. oss客户端工具_干货 | 基于Go SDK操作京东云对象存储OSS的入门指南

    前言 本文介绍如何使用Go语言对京东云对象存储OSS进行基本的操作,帮助客户快速通过Go SDK接入京东云对象存储,提高应用开发的效率. 在实际操作之前,我们先看一下京东云OSS的API接口支持范围和 ...

  8. 虹科干货 | 仅需3步!教你如何基于Windows系统操作使用RELY-TSN-KIT评估套件

    虹科RELY-TSN-KIT是首款针对TSN的开箱即用的解决方案,它可以无缝实施确定性以太网网络,并从这些技术复杂性中抽象出用户设备和应用.该套件可评估基于IEEE 802.1AS同步的时间常识的重要 ...

  9. php京东云oss,干货 | 基于Go SDK操作京东云对象存储OSS的入门指南

    前言 本文介绍如何使用Go语言对京东云对象存储OSS进行基本的操作,帮助客户快速通过Go SDK接入京东云对象存储,提高应用开发的效率. 在实际操作之前,我们先看一下京东云OSS的API接口支持范围和 ...

最新文章

  1. python 多态 协议详解
  2. 使用 Chrome DevTools 调试 JavaScript
  3. 轻量级流程图控件GoJS示例连载(一):最小化
  4. 我の第一篇万字博文 | 带大家开开心心地进入Python世界
  5. Oracle 客户端 使用 expdp/impdp 示例 说明
  6. 【公益】开放一台Eureka注册中心给各位Spring Cloud爱好者
  7. OSPF通过MPLS ×××
  8. 用TensorFlow的Linear/DNNRegrressor预测数据
  9. 分布式事务实践 解决数据一致性 分布式事务实现:消息驱动模式
  10. StreamingAssets文件夹的读取异常
  11. C++/mfc错误总结
  12. 搞懂 Vision Transformer 原理和代码,看这篇技术综述就够了(五)
  13. 最新大数据案例分享:2019微信数据报告(图集)
  14. 王道机试 第十二章 动态规划 12.5 背包问题(0-1背包,完全背包,多重背包)
  15. arm-linux-gcc camke,Window平台基于CMake与linaro交叉编译arm程序
  16. 百度——测试开发实习生面试记录
  17. 解决win10小娜无法搜索本地应用程序
  18. 解决百度云管家导入未完成下载任务
  19. GeForce GTX 1050-2G驱动安装
  20. win10 GTX1060 安装CUDA+PyTorch GPU

热门文章

  1. 菜鸟到大神的上位历程,即学即用走向人生巅峰
  2. 研究生语音识别课程作业记录(三) 非特定人孤立词识别
  3. Flask成长笔记--依赖包操作
  4. 【matlab】:matlab的linspace函数解析
  5. python和java数据类型
  6. 高中数学40分怎么办_新高一第一次考试数学只考了40分,还有救吗?
  7. android投屏小米电视软件,小米投屏神器安卓版
  8. netstat 的各个 state 什么意思
  9. 软件设计模式“单例模式”和“工厂模式”
  10. 什么是Redis?为什么要用Redis?