1. 概述

一般在做模型的时候开始最关心的是模型的性能,也就是模型的精度,我们可以增加网络的宽度与深度,来不断增加模型的表达能力。在精度达标之后,网络也变地很臃肿了,其实里面很多的参数都是非必须的,也就是冗余的。如何去掉这些冗余呢?在之前的文章中讲到了了几种方法,这篇论文中给出的方法与之前的方法思路不同,是按照聚类的思想来去除冗余的filters,从而减少网络中filters的数量,达到网络剪裁的目的。在文章中给出了filter之间相似度的度量方法,并用这个度量方法作为filter合并的依据。
论文地址:Building Efficient ConvNets using Redundant Feature Pruning
代码地址:Redundant-Feature-Pruning-Pytorch-Implementation

2. 实现

这篇文章中涉及到的剪枝方法示意图见下图:

在图1中ZZZ代表feature map,WWW代表filters。这里假设ϕ1\phi_1ϕ1​是filter的聚类表达,那么ϕ3,ϕ5\phi_3, \phi_5ϕ3​,ϕ5​由于与ϕ1\phi_1ϕ1​的相似性高,那么基于冗余去除原则,其在Zl+1Z_{l+1}Zl+1​中的feature map与Wl+1W^{l+1}Wl+1中对应的部分会被删除掉。
这里涉及到这些filter是之间距离度量函数的选择问题。首先,将每个filter作为一个聚类的中心,然后将其中的聚类中心,按照阈值τ\tauτ进行合并(大于这个阈值那就合并,小于该阈值不合并),其合并的距离度量可以使用下面的方式进行描述:

其中,SIMC(ϕ1,ϕ2)=&lt;ϕ1,ϕ2&gt;∣∣ϕ1∣∣∣∣ϕ2∣∣SIM_C(\phi_1, \phi_2)=\frac{&lt;\phi_1, \phi_2&gt;}{||\phi_1|| ||\phi_2||}SIMC​(ϕ1​,ϕ2​)=∣∣ϕ1​∣∣∣∣ϕ2​∣∣<ϕ1​,ϕ2​>​是两个特征之间的cos距离度量,&lt;ϕ1,ϕ2&gt;&lt;\phi_1, \phi_2&gt;<ϕ1​,ϕ2​>是特征的内积。

基于冗余特征裁剪在lthl^{th}lth卷积层的使用可以归纳为如下步骤:
1)对filters ϕi\phi_iϕi​进行聚类得到nfn_fnf​个聚类中心,使用阈值τ\tauτ;
2)考虑两种启发式方式:(A)在nfn_fnf​个聚类中心中随机选取一个filter,然后去剪裁剩余的filters与其对应的feature map;(B)随机剪裁掉n′−nfn^{'} - n_fn′−nf​个filters和其对应的feature map。在后面一层(l+1)th(l+1)^{th}(l+1)th中的对应参数也需要剪裁掉。对于这两种不同的策略,论文在后面也对其做了实验分析,详见后文。
3)得到lthl^{th}lth与(l+1)th(l+1)^{th}(l+1)th新的kernel矩阵。

3. 实验结论

首先,来看一下论文整体的性能吧,这里也比较了上面提到的A、B两种启发式策略,并比较了两种策略的结果

之后在每个层上按照之前的两个方案进行裁剪,得到下图的前两个,之后进行finetune得到下图:

可以看到方案A的曲线比方案B的曲线更加柔和一点。

3.2 ResNet网络

这篇论文中没有涉及到对shortcut连接两端的剪裁,而是只操作了残差块的第一个卷积。实验的讨论和刚才看的VGG网络的一样。首先是各层上冗余度与阈值的关系

ResNet-56中各层裁剪比例与错误率的关系

4. 代码实现

论文中的算法在上面给出了代码链接,代码量还是比较少的,核心的函数就一个,这里给出我对它的理解:

'''/*
函数功能:对权重按照给出的阈值进行聚类,返回聚类之后的中心数目与对应的index
weight:聚类的权重,维度为[n_features, n_samples]
threshold:阈值控制聚类的程度,越大表示结果聚类数目越多
*/
'''
def cluster_weights_agglo(weight, threshold, average=True):t0 = time.time()weight = weight.Tweight = normalize(weight, norm='l2', axis=1)threshold =  1.0-threshold   # Conversion to distance measureclusters = hcluster.fclusterdata(weight, threshold, criterion="distance", metric='cosine', depth=1, method='centroid')z = hac.linkage(weight, metric='cosine', method='complete')labels = hac.fcluster(z, threshold, criterion="distance")labels_unique = np.unique(labels)n_clusters_ = len(labels_unique)#print(n_clusters_)elapsed_time = time.time() - t0# print(elapsed_time)a=np.array(labels)sort_idx = np.argsort(a)a_sorted = a[sort_idx]unq_first = np.concatenate(([True], a_sorted[1:] != a_sorted[:-1]))unq_items = a_sorted[unq_first]unq_count = np.diff(np.nonzero(unq_first)[0])unq_idx = np.split(sort_idx, np.cumsum(unq_count))first_ele = [unq_idx[idx][-1] for idx in xrange(len(unq_idx))]return n_clusters_, first_ele

《Building Efficient ConvNets using Redundant Feature Pruning》论文笔记相关推荐

  1. 论文笔记之Understanding and Diagnosing Visual Tracking Systems

    Understanding and Diagnosing Visual Tracking Systems 论文链接:http://dwz.cn/6qPeIb 本文的主要思想是为了剖析出一个跟踪算法中到 ...

  2. 《Understanding and Diagnosing Visual Tracking Systems》论文笔记

    本人为目标追踪初入小白,在博客下第一次记录一下自己的论文笔记,如有差错,恳请批评指正!! 论文相关信息:<Understanding and Diagnosing Visual Tracking ...

  3. 论文笔记Understanding and Diagnosing Visual Tracking Systems

    最近在看目标跟踪方面的论文,看到王乃岩博士发的一篇分析跟踪系统的文章,将目标跟踪系统拆分为多个独立的部分进行分析,比较各个部分的效果.本文主要对该论文的重点的一个大致翻译,刚入门,水平有限,如有理解错 ...

  4. 目标跟踪笔记Understanding and Diagnosing Visual Tracking Systems

    Understanding and Diagnosing Visual Tracking Systems 原文链接:https://blog.csdn.net/u010515206/article/d ...

  5. 追踪系统分模块解析(Understanding and Diagnosing Visual Tracking Systems)

    追踪系统分模块解析(Understanding and Diagnosing Visual Tracking Systems) PROJECT http://winsty.net/tracker_di ...

  6. ICCV 2015 《Understanding and Diagnosing Visual Tracking Systems》论文笔记

    目录 写在前面 文章大意 一些benchmark 实验 实验设置 基本模型 数据集 实验1 Featrue Extractor 实验2 Observation Model 实验3 Motion Mod ...

  7. Understanding and Diagnosing Visual Tracking Systems

    文章把一个跟踪器分为几个模块,分别为motion model, feature extractor, observation model, model updater, and ensemble po ...

  8. CVPR 2017 SANet:《SANet: Structure-Aware Network for Visual Tracking》论文笔记

    理解出错之处望不吝指正. 本文模型叫做SANet.作者在论文中提到,CNN模型主要适用于类间判别,对于相似物体的判别能力不强.作者提出使用RNN对目标物体的self-structure进行建模,用于提 ...

  9. ICCV 2017 UCT:《UCT: Learning Unified Convolutional Networks forReal-time Visual Tracking》论文笔记

    理解出错之处望不吝指正. 本文模型叫做UCT.就像论文题目一样,作者提出了一个基于卷积神经网络的end2end的tracking模型.模型的整体结构如下图所示(图中实线代表online trackin ...

  10. CVPR 2018 STRCF:《Learning Spatial-Temporal Regularized Correlation Filters for Visual Tracking》论文笔记

    理解出错之处望不吝指正. 本文提出的模型叫做STRCF. 在DCF中存在边界效应,SRDCF在DCF的基础上中通过加入spatial惩罚项解决了边界效应,但是SRDCF在tracking的过程中要使用 ...

最新文章

  1. 字符串面试题(一)字符串逆序
  2. linux 下 jenkins 安装注意事项
  3. python基础代码事例-学习笔记:python3,代码。小例子习作(2017)
  4. Vue3入门笔记—2022年1月9日
  5. 通过服务器给多台计算机装系统,怎么快速给机房多台电脑安装系统?
  6. LeetCode 1874. 两个数组的最小乘积和
  7. C语言libcurl例程:multi 多线程,多任务
  8. Wolfram 语言之父 Stephen Wolfram :编程的未来
  9. asp.net下Response.ContentType类型汇总
  10. 【日期工具类】DateUtils
  11. 移动通讯市场发展概况及预测
  12. VUE-waterfall瀑布流组件使用
  13. java8对类集合使用 Comparator.comparing 进行排序
  14. 深度学习上采样下采样概念以及实现
  15. Android下载并打开PDF文件
  16. python与c语言的区别-c语言和python之间有什么区别
  17. antdvue的table合计行
  18. 计算机一级基础题库,2016计算机一级公共基础练习题
  19. java连接Zookeeper,获取节点数据报错
  20. Linux ALSA驱动之三:PCM创建流程源码分析(基于Linux 5.18)

热门文章

  1. 【计量经济学导论】05. 异方差
  2. android代码 获取本次通话时间
  3. oracle存储过程游标写法,Oracle存储过程,游标使用
  4. 关于春晚红包活动自己的思考
  5. 新风口下,物联网将在哪些方面改善物流行业
  6. 微服务SpringCloud Alibaba架构
  7. 射频电路与天线(华南理工金品公开课)学习笔记--绪论
  8. 如何让一个2008年的电脑可以正常服役
  9. Out of Distribution(OoD)检测相关方法综述
  10. java方法命名规范(持续更新)