本文参考:5-剪枝后模型参数赋值_哔哩哔哩_bilibiliz

https://github.com/foolwood/pytorch-slimming

一、模型剪枝理论说明

论文:Learning Efficient Convolutional Networks through Network Slimming

(1)卷积后得到多个特征图(channel=64, 128, 256…),这些图不一定都重要,所以量化计算特征图的重要性

(2)训练模型的时候需要加入一些策略,让权重参数有明显的大小之分,从而筛选重要的特征图

Channel scaling factors里面的数值为特征图的打分,直观理解为分值大的特征图需要保留,分值小的特征图可以去掉。

二、计算特征图重要性

Network slimming ,利用BN层中的缩放因子Ƴ

BN的理论支持:

,使得数据为(0,1)正态分布。

整体感觉是一个归一化操作,但是BN中需要额外引入两个可训练的参数:Ƴ和β

BatchNorm的本质:

(1)BN要做的就是把越来越偏离的分布给拉回来

(2)再重新规范化到均值为0方差为1的标准正态分布

(3)这样能够使得激活函数在数值层面更敏感,训练更快。

(4)产生的问题:经过BN之后,把数值分布强制在了非线性函数的线性区域中。

针对第(3)点解释:

在激活函数中,两边处于饱和区域不敏感,接近于0位置非饱和处于敏感区域。

针对第(4)点解释:

BN将数据强制压缩到中间红色区域的线性部分,F(x)只做仿射变化,F=sigmoid,多个仿射变化的叠加仍然是仿射变化,添加再多隐藏层与单层神经网络是等价的。

所以,BN需要保证一些非线性,对规范后的结果再进行变化

添加两个参数后重新训练:

,这两个参数是网络训练过程中得到的,而不是超参给的。

该公式相当于BN的逆变换,

相当于对正态分布进行一些改变,拉动一下,变一下形状,做适当的还原。

Ƴ值越大越重要,那么该特征图调整的幅度越大,说明该特征图越重要。

三、让特征图重要度两极分化更明显

使用L1正则化对参数进行稀疏操作。

L1求导后为:sign(Θ),相当于稳定前进,都为1,最后学成0了

L2求导后为:Θ,相当于越来越慢,很多参数都接近0,平滑。

论文核心:

四、剪枝流程

使用到的vgg模型架构:

import torch
import torch.nn as nn
import math
from torch.autograd import Variableclass vgg(nn.Module):def __init__(self, dataset='cifar10', init_weights=True, cfg=None):super(vgg, self).__init__()if cfg is None:cfg = [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512]self.feature = self.make_layers(cfg, True)if dataset == 'cifar10':num_classes = 10elif dataset == 'cifar100':num_classes = 100self.classifier = nn.Linear(cfg[-1], num_classes)if init_weights:self._initialize_weights()def make_layers(self, cfg, batch_norm=False):layers = []in_channels = 3for v in cfg:if v == 'M':layers += [nn.MaxPool2d(kernel_size=2, stride=2)]else:conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1, bias=False)if batch_norm:layers += [conv2d, nn.BatchNorm2d(v), nn.ReLU(inplace=True)]else:layers += [conv2d, nn.ReLU(inplace=True)]in_channels = vreturn nn.Sequential(*layers)def forward(self, x):x = self.feature(x)x = nn.AvgPool2d(2)(x)x = x.view(x.size(0), -1)y = self.classifier(x)return ydef _initialize_weights(self):for m in self.modules():if isinstance(m, nn.Conv2d):n = m.kernel_size[0] * m.kernel_size[1] * m.out_channelsm.weight.data.normal_(0, math.sqrt(2. / n))if m.bias is not None:m.bias.data.zero_()elif isinstance(m, nn.BatchNorm2d):m.weight.data.fill_(0.5)m.bias.data.zero_()elif isinstance(m, nn.Linear):m.weight.data.normal_(0, 0.01)m.bias.data.zero_()if __name__ == '__main__':net = vgg()x = Variable(torch.FloatTensor(16, 3, 40, 40))y = net(x)print(y.data.shape)

1、原始模型训练:

(1)BN的L1稀疏正则化:使用次梯度下降法,对BN层的权重进行再调整

(2)训练完成后主要保存原始模型的参数信息

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.autograd import Variable
from vgg import vgg
import shutil
from tqdm import tqdmlearning_rate = 0.1
momentum = 0.9
weight_decay = 1e-4
epochs = 3
log_interval = 100
batch_size = 100
sparsity_regularization = True
scale_sparse_rate = 0.0001checkpoint_model_path = 'checkpoint,pth.tar'
best_model_path = 'model_best.pth.tar'train_loader = torch.utils.data.DataLoader(datasets.CIFAR10('D:\\ai_data\\cifar10', train=True, download=True,transform=transforms.Compose([transforms.Pad(4),transforms.RandomCrop(32),transforms.RandomHorizontalFlip(),transforms.ToTensor(),transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])),batch_size=batch_size, shuffle=True)test_loader = torch.utils.data.DataLoader(datasets.CIFAR10('D:\\ai_data\\cifar10', train=False, transform=transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])),batch_size=batch_size, shuffle=True)model = vgg()
model.cuda()
optimizer = optim.SGD(model.parameters(), lr=learning_rate, momentum=momentum, weight_decay=weight_decay)def train(epoch):model.train()for batch_idx, (data, target) in enumerate(tqdm(train_loader)):data, target = data.cuda(), target.cuda()data, target = Variable(data), Variable(target)optimizer.zero_grad()output = model(data)loss = F.cross_entropy(output, target)loss.backward()if sparsity_regularization:updateBN()optimizer.step()if batch_idx % log_interval == 0:print('Train Epoch: {} [{}/{} ({:.1f}%)]\tLoss: {:.6f}'.format(epoch, batch_idx * len(data), len(train_loader.dataset),100. * batch_idx / len(train_loader), loss.item()))def test():model.eval()test_loss = 0correct = 0for data, target in tqdm(test_loader):data , target = data.cuda(), target.cuda()data, target = Variable(data), Variable(target)output = model(data)test_loss += F.cross_entropy(output, target, size_average=False).item()pred = output.data.max(1, keepdim=True)[1]correct += pred.eq(target.data.view_as(pred)).cpu().sum()test_loss /= len(test_loader.dataset)print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.1f}%)\n'.format(test_loss, correct, len(test_loader.dataset),100. * correct / len(test_loader.dataset)))return correct / float(len(test_loader.dataset))def save_checkpoint(state, is_best, filename=checkpoint_model_path):torch.save(state, filename)if is_best:shutil.copyfile(filename, best_model_path)def updateBN():for m in model.modules():if isinstance(m, nn.BatchNorm2d):m.weight.grad.data.add_(scale_sparse_rate * torch.sign(m.weight.data)) # L1,使用次梯度下降best_prec = 0
for epoch in range(epochs):train(epoch)prec = test()is_best = prec > best_precbest_prec = max(prec, best_prec)save_checkpoint({'epoch': epoch + 1,'state_dict': model.state_dict(),'best_prec': best_prec,'optimizer': optimizer.state_dict()}, is_best)

2、模型剪枝

(1)剪枝过程主要分为两部分:第一部分是计算mask,第二部分是根据mask调整各层的shape

(2)BN层通道数:Conv -> BN -> ReLU -> MaxPool--à Linear,所以BN的输入维度对应Conv的输出通道数

(3)BN层总通道数:将所有BN层的通道数进行汇总

(4)BN层剪枝百分位:取总通道数的百分位得到具体的float值,大于该值的通道对应的mask置为1,否则对应的mask置为0

(5)改变权重weight:BN层抽取mask为1的通道数的值,该操作会改变BN的shape,从而上下游操作中的Conv和Linear也需要被动做出调整,对Maxpool和ReLu的通道数无影响

(6)Conv层的参数为[out_channels, in_channels, kernel_size1, kernel_size2],所以需要调整两次,先对in_channels进行调整,再对out_channels进行调整。Conv初始输入为RGB的3通道。

假如计算出的保留通道数信息为:

[48, 60, 115, 118, 175, 163, 141, 130, 259, 267, 258, 249, 225, 212, 234, 97]

Conv的输入输出变为:

In shape: 3 Out shape:48

In shape: 48 Out shape:60

In shape: 60 Out shape:115

In shape: 115 Out shape:118

……

In shape: 234 Out shape:97

(7)保存模型时,一方面把有用的参数信息保存了下来,同时剪枝后的最新的模型结构参数也保存了,方便后续再训练时构建新的模型结构

import os
import torch
import torch.nn as nn
from torch.autograd import Variable
from torchvision import datasets, transforms
from vgg import vgg
import numpy as np
from tqdm import tqdmpercent = 0.5
batch_size = 100
raw_model_path = 'model_best.pth.tar'
save_model_path = 'prune_model.pth.tar'model = vgg()
model.cuda()
if os.path.isfile(raw_model_path):print("==> loading checkpoint '{}'".format(raw_model_path))checkpoint = torch.load(raw_model_path)start_epoch = checkpoint['epoch']best_prec = checkpoint['best_prec']model.load_state_dict(checkpoint['state_dict'])print("==> loaded checkpoint '{}'(epoch {}) Prec:{:f}".format(raw_model_path, start_epoch, best_prec) )
print(model)total = 0
for m in model.modules():if isinstance(m, nn.BatchNorm2d):total += m.weight.data.shape[0]bn = torch.zeros(total)
index = 0
for m in model.modules():if isinstance(m, nn.BatchNorm2d):size = m.weight.data.shape[0]bn[index : index + size] = m.weight.data.abs().clone()index += sizey, i = torch.sort(bn)
thre_index = int(total * percent)
thre = y[thre_index]pruned = 0
cfg = []
cfg_mask = []
for k, m in enumerate(model.modules()):if isinstance(m, nn.BatchNorm2d):weight_copy = m.weight.data.clone()mask = weight_copy.abs().gt(thre).float().cuda()pruned += mask.shape[0] - torch.sum(mask)m.weight.data.mul_(mask)m.bias.data.mul_(mask)cfg.append(int(torch.sum(mask)))cfg_mask.append(mask.clone())print('layer index: {:d} \t total channel: {:d} \t remaining channel: {:d}'.format(k, mask.shape[0], int(torch.sum(mask))))elif isinstance(m, nn.MaxPool2d):cfg.append('M')
pruned_ratio = pruned / totalprint('pruned_ratio: {},Pre-processing Successful!'.format(pruned_ratio))# simple test model after Pre-processing prune(simple set BN scales to zeros)
def test():test_loader = torch.utils.data.DataLoader(datasets.CIFAR10('D:\\ai_data\\cifar10', train=False, transform=transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])),batch_size=batch_size, shuffle=True)model.eval()correct = 0for data, target in tqdm(test_loader):data, target = data.cuda(), target.cuda()data, target = Variable(data), Variable(target)output = model(data)pred = output.data.max(1, keepdim=True)[1]correct += pred.eq(target.data.view_as(pred)).cpu().sum()print('\nTest set: Accuracy: {}/{} ({:.1f}%)\n'.format(correct, len(test_loader.dataset), 100. * correct / len(test_loader.dataset)))test()# make real prune
print(cfg)
new_model = vgg(cfg=cfg)
new_model.cuda()layer_id_in_cfg = 0  # cfg中的层数索引
start_mask = torch.ones(3)
end_mask = cfg_mask[layer_id_in_cfg]
for [m0,  m1] in zip(model.modules(), new_model.modules()):if isinstance(m0, nn.BatchNorm2d):idx1 = np.squeeze(np.argwhere(np.asarray(end_mask.cpu().numpy())))m1.weight.data = m0.weight.data[idx1].clone()m1.bias.data = m0.bias.data[idx1].clone()m1.running_mean = m0.running_mean[idx1].clone()m1.running_var = m0.running_var[idx1].clone()layer_id_in_cfg += 1start_mask = end_mask.clone()if layer_id_in_cfg < len(cfg_mask):  # do not change in Final FCend_mask = cfg_mask[layer_id_in_cfg]elif isinstance(m0, nn.Conv2d):idx0 = np.squeeze(np.argwhere(np.asarray(start_mask.cpu().numpy())))idx1 = np.squeeze(np.argwhere(np.asarray(end_mask.cpu().numpy())))print('In shape: {:d} Out shape:{:d}'.format(idx0.shape[0], idx1.shape[0]))w = m0.weight.data[:, idx0, :, :].clone()w = w[idx1, :, :, :].clone()m1.weight.data = w.clone()elif isinstance(m0, nn.Linear):idx0 = np.squeeze(np.argwhere(np.asarray(start_mask.cpu().numpy())))m1.weight.data = m0.weight.data[:, idx0].clone()torch.save({'cfg': cfg, 'state_dict': new_model.state_dict()}, save_model_path)
print(new_model)
model = new_model
test()

3、再训练

剪枝后保存的模型参数相当于训练过程中的一个checkpoint,根据新的模型结构,在此checkpoint的基础上再进行训练,直到得到满意的指标。

import torch
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.autograd import Variable
from vgg import vgg
import shutil
from tqdm import tqdmlearning_rate = 0.1
momentum = 0.9
weight_decay = 1e-4
epochs = 3
log_interval = 100
batch_size = 100
sparsity_regularization = True
scale_sparse_rate = 0.0001prune_model_path = 'prune_model.pth.tar'
prune_checkpoint_path = 'pruned_checkpoint.pth.tar'
prune_best_model_path = 'pruned_model_best.pth.tar'train_loader = torch.utils.data.DataLoader(datasets.CIFAR10('D:\\ai_data\\cifar10', train=True, download=True,transform=transforms.Compose([transforms.Pad(4),transforms.RandomCrop(32),transforms.RandomHorizontalFlip(),transforms.ToTensor(),transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])),batch_size=batch_size, shuffle=True)test_loader = torch.utils.data.DataLoader(datasets.CIFAR10('D:\\ai_data\\cifar10', train=False, transform=transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])),batch_size=batch_size, shuffle=True)checkpoint = torch.load(prune_model_path)
model = vgg(cfg=checkpoint['cfg'])
model.cuda()
model.load_state_dict(checkpoint['state_dict'])optimizer = optim.SGD(model.parameters(), lr=learning_rate, momentum=momentum, weight_decay=weight_decay)def train(epoch):model.train()for batch_idx, (data, target) in enumerate(tqdm(train_loader)):data, target = data.cuda(), target.cuda()data, target = Variable(data), Variable(target)optimizer.zero_grad()output = model(data)loss = F.cross_entropy(output, target)loss.backward()optimizer.step()if batch_idx % log_interval == 0:print('Train Epoch: {} [{}/{} ({:.1f}%)]\tLoss: {:.6f}'.format(epoch, batch_idx * len(data), len(train_loader.dataset),100. * batch_idx / len(train_loader), loss.item()))def test():model.eval()test_loss = 0correct = 0for data, target in tqdm(test_loader):data , target = data.cuda(), target.cuda()data, target = Variable(data), Variable(target)output = model(data)test_loss += F.cross_entropy(output, target, size_average=False).item()pred = output.data.max(1, keepdim=True)[1]correct += pred.eq(target.data.view_as(pred)).cpu().sum()test_loss /= len(test_loader.dataset)print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.1f}%)\n'.format(test_loss, correct, len(test_loader.dataset),100. * correct / len(test_loader.dataset)))return correct / float(len(test_loader.dataset))def save_checkpoint(state, is_best, filename=prune_checkpoint_path):torch.save(state, filename)if is_best:shutil.copyfile(filename, prune_best_model_path)best_prec = 0
for epoch in range(epochs):train(epoch)prec = test()is_best = prec > best_precbest_prec = max(prec, best_prec)save_checkpoint({'epoch': epoch + 1,'state_dict': model.state_dict(),'best_prec': best_prec,'optimizer': optimizer.state_dict()}, is_best)

4、原始模型和剪枝后模型比较:

在cifar10上通过vgg模型分别迭代3次。

原始模型为156M,准确率为70%左右

剪枝后模型为36M,准确率为76%左右

备注:最好是原始模型达到顶峰时再剪枝,此时再比较剪枝前后的准确率影响。

Slimming剪枝方法相关推荐

  1. Network Slimming——有效的通道剪枝方法(Channel Pruning)

    "Learning Efficient Convolutional Networks through Network Slimming"这篇文章提出了一种有效的结构性剪枝方法,即规 ...

  2. 蚂蚁金服AAAI收录论文曝光,动态网络剪枝方法、无语预训练的网络剪枝技术有重大突破...

    来源 | 蚂蚁金服 责编 | Carol 出品 | AI科技大本营(ID:rgznai100) 一年一度在人工智能方向的顶级会议之一AAAI 2020于2月7日至12日在美国纽约举行,旨在汇集世界各地 ...

  3. 我总结了70篇论文的方法,帮你透彻理解神经网络的剪枝算法

    无论是在计算机视觉.自然语言处理还是图像生成方面,深度神经网络目前表现出来的性能都是最先进的.然而,它们在计算能力.内存或能源消耗方面的成本可能令人望而却步,这使得大部份公司的因为有限的硬件资源而完全 ...

  4. 模型剪枝经典论文解读:《Learning Efficient Convolutional Networks through Network Slimming》

    Learning Efficient Convolutional Networks through Network Slimming 摘要: CNN在落地中的部署,很大程度上受到其高计算成本的限制.在 ...

  5. pytorch基于卷积层通道剪枝的方法

    pytorch基于卷积层通道剪枝的方法 原文:https://blog.csdn.net/yyqq7226741/article/details/78301231 本文基于文章:Pruning Con ...

  6. SIGIR 2021 | 基于不确定性正则化与迭代网络剪枝的终身情感分类方法

    导读 终身学习能力对于情感分类器处理网络上连续的意见信息流而言至关重要.然而,执行终身学习对于深度神经网络来说是困难的,因为持续地训练可用信息会不可避免地会导致灾难性遗忘.发表在信息检索领域顶会 SI ...

  7. 减少参数!SPViT:视觉Transformer剪枝新方法

    点击下方卡片,关注"CVer"公众号 AI/CV重磅干货,第一时间送达 转载自:集智书童 Pruning Self-attentions into Convolutional La ...

  8. python决策树剪枝_决策树剪枝的方法与必要性

    1 决策树剪枝的必要性 本文讨论的决策树主要是基于ID3算法实现的离散决策树生成.ID3算法的基本思想是贪心算法,采用自上而下的分而治之的方法构造决策树.首先检测训练数据集的所有特征,选择信息增益最大 ...

  9. 剪枝与重参第二课:修剪方法和稀疏训练

    目录 修剪方法和稀疏训练 前言 1.修剪方法 1.1 经典框架:训练-剪枝-微调 1.2 训练时剪枝(rewind) 1.3 removing剪枝 2.dropout and dropconnect ...

最新文章

  1. 系统架构升级要不要上微服务?历“久”弥新微服务——你真的需要升级微服务架构吗
  2. 七个最流行的Python神经网络库
  3. Ubuntu 10.10配置JRE、JDK、Eclipse和Tomcat7.0.5
  4. 转为win64后, MS的lib问题
  5. python expect模块_Python尚学堂高淇|第二季0408P119P123with上常见的异常的解决tryexcept...else结构,...
  6. 列举一台微型计算机的常用硬件,电脑基础知识问答卷
  7. regex match
  8. [Git] 多个分支修改相同文件
  9. 剑指Offer_27_字符串的排列
  10. 沫沫金::jqGrid插件-弹窗返回值
  11. php文件加锁 lock_sh ,lock_ex
  12. dcdc转换器计算机显示,DC-DC转换器的问题
  13. 手机 人人网android 2.2,人人网客户端安卓版
  14. 利用stm32cubemx生成单极性倍频spwm调制代码步骤
  15. java动物继承_java 编码实现动物世界的继承关系:动物(Animal)属性:名称(name)具有行为:吃(eat)、睡觉(sleep)...
  16. 雀巢咖啡旗下感CAFÉ品牌正式推出全新单品--感CAFÉ鎏光咖啡
  17. 前后端分离API接口如何加密 —— AES加密方案
  18. 谈马云创业人生中的十大苦难
  19. 【厚积薄发系列】C++项目总结9—ZeroMQ消息队列入门及分布式系统中应用(一)
  20. 无聊的逗 蓝桥杯 python

热门文章

  1. python sklearn 回归案例:车流量预测
  2. 电路实验——实验一 电路基本测量
  3. linux parted rpm,为Everest Linux构建QtParted的rpm包(六)
  4. Win10自带微软输入法的全拼双拼一键切换
  5. 网络信息安全课程:对于匿名身份认证协议的学习
  6. “长沙银行杯”腾讯云启创新大赛——科技创新,原来还能这样玩?
  7. php余额宝收益源码,2021亲测PHP投资理财分红源码 带积分商城+余额宝+大转盘+教程...
  8. 国际多语言出海商城返佣产品自动匹配订单源码
  9. ScriptManager.RegisterStartupScript方法
  10. 如何写好年度总结PPT? by 傅一平