论文分析了 one-stage 网络训练存在的类别不平衡问题,提出能根据 loss 大小自动调节权重的 focal loss,使得模型的训练更专注于困难样本。同时,基于 FPN 设计了 RetinaNet,在精度和速度上都有不俗的表现

论文:Focal Loss for Dense Object Detection

  • 论文地址:http://arxiv.org/abs/1708.02002[1]

  • 论文代码:http://github.com/facebookresearch/Detectron[2]

Introduction


  目前 state-of-the-art 的目标检测算法大都是 two-stage、proposal-driven 的网络,如 R-CNN 架构。而 one-stage 检测器一直以速度为特色,在精度上始终不及 two-stage 检测器。因此,论文希望研究出一个精度能与 two-stage 检测器媲美的 one-stage 检测器 通过分析,论文认为阻碍 one-stage 精度主要障碍是类别不平衡问题(class imbalance)

  • 在 R-CNN 架构检测器中,通过 two-stage 级联和抽样探索法(sampling heuristics)来解决类别不平衡问题。proposal 阶段能迅速地将 bndbox 的数量缩小到很小的范围(1-2k),过滤了大部分背景。而第二阶段,则通过抽样探索法来保持正负样本的平衡,如固定的正负样本比例(1:3)和 OHEM

  • one-stage 检测器通常需要处理大量的 bndbox(~100k),密集地覆盖着各位置、尺度和长宽比。然而大部分 bndbox 都是不含目标的,即 easy background。尽管可以使用类似的抽样探索法(如 hard example mining)来补救,但这样的效率不高,因为训练过程仍然被简单的背景样本主导,导致模型更多地学习了背景而没有很好地学习检测的目标

  在解决以上问题的同时,论文产出了两个成果:

  • 新的损失函数 focal loss,该函数能够动态地调整交叉熵大小。当类别的置信度越大,权重就逐渐减少,最后变为 0。反之,置信度低的类别则得到大的权重

  • 设计了一个简单的 one-stage 检测器 RetinaNet 来演示 focal loss 的有效性。该网络包含高效的特征金字塔和特别的 anchor 设定,结合一些多种近期的 one-stage detectgor 的 trick(DNN/FPN/YOLO/SSD),达到 39.1 的 AP 精度和 5fps 的速度,超越了所有的单模型,如图 2 所示

FocalLoss


Balanced Cross Entropy

  交叉熵损失函数如图 1 最上曲线,当置信度大于 0.5 时,loss 的值也不小。若存在很多简单样本时,这些不小的 loss 堆积起来会对少样本的类别训练造成影响

  一种简单的做法是赋予不同的类不同的权重,即-balanced 交叉熵。在实际操作中,属于一个预设的超参,类别的样本数越多,则设置越小

Focal Loss Definition

  -balanced 交叉熵仅根据正负样本的数量进行权重的平衡,没有考虑样本的难易程度。因此,focal loss 降低了容易样本的损失,从而让模型更专注于难的负样本

  focal loss 在交叉熵的基础上添加了调节因子,其中是超参数。的 loss 曲线如图 1 所示,focal loss 有两个特性:

  • 当一个样本被误分且置信度很低时,调节因子会接近 1,整体的 loss 都很小。当置信度接近 1 的时候,调节因子会接近于 0,整体的 loss 也被降权了

  • 超参数平滑地调整了简单样本的降权比例。当,Focal loss 与交叉熵一致,随着增加,调节因子的影响也相应增加。当时,置信度为 0.9 的样本的 loss 将有 100 倍下降,而 0.968 的则有 1000 倍下降,这变相地增加了误分样本的权重

  实际使用时中,focal loss 会添加-balanced,这是从后面的实验中总结出来的

Class Imbalance and Model Initialization

  二分类模型初始化时对于正负样本预测是均等的,而在训练时,样本数多的类别会主导网络的学习,导致训练初期不稳定。为了解决这问题,论文在模型初始化的时候设置先验值(如 0.01),使模型初始输出偏向于低置信度来加大少数(正)样本的学习。在样本不平衡情况下,这种方法对于提高 focal loss 和 cross entropy 训练稳定性有很大帮助

RetinaNet Detector


Architecture

  RetinaNet 是 one-stage 架构,由主干网络和两个 task-specific 子网组成。主干网络用于提取特征,第一个子网用于类别分类,第二个子网用于 bndbox 回归

  • Feature Pyramid Network Backbone

  RetinaNet 采用 FPN 作为主干,FPN 通过自上而下的路径以及横行连接来增强卷积网络的特征提取能力,能够从一张图片中构造出丰富的以及多尺度特征金字塔,结构如图 3(a)-(b)。  FPN 构建在 ResNet 架构上,分别在 level -,每个 level l 意味着的尺度缩放,且每个 level 包含 256 通道

  • Anchors

  level到对应的 anchor 尺寸为到,每个金字塔层级的的长宽比均为,为了能够预测出更密集的目标,每个长宽比的 anchor 添加原设定尺寸的大小的尺寸,每个 level 总共有 9 个 anchor   每个 anchor 赋予长度为 K 的 one-hot 向量和长度为 4 的向量,K 为类别数,4 为 box 的坐标,与 RPN 类似。IoU 大于 0.5 的 anchor 视为正样本,设定其 one-host 向量的对应值为 1,的 anchor 视为背景,的 anchor 不参与训练

  • Classification Subnet

  分类子网是一个 FCN 连接 FPN 的每一 level,分类子网是权值共享的,即共用一个 FPN。子网由 4xCx(3x3 卷积+ReLU 激活层)+KxA(3x3 卷积)构成,如图 3(c),C=256,A=9

  • Box Regression Subnet

  定位子网结构与分类子网类似,只是将最后的卷积大小改为 4xAx3x3,如图 3(d 所示)。每个 anchor 学习 4 个参数,代表当前 bndbox 与 GT 间的偏移量,这个与 R-CNN 类似。这里的定位子网是类不可知的(class-agnostic),这样能大幅减少参数量

Inference and Training

  • Inference

  由于 RetinaNet 结构简单,在推理的时候只需要直接前向推算即可以得到结果。为了加速预测,每一个 FPN level 只取置信度 top-1k bndbox(),之后再对所有的结果进行 NMS()

  • Focal Loss

  训练时,focal loss 直接应用到所有~ 100k anchor 中,最后将所有的 loss 相加再除以正样本的数量。这里不除以 achor 数,是由于大部分的 bndbox 都是 easy 样本,在 focal loss 下仅会产生很少 loss。权值的设定与存在一定的关系,当增加时,则需要减少,(表现最好)

  • Initialization

  Backbone 是在 ImageNet 1k 上预训练的模型,FPN 的新层则是根据论文进行初始化,其余的新的卷积层(除了最后一层)则偏置,权重为的高斯分布

  最后一层卷积的权重为的高斯分布,偏置(偏置值的计算是配合最后的激活函数来推),使得训练初期的前景置信度输出为,即认为大概率都是背景。这样背景就会输出很小的 loss,前景会输出很大的 loss,从而阻止背景在训练前期产生巨大的干扰 loss

  • Optimization

  RetinaNet 使用 SGD 作为优化算法,8 卡,每卡 batchSize=2。learning rate=0.01,60k 和 80k 轮下降 10 倍,共进行 90k 迭代,Weight decay=0.0001,momentum=0.9, training loss 为 focal loss 与 bndbox 的 smooth L1 loss

Experiments


Training Dense Detection

  • Network Initialization

  论文首先尝试直接用标准交叉熵进行 RetinaNet 的训练,不添加任何修改和特殊初始化,结果在训练时模型不收敛。接着论文使用先验概率对模型进行初始化,模型开始正常训练,并且最终达到 30.2AP,训练对的值不敏感

  • Balanced Cross Entropy

  接着论文进行平衡交叉熵的实验,结果如 Table1a,当时,模型获得 0.9 的 AP 收益

  • Focal Loss

  接着论文进行了 focal loss 实验,结果如 Table 1b,当时,模型在-balanced 交叉熵上获得 2.9AP 收益。论文观察到,与成反向关。整体而言,带来的收益更大,此外,的值一般为(从中实验得出)

  • Analysis of the Focal Loss

  为了进一步了解 focal loss,论文分析了一个收敛模型(,ResNet-101)的 loss 经验分布。首先在测试集的预测结果中随机取个正样本和个负样本,计算其 FL 值,再对其进行归一化令他们的和为 1,最后根据归一化后的 loss 进行排序,画出正负样本的累积分布函数(CDF),如图 4

  不同的值下,正样本的 CDF 曲线大致相同,大约 20%的难样本占据了大概一半的 loss,随着的增大,更多的 loss 集中中在 top20%中,但变化比较小   不同的值下,负样本的 CDF 曲线截然不同。当时,正负样本的 CDF 曲线大致相同。当增大时,更大的 loss 集中在难样本中。当时,很大一部分的 loss 集中在很小比例的负样本中。可以看出,focal loss 可以很有效的减少容易样本的影响,让模型更专注于难样本

  • Online Hard Example Mining (OHEM)

  OHEM 用于优化 two-stage 检测器的训练,首先根据 loss 对样本进行 NMS,再挑选 hightest-loss 样本组成 minibatches,其中 NMS 的阈值和 batch size 都是可调的。与 FL 不同,OHEM 直接去除了简单样本,论文也对比了 OHEM 的变种,在 NMS 后,构建 minibatch 时保持 1:3 的正负样本比。实验结果如 Table 1d,无论是原始的 OHEM 还是变种的 OHEM,实验结果都没有 FL 的性能好,大约有 3.2 的 AP 差异。因此,FL 更适用于 dense detector 的训练

Model Architecture Design

  • Anchor Density

  one-stage 检测器使用固定的网格进行预测,一个提高预测性能的方法是使用多尺度/多长宽比的 anchro 进行。实验结果如 Table 1c,单 anchor 能达到 30.3AP,而使用 9 anchors 能收获 4AP 的性能提升。最后,当增加到 9anchors 时,性能反而下降了,这说明,当 anchor 密度已经饱和了

  • Speed versus Accuracy

  更大 Backbone 和 input size 意味着更高准确率和更慢的推理速度,Table 1e 展示了这两者的影响,图 2 展示了 RetinaNet 与其它主流检测器的性能和速度对比。大尺寸的 RetinaNet 比大部分的 two-stage 性能要好,而且速度也更快

  • Comparison to State of the Art

  与当前的主流 one-stage 算法对比,RetinaNet 大概有 5.9 的 AP 提升,而与当前经典的 two-stage 算法对比,大约有 2.3 的 AP 提升,而使用 ResNeXt32x8d-101-FPN 作为 backbone 则能进一步提升 1.7AP

Conclusion


  论文认为类别不平衡问题是阻碍 one-stage 检测器性能提升的主要问题,为了解决这个问题,提出了 focal loss,在交叉熵的基础上添加了调节因子,让模型更集中于难样本的训练。另外,论文设计了 one-stage 检测器 RetinaNet 并给出了相当充足的实验结果

END

联盟学术交流群

扫码添加联盟小编,可与相关学者研究人员共同交流学习:目前开设有人工智能、机器学习、计算机视觉、自动驾驶(含SLAM)、Python、求职面经、综合交流群扫描添加CV联盟微信拉你进群,备注:CV联盟  

最新热文荐读

GitHub | 计算机视觉最全资料集锦

Github | 标星1W+清华大学计算机系课程攻略!

Github | 吴恩达新书《Machine Learning Yearning》

收藏 | 2020年AI、CV、NLP顶会最全时间表!

收藏 | 博士大佬总结的Pycharm 常用快捷键思维导图!

收藏 | 深度学习专项课程精炼图笔记!

笔记 | 手把手教你使用PyTorch从零实现YOLOv3

笔记 | 如何深入理解计算机视觉?(附思维导图)

笔记 | 深度学习综述思维导图(可下载)

笔记 | 深度神经网络综述思维导图(可下载)

点个在看支持一下吧

目标检测 | RetinaNet:Focal Loss for Dense Object Detection相关推荐

  1. 目标检测--Focal Loss for Dense Object Detection

    Focal Loss for Dense Object Detection ICCV2017 https://arxiv.org/abs/1708.02002 本文算是用简单的方法解决复杂的问题了,好 ...

  2. RetinaNet——《Focal Loss for Dense Object Detection》论文翻译

    <Focal Loss for Dense Object Detection> 摘要 迄今为止最高精度的对象检测器基于由R-CNN推广的 two-stage 方法,其中分类器应用于稀疏的候 ...

  3. RetinaNet论文详解Focal Loss for Dense Object Detection

    一.论文相关信息 ​ 1.论文题目:Focal Loss for Dense Object Detection ​ 2.发表时间:2017 ​ 3.文献地址:https://arxiv.org/pdf ...

  4. 【翻译】Focal Loss for Dense Object Detection(RetinaNet)

    [翻译]Focal Loss for Dense Object Detection(RetinaNet) 目录 摘要 1.介绍 2.相关工作 3.Focal Loss 3.1 平衡的交叉熵损失 3.2 ...

  5. Focal Loss for Dense Object Detection(整理后转载)

    @[TOC](Focal Loss for Dense Object Detection 论文目标 核心思想 focal loss的提出 交叉熵损失函数 focal loss的重要性质 focal l ...

  6. Focal Loss for Dense Object Detection(RetinaNet)(代码解析)

    转自:https://www.jianshu.com/p/db4ccd194109 转载于:https://www.cnblogs.com/leebxo/p/10485740.html

  7. 【目标检测】Focal Loss详解

    论文题目:<Focal Loss for Dense Object Detection> 论文链接:https://arxiv.org/pdf/1708.02002.pdf 1. 前言 我 ...

  8. 一种新的无监督前景目标检测方法 A New Unsupervised Foreground Object Detection Method

    14.一种新的无监督前景目标检测方法 A New Unsupervised Foreground Object Detection Method 摘要:针对基于无监督特征提取的目标检测方法效率不高的问 ...

  9. 【目标检测】cvpr2021_VarifocalNet: An IoU-Aware Dense Object Detector

    文章目录 一.背景 二.动机 三.方法 3.1 IACS--IoU-Aware Classification Score 3.2 Varifocal loss 3.3 Star-Shaped Box ...

最新文章

  1. 论 MySql InnoDB 如何通过插入意向锁控制并发插入
  2. 图解MyEclipse用DB Browser连接四种数据库
  3. python axis 0_axis=0在sum()和dropna()中的行为似乎不同
  4. LiveVideoStack线上分享第五季(三):新一代直播传输协议SRT
  5. event loop那些事儿
  6. Flink SQL Client的Rolling Aggregation实验解析
  7. 自动检测CSRF漏洞的工具
  8. 如何更改 C# Record 构造函数的行为
  9. 职称计算机考试 数量,职称计算机考试WPS基础考点:自动求和
  10. matlab2c使用c++实现matlab函数系列教程-ones函数
  11. 【重磅预告】揭秘阿里双11技术进步历程!
  12. 最牛ai波士顿动力上台阶_波士顿动力的位置如何使美国成为人工智能的关键参与者...
  13. linux模拟gps,Android之GPS研究(实战篇二)
  14. 提取SHP格式文件折点(拐点)地理坐标(经纬度)
  15. LeetCode题目Java代码解答 (详细解释!!!)
  16. 苹果系统被曝漏洞, 大麦网再遭撞库攻击 | 宅客周刊
  17. 5G工业无线网关在物联网的应用优势
  18. python爬取百度的工具_Python爬虫之小试牛刀——使用Python抓取百度街景图像
  19. uniapp开发的多端影视APP,对接的苹果CMS
  20. (5)多体量子态与统计力学基础

热门文章

  1. python比较列表所有字符串_python – 将字符串与数组中的所有值进行比较
  2. java jedis_Java操作Redis之Jedis用法详解
  3. python独立图形_在networkx中查找图形对象中的独立图形
  4. linux mysql jdbc_linux下jdbc连Mysql异常 郁闷了一天!
  5. mac nginx加载php 配置,Mac下Nginx安装环境配置详解
  6. html漂亮的表格模板+背景_教育与课程主题响应式网站着陆页模板
  7. [LeetCode]819. 最常见的单词
  8. Android wifi驱动的移植 realtek 8188
  9. 朝花夕拾-4-shell
  10. python time,datetime与highchart中的time