《Multi-Head Multi-Loss Model Calibration》

摘要

  • 提供有意义的不确定性估计对于在临床实践中成功部署机器学习模型至关重要。
  • 不确定性量化的一个核心方面是模型返回与模型正确的实际概率一致的预测的能力,也称为模型校准。
  • 没有一种技术可以与简单但昂贵的训练深度神经网络集成的方法相匹配。
  • 本文介绍了一种简化的集成形式,绕过了昂贵的深度集成的训练和推断,保持了其校准能力。
  • 用一组头部来代替网络末端的常见线性分类器,这些头部被不同的损失函数监督,以加强其预测的多样性。
  • 每个头都被训练成最小化加权交叉熵损失,但不同分支之间的权重是不同的。
  • 所得到的平均预测可以在不牺牲准确性的情况下,在两个具有挑战性的数据集中实现出色的校准,用于组织病理学和内窥镜图像分类。
  • 实验表明,MultiHead Multi-Loss分类器本质上是校准良好的,优于其他近期校准技术,甚至挑战Deep Ensembles的性能。

引言

在训练有监督的计算机视觉模型时,我们通常专注于提高它们的预测性能,但对于安全关键任务来说,同样重要的是它们对自己的预测表达有意义的不确定性的能力。在机器学习的背景下,我们经常区分两种类型的不确定性:认知的和任意的。简单地说,认知不确定性来自于模型对它被训练来解决的问题的不完全知识,而任意不确定性描述了对数据的无知用于学习和预测。

例如,如果一个分类器已经学会了在结肠组织病理学上预测癌组织的存在,并且它的任务是对乳房活检进行预测,它可能会显示出认知的不确定性,因为它从未针对这个问题进行过训练。尽管如此,如果我们向模型询问具有模糊视觉内容的结肠活检,即难以诊断的图像,那么它可以表达任意的不确定性,因为它可能不知道如何解决问题,但模糊性来自于数据。

认识性和任意性之间的区别通常是模糊的,因为其中一个的存在并不意味着另一个的不存在。此外,在强烈的认知不确定性下,任意不确定性估计可能变得不可靠

产生良好的不确定性估计是有用的,例如,识别模型预测可信度低的测试样本,应该对其进行审查。

报告不确定性估计的一种直接方法是将模型的输出(其软最大概率的最大值)解释为其预测置信度。当这个置信度与实际精度一致时,我们说模型被校准。

训练校准(Training-Time Calibration)

流行的训练时间方法包括通过正则化来降低预测熵,例如标签平滑或MixUp,或平滑预测的损失函数。这些技术通常依赖于正确调优超参数,控制辨别能力和置信度之间的权衡,并且可以以降低预测性能为代价轻松实现更好的校准。或基于空间变化和边缘的标签平滑,它们扩展和改进了用于生物医学图像分割任务的标签平滑

事后校准(Post-Hoc Calibration)

提出了诸如Temperature Scaling及其变体等事后校准技术,通过在模型的输出概率上应用简单的单调映射来纠正过度或不自信的预测。它们最大的缺点是依赖于使用验证数据学习映射时隐含的假设:这些方法难以推广到未见数据。

除此之外,这些技术可以与训练时方法相结合,并带来复合性能改进。

模型集成(Model Ensembling)

改进校准的第三种方法是聚合几个模型的输出,这些模型事先经过训练,因此它们的预测具有一定的多样性。在深度学习中,模型集成被认为是产生有意义的不确定性的最成功的方法估计。深度集成的一个明显弱点是需要训练,然后为推理目的保留一组模型,这导致了对于较大的体系结构来说相当大的计算开销。在医学图像计算中应用集合的例子包括。

在这项工作中,我们通过不同损失函数训练的多头模型来实现模型校准。从这个意义上说,我们的方法最接近于最近在多输出架构上的一些工作,其中多分支CNN在组织病理学数据上进行训练,通过通过损失最小的分支反向传播梯度来强制不同头部的专门化。

与我们的方法相比,确保正确的梯度流以避免死头需要特别的计算技巧;此外,没有对域内数据和任意不确定性的模型校准进行分析,主要集中在异常检测上。

主要贡献

  • 利用多损失分集来实现比其他基于学习的方法更大的置信度校准,
  • 避免使用训练数据来学习后处理映射,就像大多数事后校准方法所做的那样
  • 避免深度集成的计算开销

校准多头模型

多头集成差异(Multi-Head Ensemble Diversity)

考虑一个k类分类问题,一个神经网络Uθ取一个图像x并将其映射到一个表示Uθ(x)∈RN上,它被f线性变换成一个logits向量z = f(Uθ(x))∈RK。然后通过软最大运算p = σ(z)将其映射为概率p∈[0,1]K的向量。

如果x的标签是y∈{1,…, K},我们可以用交叉熵损失来测量与预测p相关的误差。

将f替换为M个不同的分支f 1,…, f M,它们中的每一个仍然取相同的输入,但将其映射到不同的logits。

然后对得到的概率向量pm = σ(zm)求平均,得到最终的预测pµ= (1/M) p M pm。我们感兴趣的是反向传播损耗,以找到每个分支的梯度。

性质1:对于图1中的m头分类器,fm的横移损失对zm的导数为

可以看到分支m的梯度将缩放,这取决于质量pmy被fm放置在正确类别上的概率是多少,相对于所有正面放置的总质量。换句话说,如果每个头都学会了对特定样本产生类似的预测(不一定正确),那么这个网络的优化过程将导致所有人都得到相同的更新。因此,构成网络输出pµ的预测的多样性将被破坏

多头多损失模型

在训练过程中,在多头模型中获得更多样化的梯度更新的一种方法是用不同的损失函数来监督每个头部。

应用加权交叉熵损失,为每个头分配了不同的权重向量,以这样的方式,不同的损失函数Lωm-CE将监督每个分支fm的中间输出,类似于深度监督策略,但强制多样性。完整模型的总损失为作用于平均预测的单头损失和总损失之和:



式中,p = (p1,…, pM)是一个数组,是网络做出的所有预测。

分支fm处的多头损失梯度为:

在所有分支中拥有相同的权重向量并不能打破所有正面做出类似预测的情况下的对称性。

对于任意两个给定的头fmi, fmj,有ωmi = ωmj和pmi≈pmj,即pm≈pµ,任意m,两个头的梯度的范数之差为:

假设多头模型的分支比我们问题中的类的数量少,即M≤K,否则我们需要在同一类别中有不同的分支。

随机选择的N/K个类别的子集与分支fm相关联,这些类别的权重为ω= K。ω中其余类别的权重为ω = 1/K。

在一个有4个类别和2个分支的问题中,我们可以有ω1 =[2,1 /2, 2,1 /2]和ω2 =[1/2, 2,1 /2, 2]。如果N不能被K整除,则提醒类别将被分配给随机分支进行专门化。

模型评估标准

在测量模型校准时,标准方法依赖于观察不同置信度波段b下的测试集精度。例如,采用所有预测置信度约为c = 0.8的测试样本,一个校准良好的分类器将在该测试子集中显示约80%的精度。这可以通过预期校准误差(ECE)来量化,由:


在实践中,就实际可用性而言,ECE本身并不是一个很好的衡量标准,可以拥有一个完美的ECE校准模型,但没有预测能力。

利用负对数似然(NLL)和标准精度,与ECE相反,即使是校准不良的模型也可以很高。最后,将ECE、NLL和准确性综合排名时的平均排名作为总结指标。

实验结果

数据集

  • the Chaoyang dataset
  • Kvasir

性能分析

训练了三种不同的多头分类器:

  • 2-head模型,其中每个头优化为标准(未加权)CE,称为2HSL(2头-单损失)
  • 2head模型,但每个磁头最小化不同的CE损失,我们称此模型为2HML (2 Heads-Multi Loss))
  • 将正面的数量增加到4个,我们将这个模型称为4HML。

为了进行比较,包括一个标准的无损耗单头分类器(SL1H),加上使用标签平滑(LS)、基于边缘的标签平滑(MbLS)、MixUp和使用DCA损失训练的模型。还展示了Deep Ensembles (D-Ens)的性能。

实验比较简单,不展开说明

总结

性能分析

训练了三种不同的多头分类器:

  • 2-head模型,其中每个头优化为标准(未加权)CE,称为2HSL(2头-单损失)
  • 2head模型,但每个磁头最小化不同的CE损失,我们称此模型为2HML (2 Heads-Multi Loss))
  • 将正面的数量增加到4个,我们将这个模型称为4HML。

为了进行比较,包括一个标准的无损耗单头分类器(SL1H),加上使用标签平滑(LS)、基于边缘的标签平滑(MbLS)、MixUp和使用DCA损失训练的模型。还展示了Deep Ensembles (D-Ens)的性能。

[外链图片转存中…(img-fmL3aeya-1678070054916)]

实验比较简单,不展开说明

总结

多头多损失网络是具有增强校准的分类器,与单头网络相比,预测性能没有下降。这是通过同时优化几个输出分支来实现的,每个分支最小化不同加权的交叉熵损失。权重是互补的,确保每个分支在专攻原始数据类别的子集时得到奖励。

《Multi-Head Multi-Loss Model Calibration》相关推荐

  1. 少样本学习原理快速入门,并翻译《Free Lunch for Few-Shot Learning: Distribution Calibration》

    ICLR2021 Oral<Free Lunch for Few-Shot Learning: Distribution Calibration> 利用一个样本估计类别数据分布 9行代码提 ...

  2. 《深度探索C++对象模型(Inside The C++ Object Model )》学习笔记

    来源:http://dsqiu.iteye.com/blog/1669614 之前一直对C++内部的原理的完全空白,然后找到<Inside The C++ Object Model>这本书 ...

  3. 读论文《A Neural Probabilistic Language Model》

    读论文<A Neural Probabilistic Language Model> 原文地址:http://blog.csdn.net/qq_31456593/article/detai ...

  4. 论文《A Neural Influence Diffusion Model for Social Recommendation》阅读

    论文<A Neural Influence Diffusion Model for Social Recommendation>阅读 论文概况 Abstract Introduction ...

  5. 整数智能参编全球首个《人工智能研发运营一体化(Model/MLOps)能力成熟度模型》标准

    今日,<人工智能研发运营一体化(Model/MLOps)能力成熟度模型>标准[第一部分:开发管理]正式发布,整数智能作为中国通信标准化协会会员,参与该行业权威标准的编写工作,这是对整数智能 ...

  6. 问题六十八: 着色模型(shading model)(0)——《Ray Tracing from the Ground Up》代码的移植

    用ray tracing的方式来生成图形,主要是分两步: 1,几何建模.即为"光线撞击物体",求得撞击点. 2,给撞击点着色.我们之前的做法是:根据被撞击物体的材质(材质的颜色.材 ...

  7. 【论文学习】《A Survey on Neural Speech Synthesis》

    <A Survey on Neural Speech Synthesis>论文学习 文章目录 <A Survey on Neural Speech Synthesis>论文学习 ...

  8. 表情识别综述论文《Deep Facial Expression Recognition: A Survey》中文翻译

    本篇博客为论文<Deep Facial Expression Recognition: A Survey>的中文翻译,如有翻译错误请见谅,同时希望您能为我提出改正建议,谢谢! 论文链接:h ...

  9. 《YOLOX: Exceeding YOLO Series in 2021》阅读

    文章下载: YOLOX-Exceeding YOLO Series in 2021.pdf 摘要 本篇文章中,我们展示了在 YOLO 系列检测器上的改进,并获得了一个高性能的目标检测器 -- YOLO ...

最新文章

  1. ASP.NET Core Web Razor Pages系列教程一:使用ASP.NET Core 创建一个Razor Pages网络应用程序
  2. comsol计算数据导出matlab,comsol4.2怎样在matlab中通过函数输出数据
  3. Java第一次读文件慢_Java 关于文件读取速度问题,求助,谢谢啦
  4. 计算机网络简易测试仪,测线仪
  5. Flink SQL的N way join
  6. SAP ABAP实用技巧介绍系列之 通过ST03G查询指定transaction的trace data
  7. 开发基础框架:mybatis-3.2.8 +hibernate4.0+spring3.0+struts2.3
  8. 游戏开发需要掌握的法则有哪些?
  9. 阿里巴巴P3C java编程规范(最新版github下载)
  10. 20款常用的商业智能(BI)工具分享(最新)
  11. 那位仁兄或者仁姐能给小弟一个菊花论坛的邀请码
  12. neon 指令 c语言,Neon指令集优化快速入门教程
  13. linux沙盒机制6,详解Android应用沙盒机制
  14. 【python圆周率计算】python计算圆周率π的值到任意位
  15. 实施MES系统的七大核心要点,每一点都很重要,不看后悔
  16. chrome 插件 导出与导入,以apizza SQ为例
  17. Qt如何在QTabWidget上绘图
  18. SPark高集群从头到尾踩坑记录
  19. 水下目标检测——论文阅读与整理
  20. ppst——技术视频 jquery ajax 请求 同步异步的执行的设置

热门文章

  1. 那些年我们一起的CSDN
  2. win10 cmd打开php文件,win10如何打开php文件
  3. oracle查回收站大小,ORACLE回收站
  4. centos7安装docker
  5. Matlab高光谱遥感数据
  6. 设计师:设计师知识储备之欧式雕花家具(欧式雕花家具-圆雕、透雕、浮雕、平刻)之详细攻略
  7. 飞蛾火焰优化(MFO)算法——原理分析
  8. J-LINK直接烧录.bin文件到开发板
  9. Ubuntu限制本地上传、下载网速(限速)
  10. 李昱:腾讯产品登录协议详解