关注公众号,发现CV技术之美

0

写在前面

目前的计算机视觉模型在进行增量学习新的知识的时候,就会出现灾难性遗忘的问题。缓解这种遗忘的最有效的方法需要大量重播(replay)以前训练过的数据;但是,当内存限制或数据合法性问题存在时,这种方法就存在一定的局限性。

在本文中,作者研究了无数据类增量学习(DFCIL)的问题,也就是增量学习能够学习新的知识,而不存储生成器或过去任务的训练数据。目前,DFCIL的一种方法是通过倒置学习分类模型的冻结副本,来合成图像用于训练,使得模型能够不忘记以前任务的知识,也不用replay以前训练过的数据。但是,作者通过实验表明了当使用标准蒸馏策略时,这种方法对于常见的类增量benchmark都是无效的。

因此,在本文中,作者分析了这种方法失败的原因,并提出了一种新的DFCIL增量蒸馏策略,提供了一个改进的交叉熵训练和重要性加权特征蒸馏。最终作者通过实验表明,在类增量benchmark上,与SOTA DFCIL方法相比,本文提出的方法在精度上提高了25.1%,甚至优于几种需要存储图像的基于replay的方法。

01

论文和代码地址

Always Be Dreaming: A New Approach for Data-Free Class-Incremental Learning

论文地址:https://arxiv.org/abs/2106.09701

代码地址:尚未开源

02

Motivation

目前,计算机视觉的一个局限是,它们通常使用一个包含在部署过程中所有可能遇到的数据的大型数据集,进行脱机训练。然而,现实情况是许多应用程序需要在遇到新的情况和数据后不断更新模型。这就是类增量学习的范式,在学习新任务的时候忘记以前学习到的知识的问题被称为在灾难性遗忘 。目前,比较成功的增量学习方法有一个缺点:它们需要大量的内存来replay以前看到过的或建模的数据,以避免灾难性遗忘问题。

这在很多计算机视觉的应用中也是不现实的,因为:

1)许多计算机视觉应用程序都是在设备上的,因此内存有限;

2)在工业界,可能会存在很多不允许被存储的数据(比如用户的隐私信息)。

因此,作者就提出了这样一个问题:计算机视觉系统如何能在不存储数据的情况下增量地学习新信息?作者将这样的设置称为无数据类增量学习(DFCIL)。

DFCIL的一种直观方法是同时训练生成模型进行采样以进行replay,以防止忘记以前的知识。但是与分类模型相比,训练生成模型的计算和内存都更密集。

因此,作者探索了模型反演图像 合成的概念,就是通过反转已经提供的推理网络,来获得网络中与训练数据具有相似激活作用的图像。这样一来,就不需要训练额外的网络(因为它只需要现有的推理网络)。

(上图展示了当使用合成数据进行基于replay类增量学习时,特征嵌入的分布。图a展示了合成数据的直接应用使模型的学习特征更容易区分是真实数据还是合成数据,而不是任务1和2,这也是本文要解决的主要问题;图b展示了修改分类损失和添加正则化可以减轻真实和合成图像之间的特征位移;图c是理想的特性分布,使任务1和任务2更可分离。)

上图展示了DFCIL增量学习失败的原因(图a),用当前任务的真实图像和代表过去任务的合成图像训练模型时,特征提取模型提取的特征会变成:当前真实图像的特征分布与当前真实图像的特征分布(即使他们不属于同一个类)更接近,与合成图像的特征分布更不接近 ,这就导致了预测时候的偏差。

这一现象表明,当训练一个具有两种数据分布的网络时,同时包含语义位移和分布位移,分布位移对特征嵌入有更高的影响。因此,来自以前任务的的测试图像将被识别为新的类,因为模型会更关注于它们的分布,而不是它们的语义内容(这就与分类任务的目标背道而驰了)。

为了解决这个问题,作者提出了一种新的类增量学习方法,该方法学习了具有局部分类损失的新任务特征,依赖于重要性加权特征蒸馏和线性分类head微调来分离新任务和过去任务的特征嵌入。

作者通过实验表明,在类增量benchmark上,与SOTA DFCIL方法相比,本文提出的方法在精度上提高了25.1%,甚至优于几种需要存储图像的基于replay的方法。

03

方法

3.1. 先验知识-类增量学习

在类增量学习中,一个模型需要学习了对应于M个语义对象类



























的数据,但这些数据是通过N个task依次暴露给模型的,每个任务中子类都不会重合。

我们用







来表示任务n中引入的类集,其中









表示任务n中对象类的数量。每个类只出现在单个任务中,模型目标就是逐步学习引入的新对象类,并对它们进行分类,同时保留之前学习过的类的知识。

为了描述推理模型,我们将





θ





,





表示在i时刻使用任务n的类训练的模型。

3.2. Baseline Approach

在本节中,作者基于之前工作,提出了一个Data-Free的用于类增量学习的baseline。

3.2.1. Model-Inversion Image Synthesis

大多数模型反演图像合成方法都是通过直接对先验的鉴别模型





θ









进行优化来合成图像。然而,一次优化一个Batch的图像在计算上是效率低下的。因此作者选择使用卷积网络参数化函数






φ




用噪声生成合成图像进行近似优化。这就使每个任务只需要训练一次






φ




,当前任务结束时就可以直接丢弃。

首先,






φ




需要生成多样性的图片,因此作者优化合成了图像的类预测的多样性,以匹配均匀分布。将






θ







表示为模型θ对输入x产生的预测类分布,需要使合成样本









的平均类预测向量的熵最大化,如下所是(label diversity loss):

其中












为信息熵。

除了多样性之外,为了在DFCIL中合成有用的图像,图像还需要校准的类置信度、特征统计数据的一致性和局部平滑的潜在空间。

对于校准的类置信度 ,作者使用了Content Loss,通过对图像张量









的类预测一致性最大化,这样





θ









就能对所有输入做出足够confident的预测了。Content Loss的具体计算表示如下所示:

通过将























相结合,就确保合成的图像将代表过去所有任务类的分布。

对于特征统计数据的一致性 ,先前的工作发现,模型反演的复杂性会导致





θ









特征的分布大大偏离合成图像的分布。因此,合成图像的Batch统计应该与





θ









中的Batch Norm层相匹配。基于此,作者进一步提出了stat alignment loss:

其中





代表KL散度。

对于局部平滑的潜在空间 ,先验知识告诉我们,自然图像在像素空间中比初始噪声更局部平滑。因此作者又提出了一个损失函数smoothness prior loss,这个函数就是生成图像









和高斯模糊版本的生成图像

















的L2距离:

最后,






φ




的损失函数为上面提到的损失函数之和:

3.2.2. Distilling Synthetic Data for Class-Incremental Learning

在类增量学习中,对合成图像的知识蒸馏通常被用于





θ





正则化,迫使它学习







,学习







的同时,将


















的知识遗忘减到最小。对于任务







,我们从任务











期间训练的





θ









的冻结副本中合成图像。这些合成图像帮助我们将任务


















中学习的知识提炼到我们当前的模型





θ





中。

在Baseline方法中,作者采用了DeepInversion中使用的蒸馏方法。具体表示为,给定当前的任务数据




和合成的蒸馏数据









,我们最小化:

其中














是一种知识蒸馏正则化方法:

3.3. Diagnosis: Feature Embedding Prioritizes Domains Over Semantics

为了探究为什么DFCIL的Baseline方法会失败,作者使用度量(MID)分析了嵌入特征之间的表征距离,这种度量用于捕获两个分布样本的平均图像embedding之间的距离。作者将这种度量实例化为Mean Image Distance (MID) score,高分表示不同的特征,低分表示相似的特征。计算如下:

作者计算任务1真实数据的特征embedding与任务2真实数据之间的MID,然后计算任务1真实数据的特征embedding与任务1合成数据之间的MID,结果如上图所示。对于(a)DeepInversion,任务1真实数据与任务1合成数据之间的MID得分明显高于任务1真实数据与任务2真实数据之间的MID得分。

这表明embedding空间对domain有更高的优先级,而不是语义,但这不是模型想要的结果。对于作者提出的方法(b),任务1真实数据和任务1合成数据之间的MID分数明显低得多,这表明特性embedding的语义优先于domain。

3.4. A New Distillation Strategy for DFCIL

基于上面的分析,作者提出了持续的学习应该在以下几个方面保持平衡:(1)针对新任务的学习特征;(2)最小化超过上一个任务的特征偏移;(3)在embedding空间中分离新的类和以前的类之间的类重叠。

对于上面的三个平衡,(1)和(3)可以通过










实现。但是作者认为,通过将其分成两种不同的损失,可以在学习新任务的时候,不区分真实图像和合成图像的特征。根据这个想法,作者提出了一种为DFCIL设计的新的类增量学习方法,该方法独立地解决这些目标。

(蓝色箭头表示之前合成的任务数据的计算路径,绿色箭头表示真实的当前任务数据的计算路径,黄色箭头表示真实数据和合成数据的计算路径。)

模型的overview如上图所示

3.4.1. Learning current task features

作者方法背后的intuition是需要学习当前task的特征的同时,绕过偏向最近task真实数据的特征表示。具体实现上,作者通过只计算在新的 线性分类head上的局部交叉熵分类损失来实现这一点。有了这种模式,作者阻止了模型学习通过domain分离新的和过去的类数据,损失函数如下:

3.4.2. Minimizing feature drift over previous task data

与真实的当前任务图像相比,蒸馏图像属于另一个domain,因此作者寻找了另一个损失函数,直接减轻遗忘的损失函数。要实现这个目标,一个选择是特性蒸馏:

虽然
















强化了过去任务数据的重要组成部分,但它的强正则性抑制了模型的学习新任务的能力。另一方面
















并不抑制新任务的学习,可能导致真实数据和合成数据的bias。

因此,作者提出了一种重要性加权特征蒸馏,它只强化了过去任务数据中最重要的组成部分,同时允许不那么重要的特性来适应新任务。表示如下:

W为重要性权重矩阵,W权重大的特征更为重要。

3.4.3. Separating Current and Past Decision Boundaries

最后,模型需要分离当前类和过去类的决策边界,而不允许特征空间来区分真实数据和合成数据。作者通过用交叉熵损失函数来fine-tuning线性分类head来实现。除了线性分类head之外,这个损失函数并不会更新





θ





,



:





中的任何参数:

3.4.4. Final Objective

最终模型的损失函数为上述损失函数之和,如下所示:

04

实验

4.1. DFCIL (CIFAR-100 )

从上表结果可以看出,本文的方法不仅优于DFCIL方法,甚至优于生成方法。

4.2. CIL with Replay Data (CIFAR-100 )

在上表中,作者将本文的方法(不存储回放数据)与其他存储回放数据的方法进行了比较。可以看出,本文方法的performance可以优于LwF和Rehersal,但是后者需要存储回放数据,这就意味着更高的内存消耗。

4.3. Ablation Study(CIFAR-100 )

从上表可以看出,文中对Data-Free增量学习专门设计的几个损失函数和蒸馏方法,对于整个模型性能的提高,都有着非常重要的作用。

4.4. DFCIL (ImageNet)

作者还使用ImageNet数据集来验证本文的方法在大规模图像数据集上的表现。可以看出,本文的方法在这个大规模图像数据集上的实验结果也没有比基于replay的方法落后太多。

05

总结

在本文中,作者表明现有的类增量学习方法在使用真实训练数据学习新任务和使用合成蒸馏数据保存过去的知识时,performance较差。因此,作者提出了一种新的方法来实现了无数据类增量学习的SOTA性能,并与基于replay的SOTA方法性能相当。

作者提出无数据类增量学习是希望消除在类增量学习中存储回放数据的需要,使计算机视觉的广泛和实际应用成为可能。不存储数据的增量学习解决方案,将对计算机视觉应用产生直接影响,进一步促进计算机视觉任务的落地应用。

▊ 作者简介

厦门大学人工智能系20级硕士

研究领域:FightingCV公众号运营者,研究方向为多模态内容理解,专注于解决视觉模态和语言模态相结合的任务,促进Vision-Language模型的实地应用。

知乎/公众号:FightingCV

END,入群????备注:CV

让模型实现“终生学习”,佐治亚理工学院提出Data-Free的增量学习相关推荐

  1. 佐治亚理工学院计算科学与工程系博士生招生!

    仅作学术分享,不代表本公众号立场,侵权联系删除 转载于:机器之心 佐治亚理工学院计算科学与工程系助理教授 吴安琪 导师简介 吴安琪博士本科毕业于哈尔滨工业大学电子信息工程专业(师从于达仁教授),随后在 ...

  2. 佐治亚理工学院硕士建议:2022年你应该掌握这些机器学习算法

    来源:机器之心本文约1700字,建议阅读8分钟 2022年你应该知道的所有机器学习算法. 想要成为一名合格的 AI 工程师,并不是一件简单的事情,需要掌握各种机器学习算法.对于小白来说,入行 AI 还 ...

  3. 佐治亚理工学院发文:不要迷信可解释性,小心被误导

    编译 | 王晔 校对 | 琰琰 可解释性对人工智能发展来说至关重要,但在可解释系统的可信度方面,理解其可能带来的负面效应亦同等重要. 近日,佐治亚理工学院研究团队发表最新研究,重点讨论了可解释人工智能 ...

  4. ​真的存在可以检测万物的模型吗?联汇科技提出了一种有趣的解决方案

    你还在为你的检测模型只能检测固定类别的物体而烦恼吗?你还在为添加新的检测类别后需要从头开始训练模型而烦恼吗?你还在为标注目标检测模型的数据而烦恼吗?这说明你应该换一套思路来做目标检测啦!既然我们人类能 ...

  5. 佐治亚理工学院计算机系,UC联合学院学生拜访佐治亚理工学院电子与计算机工程系...

    美国当地时间7月20日早上9:30,重庆大学-辛辛那提大学联合学院(以下简称辛辛那提学院)院长张志清.学院实习管理老师冉茉莉,与正在美国佐治亚理工学院(以下简称Georgia Tech)参加亚特兰大暑 ...

  6. 性能超越最新序列推荐模型,华为诺亚方舟提出记忆增强的图神经网络

    作者 | Chen Ma, Liheng Ma等 译者 | Rachel 出品 | AI科技大本营(ID:rgznai100) 用户-商品交互的时间顺序可以揭示出推荐系统中用户行为随时间演进的序列性特 ...

  7. 基于模型的强化学习比无模型的强化学习更好?错!

    作者 | Carles Gelada and Jacob Buckman 编辑 | DeepRL 来源 | 深度强化学习实验室(ID:Deep-RL) [导读]许多研究人员认为,基于模型的强化学习(M ...

  8. R使用LSTM模型构建深度学习文本分类模型(Quora Insincere Questions Classification)

    R使用LSTM模型构建深度学习文本分类模型(Quora Insincere Questions Classification) Long Short Term 网络-- 一般就叫做 LSTM --是一 ...

  9. keras构建卷积神经网络(CNN(Convolutional Neural Networks))进行图像分类模型构建和学习

    keras构建卷积神经网络(CNN(Convolutional Neural Networks))进行图像分类模型构建和学习 全连接神经网络(Fully connected neural networ ...

最新文章

  1. excel python 形状_何使用Python操作Excel绘制柱形图
  2. 【内网穿透】生壳SSH映射 for Linux 使用教程
  3. Mongodb 数据模型概念
  4. 【转】04.Dicom 学习笔记-DICOM C-Move 消息服务
  5. linux命令cd回退_Linux命令一
  6. Python处理json字符串转化为字典
  7. Office online server 部署
  8. Linux基础命令的使用
  9. 用GIF图片来告诉大家程序猿的真实生活
  10. JZOJ-senior-5946. 【NOIP2018模拟11.02】时空幻境(braid)
  11. silk lobe资源公众号_资源合集 | 霞鹜公众号字体资源合集(截至 2019.11.30)
  12. vivo7.0系统怎么无root激活XPOSED框架的教程
  13. mathcad prime server system(PASS云计算书系统)开发
  14. 苹果计算机磁盘格式,Mac怎么将ntfs格式的磁盘格式化
  15. 【Xcode使用技巧】Xcode环境变量(environment variables)
  16. linux镜像文件超过4G怎么办,Systemback无法将超过4G的sblive文件转存为镜像文件的解决办法...
  17. 瘦,是一种信仰。轻,是一种理想
  18. 普洱市企业登记“区块链云签名”试点工作启动, 用户操作仅需5分钟!
  19. 从零开始前端学习[38]:html5中的弹性布局一(移动端响应式实现各种布局,极其重要)
  20. Python常见的魔方方法

热门文章

  1. 2-3树与2-3-4树【转载】
  2. 扩展指令集--指令参考说明
  3. N皇后问题——通俗易懂地讲解(C++)
  4. 广西二级c语言试题,广西区计算机等级考二级C语言笔试试题及答案.doc
  5. android导入导出txt通讯录,Android导入导出txt通讯录工具
  6. tensorflow 显存 训练_Tensorflow与Keras自适应使用显存方式
  7. u盘安装linux 7.4,U盘自动化安装CentOS7.4
  8. wordpress php执行短代码_PHP 8.0发布日期和PHP中JIT的状态
  9. vue中tab选项卡刷新页面后保持选中状态_Altium Designer中的快捷键汇总
  10. 组织JSON数据、JSON转换