Pytorch中Balance binary cross entropy自定义实现

balance binary cross entropy损失函数在分割任务中很有用,因为分割任务会遇到正负样本不均的问题,甚至在边缘的分割任务重,样本不均衡达到了很高的比例。

故此,个人在基于分割任务中,自实现了该损失函数,亲测有效!

import torch
import torch.nn as nn
import torch.nn.functional as Ffrom ..builder import LOSSES
from .utils import weight_reduce_lossdef cross_entropy(pred, label, weight=None, reduction='mean', avg_factor=None):# element-wise lossesloss = F.cross_entropy(pred, label, reduction='none')# apply weights and do the reductionif weight is not None:weight = weight.float()loss = weight_reduce_loss(loss, weight=weight, reduction=reduction, avg_factor=avg_factor)return lossdef _expand_binary_labels(labels, label_weights, label_channels):bin_labels = labels.new_full((labels.size(0), label_channels), 0)inds = torch.nonzero(labels >= 1).squeeze()if inds.numel() > 0:bin_labels[inds, labels[inds] - 1] = 1if label_weights is None:bin_label_weights = Noneelse:bin_label_weights = label_weights.view(-1, 1).expand(label_weights.size(0), label_channels)return bin_labels, bin_label_weightsdef binary_cross_entropy(pred,label,weight=None,reduction='mean',avg_factor=None):if pred.dim() != label.dim():label, weight = _expand_binary_labels(label, weight, pred.size(-1))# weighted element-wise lossesif weight is not None:weight = weight.float()loss = F.binary_cross_entropy_with_logits(pred, label.float(), weight, reduction='none')# do the reduction for the weighted lossloss = weight_reduce_loss(loss, reduction=reduction, avg_factor=avg_factor)return lossdef balanced_mask_cross_entropy(pred, label, mask=None, negative_ratio=3.0, eps=1e-10):positive = label.byte()negative = (1-label).byte()positive_count = int(positive.float().sum())negative_count = min(int(negative.float().sum()), int(positive_count * negative_ratio))loss = F.binary_cross_entropy(pred, label, reduction='none')[:,0,:,:]positive_loss = loss * positive.float()negative_loss = loss * negative.float()negative_loss, _ = torch.topk(negative_loss.view(-1), negative_count)balance_loss = (positive_loss.sum() + negative_loss.sum()) / (positive_count + negative_count + eps)return balance_loss@LOSSES.register_module()
class BalancedCrossEntropyLoss(nn.Module):def __init__(self,negative_ratio=3.0,eps=1e-10,loss_weight=1.0):super(BalancedCrossEntropyLoss, self).__init__()self.negative_ratio = negative_ratioself.eps = epsself.loss_weight = loss_weightself.cls_criterion = balanced_mask_cross_entropydef forward(self,pred,label,mask=None,**kwargs):loss_cls = self.loss_weight * self.cls_criterion(pred, label, mask=None, negative_ratio=self.negative_ratio, eps=self.eps, **kwargs)return loss_cls

【PyTorch】Balanced_CE_loss 实现相关推荐

  1. pytorch中实现Balanced Cross-Entropy

    当你明白了pytorch中F.cross_entropy以及F.binary_cross_entropy是如何实现的之后,你再基于它们做改进重新实现一个损失函数就很容易了. 1.背景 变化检测中,往往 ...

  2. 通过anaconda2安装python2.7和安装pytorch

    ①由于官网下载anaconda2太慢,最好去byrbt下载,然后安装就行 ②安装完anaconda2会自动安装了python2.7(如终端输入python即进入python模式) 但是可能没有设置环境 ...

  3. 记录一次简单、高效、无错误的linux上安装pytorch的过程

    1 准备miniconda Miniconda Miniconda 可以理解成Anaconda的免费.浓缩版.它非常小,只包含了conda.python以及它们依赖的一些包.我们可以根据我们的需要再安 ...

  4. 各种注意力机制PyTorch实现

    给出了整个系列的PyTorch的代码实现,以及使用方法. 各种注意力机制 Pytorch implementation of "Beyond Self-attention: External ...

  5. PyTorch代码调试利器_TorchSnooper

    GitHub 项目地址: https://github.com/zasdfgbnm/TorchSnooper 大家可能遇到这样子的困扰:比如说运行自己编写的 PyTorch 代码的时候,PyTorch ...

  6. pytorch常用代码

    20211228 https://mp.weixin.qq.com/s/4breleAhCh6_9tvMK3WDaw 常用代码段 本文代码基于 PyTorch 1.x 版本,需要用到以下包: impo ...

  7. API pytorch tensorflow

    pytorch与tensorflow API速查表 方法名称 pytroch tensorflow numpy 裁剪 torch.clamp(x, min, max) tf.clip_by_value ...

  8. tensor转换 pytorch tensorflow

    一.tensorflow的numpy与tensor互转 1.数组(numpy)转tensor 利用tf.convert_to_tensor(numpy),将numpy转成tensor >> ...

  9. tensor和模型 保存与加载 PyTorch

    PyTorch教程-7:PyTorch中保存与加载tensor和模型详解 保存和读取Tensor PyTorch中的tensor可以保存成 .pt 或者 .pth 格式的文件,使用torch.save ...

最新文章

  1. Python 在 命令行中 安装 matplotlib
  2. webstorm javascript IDE调试
  3. 前端JavaScripts
  4. MFC 窗体样式修改
  5. 蓝桥杯2017初赛-分巧克力-二分
  6. 设计模式(五)行为型模式
  7. JS之前台参数提交到后台,双引号转义为解决办法
  8. Chess Queen【数学】
  9. POJ 1185 炮兵阵地(状压dp)
  10. 【码云】git简单使用总结
  11. SARscape操作:Sentinel-1 SLC影像镶嵌、裁切
  12. javaSpring面试题,安排
  13. mysql读写分离_SpringBoot+MyBatis+MySQL读写分离
  14. fc安卓模拟器_MAME街机模拟器0.224经典游戏全收藏
  15. opencv学习(四十四)之图像角点检测Harris
  16. Referrer 还是 Referer?
  17. 局域网即时通讯软件的实现
  18. 中央电视台硬盘播出系统的扩展应用与维护经验(mxf 格式)
  19. 精确控制Origin to Word图片格式、大小及主题使用技巧
  20. spark数据处理-RDD

热门文章

  1. 人生一世,草木一秋,再伟大的人在历史长河中也只是一个匆匆过客
  2. 利用AJAX做天气预报
  3. 【Unity】脚本实现动态模型切割
  4. 华为云桌面,带你见识不一样的系统桌面
  5. 驱动开发:挂接SSDT内核钩子
  6. 《人性的弱点》观后感
  7. 2020中山大学计算机学院保研,我校举行中山大学2020级研究生招生宣讲会
  8. Bluetooth 蓝牙介绍(四):低功耗蓝牙BLE Mesh网络Ⅱ —— Mesh网络
  9. TWaver三维可视化管理软件、3D和2D开发工具软件的试用(申请试用的回复邮件)
  10. 【交互设计】什么是微交互