本文带领大家重温 Objects as Points 一文,其于2019年4月发布于arXiv,谷歌学术显示目前已有403次引用,Github代码仓库已有5.2K星标,无论在工业界和学术界均有巨大影响力。

论文作者信息:

论文:

https://arxiv.org/abs/1904.07850

0.动机

使用卷积神经网络做目标检测,大体可以分为单阶段(one-stage)方法和二阶段(two-stage)方法。单阶段检测器预定义很多anchor,基于这些anchor去做检测,密集的anchor有助于提高检测精度,然而在预测阶段,真正起到检测作用的只有少部分anchor,因此导致了计算资源的浪费。

无论是单阶段检测器还是二阶段检测器,在后处理阶段都需要NMS(Non-Maximum Suppression)操作去除多余框,NMS操作只在推理阶段存在,NMS的存在导致现有的检测器并不是严格的端到端训练的。

基于上述现象,作者提出了CenterNet算法,直接预测物体bounding box的中心点和尺寸。相比其他方法,该方法在预测阶段不需要NMS操作,极大的简化了网络的训练和推理过程。

1.CenterNet原理

CenterNet主要原理为:输入尺寸为






的3通道图像,表示为















,经过卷积神经网络运算,输出尺寸为












、通道数为数据集类别数




的heatmap,且heatmap中每个值在








区间内,表示为






























。默认情况下令






,即网络输出的heatmap的长和宽分别为输入图像的







使用


















表示heatmap中的第




个通道位置








处的值,当




















时,表示heatmap的








处是一个关键点(key point),若




















,表示heatmap的








处是背景。heatmap中关键点所在的位置对应原图像中目标的bounding box中心。

若输入图像的位置













处为类别




的bounding box中心,令


















,那么heatmap的第




个通道的








处的值为:

上式为高斯函数,式中的









与物体尺寸有关。若输入图片中有2个或多个相邻的类别为




的目标,在求取heatmap时,某个位置处的元素可能会得到根据多个目标求到的多个值,此时取最大值作为heatmap中该位置的值。

1.1 关键点(key point)损失函数

参考focal loss,构造如下形式的损失函数:




















是输入图片中目标的个数,也是heatmap中关键点的个数。上式中的













































是交叉熵损失函数,





















































是focal loss项。

focal loss的存在有如下影响:















  • 时,若


















    接近1,由于



























    项的存在,损失函数会急剧衰减,而当


















    不接近1时,损失函数轻微衰减,使得优化器更关注


















    不接近1的样本。















  • 时,若


















    接近0,由于

























    项的存在,损失函数急剧衰减,而当


















    接近1时,损失函数轻微衰减,使得优化器更关注


















    接近1的样本。















的情况中,损失函数还包含






















,结合下图讲述该项的作用:

上图为根据训练集的标注信息得到的heatmap,该heatmap作为网络的监督信息训练网络。关注上图中左边的目标,深绿色的点为输入图片bounding box中心位置













在heatmap上对应的位置


















,heatmap中该位置的值为1。

前文谈到,根据标注信息生成heatmap时,使用了二维高斯函数确定heatmap中









处周围位置的值,即上图中浅绿色方框位置的值,它们的值不为1但是接近1。

在网络训练过程中,由于






















的存在,当











很接近1时,损失函数的值被进一步压制,即使它们的值


















接近1也要让优化器不特别关注这些浅绿色位置的损失函数值,因为这些位置离物体的bounding box中心很接近。

1.2 offset损失函数

为了更准确地预测出输入图像中bounding box中心点的位置,该网络除了输出




个通道的heatmap外,还会输出2个通道的offset信息,记作




























如前文所述,输入图片bounding box中心位置













在heatmap上对应的位置为


















,这个过程存在向下取整运算,导致得到网络输出的heatmap中关键点的位置后,难以精确得到输入图片中bounding box中心的位置,offset用于弥补这一精度损失。

offset表示的值为


















,使用L1损失构造offset的损失函数,表示如下:

上式中



















表示网络预测的offset,


















可以根据训练集的标注信息得到。需要特别指出的是,offset损失只针对heatmap中的关键点,对于非关键点,不存在offset损失。

1.3 尺寸(size)损失函数

表示目标




的bounding box左上角和右下角的坐标,则该目标的尺寸为









可以通过训练集的标注信息的到。网络输出2个通道的信息用来预测目标的尺寸,记作




























。使用L1损失构造尺寸损失函数,表示如下:

1.4 整体的损失函数

根据上文内容可得到整体的损失函数,表示如下:

实验过程中,取上式的




























。使用这个损失函数训练网络,得到






个通道的feature map,分别表示关键点









、偏移









和尺寸









2.网络结构

作者尝试了4种网络结构,分别为ResNet-18、ResNet-101、DLA-34、Hourglass-104,如下图所示,方框内的数字用于指出特征的尺寸,当方框内数字为4时,表示此时特征的长和宽分别为输入图片的1/4。

  • Hourglass-104

如上图中图(a)所示,每个Hourglass模块是对称的5个下采样和5个上采样,中间使用skip connection进行短接,该网络的配置与文章《Cornernet:Detecting objects as paired keypoints》基本一致。

  • ResNet-18和ResNet-101

如上图中图(b)所示,作者在ResNet中做了些改动:在每个upsample操作之前加入了1个3x3的deformable convolutional layer,即在做upsample操作时,先通过deformable convolutional layer改变通道数,再进行反卷积(transposed convolution)操作。

  • DLA-34

上图中图(c)是原始的DLA-34结构,作者在其基础上进行改进,变成了图(d)中的结构。主要改动为:增加了一些skip connection,在upsample操作时使用deformable convolutional layer。

对于每种网络结构,backbone后面会增加3个网络head,分别预测关键点









、偏移









和尺寸









,每个head包括1个3x3的卷积、ReLU和1个1x1的卷积。网络输出结果的示意图如下图所示:

3.使用CenterNet做3D目标检测

2D目标检测只需要网络输出目标的位置和尺寸即可,而3D目标检测还需要网络输出目标的深度、(长、宽、高)、目标的角度这3个额外的信息。

3.1 深度

在backbone后面增加一个head用于预测深度信息,网络的输出并不是最终的深度。对于目标




,网络输出的深度信息记作
















,则目标的真正深度为




























表示sigmoid函数。

使用L1损失构造深度损失函数,表示如下:

上式中的









表示训练集中目标




的真实深度,单位为米。

3.2 长、宽、高

将目标的长、宽、高用向量表示,对于目标




,网络输出的长、宽、高记作





















,使用L1损失构造损失函数,表示如下:

上式中









表示训练集中目标




的长、宽、高,单位为米。

3.3 角度

使用卷积神经网络直接回归角度比较困难,因此作者将角度信息用8个标量表示,对于目标




,网络输出角度信息记作



















































































进行分组,其中


























表示第1组,用于预测目标角度位于




















范围内的角度值;


























表示第2组,用于预测目标角度位于




















范围内的角度值。

对于每一组,





















用于使用softmax函数进行分类,从而决定目标的预测角度是由第1组的信息表示还是由第2组的信息表示;





















分别用于预测“目标角度与该组内角度范围的中心角度”差值的






值和






值。使用交叉熵损失函数训练














,使用L1损失训练














,得到如下损失函数:

上式中,









表示角度区间的中间值,









由训练集中的标注信息得到,









用于指明训练集中目标的角度在哪一组的角度范围内,





























是卷积神经网络的输出值。

推理时,若输入一张图片,根据神经网络输出的





























































得到角度值:














































上式中




的取值为1或2,具体取值由









中的





























决定,根据





























选择使用第1组还是第2组的角度信息。

4.实验结果

4.1 2D目标检测

作者使用了ResNet-18、ResNet-101、DLA-34、Hourglass-104这4种网络进行实验,输入图片分辨率为512x512,输出图片分辨率为128x128,训练时使用了如下数据增强方法:

  • random flip

  • random scaling

  • cropping

  • color jittering

在COCO训练集上训练,使用Adam优化器,ResNet-101和DLA-34的下采样层由ImageNet预训练权重初始化,上采样层随机初始化;Hourglass-104以ExtremeNet为基础微调。在8卡TITAN-V GPU上,ResNet-101和DLA-34训练了2.5天,Hourglass-104训练了5天。在COCO验证集上测试,结果如下图所示:

上图中,“N.A”表示测试时未使用数据增强;“F”表示测试时使用了flip方式进行数据增强,在解码bounding box之前平均2个网络的输出结果;“MS”表示使用了5个尺度(0.5,0.75,1,1.25,1.5)进行推理,使用NMS融合5个网络的结果。“FPS”的测试是基于Intel Core i7-8086K CPU、TITAN Xp GPU、Pytorch 0.4.1、CUDA9.0和CUDNN7.1环境。

下图为使用DLA-34和Hourglass-104这2种结构在COCO测试集上的测试结果:

包含“/”的项表示“单尺度/多尺度”结果。可以看到使用Hourglass-104结构精度可以达到45.1% AP,超越了其他单阶段检测器。

4.2 3D目标检测

在KITTI数据集上训练3D目标检测算法,训练时没有使用数据增强技术。在训练和测试时,输入图片分辨率为1280x384,使用DLA-34网络结构。对于包含深度、(长、宽、高)和方向的损失函数,训练时权重均设置为1。

在测试时,recall的值设置为从0到1步长为0.1的11个值,IOU阈值为0.5,计算2D bounding box的AP、度量角度准确性的AOS、鸟瞰图bounding box的BEV AP。针对该数据集训练5个模型,测试结果取5个模型平均值,并给出标准差。测试结果如下:

上图中分别与Deep3DBox和Mono3D进行比较,并且按照要比较的算法划分测试集。在AP和AOS两个指标下,CenterNet略差于Deep3DBox和Mono3D方法,在BEV AP指标下明显优于这2个方法。

5.总结

这篇文章有如下亮点:

  • 提出了CenterNet框架用于目标检测,该方法预测目标的关键点和尺寸,简单、速度快、精度高,不需要预定义anchor,也不需要NMS,完全实现端到端训练;

  • 在CenterNet框架下,可以通过增加网络的head预测目标的其他属性,比如3D目标检测中的目标深度、角度等信息,可扩展性强。

  • 源码:

    https://github.com/xingyizhou/CenterNet

仅用于学习交流!

END

备注:目标检测

目标检测交流群

2D、3D目标检测等最新资讯,若已为CV君其他账号好友请直接私信。

我爱计算机视觉

微信号:aicvml

QQ群:805388940

微博知乎:@我爱计算机视觉

投稿:amos@52cv.net

网站:www.52cv.net

在看,让更多人看到  

重读 CenterNet,一个在Github有5.2K星标的目标检测算法相关推荐

  1. 基于anchor-free的目标检测算法CenterNet研究

    2020.04.18 小记 近期一直在MOT算法研究,目前SOTA算法核心还是基于CenterTrack网络,而CenterTrack又是基于anchor-free式的CenterNet网络,Cent ...

  2. 从零实现一个3D目标检测算法(3):PointPillars主干网实现(持续更新中)

    在上一篇文章<从零实现一个3D目标检测算法(2):点云数据预处理>我们完成了对点云数据的预处理. 从本篇文章,我们开始正式实现PointPillars网络,我们将按照本系列第一篇文章介绍的 ...

  3. 从零实现一个3D目标检测算法(2):点云数据预处理

    在上一篇文章<从零实现一个3D目标检测算法(1):3D目标检测概述>对3D目标检测研究现状和PointPillars模型进行了介绍,在本文中我们开始写代码一步步实现PointPillars ...

  4. 学习笔记--一个自管理(组织)的多目标进化算法(SMEA)

    学习笔记–一个自管理(组织)的多目标进化算法(SMEA) 摘要:在温和条件下,一个连续m维目标的优化问题的帕累托前沿(解集)可以形成一个(m-1)维的分段连续流形.基于这个性质,这篇文章提出了一个自管 ...

  5. CVPR 2020 Oral | 旷视提出目前最好的密集场景目标检测算法:一个候选框,多个预测结果...

    作为 CVPR 2020 Oral展示论文之一, 该文提出一种简单却有效的基于候选框的物体检测方法,尤其适用于密集物体检测.该方法通过一个候选框.多个预测框的概念,引入 EMD Loss.Set NM ...

  6. r语言把两个折线图图像放到一个图里_图像目标检测算法总结(从R-CNN到YOLO v3)...

    基于CNN 的目标检测是通过CNN 作为特征提取器,并对得到的图像的带有位置属性的特征进行判断,从而产出一个能够圈定出特定目标或者物体(Object)的限定框(Bounding-box,下面简写为bb ...

  7. 经典回顾 | 第一个Anchor-Free、NMS-Free 3D目标检测算法!

    作者 | 小书童  编辑 | 集智书童 点击下方卡片,关注"自动驾驶之心"公众号 ADAS巨卷干货,即可获取 点击进入→自动驾驶之心[3D目标检测]技术交流群 后台回复[3D检测综 ...

  8. 将多个文件夹或单文件夹内的xml文件转换为一个json标签(imagenet VID等视频目标检测数据集)简单易改,有注释

    文章目录 多文件夹xml文件转json 单文件夹xml文件转json 该代码主要针对视频目标检测yolov项目需要训练基础的yolox检测器所写(需要VID的json标签文件),鉴于网上没有公开的代码 ...

  9. keras优化算法_目标检测算法 - CenterNet - 代码分析

    代码出处 吃水不忘打井人,分析github上的基于keras的实现: xuannianz/keras-CenterNet​github.com 代码主体结构 模型训练的主函数流程如下所示,该流程也是使 ...

最新文章

  1. C#实现局域网内远程开机
  2. jsp 记录1 bs/cs
  3. Spring 基于注解(annotation)的配置之@Qualifier注解
  4. 网管交换机和非网管交换机有什么区别?
  5. gson json转map_Java 中几种常用 JSON 库性能比较
  6. Qt工作笔记-使用Qt Creator运行和调试运行结果不一样(参数没有初始化)
  7. 手机变身车机导航遥控器 高德地图上线手车互联新功能
  8. python基础语法加爬虫精进_从Python安装到语法基础,这才是初学者都能懂的爬虫教程...
  9. 如何解二阶齐线性微分方程
  10. 关于数字签名驱动解决方法
  11. 寻找复杂背景下的物体轮廓 (从禾路的博客园整理学习)
  12. 空心三角形 C语言版
  13. 基于JAVA+SpringMVC+Mybatis+MYSQL的学生签到管理系统
  14. Ui设计中常用的6大工具
  15. .jar是什么文件?(转载)
  16. 水生植物的Java莫斯
  17. P5 似然函数与狄拉克函数
  18. 通达信接口的登录调试步骤
  19. Java获取IP地址的多种方法
  20. IOS 字符串中去除特殊符号 stringByTrimmingCharactersInSet 应该用于账号登录等

热门文章

  1. ubuntu 16.04安装opencv 2.4.9及其关于qt的问题
  2. 天池-新闻推荐-数据分析
  3. 机器学习笔记I: 基于逻辑回归的分类预测
  4. 【编撰】linux IPC 001 - 概述
  5. 前端开发 跨平台的构架GSOAP
  6. 【基础】嵌入式浏览器移植基本要素
  7. 【读书笔记0103】Beginning linux programming-shell programming
  8. oracle 取第三大的值,Oracle 常见的几种访问提取数据的方式!
  9. c语言实验五函数答案,C语言程序设计实验五 参考答案.doc
  10. 嵩天python123测试6_神华化工股票