every blog every motto: You can do more than you think.

0. 前言

简单记录损失函数,dice loss、focal loss
说明: 后续增补

1. 正文

1.1 基础概念

举个栗子:
用模型对100人进行身体健康状况预测,已知30人患肿瘤。规定肿瘤为阳性,正常为阴性。
预测结果:25人阳性,其中5人实际为阴性。则,
TP = 20,(True Positive,正确预测为阳性
FP = 5,False Positive,错误预测为阳性
FN=10, (False Negative,错误预测为阴性
TN = 65,(True Negative,正确预测为阴性

小结:
第二字母为预测结果(P或N,预测为阳性或阴性),第一个字母为对预测结果的判断(T或F,预测正确或错误)
第二字母为预测结果(P或N,预测为阳性或阴性),第一个字母为对预测结果的判断(T或F,预测正确或错误)
第二字母为预测结果(P或N,预测为阳性或阴性),第一个字母为对预测结果的判断(T或F,预测正确或错误)

混淆矩阵形式:

图示:

说明: 不管是肿瘤还是建筑物预测,我们一般将我们所关心的类别归为阳性,剩下的归为阴性

1.2 正文

1.2.1 dice loss

dice loss 来自 dice coefficient(一种用于评估两个样本相似性的度量函数参考文献1),取值范围0-1。
dice coefficient定义如下:

dice=2∣X∩Y∣∣X∣+∣Y∣dice = {2 |X \cap Y| \over |X| + |Y|} dice=∣X∣+∣Y∣2∣X∩Y∣​

dice loss 定义如下:
Ldice=1−dice=1−2∣X∩Y∣∣X∣+∣Y∣L_{dice} = 1 - dice = 1 - {2 |X \cap Y| \over |X| + |Y|} Ldice​=1−dice=1−∣X∣+∣Y∣2∣X∩Y∣​
对于二分类问题,用混淆矩阵计算如下,
dice=2TP2TP+FP+FNdice = {2TP \over 2TP + FP + FN} dice=2TP+FP+FN2TP​
由参考文献1知,
精确率:
Precision=TPTP+FPPrecision = {TP \over TP + FP} Precision=TP+FPTP​
召回率:
Recall=TPTP+FNRecall = {TP \over TP + FN} Recall=TP+FNTP​
其中,F1:
F1=2∗Precision∗RecallPrecision+RecallF1 = {2*Precision*Recall \over Precision + Recall} F1=Precision+Recall2∗Precision∗Recall​
又,
F1=2∗Precision∗RecallPrecision+Recall=2TP2TP+FP+FN=diceF1 = {2*Precision*Recall \over Precision + Recall} = {2TP \over 2TP + FP + FN} = dice F1=Precision+Recall2∗Precision∗Recall​=2TP+FP+FN2TP​=dice
F1和我们的评价指标dice本质是一个意思,即,我们优化dice loss是直接优化F1

def dice_loss(target,predictive,ep=1e-8):intersection = 2 * torch.sum(predictive * target) + epunion = torch.sum(predictive) + torch.sum(target) + eploss = 1 - intersection / unionreturn loss

小结:
dice loss 对于正负样本不平衡问题有着不错的性能,训练过程侧重前景的挖掘。但训练loss容易不稳定,改进操作包括和其他loss结合,包括:
dice loss + ce loss
dice loss + focal loss
具体参考文献1

1.2.1 focal loss

1. 二分类交叉熵回顾

主要解决样本不平衡问题提出的。
我们首先回顾一下二分类交叉熵
loss=−1N[p∗logq+(1−p)∗log(1−q)]loss =-{1 \over N} [p*logq + (1-p)*log(1-q)] loss=−N1​[p∗logq+(1−p)∗log(1−q)]
其中,p为标签,q为预测。对于二分类,我们规定上述p为正样本,则,1-p代表负样本。如果我们对遥感影像进行建筑物提取,那么p代表建筑物(像素),q代表背景(像素)。则,上式又可改写为:

loss={−logq,p=1−log(1−q),p=0loss = \begin{cases} -logq, \quad \quad \quad p=1 \\ -log(1-q), \quad p=0 \end{cases} loss={−logq,p=1−log(1−q),p=0​
-log函数图如下,

我们对上式进行解释:
对一张遥感进行建筑物提取,简化图如下,
(我们规定,像素值为0代表背景,像素值为1代表建筑物)

我们分析左上角的像素点的计算过程,其中标签值为1,我们预测结果为0.45。那么我们代入公式
l1=−log(0.45)l1 = -log(0.45) l1=−log(0.45)
我们再分析右下角像素的计算过程,其中,右下角点为背景,所以我们的标签值为0,如下图所示,

代入公式
l2=−log(1−0.7)l2 = -log(1-0.7) l2=−log(1−0.7)
0.7位我们预测为建筑物的概率
更进一步关于数据验证部分,参考文献3
参考文献4

2. focal loss

对于上述二分类交叉熵而言,对正负样本是同等考虑的,
同时,由公式我们发现一个现象,
对于正样本(标签中正样本像素位置),输出概率概率越大,损失越小
对于负样本(标签中负样本像素位置),输出概率越小,损失越小
为了抑制样本不平衡问题(背景占比多),添加平衡因子α\alphaα,论文中取值为0.25

loss={−α×logq,p=1−(1−α)×log(1−q),p=0loss = \begin{cases} -\alpha×logq ,\quad p=1 \\ -(1-\alpha) × log(1-q) ,\quad\quad\quad\quad p=0 \end{cases} loss={−α×logq,p=1−(1−α)×log(1−q),p=0​

为了 减少易分类样本的损失,更加关注困难的、错分样本,又添加了调制系数γ\gammaγ,
loss={−α×(1−q)γ×logq,p=1−α×qγ×log(1−q),p=0loss = \begin{cases} -\alpha×(1-q)^\gamma ×logq ,\quad p=1 \\ -\alpha×q^\gamma × log(1-q) ,\quad\quad\quad\quad p=0 \end{cases} loss={−α×(1−q)γ×logq,p=1−α×qγ×log(1−q),p=0​

参考文献

[1] https://zhuanlan.zhihu.com/p/269592183
[2] https://blog.csdn.net/qq_34107425/article/details/110119894
[3] https://blog.csdn.net/weixin_39190382/article/details/114922578
[4] https://blog.csdn.net/weixin_39190382/article/details/114681163
[5] https://zhuanlan.zhihu.com/p/49981234
[6] https://www.cnblogs.com/king-lps/p/9497836.html
[7] https://www.aiuai.cn/aifarm1159.html

【深度学习】损失函数记录相关推荐

  1. 深度学习损失函数大全

    深度学习损失函数大全 比focal loss好的GHM 常见损失函数汇总 - 知乎 有图展示,效果还可以: 激活函数/损失函数汇总 - 知乎 这个也不错: 损失函数汇总(全网最全) - WSX_199 ...

  2. python多分类混淆矩阵代码_深度学习自学记录(3)——两种多分类混淆矩阵的Python实现(含代码)...

    深度学习自学记录(3)--两种多分类混淆矩阵的Python实现(含代码),矩阵,样本,模型,类别,真实 深度学习自学记录(3)--两种多分类混淆矩阵的Python实现(含代码) 深度学习自学记录(3) ...

  3. 【深度学习】深度学习实验记录--自编码+分类器

    深度学习课程记录 自编码分类器神经网络记录 1.Train the autoencoder by using unlabeled data 训练1(fail) 训练2(fail) 训练3(fail) ...

  4. c++ log函数_认识这19种深度学习损失函数,才能说你了解深度学习!

    编辑:深度学习自然语言处理小编zenRRan 损失函数是深度学习中重要的概念,选择合适的损失函数是系统能够得到理想结果的保证,本文将以pytorch工具为例,介绍这19中损失函数与实现方法. 19种损 ...

  5. 收藏 | 深度学习损失函数大全(附代码实现)

    点击上方,选择星标或置顶,不定期资源大放送! 阅读大概需要15分钟 Follow小博主,每天更新前沿干货 来源:机器学习与数学 Tensorflow 和 PyTorch 很多都是相似的,这里以 PyT ...

  6. 深度学习——损失函数(Regression Loss、Classification Loss)

    简介 Loss function 损失函数 用于定义单个训练样本与真实值之间的误差 Cost function 代价函数 用于定义单个批次/整个训练集样本与真实值之间的误差 Objective fun ...

  7. 深度学习损失函数不下降的解决方法

    https://blog.csdn.net/qq_37764129/article/details/94042702 当我们训练一个神经网络模型的时候,我们经常会遇到这样的一个头疼的问题,那就是,神经 ...

  8. 深度学习——损失函数推导过程(三个方面诠释损失函数的由来意义)

    三个维度诠释损失函数 我们在学习机器学习的过程中,通常利用损失函数来衡量模型所做出的预测离真实值之间的偏离程度. 损失函数大致分为3类方法 最小二乘法 极大似然估计法 交叉熵 1.最小二乘法 这个方法 ...

  9. ubuntu20.04+gpu驱动下载+cuda10.2+cudnn+pytorch深度学习搭建记录(一路爬坑的一天...)

    ubuntu20.04+gpu驱动下载+cuda10.2+cudnn+pytorch 深度学习环境搭建记录(一路爬坑的一天-) 1.gpu驱动下载 参考:https://blog.csdn.net/f ...

  10. 安装 Win10 Ubuntu 16.04 双系统以及 Ubuntu 配置深度学习环境记录

    0. 前言 坑爹的Ubuntu晚上运行还是好好的,第二天中午的时候打开机器发现屏幕分辨率不正常了:2K屏显示800*600左右的分辨率(无法调节),一个图标一拳头大,窗口和网页显示不全.Google查 ...

最新文章

  1. MindCon | 5天啦,你有领取MSG城市专属徽章吗?
  2. Python:python语言中与时间有关的库函数简介、安装、使用方法之详细攻略
  3. iphone固件降级_我在iPhone上装了个安卓
  4. 【转】C#命名空间与java包的区别分析
  5. C++11 并发指南六(atomic 类型详解一 atomic_flag 介绍)
  6. Cisco路由器安全配置必用的10条命令
  7. axios流输出excel
  8. SolarWinds 供应链攻击中的第四款恶意软件及其它动态
  9. linux参考文献_小白爱折腾·其二:手机Linux部署DiscuzX论坛
  10. 网页后门危害大 网站安全狗帮助查杀
  11. android模拟器检测常用方法,Android模拟器检测方案优化
  12. dellemc服务器的显示器连接,通过显示器实现合作-DellEMC.PDF
  13. 一行JS代码实现ie浏览器升级弹窗
  14. SLAM Evaluation 之轨迹对齐论文翻译Closed-Form Solution of Absolute Orientation Using Orthonormal Matrices
  15. 【踩坑】python: This install of SoX cannot process .mp3 files
  16. linux下的外接显示器设置成竖屏
  17. 如何使用gcore以及viewcore排查问题
  18. 13个小众有趣的网站,只有程序员才看得懂
  19. swi-prolog安装及使用(基于)
  20. kylin打开Dashboard教程

热门文章

  1. qtvs添加qchart_如何使用Qt Designer在表单中插入QChartView?
  2. 荣耀智慧屏 55英寸屏幕 搭载鸿蒙OS,3799元起!荣耀智慧屏发布:55英寸4K全面屏+首发鸿蒙OS+无广告...
  3. 中文只占一个字符_男人宠妻的三大表现,就算只占一个,你都是嫁对了人!
  4. 英特尔 超核芯显卡 620mac_2020双11装机。科学计算工作站配置推荐。i9-10980XE加3080显卡加64G ECC内存...
  5. SQL:postgresql查询结果加一个自定义的列
  6. html5图片动且平移,HTML5 Canvas平移,放缩,旋转演示
  7. python爬虫面向对象_Python爬虫技术--基础篇--面向对象编程(上)
  8. 用setTimeout代替setInterval
  9. Java集合框架源码解读(2)——HashMap
  10. LMM(LightMoonMovie)亮月湾电影分享管理系统;