1. 导读

导读:这篇文章从不同于原始网络剪裁的角度出发,分析了预训练给网络剪裁带来的影响,之后分析是否在网络剪裁的时候需要进行预训练,从而提出了一种不需要预训练进行的网络剪裁方式,其是通过随机初始化的方式去获取剪裁的网络结构,在这个结构上也获得了与在预训练模型上性能接近的结果,好处就是省去了花费较多的预训练模型获取过程。

在文章中将网络剪裁的方式划分为如下的三种形式:

  • (a)传统的网络剪裁方式,使用预训练模型,之后在剪裁之后的网络结构上finetune;
  • (b)相比于(a)中的方法将finetune改成了从头训练,取得效果也与之近似;
  • (c)在(b)的基础上思考是否真的需要预训练,直接在随机初始化之后模型的基础上进行剪裁,再从头进行训练,省去获取预训练模型的开销;

2. 方法设计

2.1 预训练模型与网络剪裁

在传统的观念中网络的结构性剪裁是在网络进行训练完成之后,根据网络层中的参数按照设置的阈值进行剪裁,之后再在剪裁模型权重的基础上进行finetune。后续也有论文论证了直接使用剪裁之后的结构train from scratch也能达到相似的效果。因而这篇文章在这个基础上分析是否有可能实现不需要预训练就可以实现网络的剪裁。因而文章指出可不可以直接不经过预训练就进行剪裁。

这里通过保存不同迭代次数的模型,观察这些不同迭代次数的模型对于最后剪裁结果的影响。这样也就探明了预训练模型对于最后剪裁结果的影响。

2.2 剪裁结构的相似性

对于每个剪裁的模型,这里计算网络中每个层的剪裁比例,之后将网络中各层的剪裁比例放入一个向量中,从而借此表达剪裁之后的网络结构。得到这个表达向量之后就可以计算不同模型剪裁之后的相似比。为了确保实验的有效性,文章在CIFAR10数据集上使用VGG16选择了5个随机种子进行训练,作为对比,分析初始化对于剪裁结果的影响。
在下图中展示了所有剪裁模型的相关系数矩阵。

从上图中可以分析得到3点有用信息:

  • 1)从随机权重初始化得到的剪裁结果与在预训练权重上进行剪裁之后的结果并不一致,见图2(a,b)左上图。在pretrained不同epoch上进行剪裁其带来的变化比较大,而使用随机初始化使用文章的方法其剪裁更加固定,且在10个epoch之后高度近似(只迭代10 epoch就行了)
  • 2)从随机权重初始化得到的剪裁结构在相关系数上具有更大的变化,但是在经过一定epoch(代码中为10)训练之后,训练模型上得到的相关系数就更加趋向于一致(减少少花费在pretrained上的时间),见图2(c)图;
  • 3)在预训练模型的相邻checkpoint之间具有更高的相似性(在同一次run里面),见图2(a,b)中右图;

从上面的结论可以推知,在预训练阶段其带来的剪裁结构空间在训练过程中被严重压缩了,这也许会带来性能的局限,而相反使用随机的权重初始化使得剪裁算法可以探索更加多样的剪裁结构。

2.3 网络裁剪结构的性能

文章将上文章提到在不同条件下裁剪得到的网络结构进行训练,最后得到的性能见下表1所示:

从上面的表结果中可以得出以下结论:

  • 1)从随机初始化剪裁得到的模型经过训练之后也能获得与基于预训练模型的结果近似,这一定程度上表明随机初始化进行剪裁带来的优化结果可能更好;
  • 2)上面的结果现实基于预训练的模型在最后的结果上能够稍好一些,但是基于预训练模型需要的庞大计算量,使用随机初始化进行剪裁,通过训练之后也能获得类似的性能;

2.4 裁剪方法

文章中对于网络结构的表达使用f(x;W,α)f(x;W,\alpha)f(x;W,α)进行表示,其中x,W,αx,W,\alphax,W,α分别代表输入数据,网络权值与网络结构信息。在进行网络剪裁的时候,文章提出在网络层(jjj层)的channel-wise上使用一个标量值λj\lambda_jλj​进行相乘(对应的是Conv层之后的BatchNorm中scale参数,有对应channel个,可不是每个Conv层就关联一个数),由于这个标量值是接近于0的存在因而会抑制某些channel上的输出,因而可以产生类似网络剪枝的效果,对于网络中的KKK个层可以学习得到一组标量值∧={λ1,λ2,…,λK}\wedge=\{\lambda_1,\lambda_2,\dots,\lambda_K\}∧={λ1​,λ2​,…,λK​},因而网络的优化的目标函数可以写为:
min⁡∧∑iNL(f(xi;W,∧),yi)+γ∑jK∣λj∣1\min_{\wedge}\sum_i^NL(f(x_i;W,\wedge),y_i)+\gamma\sum_j^K|\lambda_j|_1∧min​i∑N​L(f(xi​;W,∧),yi​)+γj∑K​∣λj​∣1​
s.t.0≤λj≤1,∀j=1,2,…,Ks.t. 0\le \lambda_j \le 1,\forall j=1,2,\dots,Ks.t.0≤λj​≤1,∀j=1,2,…,K
在上面的网络剪裁过程中与传统的剪裁方法不同的地方在于两点:

  • 1)在计算网络层中channel重要性的时候并没有更新权值(先使用随机初始化的参数迭代,之后再在这个模型基础上获取BatchNorm的scale数值进行剪裁,剪裁完了之后再进行正常训练),也就是选择channel的过程没有在训练的过程中;
  • 2)使用随机初始化权值进行重要channel,并没有依赖于预训练;

在上面的剪枝过程中,可以使用一个期望的剪裁值rrr,来设置剪裁的期望,对此可以添加一个正则项:
ω(∧)=(∑j∣λj∣1∑jCj−r)2\omega(\wedge)=(\frac{\sum_j|\lambda_j|_1}{\sum_jC_j}-r)^2ω(∧)=(∑j​Cj​∑j​∣λj​∣1​​−r)2
其中,CjC_jCj​是网络层的channel数量。在决定最后的网络结构的时候文章通过在满足给定FLOPS下的全局阈值τ\tauτ来得到网络结构,这个阈值是通过二分查找的形式进行确定的。计算过程如下:

3. 实验结果

CIFAR10上剪裁方法的性能比较:

《Pruning from Scratch》论文笔记相关推荐

  1. 论文笔记之Understanding and Diagnosing Visual Tracking Systems

    Understanding and Diagnosing Visual Tracking Systems 论文链接:http://dwz.cn/6qPeIb 本文的主要思想是为了剖析出一个跟踪算法中到 ...

  2. 《Understanding and Diagnosing Visual Tracking Systems》论文笔记

    本人为目标追踪初入小白,在博客下第一次记录一下自己的论文笔记,如有差错,恳请批评指正!! 论文相关信息:<Understanding and Diagnosing Visual Tracking ...

  3. 论文笔记Understanding and Diagnosing Visual Tracking Systems

    最近在看目标跟踪方面的论文,看到王乃岩博士发的一篇分析跟踪系统的文章,将目标跟踪系统拆分为多个独立的部分进行分析,比较各个部分的效果.本文主要对该论文的重点的一个大致翻译,刚入门,水平有限,如有理解错 ...

  4. 目标跟踪笔记Understanding and Diagnosing Visual Tracking Systems

    Understanding and Diagnosing Visual Tracking Systems 原文链接:https://blog.csdn.net/u010515206/article/d ...

  5. 追踪系统分模块解析(Understanding and Diagnosing Visual Tracking Systems)

    追踪系统分模块解析(Understanding and Diagnosing Visual Tracking Systems) PROJECT http://winsty.net/tracker_di ...

  6. ICCV 2015 《Understanding and Diagnosing Visual Tracking Systems》论文笔记

    目录 写在前面 文章大意 一些benchmark 实验 实验设置 基本模型 数据集 实验1 Featrue Extractor 实验2 Observation Model 实验3 Motion Mod ...

  7. Understanding and Diagnosing Visual Tracking Systems

    文章把一个跟踪器分为几个模块,分别为motion model, feature extractor, observation model, model updater, and ensemble po ...

  8. CVPR 2017 SANet:《SANet: Structure-Aware Network for Visual Tracking》论文笔记

    理解出错之处望不吝指正. 本文模型叫做SANet.作者在论文中提到,CNN模型主要适用于类间判别,对于相似物体的判别能力不强.作者提出使用RNN对目标物体的self-structure进行建模,用于提 ...

  9. ICCV 2017 UCT:《UCT: Learning Unified Convolutional Networks forReal-time Visual Tracking》论文笔记

    理解出错之处望不吝指正. 本文模型叫做UCT.就像论文题目一样,作者提出了一个基于卷积神经网络的end2end的tracking模型.模型的整体结构如下图所示(图中实线代表online trackin ...

  10. CVPR 2018 STRCF:《Learning Spatial-Temporal Regularized Correlation Filters for Visual Tracking》论文笔记

    理解出错之处望不吝指正. 本文提出的模型叫做STRCF. 在DCF中存在边界效应,SRDCF在DCF的基础上中通过加入spatial惩罚项解决了边界效应,但是SRDCF在tracking的过程中要使用 ...

最新文章

  1. 敏捷开发之道(二)极限编程XP
  2. php mysql预处理_php mysqli扩展之预处理
  3. Python代码提取时间序列特征基于tsfeature
  4. 聚合中返回source_Java 8 中的 Streams API 详解—— Streams 的背景以及 Java 8 中的使用详解...
  5. python中的单引号双引号和三引号
  6. 二分查找法---java实现
  7. php 内部异步执行顺序,event_loop中不同异步操作的执行顺序
  8. PID控制器开发笔记之二:积分分离PID控制器的实现
  9. HTML5地图分布动画
  10. 如何使用DevStack在Ubuntu 18.04上安装OpenStack
  11. DNW启动异常的问题
  12. 没有权限角色管理功能菜单加载
  13. 解析分级存储管理(HSM)
  14. shl归纳推理测试题库_强生2020秋招笔试面试经验合集
  15. 跟青翼一起学Qt4编程大纲目录
  16. Java.util包简单总结
  17. 抖音无水印视频抓取与按帧截取图片
  18. Armeria 小试牛刀
  19. Java实现输出水仙花(易懂)
  20. nginx proxy_pass规则

热门文章

  1. Qt QString详解
  2. Java乐图下载_Java平台乐图导航地图测评:实时跟踪是亮点
  3. c/c++中system函数
  4. 普通人有捷径可以走吗?
  5. java 树形数据_JAVA获取树形结构数据
  6. VueJS中axios关于回调函数this为undefined的问题
  7. Android 7.0应用抽屉,安卓7.0抛弃应用抽屉是致敬苹果iOS?
  8. python发送阿里云短信教程
  9. 成都女孩学什么技术好
  10. C# 生成一维码(条形码)和二维码