随机权重平均

  • 摘要
  • Introduction
  • SWA
  • 实验部分
  • 消融实验

摘要

您想在不增加推断成本和不改变检测器的情况下提高对象检测器的1.0 AP吗?让我们告诉您一个这样的秘方。这个秘方令人惊讶地简单:使用循环学习率训练您的检测器额外的12个epoches,然后将这12个checkpoints平均作为您的最终检测模型。这个强大的秘方受到随机权重平均(SWA)的启发,该方法在[1]中提出,用于改善深度神经网络的泛化能力。我们发现它在对象检测中也非常有效。在这个技术报告中,我们系统地研究了将SWA应用于对象检测以及实例分割的效果。通过大量的实验,我们发现了在对象检测中执行SWA的可行策略,并在具有挑战性的COCO基准测试中始终实现了对各种流行检测器(包括Mask RCNN、Faster RCNN、RetinaNet、FCOS、YOLOv3和VFNet)的∼1.0 AP的提高。我们希望这项工作能让更多的对象检测研究人员了解这项技术,并帮助他们训练更好的对象检测器。代码可在https://github.com/hyzxmaster/swa-object-detection找到。

Introduction

由于深度学习的巨大成功,目标检测在近年来取得了巨大进展。2015年, Faster RCNN[2]在COCO testdev[3]上仅达到了21.9 AP,而2020年最新的COCO排行榜[4]上这一数字已经提高到了约61.0。尽管如此,我们可以看到目标检测的演进正在变得缓慢,因为深度网络的特征表示学习能力几乎已经被挤干。根据2020年COCO+LVIS联合识别挑战赛[5]的报告,目标检测(实例分割赛道)在COCO上的表现已经达到饱和,这意味着进一步提高目标检测性能正在变得更加困难。即使研究人员费尽心思地设计更好的检测器模块,他们可能会发现在具有挑战性的COCO基准测试上进一步提高1.0 AP的性能是非常困难的。

另一方面,我们最近发现了一种非常简单但有效的增强物体检测器的方法,我们非常激动地与社区分享。您只需要使用周期性学习率训练您的检测器额外的12个epoches,然后将这12个checkpoints平均作为您的最终检测模型。因此,您可以在具有挑战性的COCO基准测试上获得约1.0 AP的提高。由于这种技术只会产生一些训练开销,因此您不需要担心任何推理成本或对检测器的任何更改。这种技术是在[1]中开发的,旨在改善深度网络中的泛化,称为随机权重平均(SWA)。我们在目标检测的研究中尝试了它,并对其在改进我们的目标检测器VarifocalNet [6]或VFNet进行了惊人的有效性。我们发现罕见的目标检测工作[7]采用了这种技术。因此,我们对将SWA应用于目标检测的效果进行了系统研究。由于其代表性和普及性,我们首先选择了Mask RCNN [8]作为我们的研究目标检测器。然后,我们尝试了不同的训练策略,并发现了在目标检测中执行SWA的可行策略。通过广泛的实验,我们发现SWA可以为各种目标检测器(包括Mask RCNN [8]、Faster RCNN [2]、RetinaNet [9]、FCOS [10]、YOLOv3 [11]和我们的VFNet [6])在COCO基准测试中提高约1.0 AP。这使我们感到非常激动,希望这项工作能够帮助社区训练更好的目标检测器。

SWA

我们简要介绍了SWA是什么以及为什么它有效。更多细节请参考SWA论文[1]、其博客[12]或相关教程[13]。

简单地说,SWA是在SGD优化策略上使用高常数学习率或周期性学习率的多个检查点的平均值。设wi为第i个检查点。==在传统的SGD中,通常选择最后一个检查点wn或最佳验证w∗i作为最终模型。==相比之下,在SWA中,多个检查点的平均值被采用作为最终模型。

为什么这个简单的方法有效?作者认为SGD通常会收敛到一组好的权重空间的边缘解(如图1中的W1),这个解通常比那些位于空间中心的解泛化性能差。使用周期性或高常数学习率的SGD优化可以探索靠近与深度神经网络高精度对应的平坦权重空间边界的多个点,如图1中的W1、W2和W3。然后,通过对这些点进行平均,SWA可以找到一个更为集中的解WSWA,其泛化性能显著提高。

在实践中,应用SWA训练物体检测器要回答两个主要问题。首先,从第m个epoch到第n个epoch,我们应该使用什么学习率计划进行SWA训练?使用高恒定学习率或循环学习率?其次,我们应该平均多少个checkpoints?也就是说,我们应该进行多少个epoches的SWA训练?在本报告中,我们通过广泛的实验回答了这些问题。

实验部分

在本章节中,我们进行了一系列实验,旨在研究SWA的效果,并发现适用于物体检测的合适方式。
数据集和评估指标。我们在广泛使用的MS COCO 2017数据集[3]上进行实验。我们在train2017数据集上训练检测器,并在val2017数据集上报告结果。我们采用标准的COCO风格的平均精度(AP)作为评估指标。

实现和训练细节。我们依靠MMDetection [14]进行实验。我们使用8个V100 GPU进行训练,总批量大小为16(每个GPU上2张图像)。为方便起见,我们在此介绍1x和2x训练计划[15]。==1x计划表示模型训练12个时期,并在第9个和第12个时期分别将初始学习率降低10倍,2x计划表示模型训练24个时期,并在第17个和第23个时期分别将初始学习率降低10倍。==为简洁起见,我们还在此描述了本报告中使用的物体检测器的命名规则。以MaskRCNN-R101-2x-0.02-0.0002-40.8-36.6为例解释。它意味着预训练的检测器Mask RCNN具有ResNet-101 [16]骨干网络,在2x计划下训练,初始学习率为0.02,结束学习率为0.0002,在COCO val2017上分别达到40.8 bbox AP和36.6 mask AP。

循环余弦退火学习率的示意图。在每个循环中,学习率从初始学习率lrmax(本例中为0.02)逐渐降低到结束学习率lrmin(本例中为0.0002),并在每个循环结束后重新开始。

消融实验

随后我们选择 Mask RCNN [8] 作为研究对象,以探索在目标检测和实例分割中如何正确使用 SWA。我们首先从 MMDetection 模型库中下载预训练模型 MaskRCNNR101-2x-0.02-0.0002-40.8-36.6 以及其配置文件作为起点。然后,我们使用不同的学习率策略对模型进行额外的 24 或 48 轮训练。第一种策略是固定学习率计划,其中选择了 0.02、0.002 和 0.0002。请注意,这些学习率对应于预训练模型的不同训练阶段中使用的学习率。第二种策略是循环学习率计划。如图 2 所示,每个epoch中,学习率从一个较大的值 lrmax 开始,然后相对地逐渐减小,直到达到 lrmin,然后再逐渐增加,最终再次回到 lrmax,以此循环。通过这些实验,我们可以评估 SWA 在不同训练策略下的性能表现。
请注意,学习率的减少是在每个 iteration中而不是每个epoch中发生的。在这项研究中,我们采用余弦退火学习率调度,选择两组(lrmax,lrmin),即(0.01,0.0001)和(0.02,0.0002),并选择1个epoch作为周期长度。

最终,我们将不同数量的checkpoints(6、12、24和48)平均,作为我们的最终SWA模型,并在COCO val2017上评估它们的性能。请注意,由于骨干网络中的批量归一化层被冻结[14],因此我们不需要按照原始的SWA论文再运行一遍数据来计算新的统计数据。

结果呈现在表格1中。如上所述,我们尝试了五种不同的训练策略,它们被分成两组。对于固定学习率的策略,我们在前15个epoch中使用了一个较大的学习率,并将其降低到原来的十分之一。对于一次性余弦退火策略,我们使用一个初始学习率并将其降低到零。最好的性能在一次性余弦退火策略中获得,其mAP为49.4%。

对于固定的学习率组,我们可以看到学习率对每个SGD epoch的性能有很大影响。具体而言,当学习率为0.02时,每个SGD epoch的表现比预训练模型差得多,例如bbox AP为33.0-34.0,而bbox AP为40.8。相比之下,当学习率为0.0002时,每个SGD epoch的表现与预训练模型相当。虽然使用不同学习率的每个SGD epoch所达到的性能差别很大,但令人惊讶的是,通过对每个训练策略下的一定数量的检查点进行平均得到的SWA模型达到了相似的结果。我们可以看到,在表1的SWA 1-12列中,所有这三个SWA模型都获得了约40.5 bbox AP和36.5 mask AP的成绩。然而,这些结果都不如起始模型,并且表明恒定的学习率策略效果不佳。相比之下,循环学习率组在每个SGD epoch中实现了更稳定的结果,他们的SWA模型也取得了更好的结果。可以看到,学习率范围(0.02,0.0002)的表现要优于范围(0.01,0.0001),表明在预训练阶段使用的学习率已经很好地发挥作用。更详细地观察(0.02,0.0002)范围的结果,其SWA 1-12模型获得了41.7 bbox AP和37.4 mask AP的成绩,分别比预训练模型提高了0.9 bbox AP和0.8 mask AP。此外,SWA 1-12模型表现优于SWA 1-6模型,并且与SWA 1-24模型和SWA 1-48模型相当。这表明,训练另外12个epoch足以生成一个良好的SWA模型,特别是考虑到计算负担和收益之间的权衡。

通过比较表格1中的结果,我们可以推断出一种可行的策略,即在使用SWA训练更好的物体检测器时,经过传统训练使用初始学习率lrini和结束学习率lrend训练一个物体检测器,然后使用循环学习率(lrini,lrend)进行额外的12个时期的训练,最后将这12个检查点平均作为最终的检测模型。

基于上述观察结果,我们还尝试从头开始训练SWA Mask RCNN(主干在ImageNet [17]上预训练)。我们首先使用学习率0.02训练原始Mask RCNN模型16个时期,得到模型MaskRCNN-R101-16e-0.02-0.02-33.4-30.8。然后,我们使用循环学习率在另外12个epoch内训练模型,并将这些12个检查点平均作为SWA模型。如表1的最后一部分所示,我们可以看到这个SWA 1-12模型实现了41.7个bbox AP和37.4个mask AP,与通过训练MaskRCNN-R101-2x0.02-0.0002-40.8-36.6获得的SWA 1-12模型相同。这表明这样的混合训练策略也可以生成更好的物体检测器,并可用于从头开始训练新的物体检测器。

为了验证我们在物体检测中发现的SWA策略的有效性,我们将其应用于不同骨干网络的各种物体检测器,包括Mask RCNN、Faster RCNN、RetinaNet、FCOS、YOLOv3和VFNet。这些结果分别在表2、表3、表4、表5、表6和表7中呈现。从这些结果中,我们可以看到SWA在我们的训练策略下始终将这些检测器的性能提高了约1.0 AP,无论它们的原始性能是高还是低。这是非常令人鼓舞的,并使我们兴奋地将这一发现分享给社区。

我们可以在图3中查看比较性的定性例子。通过比较这些检测例子,我们可以看出SWA提高了物体定位和物体分类准确性,从而减少了误报和提高了召回率。

为了进一步了解SWA带来的改进的来源,我们分析了Mask RCNN和FCOS的结果。遵循论文Diagnosing Error in Object Detectors [18]的做法,我们绘制了预训练Mask RCNN(MaskRCNNR101-2x-0.02-0.0002-40.8-36.6)和其SWA模型以及FCOS(FCOS-R101-2x-0.01-0.0001-39.1)和其SWA模型的错误分布。COCO API生成的每个图都是一系列精度-召回(PR)曲线,其中每个PR曲线都保证在评估设置变得更加宽松时严格高于前一个,每个曲线下的面积对应于AP(在图例中显示)。

通过比较图4a和图4b、图4c和图4d以及图5a和图5b,我们可以推断出SWA不仅提高了物体定位准确性,而且还提高了物体分类准确性。例如,图4a显示预训练Mask RCNN在IoU=0.75时的总体AP为44.5,但SWA Mask RCNN将这个数字提高了1.0,达到45.5 AP,表明SWA提高了定位准确性。类似地,当忽略定位误差时,即图例中表示为Loc时,预训练Mask RCNN实现了68.1 AP,但SWA Mask RCNN实现了69.0 AP,这意味着SWA还提高了物体分类准确性。对于FCOS也可以看到类似的比较结果。

总之,我们系统地研究了将SWA应用于物体检测和实例分割的效果。我们发现,使用循环学习率在另外12个epoch内训练模型并平均这些12个checkpoints,可以在具有挑战性的COCO基准测试中将检测器的性能提高约1.0 AP。我们的广泛实验表明,这种技术适用于各种物体检测器,包括Mask RCNN、Faster RCNN、RetinaNet、FCOS、YOLOv3和VFNet。我们希望我们的工作可以让更多的研究人员知道这个简单而有效的方法,并帮助他们训练更好的物体检测器。

swa可以降低模型震荡:在训练过程中,模型的权重会不断变化,而SWA策略可以减少这种变化,使模型更加平滑。这样可以降低模型的震荡,提高训练的稳定性,有助于提高模型的性能。
提高模型泛化能力:SWA策略对训练期间的权重进行平均,可以减少模型在训练集上的过拟合现象,提高模型的泛化能力。
由于模型的参数属于高维空间,SGD训练的模型往往收敛到最优解的边界区域,采用swa可以使其接近最优解
一言概括就是:采用周期式学习速率(余弦退火学习速率)额外再训练你的模型12个epoch,然后简单地平均每个epoch训练得到的weights作为最终的模型。

SWA Object Detection随机权重平均【论文+代码】相关推荐

  1. Few-shot Object Detection via Feature Reweighting论文学习以及复现

    复现Few-shot Object Detection via Feature Reweighting论文代码 写在前面 本电脑配置 环境配置 Prepare dataset Base Trainin ...

  2. 论文学习笔记《SWA Object Detection》

    这是一篇2020年年底挂在arxiv上的论文,主要思想很简单,就做了一件事情:采用周期性学习率迭代策略(余弦退火算法)额外再训练模型12个epoch,然后平均每个epoch训练得到的weights作为 ...

  3. 模型泛化技巧“随机权重平均(Stochastic Weight Averaging, SWA)”介绍与Pytorch Lightning的SWA实现讲解

    文章目录 SWA简介 SWA公式 SWA常见参数 Pytorch Lightning的SWA源码分析 SWALR 参考资料 SWA简介 SWA,全程为"Stochastic Weight A ...

  4. SWA(随机权重平均)——一种全新的模型优化方法

    这两天被朋友推荐看了一篇热乎的新型优化器的文章,文章目前还只挂在arxiv上,还没发表到顶会上.本着探索的目的,把这个论文给复现了一下,顺便弥补自己在优化器方面鲜有探索的不足. 论文标题:Averag ...

  5. 【提分trick】SWA(随机权重平均)和EMA(指数移动平均)

    1. SWA随机权重平均 1.1步骤 1.2代码 2.EMA指数移动平均 2.1步骤 2.2代码 3.总结 在kaggle比赛中,不管是目标检测任务.语义分割任务中,经常能看到SWA(Stochast ...

  6. SWA(随机权重平均) for Pytorch

    Stochastic Weight Averaging for Pytorch 随机权重平均 一.什么是Stochastic Weight Averaging(SWA) 二.SWA与SGD的对比 三. ...

  7. Sparse R-CNN: End-to-End Object Detection with Learnable Proposals论文翻译

    Sparse R-CNN: End-to-End Object Detection with Learnable Proposals论文翻译 摘要 1.介绍 2.相关工作 3.Sparse R-CNN ...

  8. Dynamic Head: Unifying Object Detection Heads with Attentions论文阅读

    Dynamic Head: Unifying Object Detection Heads with Attentions论文阅读 摘要 介绍 相关工作 方法 Dynamic Head 扩展到现存的检 ...

  9. Dynamic Head Unifying Object Detection Heads with Attentions 论文阅读笔记

    Dynamic Head Unifying Object Detection Heads with Attentions论文阅读笔记 这是微软在CVPR2021发表的文章,在coco数据集上取得了目前 ...

最新文章

  1. python技巧 使用值来排序一个字典
  2. Java基础—序列化底层原理
  3. stm32串口通讯问题
  4. (62)SPI外设驱动协议(一)(第13天)
  5. Android4.0升级新特性
  6. php 把查询数据转json格式,php将从数据库查询到的数据转化为json格式,并写入json文件中...
  7. 国家机构评测主流电视:长虹人工智能语音识别第一
  8. 路径规划—入门路径规划概念
  9. 怎么把系统桌面设置到D盘
  10. 环信头像和昵称显示问题 (添加消息扩展)--本人已实现效果
  11. 锂离子电池种类介绍和分类
  12. ICCV 2021 口罩人物身份鉴别全球挑战赛冠军方案分享
  13. 英语背单词有用吗_对于大学生英语背单词软件哪个好可以用_最好的背单词
  14. 苹果cmsV10MXone Pro自适应模板 站长亲测 全网首发
  15. 突破无人驾驶量产瓶颈,威蓝科技利用仿真测试降本增效
  16. MVC中的URL路由(一)
  17. Python 爬取必应(壁纸+搜索词)
  18. 手机android.sys木马,使用kali生成木马入侵安卓手机
  19. miui11 android,悉数MIUI11不容易注意到的细节新特性
  20. 用双十一的故事串起碎片的网络协议(上)

热门文章

  1. 解决 小程序界面数据不显示问题
  2. oj 3014 文件格式变换
  3. DevOps团队如何为网络星期一做准备
  4. hdu4082 Hou Yi's secret(相似三角形)
  5. 终极五笔 v6.02 正式版 下载
  6. ESP32 ESP-IDF 项目文件结构
  7. MANIFEST.MF文件详解
  8. CHROME扩展开发之·迁移到 Manifest V3
  9. oracle输出实心三角型,C语言帕斯卡三角形打印示例
  10. “笨办法”学Python3,Zed A. Shaw, 习题3