文 | 苏剑林
编 | 智商掉了一地
单位 | 追一科技

思想朴素却不平凡的分类问题后处理技巧,浅显易懂的讲解,拿来吧你!

顾名思义,本文将会介绍一种用于分类问题的后处理技巧——CAN(Classification with Alternating Normalization)。经过笔者的实测,CAN确实多数情况下能提升多分类问题的效果,而且几乎没有增加预测成本,因为它仅仅是对预测结果的简单重新归一化操作。

有趣的是,其实CAN的思想是非常朴素的,朴素到每个人在生活中都应该用过同样的思想。然而,CAN的论文却没有很好地说清楚这个思想,只是纯粹形式化地介绍和实验这个方法。本文的分享中,将会尽量将算法思想介绍清楚。

论文标题:
When in Doubt: Improving Classification Performance with Alternating Normalization

论文链接:
https://arxiv.org/abs/2109.13449

思想例子

假设有一个二分类问题,模型对于输入给出的预测结果是,那么我们就可以给出预测类别为;接下来,对于输入,模型给出的预测结果是,这时候处于最不确定的状态,我们也不知道输出哪个类别好。

但是,假如我告诉你:

  1. 类别必然是0或1其中之一;

  2. 两个类别的出现概率各为0.5。

在这两点先验信息之下,由于前一个样本预测结果为1,那么基于朴素的均匀思想,我们是否更倾向于将后一个样本预测为0,以得到一个满足第二点先验的预测结果?

这样的例子还有很多,比如做10道选择题,前9道你都比较有信心,第10题完全不会只能瞎蒙,然后你一看发现前9题选A、B、C的都有就是没有一个选D的,那么第10题在蒙的时候你会不会更倾向于选D?

这些简单例子的背后,有着跟CAN同样的思想,它其实就是用先验分布来校正低置信度的预测结果,使得新的预测结果的分布更接近先验分布。

2 不确定性

准确来说,CAN是针对低置信度预测结果的后处理手段,所以我们首先要有一个衡量预测结果不确定性的指标。常见的度量是“熵”[1],对于,定义为:

然而,虽然熵是一个常见选择,但其实它得出的结果并不总是符合我们的直观理解。比如对于和,直接套用公式得到,但就我们的分类场景而言,显然我们会认为比更不确定,所以直接用熵还不够合理。

一个简单的修正是只用前top-个概率值来算熵,不失一般性,假设是概率最高的个值,那么

其中。为了得到一个0~1范围内的结果,我们取为最终的不确定性指标。

算法步骤

现在假设我们有个样本需要预测类别,模型直接的预测结果是个概率分布,假设测试样本和训练样本是同分布的,那么完美的预测结果应该有:

其中是类别的先验分布,我们可以直接从训练集估计。也就是说,全体预测结果应该跟先验分布是一致的,但受限于模型性能等原因,实际的预测结果可能明显偏离上式,这时候我们就可以人为修正这部分。

具体来说,我们选定一个阈值,将指标小于的预测结果视为高置信度的,而大于等于的则是低置信度的,不失一般性,我们假设前个结果属于高置信度的,而剩下的个属于低置信度的。我们认为高置信度部分是更加可靠的,所以它们不用修正,并且可以用它们来作为“标准参考系”来修正低置信度部分。

具体来说,对于,我们将与高置信度的一起,执行一次 “行间标准化

这里的,其中乘除法都是element-wise的。不难发现,这个标准化的目的是使得所有新的的平均向量等于先验分布,也就是促使式(3)的成立。然而,这样标准化之后,每个就未必满足归一化了,所以我们还要执行一次 行内标准化

理论上,这两步可以交替迭代几次(不过实验结果显示一次的效果就挺好了)。最后,我们只保留最新的作为原来第个样本的预测结果,其余的均弃之不用。

注意,这个过程需要我们遍历每个低置信度结果执行,也就是说是逐个样本进行修正,而不是一次性修正的,每个都借助原始的高置信度结果组合来按照上述步骤迭代,虽然迭代过程中对应的都会随之更新,但那只是临时结果,最后都是弃之不用的,每次修正都是用原始的。

参考实现

这是笔者给出的参考实现代码:

# 预测结果,计算修正前准确率
y_pred = model.predict(valid_generator.fortest(), steps=len(valid_generator), verbose=True
)
y_true = np.array([d[1] for d in valid_data])
acc_original = np.mean([y_pred.argmax(1) == y_true])
print('original acc: %s' % acc_original)# 评价每个预测结果的不确定性
k = 3
y_pred_topk = np.sort(y_pred, axis=1)[:, -k:]
y_pred_topk /= y_pred_topk.sum(axis=1, keepdims=True)
y_pred_uncertainty = -(y_pred_topk * np.log(y_pred_topk)).sum(1) / np.log(k)# 选择阈值,划分高、低置信度两部分
threshold = 0.9
y_pred_confident = y_pred[y_pred_uncertainty < threshold]
y_pred_unconfident = y_pred[y_pred_uncertainty >= threshold]
y_true_confident = y_true[y_pred_uncertainty < threshold]
y_true_unconfident = y_true[y_pred_uncertainty >= threshold]# 显示两部分各自的准确率
# 一般而言,高置信度集准确率会远高于低置信度的
acc_confident = (y_pred_confident.argmax(1) == y_true_confident).mean()
acc_unconfident = (y_pred_unconfident.argmax(1) == y_true_unconfident).mean()
print('confident acc: %s' % acc_confident)
print('unconfident acc: %s' % acc_unconfident)# 从训练集统计先验分布
prior = np.zeros(num_classes)
for d in train_data:prior[d[1]] += 1.prior /= prior.sum()# 逐个修改低置信度样本,并重新评价准确率
right, alpha, iters = 0, 1, 1
for i, y in enumerate(y_pred_unconfident):Y = np.concatenate([y_pred_confident, y[None]], axis=0)for j in range(iters):Y = Y**alphaY /= Y.sum(axis=0, keepdims=True)Y *= prior[None]Y /= Y.sum(axis=1, keepdims=True)y = Y[-1]if y.argmax() == y_true_unconfident[i]:right += 1# 输出修正后的准确率
acc_final = (acc_confident * len(y_pred_confident) + right) / len(y_pred)
print('new unconfident acc: %s' % (right / (i + 1.)))
print('final acc: %s' % acc_final)

实验结果

那么,这样的简单后处理,究竟能带来多大的提升呢?原论文给出的实验结果是相当可观的:

▲原论文的实验结果之一

笔者也在CLUE上的两个中文文本分类任务上做了实验,显示基本也有点提升,但没那么可观(验证集结果):

IFLYTEK(类别数:119) TNEWS(类别数:15)
BERT 60.06% 56.80%
BERT + CAN 60.52% 56.86%
RoBERTa 60.64% 58.06%
RoBERTa + CAN 60.95% 58.00%

大体上来说,类别数目越多,效果提升越明显,如果类别数目比较少,那么可能提升比较微弱甚至会下降(当然就算下降也是微弱的),所以这算是一个“几乎免费的午餐”了。超参数选择方面,上面给出的中文结果,只迭代了1次,的选择为3、的选择为0.9,经过简单的调试,发现这基本上已经是比较优的参数组合了。

还有的读者可能想问前面说的“高置信度那部分结果更可靠”这个情况是否真的成立?至少在笔者的两个中文实验上它是明显成立的,比如IFLYTEK任务,筛选出来的高置信度集准确率为0.63+,而低置信度集的准确率只有0.22+;TNEWS任务类似,高置信度集准确率为0.58+,而低置信度集的准确率只有0.23+。

个人评价

最后再来综合地思考和评价一下CAN。

首先,一个很自然的疑问是为什么不直接将所有低置信度结果跟高置信度结果拼在一起进行修正,而是要逐个进行修正?笔者不知道原论文作者有没有对比过,但笔者确实实验过这个想法,结果是批量修正有时跟逐个修正持平,但有时也会下降。其实也可以理解,CAN本意应该是借助先验分布,结合高置信度结果来修正低置信度的,在这个过程中,如果掺入越多的低置信度结果,那么最终的偏差可能就越大,因此理论上逐个修正会比批量修正更为可靠。

说到原论文,读过CAN论文的读者,应该能发现本文介绍与CAN原论文大致有三点不同:

  1. 不确定性指标的计算方法不同。按照原论文的描述,它最终的不确定性指标计算方式应该是

也就是说,它也是top-个概率算熵的形式,但是它没有对这个概率值重新归一化,并且它将其压缩到0~1之间的因子是而不是(因为它没有重新归一化,所以只有除才能保证0~1之间)。经过笔者测试,原论文的这种方式计算出来的结果通常明显小于1,这不利于我们对阈值的感知和调试。

  1. 对CAN的介绍方式不同。原论文是纯粹数学化、矩阵化地陈述CAN的算法步骤,而且没有介绍算法的思想来源,这对理解CAN是相当不友好的。如果读者没有自行深入思考算法原理,是很难理解为什么这样的后处理手段就能提升分类效果的,而在彻底弄懂之后则会有一种故弄玄虚之感。

  2. CAN的算法流程略有不同。原论文在迭代过程中还引入了参数,使得式(4)变为

也就是对每个结果进行次方后再迭代。当然,原论文也没有对此进行解释,而在笔者看来,该参数纯粹是为了调参而引入的(参数多了,总能把效果调到有所提升),没有太多实际意义。而且笔者自己在实验中发现,基本已经是最优选择了,精调也很难获得是实质收益。

文章小结

本文介绍了一种名为CAN的简单后处理技巧,它借助先验分布来将预测结果重新归一化,几乎没有增加多少计算成本就能提高分类性能。经过笔者的实验,CAN确实能给分类效果带来一定提升,并且通常来说类别数越多,效果越明显

后台回复关键词【入群

加入卖萌屋NLP/IR/Rec与求职讨论群

后台回复关键词【顶会

获取ACL、CIKM等各大顶会论文集!

 

[1]  苏剑林. (Dec. 1, 2015). 《“熵”不起:从熵、最大熵原理到最大熵模型(一)》[Blog post]. Retrieved from https://kexue.fm/archives/3534

分类问题后处理技巧CAN,近乎零成本获取效果提升相关推荐

  1. 微信运营零成本拉新技巧,会与不会之间只差看与不看!

    运营微信公众号应该是每一家企业都应该重视的,如果你是微信公众号运营者,正好发愁不知道如何行之有效的运营涨粉,不妨静下心来花个10分钟左右的时间读读这篇诚意满满的文章--三年微信运营摸打滚爬后的体会总结 ...

  2. eclipse的tomcat如何运行自动弹网页_如何在 3 天内零成本完成 AI 小程序开发

    基于对 AI 的爱好与兴趣,我走上了独自钻研机器学习的道路.和所有热爱 AI 的人们一样,在一段孤独的摸索旅程中,我勉强完成了几次深度学习模型的训练.作者:泰斯特想说 故事的起源 基于对 AI 的爱好 ...

  3. 电脑qq收藏在哪里打开_外贸询盘怎么来?零成本用Linkedin批量开发客户!外贸人收藏...

    前言 做外贸开发国外客户的方法有很多,各种渠道以及方法我往期文章中都有详细干货解读,通过网络寻找客户已经是很普遍很流行的方法,对于网络开发的效果而言,每个公司每个人看法不同.今天介绍一下,怎么零成本利 ...

  4. 拳王公社:网赚高手的零成本引流秘籍,这4个才是核心思维!

    今天给大家分享一个零成本引流的一个互联网思维. 现在不管是做互联网行业还是线下生意,最重要的就是流量,创业的绝大多数人都卡在引流上,特别是新手同学! 很多人的都卡在如何引流上,特别是很多刚刚进入互联网 ...

  5. 百度贴吧自动发帖_引流网赚之百度贴吧引流窍门:实操引流教程百度贴吧零成本自动顶帖+10分钟学会豆瓣顶帖引流...

    引流网赚之百度贴吧窍门:实操引流教程<百度贴吧零成本自动顶帖>+<10分钟学会豆瓣顶帖引流> 关于百度贴吧的引流方式有很多,像常见的关键词排名引流,比如,百度贴吧引流效果好不好 ...

  6. 零成本拓客秘籍丨券商玩转社群营销的5个步骤

    近几年,随着互联网的发展以及智能移动终端的普及,券商业务逐渐成为高度线上化的行业.然而,随着移动互联网进入下半场,券商获取流量越来越难,成本越来越高.加之,股市牛短熊长更是让各大券商饱尝巅峰之后的漫长 ...

  7. 爆款打造之中小卖家如何做到零成本选/测款?(一)

    虽然从某种意义上来说,淘宝在有意识的去爆款化,比如说流量的碎片化分配,把原本集中在有限宝贝上的流量(尤其是自然搜索流量),为了让更多的商家都活跃起来,为了让更多的优秀产品都有机会,开始分散的匹配给其他 ...

  8. 免费软件中的零成本营销

    目前,计算机软件行业已经发展成为一个成熟的行业,它以技术密集.资金密集.用户多样.产品生存期短.竞争压力强等特点成为一个独特的行业.作为不以营利为主要目的免费软件 (Freeware) ,在茫茫的软件 ...

  9. 零成本创业项目,收入远比打工高,很值得推荐

    科思创业汇 大家好,这里是科思创业汇,一个轻资产创业孵化平台.赚钱的方式有很多种,我希望在科思创业汇能够给你带来最快乐的那一种! 如何赚钱?这一直是人们最喜欢讨论的话题之一,因为人们生活在这个世界上, ...

最新文章

  1. 彩图完美解释:麦克斯韦方程组
  2. java der pem_JAVA解析各种编码密钥对(DER、PEM、openssh公钥)
  3. 云端资源,“掌”握手中 ——关于 阿里云 App你不知道的五件事
  4. html自适应_web前端入门到实战:HTML 文档流,设置元素浮动,导致父元素高度无法自适应的解决方法...
  5. Windows10安装Anaconda和Pytorch(CPU版,无GPU加速)
  6. 硬件基础知识---(3)电阻2
  7. k8s ubuntu cni_周一见 | CNCF 人事变动、最新安全漏洞、K8s 集群需警惕中间人攻击...
  8. linux安装vnc4server,Ubuntu 18.04安装vnc4server
  9. F. Gourmet and Banquet(贪心加二分求值)
  10. “System.InvalidOperationException”类型的未经处理的异常在 ESRI.ArcGIS.AxControls.dll 中发生...
  11. 基于SSM的选课系统
  12. django.forms生成HTML,python – 在django中为表单自动生成表单字段
  13. Atitit spring springboot 集成mybatis法 目录 1.1. 使用spring管理数据源。。需要修改spring、 配置 1 1.2. 直接代码集成,无需修改任何配置 1
  14. Android 打开WIFI并快速获取WIFI的信息
  15. 前端打印复选框的打勾时问题求教
  16. 职业倾向测试脸型软件,气质类型测试适合职业
  17. HDS 高端存储TC原理和配置总结
  18. vue实现按钮倒计时功能
  19. 【Docker学习笔记 一】Docker基本概念及理论基础
  20. 使用手机访问电脑上开发的html页面

热门文章

  1. linux tar order
  2. C++ 添加程序图标到我的电脑
  3. 一个简单又高效的日志系统
  4. STL中vectortype的复制
  5. 使用Epoll 在 Linux 上开发高性能应用服务器
  6. 在Linux内核使用Kasan
  7. 内核抢占,让世界变得更美好 | Linux 内核
  8. 视频操作_01视频读写:视频读写+读取视频+保存视频
  9. LeetCode 2016. 增量元素之间的最大差值
  10. LeetCode 1712. 将数组分成三个子数组的方案数(前缀和 + 二分查找)