Distilling the knowledge in a neural network

Hinton 在论文中提出方法很简单,就是让学生模型的预测分布,来拟合老师模型(可以是集成模型)的预测分布,其中可通过用 logits 除以 temperature 来调节分布平滑程度,还避免一些极端情况影响。

蒸馏时的softmax

对于一个分类问题,定义soft label为模型的输出(即不同label的概率), hard label为最终正确的label(也就是ground truth),通常是通过最大化正确label的概率来进行学习的,但是不正确趋近于0的label也是有大有小的,这被称为"暗知识(Dark Knowledge)", 这也反应了模型的泛化能力。但因为过于趋紧0不利于student模型学习,为了让student也容易学习tearcher的输出,引入了带温度T的softmax概率为

比之前的softmax多了一个参数T(temperature),T越大产生的概率分布越平滑。
[Distilling the knowledge in a neural network]

蒸馏自由度还是很大的,并不需要一定按照 Hinton 最初论文里一样只对最后输出进行拟合,只要能让学生模型从老师模型中学习到东西就行。

DistilBert

DistillBert的做法相比bert-pkd就比较简单直接,还是保证模型的宽度不变,模型深度减为一半。主要在初始化和损失函数上下了功夫:

  • 损失函数:采用知识蒸馏损失、Masked Language Model损失和cosine embedding损失加起来的值。
  • 初始化:用Teacher模型的参数进行初始化,不过是从每两层中找一层出来。

Student architecture

和BERT类似,只是layer的数量减半
Student initialization

因为Student模型和Teacher模型每层的layer一样,因此每两层保留一层,利用相关的参数
Distillation

采用了RoBERTa的优化策略,动态mask,增大batch size,取消NSP任务的损失函数,
Training Loss

The final training objective is a linear combination of the distillation loss L_{ce}  with the supervised training loss, in our case the masked language modeling loss L_{mlm}  We found it beneficial to add a cosine embedding loss ( L_{cos} ) which will tend to align the directions of the student and teacher hidden states vectors.

最终的loss由三部分构成

1 蒸馏损失,即 L_ce = ∑ t i ∗ log ( s_i ), 其中 s_i 是student输出的概率, t_i 是teacher输出的概率,当BERT预测的 t_i​越高,而DistilBERT预测s_i越低,得到的Loss就会越高
    2 Mask language model loss,参考BERT,这部分也就是为hard loss
    3 Cosine Embedding Loss,利于让student学习和teacher一样的hidden state vector

[DistilBERT, a distilled version of BERT: smaller,faster, cheaper and lighter]

[DistilBert解读]

[模型训练损失值不变_Bert与模型蒸馏: PKD和DistillBert]

BERT-PKD (Patient Knowledge Distillation)

在hinton提到两个损失之上,再加上一个loss:L_PT。

PKD论文中做了对比,减少模型宽度和减少模型深度,得到的结论是减少宽度带来的efficiency提高不如减少深度来的更大。

论文所提出的多层蒸馏,即Student模型除了学习Teacher模型的概率输出之外,还要学习一些中间层的输出。论文提出了两种方法,第一种是Skip模式,即每隔几层去学习一个中间层,第二种是Last模式,即学习teacher模型的最后几层。如果是完全的去学习中间层的话,那么计算量很大。为了避免这个问题,我们注意到Bert模型中有个特殊字段[CLS],因为其在 BERT 分类任务中的重要性,在蒸馏过程中,让student模型去学习[CLS]的中间的输出,计算过程是先归一化,然后直接 均方差MSE 求损失。

Note:

1 至于学生模型中间层如何与老师模型中间层对应,论文中发现最佳策略是直接按倍数取老师模型对应层就行,比如1对2,2对4这样。

2 初始化的话就采用Teacher模型的前几层来做初始化。

3 更好的teacher模型会带来增长么?答案是不会的,可以看上图,把12层的Bert模型换成了24层的Bert模型,反而导致效果变差。究其原因,可能是因为在实验中,我们使用Teacher模型的前N层来初始化Student模型,对于24层模型来说,前N层更容易导致不匹配。而更好的方法则是Student模型先训练好,再去学Teacher模型。

[Patient Knowledge Distillation for BERT Model Compression]

TinyBERT

华为的 TinyBERT,比起上面的 PKD 只是对中间层 [CLS] 进行拟合,它更深入了一步。对 BERT 全范围进行拟合,词向量层,中间隐层,中间注意力矩阵,最后预测层。

在BERT 预训练阶段 和 Fine-tune阶段 分别做蒸馏,如下所示:

其中Transformer Distillation 在预训练和 fine-tune 阶段都是一样的,分为三个部分:

Note: 式11分别是所有token的 embedding、hidden layer outputs 和 attention matrix的MSE loss,L_pred 是 hinton 的dark knowledge。

在 预训练阶段 和 Fine-tune阶段 都仅使用了蒸馏的loss,而没有使用 MLM loss 和 分类 CE loss。

[TinyBERT:模型小7倍,速度快8倍,华中科大、华为出品]

[TinyBERT: Distilling BERT for Natural Language Understanding, Xiaoqi Jiao et al.  EMNLP(findings), 2020 [code]]

from: -柚子皮-

ref: [BERT 瘦身之路:Distillation,Quantization,Pruning]

深度学习:蒸馏Distill相关推荐

  1. 【模型蒸馏】从入门到放弃:深度学习中的模型蒸馏技术

    点击上方,选择星标或置顶,每天给你送干货! 阅读大概需要17分钟 跟随小博主,每天进步一丢丢 来自 | 知乎   作者 | 小锋子Shawn 地址 | https://zhuanlan.zhihu.c ...

  2. 【深度学习】深度学习中的知识蒸馏技术(上)简介

    本文概览: 1. 知识蒸馏介绍 1.1 什么是知识蒸馏? 在化学中,蒸馏是一种有效的分离不同沸点组分的方法,大致步骤是先升温使低沸点的组分汽化,然后降温冷凝,达到分离出目标物质的目的.化学蒸馏条件:( ...

  3. [深度学习]知识蒸馏技术

    一 知识蒸馏(Knowledge Distillation)介绍 名词解释 teacher - 原始模型或模型ensemble student - 新模型 transfer set - 用来迁移tea ...

  4. 深度学习中的知识蒸馏技术(上)

    本文概览: 1. 知识蒸馏介绍 1.1 什么是知识蒸馏? 在化学中,蒸馏是一种有效的分离不同沸点组分的方法,大致步骤是先升温使低沸点的组分汽化,然后降温冷凝,达到分离出目标物质的目的.化学蒸馏条件:( ...

  5. 深度学习中的知识蒸馏技术!

    本文概览: 1. 知识蒸馏介绍 1.1 什么是知识蒸馏? 在化学中,蒸馏是一种有效的分离不同沸点组分的方法,大致步骤是先升温使低沸点的组分汽化,然后降温冷凝,达到分离出目标物质的目的.化学蒸馏条件:( ...

  6. 深度学习 模型压缩之知识蒸馏

    知识蒸馏 知识蒸馏 蒸馏方式 离线蒸馏 在线蒸馏 自我蒸馏 蒸馏算法 对抗蒸馏 多教师蒸馏 跨模态蒸馏 图蒸馏 无数据蒸馏 量化蒸馏 深度交互学习(Deep Mutal Learning) Demo ...

  7. 深度学习三大谜团:集成、知识蒸馏和自蒸馏

    编译:梦佳 校对:周寅张皓 集成(Ensemble,又称模型平均)是一种「古老」而强大的方法.只需要对同一个训练数据集上,几个独立训练的神经网络的输出,简单地求平均,便可以获得比原有模型更高的性能.甚 ...

  8. 深度学习中的3个秘密:集成、知识蒸馏和蒸馏

    作者:Zeyuan Allen-Zhu 来源:AI公园 编译:ronghuaiyang 在现在的标准技术下,例如过参数化.batch-normalization和添加残差连接,"现代&quo ...

  9. 【深度学习】协同优化器和结构化知识蒸馏

    [深度学习]协同优化器和结构化知识蒸馏 文章目录 1 概述 2 什么是RAdam(Rectified Adam) 3 Lookahead - 探索损失面的伙伴系统=更快,更稳定的探索和收敛. 4 Ra ...

  10. 【深度学习】深度学习之对抗样本问题和知识蒸馏技术

    文章目录 1 什么是深度学习对抗样本 2 深度学习对于对抗样本表现的脆弱性产生的原因 3 深度学习的对抗训练 4 深度学习中的对抗攻击和对抗防御 5 知识蒸馏技术5.1 知识蒸馏介绍5.2 为什么要有 ...

最新文章

  1. JCIM| 基于双向RNN的分子生成模型
  2. jmeter(十一)JDBC Request之Query Type
  3. java it_Java中的Iterator的用法
  4. web压测工具http_load原理分析
  5. Linux启动报错UNEXPECTED INCONSISTENCY解决方法
  6. Qt工作笔记-MySQL获取select表头(域)数据
  7. OpenLTE 基站相关头文件:用户、定时器、基站接口、消息接口
  8. Jquery_异步上传文件多种方式归纳
  9. IEEE和SCI等的通俗简介
  10. shp在MATLAB中裁剪数据,ENVI中利用Shape文件裁剪栅格数据
  11. 网易2012校园招聘笔试题目
  12. java lpad oracle_「oracle」lpad函数和rpad函数详解
  13. mysql 丢失msvcr120.dll_安装MySQL数据库提示计算机中丢失MSVCR120.dll文件
  14. 设置wsl2桥接模式和设置ip
  15. html压缩包怎么打开,展示电脑rar压缩包文件怎么打开?教你正确打开方式
  16. 学python历程中
  17. Docker安装OnlyOffice并配置自签证书和自己的域名证书
  18. K-Means聚类算法原理及其python和matlab实现
  19. docker部署环境
  20. SQLSyntaxErrorException: SELECT command denied to user ‘XXXXX‘@‘xxxx‘ for table ‘XXXX‘ 异常解决

热门文章

  1. Macbook上如何调整Windows分区大小,NTFS-FAT-FAT32
  2. php 跳转邮箱,JS简单实现点击跳转登陆邮箱功能的方法
  3. Python 已知三角形三边求三角形面积
  4. html 图片展示 3d,利用CSS3制作简单的3d半透明立方体图片展示
  5. Flutter-解决Try catch出现异常:type ‘_TypeError‘ is not a subtype of type ‘Exception‘ in type cast
  6. 基于51单片机+ULN2003控制步进电机S曲线加减速
  7. 利用python进行数据分析(四)
  8. 【手把手教安装】VM16 Pro安装Win10!!!
  9. 数字温度计设计c语言,数字温度计的设计与制作
  10. 异常:“ERROR: Permission to XXX.git denied to user”终极解决方案