【本期推荐专题】物联网从业人员必读:华为云专家为你详细解读LiteOS各模块开发及其实现原理。

摘要:Focal Loss的两个性质算是核心,其实就是用一个合适的函数去度量难分类和易分类样本对总的损失的贡献。

本文分享自华为云社区《技术干货 | 基于MindSpore更好的理解Focal Loss》,原文作者:chengxiaoli。

今天更新一下恺明大神的Focal Loss,它是 Kaiming 大神团队在他们的论文Focal Loss for Dense Object Detection提出来的损失函数,利用它改善了图像物体检测的效果。ICCV2017RBG和Kaiming大神的新作(https://arxiv.org/pdf/1708.02002.pdf)。

使用场景

最近一直在做人脸表情相关的方向,这个领域的 DataSet 数量不大,而且往往存在正负样本不均衡的问题。一般来说,解决正负样本数量不均衡问题有两个途径:

1. 设计采样策略,一般都是对数量少的样本进行重采样

2. 设计 Loss,一般都是对不同类别样本进行权重赋值

本文讲的是第二种策略中的 Focal Loss。

理论分析

论文分析

我们知道object detection按其流程来说,一般分为两大类。一类是two stage detector(如非常经典的Faster R-CNN,RFCN这样需要region proposal的检测算法),第二类则是one stage detector(如SSD、YOLO系列这样不需要region proposal,直接回归的检测算法)。

对于第一类算法可以达到很高的准确率,但是速度较慢。虽然可以通过减少proposal的数量或降低输入图像的分辨率等方式达到提速,但是速度并没有质的提升。

对于第二类算法速度很快,但是准确率不如第一类。

所以目标就是:focal loss的出发点是希望one-stage detector可以达到two-stage detector的准确率,同时不影响原有的速度

So,Why?and result?

这是什么原因造成的呢?the Reason is:Class Imbalance(正负样本不平衡),样本的类别不均衡导致的。

我们知道在object detection领域,一张图像可能生成成千上万的candidate locations,但是其中只有很少一部分是包含object的,这就带来了类别不均衡。那么类别不均衡会带来什么后果呢?引用原文讲的两个后果:

(1) training is inefficient as most locations are easy negatives that contribute no useful learning signal;
(2) en masse, the easy negatives can overwhelm training and lead to degenerate models.

意思就是负样本数量太大(属于背景的样本),占总的loss的大部分,而且多是容易分类的,因此使得模型的优化方向并不是我们所希望的那样。这样,网络学不到有用的信息,无法对object进行准确分类。其实先前也有一些算法来处理类别不均衡的问题,比如OHEM(online hard example mining),OHEM的主要思想可以用原文的一句话概括:In OHEM each example is scored by its loss, non-maximum suppression (nms) is then applied, and a minibatch is constructed with the highest-loss examples。OHEM算法虽然增加了错分类样本的权重,但是OHEM算法忽略了容易分类的样本。

因此针对类别不均衡问题,作者提出一种新的损失函数:Focal Loss,这个损失函数是在标准交叉熵损失基础上修改得到的。这个函数可以通过减少易分类样本的权重,使得模型在训练时更专注于难分类的样本。为了证明Focal Loss的有效性,作者设计了一个dense detector:RetinaNet,并且在训练时采用Focal Loss训练。实验证明RetinaNet不仅可以达到one-stage detector的速度,也能有two-stage detector的准确率。

公式说明

介绍focal loss,在介绍focal loss之前,先来看看交叉熵损失,这里以二分类为例,原来的分类loss是各个训练样本交叉熵的直接求和,也就是各个样本的权重是一样的。公式如下: 因为是二分类,p表示预测样本属于1的概率(范围为0-1),y表示label,y的取值为{+1,-1}。当真实label是1,也就是y=1时,假如某个样本x预测为1这个类的概率p=0.6,那么损失就是-log(0.6),注意这个损失是大于等于0的。如果p=0.9,那么损失就是-log(0.9),所以p=0.6的损失要大于p=0.9的损失,这很容易理解。这里仅仅以二分类为例,多分类分类以此类推为了方便,用pt代替p,如下公式2:。这里的pt就是前面Figure1中的横坐标。 为了表示简便,我们用p_t表示样本属于true class的概率。所以(1)式可以写成: 接下来介绍一个最基本的对交叉熵的改进,也将作为本文实验的baseline,既然one-stage detector在训练的时候正负样本的数量差距很大,那么一种常见的做法就是给正负样本加上权重,负样本出现的频次多,那么就降低负样本的权重,正样本数量少,就相对提高正样本的权重。因此可以通过设定  的值来控制正负样本对总的loss的共享权重。  取比较小的值来降低负样本(多的那类样本)的权重。 

显然前面的公式3虽然可以控制正负样本的权重,但是没法控制容易分类和难分类样本的权重,于是就有了Focal Loss,这里的γ称作focusing parameter,γ>=0,称为调制系数:

为什么要加上这个调制系数呢?目的是通过减少易分类样本的权重,从而使得模型在训练时更专注于难分类的样本。

通过实验发现,绘制图看如下Figure1,横坐标是pt,纵坐标是loss。CE(pt)表示标准的交叉熵公式,FL(pt)表示focal loss中用到的改进的交叉熵。Figure1中γ=0的蓝色曲线就是标准的交叉熵损失(loss)。

这样就既做到了解决正负样本不平衡,也做到了解决easy与hard样本不平衡的问题。

结论

作者将类别不平衡作为阻碍one-stage方法超过top-performing的two-stage方法的主要原因。为了解决这个问题,作者提出了focal loss,在交叉熵里面用一个调整项,为了将学习专注于hard examples上面,并且降低大量的easy negatives的权值。是同时解决了正负样本不平衡以及区分简单与复杂样本的问题

MindSpore代码实现

我们来看一下,基于MindSpore实现Focal Loss的代码:

import mindsporeimport mindspore.common.dtype as mstypefrom mindspore.common.tensor import Tensorfrom mindspore.common.parameter import Parameterfrom mindspore.ops import operations as Pfrom mindspore.ops import functional as Ffrom mindspore import nnclass FocalLoss(_Loss):def __init__(self, weight=None, gamma=2.0, reduction='mean'):super(FocalLoss, self).__init__(reduction=reduction)# 校验gamma,这里的γ称作focusing parameter,γ>=0,称为调制系数self.gamma = validator.check_value_type("gamma", gamma, [float])if weight is not None and not isinstance(weight, Tensor):raise TypeError("The type of weight should be Tensor, but got {}.".format(type(weight)))self.weight = weight# 用到的mindspore算子self.expand_dims = P.ExpandDims()self.gather_d = P.GatherD()self.squeeze = P.Squeeze(axis=1)self.tile = P.Tile()self.cast = P.Cast()def construct(self, predict, target):targets = target# 对输入进行校验_check_ndim(predict.ndim, targets.ndim)_check_channel_and_shape(targets.shape[1], predict.shape[1])_check_predict_channel(predict.shape[1])# 将logits和target的形状更改为num_batch * num_class * num_voxels.if predict.ndim > 2:predict = predict.view(predict.shape[0], predict.shape[1], -1) # N,C,H,W => N,C,H*Wtargets = targets.view(targets.shape[0], targets.shape[1], -1) # N,1,H,W => N,1,H*W or N,C,H*Welse:predict = self.expand_dims(predict, 2) # N,C => N,C,1targets = self.expand_dims(targets, 2) # N,1 => N,1,1 or N,C,1# 计算对数概率log_probability = nn.LogSoftmax(1)(predict)# 只保留每个voxel的地面真值类的对数概率值。if target.shape[1] == 1:log_probability = self.gather_d(log_probability, 1, self.cast(targets, mindspore.int32))log_probability = self.squeeze(log_probability)# 得到概率probability = F.exp(log_probability)if self.weight is not None:convert_weight = self.weight[None, :, None]  # C => 1,C,1convert_weight = self.tile(convert_weight, (targets.shape[0], 1, targets.shape[2])) # 1,C,1 => N,C,H*Wif target.shape[1] == 1:convert_weight = self.gather_d(convert_weight, 1, self.cast(targets, mindspore.int32))  # selection of the weights  => N,1,H*Wconvert_weight = self.squeeze(convert_weight)  # N,1,H*W => N,H*W# 将对数概率乘以它们的权重probability = log_probability * convert_weight# 计算损失小批量weight = F.pows(-probability + 1.0, self.gamma)if target.shape[1] == 1:loss = (-weight * log_probability).mean(axis=1)  # Nelse:loss = (-weight * targets * log_probability).mean(axis=-1)  # N,Creturn self.get_loss(loss)

使用方法如下:

from mindspore.common import dtype as mstypefrom mindspore import nnfrom mindspore import Tensorpredict = Tensor([[0.8, 1.4], [0.5, 0.9], [1.2, 0.9]], mstype.float32)target = Tensor([[1], [1], [0]], mstype.int32)focalloss = nn.FocalLoss(weight=Tensor([1, 2]), gamma=2.0, reduction='mean')output = focalloss(predict, target)print(output)0.33365273

Focal Loss的两个重要性质

1. 当一个样本被分错的时候,pt是很小的,那么调制因子(1-Pt)接近1,损失不被影响;当Pt→1,因子(1-Pt)接近0,那么分的比较好的(well-classified)样本的权值就被调低了。因此调制系数就趋于1,也就是说相比原来的loss是没有什么大的改变的。当pt趋于1的时候(此时分类正确而且是易分类样本),调制系数趋于0,也就是对于总的loss的贡献很小。

2. 当γ=0的时候,focal loss就是传统的交叉熵损失,当γ增加的时候,调制系数也会增加。 专注参数γ平滑地调节了易分样本调低权值的比例。γ增大能增强调制因子的影响,实验发现γ取2最好。直觉上来说,调制因子减少了易分样本的损失贡献,拓宽了样例接收到低损失的范围。当γ一定的时候,比如等于2,一样easy example(pt=0.9)的loss要比标准的交叉熵loss小100+倍,当pt=0.968时,要小1000+倍,但是对于hard example(pt < 0.5),loss最多小了4倍。这样的话hard example的权重相对就提升了很多。

这样就增加了那些误分类的重要性Focal Loss的两个性质算是核心,其实就是用一个合适的函数去度量难分类和易分类样本对总的损失的贡献。

点击关注,第一时间了解华为云新鲜技术~

技术干货 | 基于MindSpore更好的理解Focal Loss相关推荐

  1. 技术干货 | 基于MindSpore详解Perplexity语言模型评价指标

    01 原理介绍 在研究生实习时候就做过语言模型的任务,当时让求PPL值,当时只是调包,不求甚解,哈哈哈,当时也没想到现在会开发这个评价指标,那现在我来讲一下我对这个指标的了解,望各位大佬多多指教. 这 ...

  2. 技术干货 | 基于 MindSpore 实现图像分割之豪斯多夫距离

    今天带来的内容是Hausdorff distance 豪斯多夫距离的原理介绍及MindSpore的实现代码. 当我们评价图像分割的质量和模型表现时,经常会用到各类表面距离的计算.比如: · Mean ...

  3. 技术干货 | 基于MindSpore的图算融合探索和实践

    还记得一年前的MindSpore 0.5版本中,我们怀着忐忑的心情首次发布了图算融合特性,主要支持了昇腾硬件平台.在该版本中,图算融合初步实现了图算一体的DSL表达能力,并打开了图算协同优化的大门. ...

  4. 技术干货 | 基于 MindSpore 实现图像分割之平均表面距离

    今天为大家带来的内容是Mean surface distance 平均表面距离的原理介绍及MindSpore的实现代码. 当我们评价图像分割的质量和模型表现时,经常会用到各类表面距离的计算.比如: M ...

  5. 技术干货 | 基于 Qt Quick Plugin 快速构建桌面端跨平台组件

    导读:桌面端的 UI 开发框架对比移动端.Web 端的成熟方案,一直处于不温不火的状态.随着疫情掀起的风波,桌面端在线教育.视频会议等需求不断涌现.本文将围绕 Qt Quick 的优势来介绍如何快速创 ...

  6. 如何理解Focal Loss?

    Focal Loss 预科知识点 本处使用英文更加准确,中文过于简洁精炼不利于初学者理解.其实各类专业术语的翻译,只要我们理解其本质,不同的表述都不会影响我们对问题的掌控.有些翻译为晦涩绕口的中文,其 ...

  7. 【CV】10分钟理解Focal loss数学原理与Pytorch代码

    原文链接:https://amaarora.github.io/2020/06/29/FocalLoss.html 原文作者:Aman Arora Focal loss 是一个在目标检测领域常用的损失 ...

  8. pytorch gather_【CV】10分钟理解Focal loss数学原理与Pytorch代码

    原文链接:https://amaarora.github.io/2020/06/29/FocalLoss.html 原文作者:Aman Arora Focal loss 是一个在目标检测领域常用的损失 ...

  9. 简单理解Focal Loss

    Focal Loss用来解决的是类别不均衡问题,其 α \alpha α变体的公式长下面这样: F L ( p t ) = − α t ( 1 − p t ) γ log ⁡ ( p t ) \mat ...

最新文章

  1. Win8 x64 + Office Word 2013 x64 无法自动加载 Endnote X6 的解决方案
  2. Xamarin.Forms项目无法添加服务引用
  3. ffmpeg硬解码视频文件播放器
  4. 每日程序C语言40-不带头结点的尾插法创建链表
  5. 移动开发Html 5前端性能优化指南
  6. 【Ubuntu-Docker】ubuntu16.04(18.04)Docker安装配置与卸载
  7. 【CF487E】Tourists【圆方树】【树链剖分】【multiset】
  8. Map集合使用get方法返回null抛出空指针异常问题
  9. vuex的命名空间有哪些_javascript – vuex中模块的命名空间究竟是什么
  10. 【linux指令】sed指令
  11. linux 项目内存吃掉,Linux内存被吃掉了,它去哪里了?
  12. 20145307《信息安全系统设计基础》第二周学习总结
  13. JAVA线程池的创建
  14. 服务器ip维护无法登录,用DHCP解决服务器硬件管理口没有设置IP无法登录的问题...
  15. 一位38岁的老码农在退休前和年轻码农们的互动
  16. 2 EDA技术实用教程【Verilog 语句基本类型】
  17. Gsonformat插件安装与使用
  18. 【Unity3D日常BUG】Unity3D 中听不到声音解决方案
  19. python 趋势跟踪算法_DualThrust区间突破策略Python版
  20. 客户画像中的聚类分析

热门文章

  1. CSS 制作下拉导航
  2. A*算法一个简单的记录
  3. 学习笔记 vs19 报错:E1696 C++ 无法打开 源 文件
  4. php网站商品图片上传代码,PHP实现图片上传代码
  5. python读取txt中的一列称为_python读取中文txt文本的方法
  6. vscode比较整个文件夹_vscode开发ROS1(5)-ROS工程目录结构
  7. php version 5.5.17-1~dotdeb.1,Ubuntu 12.04使用Dotdeb安装PHP5.4 / Nginx1.4/Redis2.6等新版本
  8. Echarts初体验
  9. Linux下ftp的安装配置
  10. JavaScript事件基础知识总结【思维导图】