这篇文章来自于旷视。旷视内部有一个基础模型组,孙剑老师也是很看好 NAS 相关的技术,相信这篇文章无论从学术上还是工程落地上都有可以让人借鉴的地方。回到文章本身,模型剪枝算法能够减少模型计算量,实现模型压缩和加速的目的,但是模型剪枝过程中确定剪枝比例等参数的过程实在让人头痛。

这篇文章提出了 PruningNet 的概念,自动为剪枝后的模型生成权重,从而绕过了费时的 retrain 步骤。并且能够和进化算法等搜索方法结合,通过搜索编码 network 的 coding vector,自动地根据所给约束搜索剪枝后的网络结构。和 AutoML 技术相比,这种方法并不是从头搜索,而是从已有的大模型出发,从而缩小了搜索空间,节省了搜索算力和时间。

个人觉得这种剪枝和 NAS 结合的方法,应该会在以后吸引越来越多人的注意。这篇文章的代码已经开源在了 Github:

https://github.com/liuzechun/MetaPruning

Motivation

模型剪枝是一种能够减少模型大小和计算量的方法。模型剪枝一般可以分为三个步骤:

  • 训练一个参数量较多的大网络

  • 将不重要的权重参数剪掉

  • 剪枝后的小网络做 fine tune

其中第二步是模型剪枝中的关键。有很多 paper 围绕“怎么判断权重是否重要”以及“如何剪枝”等问题进行讨论。困扰模型剪枝落地的一个问题就是剪枝比例的确定。

传统的剪枝方法常常需要人工 layer by layer 地去确定每层的剪枝比例,然后进行 fine tune,用起来很耗时,而且很不方便。不过最近的 Rethinking the Value of Network Pruning [1] 指出,剪枝后的权重并不重要,对于 channel pruning 来说,更重要的是找到剪枝后的网络结构,具体来说就是每层留下的 channel 数量。

受这个发现启发,文章提出可以用一个 PruningNet,对于给定的剪枝网络,自动生成 weight,无需进行 retrain,然后评测剪枝网络在验证集上的性能,从而选出最优的网络结构。

具体来说,PruningNet 的输入是剪枝后的网络结构,必须首先对网络结构进行编码,转换为 coding vector。这里可以直接用剪枝后网络每层的 channel 数来编码。在搜索剪枝网络的时候,我们可以尝试各种 coding vector,用 PruningNet 生成剪枝后的网络权重。网络结构和权重都有了,就可以去评测网络的性能。进而用进化算法搜索最优的 coding vector,也就是最优的剪枝结构。在用进化算法搜索的时候,可以使用自定义的目标函数,包括将网络的 accuracy,latency,FLOPS 等考虑进来。

PruningNet

从上一小节已经可以知道,PruningNet 是整个算法的关键。那么怎么才能找到这样一个“神奇网络”呢?

先做一下符号约定,使用 ci 表示剪枝之后第 i 层的 channel 数量, l 为网络的层数, W 表示剪枝后网络的权重。那么 PruningNet 的输入输出如下所示:

训练

先结合下图看一下 forward 部分。PruningNet 是由 l 个 PruningBlock 组成的,每个 PruningBlock 是一个两层的 MLP。

首先看图 b,编码着网络结构信息的 coding vector 输入到当前 block 后,输出经过 Reshape,成了一个 Weight Matrix。注意哦,这里的 WeightMatrix 是固定大小的(也就是未剪枝的原始 Weight shape 大小),和剪枝网络结构无关。

再看图 a,因为要对网络进行剪枝,所以 WeightMatrix 要进行 Crop。对应到图 b,可以看到,Crop 是在两个维度上进行的。首先,由于上一层也进行了剪枝,所以 input channel 数变少了;其次,由于当前层进行了剪枝,所以 output channel 数变少了。这样经过 Crop,就生成了剪枝后的网络 weight。我们再输入一个 mini batch 的训练图片,就可以得到剪枝后的网络的 loss。

在 backward 部分,我们不更新剪枝后网络的权重,而是更新 PruningNet 的权重。由于上面的操作都是可微分的,所以直接用链式法则传过去就行。如果你使用 PyTorch 等支持自动微分的框架,这是很容易的。

下图所示是训练过程的整个 PruningNet(左侧)和剪枝后网络(右侧,即 PrunedNet)。训练过程中的 coding vector 在状态空间里随机采样,随机选取每层的 channel 数量。

PS:和原始论文相比,下图和上图顺序是颠倒的。这里从底向上介绍了 PruningNet 的训练,而论文则是自顶向下。

搜索

训练好 PruningNet 后,就可以用它来进行搜索了!我们只需要输入某个 coding vector,PruningNet 就会为我们生成对应每层的 WeightMatrix。别忘了 coding vector 是编码的网络结构,现在又有了 weight,我们就可以在验证集上测试网络的性能了。进而,可以使用进化算法等优化方法去搜索最优的 coding vector。当我们得到了最优结构的剪枝网络后,再 from scratch 地训练它。

进化算法这里不再赘述,很多优化的书中包括网上都有资料。这里把整个算法流程贴出来:

实验

作者在 ImageNet 上用 MobileNet 和 ResNet 进行了实验。训练 PruningNet 用了 1/4 的原模型的 epochs。数据增强使用常见的标准流程,输入 image 大小为 224×224。

将原始 ImageNet 的训练集做分割,每个类别选 50 张组成 sub-validation(共计 50000),其余作为 sub-training。在训练时,我们使用 sub-training 训练 PruningNet。在搜索时,使用 sub-validation 评估剪枝网络的性能。不过,还要注意,在搜索时,使用 20000 张 sub-training 中的图片重新计算 BatchNorm layer 中的 running mean 和 running variance。

shortcut 剪枝

在进行模型剪枝时,一个比较难处理的问题是 ResNet 中的 shortcut 结构。因为最后有一个 element-wise 的相加操作,必须保证两路 feature map 是严格 shape 相同的,所以不能随意剪枝,否则会造成 channel 不匹配。下面对几种论文中用到的网络结构分别讨论。

MobileNet-v1

MobileNet-v1 是没有 shortcut 结构的。我们为每个 conv layer 都配上相应的 PruningBlock——一个两层的 MLP。PruningNet 的输入 coding vector 中的元素是剪枝后每层的 channel 数量。而输入第 i 个 PruningBlock 的是一个 2D vector,由归一化的第 i-1 层和第 i 层的剪枝比例构成。这部分可以结合代码来看:

https://github.com/liuzechun/MetaPruning/blob/master/mobilenetv1/training/mobilenet_v1.py#L15

注意第 1 个 conv layer 的输入是 1D vector,因为它是第一个被剪枝的 layer。在训练时,coding vector 的搜索空间被以一定步长划分为 grid,采样就是在这些格点上进行的。

MobileNet-v2

MobileNet-v2 引入了类似 ResNet 的 shortcut 结构,这种 resnet block 必须统一看待。具体来说,对于没有在 resnet block 中的conv,处理方法如 MobileNet-v1。对每个 resnet block,配上一个相应的 PruningBlock。由于每个 resnet block 中只有一个中间层(3×3 的 conv),所以输出第 i 个 PruningBlock 的是一个 3D vector,由归一化的第 i-1 个 resnet block,第 i 个 resnet block 和中间 conv 层的剪枝比例构成。其他设置和 MobileNet-v1 相同。这里可以结合代码来看:

https://github.com/liuzechun/MetaPruning/blob/master/mobilenetv2/training/mobilenet_v2.py#L109

ResNet

处理方法如 MobileNet-v2 所示。可以结合代码来看:

https://github.com/liuzechun/MetaPruning/blob/master/resnet/training/resnet.py#L75

实验结果

在相近 FLOPS 情况下,和 MobileNet 论文中改变 ratio 参数得到的模型比较,MetaPruning 得到的模型 accuracy 更高。尤其是压缩比例更大时,该方法更有优势。

和其他剪枝方法(如 AMC [2])等方法比较,该方法也得到了 SOTA 的结果。MetaPruning 方法能够以一种统一的方法处理 ResNet 中的 shortcut 结构,并且不需要人工调整太多的参数。

上面的比较都是基于理论 FLOPS,现在更多人在关注网络在实际硬件上的 latency 怎么样。文章对此也进行了讨论。如何测试网络的 latency?

当然可以每个网络都实际跑一下,不过有些麻烦。基于每个 layer 的 inference 时间是互相独立的这个假设,作者首先构造了各个 layer inference latency 的查找表(参见论文 Fbnet: Hardware-aware efficient convnet design via differentiable neural architecture search [3]),以此来估计实际网络的 latency。作者这里和 MobileNet baseline 做了比较,结果也证明了该方法更优。

PruningNet 结果分析

此外,作者还对 PruningNet 的预测结果进行可视化,试图找出一些可解释性,并找出剪枝参数的一些规律。

  • down-sampling 的部分 PruningNet 倾向于保留更多的 channel,如 MobileNet-v2 block 中间的那个 conv;

  • 优先剪浅层 layer 的 channel,FLOPS 约束太强剪深层的 channel,但可能会造成网络 accuracy 下降比较多。

总结

这篇文章从“剪枝后的 weight 作用不大”的现象出发,将剪枝和 NAS 结合,提出了 PruningNet 为剪枝后的网络预测 weight,避免了网络的 retrain,从而可以快速衡量剪枝网络的性能。并在编码网络信息的 coding vector 状态空间进行搜索,找到给定约束条件下的最优网络结构,在 ImageNet 数据集和 ResNet/MobileNet-v1/v2 上取得了比之前剪枝算法更好的效果。

总结

随着深度神经网络模型在各个场景下的落地,模型的压缩和加速越来越受到大家的重视,剪枝是其中的重要方法。传统的剪枝算法人工确定较多的参数,所以很多文章开始考虑端到端的剪枝。

这篇论文把剪枝算法和 NAS 结合,取两者之长,用待剪枝的模型缩小了搜索空间,用进化算法自动搜索最优网络结构。使用 coding vector 编码网络结构,用一个很简单的双隐层感知机预测网络权重,并提出了一种 shortcut 的处理方法,在 ImageNet 数据集和几种常用网络结构上取得了不错的结果。文章提出的方法简单易于操作,可以很方便地应用到自己的业务场景中。相关代码已经开源在 Github 上。

相关链接

[1] https://arxiv.org/abs/1810.05270

[2] https://arxiv.org/abs/1802.03494

[3] https://arxiv.org/abs/1812.03443

点击以下标题查看更多往期内容:

#投 稿 通 道#

 让你的论文被更多人看到 

如何才能让更多的优质内容以更短路径到达读者群体,缩短读者寻找优质内容的成本呢?答案就是:你不认识的人。

总有一些你不认识的人,知道你想知道的东西。PaperWeekly 或许可以成为一座桥梁,促使不同背景、不同方向的学者和学术灵感相互碰撞,迸发出更多的可能性。

PaperWeekly 鼓励高校实验室或个人,在我们的平台上分享各类优质内容,可以是最新论文解读,也可以是学习心得技术干货。我们的目的只有一个,让知识真正流动起来。

来稿标准:

• 稿件确系个人原创作品,来稿需注明作者个人信息(姓名+学校/工作单位+学历/职位+研究方向)

• 如果文章并非首发,请在投稿时提醒并附上所有已发布链接

• PaperWeekly 默认每篇文章都是首发,均会添加“原创”标志

? 投稿邮箱:

• 投稿邮箱:hr@paperweekly.site

• 所有文章配图,请单独在附件中发送

• 请留下即时联系方式(微信或手机),以便我们在编辑发布时和作者沟通

?

现在,在「知乎」也能找到我们了

进入知乎首页搜索「PaperWeekly」

点击「关注」订阅我们的专栏吧

关于PaperWeekly

PaperWeekly 是一个推荐、解读、讨论、报道人工智能前沿论文成果的学术平台。如果你研究或从事 AI 领域,欢迎在公众号后台点击「交流群」,小助手将把你带入 PaperWeekly 的交流群里。

▽ 点击 | 阅读原文 | 下载论文 & 源码

ICCV 2019 开源论文 | 基于元学习和AutoML的模型压缩新方法相关推荐

  1. ICCV 2019 | 旷视提出MetaPruning:基于元学习和AutoML的模型压缩新方法

    点击我爱计算机视觉标星,更快获取CVML新技术 两年一度的国际计算机视觉大会 ICCV 2019 ( IEEE International Conference on Computer Vision) ...

  2. CVPR 2019开源论文 | 基于“解构-重构”的图像分类学习框架

    作者丨白亚龙 单位丨京东AI研究院研究员 研究方向丨表示学习.图像识别 基于深度卷积图像识别的相关技术主要专注于高层次图像特征的理解,而对于相似物体之间的细节差异和具有判别意义的区域(discrimi ...

  3. ICCV 2019 开源论文 | ShapeMatchingGAN:打造炫酷动态的艺术字

    作者丨杨帅 学校丨北京大学博士生 研究方向丨图像风格化 论文引入 当你制作 PPT 时想要打造与背景图片风格一致的标题,当你设计宣传海报时想要一个引人注意的标题,当你发朋友圈时想要更生动地展示文字所传 ...

  4. ICCV 2019开源论文 | 捕捉图像DNA——单幅图像下可实现任意尺度自然变换

    作者丨武广 学校丨合肥工业大学硕士生 研究方向丨图像生成 图像翻译这个领域的应用是相当的多,图像风格迁移.图像修复.图像属性变换.图像分割.图像模态的转换等都可以统称为图像翻译的任务.本文将介绍一个图 ...

  5. CVPR 2019 开源论文 | 基于空间自适应归一化的图像语义合成

    作者丨武广 学校丨合肥工业大学硕士生 研究方向丨图像生成 深度学习在算力的推动下不断的发展,随着卷积层的堆叠,模型的层数是越来越深,理论上神经网络中的参数越多这样对数据的拟合和分布描述就能越细致.然而 ...

  6. ACL 2019 开源论文 | 基于知识库和大规模网络文本的问答系统

    作者丨张琨 学校丨中国科学技术大学博士生 研究方向丨自然语言处理 论文动机 当前问答系统面对的一大问题就是如何利用先验知识.我们人类可以通过不断的学习,掌握非常多的先验知识,并通过这些知识来回答问题. ...

  7. ACL 2019开源论文 | 基于图匹配神经网络的跨语言知识图对齐

    作者丨王文博 学校丨哈尔滨工程大学硕士生 研究方向丨知识图谱.表示学习 动机 在本篇文章之前,跨语言知识图谱对齐研究仅依赖于从单语知识图谱结构信息中获得的实体嵌入向量.并且大多数研究将实体映射到低维空 ...

  8. CVPR 2019 开源论文 | 基于翻译向量的图像翻译

    作者丨薛洁婷 学校丨北京交通大学硕士生 研究方向丨图像翻译 图像翻译通常要解决两个问题:将原域图像翻译至目标域并且翻译后的图像和原域图像保持相似性.我们利用 GAN 可以很好的解决第一个问题,而针对第 ...

  9. SIGIR 2019 开源论文 | 基于图神经网络的协同过滤算法

    作者丨纪厚业 单位丨北京邮电大学博士生 研究方向丨异质图神经网络,异质图表示学习和推荐系统 引言 协同过滤作为一种经典的推荐算法在推荐领域有举足轻重的地位.协同过滤(collaborative fil ...

最新文章

  1. 静态时序分析的概念以及约束的作用理解
  2. 学习 Message(3): 响应 WM_LBUTTONDOWN 消息
  3. 用注解方式写定时任务
  4. 2018.08.09洛谷P3959 宝藏(随机化贪心)
  5. 永远不要低估“价值互联网”!| 技术头条
  6. _GNUC__宏函数
  7. Linux下把U盘格式化为fat32
  8. STL -- string类字符串
  9. app浮层html,App设计之五:弹窗与浮层
  10. Oracle PLM,协同研发的产品生命周期管理平台
  11. Linux创建WIFI热点
  12. 山西流传于百姓餐桌的宫府名菜——山西过油肉
  13. java时间戳转换工具类
  14. (logN)²是O(N)的
  15. 放大电路①---共射极放大电路
  16. 工厂设备管理远程监控方案
  17. 贪吃小怪物显示服务器人数爆满,贪吃小怪物进不去怎么办 贪吃小怪物为什么进不去...
  18. 折弯机使用说明书_折弯机基本操作说明
  19. P2P平台方案——亿网软通“互联网+”金融解决方案
  20. 烤仔观察 | NFT+社交,2021年欧洲杯观赛新“姿势”来啦~速戳!

热门文章

  1. 语音计算矩形面积_【2020年第7期】螺旋折流板换热器质心当量矩形通用计算模型...
  2. FZU Monthly-201903 获奖名单
  3. Linux进入单用户模式(passwd root修改密码)
  4. 使用SQL脚本创建数据库,操作主键、外键与各种约束(MS SQL Server)
  5. Bootstrap 环境安装
  6. 【转】让itunes下载加速的真正办法,转向至香港台湾澳门苹果服务器 -- 不错不错!!!...
  7. POJ 3167 Cow Pattern ★(KMP好题)
  8. C专家编程-Chapter6 运行时数据结构(转)
  9. 计算机网络.doc,计算机网络network.doc
  10. vue dplayer 加载失败_最新vue脚手架项目搭建,并解决一些折腾人的问题