目录

1. 使用Smoooh L1 Loss的原因

2. Faster RCNN的损失函数

2.1 分类损失

2.2 回归损失

一些感悟


关于文章中具体一些代码及参数如何得来的请看博客:

tensorflow+faster rcnn代码解析(二):anchor_target_layer、proposal_target_layer、proposal_layer

最近又重新学习了一遍Faster RCNN有挺多收获的,在此重新记录一下。

1. 使用Smoooh L1 Loss的原因

对于边框的预测是一个回归问题。通常可以选择平方损失函数(L2损失)f(x)=x^2。但这个损失对于比较大的误差的惩罚很高。

我们可以采用稍微缓和一点绝对损失函数(L1损失)f(x)=|x|,它是随着误差线性增长,而不是平方增长。但这个函数在0点处导数不存在,因此可能会影响收敛。

一个通常的解决办法是,分段函数,在0点附近使用平方函数使得它更加平滑。它被称之为平滑L1损失函数。它通过一个参数σ 来控制平滑的区域。一般情况下σ = 1,在faster rcnn函数中σ = 3

2. Faster RCNN的损失函数

Faster RCNN的的损失主要分为RPN的损失和Fast RCNN的损失,计算公式如下,并且两部分损失都包括分类损失(cls loss)回归损失(bbox regression loss)。

下面分别讲一下RPN和fast RCNN部分的损失。

2.1 分类损失

公式:

(1)RPN分类损失

RPN网络的产生的anchor只分为前景和背景,前景的标签为1,背景的标签为0。在训练RPN的过程中,会选择256个anchor,256就是公式中的Ncls

可以看到这是一个这经典的二分类交叉熵损失,对于每一个anchor计算对数损失,然后求和除以总的anchor数量Ncls。这部分的代码tensorflow代码如下:

rpn_cls_score = tf.reshape(self._predictions['rpn_cls_score_reshape'], [-1, 2]) #rpn_cls_score = (17100,2)
rpn_label = tf.reshape(self._anchor_targets['rpn_labels'], [-1])  #rpn_label = (17100,)
rpn_select = tf.where(tf.not_equal(rpn_label, -1)) #将不等于-1的labels选出来(也就是正负样本选出来),返回序号
rpn_cls_score = tf.reshape(tf.gather(rpn_cls_score, rpn_select), [-1, 2]) #同时选出对应的分数
rpn_label = tf.reshape(tf.gather(rpn_label, rpn_select), [-1])
rpn_cross_entropy = tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(logits=rpn_cls_score, labels=rpn_label))

假设我们RPN网络的特征图大小为38×50,那么就会产生38×50×9=17100个anchor,然后在RPN的训练阶段会从17100个anchor中挑选Ncls个anchor用来训练RPN的参数,其中挑选为前景的标签为1,背景的标签为0。

  1. 代码第一行将其reshape变为(17100,2),行数表示anchor的数量,列数为前景和背景,表示属于前景和背景的分数。
  2. 代码第二行和第三行,将RPN的label也reshape成(17100,),即分别对应上anchor,然后从中选出不等于-1的,也就是选择出前景和背景,数量为Ncls,返回其index,为rpn_select。
  3. 代码第四行,根据index选择出对应的分数。
  4. 第五行,根据rpn_label和rpn_cls_score计算交叉熵损失。其中reduce_mean函数就是除以个数(Ncls)求平均

(2)Fast RCNN分类损失

RPN的分类损失时二分类的交叉熵损失,而Fast RCNN是多分类的交叉熵损失(当你训练的类别数>2时,这里假定类别数为5)。在Fast RCNN的训练过程中会选出128个rois,即Ncls = 128,标签的值就是0到4。代码为:

cross_entropy = tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(logits=tf.reshape(cls_score, [-1, self._num_classes]), labels=label))

2.2 回归损失

回归损失这块就RPN和Fast RCNN一起讲,公式为:

其中:

  •  是一个向量,表示anchor,RPN训练阶段(rois,FastRCNN阶段)预测的偏移量
  • 是与ti维度相同的向量,表示anchor,RPN训练阶段(rois,FastRCNN阶段)相对于gt实际的偏移量

R是smoothL1 函数,就是我们上面说的,不同之处是这里σ = 3,RPN训练(σ = 1,Fast RCNN训练),

对于每一个anchor 计算完部分后还要乘以P*,如前所述,P*有物体时(positive)为1,没有物体(negative)时为0,意味着只有前景才计算损失,背景不计算损失。inside_weights就是这个作用。

对于和Nreg的解释在RPN训练过程中如下(之所以以RPN训练为前提因为此时batch size = 256,如果是fast rcnn,batchsize = 128):

所以就是outside_weights,没有前景(fg)也没有后景(bg)的为0,其他为1/(bg+fg)=Ncls

代码:

    def _smooth_l1_loss(self, bbox_pred, bbox_targets, bbox_inside_weights, bbox_outside_weights, sigma=1.0, dim=[1]):sigma_2 = sigma ** 2box_diff = bbox_pred - bbox_targets #ti-ti* in_box_diff = bbox_inside_weights * box_diff  #前景才有计算损失的资格abs_in_box_diff = tf.abs(in_box_diff) #x = |ti-ti*|smoothL1_sign = tf.stop_gradient(tf.to_float(tf.less(abs_in_box_diff, 1. / sigma_2))) #判断smoothL1输入的大小,如果x = |ti-ti*|小于就返回1,否则返回0#计算smoothL1损失in_loss_box = tf.pow(in_box_diff, 2) * (sigma_2 / 2.) * smoothL1_sign + (abs_in_box_diff - (0.5 / sigma_2)) * (1. - smoothL1_sign)out_loss_box = bbox_outside_weights * in_loss_boxloss_box = tf.reduce_mean(tf.reduce_sum(out_loss_box,axis=dim))return loss_box

一些感悟

论文中把Ncls,Nreg和都看做是平衡分类损失和回归损失的归一化权重,但是我在看tensorflow代码实现faster rcnn的损失时发现(这里以fast rcnn部分的分类损失和box回归损失为例,如下),可以看到在计算分类损失时,并没有输入Ncls这个参数,只是在计算box回归损失的时候输入了outside_weights这个参数。这时候我才意识到分类损失是交叉熵函数,求和后会除以总数量,除以Ncls已经包含到交叉熵函数本身。

为了平衡两种损失的权重,outside_weights的取值取决于Ncls,而Ncls的取值取决于batch_size。因此才会有

            # RCNN, class losscls_score = self._predictions["cls_score"]label = tf.reshape(self._proposal_targets["labels"], [-1])cross_entropy = tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(logits=tf.reshape(cls_score, [-1, self._num_classes]), labels=label))# RCNN, bbox lossbbox_pred = self._predictions['bbox_pred'] #(128,12)bbox_targets = self._proposal_targets['bbox_targets'] #(128,12)bbox_inside_weights = self._proposal_targets['bbox_inside_weights']#(128,12)bbox_outside_weights = self._proposal_targets['bbox_outside_weights']#(128,12)loss_box = self._smooth_l1_loss(bbox_pred, bbox_targets, bbox_inside_weights, bbox_outside_weights)

【Faster RCNN】损失函数理解相关推荐

  1. Faster RCNN代码理解(Python) ---训练过程

    最近开始学习深度学习,看了下Faster RCNN的代码,在学习的过程中也查阅了很多其他人写的博客,得到了很大的帮助,所以也打算把自己一些粗浅的理解记录下来,一是记录下自己的菜鸟学习之路,方便自己过后 ...

  2. faster rcnn的理解

    结构: faster rcnn是fast rcnn的改进版,一个更快的算法.为了理解faster rcnn,建议读者先理解fast rcnn, fast rcnn结构的理解,可以参考我的一篇博客:fa ...

  3. Faster R-CNN 深入理解 改进方法汇总

    Faster R-CNN 从2015年底至今已经有接近两年了,但依旧还是Object Detection领域的主流框架之一,虽然推出了后续 R-FCN,Mask R-CNN 等改进框架,但基本结构变化 ...

  4. [目标检测] Faster R-CNN 深入理解 改进方法汇总

    Faster R-CNN 从2015年底至今已经有接近两年了,但依旧还是Object Detection领域的主流框架之一,虽然推出了后续 R-FCN,Mask R-CNN 等改进框架,但基本结构变化 ...

  5. 对Faster R-CNN的理解(1)

    目标检测是一种基于目标几何和统计特征的图像分割,最新的进展一般是通过R-CNN(基于区域的卷积神经网络)来实现的,其中最重要的方法之一是Faster R-CNN. 1. 总体结构 Faster R-C ...

  6. faster rcnn 论文理解

    目录(?)[-] 思想 区域生成网络结构 特征提取 候选区域anchor Region Proposal Networks Translation-Invariant Anchors 窗口分类和位置精 ...

  7. faster rcnn resnet_RCNN, Fast R-CNN 与 Faster RCNN理解及改进方法

    RCNN 这个网络也是目标检测的鼻祖了.其原理非常简单,主要通过提取多个Region Proposal(候选区域)来判断位置,作者认为以往的对每个滑动窗口进行检测算法是一种浪费资源的方式.在RCNN中 ...

  8. cnn 回归 坐标 特征图_RCNN, Fast R-CNN 与 Faster RCNN理解及改进方法

    RCNN 这个网络也是目标检测的鼻祖了.其原理非常简单,主要通过提取多个Region Proposal(候选区域)来判断位置,作者认为以往的对每个滑动窗口进行检测算法是一种浪费资源的方式.在RCNN中 ...

  9. Faster R-CNN论文笔记——FR

    转载自:http://blog.csdn.net/qq_17448289/article/details/52871461 在介绍Faster R-CNN之前,先来介绍一些前验知识,为Faster R ...

  10. Faster RCNN解析

    在介绍Faster R-CNN之前,先来介绍一些前验知识,为Faster R-CNN做铺垫. 一.基于Region Proposal(候选区域)的深度学习目标检测算法 Region Proposal( ...

最新文章

  1. vue+element-ui动态生成多级表头,并且将有相同字段下不同子元素合并为同一个...
  2. python游戏设计案例实战pdf_实战案例 | 新蔡规划馆设计方案
  3. Dynamic Graph CNN for Learning on Point Clouds(DGCNN)论文阅读笔记——核心思想:EdgeConv细析
  4. eclipse中安装Tomcat
  5. Smart View 11.1.2.5配置共享连接
  6. Swing数独游戏(二):终盘生成之随机法
  7. 通过IGT-DSER网关实现各品牌PLC之间,PLC与工业机器人(ModbusTCP)之间通讯
  8. OpenCms显示默认作者
  9. 【UCSC Genome Browser】- ClinGen剂量敏感性分析
  10. 程序和算法之间,主要有什么关系?
  11. VS2019 团队资源管理器--Git的使用(二)
  12. 带计算机来学校检讨,校园检讨书
  13. R语言各种假设检验实例整理(常用)
  14. 投融资项目入门和总结
  15. pytorch中torch.isnan()和torch.isfinite()
  16. 腾讯地图只显示某一区域,覆盖图,marker自定义图标和文本标注
  17. 如何深入掌握C语言指针(详解)
  18. 全景拼接python旗舰版
  19. Spark Mllib里的分布式矩阵(行矩阵、带有行索引的行矩阵、坐标矩阵和块矩阵概念、构成)(图文详解)...
  20. 哪些食物会使皮肤变黑?

热门文章

  1. SWMM建模与案例应用
  2. Java性能优化之for循环
  3. html设置表格高宽的代码_设置html表格宽度
  4. heic图片格式转换jpg_如何在Mac上通过简单方法将HEIC图像转换为JPG
  5. Reeder 5.0.3 将RSS阅读体验发挥到极致
  6. OFD文件怎么编辑修改?
  7. WEB数据库管理平台kb-dms:功能简介【一】
  8. 视频解析工具youtube-dl
  9. 设置计算机名和ip 一键,批量设置IP地址和计算机名
  10. 免费随机图片api接口