(HEM/OHEM)hard negative(example)mining难例挖掘 与focal loss、GHM损失函数
目录
- 分类任务中的样本不均衡及hard negative mining的必要性
- hard negative example
- HEM(hard example/negative mining) 与 OHEM(online hard example mining)
- focal loss
- 二分类focal loss
- 多分类focal loss
- GHM
- 参数target, label_weight的关系
- 非one-hot情况下,labels是从0还是从1开始编码
- 难例挖掘的相关领域:长尾分布下的分类
- XMB---Cross-Batch Memory for Embedding Learning
- 由易到难 Curriculum Learning与Self-paced Learning
分类任务中的样本不均衡及hard negative mining的必要性
在训练一个分类器的时候,对数据的要求是class balance,即不同标签的样本量都要充足且相仿。然而,这个要求在现实应用中往往很难得到保证。
在目标检测算法中,对于输入的一张图像,可能会生成成千上万的预选框(region proposal),但是其中只有很少一部分是包含真实目标的,这就带来了类别不均衡问题。
类别不平衡时,无用的易分反例样本(easy negative sample)会使得模型的整体学习方向跑偏,导致无效学习,即只能分辨出没有物体的背景,而无法分辨具体的物体。(因为在使用cross-entropy loss做mini-batch SGD时,是大量的样本产生的loss average之后计算gradient以及实施参数update。这个average的操作是有问题的,因为一个batch里面easy sample占绝大多数,hard sample只占很少部分,如果没有re-weighting或者importance balancing,那么在average之后,hard sample的contribution完全就被easy samples侵蚀平均抹掉了。事实上,往往这些hard samples,才是分类器性能进一步提升的bottleneck(hard sample:很少出现但是现实存在的那些极端样本,比如车辆跟踪中出过事故的车辆。))
hard negative example
那么负样本中哪些是困难负样本(hard negative)呢?困难负样本是指哪些容易被网络预测为正样本的proposal,即假阳性(false positive),
- 分类任务:虽然是负样本,但其预测为正的概率较高(如果p=0.5则判断为正样本,那么p=0.49就属于hard negative)
- 检测任务:如roi里有二分之一个目标时,虽然它仍是负样本,却容易被判断为正样本,这块roi即为hard negative;
- 度量学习:与anchor(正样本)距离较近的负样本就是hard negative
训练hard negative对提升网络的分类性能具有极大帮助,因为它相当于一个错题集。
如何判断它为困难负样本呢?也很简单,我们先用初始样本集(即第一帧随机选择的正负样本)去训练网络,再用训练好的网络去预测负样本集中剩余的负样本,选择其中得分最高,即最容易被判断为正样本的负样本为困难样本。
HEM(hard example/negative mining) 与 OHEM(online hard example mining)
HEM核心思想:
- 分类任务:用分类器对样本进行分类,把其中错误分类的样本(hard negative)放入负样本集合再继续训练分类器。
- 度量学习:度量学习中也有hard example/negative mining的概念。度量学习中的HEM是指找出与anchor(正样本)距离较近的负样本。
如何进行HEM? 重构训练数据集
以目标检测框算法为例:对于目标检测中我们会事先标记处ground truth,然后再算法中会生成一系列proposals,proposals与ground truth的IOU超过一定阈值(通常0.5)的则认定为是正样本,低于一定阈值的则是负样本。然后扔进网络中训练。However,这也许会出现一个问题那就是正样本的数量远远小于负样本,这样训练出来的分类器的效果总是有限的,会出现许多false positive。把其中得分较高的这些false positive当做所谓的Hard negative,既然mining出了这些Hard negative,就把这些扔进网络再训练一次,从而加强分类器判别假阳性的能力。
如何进行OHEM?在训练过程中通过loss(分类loss、roi loss等)进行选取
分类loss:
制定规则去选取hard negative: DenseBox
核心思想:选取与label差别大(分类loss)的作为hard negtive
根据制定的规则选取了hard negative ,在训练的时候加强对hard negative的训练。
In the forward propagation phase, we sort the loss of output pixels in decending order, and assign the top 1% to be hard-negative. In all experiments, we keep all positive labeled pixels(samples) and the ratio of positive and negative to be 1:1. Among all negative samples, half of them are sampled from hard-negative samples, and the remaining half are selected randomly from non-hard negative.ROI loss
一个只读的RoI网络对特征图和所有RoI进行前向传播,然后Hard RoI module利用这些RoI的loss选择B个样本。这些选择出的样本(hard examples)进入RoI网络,进一步进行前向和后向传播。
focal loss
二分类和多分类的focal loss的tensorflow 代码实现
多分类focalloss的pytorch实现
先前也有一些算法来处理类别不均衡的问题,比如OHEM(online hard example mining),OHEM的主要思想可以用原文的一句话概括:In OHEM each example is scored by its loss, non-maximum suppression (nms) is then applied, and a minibatch is constructed with the highest-loss examples。OHEM算法虽然增加了错分类(正、负)样本的权重,但是OHEM算法忽略了容易分类的(正)样本。
因此针对类别不均衡问题,作者提出一种新的损失函数:focal loss。
这个损失函数是在标准交叉熵损失基础上修改得到的。这个函数可以通过减少易分类样本的权重,使得模型在训练时更专注于难分类的样本。
我理解的focal loss 是,利用一个re-weighting factor来modulating(re-weighting)每一个样本的importance,得到一个cost sensitve的classifier。
二分类focal loss
local loss的具体实现在此以二分类cross entropy loss举例:
无focal:
L(p,y)=−(y⋅log(p)+(1−y)⋅log(1−p))L(p, y)=-(y \cdot \log (p)+(1-y) \cdot \log (1-p))L(p,y)=−(y⋅log(p)+(1−y)⋅log(1−p))
有focal(hard negative mining,加大难的负样本权重):
L(p,y)=−(y⋅(1−p)⋅log(p)+(1−y)⋅p⋅log(1−p))L(p, y)=-(y \cdot(1-p) \cdot \log (p)+(1-y) \cdot p \cdot \log (1-p))L(p,y)=−(y⋅(1−p)⋅log(p)+(1−y)⋅p⋅log(1−p))
上式中p是classifer输出的[0,1]之间实数值,为预测概率,y是非0即1的label。将一个一维实数p进行二分类化的操作是令p指示label = 1为正样本的概率,另1-p指示label = 0为负样本的概率。focal loss的核心就是直接用p作为modulating(re-weighting factor),当一个负样本很难时,p略小于0.5;easy negative则是p远小于0.5。所以hard negative mining就体现在给1-p越小的negative(hard negative)乘以一个越大的factor(p)。
更一般化的表示:
二分类任务的交叉熵损失函数公式如下:
令pt代表如下意义
则二分类focal loss为
在此基础上可以进一步引进另一个调整权重的超参数a,
ps:在每一次预测时,真实标签的one-hot向量中,只有一项为1,其它项都为0,所以其focal loss只需计算为1的那一项。
def binary_focal_loss(y_true, y_pred,gamma=2.0, alpha=0.25):# Define epsilon so that the backpropagation will not result in NaN# for 0 divisor caseepsilon = K.epsilon()# Add the epsilon to prediction value#y_pred = y_pred + epsilon# Clip the prediciton valuey_pred = K.clip(y_pred, epsilon, 1.0-epsilon)# Calculate p_tp_t = tf.where(K.equal(y_true, 1), y_pred, 1-y_pred)# Calculate alpha_talpha_factor = K.ones_like(y_true)*alphaalpha_t = tf.where(K.equal(y_true, 1), alpha_factor, 1-alpha_factor)# Calculate cross entropycross_entropy = -K.log(p_t)weight = alpha_t * K.pow((1-p_t), gamma)# Calculate focal lossloss = weight * cross_entropy# Sum the losses in mini_batchloss = K.sum(loss, axis=1)return loss
多分类focal loss
对于多分类:
无focal:
H(p,q)=−∑i=1np(xi)log(q(xi))H(p, q)=-\sum_{i=1}^{n} p\left(x_{i}\right) \log \left(q\left(x_{i}\right)\right)H(p,q)=−∑i=1np(xi)log(q(xi))
在机器学习中,将ground truth当作一个分布(P),将预测作为另一个分布(q),假设有cnum个类别(三分类问题cnum=3, 四分类问题cnum=4),那么就有:
H(p,q)=−∑i=1cnump(ci)logq(ci)H(p, q)=-\sum_{i=1}^{c n u m} p\left(c_{i}\right) \log q\left(c_{i}\right)H(p,q)=−∑i=1cnump(ci)logq(ci)
假设有一个三分类问题,某个样例的正确答案是(1,0,0)。某模型经过Softmax回归之后的预测答案是(0.5,0,4,0.1),那么这个预测和正确答案直接的交叉熵是:
H((1,0,0),(0.5,0.4,0.1))=−(1×log0.5+0×log0.4+0×log0.1)≈0.3\mathrm{H}((1,0,0),(0.5,0.4,0.1))=-(1 \times \log 0.5+0 \times \log 0.4+0 \times \log 0.1) \approx 0.3H((1,0,0),(0.5,0.4,0.1))=−(1×log0.5+0×log0.4+0×log0.1)≈0.3
有focal(hard negative mining,加大难的负样本权重):
在多分类任务中,
上述公式中的Pt用如下向量代替(其中乘法是向量内各元素的乘法,输出的pt是一个向量)
由于log函数中不能出现0,又因为Pgrountruth中只有一项不为0,所以将log函数中的Pgrountruth提到公式最前面:
该表达式输出为一个向量(只有一项不为0),在此基础上用一个求和函数将各元素相加即为最终的loss值。在numpy中用reduce_sum或recuce_max函数实现.
https://blog.csdn.net/qq_39012149/article/details/96184383
在这篇文章中,描述了多分类任务中,如何使用a参数(定义一个超参数数组,用于平衡各类别之间的权重)*
从公式和代码看,多分类并没有直接寻找到hard negative example,而是当正样本被预测道的概率较低时,将其对应的交叉熵的权重加大。(因为其公式中,groud_truth是one_hot表示,在计算交叉熵时,只有‘1’对应的正样本对应的预测概率被用上了,其它各个负样本由于对应的one_hot表示为’0’,所以没有真正计算进去)
# -*- coding: utf-8 -*-
import tensorflow as tf"""
Tensorflow实现何凯明的Focal Loss, 该损失函数主要用于解决分类问题中的类别不平衡
focal_loss_sigmoid: 二分类loss
focal_loss_softmax: 多分类loss
Reference Paper : Focal Loss for Dense Object Detection
"""def focal_loss_sigmoid(labels,logits,alpha=0.25,gamma=2):"""Computer focal loss for binary classificationArgs:labels: A int32 tensor of shape [batch_size].logits: A float32 tensor of shape [batch_size].alpha: A scalar for focal loss alpha hyper-parameter. If positive samples number> negtive samples number, alpha < 0.5 and vice versa.gamma: A scalar for focal loss gamma hyper-parameter.Returns:A tensor of the same shape as `lables`"""y_pred=tf.nn.sigmoid(logits)labels=tf.to_float(labels)L=-labels*(1-alpha)*((1-y_pred)*gamma)*tf.log(y_pred)-\(1-labels)*alpha*(y_pred**gamma)*tf.log(1-y_pred)return Ldef focal_loss_softmax(labels,logits,gamma=2):"""Computer focal loss for multi classificationArgs:labels: A int32 tensor of shape [batch_size].logits: A float32 tensor of shape [batch_size,num_classes].gamma: A scalar for focal loss gamma hyper-parameter.Returns:A tensor of the same shape as `lables`"""y_pred=tf.nn.softmax(logits,dim=-1) # [batch_size,num_classes]# To avoid divided by zeroepsilon = 1.e-7y_pred += tf.epsilon() #一个很小的数值labels=tf.one_hot(labels,depth=y_pred.shape[1])L=-labels*((1-y_pred)**gamma)*tf.log(y_pred) #输出一个向量,向量中只有一项不为0#将向量变为标量,由于向量中只有一项不为0,所以也可用reduce_max()L=tf.reduce_sum(L,axis=1) #将向量变为标量。return Lif __name__ == '__main__':logits=tf.random_uniform(shape=[5],minval=-1,maxval=1,dtype=tf.float32)labels=tf.Variable([0,1,0,0,1])loss1=focal_loss_sigmoid(labels=labels,logits=logits)logits2=tf.random_uniform(shape=[5,4],minval=-1,maxval=1,dtype=tf.float32)labels2=tf.Variable([1,0,2,3,1])loss2=focal_loss_softmax(labels==labels2,logits=logits2)with tf.Session() as sess:sess.run(tf.global_variables_initializer())print sess.run(loss1)print sess.run(loss2)
keras版本实现:
https://github.com/maozezhong/focal_loss_multi_class/blob/master/focal_loss.py
reference : https://spaces.ac.cn/archives/4493(提到了加入另外的防过拟合的函数)
GHM
官方pytorch实现
参数target, label_weight的关系
官方代码中用于分类问题的GHMC损失函数的部分代码如下:
def forward(self, pred, target, label_weight, *args, **kwargs):"""Calculate the GHM-C loss.Args:pred (float tensor of size [batch_num, class_num]):The direct prediction of classification fc layer.target (float tensor of size [batch_num, class_num]):Binary class target for each sample.label_weight (float tensor of size [batch_num, class_num]):the value is 1 if the sample is valid and 0 if ignored.Returns:The gradient harmonized loss."""# the target should be binary class labelif pred.dim() != target.dim():target, label_weight = _expand_binary_labels(target, label_weight, pred.size(-1))target, label_weight = target.float(), label_weight.float()edges = self.edgesmmt = self.momentumweights = torch.zeros_like(pred)
需要注意的是,参数target要求是one-hot编码形式,如果不是one-hot形式,则要通过_expand_binary_labels扩展成one-hot形式。
而参数label_weight则表示该标签是否要进行GHM操作,默认都是全1。而且,label_weight的维度必须与target保持一致,即如果target采用one-hot(形如 [batch_num, class_num]),则label_weight的size也是 [batch_num, class_num],如果target的size是 [batch_num],那么label_weight的size也必须是[batch_num]。
非one-hot情况下,labels是从0还是从1开始编码
官方代码中,_expand_binary_labels定义如下
def _expand_binary_labels(labels, label_weights, label_channels):bin_labels = labels.new_full((labels.size(0), label_channels), 0)inds = torch.nonzero(labels >= 1).squeeze()if inds.numel() > 0:bin_labels[inds, labels[inds] - 1] = 1bin_label_weights = label_weights.view(-1, 1).expand(label_weights.size(0), label_channels)return bin_labels, bin_label_weights
官方代码这种写法,认为labels(类别)的编码是从1开始的,不是从0开始的。
但很多情况下,我们输入的labels(类别)编码是从0开始的。因此需要对代码进行修改,如下所示:
def _expand_binary_labels(self, labels, label_weights, label_channels):# expand labelsbin_labels = labels.new_full((labels.size(0), label_channels), 0)# inds = torch.nonzero(labels >= 1).squeeze() #假设labels是从1开始编号inds = torch.nonzero(labels >= 0).squeeze() #加谁labels是从0开始编号if inds.numel() > 0:# bin_labels[inds, labels[inds] - 1] = 1 #假设labels是从1开始编号bin_labels[inds, labels[inds]] = 1 #假设labels是从0开始编号# expand label_weights(label_weight should with size [batch_num], otherwise the function "expand" cannot work)bin_label_weights = label_weights.view(-1, 1).expand(label_weights.size(0), label_channels)return bin_labels, bin_label_weights
难例挖掘的相关领域:长尾分布下的分类
综述 | 长尾(不均衡)分布下图像分类(2019-2020)
Class-Balanced Loss Based on Effective Number of Samples
解读:使用一个特别设计的损失来处理类别不均衡的数据集Learning to Reweight Examples for Robust Deep Learning
源码
知乎解读:https://zhuanlan.zhihu.com/p/37477502
更简单的理解就是,以前我们算training loss的时候,选取一个mini batch。现在,为了给每个batch内的样本重新分配权重,使用一个valid set中的mini batch计算validation loss, 根据validation loss计算权重。training loss 更新的是模型参数,而validation loss更新的是权重,也就是超参数。这就是元学习。
如果训练样本分布和验证样本分布相似,它们的梯度方向也接近,那么这样的样本比较好,需要增加权重,反之则需要降低权重。这样的样本权重分配方式能够使得模型变得无偏
XMB—Cross-Batch Memory for Embedding Learning
跨越时空的难样本挖掘
代码:https://link.zhihu.com/?target=https%3A//github.com/msight-tech/research-xbm
由易到难 Curriculum Learning与Self-paced Learning
【论文精读】Curriculum Learning
自步学习(Self-paced Learning)
(HEM/OHEM)hard negative(example)mining难例挖掘 与focal loss、GHM损失函数相关推荐
- OHEM(Online Hard Example Mining)在线难例挖掘(在线困难样例挖掘) HNM (目标检测)
Hard Negatie Mining与Online Hard Example Mining(OHEM)都属于难例挖掘,它是解决目标检测老大难问题的常用办法,运用于R-CNN,fast R-CNN,f ...
- 目标检测之六:OHEM 在线难例挖掘
https://zhuanlan.zhihu.com/p/102817180 6.OHEM 在线难例挖掘 OHEM(Online Hard negative Example Mining,在线难例挖掘 ...
- 学习了解online hard example mining在线难例挖掘
对于每一个网络,相当于一个桶,总有样本效果比较好,有的样本比较差,多用效果差的样本进行训练,那提高了整个网络的短板,总体的效果也会有提升. 一. 难例挖掘是指,针对模型训练过程中导致损失值很大的一些样 ...
- 在线难例挖掘(OHEM)
OHEM(online hard example miniing) 详细解读一下OHEM的实现代码: def ohem_loss(batch_size, cls_pred, cls_target, l ...
- OHEM,Focal loss,GHM loss二分类pytorch代码实现(减轻难易样本不均衡问题)
https://mp.weixin.qq.com/s/iOAICJege2b0pCVxPkvNiA 综述:解决目标检测中的样本不均衡问题 该综述主要介绍了OHEM,Focal loss,GHM los ...
- 目标检测中的样本不平衡处理方法——OHEM, Focal Loss, GHM, PISA
GitHub 简书 CSDN 文章目录 1. 前言 2. OHEM 3. Focal Loss 3.1 Cross Entropy 3.2 Balanced Cross Entropy 3.3 Foc ...
- 深度学习之 hard negative mining (难例挖掘)
Hard Negative Mining Method 思想 hard是困难样本,negative是负样本,hard negative就是说在对负样本分类时候,loss比较大(label与predic ...
- 跨越『时空』的难样本挖掘!
作者 | 王珣 整理 | NewBeeNLP 我们码隆科技在深度度量学习继续深耕,又做了一点点改进的工作,承蒙审稿人厚爱,被CVPR-2020接收为Oral,并进入best paper候选(共26篇文 ...
- OHEM在线难样例挖掘的两个细节
代码上如何实现Read-only Layer与R _hard-sel权限共享? https://github.com/abhi2610/ohem/blob/master/models/pascal_v ...
最新文章
- 十五天精通WCF——第六天 你必须要了解的3种通信模式
- 计算机软件可以一次摊销吗,研发用无形资产可以一次摊销吗
- github项目怎么运行_利用 GitHub 从零开始搭建一个博客
- python之sys
- 宝塔php漏洞,[安全预警]关于最近宝塔闹得很厉害的PMA漏洞BUG
- Hashtable、HashMap、TreeMap总结
- SAP License:市场需要双重SAP顾问
- 如何关闭139端口及445端口等危险端口
- 关于SqlDataReader类型的变量传值问题
- 【JAVA】利用MOM消息队列技术实现分布式随机信号分析系统
- CI框架redirect自动加上了index.php问题
- WIN10 注册表添加启动项
- winform直接控制云台_智云和快手发布重磅功能,手机云台升级,帮8成网民拍大片...
- 【English】元音辅音
- Web服务器站点设置和IIS安装设置图解
- 送给女朋友的3D立体动态相册的实现代码
- Vivado仿真报错合集(更新中)
- 从数据中台到全链路数据生产力
- PPT制作创意封面如何排版设计?
- Java爬虫之下载全世界国家的国旗图片
热门文章
- (1)Kurento之WebRTC通信架构
- (翻译)Understanding Convolutional Neural Networks for NLP
- win10推送_win10无线镜像投屏电视
- android简单实现表格布局,Android开发中TableLayout表格布局
- 百度官方:网站优化中死链处理指南与总结
- 2维正态分布-矩阵表示-推导过程
- Error: Can't find Python executable python, you can set the PYTHON env variable.解决办法
- Dijkstra算法求解单源最短路径问题
- JAVAWeb01-BS架构简述、HTML
- 【Linux】wget命令的使用