Focal Loss 分类问题 pytorch实现代码(续2)
ps:咱们继续.
先贴一下交叉熵的公式:
在贴一下我的尝试:
>>> import torch
>>> input=torch.randn(4,2)
>>> input
tensor([[ 0.0543, 0.5641],[ 1.2221, -0.5496],[-0.7951, -0.1546],[-0.4557, 1.4724]])
>>> pt=torch.softmax(input,dim=1)
>>> pt
tensor([[0.3753, 0.6247],[0.8547, 0.1453],[0.3451, 0.6549],[0.1270, 0.8730]])
>>> target=torch.tensor([1,0,1,1])
>>> ones=torch.eye(2)
>>> targetones=ones.index_select(0,target)
>>> targetones
tensor([[0., 1.],[1., 0.],[0., 1.],[0., 1.]])
>>> torch.log(pt)
tensor([[-0.9802, -0.4704],[-0.1570, -1.9287],[-1.0639, -0.4233],[-2.0639, -0.1358]])
>>> -(targetones*torch.log(pt))
tensor([[0.0000, 0.4704],[0.1570, 0.0000],[0.0000, 0.4233],[0.0000, 0.1358]])
>>> torch.log(1-pt)
tensor([[-0.4704, -0.9802],[-1.9287, -0.1570],[-0.4233, -1.0639],[-0.1358, -2.0639]])
>>> -((1-targetones)*torch.log(1-pt))
tensor([[0.4704, 0.0000],[0.0000, 0.1570],[0.4233, 0.0000],[0.1358, 0.0000]])
>>> -(targetones*torch.log(pt))-((1-targetones)*torch.log(1-pt))
tensor([[0.4704, 0.4704],[0.1570, 0.1570],[0.4233, 0.4233],[0.1358, 0.1358]])
二分类的-(targetones*torch.log(pt))和-((1-targetones)*torch.log(1-pt))怎么感觉一样啊.接着尝试如下:
>>> target=torch.Tensor([1,0,1,1])
>>> pt
tensor([[0.3753, 0.6247],[0.8547, 0.1453],[0.3451, 0.6549],[0.1270, 0.8730]])
>>> p=pt[:,1]
>>> p
tensor([0.6247, 0.1453, 0.6549, 0.8730])
>>> -(target*torch.log(p))-((1-target)*torch.log(1-p))
tensor([0.4704, 0.1570, 0.4233, 0.1358])>>> target=torch.Tensor([0,1,0,0])
>>> p=pt[:,0]
>>> p
tensor([0.3753, 0.8547, 0.3451, 0.1270])
>>> -(target*torch.log(p))-((1-target)*torch.log(1-p))
tensor([0.4704, 0.1570, 0.4233, 0.1358])>>> celoss(input,torch.tensor([1,0,1,1]))
tensor(0.2966)
>>> loss(torch.log(torch.softmax(input,dim=1)),torch.tensor([1,0,1,1]))
tensor(0.2966)>>> (-(target*torch.log(p))-((1-target)*torch.log(1-p))).mean()
tensor(0.2966)
果然是这样.上面这步已经在NLLLoss函数里了做了(真实值1和真实值0).
所以要把alpha和(1-alpha)放到用NLLLoss函数之前,并且同样变成Tensor(4,2)(Tensor生成torch.FloatTensor,而tensor生成torch.LongTensor).进一步尝试先把Focal Loss思路尝试完如下:
>>> target=torch.Tensor([1,0,1,1])
>>> p=pt[:,1]
>>> p
tensor([0.6247, 0.1453, 0.6549, 0.8730])
>>> alpha=0.25
>>> gamma=2
>>> -alpha*(1-p)**gamma*(target*torch.log(p))-(1-alpha)*p**gamma*((1-target)*torch.log(1-p))
tensor([0.0166, 0.0025, 0.0126, 0.0005])
在来尝试使用NLLLoss函数
alpha=0.25
>>> aa=torch.Tensor([1-alpha,alpha])
>>> aa
tensor([0.7500, 0.2500])
>>> bb=aa.repeat(4,1)
>>> bb
tensor([[0.7500, 0.2500],[0.7500, 0.2500],[0.7500, 0.2500],[0.7500, 0.2500]])
>>> loss=torch.nn.NLLLoss()
>>> target=torch.tensor([1,0,1,1])
>>> loss(torch.log(bb*(pt**gamma)),target)
tensor(1.7049)
但是结果不太对啊.还要再想想.
算了好像行不通啊,只能用前一种方式去实现它.根据思路去实现如下:
import torch
import torch.nn as nn#二分类
class FocalLoss(nn.Module):def __init__(self, gamma=2,alpha=0.25):super(FocalLoss, self).__init__()self.gamma = gammaself.alpha=alphadef forward(self, input, target):# input:size is M*2. M is the batch number# target:size is M.pt=torch.softmax(input,dim=1)p=pt[:,1]loss = -self.alpha*(1-p)**self.gamma*(target*torch.log(p))-\(1-self.alpha)*p**self.gamma*((1-target)*torch.log(1-p))return loss.mean()
Focal Loss 分类问题 pytorch实现代码(续2)相关推荐
- Focal Loss 分类问题 pytorch实现代码(简单实现)
ps:由于降阳性这步正负样本数量在差距巨大.正样本1500多个,而负样本750000多个.要用 Focal Loss来解决这个问题. 首先感谢Code_Mart的博客把理论汇总了下https://bl ...
- Focal Loss 分类问题 pytorch实现代码(续3)
ps:虽然无法用NLLLoss函数来实现.但好歹最后实现了自己的想法.现在再来测试下最后和最开始的Focal Loss如下: import torch import torch.nn as nn#二分 ...
- Focal Loss 分类问题 pytorch实现代码(续1)
ps:感谢Code_Mart的解答,肯定了思路,不过他也不确定是否可以在pytorch中那么写.事情这样模棱两可让我很烦躁决定深究一下.看到博客https://blog.csdn.net/qq_222 ...
- quality focal loss distribute focal loss 详解(paper, 代码)
参见generalized focal loss paper 其中包含有Quality Focal Loss 和 Distribution Focal Loss. 目录 背景 Focal Loss Q ...
- OHEM,Focal loss,GHM loss二分类pytorch代码实现(减轻难易样本不均衡问题)
https://mp.weixin.qq.com/s/iOAICJege2b0pCVxPkvNiA 综述:解决目标检测中的样本不均衡问题 该综述主要介绍了OHEM,Focal loss,GHM los ...
- focal loss 通俗理解
文章目录 什么是focal loss? 控制正负样本的权重 控制容易分类和难分类样本的权重 两种权重控制方法合并 关于focal loss如果看过此文还不理解,可以看这篇文章: focal loss ...
- 【CV】10分钟理解Focal loss数学原理与Pytorch代码
原文链接:https://amaarora.github.io/2020/06/29/FocalLoss.html 原文作者:Aman Arora Focal loss 是一个在目标检测领域常用的损失 ...
- pytorch gather_【CV】10分钟理解Focal loss数学原理与Pytorch代码
原文链接:https://amaarora.github.io/2020/06/29/FocalLoss.html 原文作者:Aman Arora Focal loss 是一个在目标检测领域常用的损失 ...
- 前景背景样本不均衡解决方案:Focal Loss,GHM与PISA(附python实现代码)
参考文献:Imbalance Problems in Object Detection: A Review 1 定义 在前景-背景类别不平衡中,背景占有很大比例,而前景的比例过小,这类问题是不可避免的 ...
最新文章
- WPF实现背景透明磨砂,并通过HandyControl组件实现弹出等待框
- jquery mobile页面切换效果(Flip toggle switch)(注:jQuery移动使用的数据属性的列表。 )...
- 浅谈web应用的负载均衡、集群、高可用(HA)解决方案
- WebForm 分页与组合查询
- 关于udelay(); mdelay(); ndelay(); msleep();
- win10 linux开发环境搭建,win10子系统linux.ubuntu开发环境搭建
- 吴恩达神经网络和深度学习-学习笔记-37-inception网络
- linux的shell脚本接收参数
- KEIL使用教程——KEIL常用配置技巧
- 英语计算机简历模板,计算机研究生英文简历模板
- STC89C52RC的AD7705读写实验(软件SPI)
- python数据分析18-21
- 烤地瓜(PYTHON 学习类和对象)
- 域名注册_申请证书\SSL证书\tls证书
- 网络信息安全管理要素和安全风险评估
- echarts条形图
- 录屏程序之屏幕实时录制保存成AVI视频文件
- 首席新媒体黎想教程:活动推广提升线下活动转化率?
- 从零搭建仿抖音短视频APP-后端开发短视频业务模块(1)
- npm ERR! enoent This is related to npm not being able to find a file.解决