【CVPR 2021】剪枝篇(一):Network Pruning via Performance Maximization

  • 论文地址:
  • 主要问题:
  • 主要思路:
  • 具体实现:
    • 基本符号:
    • 子网络生成:
    • 性能预测网络:
    • 事件记忆模块:
    • 精度分布不平衡:
    • 性能最大化算法:
  • 实验结果:

论文地址:

https://openaccess.thecvf.com/content/CVPR2021/papers/Gao_Network_Pruning_via_Performance_Maximization_CVPR_2021_paper.pdf

主要问题:

通道剪枝是一种比较通用而且效果较好的模型压缩方法,通过修剪权重和激活的通道来获得一个小的子网络

为了找到这样的子网络,许多现有的信道剪枝方法都使用分类损失作为指导

然而修剪过的子网络并不一定具有较高的精度和低分类损失,而这往往是由于损失度量不匹配造成的(作者在后面的实验中也证明了性能预测网络的梯度和分类损失有不同的方向)

因此在本文中,作者首先考虑了剪枝的损失度量不匹配问题,并提出了一种新的卷积神经网络的通道剪枝方法,用来直接最大化子网络的性能(即精度)

主要思路:

具体地说就是,我们训练一个独立的神经网络来预测子网络的性能,然后最大化网络的输出作为指导剪枝

并且在这个剪枝过程中,我们可以直接使用子网络和小批量精度作为样本来训练性能预测网络(是否想到了强化学习)

但是有效地训练这种性能预测网络往往是比较困难的,可能会面临灾难性遗忘和子网络分布不平衡的问题

为了解决这个问题,作者使用一个事件记忆模块来沿着剪枝轨迹收集样本

直接使用这些样本是有问题的,因为这些样本的准确性分布远不均匀,但是这个问题可以通过重新采样这些样本来解决(跟DQN里面经验池很像)

利用上述技术,性能预测网络在剪枝过程中进行了增量训练

在性能预测网络访问足够的样本并足够精确后,放到剪枝过程中,为通道剪枝提供额外的监督

由于性能预测网络的训练和剪枝工作同时进行,因此没有额外的成本

此外作者并没有放弃分类损失(通常是交叉熵损失),并且参考多目标学习的理念,在最终损失函数的定义中同时考虑了分类损失和性能最大化

这背后的基本原理是,分类损失和性能最大化都为修剪提供了有用但不同的信息,并且合并它们将会得到更好的结果

具体实现:

基本符号:

在CNN中,第 iii 层的特征图可以表示为 Fi∈RCi×Wi×Hi,i=1,...,LF_i\in R^{C_i×W_i×H_i},i=1,...,LFi​∈RCi​×Wi​×Hi​,i=1,...,L,其中 CiC_iCi​ 是通道数,LLL 是层数

用 1(⋅)1(\cdot)1(⋅) 表示指标函数, ⊙⊙⊙ 表示点积

子网络生成:

正如我们前面所讨论的,直接采样子网络通常会产生琐碎的结果,特别是当剪枝率很高的时候

为了训练一个针对信道剪枝的性能预测网络,我们不需要遍历所有的子网络

假设原始模型的 FLOPs 为 TtotalT_{total}Ttotal​,剪枝率为ppp,我们对从pTtotalpT_{total}pTtotal​ 到TtotalT_{total}Ttotal​ 的子网络很感兴趣

我们首先丢弃 FLOPs 低于 pTtotalpT_{total}pTtotal​ 的子网络,因为它们不满足 FLOPs 约束

因此,我们更喜欢生成具有特定 FLOPs 的有意义的子网络作为性能预测网络的训练样本

我们开始先通过将原始模型修剪到目标 FLOPs pTtotalpT_{total}pTtotal​ 来生成这些子网络

为了实现这一点,我们首先引入了基本的可微剪枝算法

我们使用可微门来描述一个通道,对于第 iii 层,门的定义为:

oi=1/(1+e−(wi+s)/τ)o_i=1/(1+e^{-(w_i+s)/\tau})oi​=1/(1+e−(wi​+s)/τ)

其中 1/(1+e−(wi+s)/τ)1/(1+e^{-(w_i+s)/\tau})1/(1+e−(wi​+s)/τ) 就是 sigmoid 函数,oi∈RCio_i\in R^{C_i}oi​∈RCi​ 并且 oi∈[0,1]o_i\in [0,1]oi​∈[0,1],

wi∈RCiw_i\in R^{C_i}wi​∈RCi​ 是门控单元可学习的参数,sss 是从 Gumbel 分布的采样,s∈Gumbel(0,1)s\in Gumbel(0,1)s∈Gumbel(0,1),τ\tauτ 是来控制锐度的超参数

这里的 oio_ioi​ 是连续的,用来精确地生成子网络

然后我们进一步将其规范到 000 或 111:

ai=1oi>12(oi)a_i=\mathbb{1}_{o_i>\frac{1}{2}}(o_i)ai​=1oi​>21​​(oi​)

其中 ai∈{0,1}Cia_i\in \{0,1\}^{C_i}ai​∈{0,1}Ci​

由于指示函数 1(⋅)1(\cdot)1(⋅) 不可微,作者使用了直通式估计器(straight-through estimator)来计算梯度

上述两个等式中的可微分门则使用 Gumbel-Solftmax 算法来近似伯努利分布(虽然有近似伯努利分布,但作者发现它们的差异也并不显著)

为了实现最终的剪枝,我们将门控单元应用到特征图上:

Fi^=ai⊙Fi\hat{F_i}=a_i⊙F_iFi​^​=ai​⊙Fi​

其中 aia_iai​ 被扩充到了 FiF_iFi​ 的大小

这样整个剪纸过程的优化目标就可以写作:

min⁡wL(f(x;a,Θ,y)+R(T(a),pTtotal))\min_wL(f(x;a,\Theta,y)+R(T(a),pT_{total}))minw​L(f(x;a,Θ,y)+R(T(a),pTtotal​))

其中 www 包括了所有门控单元的可学习参数,aaa 是用于表示 CNN 模型结构的向量,a=cat(a1,...,ai,...,aL)a=cat(a_1,...,a_i,...,a_L)a=cat(a1​,...,ai​,...,aL​),T(a)T(a)T(a) 是由 aaa 定义的子网络的 FLOPs ,x,yx,yx,y 分别是输入图片及其标签,f(⋅;a,Θ)f(\cdot;a,\Theta)f(⋅;a,Θ) 是由 Θ\ThetaΘ 参数化的 CNN 模型(其结构由 aaa 定义),R(T(a),pTtotal)=log(max(T(a),pTtotal)/pTtotal)R(T(a),pT_{total})=log(max(T(a),pT_{total})/pT_{total})R(T(a),pTtotal​)=log(max(T(a),pTtotal​)/pTtotal​) 是推动子网络到达目标 FLOPs 的正则化项

在上述等式的优化过程中会生成许多具有不同结构 aaa 的子网络

假设精度 qqq 是基于给定的小批量计算的,我们就可以得到一对代表一个子网络及其精度的样本(a,q)(a,q)(a,q)

性能预测网络:

一旦我们获得了样本(a,q)(a,q)(a,q),我们就可以训练一个神经网络来预测给定子网络结构的性能

我们首先定义了性能预测网络:qpred=PN(a)q_{pred}=PN(a)qpred​=PN(a)

PN(⋅)PN(·)PN(⋅) 就是所提出的性能预测网络

然后使用 sigmoid 函数作为输出激活,因此 qpredq_{pred}qpred​ 在 000 到 111 的范围之间

性能预测网络由全连接层和 GRU 组成。简而言之,全连接的层将每个层的结构向量转换为一个紧凑的向量表示形式,并使用GRU来连接不同的层

作者使用GRU,是考虑到 ai−1a_{i−1}ai−1​ 和 aia_iai​ 有隐式的依赖性,因此 GRU 可能会使得性能预测网络有可能捕获子网络中复杂的交互

PN优化是一个回归问题,作者采用平均绝对误差损失(MAE)进行优化:

minwpLp=∣q−PN(a)∣min_{w_p}L_p=|q-PN(a)|minwp​​Lp​=∣q−PN(a)∣

其中 wpw_pwp​ 是性能预测网络的参数

事件记忆模块:

作者做的这项工作的早期版本直接利用当前迭代的子网络来训练性能预测网络

然而,作者发现它只会恶化剪枝过程,这是因为性能预测网络很难预测早期子网络,这种现象被称为灾难性遗忘(catastrophic forgetting)

为了克服这个问题,我们需要定期回放以前的子网络

因此作者进一步提出了一个事件记忆模块(Episodic Memory Module)来记忆早期的子网络

事件记忆定义为 EM=(A,Q)EM=(A,Q)EM=(A,Q),其中 A∈Rm×K,Q∈RKA\in R^{m×K},Q\in R^KA∈Rm×K,Q∈RK,mmm 是向量 aaa 的长度,KKK 是当前 EM 的大小

当我们向 EM 添加一个子网络时,KKK 自动加 111,并且 KKK 小于预先定义的最大 EM 容量 KmaxK_{max}Kmax​

但是正如前面提到的,小批量精度不是一个很好的估计精度;而另一方面,如果我们使用整个训练数据集来计算精度,计算成本就太过昂贵

为了利用效率和精度,我们每 ccc 次迭代收集子网络和相应的小批量精度的均值,以构建一个增强的子网络表示:

a‾=1a>12(1c∑i=1cai),q‾=1c∑i=1cqi\overline{a}=1_{a>\frac{1}{2}}(\frac{1}{c}\sum^c_{i=1}a_i),\overline{q}=\frac{1}{c}\sum^c_{i=1}q_ia=1a>21​​(c1​∑i=1c​ai​),q​=c1​∑i=1c​qi​

注意如果 ccc 太大,那么上面的参数是无效的,导致增强的表示是无用的

此外在收集子网络时,我们不比计算梯度

假设我们在 EM 模块中已经有 KKK 个子网,那么 EM 就可以通过以下方式更新:

Pr−j={Ai=a‾,i=arg min⁡i∣Qi−q‾∣ifK=Kmax,AK+1=a‾otherwiseP_{r-j}= \begin{cases} A_i=\overline{a},i=\argmin_i |Q_i-\overline{q}|& if K=K_{max},\\ A_{K+1}=\overline{a} & otherwise \end{cases} Pr−j​=⎩⎨⎧​Ai​=a,i=iargmin​∣Qi​−q​∣AK+1​=a​ifK=Kmax​,otherwise​

也就是说当 K=KmaxK=K_{max}K=Kmax​ 时,我们将替换掉跟当前样本精度最相近的样本

事实上,在修剪过程中,大多数子网络在满足目标 FLOPS 后都具有相似的性能

因此,我们将使用 KmaxK_{max}Kmax​ 来鼓励子网络的多样性

精度分布不平衡:

在剪枝过程中,作者绘制了子网络精度的经验分布,发现精度集中在84左右:

为了防止性能预测网络提供没有价值的解决方案,并使其收敛速度更快,作者提出可以根据子网络的准确性进行重新采样

即将所有的子网络按照 QQQ 的等差 1N−1(max(Q)−min(Q))\frac{1}{N-1}(max(Q)-min(Q))N−11​(max(Q)−min(Q)) 分为 NNN 个组

然后我们对每一组的子网络进行计数,并根据其计数的倒数对其进行重新抽样

这就相当于创建 NNN 个伪类并进行重新采样

性能最大化算法:

在拥有了一个相对靠谱的性能预测网络后,我们开始最大限度地提高搜索更好的子网络的性能

子网络的性能可以表示为 PN(a)PN(a)PN(a),因此我们可以最大化 PN(a)PN(a)PN(a) 作为精度的估值

max⁡wPN(a)\max_wPN(a)maxw​PN(a) 跟 min⁡w1PN(a)\min_w\frac{1}{PN(a)}minw​PN(a)1​ 是等价的,因此为了使得训练更稳定,我们优化目标就可以写作:min⁡wlog(1PN(a))\min_wlog(\frac{1}{PN(a)})minw​log(PN(a)1​)

这样总的优化目标就可以表示为:

min⁡wJ(w)=L(f(x;a,Θ),y)+γ(K,LP)⋅log(1PN(a))+λR(T(a),pTtotal)\begin{aligned} \min_wJ(w)=&L(f(x;a,\Theta),y)+\gamma(K,L_P)\cdot log(\frac{1}{PN(a)})\\ &+\lambda R(T(a),pT_{total}) \end{aligned} wmin​J(w)=​L(f(x;a,Θ),y)+γ(K,LP​)⋅log(PN(a)1​)+λR(T(a),pTtotal​)​

其中 γ(K,LP)\gamma(K,L_P)γ(K,LP​) 是一个反映性能预测网络置信度的函数,被用来自动控制 log(1PN(a))log(\frac{1}{PN(a)})log(PN(a)1​) 的量级,λ\lambdaλ 被用来控制正则化的量级

γ(K,LP)\gamma(K,L_P)γ(K,LP​) 定义为:γ(K,LP)=1K≥Kmax4(K)⋅(1−Lp)2\gamma(K,L_P) =1_{K\geq\frac{K_{max}}{4}}(K)\cdot(1-L_p)^2γ(K,LP​)=1K≥4Kmax​​​(K)⋅(1−Lp​)2,范围为 [0,1][0,1][0,1]

通常来说更低的 LpL_pLp​ 表示 PN(⋅)PN(\cdot)PN(⋅) 具有更高的置信度,但是 PNPNPN 的训练是一项增量学习任务,在 PNPNPN 能够访问足够的样本之前,LpL_pLp​ 可能不是很可靠

虽然存在损失-度量不匹配的问题,但来自损失函数和性能最大化的信息仍有一些重叠

由于我们已经使用了分类损失,因此希望从性能最大化中获得独特的信息

为了实现这一点,我们使梯度彼此正交

假设 gLi=∂L∂wig^i_L=\frac{\partial L}{\partial w_i}gLi​=∂wi​∂L​ 表示第 iii 层的分类损失的梯度,gpi=∂log(1PN(a))∂wig^i_p=\frac{\partial log(\frac{1}{PN(a)})}{\partial w_i}gpi​=∂wi​∂log(PN(a)1​)​ 是性能最大化损失的梯度

这两个项的修正梯度为:

gi=gLi+g^pig^i=g^i_L+\hat{g}^i_pgi=gLi​+g^​pi​

其中 gpig_p^igpi​ 可以分解成两部分:gpi=g^pi+g‾pig_p^i=\hat{g}_p^i+\overline{g}_p^igpi​=g^​pi​+g​pi​,其中 g^pi⊥gLi\hat{g}_p^i\bot{g}_L^ig^​pi​⊥gLi​ 并且 g‾pi\overline{g}_p^ig​pi​ 跟 gLig^i_LgLi​ 方向相同

作者也给出了整个算法的伪代码:

实验结果:

【CVPR 2021】剪枝篇(一):Network Pruning via Performance Maximization相关推荐

  1. 微软亚研院 CVPR 2021 9篇视觉研究前沿进展

    本文转载自微软研究院AI头条. 编者按:作为世界顶级的 AI 会议,CVPR 一直引领着计算机视觉与模式识别技术领域的学术与工业潮流.今年的 CVPR 于6月19日至25日在线上举办.在此,为大家精选 ...

  2. CVPR 2021|一个绝妙的想法:在类别不平衡的数据上施展半监督学习

    点击上方"视学算法",选择加"星标"或"置顶" 重磅干货,第一时间送达 作者丨kid丶@知乎(已授权) 来源丨https://zhuanla ...

  3. 【论文解读】CVPR 2021 妆容迁移 论文+ 代码 汇总,美得很美得很!

    妆容迁移是指将目标图上的妆容直接迁移到原图上的技术.相比传统贴妆技术,妆容迁移具有极高的自由度,它可以让用户不再局限于设计师设计好的妆容,而是可以自主.任意地从真实模特图中获取妆容,极大地丰富了妆容的 ...

  4. 【CVPR 2021】剪枝篇(二):Convolutional Neural Network Pruning with Structural Redundancy Reduction

    [CVPR 2021]剪枝篇(二):Convolutional Neural Network Pruning with Structural Redundancy Reduction 论文地址: 主要 ...

  5. 【CVPR 2021】剪枝篇(五):基于关键通路的神经网络可解释剪枝

    [CVPR 2021]剪枝篇(五):Neural Response Interpretation through the Lens of Critical Pathways 论文地址: 主要问题: 主 ...

  6. 闲话模型压缩之网络剪枝(Network Pruning)篇

    1. 背景 今天,深度学习已成为机器学习中最主流的分支之一.它的广泛应用不计其数,无需多言.但众所周知深度神经网络(DNN)有个很大的缺点就是计算量太大.这很大程度上阻碍了基于深度学习方法的产品化,尤 ...

  7. 华人占大半壁江山!CVPR 2021 目标检测论文大盘点(65篇论文)

    点击下方卡片,关注"CVer"公众号 AI/CV重磅干货,第一时间送达 作者:Amusi  |  来源:CVer 前言 CVer 正式盘点CVPR 2021上各个方向的工作,本篇是 ...

  8. 重磅!腾讯优图20篇论文入选CVPR 2021

    点击下方卡片,关注"CVer"公众号 AI/CV重磅干货,第一时间送达 来源:腾讯优图 计算机视觉世界三大顶会之一的CVPR 2021论文接收结果出炉!本次大会收到来自全球共701 ...

  9. 【论文】模型剪枝(Network Pruning)论文详细翻译

    前言: 这是关于模型剪枝(Network Pruning)的一篇论文,论文题目是:Learning both weights and connections for efficient neural ...

最新文章

  1. DOM4J对于XML的用法
  2. JavaScript 知识图谱
  3. Python Django模板页面过滤器使用示例
  4. Little Sub and Triangles
  5. python turtle库画椭圆_如何用Python画一只肥肥的柯基狗狗——turtle库绘制椭圆与弧线实践...
  6. 单链表逆置 java_单链表的就地逆置--java实现(含头节点和不包含头节点)
  7. ssh远程连接不上linux
  8. ROS的学习(十二)用C++写一个简单的发布者
  9. .Net中的并行编程-6.常用优化策略
  10. Eclipse导入Ant项目
  11. 下列不属于计算机网络特点的是自主性,计算机网络技术B卷
  12. type=button 字体大一点_CAD设计师喜欢用SHX字体的原因你知道吗?
  13. 3串口多串口双串口以及2串口UART转WiFi多跳通讯实现三
  14. linux安装pdf阅读器 | 安装删除有道词典
  15. 常见的数据结构及其特征
  16. PPT:动画出现设置
  17. 马云正式辞职,那么天才郭盛华到底在追求什么呢?
  18. python变量相关性,数据科学:定量和定性变量之间的相关性(python语言)
  19. 短视频优质作者必备|配音神器分享|那些你刷视频时肯定听过的声音
  20. 百度网盘下载显示系统限制,无法下载解决

热门文章

  1. python常用代码大全-Python常用库大全
  2. ppt是学计算机的那一块,如何制作一个一眼就喜欢的PPT封面
  3. css布局作业:京东首页轮播图
  4. linux-kernel-ecmp-ipv4
  5. cad图纸怎么转换成pdf格式
  6. PT2264解码心得
  7. 星际空间环境地面模拟:气氛、气压或真空度的精确模拟及控制
  8. 湖南大学计算机学院唐汝英,唐旺旺-湖南大学环境科学与工程学院
  9. 手写8个常用的自定义hooks(推荐阅读)
  10. 【图形学基础】光栅图