1 FocalLoss

a. 关于Focal loss具体的解析可以参考https://zhuanlan.zhihu.com/p/49981234
对于二分类FocalLoss ,代码参考ptorch官方https://pytorch.org/vision/stable/generated/torchvision.ops.sigmoid_focal_loss.html?highlight=focal#torchvision.ops.sigmoid_focal_loss,这里主要从数值解析上去验证:

from torchvision.ops import sigmoid_focal_lossinput = torch.tensor([0.1,0.2])
target = torch.tensor([0,1])
weight = 0.25
gamma = 2
loss = sigmoid_focal_loss(input.float(), target.float(), weight=0.25, gamma=2, reduction='none')
print(loss)
'''
loss值 tensor([0.1539, 0.0303])
对于第一个是负样本,计算过程
pt = 1-torch.sigmoid(input[0])
loss_1 = -(1-weight)*(1-pt)**gamma*torch.log(pt) #值为0.1539
对于第二个是正样本,计算过程
pt = torch.sigmoid(input[1])
loss_2 = -weight*(1-pt)**gamma*torch.log(pt) #值为0.0303
'''

b. 对于多分类focal loss (multi-class focal loss), 暂时还未找到靠谱代码,基本就是说weight对所有类别都是一致的,后续在补充
b1. 这里加一个对于multi-class 的二分类 focal loss,就是将多分类转化成二分类,然后计算focal loss(retina net计算方式)

'''
假设输入 经过sigmoid之后,一共4类(包括背景类,最后一类是背景类),二个框分类,所以输入大小是2*3,target是[3,2],进过onehot之后[[0,0,0],[0,0,1]]
prob = torch.tensor([[0.0247,0.0248,0.0249],[0.0247,0.0248,0.0249],[0.0247,0.0248,0.0249]])
targets = torch.tensor([[0,0,0],[0,0,1]])
gamma = 2
alpha = 0.25
ce_loss = F.binary_cross_entropy(inputs, targets, reduction="none")
p_t = prob * targets + (1 - prob) * (1 - targets)
loss = ce_loss * ((1 - p_t) ** gamma)if alpha >= 0:alpha_t = alpha * targets + (1 - alpha) * (1 - targets)loss = alpha_t * loss对于第一个框第一类(算是负样本)的计算过程就是
pt = 1-prob[0][0]
loss_0 = (1-alpha)*(1-pt)**gamma*(-torch.log(pt)) # 值是1.1416e-5
整体计算过程参照上面的a计算过程

c. 基于分割的focal loss,代码参考https://docs.monai.io/en/stable/_modules/monai/losses/focal_loss.html#FocalLoss,这里计算是按所有类别weight都是一样,这里主要从数值解析上去验证:

import torch
from monai.losses import FocalLossinput = torch.tensor([[[[0.1,0.2],[0.3,0.4]],       [[0.5,0.6],[0.7,0.8]],[[0.9,0.1],[0.2,0.4]],]])target = torch.tensor([[[[1,0],[0,1]]]])
weight = 0.25
gamma = 2
pt = torch.exp(input[0,0,0,0])/(1+torch.exp(input[0,0,0,0]))
pt1 = torch.exp(input[0,0,0,1])/(1+torch.exp(input[0,0,0,1]))
self = FocalLoss(reduction='none', gamma=gamma, weight=weight, to_onehot_y=True) #对于这个weight可以设置成个list长度和类别长度一致,表示每一类权重大小,参考源码解释
loss = self(input, target)
'''
对应loss值是tensor([[[[0.0513, 0.0303],[0.0251, 0.0818]],[[0.0169, 0.1081],[0.1231, 0.0089]],[[0.1568, 0.0513],[0.0603, 0.0818]]]])
这里是3个类别标签,大小是2*2
计算第一个类别第一个位置的loss,此为负样本
pt = 1 - torch.exp(input[0,0,0,0])/(1+torch.exp(input[0,0,0,0]))
loss1 = -weight*(1-pt)**gamma*torch.log(pt) # 值是0.0513
计算第一个类别第二个位置的loss,此为正样本
pt =  torch.exp(input[0,0,0,1])/(1+torch.exp(input[0,0,0,1]))
loss2 = -weight*(1-pt)**gamma*torch.log(pt) # 值是0.0303
'''

2 Dice Loss

关于dice loss具体的解析可以参考https://zhuanlan.zhihu.com/p/269592183,具体代码解析参考https://docs.monai.io/en/stable/_modules/monai/losses/dice.html#DiceLoss.forward

from monai.losses.dice import *  # NOQA
import torch
from monai.losses.dice import DiceLossinput = torch.tensor([[[[0.1,0.2],[0.3,0.4]],       [[0.5,0.6],[0.7,0.8]],[[0.9,0.1],[0.2,0.4]],]]) # input的shape 1*3*2*2(对应batch*num_class*h*w)
target_idx = torch.tensor([[[1,0],[0,1]]]) #label的shape 1*1*2*2
target = one_hot(target_idx[:, None, ...], num_classes=C)  #这里是转化成one-hot形式
'''
target 的值
target = torch.tensor([[[0,1],[1,0]],[[1,0],[0,1]],[[0,0],[0,0]]])
'''self = DiceLoss(reduction='none')
loss = self(input, target)'''
对应的loss 结果
loss = tensor([[[[0.6667]],[[0.4348]],[[1.0000]]]])
如何计算
首先是有3个类别的输入,对于每个类别loss计算
整体公式就是 loss = 1-2*tp/(预测的概率和+标签的和)    (tp是指label为真对应的概率值)
loss_1 = 1-2*(0.2+0.3)/(0.1+0.2+0.3+0.4+2) = 0.6667
loss_2 = 1-2*(0.5+0.8)/(0.5+0.6+0.7+0.8+2) = 0.4348
loss_3 = 1-2*0/(0.9+0.1+0.2+0.4+0)

FocalLoss解析相关推荐

  1. 【深度学习】RetinaNet 代码完全解析

    前言 本文就是大名鼎鼎的focalloss中提出的网络,其基本结构backbone+fpn+head也是目前目标检测算法的标准结构.RetinaNet凭借结构精简,清晰明了.可扩展性强.效果优秀,成为 ...

  2. yolov5 代码内容解析

    目录 一.工程目录及所需的配置文件解析 二.训练代码详解 加载模型 优化器 数据生成器 参数及类别权重 warmup和前向传播 损失函数计算 准确性和召回率计算 Yolov5 目标检测 一.工程目录及 ...

  3. YOLOv7 | 模型结构与正负样本分配解析

    如有错误,恳请指出. Yolov7的原作者就是Yolov4的原作者.看论文的时候看到比较乱,这里可能会比较杂乱的记录一下我觉得有点启发的东西.对于yolov7的代码,我也没有仔细的看,只是大概的看了下 ...

  4. Yolov5系列(3)-loss解析

    Abstract 在yolov5中,loss在训练中起到了决定性的作用,同时,yolov5的loss又与大部分传统的方法不同,它是基于网格的.在网格上生成相应的anchor框和其对应的cls以及con ...

  5. MMDetection框架的anchor_generators.py解析与船数据解析

    anchor_generators.py解析 import mmcv import numpy as np import torch from torch.nn.modules.utils impor ...

  6. golang通过RSA算法生成token,go从配置文件中注入密钥文件,go从文件中读取密钥文件,go RSA算法下token生成与解析;go java token共用

    RSA算法 token生成与解析 本文演示两种方式,一种是把密钥文件放在配置文件中,一种是把密钥文件本身放入项目或者容器中. 下面两种的区别在于私钥公钥的初始化, init方法,需要哪种取哪种. 通过 ...

  7. List元素互换,List元素转换下标,Java Collections.swap()方法实例解析

    Java Collections.swap()方法解析 jdk源码: public static void swap(List<?> list, int i, int j) {// ins ...

  8. 条形码?二维码?生成、解析都在这里!

    二维码生成与解析 一.生成二维码 二.解析二维码 三.生成一维码 四.全部的代码 五.pom依赖 直接上代码: 一.生成二维码 public class demo {private static fi ...

  9. Go 学习笔记(82)— Go 第三方库之 viper(解析配置文件、热更新配置文件)

    1. viper 特点 viper 是一个完整的 Go应用程序的配置解决方案,它被设计为在应用程序中工作,并能处理所有类型的配置需求和格式.支持特性功能如下: 设置默认值 读取 JSON.TOML.Y ...

  10. Go 学习笔记(77)— Go 第三方库之 cronexpr(解析 crontab 表达式,定时任务)

    cronexpr 支持的比 Linux 自身的 crontab 更详细,可以精确到秒级别. ​ 1. 实现方式 cronexpr 表达式从前到后的顺序如下所示: 字段类型 是否为必须字段 允许的值 允 ...

最新文章

  1. R语言构建xgboost模型:指定特征交互方式、单调性约束的特征、获取模型中的最终特征交互形式(interaction and monotonicity constraints)
  2. vue项目设置img标签的默认图片
  3. 常见浏览器兼容性问题与解决方案
  4. 04 Django之模板系统
  5. 使用IntelliJ调试Java流
  6. Spring根据包名获取包路径下的所有类
  7. 小程序组件的使用(一)创建组件
  8. shell中 if条件的格式要求
  9. 踩坑记录——ProxyServer删除问题经验分享
  10. C# 小票打印机 直接打印 无需驱动
  11. 关于SWAT模型的一些原理(一)
  12. 【lstm做文本分类保存】
  13. Ubuntu安装配置sougou输入法
  14. 数学建模overleaf模板_数学建模论文怎么写?快来pick最优万能模板,一文格式全搞定!...
  15. 计算机数字媒体专业职业规划书,如何写数字媒体技术的职业生涯规划书?
  16. Android源码目录结构详解
  17. java多线程聊天室_JAVA多线程网络聊天室代码
  18. Stm32的GPIO驱动继电器
  19. win10计算机管理器端口号,Win10设备管理器没有端口选项的解决方法
  20. 告诉你Windows PE 是什么东东?详细介绍一下winpe

热门文章

  1. 魔方机器人需要特制魔方吗_大开眼界:会玩魔方的机器人
  2. 猿创征文 | Python 开发工具进化之旅
  3. uni-app实现实时获取当前时间日期
  4. PHP-FPM的PM配置参数说明
  5. Thanos Query Frontend
  6. opencv-11-中值滤波及自适应中值滤波
  7. html 调用es2015模块,给大家分别介绍一下CommonJS和ES2015的import
  8. xamp配置虚拟域名_如何下载,安装和配置XAMP以创建网页?
  9. 利用matlab实现非线性拟合(三维、高维、参数方程)
  10. idea 编译器注释汉字变繁体字解决办法