知识蒸馏 | 知识回顾
每天给你送来NLP技术干货!
来自:GiantPandaCV
【引言】
知识回顾(KR)发现学生网络深层可以通过利用教师网络浅层特征进行学习,基于此提出了回顾机制,包括ABF和HCL两个模块,可以在很多分类任务上得到一致性的提升。
1摘要
知识蒸馏通过将知识从教师网络传递到学生网络,但是之前的方法主要关注提出特征变换和实施相同层的特征。
知识回顾Knowledge Review选择研究教师与学生网络之间不同层之间的路径链接。
简单来说就是研究教师网络向学生网络传递知识的链接方式。
代码在:https://github.com/Jia-Research-Lab/ReviewKD
2KD简单回顾
KD最初的蒸馏对象是logits层,也即最经典的Hinton的那篇Knowledge Distillation,让学生网络和教师网络的logits KL散度尽可能小。
随后FitNets出现开始蒸馏中间层,一般通过使用MSE Loss让学生网络和教师网络特征图尽可能接近。
Attention Transfer进一步发展了FitNets,提出使用注意力图来作为引导知识的传递。
PKT(Probabilistic knowledge transfer for deep representation learning)将知识作为概率分布进行建模。
Contrastive representation Distillation(CRD)引入对比学习来进行知识迁移。
以上方法主要关注于知识迁移的形式以及选择不同的loss function,但KR关注于如何选择教师网络和学生网络的链接,一下图为例:
(a-c)都是传统的知识蒸馏方法,通常都是相同层的信息进行引导,(d)代表KR的蒸馏方式,可以使用教师网络浅层特征来作为学生网络深层特征的监督,并发现学生网络深层特征可以从教师网络的浅层学习到知识。
教师网络浅层到深层分别对应的知识抽象程度不断提高,学习难度也进行了提升,所以学生网络如果能在初期学习到教师网络浅层的知识会对整体有帮助。
KR认为浅层的知识可以作为旧知识,并进行不断回顾,温故知新。如何从教师网络中提取多尺度信息是本文待解决的关键:
提出了Attention based fusion(ABF) 进行特征fusion
提出了Hierarchical context loss(HCL) 增强模型的学习能力。
3Knowledge Review
形式化描述
X是输入图像,S代表学生网络,其中代表学生网络各个层的组成。
Ys代表X经过整个网络以后的输出。代表各个层中间层输出。
那么单层知识蒸馏可以表示为:
M代表一个转换,从而让Fs和Ft的特征图相匹配。D代表衡量两者分布的距离函数。
同理多层知识蒸馏表示为:
以上公式是学生和教师网络层层对应,那么单层KR表示方式为:
具体
与之前不同的是,这里计算的是从j=1 to i 代表第i层学生网络的学习需要用到从第1到i层所有知识。
同理,多层的KR表示为:
Fusion方式设计
已经确定了KR的形式,即学生每一层回顾教师网络的所有靠前的层,那么最简单的方法是:
直接缩放学生网络最后一层feature,让其形状和教师网络进行匹配,这样可以简单使用一个卷积层配合插值层完成形状的匹配过程。这种方式是让学生网络更接近教师网络。
这张图表示扩展了学生网络所有层对应的处理方式,也即按照第一张图的处理方式进行形状匹配。
这种处理方式可能并不是最优的,因为会导致stage之间出现巨大的差异性,同时处理过程也非常复杂,带来了额外的计算代价。
为了让整个过程更加可行,提出了Attention based fusion
, 这样整体蒸馏变为:
如果引入了fusion的模块,那整体流程就变为下图所示:
但是为了更高的效率,再对其进行改进:
可以发现,这个过程将fusion的中间结果进行了利用,即, 这样循环从后往前进行迭代,就可以得到最终的loss。
具体来说,ABF的设计如下(a)所示,采用了注意力机制融合特征,具体来说中间的1x1 conv对两个level的feature提取综合空间注意力特征图,然后再进行特征重标定,可以看做SKNet的空间注意力版本。
而HCL Hierarchical context loss 这里对分别来自于学生网络和教师网络的特征进行了空间池化金字塔的处理,L2 距离用于衡量两者之间的距离。
KR认为这种方式可以捕获不同level的语义信息,可以在不同的抽象等级提取信息。
4实验
实验部分主要关注消融实验:
第一个是使用不同stage的结果:
蓝色的值代表比baseline 69.1更好,红色代表要比baseline更差。通过上述结果可以发现使用教师网络浅层知识来监督学生网络深层知识是有效的。
第二个是各个模块的作用:
5源码
主要关注ABF, HCL的实现:
ABF实现:
class ABF(nn.Module):def __init__(self, in_channel, mid_channel, out_channel, fuse):super(ABF, self).__init__()self.conv1 = nn.Sequential(nn.Conv2d(in_channel, mid_channel, kernel_size=1, bias=False),nn.BatchNorm2d(mid_channel),)self.conv2 = nn.Sequential(nn.Conv2d(mid_channel, out_channel,kernel_size=3,stride=1,padding=1,bias=False),nn.BatchNorm2d(out_channel),)if fuse:self.att_conv = nn.Sequential(nn.Conv2d(mid_channel*2, 2, kernel_size=1),nn.Sigmoid(),)else:self.att_conv = Nonenn.init.kaiming_uniform_(self.conv1[0].weight, a=1) # pyre-ignorenn.init.kaiming_uniform_(self.conv2[0].weight, a=1) # pyre-ignoredef forward(self, x, y=None, shape=None, out_shape=None):n,_,h,w = x.shape# transform student featuresx = self.conv1(x)if self.att_conv is not None:# upsample residual featuresy = F.interpolate(y, (shape,shape), mode="nearest")# fusionz = torch.cat([x, y], dim=1)z = self.att_conv(z)x = (x * z[:,0].view(n,1,h,w) + y * z[:,1].view(n,1,h,w))# output if x.shape[-1] != out_shape:x = F.interpolate(x, (out_shape, out_shape), mode="nearest")y = self.conv2(x)return y, x
HCL实现:
def hcl(fstudent, fteacher):
# 两个都是list,存各个stage对象loss_all = 0.0for fs, ft in zip(fstudent, fteacher):n,c,h,w = fs.shapeloss = F.mse_loss(fs, ft, reduction='mean')cnt = 1.0tot = 1.0for l in [4,2,1]:if l >=h:continuetmpfs = F.adaptive_avg_pool2d(fs, (l,l))tmpft = F.adaptive_avg_pool2d(ft, (l,l))cnt /= 2.0loss += F.mse_loss(tmpfs, tmpft, reduction='mean') * cnttot += cntloss = loss / totloss_all = loss_all + lossreturn loss_all
ReviewKD实现:
class ReviewKD(nn.Module):def __init__(self, student, in_channels, out_channels, shapes, out_shapes,): super(ReviewKD, self).__init__()self.student = studentself.shapes = shapesself.out_shapes = shapes if out_shapes is None else out_shapesabfs = nn.ModuleList()mid_channel = min(512, in_channels[-1])for idx, in_channel in enumerate(in_channels):abfs.append(ABF(in_channel, mid_channel, out_channels[idx], idx < len(in_channels)-1))self.abfs = abfs[::-1]self.to('cuda')def forward(self, x):student_features = self.student(x,is_feat=True)logit = student_features[1]x = student_features[0][::-1]results = []out_features, res_features = self.abfs[0](x[0], out_shape=self.out_shapes[0])results.append(out_features)for features, abf, shape, out_shape in zip(x[1:], self.abfs[1:], self.shapes[1:], self.out_shapes[1:]):out_features, res_features = abf(features, res_features, shape, out_shape)results.insert(0, out_features)return results, logit
6参考
https://zhuanlan.zhihu.com/p/363994781
https://arxiv.org/pdf/2104.09044.pdf
https://github.com/dvlab-research/ReviewKD
投稿或交流学习,备注:昵称-学校(公司)-方向,进入DL&NLP交流群。
方向有很多:机器学习、深度学习,python,情感分析、意见挖掘、句法分析、机器翻译、人机对话、知识图谱、语音识别等。
记得备注呦
整理不易,还望给个在看!
知识蒸馏 | 知识回顾相关推荐
- 知识蒸馏 | 知识蒸馏理论篇
之前在<一文搞懂[知识蒸馏][Knowledge Distillation]算法原理>这篇文章中介绍过一些知识蒸馏的原理,这篇博文将会着重介绍目标检测领域的知识蒸馏原理. 文章目录 1.& ...
- 论文阅读:Knowledge Distillation: A Survey 知识蒸馏综述(2021)
论文阅读:Knowledge Distillation: A Survey 知识蒸馏综述2021 目录 摘要 Introduction Background 知识 基于响应的知识 基于特征的知识 基于 ...
- 知识蒸馏论文翻译(3)—— Ensembled CTR Prediction via Knowledge Distillation
知识蒸馏论文翻译(3)-- Ensembled CTR Prediction via Knowledge Distillation 经由知识蒸馏的集合CTR预测 文章目录 知识蒸馏论文翻译(3)-- ...
- ACL 2021 | 结构化知识蒸馏方法
本文介绍了上海科技大学屠可伟课题组与阿里巴巴达摩院的一项合作研究,提出了在结构预测问题上一种较为通用的结构化知识蒸馏方法.该论文已被 ACL 2021 接受为长文. 论文标题: Structural ...
- 知识蒸馏在广告系统中的应用(一)
上篇文章主要和大家聊的是强化学习在推荐混排中的应用,今天我们会开启一个全新的系列"知识蒸馏在广告系统中的应用".本文将主要涉及第一部分--背景介绍.背景介绍大致分为三块:简述广告系 ...
- BERT知识蒸馏TinyBERT
1. 概述 诸如BERT等预训练模型的提出显著的提升了自然语言处理任务的效果,但是随着模型的越来越复杂,同样带来了很多的问题,如参数过多,模型过大,推理事件过长,计算资源需求大等.近年来,通过模型压缩 ...
- 知识蒸馏——pytorch实现
轻量化网络 知识蒸馏可以理解为轻量化网络的一个tricks,轻量化网络是深度学习的一个大的发展趋势,尤其是在移动端,终端边缘计算这种对算力和运算时间有要求的场景中. 轻量化网络可以有以下四种方式实现: ...
- BERT知识蒸馏Distilled BiLSTM
1. 概述 随着BERT模型的提出,在NLP上的效果在不断被刷新,伴随着计算能力的不断提高,模型的深度和复杂度也在不断上升,BERT模型在经过下游任务Fine-tuning后,由于参数量巨大,计算比较 ...
- 深度总结 | 知识蒸馏在推荐系统中的应用
查看全文 http://www.taodudu.cc/news/show-3127021.html 相关文章: 手把手教导 3分钟让你快速入门地图可视化 Born-Again Neural Netwo ...
- 联邦学习——用data-free知识蒸馏处理Non-IID
<Data-Free Knowledge Distillation for Heterogeneous Federated Learning>ICML 2021 最近出现了利用知识蒸馏来解 ...
最新文章
- 注册与验证码php源代码,PHP验证码处理源代码
- lua学习笔记之函数
- Win7下面安装SQL Server2005
- 腾讯初探AI+农业 获国际AI温室种植大赛亚军
- object.prototype.call
- 十分钟让你明白AIDL
- mt4交易系统源码_mt4周边:一款免费的数据下载工具
- Eclipse设置Android Logcat输出字体大小
- python 在线培训费用-python培训班费用
- 企业进销存管理系统(一)
- c语言1至100的累乘求和,c语言 累加累乘课件.ppt
- 手机HTML5 audio 无法自动播放下一首
- Android 继承于PopuWindow的自定义弹出窗体
- 你还不知道如何去学习3D建模,那你来找我,我教你
- c 语言编辑器 win7旗舰版,如何使用大地win7旗舰版内置字符编辑程序
- LC振荡器稳定度与品质因数的关系
- 超大数据10进制转2进制详解(可推广到其他进制)/ Codeup 100000579 问题 C: 进制转换
- 华为设备配置策略路由引流到旁挂防火墙
- VB基础版版务处理_20041210
- 神奇数字7(你在知网搜不到的冷知识)
热门文章
- python 2.7.9 安装beautifulsoup4
- 设置NumericStepper控件不可用状态的字体颜色。
- 30天敏捷结果(24):恢复你的精力
- tensorflow搭建神经网络
- python运算符及优先级顺序
- 第三次小组实践作业小组每日进度汇报:2017-12-2
- 201521145048 《Java程序设计》第3周学习总结
- [解题报告][搜索+剪枝技巧]幻方
- Android依赖注入:Dagger、RoboGuice和ButterKnife
- Leetcode Contains Duplicate II