paper:OTA: Optimal Transport Assignment for Object Detection

code:https://github.com/Megvii-BaseDetection/OTA 

背景

标签分配(Label Assignment)是目标检测中重要的一环,经典的标签分配策略采用预定义的规则为每个anchor匹配对应的gt或背景类。比如RetinaNet采用IoU作为划分正负样本的阈值标准,anchor-free检测器比如FCOS将ground truth物体的bbox内或bbox中心区域内的anchor point作为正样本。这种静态分配策略忽略了这样一个事实,即对于不同大小、形状、遮挡状态的对象,最适合的正负样本划分的边界可能是不同的。

基于此很多动态分配方法被提出,比如ATSS基于统计特征为每个gt设置划分边界,Freeanchor、Autoassign、PAA等方法提出anchor的预测分数可以作为一个合适的指标用来设计动态分配策略。

但是,不考虑上下文单独的为每个gt分配正负样本的方法可能不是最优的。对于模糊的anchor,即可能作为正样本分配给多个gt的anchor,现有的策略都是基于人工定义的准则,比如Min Area或Max IoU。作者指出把ambiguous anchor分配给任一个gt,对其他gt的学习都是不利的(introduce harmful gradients w.r.t. other gts),因此分配还需要更多的信息。一个更好的分配策略应该摆脱对每个gt单独追求最优分配的思想,转而全局最优的思想,找到一张图像中所有gt的综合最优分配策略。

本文的创新点

本文提出把标签分配当做最优传输问题,具体是把每个gt定义成一个supplier,它可以提供一定数量的label。把每个anchor定义成demander,它需要一个label。如果一个anchor从某个gt那得到了足够数量的positive label,这个anchor就被当做这个gt的一个正样本。每个gt可以提供的positive label的数量可以理解为这个gt在训练过程中需要多少个正样本来更好的收敛。每对anchor-gt的传输cost定义为它们之间的分类和回归loss的加权和。此外,背景类也被定义为supplier,它提供negative label,anchor-background之间的传输cost定义为它们之间的分类loss。这样标签分配问题就被转化为了最优传输问题,最终是为了找到全局最优的分配方法而不再是为每个gt单独寻找最优anchor。

具体方法

Optimal Transport

最优传输问题可以表述为:假设有 \(m\) 个supplier和 \(n\) 个demander,第 \(i\) 个supplier有 \(s_{i}\) 个物品,第 \(j\) 个demander需要 \(d_{j}\) 个物品,每个物品从第 \(i\) 个supplier运到第 \(j\) 个demander的运输运输成本为 \(c_{ij}\),最优传输的目标是找到一个最优传输方案 \(\pi^{*}=\left \{ \pi_{i,j}|i=1,2,...m,j=1,2,...n \right \} \) 能以最小的运输成本把所有的物品从supplier运输到demander。

OT for Label Assignment

对于目标检测问题,假设一张图片有 \(m\) 个gt和 \(n\) 个anchor(所有FPN level加起来),每个gt当做一个supplier,持有 \(k\) 个正标签 \((i.e.,s_{i}=k,i=1,2,...,m)\),每个anchor当做一个demander,需要一个标签 \((i.e.,d_{j}=1,j=1,2,...,n)\)。从 \(gt_{i}\) 传输一个正标签到anchor \(a_{j}\) 的运输成本 \(f^{fg}\) 定义为它们之间的分类损失和回归损失的加权和

其中 \(\theta\) 是模型参数,\(P_{j}^{cls}\) 和 \(P_{j}^{reg}\) 分别表示anchor \(a_{j}\) 的预测的分类得分和bounding box。\(G_{i}^{cls}\) 和 \(G_{i}^{box}\)  分别表示 \(gt_{i}\) 的ground truth类别和bounding box。\(L_{cls}\) 和 \(L_{reg}\) 分别表示交叉熵loss和IoU loss,也可以分别替换成Focal loss和GIoU/Smooth L1 loss,\(\alpha\) 是权重系数。

此外,还有另一种提供负标签的supplier,背景类。在标准的最优传输问题中,supply的数量和demand的数量是相等的。因此背景类一共可以提供 \(n-m\times k\) 个负标签,从背景类传输一个负标签到 \(a_{j}\) 的成本为

其中 \(\oslash\) 表示背景类,把 \(c^{bg}\in \mathbb{R}^{1\times n}\) 拼接到 \(c^{fg}\in \mathbb{R}^{m\times n}\) 的最后一行即得到了完整的cost matrix \(c\in \mathbb{R}^{(m+1)\times n}\)。supply vector \(s\) 需要按下式更新

现在有了cost matrix \(c\),supply vector \(s\in \mathbb{R}^{m+1}\),demand vector \(d\in \mathbb{R}^{n}\),则最优传输路径 \(\pi^{*}\in \mathbb{R}^{(m+1)\times n}\) 可通过现有的Sinkhorn-Knopp Iteration算法求得。得到 \(\pi^{*}\) 后,对应的标签分配就是将每个anchor分配给传输给这个anchor最多标签的gt

Advanced Designs

Center Prior

center prior即只从gt的中心有限区域挑选正样本,而不是整个bounding box范围内选择。强迫模型关注潜在positive areas即中心区域有助于稳定训练,特别是在训练的早期阶段,模型的最终性能也会更好。作者发现center prior对OTA的训练也有帮助,因此引入了center prior策略。

具体做法是,对于每个gt,只挑选每个FPN层中距离bounding box中心最近的 \(r^{2}\) 个anchor,对于bounding box内 \(r^{2}\) 之外的anchor,cost matrix中对应的cost会加上一个额外的常数项cost,这样就减少了训练阶段它们被分配为正样本的概率。

Dynamic \(k\) Estimation

每个gt需要的正样本数量应该是不同的并且基于很多因素,比如物体大小、尺度、遮挡情况等。由于很难将这些因素和所需anchor数量直接映射起来,本文提出了一种简单有效的方法,根据预测框和对应gt的IoU值来粗略估计每个gt合适的正样本数量。具体来说,对于每个gt,选择IoU最大的 \(q\) 个个预测,将这 \(q\) 个IoU值的和作为这个gt正样本数量的粗略估计值。这样做是基于直觉:某个gt的所需合适的postive anchor数量与和这个gt拟合的很好的anchor的数量正相关。

OTA的完整流程如下图所示

包含center prior和dynamic k estimation的完整流程伪代码如下所示

代码解读

这里batch_size=2,输入shape=(2, 3, 1085, 800),前景loss权重系数 \(\alpha=1.5\),center prior超参 \(r=2.5\),dynamic \(k\) estmation中 \(q=20\)。

其中line96计算前景loss和中的 1e6*(1-is_in_boxes.float()) 就是中心区域外的anchor额外加的常数项cost,line105将背景的cost拼接到前景cost矩阵最后就得到了最终的cost matrix,这里的loss就是cost matrix。mu和nu分别是上面的supply vector \(s\) 和 demand vector \(d\)。

核心代码如下,加了一些注释,其中sinkhorn算法没有专门了解原理,这里就直接用吧。

    def get_ground_truth(self, shifts, targets, box_cls, box_delta, box_iou):# shifts# [[(13600,2),(3400,2),(850,2),(221,2),(63,2)],#  [(13600,2),(3400,2),(850,2),(221,2),(63,2)]]# targets# [Instances(num_instances=2, image_height=1085, image_width=800,#     fields=[gt_boxes = Boxes(tensor([[216.9492, 217.0000, 605.6497, 965.1979], [246.3277, 160.4896, 501.6949, 641.9583]], device='cuda:0')),#             gt_classes = tensor([12, 14], device='cuda:0'), ]),#  Instances(num_instances=2, image_height=1085, image_width=800,#     fields=[gt_boxes = Boxes(tensor([[216.9492, 217.0000, 605.6497, 965.1979], [246.3277, 160.4896, 501.6949, 641.9583]], device='cuda:0')),#             gt_classes = tensor([12, 14], device='cuda:0'), ])]gt_classes = []gt_shifts_deltas = []gt_ious = []assigned_units = []box_cls = [permute_to_N_HWA_K(x, self.num_classes) for x in box_cls]# [(2,13600,20),(2,3400,20),(2,850,20),(2,221,20),(2,63,20)]box_delta = [permute_to_N_HWA_K(x, 4) for x in box_delta]# [(2,13600,4),(2,3400,4),(2,850,4),(2,221,4),(2,63,4)]box_iou = [permute_to_N_HWA_K(x, 1) for x in box_iou]# [(2,13600,1),(2,3400,1),(2,850,1),(2,221,1),(2,63,1)]box_cls = torch.cat(box_cls, dim=1)  # (2,18134,20)box_delta = torch.cat(box_delta, dim=1)  # (2,18134,4)box_iou = torch.cat(box_iou, dim=1)  # (2,18134,1)for shifts_per_image, targets_per_image, box_cls_per_image, \box_delta_per_image, box_iou_per_image in zip(shifts, targets, box_cls, box_delta, box_iou):shifts_over_all = torch.cat(shifts_per_image, dim=0)  # (18134,2)gt_boxes = targets_per_image.gt_boxes  # (2,4)# In gt box and center.deltas = self.shift2box_transform.get_deltas(shifts_over_all, gt_boxes.tensor.unsqueeze(1))  # (18134,2),(2,1,4) -> (2,18134,4)is_in_boxes = deltas.min(dim=-1).values > 0.01  # (2,18134)center_sampling_radius = 2.5centers = gt_boxes.get_centers()  # (2,2),# tensor([[388.7006, 591.0990],#         [425.9887, 401.2239]], device='cuda:0')# 因为数据增强的, gt_bboxes和centers每次运行结果都会变化is_in_centers = []for stride, shifts_i in zip(self.fpn_strides, shifts_per_image):  # [8, 16, 32, 64, 128], _radius = stride * center_sampling_radiuscenter_boxes = torch.cat((torch.max(centers - radius, gt_boxes.tensor[:, :2]),torch.min(centers + radius, gt_boxes.tensor[:, 2:]),), dim=-1)  # (2,4)center_deltas = self.shift2box_transform.get_deltas(shifts_i, center_boxes.unsqueeze(1))  # (13600,2),(2,1,4) -> (2,13600,4)is_in_centers.append(center_deltas.min(dim=-1).values > 0)is_in_centers = torch.cat(is_in_centers, dim=1)  # (2,18134)del centers, center_boxes, deltas, center_deltasis_in_boxes = (is_in_boxes & is_in_centers)num_gt = len(targets_per_image)num_anchor = len(shifts_over_all)shape = (num_gt, num_anchor, -1)  # (2,18134,-1)gt_cls_per_image = F.one_hot(targets_per_image.gt_classes, self.num_classes).float()  # (2,20)with torch.no_grad():loss_cls = sigmoid_focal_loss_jit(box_cls_per_image.unsqueeze(0).expand(shape),  # (18134,20)->(1,18134,20)->(2,18134,20)gt_cls_per_image.unsqueeze(1).expand(shape),  # (2,20)->(2,1,20)->(2,18134,20)alpha=self.focal_loss_alpha,  # 0.25gamma=self.focal_loss_gamma,  # 2).sum(dim=-1)  # (2,18134,20)->(2,18134)loss_cls_bg = sigmoid_focal_loss_jit(box_cls_per_image,  # (18134,20)torch.zeros_like(box_cls_per_image),alpha=self.focal_loss_alpha,gamma=self.focal_loss_gamma,).sum(dim=-1)  # (18134,20)->(18134)gt_delta_per_image = self.shift2box_transform.get_deltas(shifts_over_all, gt_boxes.tensor.unsqueeze(1)  # (18134,2), (2,4)->(2,1,4))  # (2,18134,4)ious, loss_delta = get_ious_and_iou_loss(box_delta_per_image.unsqueeze(0).expand(shape),  # (18134,4)->(1,18134,4)->(2,18134,4)gt_delta_per_image,box_mode="ltrb",loss_type='iou')  # (2,18134),(2,18134)loss = loss_cls + self.reg_weight * loss_delta + 1e6 * (1 - is_in_boxes.float())  # 1.5# (2,18134)# Performing Dynamic k Estimationtopk_ious, _ = torch.topk(ious * is_in_boxes.float(), self.top_candidates, dim=1)  # (2,18134),20 -> (2,20)mu = ious.new_ones(num_gt + 1)  # torch.Size([3]), tensor([1., 1., 1.], device='cuda:0')mu[:-1] = torch.clamp(topk_ious.sum(1).int(), min=1).float()  # s_{i}(i=1,...,m)mu[-1] = num_anchor - mu[:-1].sum()  # s_{m+1}nu = ious.new_ones(num_anchor)  # (18134), d_{j}(j=1,..,n)loss = torch.cat([loss, loss_cls_bg.unsqueeze(0)], dim=0)  # (2,18134),(18134)->(1,18134), -> (3,18134)# Solving Optimal-Transportation-Plan pi via Sinkhorn-Iteration._, pi = self.sinkhorn(mu, nu, loss)  # (3,),(18134,),(3,18134) -> (3,18134)# Rescale pi so that the max pi for each gt equals to 1.rescale_factor, _ = pi.max(dim=1)  # (3,)pi = pi / rescale_factor.unsqueeze(1)  # (3,18134)max_assigned_units, matched_gt_inds = torch.max(pi, dim=0)gt_classes_i = targets_per_image.gt_classes.new_ones(num_anchor) * self.num_classesfg_mask = matched_gt_inds != num_gtgt_classes_i[fg_mask] = targets_per_image.gt_classes[matched_gt_inds[fg_mask]]gt_classes.append(gt_classes_i)assigned_units.append(max_assigned_units)box_target_per_image = gt_delta_per_image.new_zeros((num_anchor, 4))box_target_per_image[fg_mask] = \gt_delta_per_image[matched_gt_inds[fg_mask], torch.arange(num_anchor)[fg_mask]]gt_shifts_deltas.append(box_target_per_image)gt_ious_per_image = ious.new_zeros((num_anchor, 1))gt_ious_per_image[fg_mask] = ious[matched_gt_inds[fg_mask],torch.arange(num_anchor)[fg_mask]].unsqueeze(1)gt_ious.append(gt_ious_per_image)return torch.cat(gt_classes), torch.cat(gt_shifts_deltas), torch.cat(gt_ious)

Experiments

Alation Studies and Analysis

Effects of Individual Components

OTA可以既可以用于anchor-based detector也可以用于anchor-free detector,本文采用FCOS,同时额外加入了IoU分支,从下图可以看出随着添加IoU branch、center prior、dynamic k estimation,性能持续提升,并且比对应的原始FCOS的精度要高。

Effects of \(r\)

center prior的半径 \(r\) 控制每个gt的正样本数量,\(r\) 值小,只有最靠近gt中心的高质量anchor才被当做正样本,有助于模型的学习。\(r\) 越大,引入的低质量的正样本anchor越多,导致了优化过程中潜在的不稳定。从下表可以看出,随着 \(r\) 的增大,三种模型的精度都出现了不同程度的下降,但OTA下降的最少,表明OTA对 \(r\) 值的变化不那么敏感,同时不同的 \(r\) 值下,OTA的精度也是最高的。

Ambiguous Anchors Handling

当发生遮挡或者多个对象靠的非常近时,一个anchor可能是多个ground truth的合格候选对象(比如Faster RCNN中一个anchor与多个gt的IoU都大于0.5),这种anchor定义为ambiguous anchor。之前的方法主要通过人工设定的规则来处理这种情况,比如Min Area、Max IoU、Min Loss等。本文将 \(max\ \pi^{*}_{j}<0.9\) 的anchor \(a_{j}\) 定义为ambiguous anchor,然后统计在不同的 \(r\) 值下ATSS、PAA、OTA的ambiguous anchor的数量以及对应的精度。从上表(2)中可以看出,随着 \(r\) 的增大,ATSS中ambiguous anchor的数量显著增加,AP也降了1.8个点。PAA中ambiguous anchor的数量对 \(r\) 的变化不那么敏感,但AP也降了0.8个点。而OTA中ambiguous anchor的数量既对 \(r\) 的变化不敏感,和ATSS、PAA相比数量也是最少的,同时AP也只下降了0.3个点。这是因为当多个gt试图将positive label传输到同一个anchor时,OT算法会基于全局最小传输成本的准则自动解决它们之间的冲突。

Effects of \(k\)

如下表所示,作者对比了 \(k\) 设置为不同的常数值以及采用dynamic \(k\) 时模型的精度,可以看出随着 \(k\) 的增大,模型精度越来越高,当 \(k\) 取10或12时,模型达到最高的精度,随后开始下降。但最高的精度也比采用dynamic \(k\) 的精度低。从直觉上讲,每个gt的大小、尺度、遮挡情况都不同,因此每个gt所需的postive anchor的数量应该也是不同的。

Comparison with State-of-the-art Methods

从下表可以看出,采用ResNet-101-FPN结构,OTA的AP达到了45.3%,超过了其它所有相同backbone的方法,如ATSS(43.6% AP)、AutoAssign(44.5% AP)、PAA(44.6% AP)。

OTA: Optimal Transport Assignment for Object Detection 原理与代码解读相关推荐

  1. 【目标检测】36、OTA: Optimal Transport Assignment for Object Detection

    文章目录 一.背景 二.方法 2.1 Optimal Transport 2.2 OT for label assignment 2.3 Center prior 2.4 Dynamic k Esti ...

  2. EGNet: Edge Guidance Network for Salient Object Detection 论文及代码解读

    EGNet: Edge Guidance Network for Salient Object Detection 论文及代码解读 注:本文原创作者为Jia-Xing Zhao, Jiang-Jian ...

  3. TOOD: Task-aligned One-stage Object Detection 原理与代码解析

    paper:TOOD: Task-aligned One-stage Object Detection code:https://github.com/fcjian/TOOD  存在的问题 目标检测包 ...

  4. End-to-End Semi-Supervised Object Detection with Soft Teacher 解读

    端到端的半监督目标检测 论文:https://arxiv.org/pdf/2106.09018v3.pdf 代码地址:https://github.com/microsoft/SoftTeacher ...

  5. UNITER多模态预训练模型原理加代码解读

    UNITER多模态预训练模型原理 1. 数据 ​ 过去的5年中,Vision+NLP的研究者所使用的主要数据集如下展示: ​ 本文中所使用到的4种数据集如下图所示,Conceptual Caption ...

  6. 平均符号熵的计算公式_交叉熵(Cross Entropy)从原理到代码解读

    交叉熵(Cross Entropy)是Shannon(香浓)信息论中的一个概念,在深度学习领域中解决分类问题时常用它作为损失函数. 原理部分:要想搞懂交叉熵需要先清楚一些概念,顺序如下:==1.自信息 ...

  7. stm32-DHT11原理及代码解读

    目录 一.基础知识 1.功能:温湿度检测 2.应用范围 3.硬件电路连接 二.底层代码原理分析 1.基础知识 1.单总线说明 2.单总线传送数据位定义 3.数据格式 4.校验位数据定义 2.代码分析 ...

  8. Transformer最详细的原理加代码解读

    Transformer原理 1. motivation ​ 为了解决seq2seq的问题,之前一般都是使用RNN模型进行求解.RNN的一大劣势就是无法进行并行化计算,比如要想输出b4b^4b4就必须要 ...

  9. matlab sift乘积量化,PQ(乘积量化)应用于ANN算法原理和代码解读

    背景 PQ算法全称ProductQuantization,中文名为乘积量化.该算法来源于图像检索,本质上是对向量做压缩.该算法也可以应用于ANN,本文介绍该算法在ANN的应用以及相关代码实现.算法介绍 ...

最新文章

  1. Matlab中的图形句柄(转载)
  2. ListView在列表中新增一行的操作(增加、取消)
  3. 先序序列为a、b、c、d的不同二叉树的个数是多少(卡特兰数)
  4. 超硬核!11 个非常实用的 Python 和 Shell 拿来就用脚本实例!
  5. SecureCRT自动记录日志
  6. 分别使用 XHR、jQuery 和 Fetch 实现 AJAX
  7. 工作中任务管理的四个原则和四个技能
  8. mysql5.4升级5.6_Laravel5.4 升级到 5.6
  9. 在Kubernetes上使用Sateful Set部署RabbitMQ集群
  10. 前端删除表格某一行信息怎么实现
  11. 为解放程序员而生,网易重磅推“场景化云服务”,强势进军云计算市场
  12. Python 每日一记248Java二叉树实现折纸问题
  13. 计算机界面视频录制软件,视频录制工具怎么用?这样的电脑录屏方法超实用!...
  14. 基于ThinkPHP的图书馆管理系统 毕业设计-附源码311833
  15. 医疗图像论文笔记二:《Learning to recognize Abnormalities in Chest X-Rays with Location-Aware Dense Networks》
  16. hihocoder#1369 : 网络流算法的一些小结
  17. 11. 符号和符号解析
  18. QUESTION: 由于文件 无法被用户‘_apt‘访问,已脱离沙盒并提权为根用户来进行下载。 - pkgAcquire::Run (13: 权限不够)
  19. 奶牛慢跑 (寒假每日一题 18)
  20. 配置基于IPv6的单节点Ceph

热门文章

  1. 千锋教育嵌入式物联网教程之系统编程篇学习-01
  2. 金立的Gpad G2
  3. Nginx介绍 安装
  4. python文件之间如何互相通信_python学习1-网络编程之udp_创建socket实现两电脑之间的通信...
  5. pycharm执行文件时报错can't find '__main__' module解决方法
  6. 移动App开发日志开发实例
  7. 《三国志幻想大陆》选神兵攻略,哪些神兵适合邓艾钟会?
  8. Iphone表达式计算器
  9. The Standard C Library电子书pdf下载
  10. Origin同时画柱状图和折线图(解决图层问题)