ps:咱们继续.

先贴一下交叉熵的公式:

在贴一下我的尝试:

>>> import torch
>>> input=torch.randn(4,2)
>>> input
tensor([[ 0.0543,  0.5641],[ 1.2221, -0.5496],[-0.7951, -0.1546],[-0.4557,  1.4724]])
>>> pt=torch.softmax(input,dim=1)
>>> pt
tensor([[0.3753, 0.6247],[0.8547, 0.1453],[0.3451, 0.6549],[0.1270, 0.8730]])
>>> target=torch.tensor([1,0,1,1])
>>> ones=torch.eye(2)
>>> targetones=ones.index_select(0,target)
>>> targetones
tensor([[0., 1.],[1., 0.],[0., 1.],[0., 1.]])
>>> torch.log(pt)
tensor([[-0.9802, -0.4704],[-0.1570, -1.9287],[-1.0639, -0.4233],[-2.0639, -0.1358]])
>>> -(targetones*torch.log(pt))
tensor([[0.0000, 0.4704],[0.1570, 0.0000],[0.0000, 0.4233],[0.0000, 0.1358]])
>>> torch.log(1-pt)
tensor([[-0.4704, -0.9802],[-1.9287, -0.1570],[-0.4233, -1.0639],[-0.1358, -2.0639]])
>>> -((1-targetones)*torch.log(1-pt))
tensor([[0.4704, 0.0000],[0.0000, 0.1570],[0.4233, 0.0000],[0.1358, 0.0000]])
>>> -(targetones*torch.log(pt))-((1-targetones)*torch.log(1-pt))
tensor([[0.4704, 0.4704],[0.1570, 0.1570],[0.4233, 0.4233],[0.1358, 0.1358]])

二分类的-(targetones*torch.log(pt))和-((1-targetones)*torch.log(1-pt))怎么感觉一样啊.接着尝试如下:

>>> target=torch.Tensor([1,0,1,1])
>>> pt
tensor([[0.3753, 0.6247],[0.8547, 0.1453],[0.3451, 0.6549],[0.1270, 0.8730]])
>>> p=pt[:,1]
>>> p
tensor([0.6247, 0.1453, 0.6549, 0.8730])
>>> -(target*torch.log(p))-((1-target)*torch.log(1-p))
tensor([0.4704, 0.1570, 0.4233, 0.1358])>>> target=torch.Tensor([0,1,0,0])
>>> p=pt[:,0]
>>> p
tensor([0.3753, 0.8547, 0.3451, 0.1270])
>>> -(target*torch.log(p))-((1-target)*torch.log(1-p))
tensor([0.4704, 0.1570, 0.4233, 0.1358])>>> celoss(input,torch.tensor([1,0,1,1]))
tensor(0.2966)
>>> loss(torch.log(torch.softmax(input,dim=1)),torch.tensor([1,0,1,1]))
tensor(0.2966)>>> (-(target*torch.log(p))-((1-target)*torch.log(1-p))).mean()
tensor(0.2966)

果然是这样.上面这步已经在NLLLoss函数里了做了(真实值1和真实值0).

所以要把alpha和(1-alpha)放到用NLLLoss函数之前,并且同样变成Tensor(4,2)(Tensor生成torch.FloatTensor,而tensor生成torch.LongTensor).进一步尝试先把Focal Loss思路尝试完如下:

>>> target=torch.Tensor([1,0,1,1])
>>> p=pt[:,1]
>>> p
tensor([0.6247, 0.1453, 0.6549, 0.8730])
>>> alpha=0.25
>>> gamma=2
>>> -alpha*(1-p)**gamma*(target*torch.log(p))-(1-alpha)*p**gamma*((1-target)*torch.log(1-p))
tensor([0.0166, 0.0025, 0.0126, 0.0005])

在来尝试使用NLLLoss函数

 alpha=0.25
>>> aa=torch.Tensor([1-alpha,alpha])
>>> aa
tensor([0.7500, 0.2500])
>>> bb=aa.repeat(4,1)
>>> bb
tensor([[0.7500, 0.2500],[0.7500, 0.2500],[0.7500, 0.2500],[0.7500, 0.2500]])
>>> loss=torch.nn.NLLLoss()
>>> target=torch.tensor([1,0,1,1])
>>> loss(torch.log(bb*(pt**gamma)),target)
tensor(1.7049)

但是结果不太对啊.还要再想想.

算了好像行不通啊,只能用前一种方式去实现它.根据思路去实现如下:

import torch
import torch.nn as nn#二分类
class FocalLoss(nn.Module):def __init__(self, gamma=2,alpha=0.25):super(FocalLoss, self).__init__()self.gamma = gammaself.alpha=alphadef forward(self, input, target):# input:size is M*2. M is the batch number# target:size is M.pt=torch.softmax(input,dim=1)p=pt[:,1]loss = -self.alpha*(1-p)**self.gamma*(target*torch.log(p))-\(1-self.alpha)*p**self.gamma*((1-target)*torch.log(1-p))return loss.mean()

Focal Loss 分类问题 pytorch实现代码(续2)相关推荐

  1. Focal Loss 分类问题 pytorch实现代码(简单实现)

    ps:由于降阳性这步正负样本数量在差距巨大.正样本1500多个,而负样本750000多个.要用 Focal Loss来解决这个问题. 首先感谢Code_Mart的博客把理论汇总了下https://bl ...

  2. Focal Loss 分类问题 pytorch实现代码(续3)

    ps:虽然无法用NLLLoss函数来实现.但好歹最后实现了自己的想法.现在再来测试下最后和最开始的Focal Loss如下: import torch import torch.nn as nn#二分 ...

  3. Focal Loss 分类问题 pytorch实现代码(续1)

    ps:感谢Code_Mart的解答,肯定了思路,不过他也不确定是否可以在pytorch中那么写.事情这样模棱两可让我很烦躁决定深究一下.看到博客https://blog.csdn.net/qq_222 ...

  4. quality focal loss distribute focal loss 详解(paper, 代码)

    参见generalized focal loss paper 其中包含有Quality Focal Loss 和 Distribution Focal Loss. 目录 背景 Focal Loss Q ...

  5. OHEM,Focal loss,GHM loss二分类pytorch代码实现(减轻难易样本不均衡问题)

    https://mp.weixin.qq.com/s/iOAICJege2b0pCVxPkvNiA 综述:解决目标检测中的样本不均衡问题 该综述主要介绍了OHEM,Focal loss,GHM los ...

  6. focal loss 通俗理解

    文章目录 什么是focal loss? 控制正负样本的权重 控制容易分类和难分类样本的权重 两种权重控制方法合并 关于focal loss如果看过此文还不理解,可以看这篇文章: focal loss ...

  7. 【CV】10分钟理解Focal loss数学原理与Pytorch代码

    原文链接:https://amaarora.github.io/2020/06/29/FocalLoss.html 原文作者:Aman Arora Focal loss 是一个在目标检测领域常用的损失 ...

  8. pytorch gather_【CV】10分钟理解Focal loss数学原理与Pytorch代码

    原文链接:https://amaarora.github.io/2020/06/29/FocalLoss.html 原文作者:Aman Arora Focal loss 是一个在目标检测领域常用的损失 ...

  9. 前景背景样本不均衡解决方案:Focal Loss,GHM与PISA(附python实现代码)

    参考文献:Imbalance Problems in Object Detection: A Review 1 定义 在前景-背景类别不平衡中,背景占有很大比例,而前景的比例过小,这类问题是不可避免的 ...

最新文章

  1. WPF实现背景透明磨砂,并通过HandyControl组件实现弹出等待框
  2. jquery mobile页面切换效果(Flip toggle switch)(注:jQuery移动使用的数据属性的列表。 )...
  3. 浅谈web应用的负载均衡、集群、高可用(HA)解决方案
  4. WebForm 分页与组合查询
  5. 关于udelay(); mdelay(); ndelay(); msleep();
  6. win10 linux开发环境搭建,win10子系统linux.ubuntu开发环境搭建
  7. 吴恩达神经网络和深度学习-学习笔记-37-inception网络
  8. linux的shell脚本接收参数
  9. KEIL使用教程——KEIL常用配置技巧
  10. 英语计算机简历模板,计算机研究生英文简历模板
  11. STC89C52RC的AD7705读写实验(软件SPI)
  12. python数据分析18-21
  13. 烤地瓜(PYTHON 学习类和对象)
  14. 域名注册_申请证书\SSL证书\tls证书
  15. 网络信息安全管理要素和安全风险评估
  16. echarts条形图
  17. 录屏程序之屏幕实时录制保存成AVI视频文件
  18. 首席新媒体黎想教程:活动推广提升线下活动转化率?
  19. 从零搭建仿抖音短视频APP-后端开发短视频业务模块(1)
  20. npm ERR! enoent This is related to npm not being able to find a file.解决

热门文章

  1. LeetCode每日一题:比特位计数(No.338)
  2. Keepalived相关参数说明
  3. 【面试】iOS 开发面试题(二)
  4. Slimer软工课设日报-2016年6月30日
  5. Delphi XE10.1 引用计数
  6. WordPress Plupload插件未明跨站脚本漏洞
  7. sphinx的配置和管理
  8. 【hive】如何设置hive以及MapReduce的压缩方式?
  9. nginx实现负载均衡配置
  10. Jmeter基础之JMeter参数化补充练习