Dreaming to Distill: Data-free Knowledge Transfer via DeepInversion

1. 论文信息

论文标题

Dreaming to Distill: Data-free Knowledge Transfer via DeepInversion

论文来源

CVPR2020, https://arxiv.org/abs/1912.08795

代码

https: //github.com/NVlabs/DeepInversion

2. 背景梳理

将所学知识从一个训练好的神经网络转移到一个新的神经网络具有许多吸引人的应用。此类知识迁移任务的方法基本都基于知识蒸馏。然而,这些工作假设先前使用的训练数据集,或者代表先前训练数据集分布的一些真实图像可用。考虑到数据集可能规模大、难以传输、管理,且涉及到数据隐私问题,先前使用的训练数据集可能无法获得。

在缺乏原先训练数据的情况下,一个有趣的问题就出现了——如何从训练的模型中恢复训练数据并将其用于知识迁移。DeepDream[1] 合成或变换输入图像,使其在给定模型的输出层中对选定的类别产生高输出响应。这个方法对输入(随机噪声或者自然图像)进行优化,使用一些正则项,在保持选定的输出激活是固定的情况下,让中间特征无约束。但以这种方式生成的图像,缺少自然图像的统计特征,很容易被识别为非自然图像,并且对知识迁移不是很有用。

论文贡献

  • 提出了DeepInversion,用于为分类模型合成class-conditioanl的图像。此外,通过Adaptive DeepInversion来利用学生与教师的分歧,提高合成图像的多样性。
  • 我们将DeepInversion应用于无数据剪枝、无数据知识迁移和增量学习。

3. 方法

3.1 背景

知识蒸馏

给定教师模型 pTp_TpT​ 和数据集XXX,学生模型的参数WSW_SWS​ 可以通过下式学到
min⁡WS∑x∈XKL(pT(x),pS(x))\min_{W_S}\sum_{x\in X}KL(p_T(x),p_S(x)) WS​min​x∈X∑​KL(pT​(x),pS​(x))
其中pT(x)=p(x,WT)p_T(x)=p(x,W_T)pT​(x)=p(x,WT​)和pS(x)=p(x,WS)p_S(x)=p(x,W_S)pS​(x)=p(x,WS​)是教师和老师的输出。

DeepDream

给定一个随机初始化的输入x^∈RH×W×C\hat{x}\in R^{H\times W\times C}x^∈RH×W×C和一个任意的目标标签yyy,通过优化以下目标来合成图片:
min⁡x^L(x^,y)+R(x^),\min_{\hat{x}} \mathcal{L}(\hat{x},y)+\mathcal{R}(\hat{x}), x^min​L(x^,y)+R(x^),
其中L(⋅)\mathcal{L}(\cdot)L(⋅)为分类损失,R(⋅)\mathcal{R}(\cdot)R(⋅)为图像正则项。DeepDream使用以下图像先验来避免生成不实际的图像:
Rprior(x^)=αtvRTV(x^)+αl2Rl2(x^)\mathcal{R}_{prior}(\hat{x})=\alpha_{tv}\mathcal{R}_{TV}(\hat{x})+\alpha_{l_2}\mathcal{R}_{l_2}(\hat{x}) Rprior​(x^)=αtv​RTV​(x^)+αl2​​Rl2​​(x^)
其中RTV\mathcal{R}_{TV}RTV​和Rl2\mathcal{R}_{l_2}Rl2​​分别为总方差惩罚和x^\hat{x}x^的二范数。图像先验正则项可以稳定的收敛到有效图像,但是这些图像的分布仍然与自然图像相差甚远,因此导致知识蒸馏效果不理想。

3.2 DeepInversion

我们使用新的特征正则化项来提高DeepDream生成的图像质量。为了增强生成图像与原始图像在不同层次特征的相似性,我们提出最小化两者在特征图的距离。

我们假设特征统计量在batches间遵循高斯分布,因此可以用均值μ\muμ和方差σ2\sigma^2σ2来定义。特征分布正则项可以表示为:
Rfeature(x^)=∑l∥μl(x^)−E(μl(x)∣X)∥2+∑l∥σl2(x^)−E(σl2(x)∣X)∥2\mathcal{R}_{feature}(\hat{x})=\sum_l\Vert\mu_l(\hat{x})-E(\mu_l(x)|X)\Vert_2+\sum_l\Vert\sigma_l^2(\hat{x})-E(\sigma_l^2(x)|X)\Vert_2 Rfeature​(x^)=l∑​∥μl​(x^)−E(μl​(x)∣X)∥2​+l∑​∥σl2​(x^)−E(σl2​(x)∣X)∥2​
其中μ(x^)\mu(\hat{x})μ(x^)和σl2(x^)\sigma^2_l(\hat{x})σl2​(x^)是batch-wise的第lll层卷积层对应的特征图的均值和方差估计。

存储在BN层中的均值和方差的移动平均统计数据可以被用来近似上式的均值方差。
E(μl(x)∣X)≃BNl(running_mean)E(\mu_l(x)|X)\simeq \text{BN}_l(running\_mean) E(μl​(x)∣X)≃BNl​(running_mean)
E(σl2(x)∣X)≃BNl(running_variance)E(\sigma_l^2(x)|X)\simeq \text{BN}_l(running\_variance) E(σl2​(x)∣X)≃BNl​(running_variance)
特征分布正则化大大提高了生成图像的质量。
DeepInversion总的正则化项
RDI(x^)=Rprior(x^)+αfRfeature(x^)\mathcal{R}_{DI}(\hat{x})=\mathcal{R}_{prior}(\hat{x})+\alpha_{f}\mathcal{R}_{feature}(\hat{x}) RDI​(x^)=Rprior​(x^)+αf​Rfeature​(x^)

3.3 Adaptive DeepInversion

除了图像质量,生成图像的多样性在避免重复和冗余的合成方面也起着至关重要的作用。我们提出了一种基于图像生成过程与学生网络之间的迭代竞争增强的图像生成机制。主要目的是鼓励学生和教师之间产生分歧。因此,我们引入了一个额外的基于Jensen-Shannon散度的图像生成损失,对输出分布相似性进行惩罚。
Rcompete(x^)=1−JS(pT(x^),pS(x^))\mathcal{R}_{compete}(\hat{x})=1-JS(p_T(\hat{x}),p_S(\hat{x})) Rcompete​(x^)=1−JS(pT​(x^),pS​(x^))
JS(pT(x^),pS(x^))=12(KL(pT(x^),M)+KL(pS(x^),M))JS(p_T(\hat{x}),p_S(\hat{x}))=\frac{1}{2}(KL(p_T(\hat{x}),M)+KL(p_S(\hat{x}),M)) JS(pT​(x^),pS​(x^))=21​(KL(pT​(x^),M)+KL(pS​(x^),M))
其中M=12⋅(pT(x^)+pS(x^))M=\frac{1}{2}\cdot (p_T(\hat{x})+p_S(\hat{x}))M=21​⋅(pT​(x^)+pS​(x^))是教师和学生分布的平均值。
在优化过程中,这个新的项会使得生成的新图像学生无法轻松分类而教师可以。如下图所示,我们的方案在学习过程中迭代扩展了图像的分布覆盖。

Adaptive DeepInversion的正则化项为
RADI(x^)=RDI(x^)+αcRcompete(x^)\mathcal{R}_{ADI}(\hat{x})=\mathcal{R}_{DI}(\hat{x})+\alpha_{c}\mathcal{R}_{compete}(\hat{x}) RADI​(x^)=RDI​(x^)+αc​Rcompete​(x^)

4. 实验

在CIFAR-10上验证了每个部分的有效性,并且在ImageNet数据集上显示了Deep Inversion在(a)剪枝;(b)知识迁移;(c)增量学习上的应用。
这里仅展示在CIFAR-10上的结果,其他实验请阅读原文。


从表中可以看出:DeepInversion和Adaptive DeepInversion 与DeepDream相比,极大的提升了无数据知识蒸馏的效果,ADI训练的学生模型甚至接近了教师的效果。和DAFL[2](利用教师,训练一个生成器将噪声转换为图片)相比,DI和ADI生成的图片更具有视觉保真度(见下图),且不需要额外的生成网络。

5. 总结

该文章提出了DeepInversion,利用已训练的模型合成高分辨率、高保真度的训练图像。此外,使用Adaptive DeepInversion,通过迭代的方式提高了生成图像的多样性。

Dreaming to Distill: Data-free Knowledge Transfer via DeepInversion相关推荐

  1. Dreaming to Distill Data-free Knowledge Transfer via DeepInversion

    Dreaming to Distill: Data-free Knowledge Transfer via DeepInversion 我们提出了DeepInversion,一种从图像分布中合成图像的 ...

  2. CVPR 2017:Interspeices Knowledge Transfer for Facial KeyPoint Detection(跨物种脸部关键点检测知识迁移)

    CVPR 2017: Interspeices Knowledge Transfer for Facial KeyPoint Detection(跨物种脸部关键点检测知识迁移) 一.介绍 本文主要涉及 ...

  3. Cross Domain Knowledge Transfer for Person Re-identification笔记

    Cross Domain Knowledge Transfer for Person Re-identification笔记 1 介绍 2 相关工作 3 方法 3.1 特征提取的ResNet 3.2 特 ...

  4. 【论文翻译】UniT: Unified Knowledge Transfer for Any-Shot Object Detection and Segmentation

    UniT: Unified Knowledge Transfer for Any-Shot Object Detection and Segmentation UniT:任意样本量的目标检测和分割的统 ...

  5. Open-Vocabulary Multi-Label Classification via Multi-modal Knowledge Transfer 论文解读

    Open-Vocabulary Multi-Label Classification via Multi-modal Knowledge Transfer 论文解读 前言 Motivation Con ...

  6. 【cvpr2022-论文笔记】《L2G: A Simple Local-to-Global Knowledge Transfer .... Semantic Segmentation》

    目录 文章概述 网络架构 Classification Loss Attention Transfer Loss Shape Tansfer Loss 相关讨论 本文记录弱监督语义分割领域论文笔记&l ...

  7. CL-ReLKT: Cross-lingual Language Knowledge Transfer for MultilingualRetrieval Question Answering论文阅读

    CL-ReLKT: Cross-lingual Language Knowledge Transfer for Multilingual Retrieval Question Answering 摘要 ...

  8. FTP登录提示Can't open data connection for transfer of /

    服务器: 系统:windows server 2008 R2 standard 是否开启防火墙:是 FTP客户端:Filezilla server 本地: FTP服务端:winscp 使用winscp ...

  9. Contour Knowledge Transfer for Salient Object Detection

    Contour Knowledge Transfer for Salient Object Detection 摘要 1 Introduction 2 Related Work 3 Approach ...

最新文章

  1. 不信你看!这次Python和AI真的玩儿大了!!
  2. Node-RED安装图形化节点dashboard实现订阅mqtt主题并在仪表盘中显示温度
  3. 均值滤波、中值滤波、混合中值滤波C++源码实例
  4. Markdown 基础学习
  5. Memcached 使用 及.NET客户端调用
  6. C#中IEnumerableT.Distinct()将指定实体类对象用Lambda表达式实现多条件去重
  7. mac睡眠快捷键_mac键盘快捷键大全哪里有?
  8. android 串口一直打开_串口通讯你真的会了吗?不妨来看看这些经验
  9. 问题-[Delphi]用LoadLibrary加载DLL时返回0的错误
  10. 【Axure报错】-Unable to connect to Axure Share. Please make sure you have an internet connection and try
  11. python 当前时间的零点,python 获取当天凌晨零点的时间戳方法
  12. android_静默安装/adb执行/软件搬家/消息派发
  13. 00018计算机应用基础2021,2021年全国自考10月00018计算机应用基础历年试题含答案.doc...
  14. Ubuntu下载安装VSCode(解决安装失败问题)
  15. tic/toc/cputime测试时间的区别
  16. 产品经理求职方法指南:面试通关
  17. win7识别到移动硬盘,但不显示盘符解决办法
  18. 《Hud 2589》Phalanx详解
  19. Hierarchical Graph Network for Multi-hop Question Answering 论文笔记
  20. scipy中的imread,imresize怎么用

热门文章

  1. 【JSD-Day01】语言基础第一天
  2. MCE公司:MCE 中国生命科学研究促进奖获奖论文集锦三
  3. Python菜鸟学习手册14----标准库+代码实例
  4. Day 3 (云计算-zsn)
  5. 微信运动步数无限修改教程最高98800
  6. 爬虫36计之1.1 爬取高清MM图片壁纸
  7. 川师计算机类专业收分安徽,四川师范大学专业收分
  8. 论文中文翻译——Vulnerability Dataset Construction Methods Applied To Vulnerability Detection A Survey
  9. 团队作业10——事后诸葛亮分析
  10. 正则表达式--常用用法及lookahead、lookbehind