摘要:现有大部分机器学习或者深度学习的研究工作大多着眼于模型或应用,而忽略对数据本身的研究。今天给大家介绍的几个文章就关注于在机器学习中如何通过对训练集的选择和加权取得更好的测试性能。

在开始之前,先和大家简单回顾一下我个人觉得相关的几方面工作。其实远在深度学习时代之前,根据loss对样本加权的工作就已经有很多。神奇的是,其实在一条线上有着截然相反的想法的研究:第一类工作的想法是如果一个样本训练得不够好,也就是loss高的话,那么说明现在的模型没有很好fit到这样的数据,所以应该对这样的样本给予更高的权重。这一类工作就对应到经典的Hard Negative (Example)Mining,近期的工作如Focal Loss也是这个思想。另一类工作的想法是学习需要循序渐进,应该先学习简单的样本,逐渐加大难度,最终如果仍然后Loss很大的样本,那么认为这些样本可能是Outlier,强行fit这些样本反而可能会使泛化性能下降。这一类中对应的是Curriculum Learning或者Self-Paced Learning类型的工作。本质上,这两个极端对应的是对训练数据本身分布的不同假设。第一类方法认为那些fit不好的样本恰恰是模型应当着重去学习的,第二类方法认为那些fit不上的样本则很可能是训练的label有误。

所以,一个很有趣的问题是:我们应该何时在这两种极端之间选择?在这两个极端之间是不是会有更好的权衡?这个问题乍看上去没什么简单的办法,今天要介绍的文章就是引入了一个新的信息源——一个无偏的验证集来解决这个问题。有了这样额外的信息源之后,这个问题就变成了如何对每个样本加权,使得验证集上的loss下降。一个naive的办法自然是用leave one out,去掉每个样本训练一个model,但是这个cost会非常地大,实际上是不可行的。所以核心就在于如何对model进行近似,用尽量低的代价尽量准确地获得这样的信息。

在[1]中,作者使用了一个统计学中经典工具Influence Function。作者首先从一个twice-differentiable的trictly convex函数出发一步步拓展结论。基本思路是考虑如果我们增加eps某一个样本的weight,会对model的参数有怎样的影响:

中H在这里是二阶Hessian矩阵。如果有熟悉优化的朋友可以看出来这个形式其实和Newton法很像。实际推导也没有用到很深奥的数学知识,有兴趣的读者可以参照下文章中的附录。我们更进一步可以使用链式法则得出对z加大eps的weight后对于某个测试样本\z_test的loss变化:

这个结论其实很有指导意义,告诉了我们在一个训练好的model上,如何不重新训练就能评估一个样本对某个测试样本的重要性。然而想直接使用这个办法还有最后一个障碍就是Hessian矩阵的计算,对于CNN这样参数量巨大的模型来说,想要完全准确计算代价依旧很高。所以作者又提出了两种近似Hessian矩阵的方法,分别是使用Conjugate Gradient和Stochastic Approximation。由于这不是文章的重点,所以不在这里展开。

作者在文中还将这样一个结论推广,一方面分别给出了在非凸函数、非收敛的情形以及不可微分的loss下的分析,此方法或此方法的变种都能很有效地预测在验证集上测试性能的变化。另一方面还类似地推导了对某一个训练样本本身扰动带来的变化,这个结论自然地和adverserial sample这个问题联系了起来。为了验证这样近似的有效性,作者还去和leave one out的exact influence做了比较,可以看到,在线性模型下这个模型近似的效果基本和GT一致,在CNN这种非凸情形下,虽然肯定不如线性模型下的效果,但是和GT仍然保持了高度相关性。

有了上面的结论,一个自然的想法便是,找出那些对于降低验证集loss没有帮助的样本,把他们从训练集中排除,从而提升model的性能。[2]正就是做了这样的工作,作者们也给这个工作起了一个很有意思的名字——Data Dropout。具体做法是首先使用全部数据得到一个初始模型,然后在这样的初始模型上计算每个样本的influence,去掉那些对降低验证集loss的样本后,使用新的训练集再次训练得到最终的模型。作者也分别在不同规模(CIFAR, ImageNet)和不同应用(Classification, Denoising)中证明了这样做的有效性。

此方法筛掉的样本数其实也并不多,作者在下表中也报告去除掉的“不好”样本的个数,这几个数据集看来基本是在1%到3%左右,但是带来的性能提升却是显著的。

另一个相关的工作[3]也是同样的出发点,但是更好地利用了Deep Learning中已有操作,使用meta-learning的办法去学习样本的weight。首先这个问题可以写成一个two-level交替优化的目标函数:

其中f_i为原始的training loss,f_i^v为validation loss。但是很显然,直接优化这样一个目标需要在两步之间不停交替迭代,代价颇高。所以作者提出了对于第一个目标函数使用一步gradient descent的近似,同时和上文一样对样本加入eps的weight,上面的优化目标即可变为:

更进一步,对于eps的优化同样可以进行一步gradient descent的近似,如果在eps=0附近展开的话,可以得到:

为了保证训练时每个batch的effective learning rate一致,作者还对每个batch下的weight做了normalize:

使用这样的sample weight,对于model的参数theta再次进行一次gradient计算即可完成对于此batch的更新。

文中使用MLP为例,给出了一个导出的weight示例:

其中z{i, l}代表的是第i个样本第l层的feature,上标v代表的是validation set。同样,g{i, l}代表是的第i个样本在第l层收到的gradient。这个结果有着十分直观的含义:当一个样本和validation set中的样本feature接近,且gradient方向接近的时候,那么我们会增加这个样本的weight。换句话说,如果一个样本和validation set中的样本接近,且训练的目标一致,那么我们就应该更好地fit这个样本,因为它能直接帮助validation set降低loss。

作者基于这样的方法,还证明了reweighted training在mild condition下的收敛性和传统的SGD算法一致,且可以收敛到validation loss的一个critical point。

在实验部分,作者分别在class imbalanced和noisy label的情况下测试了这个算法。都分别证明了其有效性,但是比较遗憾的是没有和前面提到的[1]进行比较,实验使用的数据集规模也都比较小。

其实还有一些没有覆盖到的paper[4],我个人觉得没有这两篇有代表性所有就不展开了。总结一下,我觉得这是一个和模型与应用同等重要的问题,其实自己也曾经思索过一段时间但是没有很好的想法。这几个工作提供了一个很不错的思路,即引入一个新的无偏验证集来提供更多的信息。然而这个验证集在实际应用中是否会引入一些overfitting的风险其实还有待更多应用的验证。希望这个方向后续有更多exciting的工作出现。

[1] Koh, Pang Wei, andPercy Liang. "Understanding black-box predictions via influencefunctions."ICML (2017).

[2] Wang, Tianyang, JunHuan, and Bo Li. "Data dropout: Optimizing training data for convolutionalneural networks." 2018 IEEE 30th International Conference on Tools withArtificial Intelligence (ICTAI). IEEE, 2018.

[3] Ren, Mengye, et al."Learning to reweight examples for robust deep learning." ICML(2018).

[4] Fan, Yang, et al. "Learning What Data to Learn." arXivpreprint arXiv:1702.08635 (2017).

微信公众号: 极市平台(ID: extrememart )

每天推送最新CV干货

loss 加权_样本生而不等——聊聊那些对训练数据加权的方法相关推荐

  1. 样本生而不等——聊聊那些对训练数据加权的方法

    现有大部分机器学习或者深度学习的研究工作大多着眼于模型或应用,而忽略对数据本身的研究.今天给大家介绍的几个文章就关注于在机器学习中如何通过对训练集的选择和加权取得更好的测试性能. 在开始之前,先和大家 ...

  2. 在envi做随机森林_基于模糊孤立森林算法的多维数据异常检测方法

    引用:李倩, 韩斌, 汪旭祥. 基于模糊孤立森林算法的多维数据异常检测方法[J]. 计算机与数字工程, 2020, 48(4): 862-866. 摘要:针对孤立森林算法在进行异常检测时,忽略了每一条 ...

  3. python pypdf2另存为图片_最全总结!聊聊 Python 操作PDF的几种方法

    作者 | 陈熹 来源 | 早起Python前言本文主要涉及: os 模块综合应用 glob 模块综合应用 PyPDF2 模块操作 基本操作 PyPDF2 导入模块的代码常常是: from PyPDF2 ...

  4. java 解析数据包_一种基于Java语言的网络通讯数据包解析方法与流程

    本发明涉及网络通讯领域,特别涉及一种基于Java语言的网络通讯数据包解析方法. 背景技术: 计算机系统和网络的大量普及使用使全球跨入了信息化时代.但是,正由于现代社会中几乎一切都在"计算机化 ...

  5. 分包组包 北斗通信_一种利用北斗短报文实现第三方数据双向传输的方法与流程...

    本发明涉及通信技术领域,特别涉及一种利用北斗短报文实现第三方数据双向传输的方法. 背景技术: 中国北斗卫星导航系统是中国自行研制的全球卫星导航系统,北斗RDSS是北斗系统区别于其他导航系统的特点之一, ...

  6. python规模大小的指标是_训练数据多少才够用

    [导读]机器学习获取训练数据可能很昂贵.因此,机器学习项目中的关键问题是确定实现特定性能目标需要多少训练数据.在这篇文章中,我们将对从回归分析到深度学习等领域的训练数据大小的经验和研究文献结果进行快速 ...

  7. 神经网络测试集loss不变_神经网络训练过程中不收敛或者训练失败的原因

    在面对模型不收敛的时候,首先要保证训练的次数够多.在训练过程中,loss并不是一直在下降,准确率一直在提升的,会有一些震荡存在.只要总体趋势是在收敛就行.若训练次数够多(一般上千次,上万次,或者几十个 ...

  8. 计算智能学习1_感知机原理_水果生熟分类器_matlab实现

    文章目录 感知机(Perceptron)原理 实验部分 matlab实现 小实验一:分析水果是生是熟 一.需求分析 二.概要设计 三.详细设计(完整代码) 四.实验结果总结 小实验二:测试上课是否迟到 ...

  9. 睡眠 应该用 a加权 c加权_时间加权平均价格算法(TWAP)和成交量平均算法(VWAP)在量化回测的应用...

    本应用实践平台为BigQuant人工智能量化平台 为什么要引入TWAP和 VWAP? 为了评估策略的资金容量,我们对M.trade模块里买入点和卖出点这两个参数进行了更丰富的扩展,支持了策略能够按更丰 ...

最新文章

  1. 炫酷,SpringBoot+Echarts实现用户访问地图可视化(附源码)
  2. celery中间件:broker
  3. oracle备份镜像,Oracle RMAN两种备份方式 – 备份集备份与镜像复制备份
  4. µC/OS-II和µC/OS-III比较
  5. HTML之position:absolute relative static fixed的区别和理解
  6. OHCI,UHCI,EOHCI,XHCI
  7. html 分页_MySQL——优化嵌套查询和分页查询
  8. python 不等于_python怎么一次输入两个数
  9. 2017 ACM-ICPC南宁网络赛: G. Finding the Radius for an Inserted Circle
  10. c++接口与实现的分离
  11. 网页传奇服务器端,拍拍科技武易传奇神鸟归来商业版+网站
  12. excel如何比对两列数据是否相同
  13. 顶级摄影师的磨皮美白利器Portraiture,支持搭配微设证件大师使用
  14. 台式计算机开机不自检不起动,台式机开机不自检怎么办
  15. JSON与csv哪一个更自描述_徒步进藏和骑行进藏旅行,哪一个更辛苦
  16. Cocos2d-x 3.2 lua飞机大战开发实例(三)道具的掉落,碰撞检测,声音,分数,爆炸效果,完善游戏的功能细节
  17. SSM框架之数据分页,模糊查询
  18. 景观生态学原理| 8 景观生态学与生物多样性保护
  19. 如何用turtle画椭圆?
  20. 微机原理与接口技术[第三版]——第五章课后习题答案

热门文章

  1. 主板知识详解:主板结构
  2. I帧,P帧,B帧 压缩率对比
  3. Win11 25188.1000补丁包介绍及下载地址
  4. 最新最佳最重要的计算机相关网站推荐(更新版)
  5. win操作iOS UI自动化(tidevice+appium)
  6. 文本记录任意时刻的ping值
  7. 吃饭的时候吃饭,睡觉的时候睡觉。 (转)
  8. 论文精度MISC: A MIxed Strategy-Aware Model Integrating COMET for Emotional Support Conversation
  9. RTT Nano学习笔记 8 - 信号量
  10. UNICODE与UTF-8的转换详解