Gradient Episodic Memory for Continual Learning(用于持续学习的梯度情景记忆)

  • 本篇论文的贡献
  • 创新性
  • Gradient of Episodic Memory (GEM)算法步骤
  • 算法原理图
  • 评估指标和实验结果
    • 评估指标
    • 实验数据集
    • 对比方法
    • 实验结果
  • 可能存在的问题

本篇论文的贡献

①提出了一组指标来评估模型在连续数据上的学习。
②提出了一种持续学习模型,称为梯度情景记忆 (GEM),它可以减轻遗忘,并支持正向反向迁移。③在 MNIST 和 CIFAR-100 数据集上证明了 GEM 与最先进的技术相比具有强大的性能。

创新性

大多数监督学习方法假设每个示例 ( x i , y i ) (x_{i}, y_{i}) (xi​,yi​)是来自固定概率分布 P 的相同且独立分布 (iid) 的样本。ERM 的直接应用会导致“灾难性遗忘”(学习新任务可能会损害学习者在以前解决的任务中的表现)。本文缩小了 ERM 与更类似于人类的学习之间的差距。
i)任务的数量很大
ii)每个任务的训练示例数量很少
iii)学习者只观察与每个任务有关的示例一次
iv)报告了衡量迁移和遗忘的指标,而不是仅报告所有任务的平均性能。

Gradient of Episodic Memory (GEM)算法步骤

①将任务描述符 t i t_{i} ti​引入输入样本 ( x i , y i ) (x_{i},y_{i}) (xi​,yi​)中构成数据连续体,假设整数任务描述符,并使用它们来索引情节记忆;假设数据连续体是局部独立同分布的,即每个 ( x i , t i , y i ) (x_{i},t_{i},y_{i}) (xi​,ti​,yi​)满足 ( x i , y i ) (x_{i},y_{i}) (xi​,yi​) 对于 P t i ( X , Y ) P_{ti}(X, Y) Pti​(X,Y)独立同分布。目标是学习一个预测器 f : X ×T→ Y,它可以随时查询以预测与测试对 (x, t) 关联的目标向量 y。

②定义第 k 个任务的记忆损失:


为了减少模型过拟合,同时允许模型存在正的后向迁移,将上述记忆损失 作为不等式约束,避免它们增加但允许它们减少,对于样本 ( x , t , y ) (x,t,y) (x,t,y)有如下约束条件:

其中, f θ t − 1 f_{\theta}^{t-1} fθt−1​表示第t-1个任务结束时,预测器的参数状态。
为了有效地解决上述约束条件。首先,没有必要存储旧的预测变量 f θ t − 1 f_{\theta}^{t-1} fθt−1​,只要保证在每次参数更新 g 后之前任务的损失不会增加。其次,假设函数是局部线性的(因为它发生在小的优化步骤周围),可以通过计算之前任务的损失梯度向量之间的角度来诊断损失的增加和建议的更新。将约束条件改写为:

其中 g k g_{k} gk​表示当前任务t之前的任务k的损失梯度,以内积的方式来判断两个梯度是否向**“锐角”方向更新。如果满足上述所有不等式约束,则建议的参数梯度更新 g 不太可能增加先前任务的损失。另一方面,如果违反了一个或多个不等式约束,则至少有一个先前的任务会在参数更新后经历损失的增加。如果发生违规,则建议将梯度 g 投影到满足所有约束的最接近的梯度 g ~ \tilde{g} g~​,通过求解L2范数找到这样一个满足所有约束的梯度 g ~ \tilde{g} g~​代替g进行模型更新,上述优化问题可以转换为:

原文中作者通过
二次规划的相关知识得到了 g ~ \tilde{g} g~​,这里暂时还没有研究明白,只能理解为通过二次规划**这种方法可以解决上述问题。

算法原理图


①对于连续数据体中的样本将其划分为多个不同的任务,对于每一个任务中的数据样本分别进行训练更新。(x,y)之间是独立同分布的,t之间是相互关联的。
②对于每一个任务中的训练样本,首先计算当前任务本身的损失函数梯度g;再计算与之前 k 个任务的记忆损失梯度 g k g_{k} gk​;通过上一部分算法步骤中的②得到满足所有约束条件的参数更新梯度 g ~ \tilde{g} g~​;根据 g ~ \tilde{g} g~​进行参数更新。
③对于每一个任务,模型训练完成后,计算每个任务对应的评估指标R矩阵。

评估指标和实验结果

评估指标


R i , j R_{i,j} Ri,j​ 表示模型在任务 t j t_{j} tj​上观察到任务 t i t_{i} ti​的最后一个样本后的测试分类准确度; b ‾ \overline b b 是随机初始化时每个任务的测试准确度向量。
ACC:在模型完成对任务 t i t_{i} ti​ 的学习后,我们评估它在所有 T 个任务上的测试性能。
BWT:后向迁移(BWT),这是学习任务 t 对前一个任务 k ≺ t 性能的影响。一方面,当学习某个任务 t 时,存在正向向后迁移,从而提高了前面某个任务 k 的性能。另一方面,当学习某个任务 t 会降低对某个先前任务 k 的性能时,存在负向后迁移。大的负反向迁移也称为(灾难性)遗忘。
FWT:前向迁移(FWT),这是学习任务 t 对未来任务 k>t 性能的影响。特别是,当模型能够执行“零样本”学习时,可能通过利用任务描述符中可用的结构来进行正向迁移。
这些指标越大,模型越好。如果两个模型具有相似的 ACC,则最优选的是具有较大 BWT 和 FWT 的模型。请注意,讨论第一个任务的后向传输或最后一个任务的前向传输是没有意义的。

实验数据集

①MNISTPermutations,MNIST 手写数字数据集的一个变体,其中每个任务都通过像素的固定排列进行转换。在这个数据集中,每个任务的输入分布是不相关的。
②MNIST Rotations,MNIST 的一种变体,其中每个任务都包含旋转了 0 到 180 度之间的固定角度的数字
③ 增量 CIFAR100 ,CIFAR 对象识别数据集的一个变体,具有 100 个类别 ,其中每个任务都引入了一组新的类别。对于总共 T 个任务,每个新任务都涉及来自 100/T 个类的不相交子集的示例。在这里,所有任务的输入分布都是相似的,但不同的任务需要不同的输出分布。
对于所有数据集,一共考虑了 T = 20 个任务。在 MNIST 数据集上,每个任务都有来自 10 个不同类别的 1000 个示例。在 CIFAR100 数据集上,每个任务都有来自 5 个不同类别的 2500 个示例。该模型按顺序观察任务,每个示例一次。每个任务的评估是在每个数据集的测试分区上执行的。

对比方法

  1. single:在所有任务中训练的单个预测器。
  2. independent:每个任务一个独立的预测器。每个独立预测器具有与“single”相同的架构,但隐藏单元比“single”少 T 倍。每个新的独立预测器可以随机初始化,或者是最后一个训练预测器的克隆(由网格搜索决定)。
  3. multimodal:具有与“single”相同的架构,但每个任务都有一个专用输入层(仅适用于 MNIST 数据集)。
  4. EWC :其中损失被正则化以避免灾难性遗忘。
  5. iCARL :一个类增量学习器,使用最接近示例算法进行分类,并通过使用情景记忆防止灾难性遗忘。 iCARL 要求跨任务的输入表示相同,因此该方法仅适用于在 CIFAR100 上的实验。

实验结果


图 1(左)总结了所有数据集和方法的平均准确度(ACC)、后向传输(BWT)和前向传输(FWT)。总体而言,GEM 的性能优于当前的方法,同时 GEM 能够最小化后向传输,并且能够使得其为正值,而前向传输可忽略不计或为更大的正值。

可能存在的问题

①GEM 不利用结构化任务描述符(例如解释如何解决第 i 个任务的自然语言段落)而是采用整数任务描述符,可以利用结构化任务描述符来获得正向前向迁移(零样本学习)。
②每次 GEM 迭代都需要每个任务进行一次反向传递,从而增加了计算时间。

【论文笔记】Gradient Episodic Memory for Continual Learning相关推荐

  1. (GEM)Gradient Episodic Memory for Continual Learning论文笔记

    (GEM)Gradient Episodic Memory for Continual Learning Abstract GEM:减轻了遗忘,同时允许有益的知识转移到先前的任务中. Introduc ...

  2. Gradient Episodic Memory for Continual Learning 论文阅读+代码解析

    一. 介绍 在开始进行监督学习的时候我们需要收集一个训练集 D t r = { ( x i , y i ) } i = 1 n D_{tr}=\{(x_i,y_i)\}^n_{i=1} Dtr​={( ...

  3. 论文笔记:Meta-attention for ViT-backed Continual Learning CVPR 2022

    论文笔记:Meta-attention for ViT-backed Continual Learning CVPR 2022 论文介绍 论文地址以及参考资料 Transformer 回顾 Self- ...

  4. Deep Learning论文笔记之(八)Deep Learning最新综述

    Deep Learning论文笔记之(八)Deep Learning最新综述 zouxy09@qq.com http://blog.csdn.net/zouxy09 自己平时看了一些论文,但老感觉看完 ...

  5. 论文笔记:《DeepGBM: A Deep Learning Framework Distilled by GBDT for Online Prediction Tasks》

    论文笔记:<DeepGBM: A Deep Learning Framework Distilled by GBDT for Online Prediction Tasks> 摘要 1. ...

  6. 论文笔记VITAL: VIsual Tracking via Adversarial Learning

    论文笔记VITAL: VIsual Tracking via Adversarial Learning 1. 论文标题及来源 2. 拟解决问题 3. 解决方法 3.1 算法流程 4. 实验结果 4.1 ...

  7. 论文笔记——Fair Resource Allocation in Federated Learning

    论文笔记--Fair Resource Allocation in Federated Learning 原文论文链接--http://www.360doc.com/content/20/0501/1 ...

  8. Continual Learning 经典方法 — Gradient Episodic Memory (GEM)

    1. 终身学习目标: 缓解灾难性遗忘问题:当数据以online stream的方式训练模型时,训练完 Task 1 之后的模型,在学习 Task 2 的数据时往往会将 Task 1 的知识遗忘,导致在 ...

  9. 论文笔记(十六):Learning to Walk in Minutes Using Massively Parallel Deep Reinforcement Learning

    Learning to Walk in Minutes Using Massively Parallel Deep Reinforcement Learning 文章概括 摘要 1 介绍 2 大规模并 ...

最新文章

  1. Spring事务支持:利用继承简化配置
  2. Codeforces Gym 100342J Problem J. Triatrip 求三元环的数量 bitset
  3. VBS转化为exe可执行文件实例演示,vbs转exe工具推荐
  4. C#连接MySQL数据库 制作股票交易模拟程序
  5. 洛谷 P1816 忠诚 ST函数
  6. pwm控制的基本原理_单片机PWM控制基本原理详解~
  7. MTK 驱动开发(42)---GAT 工具使用
  8. 潜意识的力量:潜意识开发四大关键
  9. WinForm 设置窗体启动位置在活动屏幕右下角
  10. STC8H8K系列汇编和C51实战——实现跑马灯(汇编版)
  11. uniGUI session超时时间设置
  12. css单线边框_css border-collapse设置表格单线边框和双线边框
  13. es中的keyword相关功能
  14. STM32F103_study49_The punctual atoms(STM32 Bit operation and logical operation in C language )
  15. html表单实验结论,web前端开发技术实验报告-实验五
  16. jQuery 从零开始学习 (三) 属性与css样式
  17. Can not modify more than one base table through a join view
  18. 使用百度翻译开发平台,英文翻译为中文
  19. SimpleDeserializer encountered a child element, which is NOT expected, in something it was trying to
  20. 远程主机强迫关闭了一个现有的连接。请高手解答?

热门文章

  1. WITCH CHAPTER 0 [cry] 绝密开发中的史克威尔艾尼克斯的DX12技术演示全貌
  2. PE文件格式分析系列(文章3)----一个PE文件rdata段的分析(Win32工程Release版)(二)
  3. linux设备描述文件,iOS开发 - 超级签名实现之描述文件
  4. HTML 5 video 视频标签全属性详解(转)
  5. 什么是app报毒?该如何解决
  6. NFT is dead long live NFT
  7. C8051关闭看门狗汇编语言,汇编写启动代码之关看门狗
  8. javascript支持区号输入的省市二级联动下拉菜单
  9. https服务IE可以访问,而GOOGLE无法访问
  10. 上海儿童编程培训python课程体系是什么