Contrastive Model Inversion for Data-Free Knowledge Distillation

Model inversion,其目标是从预训练的模型中恢复训练数据,最近被证明是可行的。然而,现有的inversion方法通常存在模式崩溃问题,即合成的样本彼此高度相似,因此对下游任务(如知识蒸馏)的有效性有限。在本文中,我们提出了 Contrastive Model Inversion (CMI),其中数据多样性被明确地建模为一个可优化的目标,以缓解模式崩溃问题。

我们主要观察到,在相同数据量的约束下,数据多样性越高,对模型训练帮助越大。为此,我们在CMI中引入了一个对比学习目标,鼓励合成的样本与之前的batch已经合成的样本有更多的多样性。在CIFAR-10、CIFAR-100和Tiny-ImageNet上进行的预训练模型的实验表明,CMI不仅产生了比现有方法更真实的样本,而且当生成的数据被用于知识蒸馏时,也取得了明显优越的性能。

Code is available at https://github.com/zju-vipa/DataFree.

1 Introduction

现有的KD方法在很大程度上依赖于大量的训练数据,将知识从预先训练好的教师模型转移给学生。然而,在许多情况下,由于隐私或传输的原因,训练数据并不与预训练的模型一起发布,这使得这些方法无法适用。因此,无数据KD被提出来解决这个问题。无数据KD的重要步骤是Model inversion,其目标是从预训练的教师模型中恢复训练数据。有了合成数据,学生模型就可以通过直接利用数据驱动的KD方法轻松地学习。

Model inversion本身已被研究了很长时间。例如,[Mahendran和Vedaldi,2015]研究了Model inversion以更好地理解深层表征。[Fredrikson等人,2015]研究了Model inversion攻击来推断敏感信息。最近,随着无数据KD获得更多关注,Model inversion的研究再次出现[Yin等人,2020;Lopes等人,2017;Fang等人,2019] 。具体来说,无数据KD对Model inversion提出了更高的要求,原因如下:第一,生成的数据应该遵循与原始训练数据相同的分布,否则学生模型无法很好的学习。第二,生成数据应该具有丰富的多样性。

遗憾的是,现有的inversion方法仍然不能满足这种要求。例如,工作[Chen et al., 2019],通过拟合 “one-hot” prediction分布来inversion分类模型。[Yin等人,2020]则是通过使用存储在教师模型的批量归一化层中的统计数据对中间特征图的分布进行正则化来合成图像。这两种方法都依赖于对真实数据分布的一些假设,并通过拟合先验分布独立优化每个实例。由于没有明确的约束条件来鼓励数据的多样性,这些方法受到模式崩溃问题的影响,生成的实例变成了彼此高度相似

Fang等人,2019年;Choi等人,2020年]提出通过挖掘更难的或对抗性的例子来产生更多数据进行训练。虽然对于无数据的KD来说,取得了一些性能上的提高,但生成的数据往往看起来是不真实的。

在本文中,我们试图通过促进数据多样性的角度来缓解无数据KD中的模式崩溃问题。通过实验,我们发现在相同的数据量下,更高的数据多样性表明了更强的实例区分能力( higher data diversity indicates stronger instance discrimination)。在这一现象的启发下,我们首先提出了一个 based on instance discrimination的数据多样性定义,然后提出了Contrastive Model Inversion (CMI)来解决模式崩溃问题,同时使生成的数据分布更接近真实数据分布。通过这种方式,生成的数据变得更加多样化和真实。

具体来说,在CMI中,我们引入了另一个对比学习目标,其中positive图像对包括同一数据样本的剪裁图像和完整图像,而negative图像对包括两个不同的数据样本。通过鼓励在某些距离定义下positive图像对相互靠近,negative图像对相互远离,CMI大大改善了图像的多样性和真实性,从而促进了无数据KD的性能。在CIFAR-10、CIFAR-100和Tiny-ImageNet上进行的预训练模型的实验表明,CMI不仅确保了合成比现有技术在视觉上更合理的样本,而且在生成的数据用于知识蒸馏时,也取得了明显的优越性能。

我们的贡献如下:

  • 我们提出了数据多样性的定义,这使我们能够将多样性明确地纳入优化目标,以提高生成数据的多样性。
  • 我们提出了一种新的Contrastive Model Inversion方法来处理无数据KD中的模式崩溃,同时强制要求生成的数据分布更加接近真实的数据。
  • 我们进行了广泛的实验来验证CMI相对于现有技术水平的优越性。

2 Related Works

Model Inversion(MI)旨在从预训练模型的参数中重新构建输入,它最初是为了理解神经网络的深度表征提出的[Mahendran and Vedaldi, 2015]。给定一个函数映射φ(x)和输入x,一个标准的Model Inversion问题可以被形式化为寻找一个x’来实现最小的d(φ(x), φ(x’)),其中d(⋅,⋅)d(\cdot, \cdot)d(,)是一个误差函数,例如,MSE误差。这种范式被称为模型反转攻击[Wu et al., 2016],被广泛用于模型安全[Zhang et al., 2020]和可解释性[Mahendran and Vedaldi, 2015]等多个领域。最近,Inversion技术在知识迁移中显示出其有效性[Lopes等人,2017;Yin等人,2020],促进了无数据蒸馏的发展。

无数据知识蒸馏 旨在从容量大的教师那里学习学生模型,而不需要获取真实世界的数据[Lopes et al., 2017;Chen et al., 2019; Ma et al., 2020],从而实现模型压缩[Yu et al., 2017]。现有的无数据工作的贡献可以大致分为两类:adversarial training and data prior。adversarial training的动机是鲁棒性优化,困难的样本被合成用于学生学习[Micaelli和Storkey,2019;Fang等人,2019]。data prior为无数据KD提供了另一个视角,合成的数据必须满足某些的先验,如total variance prior[Mahendran和Vedaldi,2015]和 Batch normalization statistics[Yin等人,2020]。

对比学习在自我监督学习领域取得了巨大的进展[Chen et al., 2020; He et al., 2020]。其核心思想是将每个样本作为一个不同的类别,并学习如何区分它们[Wu et al., 2018; Liu et al., 2021]。在这项工作中,我们从另一个角度重新审视对比学习框架,它的instance discrimination 能力被用来为 model inversion中的数据多样性建模。

3 Method

3.1 Preliminary

Model inversion作为无数据知识蒸馏的重要步骤,旨在从预先训练好的教师模型ft(x;θt)f_t(x; θ_t)ft(x;θt)中恢复训练数据X’,以替代无法获得的原始数据X。在这一部分,我们讨论三种典型的inversion技术

BN regularization 最初是在[Yin et al., 2020]中引入的,通过假设特征服从高斯分布的假设来正则化X的分布。正则化通常表示为feature statistics N(µl(x),σl2(x))\mathcal N(µ_l(x), σ^2_l(x))N(µl(x),σl2(x))和Batch normalization statistics N(µl,σl2)N(µ_l, σ^2_l)N(µl,σl2)之间的差距,具体如下

Class prio 通常被引入到类条件生成中,它基于网络对x∈X′x∈\mathcal X'xX做出 “one-hot"预测的假设[Chen等人,2019]。给定一个预先定义的类别c,它鼓励最小化交叉熵损失

Adversarial Distillation的动机是robust optimization,生成在教师ft(x;θt)和学生fs(x;θs)之间产生大的分歧[Micaelli和Storkey,2019;Fang等人,2019]的样本集x,即最大化KL散度项

统一框架 结合上述技术,将形成一个统一的 inversion框架[Choi等人,2020],用于无数据知识蒸馏。

其中α、β和γ是不同损失的平衡项。由于在这个框架中没有明确的多样性约束,传统的inversion方法可能倾向于 “偷懒”,重复合成重复的样本。为了克服这个问题,我们提出了一种diversityaware inversion技术,即y Contrastive Model Inversion (CMI)。

3.2 Contrastive Model Inversion

Overview 有了预先训练好的教师模型ft(x;θt)f_t(x; θ_t)ft(x;θt),CMI的目标是产生一组具有丰富多样性的x∈X′x∈\mathcal X'xX,有了它就可以从教师那里提取全面的知识。在这一节中,我们为数据的多样性提出了一个有趣的定义,并在此基础上介绍了所提出的Contrastive Model Inversion(CMI)。我们的动机是直观的:在相同数据量的约束下,更高的多样性通常表示更强的实例可区分性。为此,我们用 instance discrimination问题对数据多样性进行建模[Wu et al.,2018],并通过对比学习构建一个可优化的目标。

Definition of Data Diversity

给定一组数据X′\mathcal X'X,对数据多样性的直观描述是 “数据集中的样本有多大的可区分性(how distinguishable are the samples from the dataset)”,这显示了多样性和 instance distinguishability之间的正相关关系。因此,如果我们有合适的度量d(x1,x2)d(x_1, x_2)d(x1,x2)来估计instance pair {x1,x2}\{x_1, x_2\}{x1,x2}的distinguishability,那么我们可以为数据多样性制定一个明确的定义,如下所示。

其中d(x1,x2)d(x_1, x_2)d(x1,x2)将应用于X中所有可能的(x1,x2)(x_1, x_2)(x1,x2)对。有各种方法来定义d(⋅,⋅)d(\cdot, \cdot)d(,),导致不同的多样性标准。例如,预训练模型ft(x;θt)f_t(x; θ_t)ft(x;θt)实际上是一个嵌入函数,它将数据x映射到一个高级特征空间,其中一个简单的度量可以定义为d(x1,x2)=∣∣ft(x1)−ft(x2)∣∣d(x_1, x_2) = ||f_t(x_1)-f_t(x_2)||d(x1,x2)=ft(x1)ft(x2),这被称为感知距离[Li等人, 2003]。然而,由于以下问题,这种距离对于多样性估计可能是有问题的。1)函数ftf_tft实际上没有被明确训练为测量样本之间的相似性,其中欧氏距离的含义对我们来说是未知的。2)、embedding ft(x)f_t(x)ft(x)可能编码了关于输入的结构信息,而这些信息不能被这个度量所捕获。3)这个距离度量是无界的,我们无法弄清楚它应该有多大才能表示出一个好的多样性。在这种情况下,在ftf_tft上最大化这样的距离度量可能只会导致不是我们想要的结果。因此,需要一个更合适的嵌入空间来构建一个有意义的distinguishability度量。在下文中,我们提出了一个基于学习的数据多样性度量,它是通过解决一个对比性学习目标来建立的

Data Diversity from Contrastive Learning 对比学习最初是为了以自我监督的方式从数据中学习有用的表征,其中通过将每个样本视为一个不同的类别来建立instance-level discrimination[Wu et al., 2018]。通过对比学习,网络可以学习如何区分不同的样本,这正好与我们对度量d(⋅,⋅)d(\cdot, \cdot)d(,)的要求相吻合。在此基础上,我们引入另一个网络h(⋅)h(\cdot)h()作为教师网络ftf_tft的 instance discriminator,接受特征ft(x)f_t(x)ft(x)作为输入,并将其投射到一个新的嵌入空间。为了简化,我们用v = h(x)来表示v=(h⋅ft)(x)v = (h \cdot f_t)(x)v=(hft)(x),因为教师网络是固定的。在h(⋅)h(\cdot)h()的新嵌入空间中,我们用简单的余弦相似度来描述数据对x1和x2之间的关系,如下所示。

然后,可以用对比学习框架[Chen et al., 2020]的形式来表述instance discrimination问题,每个instance 将被随机转换为不同的views,并应被正确匹配。对于每个instance x∈X′x∈\mathcal X'xX,我们通过随机增强构建一个positive view x+,并将其他instances视为negative view x-。对比性学习损失的形式化为:

其中常数Z(x−)Z(x^-)Z(x)指的是每个实例xi的负样本量。因此,我们可以通过最小化对比性损失LcrL_{cr}Lcr来直接最大化多样性Ldiv\mathcal L_{div}Ldiv

Model Inversion 在上一部分中,我们将数据多样性与对比学习目标结合起来,可以直接优化,使数据更加多样化。本节将对比学习整合到model inversion中,形成我们最终的算法,即contrastive model inversion。

图1:contrastive model inversion方法的说明图。在每个时间步骤中,一个重新初始化的生成器在 instance discrimination的目标下训练,以合成distinguishable samples。

如图1所示,我们的方法由四个部分组成:生成器g(⋅;θg)g(\cdot;θ_g)g(θg)、教师网络ft(⋅;θt)f_t(\cdot;θ_t)ft(θt)、 instance discriminator h(⋅;θh)h(\cdot;θ_h)h(;θh)和memory bank B。判别器是一个简单的多层感知机,如[Chen et al., 2020]中使用的,它接受倒数第二层的表征以及中间特征的 global pooling作为输入。

CMI的核心思想是逐步合成一些新的样本,这些样本可以很容易地与memory bank中的历史样本区分开来。因此,model inversion过程是以 "case-by-case "策略处理的,这意味着在每个时间步骤T中,生成器将只合成一批数据。具体来说,在时间步骤T的开始,我们重新初始化生成器,并迭代优化其latent code z 以及参数θg。在这种情况下,生成器只负责数据分布的一小部分

与[Yin等,2020]中使用的独立更新不同像素的策略相比,"case-by-case "生成器可以为像素提供更强的正则化,因为它们是由共享权重θgθ_gθg产生的。在合成过程中,随机增强将被应用于合成图像,以产生一个 local view x和一个global view x+,用于对比学习。然而,请注意,单一batch的训练将不足以训练判别器。因此,我们让存储在n memory bank B中的历史图像也参与到学习过程中。现在, contrastive model inversion的目标可以被形式化为以下内容:

其中Linv(⋅)\mathcal L_{inv}(\cdot)Linv()指的是方程4中广泛使用的inversion criterion,它只适用于图像的 global view,Lcr\mathcal L_{cr}Lcr指的是所提出的数据多样性的对比性损失。请注意,Lcr\mathcal L_{cr}Lcr同时考虑了synthetic batch g(z;θg)g(z; θ_g)g(z;θg)和B的历史数据,其中历史数据将为当前图像合成提供有用的指导。在对比学习过程中,we stop the gradient on global view and only allow backpropagation on local ones as done in [Chen and He, 2020].。We found that this operation can provide more clear gradient for local pattern synthesis.

对比性模型反演的完整算法总结在Alg中。1. 存储在memory bank B中的合成图像将被用于下游的蒸馏任务。

3.3 Decision Adversarial Distillation

有了数据X,很容易用KL散度来训练学生。然而,合成也许不是知识迁移的最佳方式,其中一些重要的模式被遗漏。对抗性蒸馏法是一种流行的提高学生成绩的技术,它将学生纳入到图像合成中,用公式10使教师和学生之间的disagreement最大化。然而,大的disagreement 可能并不总是对应于有价值的样本,因为它们可能只是一些异常值。在这项工作中,我们更加关注那些 boundary samples,并引入decision adversarial loss。

函数1{⋅}\mathbb 1\{\cdot\}1{}是一个指标,当教师和学生对x产生相同的预测时,启用对抗性学习,否则禁用。与公式10中的 unbounded loss项不同,我们的决策对抗性损失将使x接近决策边界,这可以提供更多关于教师网络的信息。

4 Experiments

https://arxiv.org/pdf/2105.08584.pdf

Contrastive Model Inversion for Data-Free Knowledge Distillation相关推荐

  1. FreeKD:Free-direction Knowledge Distillation for Graph Neural Networks

    标题:FreeKD:Free-direction Knowledge Distillation for Graph Neural Networks 1.什么是知识蒸馏? 1.1 什么是知识: we t ...

  2. Zero-shot knowledge distillation in deep networks

    Zero-shot knowledge distillation in deep networks Objective Can we do Knowledge Distillation without ...

  3. arXiv2022.10 | EfficientVLM: 基于Knowledge Distillation and Modal-adaptive Pruning的快、准VLP model

    北航.ETH Zurich.Bytedance AI Lab X-VLM : https://arxiv.org/abs/2111.08276  Bytedance AI Lab EfficientV ...

  4. Mosaicking to Distill Knowledge Distillation from Out-of-Domain Data

    Mosaicking to Distill: Knowledge Distillation from Out-of-Domain Data 在本文中,我们试图解决一项雄心勃勃的任务,即域外知识蒸馏(O ...

  5. 李宏毅作业七其二 Network Compression (Knowledge Distillation)

    Network Compression --Knowledge Distillation 前言 一.knowledge distillation是什么? 1.原理 2.KL散度 3.Readme 二. ...

  6. 深度学习——(12)Knowledge distillation(Demo)

    深度学习--(12)Knowledge distillation(Demo) 原本昨天晚上要写的,但是奈何手中有更紧迫的任务需要做,所以自己还没有实战,昨天看到了一个简单的demo,自己写了一部分注释 ...

  7. Knowledge Distillation 知识蒸馏详解

    文章目录 往期文章链接目录 Shortcoming of normal neural networks Generalization of Information Knowledge Distilla ...

  8. 【论文翻译】Few Sample Knowledge Distillation for Efficient Network Compression

    Few Sample Knowledge Distillation for Efficient Network Compression 用于高效网络压缩的少样本知识提取 论文地址:https://ar ...

  9. 论文翻译: Relational Knowledge Distillation

    Relational Knowledge Distillation 这是 CVPR 2019年的一篇文章的译文. 文章链接: Relational Knowledge Distillation 附上G ...

最新文章

  1. R语言distRhumb函数计算距离实战(两个地理点之间的Rhumb距离)
  2. java心电图心率计算_java如何画心电图?
  3. flask 实现异步非阻塞----gevent
  4. [LOJ 6042]「雅礼集训 2017 Day7」跳蚤王国的宰相(树的重心+贪心)
  5. 使用Python往Elasticsearch插入数据
  6. php unexpected t_object_operator,php - PHP中的“Unexpected T_OBJECT_OPERATOR”错误
  7. 深入剖析Redis高可用集群架构原理
  8. codejock 用法记录
  9. 1296. 聪明的燕姿
  10. mac更新后Git无法使用的问题
  11. java 传值为不可变_Java函数传参(String的不可变性)
  12. 动物识别系统代码python_动物识别系统代码
  13. ubuntu20.4安装anaconda和pycharm
  14. 卡方分布、F分布、t分布和正态分布的关系
  15. 做门户网站 个人站长的新好出路
  16. PB数据窗口9种风格
  17. 程序运算小数时为什么会出错?
  18. 字符串strip()介绍
  19. c#实现文件重命名操作
  20. Risk-V编程,实现快速排序

热门文章

  1. 哈哈哈~井字棋(无心版),快来初步感受一下代码世界的乐趣吧
  2. SAP MM 常见移动类型及定义
  3. 在CentOS 7配置IPv6 DNS Server
  4. win10照片查看器恢复办法
  5. OpenCV Flann
  6. 2021年中式烹调师(初级)模拟考试系统及中式烹调师(初级)实操考试视频
  7. 【共识专栏】共识的分类(上)
  8. redis的几种常见客户端
  9. 富途证券后端PHP面经
  10. RNA-Seq质控工具RseQC安装使用