论文地址:https://openaccess.thecvf.com/content_CVPR_2020/html/Zhao_Maintaining_Discrimination_and_Fairness_in_Class_Incremental_Learning_CVPR_2020_paper.html
发表于:CVPR 20

Abstract

深度神经网络(DNNs)已被应用于类增量学习中,其目的是解决现实世界中常见的不断学习新类的问题。标准DNN的一个缺点是它们容易发生灾难性的遗忘。知识蒸馏(KD)是一种常用的技术来缓解这个问题。在本文中,我们证明了它确实可以帮助模型在旧的类别中输出更多的判别结果。然而,它不能缓解模型倾向于将对象分类到新类的问题,导致KD的积极作用被隐藏和限制。我们观察到,造成灾难性遗忘的一个重要因素是,在类增量学习中,最后一个全连接(FC)层的权重是高度倾斜的。在本文中,我们提出了一个简单而有效的解决方案,其动机是上述的观察,以解决灾难性遗忘。首先,我们利用KD来维持旧类中的判别性。然后,为了进一步保持旧类和新类之间的公平性,我们提出了权重对齐(Weight Aligning, WA),在正常的训练过程后纠正FC层中有偏见的权重。与以前的工作不同,WA不需要任何额外的参数或事先的验证集,因为它利用了有偏见的权重本身提供的信息。我们在ImageNet-1000、ImageNet-100和CIFAR-100的不同设置下对提出的方法进行了评估。实验结果表明,所提出的方法可以有效地缓解灾难性遗忘,并明显优于最先进的方法。

Method

本文的动机和解法都谈不上新颖,依然是针对类增量学习中输出层(FC)倾向于预测新类的问题,提出添加一些模块进行纠偏。不同的是,本文的这一过程是"无监督"的,并不需要旧模型的统计信息或者额外的验证集之类的辅助。

本文的方法称为权重对齐(Weight Aligning, WA)。具体来说,是将新类权重向量的范数与旧类对齐。形式化地,记新旧类FC层的权重为:W=(Wold ,Wnew)\mathbf{W}=\left(\mathbf{W}_{\text {old }}, \mathbf{W}_{n e w}\right) W=(Wold ​,Wnew​) Wold =(w1,w2,⋯,wCold b)∈Rd×ColdbWnew=(wColdb+1,⋯,wColdb+Cb)∈Rd×Cb\begin{aligned} \mathbf{W}_{\text {old }} &=\left(\mathbf{w}_{1}, \mathbf{w}_{2}, \cdots, \mathbf{w}_{C_{\text {old }}^{b}}\right) \in \mathbb{R}^{d \times C_{o l d}^{b}} \\ \mathbf{W}_{n e w} &=\left(\mathbf{w}_{C_{o l d}^{b}+1}, \cdots, \mathbf{w}_{C_{o l d}^{b}+C^{b}}\right) \in \mathbb{R}^{d \times C^{b}} \end{aligned} Wold ​Wnew​​=(w1​,w2​,⋯,wCold b​​)∈Rd×Coldb​=(wColdb​+1​,⋯,wColdb​+Cb​)∈Rd×Cb​ 则有相应的正则化形式如下:Norm⁡old =(∥w1∥,⋯,∥wCold b∥)Normnew=(∥wColdb+1∥,⋯,∥wColdb+Cb∥)\begin{aligned} &\operatorname{Norm}_{\text {old }}=\left(\left\|\mathbf{w}_{1}\right\|, \cdots,\left\|\mathbf{w}_{C_{\text {old }}^{b}}\right\|\right) \\ &\text {Norm}_{n e w}=\left(\left\|\mathbf{w}_{C_{o l d}^{b}+1}\right\|, \cdots,\left\|\mathbf{w}_{C_{o l d}^{b}+C^{b}}\right\|\right) \end{aligned} ​Normold ​=(∥w1​∥,⋯,∥∥∥​wCold b​​∥∥∥​)Normnew​=(∥∥∥​wColdb​+1​∥∥∥​,⋯,∥∥∥​wColdb​+Cb​∥∥∥​)​ 进一步地,可以对新类的权重进行标准化,有:W^new=γ⋅Wnewγ=Mean⁡(Norm old )Mean⁡(Norm new)\begin{gathered} \widehat{\mathbf{W}}_{n e w}=\gamma \cdot \mathbf{W}_{n e w} \\ \gamma=\frac{\operatorname{Mean}\left(\text { Norm }_{\text {old }}\right)}{\operatorname{Mean}\left(\text { Norm }_{n e w}\right)} \end{gathered} Wnew​=γ⋅Wnew​γ=Mean( Norm new​)Mean( Norm old ​)​​ 接着回看网络输出的过程。可以简单拆分为xxx的特征被提取,然后再乘以分类头的全连接层进行分类,即:o(x)=(Oold (x)Onew(x))=(Wold Tϕ(x)WnewTϕ(x))\mathbf{o}(\mathbf{x})=\left(\begin{array}{c} \mathbf{O}_{\text {old }}(\mathbf{x}) \\ \mathbf{O}_{n e w}(\mathbf{x}) \end{array}\right)=\left(\begin{array}{c} \mathbf{W}_{\text {old }}^{T} \phi(\mathbf{x}) \\ \mathbf{W}_{n e w}^{T} \phi(\mathbf{x}) \end{array}\right) o(x)=(Oold ​(x)Onew​(x)​)=(Wold T​ϕ(x)WnewT​ϕ(x)​) 由于上文我们对权重进行了对齐,则纠偏后的输出为:ocorrected (x)=(Wold Tϕ(x)W^new Tϕ(x))=(Wold Tϕ(x)γ⋅WnewTϕ(x))=(oold (x)γ⋅onew (x))\begin{aligned} &\mathbf{o}_{\text {corrected }}(\mathbf{x})=\left(\begin{array}{c} \mathbf{W}_{\text {old }}^{T} \phi(\mathbf{x}) \\ \widehat{\mathbf{W}}_{\text {new }}^{T} \phi(\mathbf{x}) \end{array}\right) \\ &=\left(\begin{array}{c} \mathbf{W}_{\text {old }}^{T} \phi(\mathbf{x}) \\ \gamma \cdot \mathbf{W}_{n e w}^{T} \phi(\mathbf{x}) \end{array}\right)=\left(\begin{array}{c} \mathbf{o}_{\text {old }}(\mathbf{x}) \\ \gamma \cdot \mathbf{o}_{\text {new }}(\mathbf{x}) \end{array}\right) \end{aligned} ​ocorrected ​(x)=(Wold T​ϕ(x)Wnew T​ϕ(x)​)=(Wold T​ϕ(x)γ⋅WnewT​ϕ(x)​)=(oold ​(x)γ⋅onew ​(x)​)​ 总的来说方法确实极其simple,但是由于代码似乎还没有开源,为此也尚不清楚以上操作具体如何实现。

[论文阅读] Maintaining Discrimination and Fairness in Class Incremental Learning相关推荐

  1. 论文阅读:Natural Language Processing Advancements By Deep Learning: A Survey

    文章目录 一.介绍 二.背景 1.人工智能和深度学习 (1)多层感知机 (2)卷积神经网络 (3)循环神经网络 (4)自编码器 (5)生成对抗网络 2.NLP中深度学习的动机 三.NLP领域的核心概念 ...

  2. 论文阅读 【CVPR-2022】 A Simple Multi-Modality Transfer Learning Baseline for Sign Language Translation

    论文阅读 [CVPR-2022] A Simple Multi-Modality Transfer Learning Baseline for Sign Language Translation st ...

  3. 强化学习泛化性 综述论文阅读 A SURVEY OF GENERALISATION IN DEEP REINFORCEMENT LEARNING

    强化学习泛化性 综述论文阅读 摘要 一.介绍 二.相关工作:强化学习子领域的survey 三.强化学习中的泛化的形式 3.1 监督学习中泛化性 3.2 强化学习泛化性背景 3.3 上下文马尔可夫决策过 ...

  4. 【论文阅读】Misshapen Pelvis Landmark Detection WithLocal-Global Feature Learning for DiagnosingDevelop

    作者及团队:刘川斌 Chuanbin Liu; 谢洪涛; 张思成; 毛振东; 孙俊; 张永东 会议及时间:IEEE Transactions on Medical Imaging 2020-12| 期 ...

  5. 【论文阅读】Search-Based Testing Approach for Deep Reinforcement Learning Agents

    文章目录 一.论文信息 二.论文结构 三.论文内容 Abstract 摘要 一.论文信息 题目: Search-Based Testing Approach for DeepReinforcement ...

  6. 【论文阅读笔记】FLAME: Taming Backdoors in Federated Learning

    个人阅读笔记,若有错误欢迎指正 会议: USENIX Security Symposium 2022  论文地址:[2101.02281] FLAME: Taming Backdoors in Fed ...

  7. 论文阅读:Natural Language Processing Advancements By Deep Learning: A Survey 深度学习在自然语言处理中的进展

    Natural Language Processing Advancements By Deep Learning: A Survey 深度学习在自然语言处理中的进展 目录 Natural Langu ...

  8. [论文阅读] Nearest Neighbor Classifier Embedded Network for Active Learning

    论文地址:https://www.aaai.org/AAAI21Papers/AAAI-39.WanF.pdf 代码:https://github.com/WanFang13/NCE-Net 发表于: ...

  9. 论文阅读:《Neural Machine Translation by Jointly Learning to Align and Translate》

    https://blog.csdn.net/u011239443/article/details/80521026 论文地址:http://pdfs.semanticscholar.org/071b/ ...

最新文章

  1. Ubuntu16.04LTS Install Intel® RealSense™ ROS from Sources
  2. JSP的改动需要重启应用服务器才能生效?
  3. axure rp 创建弹框_如何在Axure RP 9中创建交换机
  4. 怎样实现登录用户管理_如何编写程序实现图书管理系统里面的用户管理功能
  5. Android面试,View绘制流程以及invalidate()等相关方法分析
  6. python程序运行结果始终为0_Python:始终运行程序
  7. 动态网页和静态网页的区别是什么?
  8. MySQL 基准测试(mysqlslap)出现 Using a password on the command line interface can be insecure 警告
  9. hihocoder-1014 Trie树
  10. 在拓扑引擎内检测到故障,错误代码255
  11. [LintCode] Swap Nodes in Pairs
  12. 微信小程序 页面递归生成
  13. js 对象 浅拷贝 和 深拷贝
  14. 大数据与人工智能结课论文
  15. 人民币利率互换小幅上行,通胀不乐观致紧缩预期趋浓_183
  16. SDJZ2537LOL如何拯救小学生
  17. uniapp启动页面
  18. 您的首个 App 内购买项目必须以新的 App 版本提交
  19. 数据结构--员工管理系统--链表实现
  20. 第三十三篇,网络编程TCP协议通讯过程实现和函数接口

热门文章

  1. 用SpringGraph制作拓扑图和关系图
  2. python创建系列_一起学python系列之类(创建和使用类)
  3. c++ 合并2个txt_多个表达矩阵文件合并
  4. django3与vue3本地搭建
  5. redisson版本_通用Redisson版本
  6. matlab 写入 MYSQL_阿里开源MySQL中间件Canal快速入门
  7. PyTorch 学习笔记(一):让PyTorch读取你的数据集
  8. WampServer安装教程
  9. 分类(Classification)
  10. 机器学习中二分类逻辑回归的学习笔记