介绍

RetinaNet是2018年Facebook AI团队在目标检测领域新的贡献。它的重要作者名单中Ross Girshick与Kaiming He赫然在列。来自Microsoft的Sun Jian团队与现在Facebook的Ross/Kaiming团队在当前视觉目标分类、检测领域有着北乔峰、南慕容一般的独特地位。这两个实验室的文章多是行业里前进方向的提示牌。

RetinaNet只是原来FPN网络与FCN网络的组合应用,因此在目标网络检测框架上它并无特别亮眼创新。文章中最大的创新来自于Focal loss的提出及在单阶段目标检测网络RetinaNet(实质为Resnet + FPN + FCN)的成功应用。Focal loss是一种改进了的交叉熵(cross-entropy, CE)loss,它通过在原有的CE loss上乘了个使易检测目标对模型训练贡献削弱的指数式,从而使得Focal loss成功地解决了在目标检测时,正负样本区域极不平衡而目标检测loss易被大批量负样本所左右的问题。此问题是单阶段目标检测框架(如SSD/Yolo系列)与双阶段目标检测框架(如Faster-RCNN/R-FCN等)accuracy gap的最大原因。在Focal loss提出之前,已有的目标检测网络都是通过像Boot strapping/Hard example mining等方法来解决此问题的。作者通过后续实验成功表明Focal loss可在单阶段目标检测网络中成功使用,并最终能以更快的速率实现与双阶段目标检测网络近似或更优的效果。

类别不平衡问题

常规的单阶段目标检测网络像SSD一般在模型训练时会先大密度地在模型终端的系列feature maps上生成出10,000甚至100,0000个目标候选区域。然后再分别对这些候选区域进行分类与位置回归识别。而在这些生成的数万个候选区域中,绝大多数都是不包含待检测目标的图片背景,这样就造成了机器学习中经典的训练样本正负不平衡的问题。它往往会造成最终算出的training loss为占绝对多数但包含信息量却很少的负样本所支配,少样正样本提供的关键信息却不能在一般所用的training loss中发挥正常作用,从而无法得出一个能对模型训练提供正确指导的loss。

常用的解决此问题的方法就是负样本挖掘。或其它更复杂的用于过滤负样本从而使正负样本数维持一定比率的样本取样方法。而在此篇文章中作者提出了可通过候选区域包含潜在目标概率进而对最终的training loss进行较正的方法。实验表明这种新提出的focal loss在单阶段目标检测任务上表现突出,有效地解决了此领域里面潜在的类别不平衡问题。

Focal loss

CE(cross-entropy) loss

以下为典型的交叉熵loss,它广泛用于当下的图像分类、检测CNN网络当中。

Cross-entropy_loss

Balanced CE loss

考虑到上节中提到的类别不平衡问题对最终training loss的不利影响,我们自然会想到可通过在loss公式中使用与目标存在概率成反比的系数对其进行较正。如下公式即是此朴素想法的体现。它也是作者最终Focus loss的baseline。

Balanced_CE-loss

Focal loss定义

以下是作者提出的focal loss的想法。

Focal_loss定义

下图为focal loss与常规CE loss的对比。从中,我们易看出focal loss所加的指数式系数可对正负样本对loss的贡献自动调节。当某样本类别比较明确些,它对整体loss的贡献就比较少;而若某样本类别不易区分,则对整体loss的贡献就相对偏大。这样得到的loss最终将集中精力去诱导模型去努力分辨那些难分的目标类别,于是就有效提升了整体的目标检测准度。不过在此focus loss计算当中,我们引入了一个新的hyper parameter即γ。一般来说新参数的引入,往往会伴随着模型使用难度的增加。在本文中,作者有试者对其进行调节,线性搜索后得出将γ设为2时,模型检测效果最好。

几种loss的对比

在最终所用的focal loss上,作者还引入了α系数,它能够使得focal loss对不同类别更加平衡。实验表明它会比原始的focal loss效果更好。

最终所用的Focal_loss

模型的初始化参数选择

一般我们初始化CNN网络模型时都会使用无偏的参数对其初始化,比如Conv的kernel 参数我们会以bias 为0,variance为0.01的某分布来对其初始化。但是如果我们的模型要去处理类别极度不平衡的情况,那么就会考虑到这样对训练数据分布无任选先验假设的初始化会使得在训练过程中,我们的参数更偏向于拥有更多数量的负样本的情况去进化。作者观察下来发现它在训练时会出现极度的不稳定。于是作者在初始化模型最后一层参数时考虑了数据样本分布的不平衡性,这样使得初始训练时最终得出的loss不会对过多的负样本数量所惊讶到,从而有效地规避了初始训练时模型的震荡与不稳定。

RetinaNet检测框架

RetinaNet本质上是Resnet + FPN + 两个FCN子网络。

以下为RetinaNet目标框架框架图。有了之前blog里面提到的FPN与FCN的知识后,我们很容易理解此框架的设计含义。

RetinaNet目标检测框架

一般主干网络可选用任一有效的特征提取网络如vgg16或resnet系列,此处作者分别尝试了resnet-50与resnet-101。而FPN则是对resnet-50里面自动形成的多尺度特征进行了强化利用,从而得到了表达力更强、包含多尺度目标区域信息的feature maps集合。最后在FPN所吐出的feature maps集合上,分别使用了两个FCN子网络(它们有着相同的网络结构却各自独立,并不share参数)用来完成目标框类别分类与位置回归任务。

模型的推理与训练

模型推理

一旦我们有了训练好的模型,在正式部署时,只需对其作一次forward,然后对最终生成的目标区域进行过渡。然后只对每个FPN level上目标存在概率最高的前1000个目标框进一步地decoding处理。接下来再将所有FPN level上得到的目标框汇集起来,统一使用极大值抑制的方法进一步过渡(其中极大值抑制时所用的阈值为0.5)。这样,我们就得到了最终的目标与其位置框架。

模型训练

模型训练中主要在后端Loss计算时采用了Focal loss,另外也在模型初始化时考虑到了正负样本极度不平衡的情况进而对模型最后一个conv layer的bias参数作了有偏初始化。

训练时用了SGD,mini batch size为16,在8个GPU上一块训练,每个GPU上local batch size为2。最大iterations数目为90,000;模型初始lr为0.01,接下来随着训练进行分step wisely 降低。真正的training loss则为表达目标类别的focus loss与表达目标框位置回归信息的L1 loss的和。

下图为RetinaNet模型的检测准度与性能。

RetinaNet的检测准度与性能

代码实例

以下函数用于从FPN的各个level的feature maps上提取各种scale的anchor box。

def _create_cell_anchors():

"""

Generate all types of anchors for all fpn levels/scales/aspect ratios.

This function is called only once at the beginning of inference.

"""

k_max, k_min = cfg.FPN.RPN_MAX_LEVEL, cfg.FPN.RPN_MIN_LEVEL

scales_per_octave = cfg.RETINANET.SCALES_PER_OCTAVE

aspect_ratios = cfg.RETINANET.ASPECT_RATIOS

anchor_scale = cfg.RETINANET.ANCHOR_SCALE

A = scales_per_octave * len(aspect_ratios)

anchors = {}

for lvl in range(k_min, k_max + 1):

# create cell anchors array

stride = 2. ** lvl

cell_anchors = np.zeros((A, 4))

a = 0

for octave in range(scales_per_octave):

octave_scale = 2 ** (octave / float(scales_per_octave))

for aspect in aspect_ratios:

anchor_sizes = (stride * octave_scale * anchor_scale, )

anchor_aspect_ratios = (aspect, )

cell_anchors[a, :] = generate_anchors(

stride=stride, sizes=anchor_sizes,

aspect_ratios=anchor_aspect_ratios)

a += 1

anchors[lvl] = cell_anchors

return anchors

下面函数则描述了如何使用train好的RetinaNet来进行图片目标检测。

def im_detect_bbox(model, im, timers=None):

"""Generate RetinaNet detections on a single image."""

if timers is None:

timers = defaultdict(Timer)

# Although anchors are input independent and could be precomputed,

# recomputing them per image only brings a small overhead

anchors = _create_cell_anchors()

timers['im_detect_bbox'].tic()

k_max, k_min = cfg.FPN.RPN_MAX_LEVEL, cfg.FPN.RPN_MIN_LEVEL

A = cfg.RETINANET.SCALES_PER_OCTAVE * len(cfg.RETINANET.ASPECT_RATIOS)

inputs = {}

inputs['data'], im_scale, inputs['im_info'] = \

blob_utils.get_image_blob(im, cfg.TEST.SCALE, cfg.TEST.MAX_SIZE)

cls_probs, box_preds = [], []

for lvl in range(k_min, k_max + 1):

suffix = 'fpn{}'.format(lvl)

cls_probs.append(core.ScopedName('retnet_cls_prob_{}'.format(suffix)))

box_preds.append(core.ScopedName('retnet_bbox_pred_{}'.format(suffix)))

for k, v in inputs.items():

workspace.FeedBlob(core.ScopedName(k), v.astype(np.float32, copy=False))

workspace.RunNet(model.net.Proto().name)

cls_probs = workspace.FetchBlobs(cls_probs)

box_preds = workspace.FetchBlobs(box_preds)

# here the boxes_all are [x0, y0, x1, y1, score]

boxes_all = defaultdict(list)

cnt = 0

for lvl in range(k_min, k_max + 1):

# create cell anchors array

stride = 2. ** lvl

cell_anchors = anchors[lvl]

# fetch per level probability

cls_prob = cls_probs[cnt]

box_pred = box_preds[cnt]

cls_prob = cls_prob.reshape((

cls_prob.shape[0], A, int(cls_prob.shape[1] / A),

cls_prob.shape[2], cls_prob.shape[3]))

box_pred = box_pred.reshape((

box_pred.shape[0], A, 4, box_pred.shape[2], box_pred.shape[3]))

cnt += 1

if cfg.RETINANET.SOFTMAX:

cls_prob = cls_prob[:, :, 1::, :, :]

cls_prob_ravel = cls_prob.ravel()

# In some cases [especially for very small img sizes], it's possible that

# candidate_ind is empty if we impose threshold 0.05 at all levels. This

# will lead to errors since no detections are found for this image. Hence,

# for lvl 7 which has small spatial resolution, we take the threshold 0.0

th = cfg.RETINANET.INFERENCE_TH if lvl < k_max else 0.0

candidate_inds = np.where(cls_prob_ravel > th)[0]

if (len(candidate_inds) == 0):

continue

pre_nms_topn = min(cfg.RETINANET.PRE_NMS_TOP_N, len(candidate_inds))

inds = np.argpartition(

cls_prob_ravel[candidate_inds], -pre_nms_topn)[-pre_nms_topn:]

inds = candidate_inds[inds]

inds_5d = np.array(np.unravel_index(inds, cls_prob.shape)).transpose()

classes = inds_5d[:, 2]

anchor_ids, y, x = inds_5d[:, 1], inds_5d[:, 3], inds_5d[:, 4]

scores = cls_prob[:, anchor_ids, classes, y, x]

boxes = np.column_stack((x, y, x, y)).astype(dtype=np.float32)

boxes *= stride

boxes += cell_anchors[anchor_ids, :]

if not cfg.RETINANET.CLASS_SPECIFIC_BBOX:

box_deltas = box_pred[0, anchor_ids, :, y, x]

else:

box_cls_inds = classes * 4

box_deltas = np.vstack(

[box_pred[0, ind:ind + 4, yi, xi]

for ind, yi, xi in zip(box_cls_inds, y, x)]

)

pred_boxes = (

box_utils.bbox_transform(boxes, box_deltas)

if cfg.TEST.BBOX_REG else boxes)

pred_boxes /= im_scale

pred_boxes = box_utils.clip_tiled_boxes(pred_boxes, im.shape)

box_scores = np.zeros((pred_boxes.shape[0], 5))

box_scores[:, 0:4] = pred_boxes

box_scores[:, 4] = scores

for cls in range(1, cfg.MODEL.NUM_CLASSES):

inds = np.where(classes == cls - 1)[0]

if len(inds) > 0:

boxes_all[cls].extend(box_scores[inds, :])

timers['im_detect_bbox'].toc()

# Combine predictions across all levels and retain the top scoring by class

timers['misc_bbox'].tic()

detections = []

for cls, boxes in boxes_all.items():

cls_dets = np.vstack(boxes).astype(dtype=np.float32)

# do class specific nms here

keep = box_utils.nms(cls_dets, cfg.TEST.NMS)

cls_dets = cls_dets[keep, :]

out = np.zeros((len(keep), 6))

out[:, 0:5] = cls_dets

out[:, 5].fill(cls)

detections.append(out)

# detections (N, 6) format:

# detections[:, :4] - boxes

# detections[:, 4] - scores

# detections[:, 5] - classes

detections = np.vstack(detections)

# sort all again

inds = np.argsort(-detections[:, 4])

detections = detections[inds[0:cfg.TEST.DETECTIONS_PER_IM], :]

# Convert the detections to image cls_ format (see core/test_engine.py)

num_classes = cfg.MODEL.NUM_CLASSES

cls_boxes = [[] for _ in range(cfg.MODEL.NUM_CLASSES)]

for c in range(1, num_classes):

inds = np.where(detections[:, 5] == c)[0]

cls_boxes[c] = detections[inds, :5]

timers['misc_bbox'].toc()

return cls_boxes

以下为RetinaNet中training loss的具体计算。可以看出它包含了两个部分分别为反映位置信息的L1 loss与反映类别信息的focus loss。

def add_fpn_retinanet_losses(model):

loss_gradients = {}

gradients, losses = [], []

k_max = cfg.FPN.RPN_MAX_LEVEL # coarsest level of pyramid

k_min = cfg.FPN.RPN_MIN_LEVEL # finest level of pyramid

model.AddMetrics(['retnet_fg_num', 'retnet_bg_num'])

# ==========================================================================

# bbox regression loss - SelectSmoothL1Loss for multiple anchors at a location

# ==========================================================================

for lvl in range(k_min, k_max + 1):

suffix = 'fpn{}'.format(lvl)

bbox_loss = model.net.SelectSmoothL1Loss(

[

'retnet_bbox_pred_' + suffix,

'retnet_roi_bbox_targets_' + suffix,

'retnet_roi_fg_bbox_locs_' + suffix, 'retnet_fg_num'

],

'retnet_loss_bbox_' + suffix,

beta=cfg.RETINANET.BBOX_REG_BETA,

scale=model.GetLossScale() * cfg.RETINANET.BBOX_REG_WEIGHT

)

gradients.append(bbox_loss)

losses.append('retnet_loss_bbox_' + suffix)

# ==========================================================================

# cls loss - depends on softmax/sigmoid outputs

# ==========================================================================

for lvl in range(k_min, k_max + 1):

suffix = 'fpn{}'.format(lvl)

cls_lvl_logits = 'retnet_cls_pred_' + suffix

if not cfg.RETINANET.SOFTMAX:

cls_focal_loss = model.net.SigmoidFocalLoss(

[

cls_lvl_logits, 'retnet_cls_labels_' + suffix,

'retnet_fg_num'

],

['fl_{}'.format(suffix)],

gamma=cfg.RETINANET.LOSS_GAMMA,

alpha=cfg.RETINANET.LOSS_ALPHA,

scale=model.GetLossScale(),

num_classes=model.num_classes - 1

)

gradients.append(cls_focal_loss)

losses.append('fl_{}'.format(suffix))

else:

cls_focal_loss, gated_prob = model.net.SoftmaxFocalLoss(

[

cls_lvl_logits, 'retnet_cls_labels_' + suffix,

'retnet_fg_num'

],

['fl_{}'.format(suffix), 'retnet_prob_{}'.format(suffix)],

gamma=cfg.RETINANET.LOSS_GAMMA,

alpha=cfg.RETINANET.LOSS_ALPHA,

scale=model.GetLossScale(),

num_classes=model.num_classes

)

gradients.append(cls_focal_loss)

losses.append('fl_{}'.format(suffix))

loss_gradients.update(blob_utils.get_loss_gradients(model, gradients))

model.AddLosses(losses)

return loss_gradients

参考文献

Focal Loss for Dense Object Detection, Tsung-Yi Lin, 2018

retinanet 部署_RetinaNet: Focal loss在目标检测网络中的应用相关推荐

  1. 目标检测网络中的Backbone,Neck和Head - 以YOLOv4为例

    目标检测网络中的Backbone,Neck和Head - 以YOLOv4为例 目标检测网络中常见到的三个概念: Backbone:在不同图像细粒度上聚合并形成图像特征的卷积神经网络: Neck:一系列 ...

  2. 目标检测网络中的 bottom-up 和 top-down理解

    看目标检测网络方面的论文时,出现了一组对比词汇: bottom-up和top-down,查了一些资料,结合个人理解,得到的看法是: top-down: 顾名思义是自上而下进行,最初来源于行人检测框架, ...

  3. 深度目标检测网络中关于anchor的神之问(配代码详解)(二)

    目录 将所求出的所有anchor都用于计算吗?如何将筛选所用于计算proposal的anchor点? 如何用anchor来计算proposal(分类与边框回归)? 如何根据前景anchor和GT作Bo ...

  4. 计算机视觉算法——目标检测网络总结

    计算机视觉算法--目标检测网络总结 计算机视觉算法--目标检测网络总结 1. RCNN系列 1.1 RCNN 1.1.1 关键知识点--网络结构及特点 1.1.2 关键知识点--RCNN存在的问题 1 ...

  5. 收藏 | 如何定义目标检测网络的正负例:Anchor-based

    点上方计算机视觉联盟获取更多干货 仅作学术分享,不代表本公众号立场,侵权联系删除 转载于:作者:知乎-Lighthouse 地址:https://www.zhihu.com/org/lighthous ...

  6. RefineDetLite:腾讯提出轻量级高精度目标检测网络

    点击我爱计算机视觉标星,更快获取CVML新技术 前几天腾讯公布了一篇论文RefineDetLite: A Lightweight One-stage Object Detection Framewor ...

  7. 从L1 loss到EIoU loss,目标检测边框回归的损失函数一览

    本文转载自知乎,已获作者授权转载. 链接:https://zhuanlan.zhihu.com/p/342991797 目标检测任务的损失函数由Classificition Loss和BBox Reg ...

  8. 【camera】全景驾驶感知网络YOLOP部署与实现(交通目标检测、可驾驶区域分割、车道线检测)

    全景驾驶感知网络YOLOP部署与实现(交通目标检测.可驾驶区域分割.车道线检测) 项目下载地址 包含C++和Python两种版本的程序实现:下载地址 YOLOP开源项目: https://github ...

  9. 丢弃Transformer!旷视和西安交大提出基于FCN的端到端目标检测网络

    点击上方,选择星标或置顶,不定期资源大放送! 阅读大概需要15分钟 Follow小博主,每天更新前沿干货 本文作者:王剑锋  | 编辑:Amusi https://zhuanlan.zhihu.com ...

最新文章

  1. SSM整合pom.xml和导包
  2. Codefest 18 (rated, Div. 1 + Div. 2)-D-Valid BFS--思维--已知bfs序,求是否正确
  3. android版本如何修改时间,如何修改Android系统默认时间
  4. CAS实现单点登录方案(SSO完整版)
  5. 飞凌 ok6410 按键驱动源码及测试代码
  6. 使用虚拟机VMware12定制安装redhat6企业版
  7. 设计灵感|优秀案例教你如何像杂志一样排版?
  8. 图解源码之java锁的获取和释放(AQS)篇
  9. 单机到集群的WEB架构演变
  10. ios 设置字体家族
  11. 网易云音乐广告CTR预估模型演进过程
  12. AT24CXX、DS1302、红外HS0038 20170610 周六
  13. 周四007欧联杯 佛罗伦萨 VS 门兴[11]
  14. LAZARUS APT利用恶意word文档攻击MAC用户
  15. 曼哈顿距离,欧式距离,余弦距离
  16. Kali社会工程学套件入侵Windows
  17. 关于STM32 GPIO配置基础概括
  18. PHP常见的设计模式之:适配器模式
  19. 一加手机怎么root权限_一加五,怎么获取ROOT权限
  20. 【NLP】Praat库(1) 安装及初步使用

热门文章

  1. 当AI机器人闯进了艺术创作,下一个被威胁的是画家?
  2. 小鱼授权系统源码_无加密
  3. java长整型转换为整型_java ip地址转换为长整型
  4. UE5 5.0正式版 新功能详解
  5. linux mint下安装企鹅输入法
  6. App识别微信小程序二维码、太阳码调研
  7. ​力扣解法汇总2347. 最好的扑克手牌
  8. 什么是 web 语义化,有什么好处
  9. 基于Vue和SpringBoot的毕业生追踪系统的设计和实现
  10. Matlab:数据可视化