模型压缩的目的是减小模型计算量(FLOPs or MACC)、减小模型参数量/体积、减小模型的推理时间(latency)。主要方法有知识蒸馏、紧凑网络设计、剪枝、量化、低秩近似等。今天分享一篇通道剪枝相关的论文。商汤出品,CVPR 2020 Oral。

论文传送门:http://arxiv.org/abs/2005.03354v1

商汤官方解读:https://zhuanlan.zhihu.com/p/146721840


Introduction

CNN是过参数化的,通道剪枝可以加速和压缩模型,去除不重要的通道使得模型更加高效、紧凑。通道剪枝可以被视为在原始网络中寻找一个最优的子结构的问题。典型的剪枝程式:预训练大模型 - 通道/权重剪枝 - 微调小模型。有一篇论文指出,剪枝后的模型结构是其性能表现的核心因素,而不是继承的所谓的重要的权重。这一发现表明通道剪枝的本质是找到良好的剪枝结构,即逐层通道数目。

搜索空间很大,如何高效的搜索?本文提出了DMCP,给定限制条件下可以端到端优化的通道剪枝算法。作者将通道剪枝建模为一个马尔可夫过程,马尔可夫过程中的状态

对应通道k被保留、

的转移概率对应当第k个通道被保留时第k+1个通道的保留概率。各状态的边缘概率可以由转移概率的乘积计算出来,该边缘概率被视作通道的重要程度。在网络前馈的时候,边缘概率作为系数乘以特征图的对应通道。该方法可以使用SGD来进行端到端的优化。优化结束之后,根据各状态的边缘概率对模型进行剪枝,重训剪枝后的模型来获得更高的性能。

Related Work

神经网络结构搜索(NAS):一种与此相似的技术是NAS,例如DARTS。DARTS使用一组可学习的权重来设置每个候选操作的概率的参数,层的输出是概率和相应操作的特征图的线性组合。训练后,选择概率最高的候选操作构成最终架构。但是,DARTS是在小型代理任务上执行,然后将搜索到的结构迁移到大规模目标任务上。ProxylessNAS 通过仅采样两条路径来搜索大规模目标任务上的体系结构,从而避免了使用代理任务。但是DARTS中提出的可微方法并不能直接应用在剪枝过程中,原因在于:两者的所搜空间是不同的、DARTS总的各个操作是相互独立的而剪枝中存在一些隐含的逻辑关系。

通道剪枝:可以被分为硬剪枝(直接去除通道)和软剪枝(通道系数设置为0)。硬剪枝方法间的差异存在于剪枝标准的不同,例如权重标准,输出中零的平均百分比或每个通道对最终损失的影响。比如一种做法是采用BN中的比例因子作为通道重要程度的度量,稀疏训练后去除相对不重要的通道。软修剪方法主要是使修剪的通道为零或接近于零,以减少这些通道的影响。比如一种做法是首先通过层内准则和计算的层比率将一些滤波器置零,然后会逐渐增加置零的滤波器的比例,直到达到给定的计算预算为止。本文的方法可以归为软剪枝,与别的方法相比,本文方法最大的区别在于简化了搜索空间:

。另外的一些相关工作设计一个搜索过程,直接从未剪枝的网络中搜索最佳子结构。AMC提出采用强化学习的方法来训练一个控制器,它输出每层的剪枝率。MetaPruning使用元学习的方法来预测模型权重、结合遗传算法来搜索子结构。这两种方法存在的共性问题是:该过程在一个非常大的结构采样空间里面进行,这些方法的可扩展性受到限制。有关AMC和MetaPruning的简要介绍参见文尾。

Method

本文将通道剪枝建模为马尔可夫过程,如上图所示,这是一个有向无环图。状态

代表第k个通道被保留;转移概率
代表:如果第k的通道被保留,那么第 k+1个通道的保留概率。任何一个状态都可以转移到终止状态
就预示着剪枝过程的结束。该过程具有以下特性:如果层L保留k个通道,则它们必须是前k个通道,也就是说如果第k个通道保留,那么前k-1个通道也一定要保留。马尔可夫过程的无记忆性:给定第k个通道被保留,那么第 k+1个通道是否保留与前 k-1个通道是否保留是条件独立的。注意作者限制每层至少保留一个通道,即
。前 k个通道的保留概率为:

那么转移概率

如何得到呢?作者将该过程看作一个随机过程,然后通过一组可学习的参数来对转移概率进行参数化。为了与网络权重(network weights)相区别,作者将这组可学习参数称为结构参数(architecture parameters)

。结构参数是直接由网络学习得到的,转移概率
由结构参数计算得到,边缘概率
由转移概率计算得到:

一些细节:1. 将输出特征图的每个通道乘以相应的保留概率,即实现将结构参数整合进未剪枝的模型中;2. 注意,BN会对特征图做批标准化,ReLU激活特征图,所以结构参数应该用在它们后面;3. 对于shortcut的处理。残差结构中的shortcut要求被连接的两个卷积层具有相同的通道数,作者采用结构参数共享的方法来满足这一点;4. 作者不为结构参数添加权重衰减,原因:因为当某些通道的保留概率接近于0或1时,可学习参数

的范数将变得非常大,强迫它们接近0会损害优化过程。

下面来为结构参数A添加正则项 / 计算损失。首先要明确Budget Regularization. 本文中使用FLOPs作为Budget Regularization。但是FLOPs不能直接使用梯度下降法(GD)来进行优化,作者提出了以下的解决方案。层L的平均通道数通过下式计算:

给定层L的平均输入通道数E(in)和平均输出通道数E(out),那么E(FLOPs)可以通过下式计算:

上式中

是输入特征图的宽/高,
是卷积核的宽/高,
是padding size,stride是步长。整个模型的
由逐层相加而得。

行文至此,损失函数也就水到渠成了。假设是分类任务, 假设

是给定的FLOPs Budget Regularization,那么损失函数可以写作:

注意,为了使得

严格低于
但又不对其过于敏感,作者为该项损失增加一条边界,即限制:
,其中
称为容忍率,默认设置为0.95。

DMCP的训练过程可以分为两个阶段,即未剪枝网络的权重更新和结构参数更新。在训练过程中,阶段1和阶段2被迭代调用。

热身阶段:在迭代阶段1和阶段2之前,DMCP首先训练阶段1一些epoch以进行预热,其中子网络使用的结构参数采样自随机初始化的Markov过程。此过程旨在避免由于权重训练不足而导致更新结构参数时,网络掉入不良的局部最小值。

阶段1:利用

,仅更新未剪枝网络的权重。作者引入variant sandwich rule来训练未剪枝的网络,如上图(a)所示,以使未剪枝模型中的通道组比紧随之后的通道组更“重要”,这一训练方式将同一层中的通道看作是不平等的,此时具有k个通道的层的最佳选择是前k个通道,而不是其他可能的组合。在此基础上,引入马尔可夫建模变得合理。

阶段2:利用

,仅更新结构参数。为了进一步减少搜索空间,作者将通道均匀地分为几组(≥10组),每个结构结构参数α负责一组而不是仅一个通道。每层具有相同数量的组。

通过Pruned Model Sampling来实现通道剪枝的最终目的:获得一个紧凑的、复杂度低的模型。在DMCP训练结束之后,来采样选择满足给定限制条件的模型。作者提出了两种采样方法:其一,Direct Sampling (DS),在target FLOPs budget的限制下,根据每一层优化得到的马尔可夫过程的转移概率采样,层间是独立的;其二,Expected Sampling (ES),将每层的通道数设置为平均通道数,平均通道数由马尔可夫过程的状态概率相加而得。作者称,在实验中

总是可以优化到0,所以FLOPs的限制是容易满足的,即Expected Sampling (ES)总是可以满足FLOPs的限制。最后,从头训练获得的剪枝模型。

Experiments

作者利用MobileNetV2ResNetImageNet上进行classification实验。

消融实验之可恢复验证(Recoverability verification)。DMCP应该具有的一个特性是,在没有FLOPs约束的预训练模型上进行搜索时,它应该保留几乎所有通道。作者在MobileNetV2 1.0x上进行实验,随机初始化结构参数、无FLOPs约束。训练过程中freeze网络的权重,仅仅训练结构参数,结果如下图。从图中可以看到,DMCP具有可恢复性。

消融实验之Expected sampling和Direct Sampling的对比。作者在MobileNetV2-210M和ResNet50-1.1G(注:ResNet50-1.1G表示未剪枝模型选用ResNet50,target FLOPs=1.1G,其余命名方式类似)上进行实验。MobilenetV2 0.75x和ResNet50 0.5x的FLOPs分别是210M和1.1G,这两个baseline的分类精度分别是70.4%和71.9%。对于DS,作者采样了五个模型。从下述结果可以看到:1.不论哪种方法,都优于baseline;2.ES与DS的最优结果相近。ES具有鲁棒性,作者默认采用ES的采样方式。

消融实验之Influence of warmup phase。Warmup可以带来更好的性能。一个可能的原因是,使用预热使网络权重在更新结构参数之前得到了足够的训练,这使得权重更加可区分,并防止了结构参数陷入不良的局部最小值。作者在补充材料中提到了一个细节:使用预训练模型代替预热,精度会下降0.6%。

消融实验之Impact of the variant sandwich rule。变体与原始的“三明治规则”之间有两个区别。首先,前者每层中保留通道的比率不同,而后者是固定的。因为剪枝网络对不同的层可能具有不同的偏好。其次,前者保留通道比率的随机抽样服从结构参数分布,而后者服从均匀分布。从实验结果可以看到,变体的“三明治规则”更适合于当前任务。

消融实验之Training scheme。在warmup阶段之后,第一种设置是:仅根据

更新结构参数;第二种设置是:根据
更新结构参数;第三种设置是:根据
同时更新结构参数和网络权重。第一个实验是FLOPs引导的baseline,具有相同FLOPs的层修剪程度也相同,其结果远比其他实验差。与第二个实验相比,可知任务损失可以帮助区分不同层的重要性,即使不同的层具有相同FLOPs它们的重要程度也很可能是不同的。最后,第三个实验说明当结构参数发生变化时,权重也应进行调整。

DMCP与其他SOTA的方法进行对比,包括reinforcement learning method AMC, evolution method MetaPruning, one-shot method AutoSlim, and traditional channel pruning methods SFP and FPGM。从下表中可以看出,在相同的FLOPs限制下,DMCP优于所有其他方法。请注意,AMC,MetaPruning和DMCP通过标准的硬标签丢失从头开始训练修剪的模型, 而AutoSlim采用了知识蒸馏。为了与AutoSlim进行公平比较,作者还使用相同的训练方法来训练剪枝后的模型。结果表明,这种训练方法可以进一步提高性能,并且在不同的FLOPs模型中DMCP都超过了AutoSlim。

最后补充一点:作者在附录中提到,不用可尔科夫过程建模结构参数,而是用贝努利模型来建模,剪枝后的模型分类精度为70.1%,比DMCP低了2.3%。

Conclusion

小结一下:本文提出了一种新的通道剪枝方法(称为可微马尔可夫通道剪枝,DMCP),以解决现有方法需要训练和评估大量子结构的缺陷。DMCP将通道剪枝建模为马尔可夫过程,该方法是可微分的,因此可以采用梯度下降法来端到端的优化。优化之后,可以通过简单的“期望采样”来采样所需模型,并从头训练以微调。


相关工作补充:

AMC: AutoML for Model Compression and Acceleration on Mobile Devices (ECCV18)

传统的模型压缩方法依靠手工设计的启发式策略,或者领域专家在大的设计空间中对精度、速度、模型体积进行均衡。搜索空间太大引发两个问题:耗时、结果次优。AMC旨在利用自动机器学习(AutoML)的方法来自动寻找优于 人工设计的、模型特异的 压缩策略。AMC采用了强化学习的方法。作者发现压缩后模型的性能对每一层的稀疏度非常敏感,因此需要一个细粒度的动作空间,所以作者提出了一种使用DDPG代理程序的连续压缩比控制策略,通过惩罚精度损失、反复试验来学习,同时鼓励模型压缩和加速。Actor的内部结构:FC-FC-sigmoid,每个FC含有300个神经单元。具体来说,DDPG代理以分层方式处理网络。对于每个层Lt,代理接收一个embedding St,该embedding编码该层的有用特性,然后输出精确的压缩比at。在压缩了层Lt之后,代理将移动到下一层。在没有微调的情况下评估当前的剪枝模型的精度,这种简单的近似可以缩短搜索时间。搜索完成后微调搜索到的模型。两种压缩模型搜索思路:给定FLOPs, latency或model size搜索最佳精度的模型(通过限制搜索空间实现);给定精度搜索最小的模型(在reward中同时考虑精度和资源限制)。

St如上,其中t是层序号、c*k*k*n是卷积核尺寸、c*h*w是输入特征图尺寸、reduced是前面所有层整个的FLOPs削减量、Rest是剩余的FLOPs削减量、at-1属于(0, 1]是前一层的削减比例。

MetaPruning: Meta Learning for Automatic Neural Network Channel Pruning (ICCV19)

常规的通道剪枝方法主要依靠数据驱动的稀疏性约束或人为设计的策略。最近的AutoML-style的一些方法基于反馈循环或强化学习,以迭代的方式自动剪枝通道。MetaPruning是港科大、旷视、清华、华科一起提出的一种采用元学习(Meta Learning)的方法来自动进行网络通道剪枝的算法。元学习是在学习学习方式。本文中作者提出训练一个可以为所有候选的剪枝网络结构生成权重的PruningNet,这样就可以通过在验证数据上评估其准确性来搜索性能良好的结构,非常高效。采用随机结构采样的方法来训练PruningNet。PruningNet会使用相应的网络编码矢量(即每一层中的通道数)为剪枝网络生成权重。通过随机输入不同的网络编码向量,PruningNet逐渐学会为各种剪枝网络结构生成权重。训练完成后,通过遗传算法来搜索性能良好的剪枝网络,该方法可以灵活地合并各种约束。 PruningNet由FC-ReLU-FC组成。

深度学习每层的通道数如何计算_模型通道剪枝之DMCP: Differentiable Markov Channel Pruning...相关推荐

  1. 深度学习每层的通道数如何计算_深度学习基础系列(一)| 一文看懂用kersa构建模型的各层含义(掌握输出尺寸和可训练参数数量的计算方法)...

    我们在学习成熟网络模型时,如VGG.Inception.Resnet等,往往面临的第一个问题便是这些模型的各层参数是如何设置的呢?另外,我们如果要设计自己的网路模型时,又该如何设置各层参数呢?如果模型 ...

  2. 吴恩达神经网络与深度学习——浅层神经网络

    吴恩达神经网络与深度学习--浅层神经网络 神经网络概述 神经网络表示 计算神经网络的输出 m个样本的向量化 for loop 向量化 向量化实现的解释 激活函数 sigmoid tanh函数 ReLu ...

  3. 动手学深度学习——卷积层里的填充和步幅

    1.填充 填充( padding )是指在输⼊⾼和宽的两侧填充元素(通常是 0 元素). 给定(32x32)输入图像: 应用5x5大小的卷积核,第一层得到输出大小28x28,第七层得到输出大小4x4: ...

  4. 深度学习 | BN层原理浅谈

    深度学习 | BN层原理浅谈 文章目录 深度学习 | BN层原理浅谈 一. 背景 二. BN层作用 三. 计算原理 四. 注意事项 为什么BN层一般用在线性层和卷积层的后面,而不是放在激活函数后 为什 ...

  5. Keras深度学习实战(1)——神经网络基础与模型训练过程详解

    Keras深度学习实战(1)--神经网络基础与模型训练过程详解 0. 前言 1. 神经网络基础 1.1 简单神经网络的架构 1.2 神经网络的训练 1.3 神经网络的应用 2. 从零开始构建前向传播 ...

  6. 深度学习论文随记(二)---VGGNet模型解读-2014年(Very Deep Convolutional Networks for Large-Scale Image Recognition)

    深度学习论文随记(二)---VGGNet模型解读 Very Deep Convolutional Networks forLarge-Scale Image Recognition Author: K ...

  7. 在浏览器中进行深度学习:TensorFlow.js (四)用基本模型对MNIST数据进行识别

    2019独角兽企业重金招聘Python工程师标准>>> 在了解了TensorflowJS的一些基本模型的后,大家会问,这究竟有什么用呢?我们就用深度学习中被广泛使用的MINST数据集 ...

  8. 【深度学习】解析神经网络中的数值稳定性、模型初始化和分布偏移(Pytorch)

    [深度学习]解析神经网络中的数值稳定性.模型初始化和分布偏移 文章目录 1 概述1.1 梯度消失和梯度爆炸1.2 打破对称性 2 参数初始化 3 环境和分布偏移3.1 协变量偏移3.2 标签偏移3.3 ...

  9. 【深度学习】Swin-Unet图像分割网络解析(文末提供剪枝仓库)

    [深度学习]Swin-Unet图像分割网络解析(文末提供剪枝仓库) 文章目录 1 概述 2 Swin-Unet架构 3 bottleneck理解 4 具体结构4.1 Swin Transformer ...

最新文章

  1. 给vim添加自动跳出括号的功能
  2. 【python数据挖掘课程】十.Pandas、Matplotlib、PCA绘图实用代码补充
  3. pycharm如何汉化
  4. 第一节: Timer的定时任务的复习、Quartz.Net的入门使用、Aop思想的体现
  5. c语言 系统 dome,订餐系统(Dome)
  6. 【C#/WPF】用Thumb做可拖拽的UI控件
  7. 构建线性表的c语言代码,数据结构严蔚敏C语言版—线性表顺序存储结构(顺序表)C语言实现相关代码...
  8. windows10下 tensorflow gpu版本安装配置方法
  9. java文件读写工具类
  10. Xcode9之折叠代码
  11. BUUCTF-MISC-被劫持的神秘礼物~梅花香之苦寒来
  12. linux考试不及格反思100字,考试没考好的反思总结(精选10篇)
  13. 牛客网面试提错题集(1)
  14. 行政组织理论-阶段测评4
  15. timer延迟1us_C# 高精度延迟代码执行时间(us/ns)
  16. postman设置前置条件
  17. 系统 应用程序 提示 初始化失败 或 无法加载模块 等错误
  18. 牛客小白月赛16——D 小阳买水果
  19. [演讲]北大鄂维南院士:智能时代意味着什么?
  20. BP神经网络隐藏层的作用,bp神经网络输出层函数

热门文章

  1. 138. 兔子与兔子【字符串哈希】
  2. 1128 N Queens Puzzle (20 分)【难度: 一般 / 知识点: 模拟】
  3. unity开宝箱动画_[技术博客]Unity3d 动画控制
  4. 4G室内直放站_室内信号不太好,安装一个手机信号放大器,有效果吗?
  5. IDEA 真牛逼,900行 又臭又长 的类重构,几分钟搞定
  6. Missing artifact com.github.pagehelper:pagehelper:jar:3.4.2-fix的解决方法
  7. 计数排序及其改进 C++代码实现与分析 恋上数据结构笔记
  8. Android网络编程的Socket通信总结
  9. 06--MySQL自学教程:DML(Data Manipulation Language:数据库操作语言),只操作表
  10. (Mybatis)XML配置解析