星标/置顶小屋,带你解锁

最萌最前沿的NLP、搜索与推荐技术

文 | 苏剑林

编 | 夕小瑶


在训练模型的时候,我们需要损失函数一直训练到0吗?显然不用。一般来说,我们是用训练集来训练模型,但希望的是验证集的损失越小越好,而正常来说训练集的损失降低到一定值后,验证集的损失就会开始上升(即过拟合),因此没必要把训练集的损失降低到0。

为了对抗这种过拟合现象,提高模型的测试集表现(即泛化能力),一种很自然的想法是提前终止(early stopping),也就是当观测到模型的验证集表现不降反升时,果断停止训练。这也是如今大模型跑小数据时的最常用做法。

既然如此,在模型训练loss已经到达某个阈值之后,我们可不可以做点别的事情来继续提升模型的测试集性能呢?一篇发表于机器学习顶会ICML2020上的论文《Do We Need Zero Training Loss After Achieving Zero Training Error?》[1]回答了这个问题。

不过这篇论文的回答也仅局限在“是什么”这个层面上,并没很好地描述“为什么”,另外看了知乎上kid丶[2]大佬的解读,也没找到自己想要的答案。因此自己分析了一下,记录在此。

思路描述

论文提供的解决方案非常简单,假设原来的损失函数是,现在改为:

其中是预先设定的阈值。当时,这时候就是执行普通的梯度下降;而时,注意到损失函数变号了,所以这时候是梯度上升。因此,总的来说就是以为阈值,低于阈值时反而希望损失函数变大。论文把这个改动称为“Flooding”。

这样做有什么效果呢?论文显示,训练集的损失函数经过这样处理后,验证集的损失能出现“二次下降(Double Descent)”,如下图。简单来说就是最终的验证集效果可能更好些。

左图:不加Flooding的训练示意图;右图:加了Flooding的训练示意图

效果

从上图可以看出来这个方法的理想很丰满,那么实际表现如何呢?

作者这里在MNIST、CIFAR等众多CV领域的benchmark上进行了实验,且如下图所示

图中中间一栏是没有加flooding的结果(early stopping和weight decay的四种排列组合),右边一栏是加了flooding的结果(四种排列组合的基础上都加上flooding)。可以看到加了flooding后,大部分情况下模型都能比之前有更好的测试集表现。

个人分析

如何解释这个方法的有效性呢?可以想象,当损失函数达到之后,训练流程大概就是在交替执行梯度下降和梯度上升。直观想的话,感觉一步上升一步下降,似乎刚好抵消了。事实真的如此吗?我们来算一下看看。假设先下降一步后上升一步,学习率为,那么:

我们有

(滑动查看完整公式)

近似那一步是使用了泰勒展式对损失函数进行近似展开,最终的结果就是相当于损失函数为梯度惩罚、学习率为的梯度下降。更妙的是,改为“先上升再下降”,其表达式依然是一样的(这不禁让我想起“先升价10%再降价10%”和“先降价10%再升价10%”的故事)。因此,平均而言,Flooding对损失函数的改动,相当于在保证了损失函数足够小之后去最小化,也就是推动参数往更平稳的区域走,这通常能提供提高泛化性能(更好地抵抗扰动),因此一定程度上就能解释Flooding其作用的原因了。

本质上来讲,这跟往参数里边加入随机扰动、对抗训练等也没什么差别,只不过这里是保证了损失足够小后再加扰动。读者可以参考《泛化性乱弹:从随机噪声、梯度惩罚到虚拟对抗训练》[3]了解相关内容,也可以参考“圣经”《深度学习》第二部分第七章的“正则化”一节。

方法局限性

虽然这个方法看起来还挺work,但是不能忽视的一个细节是,作者在做上面表格里的每组flooding的实验时,都对flooding的超参b调节了20组(从0.01~0.20),如下

这在数据规模很小时实验代价还好,但单次实验代价较高时,可能就不那么实用了。

继续脑洞

有心使用这个方法的读者可能会纠结于的选择或调超参的实验代价,不过笔者倒是有另外一个脑洞:无非就是决定什么时候开始交替训练罢了,如果从一开始就用不同的学习率进行交替训练呢?也就是自始至终都执行

其中,这样我们就把去掉了(当然引入了的选择,天下没免费午餐)。重复上述近似展开,我们就得到

(滑动查看完整公式)

这就相当于自始至终都在用学习率来优化损失函数了,也就是说一开始就把梯度惩罚给加了进去。这样能提升模型的泛化性能吗?笔者简单试了一下,有些情况下会有轻微的提升,基本上都不会有负面影响,总的来说不如自己直接加梯度惩罚好,所以不建议这样做。

文章小结

本文简单介绍了ICML2020一篇论文提出的“到一定程度后就梯度上升”的训练策略,并给出了自己的推导和理解,结果显示它相当于对参数的梯度惩罚,而梯度惩罚也是常见的正则化手段之一。


文末福利

后台回复关键词入群
加入卖萌屋NLP/IR/Rec与求职讨论群
有顶会审稿人、大厂研究员、知乎大V和妹纸
等你来撩哦~

关注星标

带你解锁最前沿的NLP、搜索与推荐技术

参考文献

[1] Do We Need Zero Training Loss After Achieving Zero Training Error?: https://arxiv.org/abs/2002.08709

[2] kid丶: https://zhuanlan.zhihu.com/p/163676138

[3] 泛化性乱弹:从随机噪声、梯度惩罚到虚拟对抗训练: https://kexue.fm/archives/7466

ICML2020 | 一行代码就能实现的测试集上分技巧相关推荐

  1. c语言一行代码太长,C语言修改一行代码,运行效率居然提升数倍,这个技巧你知道吗...

    对编译.链接.OS内核.系统调优等技术感兴趣的童鞋,不妨右上角关注一下吧,近期会持续更新相关方面的专题文章!引言 近日,网上看到一篇文章,分析数组访问的性能问题.文章经过一系列"有理有据&q ...

  2. [深度学习-实践]Tensorflow 2.x应用ResNet SeNet网络训练cifar10数据集的模型在测试集上准确率 86%-87%,含完整代码

    环境 tensorflow 2.1 最好用GPU Cifar10数据集 CIFAR-10 数据集的分类是机器学习中一个公开的基准测试问题.任务的目标对一组32x32 RGB的图像进行分类,这个数据集涵 ...

  3. 五、在测试集上评估图像分类算法精度(Datawhale组队学习)

    文章目录 配置环境 准备图像分类数据集和模型文件 测试集图像分类预测结果 表格A-测试集图像路径及标注 表格B-测试集每张图像的图像分类预测结果,以及各类别置信度 可视化测试集中被误判的图像 测试集总 ...

  4. 在测试集上训练,还能中CVPR?这篇IEEE批判论文是否合理?

    机器之心报道 机器之心编辑部 今日,一篇论文帖子在 Reddit 的机器学习版块引起了大家的关注.该论文表示 Concetto Spampinato 等人 2017 年的 CVPR 论文存在错误.但从 ...

  5. [NLP]基于IMDB影评情感分析之BERT实战-测试集上92.24%

    系列文章目录 深度学习NLP(一)之Attention Model; 深度学习NLP(二)之Self-attention, Muti-attention和Transformer; 深度学习NLP(三) ...

  6. [深度学习-TF2实践]应用Tensorflow2.x训练ResNet,SeNet和Inception模型在cifar10,测试集上准确率88.6%

    环境 tensorflow 2.1 最好用GPU Cifar10数据集 CIFAR-10 数据集的分类是机器学习中一个公开的基准测试问题.任务的目标对一组32x32 RGB的图像进行分类,这个数据集涵 ...

  7. 解决:测试集上每次输出的结果不相同问题

    1 原因 可能在图片数据集加载时,shuffle设置为True了,需要改为False 在模型中,一些层中Dropout,Normalization等具有随机性,需要设置一下种子 没有开启net.eva ...

  8. python使用matplotlib对比多个模型在测试集上的效果并可视化、设置模型性能可视化结果柱状图(bar plot)标签的小数点位数(例如,强制柱状图标签0.7显示为两位小数0.70)

    python使用matplotlib对比多个模型在测试集上的效果并可视化.设置模型性能可视化结果柱状图(bar plot)标签的小数点位数(例如,强制柱状图标签0.7显示为两位小数0.70) 目录

  9. R语言随机森林模型:计算随机森林模型的特征重要度(feature importance)并可视化特征重要度、使用少数重要特征拟合随机森林模型(比较所有特征模型和重要特征模型在测试集上的表现差异)

    R语言随机森林模型:计算随机森林模型的特征重要度(feature importance)并可视化特征重要度.使用少数重要特征拟合随机森林模型(比较所有特征模型和重要特征模型在测试集上的表现差异) 目录

最新文章

  1. JS函数式编程【译】5.2 函子 (Functors)
  2. c#: 协变和逆变深度解析
  3. 全国计算机考试光盘,全国计算机一级模拟考试题(光盘).doc
  4. 集合中重写equals方法删除new的对象
  5. C++虚基类成员可见性
  6. centos安装php服务器,在CentOS上安装搭建PHP+Apache+Mysql的服务器环境方法
  7. 今天的凉爽的学习环境 录音软件
  8. [Spark]Could not locate executable null\bin\winutils.exe in the Hadoop binaries
  9. 五分钟,带你彻底掌握 MyBatis缓存 工作原理
  10. 2021哈工程计算机考研科目,2021考研大纲:哈尔滨工程大学计算机专业基础综合2021年硕士研究生自命题考试大纲...
  11. 微信小程序之移动端适配
  12. PTA-数据库作业题(二)
  13. android 社交类ui设计,基于社交类APP界面设计与创意思维的研究
  14. 打印机接无线共享服务器出现乱码,Ricoh理光复印机网络打印机出乱码的解决办法...
  15. uvm snippets
  16. clearcase下的一些常用命令
  17. softlockup原理分析
  18. 【Python+Pycharm】单词底部有波浪线,提示typo in word时
  19. Linux操作系统笔记(超详细)
  20. ug五轴编程视频教程

热门文章

  1. android中给TextView或者Button的文字添加阴影效果
  2. .NET Framework 1.1安装出现1935错误的解决办法
  3. C语言和C++的区别
  4. 写代码获取全国疫情地图
  5. 2019 高考填报志愿建议
  6. warning: function declaration isn’t a prototype(函数声明不是原型)的解决办法
  7. win10系统能做域服务器吗,Win10 LTSC 加入 Windows Server 2019 域服务器
  8. STM32——串口通信
  9. Orange-Classification,Regression
  10. python 随机名言_如何用简易代码自动生成经典语录