one stage 精度不高,一个主要原因是正负样本的不平衡,以YOLO为例,每个grid cell有5个预测,本来正负样本的数量就有差距,再相当于进行5倍放大后,这种数量上的差异更会被放大。

文中提出新的分类损失函数Focal loss,该损失函数通过抑制那些容易分类样本的权重,将注意力集中在那些难以区分的样本上,有效控制正负样本比例,防止失衡现象。也就是focal loss用于解决正负样本不平衡与难易样本不平衡的问题.

其中用于控制正负样本的数量失衡,用于控制简单/难区分样本数量失衡。一般=0.25,=2.也就是正样本loss相对增加,负样本loss相对减少,负样本相比正样本loss减少的倍数为3,同时困难样本loss相对增加,简单样本loss相对减少.

模型采用FPN,P3到P7,其中P7能够增加对大物体的检测。

在FPN的P3-P7中分别设置32x32-512x512尺寸不等的anchor,比例设置为{1:2, 1:1, 2:1}。每一层一共有9个anchor,不同层能覆盖的size范围为32-813。对每一个anchor,都对应一个K维的one-hot向量(K是类别数)和4维的位置回归向量。

同时分类子网对A个anchor,每个anchor中的K个类别,都预测一个存在概率。如下图所示,对于FPN的每一层输出,对分类子网来说,加上四层3x3x256卷积的FCN网络,最后一层的卷积稍有不同,用3x3xKA,最后一层维度变为KA表示,对于每个anchor,都是一个K维向量,表示每一类的概率,然后因为one-hot属性,选取概率得分最高的设为1,其余k-1为归0。传统的RPN在分类子网用的是1x1x18,只有一层,而在RetinaNet中,用的是更深的卷积,总共有5层,实验证明,这种卷积层的加深,对结果有帮助。与分类子网并行,对每一层FPN输出接上一个位置回归子网,该子网本质也是FCN网络,预测的是anchor和它对应的一个GT位置的偏移量。首先也是4层256维卷积,最后一层是4A维度,即对每一个anchor,回归一个(x,y,w,h)四维向量。注意,此时的位置回归是类别无关的。分类和回归子网虽然是相似的结构,但是参数是不共享的

代码:

正负样本计算loss的两种方式


import torch
import torch.nn.functional as Fdef focal_loss_one(alpha, beta, cls_preds, gts):print('======第一种实现方式=======')num_pos = gts.sum()print('==num_pos:', num_pos)alpha_tensor = torch.ones_like(cls_preds) * alphaalpha_tensor = torch.where(torch.eq(gts, 1.), alpha_tensor, 1. - alpha_tensor)print('===alpha_tensor===', alpha_tensor)preds = torch.where(torch.eq(gts, 1.), cls_preds, 1. - cls_preds)print('===1. - preds===', 1. - preds)focal_weight = alpha_tensor * torch.pow((1. - preds), beta)print('==focal_weight:', focal_weight)batch_bce_loss = -(gts * torch.log(cls_preds) + (1. - gts) * torch.log(1. - cls_preds))batch_focal_loss = focal_weight * batch_bce_lossprint('==batch_focal_loss:', batch_focal_loss)batch_focal_loss = batch_focal_loss.sum()print('== batch_focal_loss:', batch_focal_loss)print('==batch_focal_loss.item():', batch_focal_loss.item())if num_pos != 0:mean_batch_focal_loss = batch_focal_loss / num_poselse:mean_batch_focal_loss = batch_focal_lossprint('==mean_batch_focal_loss:', mean_batch_focal_loss)def focal_loss_two(alpha, beta, cls_preds, gts):print('======第二种实现方式=======')pos_inds = (gts == 1.0).float()print('==pos_inds:', pos_inds)neg_inds = (gts != 1.0).float()print('===neg_inds:', neg_inds)pos_loss = -pos_inds * alpha * (1.0 - cls_preds) ** beta * torch.log(cls_preds)neg_loss = -neg_inds * (1 - alpha) * ((cls_preds) ** beta) * torch.log(1.0 - cls_preds)num_pos = pos_inds.float().sum()print('==num_pos:', num_pos)pos_loss = pos_loss.sum()neg_loss = neg_loss.sum()if num_pos == 0:mean_batch_focal_loss = neg_losselse:mean_batch_focal_loss = (pos_loss + neg_loss) / num_posprint('==mean_batch_focal_loss:', mean_batch_focal_loss)def focal_loss_three(alpha, beta, cls_preds, gts):print('======第三种实现方式=======')num_pos = gts.sum()pred_sigmoid = cls_predstarget = gts.type_as(pred_sigmoid)pt = (1 - pred_sigmoid) * target + pred_sigmoid * (1 - target)focal_weight = (alpha * target + (1 - alpha) *(1 - target)) * pt.pow(beta)batch_focal_loss = F.binary_cross_entropy(pred_sigmoid, target, reduction='none') * focal_weightbatch_focal_loss = batch_focal_loss.sum()if num_pos != 0:mean_batch_focal_loss = batch_focal_loss / num_poselse:mean_batch_focal_loss = batch_focal_lossprint('==mean_batch_focal_loss:', mean_batch_focal_loss)
bs = 2
num_class = 3
alpha = 0.25
beta = 2
# (B, cls)
cls_preds = torch.rand([bs, num_class], dtype=torch.float)
print('==cls_preds:', cls_preds)
gts = torch.tensor([0, 2])
# (B, cls)
gts = F.one_hot(gts, num_classes=num_class).type_as(cls_preds)
print('===gts===', gts)
focal_loss_one(alpha, beta, cls_preds, gts)
focal_loss_two(alpha, beta, cls_preds, gts)
focal_loss_three(alpha, beta, cls_preds, gts)

只有正样本计算loss:

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variableclass FocalLoss(nn.Module):"""This criterion is a implemenation of Focal Loss, which is proposed inFocal Loss for Dense Object Detection.Loss(x, class) = - \alpha (1-softmax(x)[class])^gamma \log(softmax(x)[class])The losses are averaged across observations for each minibatch.Args:alpha(1D Tensor, Variable) : the scalar factor for this criteriongamma(float, double) : gamma > 0; reduces the relative loss for well-classified examples (p > .5),putting more focus on hard, misclassified examplessize_average(bool): By default, the losses are averaged over observations for each minibatch.However, if the field size_average is set to False, the losses areinstead summed for each minibatch."""def __init__(self, class_num, alpha=None, gamma=2, size_average=True):super(FocalLoss, self).__init__()if alpha is None:self.alpha = Variable(torch.ones(class_num, 1))else:if isinstance(alpha, Variable):self.alpha = alphaelse:self.alpha = Variable(alpha)self.gamma = gammaself.class_num = class_numself.size_average = size_averagedef forward(self, inputs, targets):N = inputs.size(0)C = inputs.size(1)P = F.softmax(inputs, dim=-1)print('===P:', P)#.data 获取variable的tensorclass_mask = inputs.data.new(N, C).fill_(0)class_mask = Variable(class_mask)ids = targets.view(-1, 1)class_mask.scatter_(1, ids.data, 1.)#得到onehotprint('==class_mask:', class_mask)if inputs.is_cuda and not self.alpha.is_cuda:self.alpha = self.alpha.cuda()alpha = self.alpha[ids.data.view(-1)]print('==alpha:', alpha)probs = (P*class_mask).sum(1).view(-1, 1)print('==probs:', probs)log_p = probs.log()batch_loss = -alpha*(torch.pow((1-probs), self.gamma))*log_pif self.size_average:loss = batch_loss.mean()else:loss = batch_loss.sum()return lossdef debug_focal():import numpy as np#只对困难样本计算lossloss = FocalLoss(class_num=8)#, alpha=torch.tensor([0.25,0.25,0.25,0.25,0.25,0.25,0.25,0.25]).reshape(-1, 1))inputs = torch.rand(2, 8)print('==inputs:', inputs)# print('==inputs.data:', inputs.data)# targets = torch.from_numpy(np.array([[1,0,0,0,0,0,0,0],#                                      [0,1,0,0,0,0,0,0]]))targets = torch.from_numpy(np.array([0, 1]))cost = loss(inputs, targets)print('===cost===:', cost)if __name__ == '__main__':debug_focal()

RetinaNet+focal loss相关推荐

  1. Focal Loss和它背后的男人RetinaNet

    说起Focal Loss,相信做CV的都不会陌生,当面临正负样本不平衡时可能第一个想到的就是用Focal Loss试试.但是怕是很多人会不知道这篇论文中所提出的one stage目标检测模型Retin ...

  2. 目标检测 | RetinaNet:Focal Loss for Dense Object Detection

    论文分析了 one-stage 网络训练存在的类别不平衡问题,提出能根据 loss 大小自动调节权重的 focal loss,使得模型的训练更专注于困难样本.同时,基于 FPN 设计了 RetinaN ...

  3. RetinaNet和Focal Loss论文笔记

    论文:Focal Loss for Dense Object Detection.Tsung-Yi Lin Priya Goyal Ross Girshick Kaiming He Piotr Dol ...

  4. RetinaNet——《Focal Loss for Dense Object Detection》论文翻译

    <Focal Loss for Dense Object Detection> 摘要 迄今为止最高精度的对象检测器基于由R-CNN推广的 two-stage 方法,其中分类器应用于稀疏的候 ...

  5. RetinaNet论文详解Focal Loss for Dense Object Detection

    一.论文相关信息 ​ 1.论文题目:Focal Loss for Dense Object Detection ​ 2.发表时间:2017 ​ 3.文献地址:https://arxiv.org/pdf ...

  6. 【翻译】Focal Loss for Dense Object Detection(RetinaNet)

    [翻译]Focal Loss for Dense Object Detection(RetinaNet) 目录 摘要 1.介绍 2.相关工作 3.Focal Loss 3.1 平衡的交叉熵损失 3.2 ...

  7. 【CV】RetinaNet:使用二分类类别不平衡损失 Focal Loss 实现更好的目标检测

    论文名称:Focal Loss for Dense Object Detection 论文下载:https://arxiv.org/abs/1610.02357 论文年份:ICCV 2017 论文被引 ...

  8. retinanet 部署_RetinaNet: Focal loss在目标检测网络中的应用

    介绍 RetinaNet是2018年Facebook AI团队在目标检测领域新的贡献.它的重要作者名单中Ross Girshick与Kaiming He赫然在列.来自Microsoft的Sun Jia ...

  9. Retinanet网络与focal loss损失

    参考代码:https://github.com/yhenon/pytorch-retinanet 1.损失函数 1)原理 本文一个核心的贡献点就是 focal loss.总损失依然分为两部分,一部分是 ...

最新文章

  1. nmt模型源文本词项序列_「自然语言处理(NLP)」阿里团队--文本匹配模型(含源码)...
  2. nacos 配置_SpringCloud Alibaba之Nacos配置中心
  3. Highly Available (Mirrored) Queues
  4. 自动化测试框架:没有Surprise的原因
  5. Memcache安装 2
  6. 【ASP.NET Web API教程】3.4 HttpClient消息处理器
  7. Unity3d知识点
  8. [解疑][TI]TI毫米波雷达系列(五):恒虚警算法(CFAR)原理
  9. Premiere 2020安装及Premiere缺失字体处理
  10. 计算机鼠标键盘没反应,终于找到为什么电脑鼠标键盘失灵了
  11. mp3转html,使用javascript将wav转换为mp3
  12. PHP实现站点pv,uv统计(一)
  13. 网页音乐制作器(网页钢琴)-- MusicMaker
  14. A2SHB规格书,A2SHB如何测试好坏
  15. 人工智能语言python培训
  16. windows添加右键点击打开CMD(运行)的方法
  17. 记一下唐伯虎点秋香里的经典台词
  18. Docker容器化技术笔记
  19. 三维点云论文——图片常用格式LaTeX排版
  20. oracle赋权操作

热门文章

  1. python xlrd使用_python处理Excel xlrd的简单使用
  2. Bing与DuckDuckGo搜索结果惊人一致?Google展现强势差异
  3. 浅入深出被人看扁的逻辑回归!
  4. 阿里P8架构师谈:Quartz调度框架详解、运用场景、与集群部署实践
  5. 论文浅尝 | 融合多层次领域知识的分子图对比学习
  6. Python快速找到列表中所有重复的元素
  7. 【NLP】Google BERT详解
  8. (六)Spark-Eclipse开发环境WordCount-JavaPython版Spark
  9. 11月25号站立会议
  10. C# WinForm程序退出的方法