文章内容

  • 问题提出
  • 相关研究现状
    • 1. 辅助头和连续标签
    • 2. 基于排序的损失
    • 3. 样本不平衡问题
  • 本文工作
    • AP Loss的不足之处
    • RS Loss 定义
    • 训练细节
  • 代码解读

论文链接:Rank & Sort Loss for Object Detection and Instance Segmentation
mmdet实现代码:Rank & Sort Loss for Object Detection and Instance Segmentation

问题提出

当下常用的损失函数形式如下,将任务ttt在步骤kkk的损失加权求和:
L=∑k∈K∑t∈TλtkLtkL=\sum_{k\in K}\sum_{t\in T}\lambda_t^k L_t^k L=k∈K∑​t∈T∑​λtk​Ltk​
缺点是超参数过多,很容易引起特定任务之间的不平衡,如正负样本不平衡,级联网络内部的不平衡等,最终得到次优的解决方案。

AP Loss和aLRP Loss与传统的基于分类得分的损失函数相比,具有训练过程和网络评估指标一致性(直接优化网络评估指标AP/aLRP等)、待调超参数较少、对类别不均衡不敏感等优势,但需要更长的训练时间和更多增强操作,且没有建模正样本之间的关联

有研究证明,采用辅助头对正样本定位质量进行排序,或监督分类器直接回归样本的IoU(预测定位精度)能够提高网络性能。

因此本文提出RS Loss,不仅将正样本排序在负样本之前,还基于连续的IoU值在正样本内部进行排序,这将有以下好处:

  1. 正样本内部的排序使得网络不需要辅助头实现对正样本定位质量的排序
  2. 排序性质使得网络能够在没有采样策略的前提下处理极端不平衡的数据集
  3. 借助分类得分和IoU值共同排序调优,与NMS和网络评价指标如AP具有一致性
  4. 除了学习率没有需要调优的参数

相关研究现状

1. 辅助头和连续标签

很多实验证明,用辅助器预测检测结果的定位质量、中心度、IoU、mask-IoU或置信度,并将这些预测与NMS的分类得分相结合,可以提高检测性能。还有研究发现,使用连续的IoU值比使用过辅助器监督分类器效果更好,由此产生使用连续标签训练分类器的Quality Focal Loss并表现出对类别不均衡数据集的鲁棒性。

2. 基于排序的损失

基于排序的损失不可微、难以优化。black-box solvers采用插值AP解决该问题但收效甚微;DR Loss通过对Hinge Loss引入margin实现正负排序;AP Loss和aLRP Loss对性能评估指标进行优化,通过感知学习的误差驱动算法实现不可微部分的优化,但他们需要更长的训练时间和更多的增强手段。RS Loss与之区别在于将连续的定位质量得分作为标签。

3. 样本不平衡问题

常见的解决方法是引入超参数并通过网格化搜索的方式进行调参。有实验采用自平衡策略来平衡分类和定位分支,使两者在aLRP Loss的限定范围内竞争;还有研究使用分类和定位损失的比率来平衡这些任务。
本文中,不同任务的损失值LtkL_t^kLtk​有自己的限定范围,因此不同任务之间没有竞争关系,

本文工作

  1. 重新定义了错误驱动更新与反向传播,从而解决排序不可微的问题,能够计算排序损失
  2. 定义正负样本之间和正样本内的排序方法,解决类别不均衡问题

AP Loss的不足之处

AP Loss虽然基于排名重新定义了以AP为优化目标的损失函数,并借助感知器误差驱动优化算法实现反向传播,但有以下两处不足:

  1. 产生的损失值LLL没有考虑目标Lij∗L^∗_{ij}Lij∗​,因此当Lij∗≠0L^∗_{ij}\neq 0Lij∗​​=0时不可解释
  2. 只计算i∈P,j∈Ni\in P, j\in Ni∈P,j∈N时的损失,忽略了i,ji,ji,j都是正样本时的类内误差,而对于使用连续标签的算法来说类内误差不可忽视,标签越大得分就应该越高。

RS Loss 定义

首先定义当前损失值lRS(i)l_{RS}(i)lRS​(i)为正样本iii的 ranking error 和 sorting error 之和,其中ranking error代表正负样本间的排序损失(参考aLRP Loss形式),sorting error对正样本中得分sj>sis_j\gt s_isj​>si​的样本进行惩罚,其在[0,1]内连续的标签值(如IoU)越大,惩罚项越小:
lRS(i):=NFP(i)rank(i)+∑j∈PH(xij)(1−yj)rank+(i):=得分比正样本i高的负样本个数正确排名+得分更高的正样本的labels惩罚\begin{aligned} l_{RS}(i)&:=\textcolor{blue}{\frac{N_{FP}(i)}{rank(i)}}+\textcolor{orange}{\frac{\sum_{j\in P}H(x_{ij})(1-y_j)}{rank^+(i)}}\\ \\ &:=\textcolor{blue}{\frac{得分比正样本i高的负样本个数}{正确排名}}+\textcolor{orange}{得分更高的正样本的labels惩罚} \end{aligned} lRS​(i)​:=rank(i)NFP​(i)​+rank+(i)∑j∈P​H(xij​)(1−yj​)​:=正确排名得分比正样本i高的负样本个数​+得分更高的正样本的labels惩罚​
定义目标损失值lRS∗(i)l^*_{RS}(i)lRS∗​(i)如下,当正样本iii排序在所有负样本之前时,lR∗(i)=0l^*_{R}(i)=0lR∗​(i)=0, lS∗(i)l^*_{S}(i)lS∗​(i)对所有标签值yjy_jyj​大于样本iii得分的正样本的1−yj1-y_j1−yj​求均值,
lRS∗(i)=lR∗(i)+∑j∈PH(xij)[yj≥yi](1−yj)∑j∈PH(xij)[yj≥yi]=0+正样本中label大于i的样本(1−yi)之和正样本中label大于i的样本个数\begin{aligned} l^*_{RS}(i)&=\textcolor{blue}{l^*_{R}(i)}+\textcolor{orange}{\frac{\sum_{j\in P}H(x_{ij})[y_j\ge y_i](1-y_j)}{\sum_{j\in P}H(x_{ij})[y_j\ge y_i]}}\\ \\ &=\textcolor{blue}{0}+\textcolor{orange}{\frac{正样本中label大于i的样本(1-y_i)之和}{正样本中label大于i的样本个数}} \end{aligned} lRS∗​(i)​=lR∗​(i)+∑j∈P​H(xij​)[yj​≥yi​]∑j∈P​H(xij​)[yj​≥yi​](1−yj​)​=0+正样本中label大于i的样本个数正样本中label大于i的样本(1−yi​)之和​​
RS Loss定义为正样本的当前lRS(i)l_{RS}(i)lRS​(i)和目标值lRS∗(i)l^*_{RS}(i)lRS∗​(i)差异的均值:1∣P∣∑i∈P(lRS(i)−lRS∗(i))\frac{1}{|P|}\sum_{i\in P}\left( l_{RS}(i)-l^*_{RS}(i) \right)∣P∣1​∑i∈P​(lRS​(i)−lRS∗​(i))
参照AP Loss的三步定义,在此定义兼顾了正负样本和正样本内部误差的 primary term LijL_{ij}Lij​如下:
Lij={(lR(i)−lR∗(i))pR(j∣i)fori∈P,j∈N(lS(i)−lS∗(i))pS(j∣i)fori∈P,j∈P0otherwiseL_{ij}= \left\{ \begin{array} {rcl} (l_R(i)-l^*_R(i))p_R(j|i) & {for\space i\in P, j\in N}\\ (l_S(i)-l^*_S(i))p_S(j|i) & {for\space i\in P, j\in P}\\ 0 & {otherwise} \end{array} \right. Lij​=⎩⎨⎧​(lR​(i)−lR∗​(i))pR​(j∣i)(lS​(i)−lS∗​(i))pS​(j∣i)0​for i∈P,j∈Nfor i∈P,j∈Potherwise​
其中pR(j∣i),pS(j∣i)p_R(j|i),p_S(j|i)pR​(j∣i),pS​(j∣i)负责将样本iii上的误差分别分布在导致误差的样本jjj上(如只有得分sj>sis_j > s_isj​>si​的负样本jjj会引起 ranking error;只有得分sj>sis_j > s_isj​>si​且标签值yj<yiy_j < y_iyj​<yi​的正样本jjj会引起 sorting error),即:
pR(j∣i)=H(xij)∑k∈NH(xik)pS(j∣i)=H(xij)[yj<yi]∑k∈PH(xik)[yk<yi]\begin{aligned} &p_R(j|i)=\frac{H(x_{ij})}{\sum_{k\in N}H(x_{ik})} \\ &p_S(j|i)=\frac{H(x_{ij})[y_j<y_i]}{\sum_{k\in P}H(x_{ik})[y_k<y_i]} \end{aligned} ​pR​(j∣i)=∑k∈N​H(xik​)H(xij​)​pS​(j∣i)=∑k∈P​H(xik​)[yk​<yi​]H(xij​)[yj​<yi​]​​
因此,Identity Update过程可以分ranking error和 sorting error两部分进行:正样本同时受到两部分的影响,而负样本的定位精度对RS Loss没有影响,因此只被ranking error更新。

  • 对于正样本,其反向传播梯度∂LRS∂si\frac{\partial L_{RS}}{\partial s_i}∂si​∂LRS​​不仅包括样本iii本身的ranking error和 sorting error,这部分称为promotion update signal;此外还受到得分更高但连续标签值更小的其他正样本(missorted samples)的影响,这部分称为demotion update signal,与promotion符号相反,根据misranked样本jjj的信号推动该样本iii的信息更新。
    1∣P∣(lRS∗(i)−lRS(i)⏟promotionupdatesignal+∑j∈P(lS(j)−lS∗(j))pS(i∣j)⏟demotionupdatesignal)\frac{1}{|P|}\left( \underbrace{l^*_{RS}(i)-l_{RS}(i)}_{promotion\space update\space signal}+\underbrace{\sum_{j\in P}\left(l_{S}(j)-l^*_{S}(j)\right)p_S(i|j)}_{demotion\space update\space signal}\right) ∣P∣1​⎝⎜⎜⎜⎜⎛​promotion update signallRS∗​(i)−lRS​(i)​​+demotion update signalj∈P∑​(lS​(j)−lS∗​(j))pS​(i∣j)​​⎠⎟⎟⎟⎟⎞​
  • 对于负样本,其损失值只受到排名损失ranking error的反向传播影响,定义如下:
    ∂LRS∂si=1∣P∣∑j∈PlR(j)pR(i∣j)\frac{\partial L_{RS}}{\partial s_i}=\frac{1}{|P|}\sum_{j\in P} l_R(j)p_R(i|j) ∂si​∂LRS​​=∣P∣1​j∈P∑​lR​(j)pR​(i∣j)

训练细节

ATSS的损失定义为LATSS=Lcls+λboxLbox+λctrLctrL_{ATSS}=L_{cls}+\lambda_{box}L_{box}+\lambda_{ctr}L_{ctr}LATSS​=Lcls​+λbox​Lbox​+λctr​Lctr​,其中三个分量分别代表分类损失、定位损失(GIOU)和中心点定位损失(交叉熵损失)。
在此删除网络中的辅助头,并用RS Loss替代分类损失,其中连续标签值为目标框与真值框的IoU值,得到LRS−ATSS=LRS+λboxLboxL_{RS-ATSS}=L_{RS}+\lambda_{box}L_{box}LRS−ATSS​=LRS​+λbox​Lbox​,超参数λbox\lambda_{box}λbox​通常通过网格搜索设置为常数,在此采用两种启发式无调优算法来确定每个iteration的λbox\lambda_{box}λbox​值:

  • 基于损失值:λbox=LRS/Lbox\lambda_{box}=L_{RS}/L_{box}λbox​=LRS​/Lbox​
  • 基于梯度:λbox=∣∂LRS∂s∣/∣∂Lbox∂b∣\lambda_{box}=|\frac{\partial L_{RS}}{\partial s}|/|\frac{\partial L_{box}}{\partial b}|λbox​=∣∂s∂LRS​​∣/∣∂b∂Lbox​​∣,其中b,sb,sb,s分别是目标框回归和分类头的输出。

本文实验发现基于损失值的λbox\lambda_{box}λbox​设置方法与调参效果相近;并且本文用每个预测框的分类得分对它们的定位GIoU损失进行加权。这两个tricks(基于数值的任务平衡、基于得分的实例加权)都与超参数无关,可以应用于所有网络。

RS Loss 使用总结

  • 使用RS Loss时,通常会删除启发式/随机采样策略,并移除网络的 IoU aux. head 或 IoU Net,直接用预测结果的 IoU 来监督分类 head;
  • 会在预测框回归中使用基于分数的加权策略,并倾向于使用Dice Loss而非交叉熵损失来训练用于实例分割的mask prediction head(Dice Loss有界,且和GIOU一样对各种因素考虑周全)
  • 每个iteration后设置损失函数L=∑k∈K∑t∈TλtkLtkL=\sum_{k\in K}\sum_{t\in T}\lambda_t^k L_t^kL=∑k∈K​∑t∈T​λtk​Ltk​中各分量的权重为LclskLtk\frac{L^k_{cls}}{L^k_t}Ltk​Lclsk​​,只有RPN例外(×0.2\times0.2×0.2)

代码解读

class RankSort(torch.autograd.Function):@staticmethoddef forward(ctx, logits, targets, delta_RS=0.50, eps=1e-10):# ---------------------------------------------------------## targets: continuous label for each sample (e.g. IoU)# delta_RS: parameter of piecewise step functions# ---------------------------------------------------------#classification_grads = torch.zeros(logits.shape).cuda()# ---------------------## Filter fg logits# ---------------------#fg_labels = (targets > 0.)fg_logits = logits[fg_labels]fg_targets = targets[fg_labels]fg_num = len(fg_logits)sorting_error = torch.zeros(fg_num).cuda()ranking_error = torch.zeros(fg_num).cuda()fg_grad = torch.zeros(fg_num).cuda()# --------------------------------------## Filter non-trivial negative samples# --------------------------------------## Do not use bg with scores less than minimum fg logit# since changing its score does not have an effect on precisionthreshold_logit = torch.min(fg_logits)-delta_RSrelevant_bg_labels = ((targets == 0) & (logits >= threshold_logit))relevant_bg_logits = logits[relevant_bg_labels]relevant_bg_grad = torch.zeros(len(relevant_bg_logits)).cuda()# -----------------------------## Loop on posivite indices# -----------------------------## sort the fg logits and loops over each positive following the orderorder = torch.argsort(fg_logits)for ii in order:# --------------------------------## Difference Transforms (x_ij)# --------------------------------#fg_relations = fg_logits - fg_logits[ii]bg_relations = relevant_bg_logits - fg_logits[ii]# ----------------------## H(x_ij) \in [0,1]# ----------------------## piecewise step functionsif delta_RS > 0:fg_relations = torch.clamp(fg_relations / (2 * delta_RS) + 0.5, min=0, max=1)bg_relations = torch.clamp(bg_relations / (2 * delta_RS) + 0.5, min=0, max=1)# common functionelse:fg_relations = (fg_relations >= 0).float()bg_relations = (bg_relations >= 0).float()# Rank of ii among pos and false positive number (bg with larger scores)rank_pos = torch.sum(fg_relations)FP_num = torch.sum(bg_relations)# Rank of ii among all examplesrank = rank_pos + FP_num# ------------------## Ranking error# ------------------## Since target_ranking_error is always 0, we here store current_ranking_error as ranking_errorranking_error[ii] = FP_num/rank# ------------------------## Current sorting error# ------------------------#current_sorting_error = torch.sum(fg_relations * (1 - fg_targets)) / rank_pos# -------------------------------------------------------------------## Find examples ranking higher and targets larger than ii# -------------------------------------------------------------------#iou_relations = (fg_targets >= fg_targets[ii])target_sorted_order = iou_relations * fg_relations# The rank of ii among positives in sorted orderrank_pos_target = torch.sum(target_sorted_order)# ------------------------------------------------------## Target sorting error  (target ranking error is 0)# ------------------------------------------------------## Since target_ranking_error is always 0, target_sorting_error is also the target_errortarget_sorting_error = torch.sum(target_sorted_order * (1 - fg_targets)) / rank_pos_target# ------------------## Sorting error# ------------------#sorting_error[ii] = current_sorting_error - target_sorting_error# -------------------------------------## Identity Update for Ranking Error# -------------------------------------#if FP_num > eps:# For ii the update is the ranking errorfg_grad[ii] -= ranking_error[ii]# For negatives, distribute error via ranking pmf (i.e. bg_relations/FP_num)relevant_bg_grad += (bg_relations * (ranking_error[ii]/FP_num))# --------------------------------------------------------------## Find examples ranking higher but targets smaller than ii# --------------------------------------------------------------## Find the positives that are misranked (the ones with smaller IoU but larger logits)missorted_examples = (~ iou_relations) * fg_relations# Denominotor of sorting pmfsorting_pmf_denom = torch.sum(missorted_examples)# -------------------------------------## Identity Update for Sorting Error# -------------------------------------#if sorting_pmf_denom > eps:# For ii the update is the sorting errorfg_grad[ii] -= sorting_error[ii]# For positives, distribute error via sorting pmf (i.e. missorted_examples/sorting_pmf_denom)fg_grad += (missorted_examples * (sorting_error[ii] / sorting_pmf_denom))# Normalize gradients by number of positivesclassification_grads[fg_labels] = (fg_grad/fg_num)classification_grads[relevant_bg_labels] = (relevant_bg_grad/fg_num)ctx.save_for_backward(classification_grads)return ranking_error.mean(), sorting_error.mean()@staticmethoddef backward(ctx, out_grad1, out_grad2):g1, = ctx.saved_tensorsreturn g1*out_grad1, None, None, None

[论文解读] Rank Sort Loss for Object Detection and Instance Segmentation相关推荐

  1. 论文解读:DETR 《End-to-end object detection with transformers》,ECCV 2020

    论文解读:DETR <End-to-end object detection with transformers>,ECCV 2020 0. 论文基本信息 1. 论文解决的问题 2. 论文 ...

  2. Cascade R-CNN: High Quality Object Detection and Instance Segmentation(级联R-CNN:高质量目标检测与实例分割)

    Cascade R-CNN: High Quality Object Detection and Instance Segmentation Zhaowei Cai, and Nuno Vasconc ...

  3. 2023-一种无监督目标检测和实例分割方法【Cut and Learn for Unsupervised Object Detection and Instance Segmentation】

    Cut and Learn for Unsupervised Object Detection and Instance Segmentation 无监督目标检测和实例分割的剪切与学习 Faceboo ...

  4. [论文解读]Deep active learning for object detection

    Deep active learning for object detection 文章目录 Deep active learning for object detection 简介 摘要 初步 以前 ...

  5. 【D2Det】《 D2Det:Towards High Quality Object Detection and Instance Segmentation》

    CVPR-2020 Pytorch Code: https://github.com/JialeCao001/D2Det. 文章目录 1 Background and Motivation 2 Rel ...

  6. PAMI19 - 强大的级联RCNN架构《Cascade R-CNN: High Quality Object Detection and Instance Segmentation》

    文章目录 原文 初识 相知 Challenge to High Quality Detection Cascade RCNN 与相似工作的异同 扩展到实例分割 回顾 参考 原文 https://arx ...

  7. 论文解读 | Rich Feature Hierarchies for Accurate Object Detection and Semantic Segmentation

    论文地址:Rich Feature Hierarchies for Accurate Object Detection and Semantic Segmentation GitHub地址: http ...

  8. 目标检测经典论文——R-CNN论文翻译:Rich feature hierarchies for accurate object detection and semantic segmentation

    Rich feature hierarchies for accurate object detection and semantic segmentation--Tech report (v5) 用 ...

  9. 【论文精读】Rich Feature Hierarchies for Accurate Object Detection and Semantic Segmentation(R-CNN)

    论文Title:Rich Feature Hierarchies for Accurate Object Detection and Semantic Segmentation.发表于2014年. 本 ...

最新文章

  1. 青少年学python第六节_青少年学Python(第2册)
  2. 【Python】浅谈 multiprocessing
  3. python实时获取子进程输出_Python 从subprocess运行的子进程中实时获取输出的例子...
  4. 基于领域知识的Docker镜像自动构建方法
  5. leetcode二分查找
  6. python列表框_Python列表框
  7. 我的应用我做主丨动手搭建招聘小应用
  8. python打开快捷方式_Python打开一个JAR快捷方式
  9. 详解苹果 macOS Mail 中的零点击漏洞
  10. docker 镜像注册【图文教程】
  11. html游戏开源代码是什么,怎么运行html5游戏的源代码
  12. 反激式开关电源设计资料
  13. pygame安装超详细讲解
  14. 《白帽子讲Web安全》| 学习笔记之访问控制
  15. itunes备份是整个手机备份吗_iTunes备份道理我都懂,但我依然不想备份的?
  16. 性能监控平台prometheus+grafana
  17. 计算机软件方法专利撰写,干货 | 计算机软件专利撰写模板
  18. Unity3D for VR 学习(7): 360°全景照片
  19. 如何分析PARSEC源码
  20. 上升了百分之几怎么算_上涨百分之多少怎么算

热门文章

  1. python 32bit? 64bit?
  2. 【技术知识】SVAC 2.0安全技术浅析
  3. 立方体在三维坐标中的旋转(3D,Spining)
  4. ride导入自定义python库
  5. 【附源码】计算机毕业设计SSM我的大学电子相册
  6. 解决:SpringBoot中使用WebSocket传输数据,提示 1009|The decoded text message was too big for the output buffer and
  7. flask peewee教程
  8. LCD液晶屏和LED液晶屏的较量
  9. 支付宝 APP登录 获取用户信息 PHP
  10. 一个摆烂年轻人对手机的需求