©PaperWeekly 原创 · 作者 | 苏剑林

单位 | 追一科技

研究方向 | NLP、神经网络

最近 arXiv 上的一篇论文《EXACT: How to Train Your Accuracy》[1] 引起了笔者的兴趣,顾名思义这是介绍如何直接以准确率为训练目标来训练模型的。正好笔者之前也对此有过一些分析,如《函数光滑化杂谈:不可导函数的可导逼近》[2]、《再谈类别不平衡问题:调节权重与魔改 Loss 的对比联系》等, 所以带着之前的研究经验很快完成了论文的阅读,写下了这篇总结,并附上了最近关于这个主题的一些新思考。

失实的例子

论文开头指出,我们平时用的分类损失函数是交叉熵或者像 SVM 中的 Hinge Loss,这两个损失均不能很好地拟合最终的评价指标准确率。为了说明这一点,论文举了一个很简单的例子:假设数据只有 三个点,-1 和 1 分别代表负类和正类,待拟合模型是 f(x)=x-b,b 是参数,我们希望通过 来预测类别。如果用“sigmoid + 交叉熵”,那么损失函数就是 , 代表一对标签数据;如果用 Hinge Loss,则是 。

由于只是一个一维模型,我们可以直接网格搜索出它的最优解,可以发现如果用“sigmoid + 交叉熵”的话,损失函数的最小值在 b=0.7 取到,而如果是 Hinge Loss,那么 。然而,如果要通过 完全分类正确,那么 才行,因此这说明了交叉熵或 Hinge Loss 与最后评测指标准确率的不一致性。

看上去是一个很简明漂亮的例子,但笔者认为它是不符合事实的。其中,最大的问题是模型设置温度参数,即一般出现的模型是 而不是 ,刻意去掉温度参数来构造不符合事实的反例是没有说服力的,事实上补上可调的温度参数后,这两个损失都可以学到正确的答案。更不公平的是,后面作者在提出自己的方案 EXACT 时,是自带温度参数的,并且温度参数是关键一环,换句话说,在这个例子中,EXACT 比其他两个损失好,纯粹是因为 EXACT 有温度参数。

新瓶装旧酒

然后我们来看论文所提出的方案——EXACT(EXpected ACcuracy opTimization)。从事后来看,EXACT 很是莫名其妙,因为作者是直接不加任何解释地从重参数的角度重新定义了一个条件概率分布 :

其中 是一个向量网络, 是一个标量网络, 跟 维度相同,每个分量是独立同分布地从 采样得到。关于用重参数来定义概率分布的做法,我们在上一篇文章《从重参数的角度看离散概率分布的构建》已经讨论过,这里不重复。

紧接着,有了这个新的 ,作者直接以

作为损失函数,全文的理论框架基本上到此结束。

由此,我们可以总结 EXACT 的莫名其妙之处了。在《从重参数的角度看离散概率分布的构建》我们知道,从重参数角度来看,Softmax 对应的噪声分布是 Gumbel 分布,而 EXACT 换成了正态分布,那么好在哪?为什么会好?这些全无解释。

此外,式 (2) 的相反数是准确率的光滑近似,这本已“广为人知”,但同时也有一个广为人知的结论是在 Softmax 情况下直接优化式 (2) 的效果通常都是不如优化交叉熵的,现在只是换了一个“新瓶”(新概率分布的构建方法)装“旧酒”(同样的准确率光滑近似),真的就能有提升吗?

实验难复现

原论文给出了非常惊人的实验结果,显示 EXACT 几乎总是 SOTA:

然而,笔者根据自己的理解尝试实现了 EXACT,并在 NLP 任务上测试,结果显示 EXACT 完全不能达到“Softmax+交叉熵”的水平。此外,原论文还提到优化 会比 (2) 更好,但笔者的结果是该变体连 (2) 都比不上。总的来说,笔者的测试结论与原论文是大相径庭的。

由于原论文还没有开源代码,因此笔者还不能对论文实验的可靠性做进一步的判断。但从笔者的理论理解和初步的实验结果来看,直接优化式 (2) 是很不可能达到优化交叉熵的效果的,仅仅修改构建概率分布的方式,应该很难形成实质的提升。如果读者有新的实验结果,欢迎进一步交流分享。

一个新视角

从数值上来比较,式 (2) 确实比交叉熵 更贴合准确率。但为什么优化交叉熵往往能获得更好的的准确率?笔者原来也百思不得其解,在《再谈类别不平衡问题:调节权重与魔改 Loss 的对比联系》中,笔者设置将它视为“公理”来使用,实属无奈。

直到有一天,笔者突然意识到了一个关系:随着训练,多数 会慢慢接近于 1,于是可以用近似 得到:

于是我们就能解释为什么优化交叉熵也能获得很好的准确率了,因为从上式我们可以发现,交叉熵优化到中后期跟式 (2) 基本是等价的,也就是同样在优化准确率的光滑近似!

那交叉熵相比式 (2) 的好处在哪呢?差别就在于当 时, 与 的差距。当 时,即目标类的概率很小,意味着分类可能很不准确,这时候 给出的是一个会趋于无穷大的结果,但 最多就只能给出 1。这样一比较,我们就发现交叉熵的 对错误分类的样本的惩罚更大,因此它会更倾向于修正分类错误的样本,同时最终分类结果又跟直接优化准确率的光滑近似相近。

由此,我们可以得到一个优秀的损失函数的新视角:

首先寻找评测指标的一个光滑近似,最好能表达成每个样本的期望形式,然后将错误方向的误差逐渐拉到无穷大(保证模型能更关注错误样本),但同时在正确方向保证与原始形式是一阶近似。

文章小结

本文主要讨论了如何优化准确率的问题,其中先简单介绍和评述了一下最近的论文《EXACT: How to Train Your Accuracy》[1],然后就“为什么优化交叉熵能获得更好的准确率结果”给出了自己的分析。

参考文献

[1] https://arxiv.org/abs/2205.09615

[2] https://kexue.fm/archives/6620

更多阅读

#投 稿 通 道#

 让你的文字被更多人看到 

如何才能让更多的优质内容以更短路径到达读者群体,缩短读者寻找优质内容的成本呢?答案就是:你不认识的人。

总有一些你不认识的人,知道你想知道的东西。PaperWeekly 或许可以成为一座桥梁,促使不同背景、不同方向的学者和学术灵感相互碰撞,迸发出更多的可能性。

PaperWeekly 鼓励高校实验室或个人,在我们的平台上分享各类优质内容,可以是最新论文解读,也可以是学术热点剖析科研心得竞赛经验讲解等。我们的目的只有一个,让知识真正流动起来。

如何训练你的准确率?相关推荐

  1. ImageNet训练再创纪录!谷歌提出1个小时训练EfficientNet,准确率高达83%!

    点击上方,选择星标或置顶,不定期资源大放送! 阅读大概需要15分钟 Follow小博主,每天更新前沿干货 本文转载自:AI科技评论  |  注:文末附论文PDF下载 作者 | 青暮 近年来,随着深度学 ...

  2. 训练集山准确率高测试集上准确率很低_推荐算法改版前的AB测试

    编辑导语:所谓推荐算法就是利用用户的一些行为,通过一些数学算法,推测出用户可能喜欢的东西:如今很多软件都有这样的操作,对于此系统的设计也会进行测试:本文作者分享了关于推荐算法改版前的AB测试,我们一起 ...

  3. 训练集山准确率高测试集上准确率很低_拒绝DNN过拟合,谷歌准确预测训练集与测试集泛化差异,还开源了数据集 | ICLR 2019...

    鱼羊 发自 凹非寺 量子位 报道 | 公众号 QbitAI 深度神经网络(DNN)如今已经无处不在,从下围棋到打星际,DNN已经渗透到图像识别.图像分割.机器翻译等各种领域,并且总是表现惊艳. 然而, ...

  4. Python 教你训练一个98%准确率的微博抑郁文本分类模型(含数据)

    Paddle是一个比较高级的深度学习开发框架,其内置了许多方便的计算单元可供使用,我们之前写过PaddleHub相关的文章: 1.Python 识别文本情感就这么简单 2.比PS还好用!Python ...

  5. 【问题解决】训练和验证准确率很高,但测试准确率很低

    前情提要: 采用ResNet50预训练模型训练自己的图像分类模型.训练和验证阶段准确率很高,但随机输入一张图片时,大多数情况下依旧预测得不准确. (于是开始搜索各种"验证准确率高但测试准确率 ...

  6. keras训练模型,训练集的准确率很高,但是测试集准确率很低的原因

    今天在测试模型时发现一个问题,keras训练模型,训练集准确率很高,测试集准确率很低,因此记录一下希望能帮助大家也避坑: 首先keras本身不同的版本都有些不同的或大或小的bug,包括之前也困扰过我的 ...

  7. 在服务器上远程使用tensorboard查看训练loss和准确率

    本人使用的是vscode 很简单 from torch.utils.tensorboard import SummaryWriter writer = SummaryWriter('./logs')w ...

  8. Java统计做题正确率_ResNet:训练期间的准确率为100%,但使用相同数据的预测准确率为33%...

    我之前遇到过类似的问题,但解决方案非常简单 . 你需要增加时代数 . 这是1000个纪元后的输出 [[ 9.99999881e-01 8.58182432e-08 9.54004670e-12] [ ...

  9. tensorflow2.0中valid_data的作用是在训练的过程对对比训练数据与测试数据的准确率 损失率,便于判断模型的训练效果:是过拟合还是欠拟合(过拟合)

    tensorflow2.0中valid_data的作用是在训练的过程对对比训练数据与测试数据的准确率,便于判断模型的训练效果:是过拟合还是欠拟合 过拟合:训练数据的准确率较高而测试数据的准确率较低 欠 ...

最新文章

  1. java生成sql语句_java生成SQL语句
  2. C#调用USER32.DLL的API函数
  3. 35所大学获批新增「人工智能」本科专业,工学学位、四年制
  4. 2020 年最新版 68 道Redis面试题,20000 字干货,赶紧收藏起来备用!
  5. Python __str__() 方法
  6. Symbol学习与回忆
  7. c语言程序设计网络作业,北语网院17春《C语言程序设计》作业_2满分答案
  8. 利用webBrowser来实现自动登录网站
  9. 金立S6:因“耀”开启金属手机2.0时代
  10. MyQR库自动为网址生成二维码
  11. C++ 类中的静态成员变量,静态成员函数
  12. 苹果绕过ID_亲测:苹果手机绕过ID,到底能不能用?结果不太理想
  13. 使用pdfFactory Pro虚拟打印机给文档加上水印
  14. 企业做营销型网站的目的
  15. java.lang.ClassNotFoundException问题的解决
  16. 【年终总结】回顾我平凡且不平凡的 2021
  17. 基于共享单车轨迹的自行车道规划(读书笔记)
  18. SpringBoot 实现大文件视频转码(转码基于FFMPEG实现)
  19. C# IDE SharpDevelop的一些缺陷
  20. EF6 批量更新删除数据

热门文章

  1. 【blender】材质球参数及各种问题
  2. 酶切位点分析(the analysis of enzyme sites)
  3. 保持忠贞是不容易的,需要持续付出努力
  4. 离散数学练习,赵钱孙李周,何人出国?理论+代码
  5. Python 处理超大 JSON 文件,这个方法简单!
  6. asm使用指南中文-md版快速入门详解
  7. oracle表名中带@什么意思
  8. OSPF高级特性 —— 路由通告
  9. 在路上---一个平凡人的2016年总结及2017年展望
  10. 清泉HAL库开发STM32之GPIO