每天给你送来NLP技术干货!


来自:SimpleAI

  • 标题:An Empirical Study of Example Forgetting during Deep Neural Network Learning

  • 会议:ICLR-2019

  • 机构:CMU,MSR,MILA

一句话总结:
学了忘,忘了学。你我如此,神经网络也如此。在深度模型训练过程中,可能发生了大量的、反复的样本遗忘现象。

本论文的写法很特别,跟大家常读的八股文不同,本文更像一个实验报告,标题也说了是一个Empirical Study,我觉得是一个很好的写Empirical Study的范本,值得收藏。

一、跟“灾难性遗忘”的关系

灾难性遗忘(catastrophic forgetting),是一个在深度学习中常被提起的概念,也是lifelong learning, continual learning中研究的主要问题之一。

灾难性遗忘,描述的是在一个任务上训练出来的模型,如果在一个新任务上进行训练,就会大大降低原任务上的泛化性能,即之前的知识被严重遗忘了。 在论文Attention-Based Selective Plasticity中的一幅图很形象地描述了这个概念:

灾难性遗忘,来源:论文Attention-Based Selective Plasticity

而本文提出的样本遗忘(example forgetting),则是受到灾难性遗忘现象的启发而提出的,即在同一个任务的训练过程中,也可能会有遗忘现象,一个样本可能在训练过程中反复地学了忘,忘了学。

实际上,如果我们把任务的概念放宽,那么我每一个mini-batch都可以看做一个小task,所以这里的example forgetting,就是更微观视角的catastrophic forgetting.

二、概念定义

1. Forgetting & Learning events

当一个样本本来预测对的,现在预测错了,就是一次forgetting event;相反的就是learning event.

我们会初始化一开始每个样本的预测都是不对的,但是在经过训练后(比如一个batch之后)进行上述的检查。

2. Classification margin

分类边际,被定义为:正确的类别对应的logit,跟其他类别中最大的logit的差。

3. Forgettable & Unforgettable examples

  • 被遗忘至少一次的,就叫forgettable example

  • 在某时刻被学习到了,然后从此就没有被遗忘过的样本,就叫unforgettable example

  • 从未被学习到的(即自始至终都预测是错的),不能算作unforgettable(但是,自始至终预测都是对的,就算)

三、实验设置&统计流程:

统计算法

上述统计算法更加清晰地告诉我们本文是如何进行对forgetting events进行统计的,即我们是在每个batch训练完之后统计一次。

本文使用了三个数据集:MNIST, permuted-MNIST(MNIST的像素重排版)和CIFAR-10,这三个数据集的学习难度是递增的。

四、☆ 实验观察

这一部分就是本论文的主要部分了,没有太多的理论,主要就是通过一系列的实验来向我们展示训练过程中发生了什么,但真的都挺有意思的,能给人带来很多启发和思考。

1. 遗忘次数的统计

number of forgetting events

从上图可以看出,随着数据集的复杂度和多样性(complexity & diversity)的增加,样本遗忘的情况越来越多。简单的数据集,有大量的unforgettable examples. 作者统计如下:

dataset # unforgettable examples
MNIST 91.7%
permuted-MNIST 75.3%
CIFAR-10 31.3%

另外,有些样本遗忘,可能是随机发生的,就是模型自己随便更新都可能造成遗忘,所以作者们专门做了一个统计,让模型用随机的梯度来更新,看看遗忘的情况:

forgetting by chance

可见随机遗忘的分布,跟真实遗忘的分布还是有很大差别的,而且随机遗忘的次数会很少,一般在2次以内。

2. 何时被第一次学到

一个样本究竟出现几次才会被模型学到?这是一个很有意思的问题,作者分别对unforgettable和forgettable的样本进行了统计:

first learning event

从上图可以发现,大部分的样本,在出现5次以内就可以被学习到。相比而言,unforgettable样本更早被学到。

3. 遗忘次数跟misclassification margin的关系

前面定义了classification margin,而misclassification margin这里定义为一个样本在所有forgetting events中的平均classification margin,所以这个的绝对值越大,就代表分类的模糊程度越大。

misclassification margin

上图是一个2D的直方图,代表了所有样本是如何分布的。总体上看,forgetting次数多的样本,其misclassification margin也很大。

4. 发现噪音样本

我们很自然可以想到,能否利用遗忘次数,来判断一个样本是否是噪音(标签错误)呢?作者从数据集中随机挑选了20%的样本改变其标签,然后做了如下统计:

noise detection

发现,噪音样本跟正常样本在遗忘次数上,分布十分不一样,遗忘次数会显著多于正常样本。因此我们可以利用这个特点,来帮助我们对数据集去噪,例如最近的文章DataCLUE: A Benchmark Suite for Data-centric NLP中就使用了这种方法。

上面展示的是label noise的结果,作者在附录部分还附上了对input添加noise的实验,也挺有意思的:

pixel noise

发现,对样本(图片)添加的noise越大,这个forgetting的统计就越接近一个正态分布,这也一定程度上反映了分类任务越难,样本遗忘的情况就越严重。

5. 微观视角的灾难性遗忘

这是一个很有意思的实验。

上面的很多分析都验证了神经网络确实会有遗忘,即使在同一个任务的训练中。为了跟经典的灾难性遗忘进行对照,作者仿照经典的continual learning的实验方法来设计了实验:将样本分两批,使用模型依次进行训练,并记录模型在两批样本上的分类准确率

continual learning

上图最左边,是使用一个数据集中随机挑选的两部分来轮流训练。我们发现,即使两个task都来自同一个数据分布,灾难性遗忘也可能发生!模型太健忘了。

右边的两个图,则是使用unforgettable和forgettable样本作为两个数据集来依次训练,可以发现两个结论:

  • 在容易遗忘的样本上训练完之后,再去难忘的样本上训练,灾难性遗忘很严重(刚刚把易遗忘的样本学会,就一下子忘记了)

  • 在难忘的样本上训练完之后,再去易遗忘的样本上训练,灾难性遗忘的现象很轻微。

6. 我们可以丢掉很多样本,还能保持泛化性能

在上面的实验我们可以看出,学习forgettable examples对于unforgettable examples上的泛化性能似乎影响不大,而反过来就影响很大。借助开头的那个图来理解一下:

这意味着forgettable examples的分布能够比较好地涵盖unforgettable examples的分布,这样才会使得学习新的样本对原来的decision boundary不会有太大改变。

所以,从这个角度看的话,forgettable examples比unforgettable examples蕴含了更多的信息,样本在训练中被遗忘的次数越多,它对分类任务的作用可能越大

因此,我们可以大胆假设,是不是我把unforgettable examples丢掉一大批,都不会怎么影响模型的性能呢?作者做了如下的样本丢弃实验:

removing unforgettable examples

左图中的绿线和蓝线分别代表按照被遗忘次数排序的样本和随机排列的样本,不断增大丢弃比例后的结果。可以发现,在CIFAR-10数据集中,我们可以把前35%最少遗忘的样本丢掉,只损失0.2%的准确率

右图则是同样去除5000个样本,但是改变这5000个样本中平均被遗忘次数。可以发现,大体上,包含的forgettable examples越多,效果越差。但是有意思的是存在一个明显的拐点,当forgettable examples达到一定比例时,效果又会抬升一点。作者解释,这说明数据集中可能存在某些异常点或者错误标注的样本(outliers or mislabeled examples),把他们去掉了对模型有好处,但这些样本往往被遗忘次数也很多。

removing unforgettable examples

这个图则对比了三个数据集,对比发现MNIST,permuted-MNIST和CIFAR-10可以分别移除高达80%,50%,30%的训练样本且几乎不影响性能。

7. 样本遗忘现象的稳定性

我们肯定还会关心这种样本遗忘现象,换了随机种子,换了模型,结果会不会差别很大,还是说,(不)容易遗忘的样本,换了模型和种子都依然(不)容易遗忘?

作者对此都做了实验探究,首先,使用了10个不同seed,对所有样本的number of forgetting进行统计,然后彼此之间计算排序的Pearson相关系数,发现高达89.2%,所以不同seed下,样本的遗忘现象是十分类似的。

然后,作者探究了在不同的训练阶段(不同的epoch时)的遗忘情况的差异,见下图最左边,实验表明,训练到75轮以后,样本遗忘的情况就基本稳定了。

(中间那个图我看不懂,就不讲了)

最右边那个图,是使用ResNet18来对forgetting events进行统计,然后使用这个统计结果,不断删减训练样本,在更大型的模型WideResNet上进行训练的结果,发现依然可以删除30%的数据还能保持性能基本不变,这说明,我们可以使用轻量的模型进行遗忘现象的统计,来辅助重型模型的训练。

总之,你如何训练(超参数、模型架构等)对遗忘现象的统计结果的影响不大,遗忘现象反映的是数据集本身的特点。

五、总结& 思考

写作上:

读到这里,我们应该可以发现,这就是一个对模型训练过程中的一些现象进行了一系列简单的统计,并没有什么技术含量,但是读完的感觉,却让我们大呼过瘾,原来深度学习这个黑箱子里还发生了这么多有趣的事情!

这篇文章,让我看到了搞深度学习的科研的另一种可能,我们不一定要设计复杂的模型,要提出什么深刻的数学理论才能做出好的研究,像本文这种对模型的行为的观察、对数据集特点的分析,也可以做出好的研究,并给后续的研究者提供很多经验和思考。

本文虽然像一个实验报告,使用的统计手段也很简单,但是本文设计实验的方法、如何从各种角度去对一个现象进行观测,是很值得我们学习的。

样本遗忘带来的启发

样本遗忘,以及灾难性遗忘,告诉我们神经网络本身存在一定缺陷,没法将学到的知识进行比较好的保留,知识很容易被覆盖。这明显跟人类学习过程不太一样,新的知识一般不会对曾经学过的知识进行巨大冲击,而是融合。所以这对于我们设计神经网络,设计训练方法,应该有很大启示,在continual learning领域应该已经有丰富的工作来试图解决这方面问题。

另外,样本遗忘现象本身,也可以帮助我们认识数据集,这对于Data-centric AI领域的研究应该也有很大帮助。


投稿或交流学习,备注:昵称-学校(公司)-方向,进入DL&NLP交流群。

方向有很多:机器学习、深度学习,python,情感分析、意见挖掘、句法分析、机器翻译、人机对话、知识图谱、语音识别等。

记得备注呦

整理不易,还望给个在看!

ICLR2019 | 模型训练会发生了大量的、反复的样本遗忘现象,如何解决?相关推荐

  1. 利用多 GPU 加速深度学习模型训练

    01 - 前言 深度学习模型通常使用 GPU 训练,因为 GPU 具有相比 CPU 更高的计算能力,以 Tesla V100 为例,使用 Tensor Core 加速的半精度浮点计算能力达到 125 ...

  2. 模型训练常用tricks

    一.背景 背景:常见NLP模型训练tricks 目标群体:Trainer 技术应用场景:仅适用于深度学习(狭义)模型训练,未涉及机器学习模型 整体思路:按训练前.训练中.训练后三个阶段划分 二.模型训 ...

  3. 深度学习模型训练过程

    深度学习模型训练过程 一.数据准备 基本原则: 1)数据标注前的标签体系设定要合理 2)用于标注的数据集需要无偏.全面.尽可能均衡 3)标注过程要审核 整理数据集 1)将各个标签的数据放于不同的文件夹 ...

  4. 用什么tricks能让模型训练得更快?先了解下这个问题的第一性原理

    点击上方"视学算法",选择加"星标"或"置顶" 重磅干货,第一时间送达 作者丨Horace He 来源丨机器之心 编辑丨极市平台 导读 深度 ...

  5. 【云原生AI】Fluid + JindoFS 助力微博海量小文件模型训练速度提升 18 倍

    简介: 深度学习平台在微博社交业务扮演着重要的角色.计算存储分离架构下,微博深度学习平台在数据访问与调度方面存在性能低效的问题.本文将介绍微博内部设计实现的一套全新的基于 Fluid(内含 Jindo ...

  6. 训练softmax分类器实例_第四章.模型训练

    迄今为止,我们只是把机器学习模型及其大多数训练算法视为黑盒.但是如果你做了前面几章的一些练习,你可能会惊讶于你可以在不知道任何关于背后原理的情况下完成很多工作:优化一个回归系统,改进一个数字图像分类器 ...

  7. 【深度学习】深度学习模型训练全流程!

    Datawhale干货 作者:黄星源.奉现,Datawhale优秀学习者 本文从构建数据验证集.模型训练.模型加载和模型调参四个部分对深度学习中模型训练的全流程进行讲解. 一个成熟合格的深度学习训练流 ...

  8. 【天池赛事】零基础入门语义分割-地表建筑物识别 Task5:模型训练与验证

    [天池赛事]零基础入门语义分割-地表建筑物识别 Task1:赛题理解与 baseline(3 天) – 学习主题:理解赛题内容解题流程 – 学习内容:赛题理解.数据读取.比赛 baseline 构建 ...

  9. 加载tf模型 正确率很低_深度学习模型训练全流程!

    ↑↑↑关注后"星标"Datawhale 每日干货 & 每月组队学习,不错过 Datawhale干货 作者:黄星源.奉现,Datawhale优秀学习者 本文从构建数据验证集. ...

  10. 深度学习模型训练的一般方法(以DSSM为例)

    向AI转型的程序员都关注了这个号???????????? 机器学习AI算法工程   公众号:datayx 本文主要用于记录DSSM模型学习期间遇到的问题及分析.处理经验.先统领性地提出深度学习模型训练 ...

最新文章

  1. 来聊聊双目视觉的基础知识(视觉深度、标定、立体匹配)
  2. [译] 解密 Airbnb 的数据科学部门如何构建知识仓库
  3. 关于OPENGL的各个变换的顺序
  4. MyBatis创建SqlSession-怎么拿到一个SqlSessionTemplate?
  5. 东南亚版“QQ 音乐”:JOOX 的音乐推荐重构之路
  6. JavaScript substr() 方法
  7. emacs中安装markdown-mode
  8. 大数据职业理解_数据分析师真有那么好?其实正在面临3大职业困境
  9. 在Ubuntu下安装ros
  10. 医院his系统机房服务器,医院信息中心机房如何建设
  11. repeate 常用的每行显示几个共几行
  12. 识图在线识图_以图搜图在线搜索软件
  13. Arthas--深入排查java进程消耗CPU或内存过高问题
  14. 阿里天池:Airbnb短租房数据集分析
  15. {typedir} {style} {tid} {aid} 分别是什么意思?
  16. scrapy框架下的豆瓣电影评论爬取以及登录,以及生成词云和柱状图
  17. waiting for changelog lock.
  18. 表单验证-----验证图片格式
  19. C++ 代码模拟登录淘宝、天猫、支付宝等电商网站的实现
  20. Android通讯录模糊匹配搜索实现(号码、首字母,移动应用开发课程设计心得

热门文章

  1. STM32 ADC没有输入电压时,采集结果不为0
  2. 活动报名 | 前端攻城狮该怎样跳脱“围城”的焦虑
  3. 数的计数【Noip2001】
  4. 2018 计蒜之道 初赛 第一场
  5. Python自动发送邮件提示:smtplib.SMTPServerDisconnected: please run connect() first
  6. unsupported major.monor version 51.0 (unable to load *.servlet)………………
  7. Putty密钥(PrivateKey)导入SecureCRT
  8. 【react】---组件传值的介绍
  9. php 判断设备是手机还是平板还是pc
  10. 后缀数组三·重复旋律3