cb loss pytorch 实现,可直接调用
参考:https://github.com/vandit15/Class-balanced-loss-pytorch/blob/master/class_balanced_loss.py

import numpy as np
import torch
import torch.nn.functional as Fdef focal_loss(logits, labels, alpha, gamma):"""Compute the focal loss between `logits` and the ground truth `labels`.Focal loss = -alpha_t * (1-pt)^gamma * log(pt)where pt is the probability of being classified to the true class.pt = p (if true class), otherwise pt = 1 - p. p = sigmoid(logit).Args:logits: A float tensor of size [batch, num_classes].labels: A float tensor of size [batch, num_classes].alpha: A float tensor of size [batch_size]specifying per-example weight for balanced cross entropy.gamma: A float scalar modulating loss from hard and easy examples.Returns:focal_loss: A float32 scalar representing normalized total loss."""bce_loss = F.binary_cross_entropy_with_logits(input=logits, target=labels, reduction="none")if gamma == 0.0:modulator = 1.0else:modulator = torch.exp(-gamma * labels * logits - gamma * torch.log(1 + torch.exp(-1.0 * logits)))loss = modulator * bce_lossweighted_loss = alpha * lossloss = torch.sum(weighted_loss)loss /= torch.sum(labels)return lossclass ClassBalancedLoss(torch.nn.Module):def __init__(self, samples_per_class=None, beta=0.9999, gamma=0.5, loss_type="focal"):super(ClassBalancedLoss, self).__init__()if loss_type not in ["focal", "sigmoid", "softmax"]:loss_type = "focal"if samples_per_class is None:num_classes = 5000samples_per_class = [1] * num_classeseffective_num = 1.0 - np.power(beta, samples_per_class)weights = (1.0 - beta) / np.array(effective_num)self.constant_sum = len(samples_per_class)weights = (weights / np.sum(weights) * self.constant_sum).astype(np.float32)self.class_weights = weightsself.beta = betaself.gamma = gammaself.loss_type = loss_typedef update(self, samples_per_class):if samples_per_class is None:returneffective_num = 1.0 - np.power(self.beta, samples_per_class)weights = (1.0 - self.beta) / np.array(effective_num)self.constant_sum = len(samples_per_class)weights = (weights / np.sum(weights) * self.constant_sum).astype(np.float32)self.class_weights = weightsdef forward(self, x, y):_, num_classes = x.shapelabels_one_hot = F.one_hot(y, num_classes).float()weights = torch.tensor(self.class_weights, device=x.device).index_select(0, y)weights = weights.unsqueeze(1)if self.loss_type == "focal":cb_loss = focal_loss(x, labels_one_hot, weights, self.gamma)elif self.loss_type == "sigmoid":cb_loss = F.binary_cross_entropy_with_logits(x, labels_one_hot, weights)else:  # softmaxpred = x.softmax(dim=1)cb_loss = F.binary_cross_entropy(pred, labels_one_hot, weights)return cb_lossdef test():torch.manual_seed(123)batch_size = 10num_classes = 5x = torch.rand(batch_size, num_classes)y = torch.randint(0, 5, size=(batch_size,))samples_per_class = [1, 2, 3, 4, 5]loss_type = "focal"loss_fn = ClassBalancedLoss(samples_per_class, loss_type=loss_type)loss = loss_fn(x, y)print(loss)if __name__ == '__main__':test()

class balanced loss pytorch 实现相关推荐

  1. [论文解读] A Ranking-based, Balanced Loss Function Unifying Classification and Localisation in Object De

    文章内容 相关研究现状 1. 定位任务和分类任务的平衡/耦合 2. 基于排名的目标检测算法 本文工作 基于排序损失的误差驱动优化方法推广 定理1:基于概率分布的损失函数重定义 定理2:正负样本梯度总和 ...

  2. pytorch训练Class-Balanced Loss

    1. 提示:文章写完后,目录可以自动生成,如何生成可参考右边的帮助文档 文章目录 1. pytorch版Class-Balanced Loss训练模型 一.数据准备 二.模型训练 三.模型预测 总结 ...

  3. pytorch版Class-Balanced Loss训练模型

    pytorch版Class-Balanced Loss训练模型 1.论文参考原文 https://arxiv.org/pdf/1901.05555.pdf 2.数据准备 将自己的数据集按照一下格式进行 ...

  4. SRGAN loss部分的pytorch代码实现

    转载地址:https://bbs.huaweicloud.com/forum/thread-137101-1-1.html 作者: 雨丝儿 最近在参加华为与高校合做开发mindspore模型的活动,使 ...

  5. 深度学习100+经典模型TensorFlow与Pytorch代码实现大合集

    关注上方"深度学习技术前沿",选择"星标公众号", 资源干货,第一时间送达! [导读]深度学习在过去十年获得了极大进展,出现很多新的模型,并且伴随TensorF ...

  6. BCE loss和 CE理解

    1. BCE loss:Binary Cross Entropy Loss BCE loss pytorch官网链接 1.1 解释 pytorch中调用如下.设置weight,使得不同类别的损失权值不 ...

  7. 【Pytorch】常见的人脸身份识别损失函数

    [Pytorch]常见的人脸身份识别损失函数 实验环境准备:人脸多角度多光照的图像数据集MUCT(276个受试者)+ MobileNetV3 说明:对于人脸身份数据集MUCT,是少样本数据集,应该使用 ...

  8. Style Transfer(PyTorch)

    Style Transfer-PyTorch Content Loss content loss用来计算原图片和生成的图片之间像素的差距,这里用的是卷积层获取的 feature map 之间的差距 通 ...

  9. Pytorch以及tensorflow中KLdivergence的计算

    1. KL divergence是什么 KL 散度是一个距离衡量指标,衡量的是两个概率分布之间的差异. y p r e d y_{pred} ypred​指的是模型的输出的预测概率,形如[0.35,0 ...

最新文章

  1. 电脑显示未安装任何音频输出设备_一套完整的台式电脑有哪些配置
  2. 运维笔记--postgresql占用CPU问题定位
  3. 学习 OpenStack 的方法论 - 每天5分钟玩转 OpenStack(150)
  4. java(1)——用notepad++编译java(javac.exe)
  5. 设置View单个圆角
  6. pycharm导入本地py文件时,模块下方出现红色波浪线
  7. 【渝粤教育】国家开放大学2018年春季 0538-21T社区护理 参考试题
  8. VC下关于debug和release的不同的讨论(收藏-转载)
  9. 跨时代比较:工业化因素是关键
  10. 塞班手机刷linux,向 诺基亚 塞班手机中 批量导入 通讯录(csplit iconv)
  11. STM32F401的RCC时钟配置
  12. TIA博途WINCC中英文切换的项目中摄氏度符号无法正常显示的解决办法
  13. win10安装打印机驱动程序失败“试图将读懂程序添加到存储区时遇到问题”
  14. BDH,CDH,DDH,DLP是什么?
  15. 罗马数字(Python)
  16. R语言制作Meta分析偏倚风险评估(ROB)图
  17. Ubuntu16.0.4 安装rebar3指南
  18. Richard Stallman的演讲:「A Free Digital Society」
  19. 计算机信息的容量单位是什么,信息的基本容量单位是
  20. SegmentFault 讲堂一周岁:Keep learning

热门文章

  1. 八股文-ArrayList
  2. QQ邮箱的POP3与SMTP服务器是什么?
  3. python 将数据库的 utc时间转换成本地时间
  4. drcom for linux,Drcom for Ubuntu上网解决经验
  5. 连续系统的复频域分析 matlab,(连续系统复频域分析.doc
  6. 进程间的通信(管道通信)
  7. 计算机毕业设计之java+javaweb的蛋糕甜品商城系统
  8. 最新版Eclipse2020创建项目红叉问题(“Failed to init ct.sym ...\jrt-fs.jar )
  9. 小丁带你走进git的世界二-工作区暂存区分支
  10. Matlab之classification learner app无法从workspace导入label (response variable)