https://mp.weixin.qq.com/s/iOAICJege2b0pCVxPkvNiA
综述:解决目标检测中的样本不均衡问题
该综述主要介绍了OHEM,Focal loss,GHM loss;由于我这的二分类数据集不存在正负样本不均衡的问题,所以着重看了处理难易样本不均衡(正常情况下,容易的样本较多,困难的样本较少);由于我只是分类问题,所以写了各种分类的loss,且网络的最后一层为softmax,所以网络输出的pred是softmax层前的logits经过softmax后的结果,普通的交叉熵损失即为sum(-gt*log(pred)),但torch.nn.CrossEntropyLoss()中会对于输入的pred再进行一次softmax,所以这里使用torch.nn.NLLLoss代替,当然经测试,即使网络最后一层使用softmax损失函数还是使用torch.nn.CrossEntropyLoss(),效果和使用torch.nn.NLLLoss差不多。。。

OHEM:
代码参考:https://www.codeleading.com/article/7442852142/

def ohem_loss(pred, target, keep_num):loss = torch.nn.NLLLoss(reduce=False)(torch.log(pred), target)print(loss)loss_sorted, idx = torch.sort(loss, descending=True)loss_keep = loss_sorted[:keep_num]return loss_keep.sum() / keep_num

Focal loss:
详解:原论文Focal Loss for Dense Object Detection
代码参考:https://zhuanlan.zhihu.com/p/80594704

def focal_loss(pred,target,gamma=0.5):pred_temp=pred.detach().cpu()target_temp=target.detach().cpu()pt = torch.tensor([pred_temp[i,target_temp[i]] for i in range(target_temp.shape[0])])focal_weight = (1-pt).pow(gamma)return torch.mean((torch.nn.NLLLoss(reduce=False)(torch.log(pred), target)).mul(focal_weight.to(device).detach()))

GHM loss:
详解:https://zhuanlan.zhihu.com/p/80594704
代码参考:https://github.com/DHPO/GHM_Loss.pytorch/blob/master/GHM_loss.py

class GHM_Loss(nn.Module):def __init__(self, bins, alpha):super(GHM_Loss, self).__init__()self._bins = binsself._alpha = alphaself._last_bin_count = Nonedef _g2bin(self, g):return torch.floor(g * (self._bins - 0.0001)).long()def _custom_loss(self, x, target, weight):raise NotImplementedErrordef _custom_loss_grad(self, x, target):raise NotImplementedErrordef forward(self, x, target):g = torch.abs(self._custom_loss_grad(x, target))bin_idx = self._g2bin(g)bin_count = torch.zeros((self._bins))for i in range(self._bins):bin_count[i] = (bin_idx == i).sum().item()N = x.size(0)nonempty_bins = (bin_count > 0).sum().item()gd = bin_count * nonempty_binsgd = torch.clamp(gd, min=0.0001)beta = N / gdreturn self._custom_loss(x, target, beta[bin_idx])class GHMC_Loss(GHM_Loss):def __init__(self, bins, alpha):super(GHMC_Loss, self).__init__(bins, alpha)def _custom_loss(self, x, target, weight):return torch.sum((torch.nn.NLLLoss(reduce=False)(torch.log(x),target)).mul(weight.to(device).detach()))/torch.sum(weight.to(device).detach())def _custom_loss_grad(self, x, target):x=x.cpu().detach()target=target.cpu()return torch.tensor([x[i,target[i]] for i in range(target.shape[0])])-target

OHEM,Focal loss,GHM loss二分类pytorch代码实现(减轻难易样本不均衡问题)相关推荐

  1. 类别不均衡问题之loss大集合:focal loss, GHM loss, dice loss 等等

    数据类别不均衡问题应该是一个极常见又头疼的的问题了.最近在工作中也是碰到这个问题,花了些时间梳理并实践了类别不均衡问题的解决方式,主要实践了"魔改"loss(focal loss, ...

  2. 《Focal Loss GHM Loss Dice Los》论文笔记

    Focal Loss 在二分类问题中,交叉熵损失定义如下: yyy 表示真实值,取值0与1,ppp表示模型预测正类的概率,取值0到1. 为了表述方便,将上述公式重新表述为: 对于类别不平衡问题,我们可 ...

  3. 吴恩达《机器学习》学习笔记七——逻辑回归(二分类)代码

    吴恩达<机器学习>学习笔记七--逻辑回归(二分类)代码 一.无正则项的逻辑回归 1.问题描述 2.导入模块 3.准备数据 4.假设函数 5.代价函数 6.梯度下降 7.拟合参数 8.用训练 ...

  4. 深度学习 神经网络(5)逻辑回归二分类-Pytorch实现乳腺癌预测

    深度学习 神经网络 逻辑回归二分类-乳腺癌预测 一.前言 二.代码实现 2.1 引入依赖库 2.2 加载并查看数据集 2.3 数据处理 2.4 数据分割 2.5 迭代训练 2.6 数据验证 一.前言 ...

  5. 寻找解决样本不均衡方法之Focal Loss与GHM

    寻找解决样本不均衡方法之Focal Loss与GHM 主要参考资料:5分钟理解Focal Loss与GHM--解决样本不平衡利器 - 知乎 (zhihu.com) Focal Loss的引入主要是为了 ...

  6. 样本不均衡问题 (OHEM, Focal loss)

    目录 不均衡问题分析 正负样本不均衡 难易样本不均衡 类别间样本不均衡 常用的解决方法 在线难样本挖掘: OHEM 难负样本挖掘 (Hard Negative Mining, HNM) 在线难样本挖掘 ...

  7. 目标检测中的样本不平衡处理方法——OHEM, Focal Loss, GHM, PISA

    GitHub 简书 CSDN 文章目录 1. 前言 2. OHEM 3. Focal Loss 3.1 Cross Entropy 3.2 Balanced Cross Entropy 3.3 Foc ...

  8. Tensorlfow2.0 二分类和多分类focal loss实现和在文本分类任务效果评估

    Tensorlfow2.0 二分类和多分类focal loss实现和在文本分类任务效果评估 前言 二分类 focal loss 多分类 focal loss 测试结果 二分类focal_loss结果 ...

  9. pytorch 12 支持任意维度数据的梯度平衡机制GHM Loss的实现(支持ignore_index、class_weight,支持反向传播训练,支持多分类)

    梯度平衡机制GHM(Gradient Harmonized Mechanism) Loss是Focal loss的升级版,源自论文Gradient Harmonized Single-stage De ...

最新文章

  1. UIViewController、UINavigationController与UITabBarController的整合使用
  2. Git——如何将本地项目提交至远程仓库(第一次)
  3. 【Paper】2021_领导-跟随多智能体系统容错一致性控制
  4. php解析multilpart,使用PHP语言实现POP3邮件的解码
  5. 为了使界面组件更圆滑,Swing,且跨系统
  6. CTL_CODE说明
  7. mt4 指标 涨跌幅 颜色k线_通达信精选指标——彩色K线指标
  8. 计算机图形学E5——OpenGL 扫描线填充
  9. mysql表空间过大_详解MySQL表空间以及ibdata1文件过大问题
  10. android实现nfc支付宝,支付宝首次支持NFC与LBS 实现快速手机支付
  11. 第七章文件与数格式化
  12. 帝国CMS灵动标签如何调用父子栏目连接和名称导航
  13. 论文速递-ANALYSIS OF VISUAL REASONING ON ONE-STAGE OBJECT DETECTION
  14. 从头开始实现Java多人联机游戏(飞机大战)源码粘贴即用
  15. 实验吧-天网管理系统
  16. Activiti工作流会签与获取下一节点任务信息
  17. JavaSE探赜索隐之乾坤袋(集合)
  18. arcgis 圈选获取图层下点位_关于Arcgis这62个常用技巧,你造吗
  19. C语言(二级基础知识2)
  20. python dfs

热门文章

  1. 如何查看yandex文字搜索广告的搜索词?
  2. Framer:开源原型设计工具,巨头们的心头好
  3. 交互设计软件Framer X for mac软件测评
  4. 最完美的公式——欧拉公式
  5. java 基础知识学习2
  6. Android的16ms和垂直同步以及三重缓存
  7. VS2012编译和调用gdal
  8. word 2007 删除表格内容
  9. JavaEE体系架构
  10. 直播平台基本功能解读:以呆萌直播为例的技术剖析