点击上方“小白学视觉”,选择加"星标"或“置顶

重磅干货,第一时间送达

来源: arXiv

编辑:刘晓坤、思源

自 2012 年 AlexNet 大展神威以来,研究者已经提出了各种卷积架构,包括 VGG、NiN、Inception、ResNet、DenseNet 和 NASNet 等,我们会发现模型的准确率正稳定提升。

但是现在这些提升并不仅仅来源于架构的修正,还来源于训练过程的改进:包括损失函数的优化、数据预处理方法的提炼和最优化方法的提升等。在过去几年中,卷积网络与图像分割出现大量的改进,但大多数在文献中只作为实现细节而简要提及,而其它还有一些技巧甚至只能在源代码中找到。

在这篇论文中,李沐等研究者研究了一系列训练过程和模型架构的改进方法。这些方法都能提升模型的准确率,且几乎不增加任何计算复杂度。它们大多数都是次要的「技巧」,例如修正卷积步幅大小或调整学习率策略等。总的来说,采用这些技巧会产生很大的不同。因此研究者希望在多个神经网络架构和数据集上评估它们,并研究它们对最终模型准确率的影响。

研究者的实验表明,一些技巧可以显著提升准确率,且将它们组合在一起能进一步提升模型的准确率。研究者还对比了基线 ResNet 、加了各种技巧的 ResNet、以及其它相关的神经网络,下表 1 展示了所有的准确率对比。这些技巧将 ResNet50 的 Top-1 验证准确率从 75.3%提高到 79.29%,还优于其他更新和改进的网络架构。此外,研究者还表示这些技巧很多都可以迁移到其它领域和数据集,例如目标检测和语义分割等。

论文:Bag of Tricks for Image Classification with Convolutional Neural Networks

论文地址:https://arxiv.org/pdf/1812.01187.pdf

摘要:图像分类研究近期的多数进展都可以归功于训练过程的调整,例如数据增强和优化方法的变化。然而,在这些文献中,大多数微调方法要么被简单地作为实现细节,或仅能在源代码中看到。在本文中,我们将测试一系列的微调方法,并通过控制变量实验评估它们对最终准确率的影响。我们将展示通过组合不同的微调方法,我们可以显著地改善多种 CNN 模型。例如,我们将 ImageNet 上训练的 ResNet-50 的 top-1 验证准确率从 75.3% 提升到 79.29。本研究还表明,图像分类准确率的提高可以在其他应用领域(如目标检测和语义分割)中实现更好的迁移学习性能。

2 训练过程

目前我们基本上都用小批量 SGD 或其变体训练神经网络,Algorithm 1 展示了 SGD 的模版过程(感兴趣的读者可以查阅原论文)。利用广泛使用的 ResNet 实现作为我们的基线,训练过程主要分为以下六个步骤:

  1. 随机采样一张图片,并解码为 32 位的原始像素浮点值,每一个像素值的取值范围为 [0, 255]。

  2. 随机以 [3/4, 4/3] 为长宽比、[8%, 100%] 为比例裁减矩形区域,然后再缩放为 224*224 的方图。

  3. 以 0.5 的概率随机水平翻转图像。

  4. 从均匀分布 [0.6, 1.4] 中抽取系数,并用于缩放色调和明亮度等。

  5. 从正态分布 N (0, 0.1) 中采样一个系数,以添加 PCA 噪声。

  6. 图像分别通过减去(123.68, 116.779, 103.939),并除以(58.393, 57.12, 57.375)而获得经归一化的 RGB 三通道。

经过六步后就可以训练并验证了,以下展示了基线模型的准确率:

表 2:文献中实现的验证准确率与我们基线模型的验证准确率,注意 Inception V3 的输入图像大小是 299*299。

3 高效训练

随着 GPU 等硬件的流行,很多与性能相关的权衡取舍或最优选择都已经发生了改变。在这一章节中,我们研究了能利用低精度和大批量训练优势的多种技术,它们都不会损害模型的准确率,甚至有一些技术还能同时提升准确率与训练速度。

3.1 大批量训练

对于凸优化问题,随着批量的增加,收敛速度会降低。人们已经知道神经网络会有类似的实证结果 [25]。换句话说,对于相同数量的 epoch,大批量训练的模型与使用较小批量训练的模型相比,验证准确率会降低。因此有很多方法与技巧都旨在解决这个问题:

线性扩展学习率:较大的批量会减少梯度的噪声,从而可以增加学习率来加快收敛。

学习率预热:在预热这一启发式方法中,我们在最初使用较小的学习率,然后在训练过程变得稳定时换回初始学习率。

Zero γ:注意 ResNet 块的最后一层可以是批归一化层(BN)。在 zero γ启发式方法中,我们对所有残差块末端的 BN 层初始化γ=0。因此,所有的残差块仅返回输入值,这相当于网络拥有更少的层,在初始阶段更容易训练。

无偏衰减:无偏衰减启发式方法仅应用权重衰减到卷积层和全连接层的权重,其它如 BN 中的γ和β都不进行衰减。

表 4:ResNet-50 上每种有效训练启发式的准确率效果。

3.2 低精度训练

然而,新硬件可能具有增强的算术逻辑单元以用于较低精度的数据类型。尽管具备性能优势,但是精度降低具有较窄的取值范围,因此有可能出现超出范围而扰乱训练进度的情况。

表 3:ResNet-50 在基线(BS = 256 与 FP32)和更高效硬件设置(BS = 1024 与 FP16)之间的训练时间和验证准确率的比较。

4 模型变体

我们将简要介绍 ResNet 架构,特别是与模型变体调整相关的模块。ResNet 网络由一个输入主干、四个后续阶段和一个最终输出层组成,如图 1 所示。输入主干有一个 7×7 卷积,输出通道有 64 个,步幅为 2,接着是 3 ×3 最大池化层,步幅为 2。输入主干(input stem)将输入宽度和高度减小 4 倍,并将其通道尺寸增加到 64。

从阶段 2 开始,每个阶段从下采样块开始,然后是几个残差块。在下采样块中,存在路径 A 和路径 B。路径 A 具有三个卷积,其卷积核大小分别为 1×1、3×3 和 1×1。第一个卷积的步幅为 2,以将输入长度和宽度减半,最后一个卷积的输出通道比前两个大 4 倍,称为瓶颈结构。路径 B 使用步长为 2 的 1×1 卷积将输入形状变换为路径 A 的输出形状,因此我们可以对两个路径的输出求和以获得下采样块的输出。残差块类似于下采样块,除了仅使用步幅为 1 的卷积。

我们可以改变每个阶段中残差块的数量以获得不同的 ResNet 模型,例如 ResNet-50 和 ResNet-152,其中的数字表示网络中卷积层的数量。

图 1:ResNet-50 的架构。图中说明了卷积层的卷积核大小、输出通道大小和步幅大小(默认值为 1),池化层也类似。

图 2:三个 ResNet 变体。ResNet-B 修改 ResNet 的下采样模块。ResNet-C 进一步修改输入主干。在此基础上,ResNet-D 再次修改了下采样块。

表 5:将 ResNet-50 与三种模型变体进行模型大小(参数数量)、FLOPs 和 ImageNet 验证准确率(top-1、top-5)的比较。

5 训练方法改进

5.1 余弦学习率衰减

Loshchilov 等人 [18] 提出余弦退火策略,其简化版本是按照余弦函数将学习速率从初始值降低到 0。假设批次总数为 T(忽略预热阶段),然后在批次 t,学习率η_t 计算如下:

其中η是初始学习率,我们将此方案称为「余弦」衰减。

图 3:可视化带有预热方案的学习率变化。顶部:批量大小为 1024 的余弦衰减和按迭代步衰减方案。底部:关于两个方案的 top-1 验证准确率曲线。

5.2 标签平滑

标签平滑的想法首先被提出用于训练 Inception-v2 [26]。它将真实概率的构造改成:

其中ε是一个小常数,K 是标签总数量。

图 4:ImageNet 上标签平滑效果的可视化。顶部:当增加ε时,目标类别与其它类别之间的理论差距减小。下图:最大预测与其它类别平均值之间差距的经验分布。很明显,通过标签平滑,分布中心处于理论值并具有较少的极端值。

5.3 知识蒸馏

在知识蒸馏 [10] 中,我们使用教师模型来帮助训练当前模型(被称为学生模型)。教师模型通常是具有更高准确率的预训练模型,因此通过模仿,学生模型能够在保持模型复杂性相同的同时提高其自身的准确率。一个例子是使用 ResNet-152 作为教师模型来帮助训练 ResNet-50。

5.4 混合训练

在混合训练(mixup)中,每次我们随机抽样两个样本 (x_i,y_i) 和 (x_j,y_j)。然后我们通过这两个样本的加权线性插值构建一个新的样本:

其中 λ∈[0,1] 是从 Beta(α, α) 分布提取的随机数。在混合训练中,我们只使用新的样本 (x hat, y hat)。

5.5 实验结果

表 6:通过堆叠训练改进方法,得到的 ImageNet 验证准确率。基线模型为第 3 节所描述的。

6 迁移学习

6.1 目标检测

表 8:在 Pascal VOC 上评估各种预训练基础网络的 Faster-RCNN 性能。

6.2 语义分割

下载1:OpenCV-Contrib扩展模块中文版教程

在「小白学视觉」公众号后台回复:扩展模块中文教程即可下载全网第一份OpenCV扩展模块教程中文版,涵盖扩展模块安装、SFM算法、立体视觉、目标跟踪、生物视觉、超分辨率处理等二十多章内容。

下载2:Python视觉实战项目52讲

在「小白学视觉」公众号后台回复:Python视觉实战项目即可下载包括图像分割、口罩检测、车道线检测、车辆计数、添加眼线、车牌识别、字符识别、情绪检测、文本内容提取、面部识别等31个视觉实战项目,助力快速学校计算机视觉。

下载3:OpenCV实战项目20讲

在「小白学视觉」公众号后台回复:OpenCV实战项目20讲即可下载含有20个基于OpenCV实现20个实战项目,实现OpenCV学习进阶。

交流群

欢迎加入公众号读者群一起和同行交流,目前有SLAM、三维视觉、传感器、自动驾驶、计算摄影、检测、分割、识别、医学影像、GAN、算法竞赛等微信群(以后会逐渐细分),请扫描下面微信号加群,备注:”昵称+学校/公司+研究方向“,例如:”张三 + 上海交大 + 视觉SLAM“。请按照格式备注,否则不予通过。添加成功后会根据研究方向邀请进入相关微信群。请勿在群内发送广告,否则会请出群,谢谢理解~

亚马逊:用CNN进行图像分类的Tricks相关推荐

  1. 亚马逊专家揭秘:如何建立自动检测乳腺癌的深度学习模型

    安妮 编译自 Insight Data Science 量子位出品 | 公众号 QbitAI 本文作者Sheng Weng,现亚马逊Alexa项目组数据专家,莱斯大学应用物理专业已毕业博士生,主要研究 ...

  2. Kaggle亚马逊比赛冠军专访:利用标签相关性来处理分类问题

    近日,Kaggle Blog上刊登了对「Planet: Understanding the Amazon from Space」比赛冠军的专访,在访问中,我们了解到了冠军选手bestfitting的一 ...

  3. 亚马逊马超:如何使用DGL进行大规模图神经网络训练?

    演讲嘉宾 | 马超(亚马逊应用科学家) 整理 | 刘静  出品 | AI科技大本营(ID:rgznai100) 与传统基于张量(Tensor)的神经网络相比,图神经网络将图 (Graph) 作为输入, ...

  4. 图表对比详解:亚马逊、微软和谷歌云的机器学习即服务哪家强

    林鳞 编译自 KDnuggets 量子位 出品 | 公众号 QbitAI 对于大多数公司来说,机器学习是一项复杂而伤神的工作,花销大.对人才要求高.机器学习即服务针对这个痛点应运而生. 什么是&quo ...

  5. 亚马逊贝索斯伸出橄榄枝后,巴菲特也力挺特朗普

    11月11日消息,据路透社报道称,在大选期间高调支持希拉里并攻击特朗普的"股神"巴菲特现在也转变立场,看好特朗普当选后的美股走势,称不管大选结果如何,美股长期而言势将上涨. 巴菲特 ...

  6. 亚马逊一口气发布了9款机器学习产品

    AI前线导读: 今天,在拉斯维加斯举行的AWS re:invent进行到第三个日程,大会上,AWS CEO Andy Jassy在主题演讲上一口气做了二十个新发布,其中包括9款机器学习产品! 更多干货 ...

  7. 解构亚马逊Alexa的1.5万种技能

     解构亚马逊Alexa的1.5万种技能:三大派系.口碑落差,长尾死亡 本文作者:邹霖 2017-07-12 18:45 导语:上周,Voicebot 对外宣称 Alexa技能突破1.5万.那么这1 ...

  8. 阿里巴巴宣布架构调整;英伟达放大招!重磅发布 ​TensorRT 7 ,支持超千种计算变换;苹果、谷歌和亚马逊罕见结盟……...

    戳蓝字"CSDN云计算"关注我们哦!  嗨,大家好,重磅君带来的[云重磅]特别栏目,如期而至,每周五第一时间为大家带来重磅新闻.把握技术风向标,了解行业应用与实践,就交给我重磅君吧 ...

  9. 伯克利2019深度学习课程—李沐及其亚马逊同事一起讲述(内附视频链接及PDF下载)

    点上方蓝字计算机视觉联盟获取更多干货 在右上方 ··· 设为星标 ★,与你不见不散 伯克利2019深度学习课程是李沐老师大致按照李沐老师的开源新书<动手学深度学习>来安排的(和去年放出的同 ...

  10. 30多门免费课程上线,亚马逊“机器学习大学”开学了

    圆栗子 发自 凹非寺  量子位 报道 | 公众号 QbitAI 今天,亚马逊"机器学习大学"开课了. 团队昭告天下,那些用来培训亚马逊工程师的秘密课程,已经通过AWS免费向所有开发 ...

最新文章

  1. linux如何取文件列名,Linux_根据表名和索引获取需要的列名的存储过程,复制代码 代码如下: create proc p - phpStudy...
  2. Jquery 【select 通过value来寻找对应的项name】
  3. Flink应用实战案例50篇(一)- Flink SQL 在京东的优化实战
  4. asp.net 生命周期中的时间流程
  5. 判断文件是否改变php,PHP判断文件是否被修改实例
  6. 学什么c语言标准,C语言的标准 “输入输出”!今天是你学C语言的第几天?
  7. myeclipse 8.5 注册码
  8. HTML small元素
  9. java pc 蓝牙_Nokia PC 套件与蓝牙适配器连接教程(转)--个人推荐
  10. 技巧 | 如何使用R语言的基础绘图系统的拼图功能
  11. Copy(定义,特点,深复制,浅复制)(非ARC,ARC的运用范围)
  12. 开发软件快捷键(持续更新中)
  13. spring学习笔记--IOC接口
  14. npm升级所有可更新包
  15. 数据处理中常用的Excel函数
  16. 李航老师《统计学习方法》及相关资源(代码、课件)的汇总及下载
  17. win7 x64怎么枚举所有快捷键呢
  18. SOP是Standard Operation Procedure三个单词中首字母的大写 ,即标准作业程序
  19. 语音识别之wave文件(*.wav)格式、PCM数据格式介绍
  20. Java爬取英雄联盟官网,全英雄皮肤背景图片

热门文章

  1. 给算法工程师和研究员的「霸王餐」| 附招聘信息
  2. 医疗影像处理:去除医疗影像中背景的影响2D/3D【numpy-code】| CSDN博文精选
  3. 一次改变未来10年人生的机会
  4. 不止临床应用,AI还要帮不懂编程的医生搞科研
  5. 初学者的机器学习入门实战教程!
  6. 又一届Google Cloud Next,李飞飞发布TPU 3.0,两大AutoML新品
  7. 量子技术发展的一小步:Google AI推出开源框架Cirq
  8. Python需求增速达174%,AI人才缺口仍超百万!这份来自2017年的实际招聘数据如是说
  9. 面试官:抛开Spring来说,如何自己实现Spring AOP?
  10. 这次性能优化, QPS 翻倍了