点击上方“视学算法”,星标公众号
重磅干货,第一时间送达
作者:zhouyuangan
来源:CVer

这篇文章将从一个证件检测网络(Retinanet)的轻量化谈起,简洁地介绍,我在实操中使用到的设计原则和idea,并贴出相关的参考资料和成果供读者参考。因此本文是一篇注重工程性、总结个人观点的文章,存在不恰当的地方,请读者在评论区指出,方便交流。

目前已有的轻量网络有:MobileNet V2和ShuffleNet v2为代表。在实际业务中,Retinanet仅需要检测证件,不涉及过多的类别物体的定位和分类,因此,我认为仅仅更换上述两个骨架网络来优化模型的性能是不够的,需要针对证件检测任务,专门设计一个更加轻量的卷积神经网络来提取、糅合特征。

设计原则:

1. 更多的数据

轻量的浅层网络特征提取能力不如深度网络,训练也更需要技巧。假设保证有足够多的训练的数据,轻量网络训练会更加容易。

Facebook研究院的一篇论文[1]提出了“数据蒸馏”的方法。实际上,标注数据相对未知数据较少,我使用已经训练好、效果达标的base resnet50的retinanet来进行自动标注,得到一批10万张机器标注的数据。这为后来的轻量网络设计奠定了数据基础。我认为这是构建一个轻量网络必要的条件之一,网络结构的有效性验证离不开大量的实验结果来评估。

接下来,这一部分我将简洁地介绍轻量CNN地设计的四个原则

2. 卷积层的输入、输出channels数目相同时,计算需要的MAC(memory access cost)最少

3. 过多的分组卷积会增加MAC

对于1x1的分组卷积(例如:MobileNetv2的深度可分离卷积采用了分组卷积),其MAC和FLOPS的关系为:

g代表分组卷积数量,很明显g越大,MAC越大。详细参考[2]

4. 网络结构的碎片化会减少可并行计算

这些碎片化更多是指网络中的多路径连接,类似于short-cut,bottle neck等不同层特征融合,还有如FPN。拖慢并行的一个很主要因素是,运算快的模块总是要等待运算慢的模块执行完毕。

5. Element-wise操作会消耗较多的时间(也就是逐元素操作)

从表中第一行数据看出,当移除了ReLU和short-cut,大约提升了20%的速度。

以上是从此篇论文[2]中转译过来的设计原则,在实操中,这四条原则需要灵活使用。

根据以上几个原则进行网络的设计,可以将模型的参数量、访存量降低很大一部分。

接下来介绍一些自己总结的经验。

6. 网络的层数不宜过多

通常18层的网络属于深层网络,在设计时,应选择一个参考网络基线,我选择的是resnet18。由于Retinanet使用了FPN特征金字塔网络来融合各个不同尺度范围的特征,因此Retinanet仍然很“重”,需要尽可能压缩骨架网络的冗余,减少深度。

7. 首层卷积层用空洞卷积和深度可分离卷积替换

一个3x3,d=2的空洞卷积在感受野上,可以看作等效于5x5的卷积,提供比普通3x3的卷积更大的感受野,这在网络的浅层设计使用它有益。计算出网络各个层占有的MAC和参数量,将参数量和计算量“重”的卷积层替换成深度可分离卷积层,可以降低模型的参数量。

这里提供一个计算pytorch 模型的MAC和FLOPs的python packages[3]

if __name__ == "__main__":    from ptflops import get_model_complexity_info

    net = SNet(num_classes=1)    x = torch.Tensor(1, 3, 224, 224)

    net.eval()

    if torch.cuda.is_available():        net = net.cuda()        x = x.cuda()

    with torch.cuda.device(0):        flops, params = get_model_complexity_info(net, (224, 224), print_per_layer_stat=True, as_strings=True, is_cuda=True)        print("FLOPS:", flops)        print("PARAMS:", params)

output:

(regressionModel): RegressionModel(    0.045 GMac, 27.305% MACs,    (conv1): Conv2d(0.009 GMac, 5.257% MACs, 128, 256, kernel_size=(1, 1), stride=(1, 1))    (act1): ReLU(0.0 GMac, 0.041% MACs, )    (conv2): Conv2d(0.017 GMac, 10.472% MACs, 256, 256, kernel_size=(1, 1), stride=(1, 1))    (act2): ReLU(0.0 GMac, 0.041% MACs, )    (conv3): Conv2d(0.017 GMac, 10.472% MACs, 256, 256, kernel_size=(1, 1), stride=(1, 1))    (act3): ReLU(0.0 GMac, 0.041% MACs, )    (output): Conv2d(0.002 GMac, 0.982% MACs, 256, 24, kernel_size=(1, 1), stride=(1, 1))  )  (classificationModel): ClassificationModel(    0.044 GMac, 26.569% MACs,    (conv1): Conv2d(0.009 GMac, 5.257% MACs, 128, 256, kernel_size=(1, 1), stride=(1, 1))    (act1): ReLU(0.0 GMac, 0.041% MACs, )    (conv2): Conv2d(0.017 GMac, 10.472% MACs, 256, 256, kernel_size=(1, 1), stride=(1, 1))    (act2): ReLU(0.0 GMac, 0.041% MACs, )    (conv3): Conv2d(0.017 GMac, 10.472% MACs, 256, 256, kernel_size=(1, 1), stride=(1, 1))    (act3): ReLU(0.0 GMac, 0.041% MACs, )    (output): Conv2d(0.0 GMac, 0.245% MACs, 256, 6, kernel_size=(1, 1), stride=(1, 1))    (output_act): Sigmoid(0.0 GMac, 0.000% MACs, )  )

8. Group Normalization 替换 Batch Normalization

BN在诸多论文中已经被证明了一些缺陷,而训练目标检测网络耗费显存,开销巨大,通常冻结BN来训练,原因是小批次会让BN失效,影响训练的稳定性。建议一个BN的替代--GN,pytorch 0.4.1内置了GN的支持。

9. 减少不必要的shortcut连接和RELU层

网络不够深,没有必要使用shortcut连接,不必要的shortcut会增加计算量。RELU与shortcut一样都会增加计算量。同样RELU没有必要每一个卷积后连接(需要实际训练考虑删减RELU)。

10. 善用1x1卷积

1x1卷积可以改变通道数,而不改变特征图的空间分辨率,参数量低,计算效率也高。如使用kernel size=3,stride=1,padding=1,可以保证特征图的空间分辨率不变,1x1的卷积设置stride=1,padding=0达到相同的目的,而且1x1卷积运算的效率目前有很多底层算法支持,效率更高。[5x1] x [1x5] 两个卷积可以替换5x5卷积,同样可以减少模型参数。

11. 降低通道数

降低通道数可以减少特征图的输出大小,显存占用量下降明显。参考原则2

12. 设计一个新的骨架网络找对参考网络

一个好的骨架网络需要大量的实验来支撑它的验证,因此在工程上,参考一些实时网络结构设计自己的骨架网络,事半功倍。我在实践中,参考了这篇[4]paper的骨架来设计自己的轻量网络。

总结

我根据以上的原则和经验对Retinanet进行瘦身,不仅局限于骨架的新设计,FPN支路瘦身,两个子网络(回归网络和分类网络)均进行了修改,期望性能指标FPS提升到63,增幅180%。

FPS

mAP

Model size

注:本文中部分观点参考来源

1 https://towardsdatascience.com/types-of-convolutions-in-deep-learning-717013397f4d

2 The Receptive Field, the Effective RF, and how it s hurting your results

https://www.linkedin.com/pulse/receptive-field-effective-rf-how-its-hurting-your-rosenberg/

3 https://arxiv.org/abs/1807.11164

4 mp.weixin.qq.com/s?

参考

  1. Data Distillation Towards Omni-Supervised Learning

  2. ShuffleNet V2: Practical Guidelines for Efficient CNN Architecture Design

  3. https://github.com/zhouyuangan/flops-counter.pytorch

  4. ThunderNet: TowardsReal-timeGenericObjectDetection

- END -
如果看到这里,说明你喜欢这篇文章,请转发、点赞。扫描下方二维码或者微信搜索「perfect_iscas」,添加好友后即可获得10套程序员全栈课程+1000套PPT和简历模板向我私聊「进群」二字即可进入高质量交流群。
扫描二维码进群↓


在看 

轻量级卷积神经网络的设计技巧相关推荐

  1. 分组卷积计算量_轻量级卷积神经网络的设计技巧

    作者 | zhouyuangan 来源 | CVer 这篇文章将从一个证件检测网络(Retinanet)的轻量化谈起,简洁地介绍,我在实操中使用到的设计原则和idea,并贴出相关的参考资料和成果供读者 ...

  2. LiteFlowNet:用于光流估计的轻量级卷积神经网络

    LiteFlowNet:用于光流估计的轻量级卷积神经网络 原文链接 摘要   FlowNet2 [14] 是用于光流估计的最先进的卷积神经网络 (CNN),需要超过 160M 的参数才能实现准确的流估 ...

  3. 基于Pytorch框架的轻量级卷积神经网络垃圾分类识别系统

    今天在查资料的时候在网上看到一篇文章,博主是基于TensorFlow实现的CNN来完成对垃圾分类识别的,想到最近正好在使用Pytorch就想也做一下,就当是项目开发实践了.先看下动态操作效果: 原文在 ...

  4. 【基于zynq的卷积神经网络加速器设计】(一)熟悉vivado和fpga开发流程:使用Vivado硬件调试烧写hello-world led闪烁程序实现及vivado软件仿真

    HIGHLIGHT: vivado设计流程: note: 分析与综合 和 约束输入 可以调换顺序 [基于zynq的卷积神经网络加速器设计](一)熟悉vivado和fpga开发流程:使用Vivado硬件 ...

  5. 【博士论文】深度卷积神经网络架构设计及优化问题研究

    来源:专知 本文为论文介绍,建议阅读5分钟 近年来,深度卷积神经网络在计算机视觉领域取得了革命性的进展,并被广泛地应用到图像分类.物体检测.实例分割等经典的计算机视觉问题当中. 来自南京理工大学的李翔 ...

  6. 基于ZYNQ平台的卷积神经网络加速器设计及其应用研究

    摘 要 近些年来,深度学习作为机器学习的一种新的形式,它使计算机能够从经验 中学习并根据概念层次来理解世界.作为一种崭新的人工神经网络方法,卷积神 经网络(CNN)使神经元之间可以权值共享来减少样本的 ...

  7. 卷积神经网络调参技巧(2)--过拟合(Dropout)

    Dropout(丢弃) 首先需要讲一下过拟合,训练一个大型网络时,因为训练数据有限,很容易出现过拟合.过拟合是指模型的泛化能力差,网络对训练数据集的拟合能力很好,但是换了其他的数据集,拟合能力就变差了 ...

  8. 卷积神经网络超详细介绍

    文章目录 1.卷积神经网络的概念 2. 发展过程 3.如何利用CNN实现图像识别的任务 4.CNN的特征 5.CNN的求解 6.卷积神经网络注意事项 7.CNN发展综合介绍 8.LeNet-5结构分析 ...

  9. 卷积神经网络超详细介绍(转载)

    卷积神经网络超详细介绍 文章目录 1.卷积神经网络的概念 2. 发展过程 3.如何利用CNN实现图像识别的任务 4.CNN的特征 5.CNN的求解 6.卷积神经网络注意事项 7.CNN发展综合介绍 8 ...

最新文章

  1. RelativeLayout(相对布局)的分析
  2. 使用VMware Workstation搭建基于Linux的Oracle 10g RAC
  3. mysql 日均pv100w_日均百万PV架构第四弹(分布式监控)_MySQL
  4. [css] 不用换行的标签,怎么伪元素实现换行的效果?
  5. 逆向知识第十讲,循环在汇编中的表现形式,以及代码还原
  6. java sql异常_java.sql.SQLException: Io 异常: Got minus one from a
  7. 小小一行Python命令,居然把电脑变成服务器
  8. 高效排错系列--摘要
  9. 螺旋进刀非法平面选择_进刀方法、刀片类型、术语...螺纹加工重点知识你都知道吗?...
  10. Android开源库--ActiveAndroid(active record模式的ORM数据库框架)
  11. 疫苗接种率低?不用怕,互联网公司给出解决之道
  12. Fiddler4的安装与使用
  13. java tomcat 内存溢出怎么解决_Tomcat内存溢出分析及解决方法
  14. 机器学习读书笔记:样本降维
  15. layui表格中显示内容换行
  16. html图片过渡,CSS3 过渡
  17. HTML文本框边框宽度,如何设置文本框尺寸 word文本框怎么设置统一大小
  18. aggr代码 cellranger_CellRanger初探
  19. python控制多个屏幕_使用Python控制屏幕
  20. 联通项目中的常见术语(BTS、BSC、MSC、VLR、HLR)

热门文章

  1. 厉害了,用Python绘制动态可视化图表,并保存成gif格式
  2. 这些算法在印度农村医疗中发挥极大作用,未来还将发挥哪些作用?
  3. 可租赁、可定制的虚拟人居然还能这么玩?9月25日来百度大脑人像特效专场一探究竟!...
  4. 芬兰开放“线上AI速成班”课程,全球网民均可免费观看
  5. 神州数码与神州控股、神州信息共同主办首届技术年会,透露出什么信号?
  6. 深度CTR预估模型的演化之路2019最新进展
  7. 2019全球AI 100强,中国占独角兽半壁江山,但忧患暗存
  8. AI人才招聘:估值超400亿美元,即将IPO的独角兽招AI专家
  9. Google发布机器学习术语表 (中英对照)
  10. Java 8 开发的 4 大技巧