文章目录

  • 一、摘要
  • 二、主要贡献
  • 三、创新点灵感分析
  • 四、总体框架
    • 4.1 算法介绍
    • 4.2 Generation with Masked Feature
  • 五、总结

[论文]:Yang Z, Li Z, Shao M, et al. Masked Generative Distillation[J]. arXiv preprint arXiv:2205.01529, 2022.
代码地址
论文地址
论文翻译


一、摘要

知识蒸馏已成功应用于各种任务。当前的蒸馏算法通常通过模仿教师的输出来提高学生的表现。本文表明,教师还可以通过指导学生的特征恢复来提高学生的表征能力。从这个角度来看,我们提出了掩蔽生成蒸馏(MGD),它很简单:我们屏蔽学生特征的随机像素,并迫使它通过一个简单的块生成教师的完整特征。MGD是一种真正通用的基于特征的蒸馏方法,可用于各种任务,包括图像分类、目标检测、语义分割和实例分割。我们在具有广泛数据集的不同模型上进行了实验,结果表明所有学生都取得了出色的改进。值得注意的是,我们将 ResNet-18 从 69.90% 提高到 71.69% ImageNet top-1 准确率,ResNet-50 主干的 RetinaNet 从 37.4 提高到 41.0 边界框 mAP,SOLO 基于 ResNet-50 从 33.1 提高到 36.2 Mask mAP,DeepLabV3 基于 ResNet-18 从 73.20 提高到 76.02 mIoU。

二、主要贡献

1.引入了一种新的基于特征的知识蒸馏方法,它使学生通过其掩码特征生成教师的特征,而不是直接模仿它。
2.提出了一种新的基于特征的蒸馏方法——掩蔽生成蒸馏,它简单且易于使用两个超参数。
3.我们通过对不同数据集的大量实验来验证我们的方法在各种模型上的有效性。对于图像分类和密集预测任务,学生使用 MGD 取得了显着的改进。

三、创新点灵感分析

之前的feature-based蒸馏方法通常会让学生模型尽可能模仿教师模型的输出,因为教师模型通常有着更强的表示能力。在本文中,作者发现直接去模仿教师模型来提升学生特征的表示能力其实是不必要的,如果让学生模型使用部分pixels来重建教师模型的全部特征,那么学生模型对这些使用到的pixels的表示能力也会得到提升。

上图为FPN输出的第一层要素的可视化。教师:RetinaNet-ResNeXt101。学生:RetinaNet-ResNet50。FGD是一种检测器的提取方法,它迫使学生模仿老师的特征。
由上图可以看出学生模型和教师模型的特征存在差异,同时教师模型的mAP也比学生模型高。采用SOTA的蒸馏方法进行蒸馏后(使用注意力来强迫学生模型模拟教师模型的特征),学生模型的特征与教师模型更相似,同时mAP也得到极大的提升。而使用本文的蒸馏方法训练后,学生模型与教师模型特征虽然相差较大,但是mAP甚至达到教师模型的水平。

四、总体框架

4.1 算法介绍

以前的基于特征的提取方法通常让学生尽可能地模仿老师的输出,因为老师的特征具有更强的表示能力。但是,我们认为没有必要直接模仿老师来提高学生特征的表征能力。用于提取的特征一般是通过深度网络的高阶语义信息。特征像素在一定程度上已经包含了相邻像素的信息。所以,如果能通过简单的分块,用部分像素还原老师的全部特征,这些用过的像素的表现力也能得到提升。
从这个角度出发,我们提出了一种简单有效的基于特征的提取方法——掩蔽生成提取法。如下图所示,我们首先屏蔽学生特征的随机像素,然后通过一个简单的块用屏蔽的特征生成教师的完整特征。由于在每次迭代中使用随机像素,因此在整个训练过程中将使用所有像素,这意味着该特征将更加鲁棒,并且其表示能力将得到提高。在我们的方法中,老师只是作为学生恢复特征的指导,并不要求学生直接模仿。

4.2 Generation with Masked Feature

对于基于 CNN 的模型,更深层的特征具有更大的感受野和更好地表示原始输入图像。换句话说,特征图像素已经在一定程度上包含了相邻像素的信息。因此,我们可以使用部分像素来恢复完整的特征图。
我们的方法旨在通过学生的掩码特征生成教师的特征,这有助于学生获得更好的表示。
我们用 T l ∈ R C × H × W T^l ∈ R^{C×H×W} Tl∈RC×H×W 和 S l ∈ R C × H × W ( l = 1 , . . , L ) S^l ∈ R^{C×H×W} (l = 1,.., L) Sl∈RC×H×W(l=1,..,L)教师和学生的第 l l l个特征图。首先,我们设置第 l l l 个随机掩码来覆盖学生的第 l l l 个特征,可以表示为:
M i , j l = { 0 , if  R i , j l < λ 1 Otherwise (1) M^l_{i,j}= \begin{cases} 0, & \text {if $R^l_{i,j}<\lambda $ } \\ 1 & \text{ Otherwise} \end{cases} \tag {1} Mi,jl​={0,1​if Ri,jl​<λ  Otherwise​(1)其中 R i , j l R^l_{i,j} Ri,jl​是 (0, 1) 中的随机数,i, j 分别是特征图的水平坐标和垂直坐标。λ 是一个超参数,表示掩码比率。第 l l l 个特征图由第 l l l个随机掩码覆盖。
对应的代码如下,self.lambda_mgd代表masked ratio. Defaults to 0.65,mat代表生成的随机掩码覆盖:

device = preds_S.device
mat = torch.rand((N,1,H,W)).to(device)
mat = torch.where(mat>1-self.lambda_mgd, 0, 1).to(device)

然后我们使用相应的掩码来覆盖学生的特征图,并尝试生成具有左像素的教师特征图,可以表述如下: G ( f a l i g n ( S l ) ⋅ M l ) ⟶ T l (2) G(f_{align}(S^l)\cdot M^l)\longrightarrow T^l\tag {2} G(falign​(Sl)⋅Ml)⟶Tl(2) G ( F ) = W l 2 ( R e L U ( W l 1 ( F ) ) ) (3) G(F)=W_{l2}(ReLU(W_{l1}(F))) \tag {3} G(F)=Wl2​(ReLU(Wl1​(F)))(3) G G G表示包含两个卷积层的投影仪层: W l 1 W_{l1} Wl1​ 和 W l 2 W_{l2} Wl2​,一个激活层 ReLU。在本文中,我们采用了 1×1 的卷积层对于适配层 f a l i g n f_{align} falign​,投影仪层 W l 1 W_{l1} Wl1​ 和 W l 2 W_{l2} Wl2​的3×3卷积层。用于将覆盖后的学生网络生成生成教师的feature_maps

公式2的代码为将学生网络特征与生成的随机掩码覆盖相乘,最终能得到覆盖后的特征:

masked_fea = torch.mul(preds_S, mat)

之后由公式3将新生成的masked_fea 进一步处理,尝试生成教师的feature_maps,对应的代码如下:

self.generation = nn.Sequential(nn.Conv2d(teacher_channels, teacher_channels, kernel_size=3, padding=1),nn.ReLU(inplace=True), nn.Conv2d(teacher_channels, teacher_channels, kernel_size=3, padding=1))
new_fea = self.generation(masked_fea)

根据这种方法,我们设计了MGD的蒸馏损失 L d i s L_{dis} Ldis​:
L d i s ( S , T ) = ∑ l = 1 L ∑ k = 1 C ∑ i = 1 H ∑ j = 1 W ( T k , i , j l − G ( f a l i g n ( S k , i , j l ) M i , j l ) ) 2 (4) L_{dis}(S,T)=\sum\limits_{l=1}^L\sum\limits_{k=1}^C\sum\limits_{i=1}^H\sum\limits_{j=1}^W(T^l_{k,i,j}-G(f_{align}(S^l_{k,i,j})M^l_{i,j}))^2\tag {4} Ldis​(S,T)=l=1∑L​k=1∑C​i=1∑H​j=1∑W​(Tk,i,jl​−G(falign​(Sk,i,jl​)Mi,jl​))2(4)其中 L 是蒸馏层的总和,C、H、W 表示特征图的形状。S 和 T 分别表示学生和教师的特征。对应的代码如下:

dis_loss = loss_mse(new_fea, preds_T)/N

这里值得注意的是,本文仅需要两个超参数。分别为:掩码率 λ \lambda λ、loss平衡参数 α \alpha α,相比于其他的蒸馏算法调参更为简单。

五、总结

以前基于特征的提炼方法通常会让学生尽可能地模仿老师的输出,因为老师的特征具有更强的代表性。然而,作者认为没有必要直接模仿老师来提高学生特征的表示力。用于提炼的特征一般是通过深度网络的高阶语义信息。特征像素在一定程度上已经包含了相邻像素的信息。因此,如果可以通过简单的区块来使用部分像素来还原老师的完整特征,那么这些被使用的像素的表示力也可以得到提高。
通过这个掩膜,获得了部分的特征图,然后再生成新的特征图去模仿教师网络的特征图,相比原始的特征模仿,多的这一步,是增大网络学习的难度,从而迫使学生网络去学习一个更优秀的特征表示,而生成的特征图去模仿教师网络是因为教师网络的特征表示更优秀,通过模仿可以让学生网络训练时候的”进步“方向不走偏,往学习更优秀的特征表示的方向走。

【知识蒸馏】Masked Generative Distillation相关推荐

  1. 【AAAI 2021】跨层知识蒸馏:Cross-Layer Distillation with Semantic Calibration

    [AAAI 2021]跨层知识蒸馏:Cross-Layer Distillation with Semantic Calibration 论文地址: 代码地址: 主要问题: 主要思路: 具体实现: 基 ...

  2. 【论文笔记_知识蒸馏_2022】Masked Generative Distillation

    摘要 知识提炼已经成功地应用于各种任务.当前的蒸馏算法通常通过模仿老师的输出来提高学生的表现.本文表明,教师也可以通过指导学生的特征恢复来提高学生的表征能力.从这个角度出发,我们提出了掩蔽生成蒸馏(M ...

  3. 【知识蒸馏】知识蒸馏(Knowledge Distillation)技术详解

    参考论文:Knowledge Distillation: A Survey 1.前言 ​ 近年来,深度学习在学术界和工业界取得了巨大的成功,根本原因在于其可拓展性和编码大规模数据的能力.但是,深度学习 ...

  4. 一文搞懂【知识蒸馏】【Knowledge Distillation】算法原理

    知识蒸馏算法原理精讲 文章目录 知识蒸馏算法原理精讲 1. 什么是知识蒸馏? 2. 轻量化网络的方式有哪些? 3. 为什么要进行知识蒸馏? 3.1 提升模型精度 3.2 降低模型时延,压缩网络参数 3 ...

  5. ECCV 2022 | 适用于分类,检测,分割的生成式知识蒸馏开源

    作者丨美索不达米亚平原@知乎 (已授权) 来源丨https://zhuanlan.zhihu.com/p/539496128 编辑丨极市平台 导读 本文主要介绍ECCV 2022关于知识蒸馏的工作: ...

  6. 适用于分类,检测,分割的生成式知识蒸馏开源

    关于知识蒸馏的工作: Masked Generative Distillation.该方法在图像分类和密集预测的实验中,其学生模型均获得大幅提升 文章链接:https://arxiv.org/abs/ ...

  7. ECCV 2022 | MGD:适用于分类、检测和分割的生成式知识蒸馏

    ©作者 | 美索不达米亚平原 单位 | 清华大学.字节跳动 本文介绍我们ECCV 2022关于知识蒸馏的工作: Masked Generative Distillation,方法适用于分类,检测与分割 ...

  8. ECCV 2022 | 清华字节提出MGD:适用于分类/检测/分割的生成式知识蒸馏

    点击下方卡片,关注"CVer"公众号 AI/CV重磅干货,第一时间送达 作者:美索不达米亚平原 |  已授权转载(源:知乎)编辑:CVer https://zhuanlan.zhi ...

  9. #今日论文推荐#ECCV 2022 | 清华字节提出MGD:适用于分类/检测/分割的生成式知识蒸馏

    #今日论文推荐#ECCV 2022 | 清华&字节提出MGD:适用于分类/检测/分割的生成式知识蒸馏 知识蒸馏主要可以分为logit蒸馏和feature蒸馏.其中feature蒸馏具有更好的拓 ...

最新文章

  1. 水稻微生物组时间序列分析3-冲击图展示时间序序列变化
  2. python面相对象经典例子
  3. DP Review 1
  4. 星空下的痕迹 Jenkins学习(四)----------windows下Publish over FTP插件应用
  5. Java:14 个 Spring MVC 顶级技巧,随时用随时爽,一直用一直爽
  6. Java ClassLoader findSystemClass()方法与示例
  7. android 请求网络异步载入
  8. php array =,PHP Array 函数
  9. 数据结构拾遗(3) --红黑树的设计与实现(下)
  10. 调整手机titlebar与app的titlebar相衔接
  11. js 解析php arraylist,使用JSON将ArrayList从Android发送到PHP脚本
  12. C#写文本写Csv文件操作
  13. js 自学,云知梦知识 点理论
  14. 精益管理研究院陈逸超 | 用精益思维创造数据价值金矿
  15. 手游服务器微信互通,使命召唤手游QQ和微信可以一起玩吗
  16. response.sendRedirect 加域名或者不加域名的重定向加locahost或者不加localhost
  17. 云桌面服务器中兴,随需而动——中兴通讯VDI+VOI融合云桌面解决方案
  18. 从点阵到OLED屏幕——动态扫描显示原理
  19. y7000电池固件_y7000怎么刷电池固件|Surface Pro 3固件更新:电池续航问题终解决
  20. 杰理之手机同步时间接口【篇】

热门文章

  1. 码农翻身之大话编程篇:9 CPU阿甘
  2. RLC可以采用TM、UM、AM三种方式的区别是什么
  3. a链接的四种状态:link、visited、hover、active
  4. 如何提高谷歌排名?(17个要点)
  5. liunx系统文本处理命令
  6. 简单的闪避游戏的c语言,谁有一些简单小游戏的C语言程序?
  7. 'gbk' codec can't decode byte ... 的解决办法
  8. 英文诗歌数据-绘制英文词云图+英文本文分类(pytorch)
  9. 电脑常识——host文件修改(屏蔽网站或解开屏蔽)
  10. event中的stopPropagation和preventDefault