CONTRASTIVE REPRESENTATION DISTILLATION

我们常常希望将表征性知识从一个神经网络转移到另一个神经网络。这方面的例子包括将一个大型网络提炼成一个较小的网络,将知识从一种感觉模式转移到另一种感觉模式,或者将一系列模型集合成一个单一的估计器。知识提炼是解决这些问题的标准方法,它使教师和学生网络的概率输出之间的KL背离最小。我们证明这个目标忽略了教师网络的重要结构知识。这促使我们提出了另一个目标,即训练学生在教师的数据表述中捕捉到更多的信息。我们把这个目标表述为对比性学习。实验证明,我们的新目标在各种知识转移任务上优于知识蒸馏和其他尖端蒸馏器,包括单一模型压缩、集合蒸馏和跨模式转移。我们的方法在许多转移任务中创造了新的最先进的技术,当与知识蒸馏相结合时,有时甚至超过了教师网络。

1 INTRODUCTION

知识提炼(KD)将知识从一个深度学习模型(教师)转移到另一个(学生)。最初由Hinton等人(2015)提出的目标是最小化教师和学生输出之间的KL散度。当输出是一个分布时,这种表述具有直观的意义,例如,在类上的probability mass function。然而,我们经常希望转移关于一个表示的知识。例如,在 "跨模式提炼 "的问题中,我们可能希望将图像处理网络的表示转移到声音(Aytar等人,2016)或深度(Gupta等人,2016)处理网络,这样,图像的深度特征和相关的声音或深度特征是高度相关的。在这种情况下,KL发散是不确定的。

表征性知识是结构化的–各维度呈现出复杂的相互依赖关系。在(Hinton等人,2015)中引入的原始KD目标将所有维度视为独立的,并以输入为条件。让yTy^TyT为老师的输出,ySy^SyS是学生的输出。然后是原始KD目标函数,ψ, has the fully factored form:ψ(yS,yT)=Σiϕi(yiS,yiT)ψ(y^S,y^T)=\Sigma_i\phi_ i(y^S_i,y^T_i)ψ(yS,yT)=Σiϕi(yiS,yiT). 这样一个考虑因素的目标不足以转移结构知识,即输出维度i和j之间的依赖关系。这类似于图像生成中,由于输出维度之间的独立假设,L2目标产生模糊结果的情况。

为了克服这个问题,我们希望有一个能够捕捉到相关性和高阶输出依赖性的目标。为了实现这一点,在本文中,我们利用了e the family of contrastive objectives(Gutmann & Hyvärinen,2010;Oord等人,2018;Arora等人,2019;Hjelm等人,2018)。近年来,这些目标函数被成功地用于density estimation和表征学习,特别是在自我监督的情况下。在这里,我们将它们调整为从一个深度网络到另一个深度网络的知识提炼任务。我们表明,在表示空间中工作是很重要的,与最近的工作如Zagoruyko & Komodakis(2016a);Romero等人(2014)类似。然而,请注意,这些作品中使用的损失函数并没有明确地试图捕捉表征空间中的相关性或高阶依赖关系。

我们的目标是最大限度地降低师生之间的互信息表示。我们发现,这导致在一些知识转移任务中表现更好,我们推测这是因为对比目标更好地传递了教师表征中的所有信息,而不是仅仅传递关于条件独立输出类概率的知识。有些令人惊讶的是,对比目标甚至改进了最初提出的提取类概率知识的任务的结果,例如,将一个大型CIFAR100网络压缩成一个性能几乎相同的小型网络。我们认为这是因为不同类别概率之间的相关性包含了有用的信息,可以规范学习问题。我们的论文在两个主要独立发展的文献之间建立了联系:知识提炼和表征学习。这种联系使我们能够利用表征学习的强大方法来显著改进知识提炼的SOTA。我们的贡献是:

  1. 一个基于对比的目标,在深度网络之间转移知识。
  2. 应用于模型压缩、跨模式转移和ensemble distillation。
  3. 对最近的12种蒸馏方法进行基准测试;CRD优于所有其他方法

2 RELATED WORK

注意力转移(Zagoruyko & Komodakis, 2016a)侧重于网络的特征图,而不是输出logits。这里的想法是在教师和学生的特征图中激发出类似的反应模式(称为 “attention”)。然而,在这种方法中,只有具有相同空间分辨率的特征图可以被结合,这是一个重要的限制,因为它需要学生和教师网络具有非常相似的架构。

这种技术实现了最先进的提炼结果(以学生网络的泛化为标准)。FitNets(Romero等人,2014)也通过使用回归来指导学生网络的特征激活来处理中间表征。由于Zagoruyko和Komodakis(2016a)做了这种回归的加权形式,他们往往表现得更好。其他论文(Yim等人,2017;Huang & Wang,2017;Kim等人,2018;Yim等人,2017;Huang & Wang,2017;Ahn等人,2019;Koratana等人,2019)执行了基于表示的各种标准。我们在本文中使用的对比性目标与CMC中使用的目标相同(Tian et al., 2019)。但我们从不同的角度推导,并给出严格的证明,我们的目标是互信息的下限。我们的目标也与(Oord等人,2018;Gutmann & Hyvärinen,2010)中介绍的InfoNCE和NCE目标有关。Oord等人(2018)在自我监督的表征学习的背景下使用对比学习。他们表明,他们的目标是最大化相互信息的下限。Hjelm等人(2018)使用了一种非常相关的方法。InfoNCE和NCE密切相关,但与对抗性学习不同(Goodfellow等人,2014)。在(Goodfellow,2014)中,表明Gutmann & Hyvärinen(2010)的NCE目标可以导致最大似然学习,但不是对抗性目标。

3 METHOD

对比学习的关键思想是非常普遍的:对于 "positive " pair,学习一个在某些度量空间中接近的表征,并在 "negative " pair之间分离表征。图1直观地解释了我们如何为我们考虑的三个任务构建对比性学习:模型压缩、跨模式转移和ensemble distillation。

图1:我们考虑的三种提取设置:(a)压缩模型,(b)将知识从一种模式(如RGB)转移到另一种模式(如深度),(c)将网络集合提取到单个网络中。建构目标鼓励教师和学生将相同的输入映射到接近的表示(在某些度量空间中),并将不同的输入映射到遥远的表示,如阴影圆所示。

3.1 CONTRASTIVE LOSS

给定两个深度神经网络,一个是教师fTf^TfT,一个是学生fSf^SfS。设x为网络输入;我们将倒数第二层(在logits之前)的表征分别表示为fT(x)和fS(x)f^T(x)和f^S(x)fT(x)fS(x)。设xix_ixi代表训练样本,xjx_jxj为随机选取样本。我们想推近表示fT(xi)和fS(xi)f^T(x_i)和f^S(x_i)fT(xi)fS(xi),而分离fT(xj)和fS(xj)f^T(x_j)和f^S(x_j)fT(xj)fS(xj)。为了便于记法,我们分别为学生和教师的数据表示定义了随机变量S和T:

直观地说,我们将考虑联合分布p(S,T)p(S, T)p(S,T)和边际分布的乘积p(S)p(T)p(S)p(T)p(S)p(T),因此,通过最大化这些分布之间的KL散度,我们可以最大化学生和教师表示之间的互信息。为了设置一个能够实现这一目标的适当的损失,让我们定义一个带有latent variable C的分布q,它决定一个tuple(fT(xi),fS(xj))(f^T(x_i), f^S(x_j ))(fT(xi),fS(xj))是来自联合(C = 1)还是边际分布的乘积(C = 0)

Now, suppose in our data, we are given 1 congruent pair (drawn from the joint distribution, i.e. the same input provided to T and S) for every N incongruent pairs (drawn from the product of marginals; independent randomly drawn inputs provided to T and S). Then the priors on the latent C are:

通过简单的应用贝叶斯法则,C=1类的后验为

接下来,我们观察到与互信息的联系,如下所示

然后取两边的期望 p(t,S)p(t,S)p(t,S)(相当于 q(T,S∣C=1)q(T,S | C=1)q(T,SC=1))并重新排列((equivalently w.r.t. q(T,S∣C=1)q(T,S | C=1)q(T,SC=1)) and rearranging),得到:

其中I(T;S)I(T; S)I(T;S)是教师和学生embeddings分布之间的互信息。因此,最大化Eq(T,S∣C=1)logq(C=1∣T,S)\mathbb E_{q(T ,S|C=1)}log\ q(C = 1|T, S)Eq(T,SC=1)logq(C=1T,S),通过学生网络的参数S,增加了互信息的下限。然而,我们不知道真实的分布q(C = 1|T, S);所以,我们通过拟合一个模型h:{T,S}→[0,1]h:\{\mathcal T ,\mathcal S\} → [0, 1]h:{T,S}[0,1]来估计它,通过来自数据分布q(C=1∣T,S)q(C = 1|T, S)q(C=1T,S)的样本,其中T和S\mathcal T和\mathcal STS代表embeddings的域( the domains of the embeddings)。我们最大化这个模型下的数据的对数似然(一个二元分类问题):

We term h the critic since we will be learning representations that optimize the critic’s score.假设h有足够的表现力,h∗(T,S)=q(C=1∣T,S)h^∗(T, S)=q(C = 1|T, S)h(T,S)=q(C=1T,S)(通过吉布斯不等式;证明见第6.2.1节),因此我们可以用h∗重写公式9

因此,我们看到,最佳critic是一个estimator ,其期望值降低了互信息的范围。我们希望学习一个学生,使其表征与教师的表征之间的互信息最大化,这就提出了以下优化问题

这里的一个明显的困难是,最佳critich∗取决于当前的学生。我们可以通过将(12)中的约束弱化来规避这个困难:

这说明我们可以在学习h的同时共同优化fSf^SfS。我们注意到,由于(16),fS∗=argmaxfSLcritic(h)f^{S∗} = arg\ max_{f^S} \mathcal L_{critic}(h)fS=argmaxfSLcritic(h)对于任何h,也是一个基于互信息优化下限(较弱的下限)的表示,因此我们的公式不依赖于h的完美优化。

我们可以选择用满足h:{T,S}→[0,1]h:\{\mathcal T,\mathcal S\}\rightarrow[0,1]h:{T,S}[0,1]的任何函数族. 在实践中,我们使用以下方法:

其中M是数据集的cardinality,τττ是调整concentration level的温度。在实践中,由于S和T的维度可能不同,gS和gTg^S和g^TgSgT将它们线性转化为相同的维度,并在内积之前通过L-2 norm进一步归一化它们。公式(18)的形式受到NCE的启发(Gutmann & Hyvärinen, 2010; Wu et al., 2018)。我们的表述与InfoNCE损失(Oord等人,2018)相似,即我们最大化了互信息的下限。然而,我们使用了一个不同的目标和约束,在我们的实验中,我们发现它比InfoNCE更有效。

Implementation. 理论上,公式16中较大的N会导致MI的更严格的下限。在实践中。为了避免使用非常大的batch规模,我们遵循Wu等人(2018)的做法,实现了一个内存缓冲器存储每个数据样本的潜在特征,这些特征是由以前的批次计算出来的。因此,在 训练中,我们可以有效地从存储缓冲区中检索大量的负面样本。

3.2 KNOWLEDGE DISTILLATION OBJECTIVE

Hinton等人(2015)提出了知识提炼损失。除了学生输出ySy^SyS和 one-hot label y之间的常规交叉熵损失外,它还要求学生网络输出尽可能地与教师输出相似,即最小化他们输出概率之间的交叉熵。完整的目标是

3.3 CROSS-MODAL TRANSFER LOSS

在图1(b)所示的跨模态转移任务中,教师网络是在一个具有大规模标记的数据集的源模态X上训练的。然后,我们希望将知识转移给学生网络,但要将其适应于另一个数据集或模态Y。但教师网络的特征仍然有价值,可以帮助学生在另一个领域学习。在这个转移任务中,我们使用对比性损失公式10来匹配学生和教师的特征。此外,我们还考虑了其他的提炼目标,比如上一节中讨论的KD,注意力转移Zagoruyko & Komodakis(2016a)和FitNet Romero等人(2014)。

这种transfer是在一个成对但无标签的数据集D={(xi,yi)∣i=1,...,L,xi∈X,yi∈Y}D = \{(x_i, y_i)|i = 1, ..., L, x_i ∈ \mathcal X , y_i ∈\mathcal Y\}D={(xi,yi)i=1,...,L,xiX,yiY}上进行。在这种情况下,对于源模态上的原始训练任务,没有这种数据的真实标签y,因此我们在所有测试的目标中忽略了H(y,yS)H(y, y^S)H(y,yS)项。之前的跨模态工作Aytar等人(2016);Hoffman等人(2016b;a)使用L2回归或KL-散度。

4 ENSEMBLE DISTILLATION LOSS

在1©所示的集合蒸馏的情况下,我们有M>1个教师网络,fTif^{T_i}fTi和一个学生网络fSf^SfS。我们采用对比框架,在每个教师网络fTif^{T_i}fTi和学生网络fSf^SfS的特征之间定义了多个成对的对比性损失。这些损失相加,得出最终的损失(要最小化)

网络,fTif^{T_i}fTi和一个学生网络fSf^SfS。我们采用对比框架,在每个教师网络fTif^{T_i}fTi和学生网络fSf^SfS的特征之间定义了多个成对的对比性损失。这些损失相加,得出最终的损失(要最小化)

CONTRASTIVE REPRESENTATION DISTILLATION相关推荐

  1. 【CVPR 2021】基于Wasserstein Distance对比表示蒸馏方法:Wasserstein Contrastive Representation Distillation

    [CVPR 2021]基于Wasserstein Distance对比表示蒸馏方法:Wasserstein Contrastive Representation Distillation 论文地址: ...

  2. Simple Contrastive Representation Adversarial Learning for NLP Tasks

    论文目的 对比学习是自监督学习在NLP中的应用,本文使用对抗方法生成对比学习需要的训练样本对,对比学习提升了对抗训练的泛华能力,对抗训练也增强了对比学习的鲁棒性,提出了两种结构:supervised ...

  3. BEVDistill: Cross-Modal BEV Distillation for Multi-View 3D Object Detection

    Paper name BEVDistill: Cross-Modal BEV Distillation for Multi-View 3D Object Detection Paper Reading ...

  4. 杂谈 | 当前知识蒸馏与迁移学习有哪些可用的开源工具?

    所有参与投票的 CSDN 用户都参加抽奖活动 群内公布奖项,还有更多福利赠送 作者&编辑 | 言有三 来源 | 有三AI(ID:yanyousan_ai) [导读]知识蒸馏与迁移学习不仅仅属于 ...

  5. 无需多个模型也能实现知识整合?港中文MMLab提出「烘焙」算法,全面提升ImageNet性能...

    视学算法专栏 转载自:机器之心 作者:葛艺潇 来自港中文 MMLab 的研究者提出一种烘焙(BAKE)算法,为知识蒸馏中的知识整合提供了一个全新的思路,打破了固有的多模型整合的样式,创新地提出并尝试了 ...

  6. 【模型蒸馏】从入门到放弃:深度学习中的模型蒸馏技术

    点击上方,选择星标或置顶,每天给你送干货! 阅读大概需要17分钟 跟随小博主,每天进步一丢丢 来自 | 知乎   作者 | 小锋子Shawn 地址 | https://zhuanlan.zhihu.c ...

  7. ICLR2020:40篇计算机视觉github开源论文合集

    关注上方"深度学习技术前沿",选择"星标公众号", 资源干货,第一时间送达! 转自极市平台 1. star:9819|Weakly Supervised Dis ...

  8. 【杂谈】当前知识蒸馏与迁移学习有哪些可用的开源工具?

    知识蒸馏与迁移学习不仅仅属于模型优化的重要技术之一,也是提升模型跨领域泛化能力的重要技术,那么当前有哪些可用的知识蒸馏和迁移学习开源工具呢? 作者&编辑 | 言有三 1 PaddleSlim ...

  9. 最高一万星!GitHub 标星最多的 40 篇 ICLR2020 计算机视觉论文合集,附打包下载

    编译|极市平台 1. star:9819|Weakly Supervised Disentanglement with Guarantees(弱监督学习) 论文:https://arxiv.org/p ...

  10. “烘焙”ImageNet:自蒸馏下的知识整合

    ©作者|葛艺潇 学校|香港中文大学博士生 研究方向|图像检索.图像生成等 最先进的知识蒸馏算法发现整合多个模型可以生成更准确的训练监督,但需要以额外的模型参数及明显增加的计算成本为代价.为此,我们提出 ...

最新文章

  1. LeetCode简单题之整理字符串
  2. 词法分析器构造工具Flex基础学习
  3. P1744 采购特价商品(SPFA求最短路径模板)
  4. python入门需要什么基础知识_Python 基础之:入门必备知识
  5. 动态“神还原”李焕英旧照,用技术致敬每一位妈妈!
  6. 区块链JAVA数字交易所官方商业版开发级全套三端纯源码
  7. ASP禁止刷新当前页
  8. winXP 下安装python3.3.2
  9. 华为发布全球首款 5G 汽车通讯硬件;今日头条系产品大裁员;三星手机推迟上市 | 极客头条...
  10. nginx一招配置,帮你快速隐藏php后缀名
  11. 【设计模式笔记】代理模式
  12. ideaIU-2019.3.2.exe安装教程
  13. 简单易用的运动控制卡(八):直线插补和手轮运动
  14. 7-6 jmu-python-随机生成密码 (10 分)习题解答
  15. unity 裙子摆动_【Unity Shader】摇摆的小草——顶点动画
  16. WEP加密概念-个人笔记
  17. GBase 8c 的安全特性
  18. iPhone转Android体验,一直用苹果手机,突然换成安卓是什么体验?网友:差距太大!...
  19. 一种简单、安全的Dota全图新思路 作者:LC 【转】
  20. pycharm查找替换快捷键

热门文章

  1. 民生服务是“双创”永恒主题 且听“鸿雁旅居网”、“熊猫中医”谈背后心路历程...
  2. jsp医疗报销管理系统 myeclipse开发技术 mysql数据库
  3. 北京医保报销比例,范围
  4. PPT精美模板免费下载网站 高清壁纸免费下载网站 在线PS(Photoshop在线网站)网站 分享
  5. cadence ETS安装过程
  6. 企业OA系统在低代码平台中要如何开发?
  7. linux中rm件命令,Linux rm命令详解
  8. Windows程序闪退原因查看方法----事件查看器
  9. ROS学习笔记9 —— launch文件
  10. 指投:3 常见的指数基金品种