【知识蒸馏】常见的知识蒸馏方式(二)
1、Distilling the Knowledge in a Neural Network
Hinton的文章"Distilling the Knowledge in a Neural Network"首次提出了知识蒸馏(暗知识提取)的概念,通过引入与教师网络(teacher network:复杂、但推理性能优越)相关的软目标(soft-target)作为total loss的一部分,以诱导学生网络(student network:精简、低复杂度)的训练,实现知识迁移(knowledge transfer)。
如上图所示,教师网络(左侧)的预测输出除以温度参数(Temperature)之后、再做softmax变换,可以获得软化的概率分布(软目标),数值介于0~1之间,取值分布较为缓和。Temperature数值越大,分布越缓和;而Temperature数值减小,容易放大错误分类的概率,引入不必要的噪声。针对较困难的分类或检测任务,Temperature通常取1,确保教师网络中正确预测的贡献。硬目标则是样本的真实标注,可以用one-hot矢量表示。total loss设计为软目标与硬目标所对应的交叉熵的加权平均(表示为KD loss与CE loss),其中软目标交叉熵的加权系数越大,表明迁移诱导越依赖教师网络的贡献,这对训练初期阶段是很有必要的,有助于让学生网络更轻松的鉴别简单样本,但训练后期需要适当减小软目标的比重,让真实标注帮助鉴别困难样本。另外,教师网络的推理性能通常要优于学生网络,而模型容量则无具体限制,且教师网络推理精度越高,越有利于学生网络的学习。
教师网络与学生网络也可以联合训练,此时教师网络的暗知识及学习方式都会影响学生网络的学习,具体如下(式中三项分别为教师网络softmax输出的交叉熵loss、学生网络softmax输出的交叉熵loss、以及教师网络数值输出与学生网络softmax输出的交叉熵loss):
联合训练的Paper地址:https://arxiv.org/abs/1711.05852
2、Exploring Knowledge Distillation of Deep Neural Networks for Efficient Hardware Solutions
这篇文章将total loss重新定义如下:
GitHub地址:https://github.com/peterliht/knowledge-distillation-pytorch
total loss的Pytorch代码如下,引入了精简网络输出与教师网络输出的KL散度,并在诱导训练期间,先将teacher network的预测输出缓存到CPU内存中,可以减轻GPU显存的overhead:
def loss_fn_kd(outputs, labels, teacher_outputs, params):"""Compute the knowledge-distillation (KD) loss given outputs, labels."Hyperparameters": temperature and alphaNOTE: the KL Divergence for PyTorch comparing the softmaxs of teacherand student expects the input tensor to be log probabilities! See Issue #2"""alpha = params.alphaT = params.temperatureKD_loss = nn.KLDivLoss()(F.log_softmax(outputs/T, dim=1),F.softmax(teacher_outputs/T, dim=1)) * (alpha * T * T) + \F.cross_entropy(outputs, labels) * (1. - alpha)return KD_loss
3、Ensemble of Multiple Teachers
第一种算法:多个教师网络输出的soft label按加权组合,构成统一的soft label,然后指导学生网络的训练:
第二种算法:由于加权平均方式会弱化、平滑多个教师网络的预测结果,因此可以随机选择某个教师网络的soft label作为guidance:
第三种算法:同样地,为避免加权平均带来的平滑效果,首先采用教师网络输出的soft label重新标注样本、增广数据、再用于模型训练,该方法能够让模型学会从更多视角观察同一样本数据的不同功能:
Paper地址:
https://www.researchgate.net/publication/319185356_Efficient_Knowledge_Distillation_from_an_Ensemble_of_Teachers
4、Hint-based Knowledge Transfer(这个不是很懂)
为了能够诱导训练更深、更纤细的学生网络(deeper and thinner FitNet),需要考虑教师网络中间层的Feature Maps(作为Hint),用来指导学生网络中相应的Guided layer。此时需要引入L2 loss指导训练过程,该loss计算为教师网络Hint layer与学生网络Guided layer输出Feature Maps之间的差别,若二者输出的Feature Maps形状不一致,Guided layer需要通过一个额外的回归层,具体如下:
具体训练过程分两个阶段完成:第一个阶段利用Hint-based loss诱导学生网络达到一个合适的初始化状态(只更新W_Guided与W_r);第二个阶段利用教师网络的soft label指导整个学生网络的训练(即知识蒸馏),且total loss中soft target相关部分所占比重逐渐降低,从而让学生网络能够全面辨别简单样本与困难样本(教师网络能够有效辨别简单样本,而困难样本则需要借助真实标注,即hard target):
Paper地址:https://arxiv.org/abs/1412.6550
GitHub地址:https://github.com/adri-romsor/FitNets
5、Attention to Attention Transfer(这个不是很懂)
通过网络中间层的attention map,完成teacher network与student network之间的知识迁移。考虑给定的tensor A,基于activation的attention map可以定义为如下三种之一:
随着网络层次的加深,关键区域的attention-level也随之提高。文章最后采用了第二种形式的attention map,取p=2,并且activation-based attention map的知识迁移效果优于gradient-based attention map,loss定义及迁移过程如下:
Paper地址:https://arxiv.org/abs/1612.03928
GitHub地址:https://github.com/szagoruyko/attention-transfer
6、Flow of the Solution Procedure
暗知识亦可表示为训练的求解过程(FSP: Flow of the Solution Procedure),教师网络或学生网络的FSP矩阵定义如下(Gram形式的矩阵https://blog.csdn.net/u013066730/article/details/80940781):
训练的第一阶段:最小化教师网络FSP矩阵与学生网络FSP矩阵之间的L2 Loss,初始化学生网络的可训练参数:
训练的第二阶段:在目标任务的数据集上fine-tune学生网络。从而达到知识迁移、快速收敛、以及迁移学习的目的。
Paper地址:
http://openaccess.thecvf.com/content_cvpr_2017/papers/Yim_A_Gift_From_CVPR_2017_paper.pdf
7、Knowledge Distillation with Adversarial Samples Supporting Decision Boundary(这个不是很懂)
从分类的决策边界角度分析,知识迁移过程亦可理解为教师网络诱导学生网络有效鉴别决策边界的过程,鉴别能力越强意味着模型的泛化能力越好:
文章首先利用对抗攻击策略(adversarial attacking)将基准类样本(base class sample)转为目标类样本、且位于决策边界附近(BSS: boundary supporting sample),进而利用对抗生成的样本诱导学生网络的训练,可有效提升学生网络对决策边界的鉴别能力。文章采用迭代方式生成对抗样本,需要沿loss function(基准类得分与目标类得分之差)的梯度负方向调整样本,直到满足停止条件为止:
loss function:
沿loss function的梯度负方向调整样本:
停止条件(只要满足三者之一):
结合对抗生成的样本,利用教师网络训练学生网络所需的total loss包含CE loss、KD loss以及boundary supporting loss(BS loss):
Paper地址:https://arxiv.org/abs/1805.05532
8、Label Refinery:Improving ImageNet Classification through Label Progression
这篇文章通过迭代式的诱导训练,主要解决训练期间样本的crop与label不一致的问题,以增强label的质量,从而进一步增强模型的泛化能力:
诱导过程中,total loss表示为本次迭代(t>1)网络的预测输出(概率分布)与上一次迭代输出(Label Refinery:类似于教师网络的角色)的KL散度:
文章实验部分表明,不仅可以用训练网络作为Label Refinery Network,也可以用其他高质量网络(如Resnet50)作为Label Refinery Network。并在诱导过程中,能够对抗生成样本,实现数据增强。
GitHub地址:https://github.com/hessamb/label-refinery
9、Meal V2 KD (Ensemble of Multi-Teachers)
具体请看博客https://blog.csdn.net/u013066730/article/details/111933231
MEAL V2的基本思路是通过知识蒸馏,将多个Teacher模型的效果ensemble、迁移到一个Student模型中,包括:Teacher模型集成,KL散度loss以及判别器:
- 多个Teacher的预测概率求平均;
- 仅依靠Teacher的Soft label;
- 判别器起到正则化作用;
- Student从预训练模型开始,减少蒸馏训练的开销;
Paper地址:https://arxiv.org/abs/2009.08453
GitHub:https://github.com/szq0214/MEAL-V2
10、Miscellaneous
-------- 知识蒸馏可以与量化结合使用,考虑了中间层Feature Maps之间的关系,可参考:
https://blog.csdn.net/nature553863/article/details/82147933
-------- 知识蒸馏与Hint Learning相结合,可以训练精简的Faster-RCNN,可参考:
https://blog.csdn.net/nature553863/article/details/82463249
-------- 知识蒸馏在Transformer模型压缩方面,主要采用Self-attention Knowledge Distillation,可参考:
https://blog.csdn.net/nature553863/article/details/106855786
-------- 模型压缩方面,更为详细的讨论,请参考:
https://blog.csdn.net/nature553863/article/details/81083955
【知识蒸馏】常见的知识蒸馏方式(二)相关推荐
- 揭秘:企业做知识管理常见的几种方式!
企业做知识管理的方式有很多种,下面将介绍比较常见的几种方式,并分享如何做好知识管理. 一.企业做知识管理的方式 创建知识库 创建知识库是最基本的知识管理方式,它可以帮助企业把知识信息整理归类,以便更好 ...
- 【论文泛读121】边际效用递减:探索BERT知识蒸馏的最小知识
贴一下汇总贴:论文阅读记录 论文链接:<Marginal Utility Diminishes: Exploring the Minimum Knowledge for BERT Knowled ...
- “烘焙”ImageNet:自蒸馏下的知识整合
©作者|葛艺潇 学校|香港中文大学博士生 研究方向|图像检索.图像生成等 最先进的知识蒸馏算法发现整合多个模型可以生成更准确的训练监督,但需要以额外的模型参数及明显增加的计算成本为代价.为此,我们提出 ...
- 【知识蒸馏】使用CoatNet蒸馏ResNet图像分类模型
本文转载自:https://blog.csdn.net/hhhhhhhhhhwwwwwwwwww/article/details/127787791 ,仅作留用和学习,如有侵权,立刻删除! 文章目录 ...
- 知识蒸馏基础及Bert蒸馏模型
为了提高模型准确率,我们习惯用复杂的模型(网络层次深.参数量大),甚至会选用多个模型集成的模型,这就导致我们需要大量的计算资源以及庞大的数据集去支撑这个"大"模型.但是,在部署服务 ...
- 软考知识点——加密算法、常见计算机网络知识
目录 一.加密算法 1.常见的加密算法 (1)2021年下半年软考上午真题8 (2)2021年上半年软考上午真题9 2.加密技术的应用 3.网络安全协议分层 (1)2021年上半年软考上午真题7 (2 ...
- 软考项目管理领域的常见英文术语,特别是 9 大知识领域有关的知识
软考项目管理领域的常见英文术语,特别是 9 大知识领域有关的知识 软考项目管理领域的常见英文术语 一.项目管理基础术语 二.项目整体管理 三.项目范围管理 四.项目时间管理 五.项目人力资源管理 六. ...
- 跨行交易的一些常见的知识
跨行交易的一些常见的知识 实例 用户在a银行取b银行账户中的钱,这个时候就属于跨行取钱. 关键有两个问题: 首先是 信息流 银行之前是如何传递金钱的消息的. 第二是资金流: 银行之间是如何传递金钱的. ...
- 知识付费的七种变现方式
知识付费的七种变现方式. 一在线问答 以文字.音频.视频等方式来对提问者的问题进行回答.只要你在某些领域有丰富的知识积累,那么你的回答就能得到提问者的青睐,就可以赚取相应的佣金.不过这种收益方式效果甚 ...
- 数据库基础知识和常见术语学习
数据库基础知识和常见术语学习 什么是数据库 数据库系统 什么是数据库系统 数据库系统(DBS)的组成 数据库系统的特点 数据库管理系统(DBMS) 什么是数据库管理系统 数据库管理系统所提供的功能 数 ...
最新文章
- r语言和python-R语言和Python一块学习会弄混吗
- TeeChart经验总结 10.ZoomScroll
- 客服人员控制台Console,Salesforce Service Cloud的核心
- 初步学习Linux文件基本属性和Cygwin STATUS_ACCESS_VIOLATION 错误
- flatform installer web 安装php_web安装平台-微软web服务器配置安装工具(Web Platform Installer)5.0 官方最新版-东坡下载...
- 55 MM配置-评估和科目设置-定义账户分类参考
- python计算AA制时砍价后大家需要分摊的钱
- 如何使用PDF expert在Mac上给PDF调整页面顺序?
- LeetCode----两数之和
- 知网论文[全PDF下载],从此告别CAJ阅读器
- python嵩天ppt_python知识精华:嵩天微专业笔记
- Swing Copters摇摆直升机高分攻略,游戏攻略
- jfinal jboot 拦截器过滤文件上传请求 和 跨域解决方法
- 微信刷脸支付php后端,2.1 微信刷脸支付初始化
- html中如何访问ftp中的图片,CSS FTP上传网页图解教程
- linux上离线安装PostgreSQL和插件PostGIS
- 怎么理解预训练模型?
- Benchmark Analysis 7:SPEC2006.482sphinx
- 让AI帮你玩游戏(一) 基于目标检测用几个样本帮你实现在魔兽世界中钓鱼(群已满)
- ROS(1)创建工作空间和功能包过程