来自:NLP从入门到放弃

PKD[1]核心点就是不仅仅从Bert(老师网络)的最后输出层学习知识去做蒸馏,它还另加了一部分,就是从Bert的中间层去学习

简单说,PKD的知识来源有两部分:中间层+最后输出,当然还有Hard labels

它缓解了之前只用最后softmax输出层的蒸馏方式出现的过拟合而导致泛化能力降低的问题。

接下来,我们从PKD模型的两个策略说起:PKD-Last 和 PKD-Skip。

1.PKD-Last and PKD-Skip

PKD的本质是从中间层学习知识,但是这个中间层如何去定义,就各式各样了。

比如说,我完全可以定位我只要奇数层,或者我只要偶数层,或者说我只要最中间的两层,等等,不一而足。

那么作者,主要是使用了这么多想法中的看起来比较合理的两种。

PKD-Last,就是把中间层定义为老师网络的最后k层

这样做是基于老师网络越靠后的层数含有更多更重要的信息。

这样的想法其实和之前的蒸馏想法很类似,也就是只使用softmax层的输出去做蒸馏。但是从感官来看,有种尾大不掉的感觉,不均衡。

另一个策略是 就是PKD-Skip,顾名思义,就是每跳几层学习一层

这么做是基于老师网络比较底层的层也含有一些重要性信息,这些信息不应该被错过。

作者在后面的实验中,证明了,PKD-Skip 效果稍微好一点(slightly better);

作者认为PKD-Skip抓住了老师网络不同层的多样性信息。而PKD-Last抓住的更多相对来说同质化信息,因为集中在了最后几层。

2. PKD

2.1架构图

两种策略的PKD的架构图如下所示,注意观察图,有个细节很容易忽视掉:

PKD_Models

我们注意看这个图,Bert的最后一层(不是那个绿色的输出层)是没有被蒸馏的,这个细节一会会提到

2.2 怎么蒸馏中间层

这个时候,需要解决一个问题:我们怎么蒸馏中间层?

仔细想一下Bert的架构,假设最大长度是128,那么我们每一层Transformer encoder的输出都应该是128个单元,每个单元是768维度。

那么在对中间层进行蒸馏的时候,我们需要针对哪一个单元?是针对所有单元还是其中的部分单元?

首先,我们想一下,正常KD进行蒸馏的时候,我们使用的是[CLS]单元Softmax的输出,进行蒸馏。

我们可以把这个思想借鉴过来,一来,对所有单元进行蒸馏,计算量太大。二来,[CLS] 不严谨的说,可以看到整个句子的信息。

为啥说是不严谨的说呢?因为[CLS]是不能代表整个句子的输出信息,这一点我记得Bert中有提到。

2.3蒸馏层数和学生网络的初始化

接下来,我想说一个很小的细节点,对比着看上面的模型架构图:

Bert(老师网络)的最后一层 (Layer 12 for BERT-Base) 在蒸馏的时候是不予考虑

原因的话,其一可以这么理解,PKD创新点是从中间层学习知识,最后一层不属于中间层。当然这么说有点牵强附会。

作者的解释是最后一层的隐层输出之后连接的就是Softmax层,而Softmax层的输出已经被KD Loss计算在内了。

比如说,K=5,那么对于两种PKD的模式,被学习的中间层分别是:

PKD-Skip: ;

PKD-Last:

还有一个细节点需要注意,就是学生网络的初始化方式,直接使用老师网络的前几层去初始化学生网络的参数。

2.4 损失函数

首先需要注意的是中间层的损失,作者使用的是MSE损失。如下:

中间层损失计算

整个模型的损失主要是分为两个部分:KD损失和中间层的损失,如下:

Loss_of_PKD

超参数问题:

3. 实验效果

实验效果可以总结如下:

  1. PKD确实有效,而且Skip模型比Last效果稍微好一点。

  2. PKD模型减少了参数量,加快了推理速度,基本是线性关系,毕竟减少了层数

除了这两点,作者还做了一个实验去验证:如果老师网络更大,PKD模型得到的学生网络会表现更好吗

这个实验我很感兴趣。

直接上结果图:

Larger_Teacher

KD情况下,注意不是PKD模型,看#1 和#2,在老师网络增加的情况下,效果有好有坏。这个和训练数据大小有关。

KD情况下,看#1和#3,在老师网络增加的情况下,学生网络明显变差。

作者分析是因为,压缩比高了,学生网络获取的信息变少了。

也就是大网络和小网络本身效果没有差多少,但是学生网络在老师是大网络的情况下压缩比大,学到的信息就少了。

更有意思的是对比#2和#3,老师是大网络的情况下,学生网络效果差。

这里刚开始没理解,后来仔细看了一下,注意#2 的学生网络是,也就是它的初始化是从来的,占了一半的信息。

好的,写到这里

下载一:中文版!学习TensorFlow、PyTorch、机器学习、深度学习和数据结构五件套!后台回复【五件套】
下载二:南大模式识别PPT后台回复【南大模式识别】

说个正事哈

由于微信平台算法改版,公号内容将不再以时间排序展示,如果大家想第一时间看到我们的推送,强烈建议星标我们和给我们多点点【在看】。星标具体步骤为:

(1)点击页面最上方深度学习自然语言处理”,进入公众号主页。

(2)点击右上角的小点点,在弹出页面点击“设为星标”,就可以啦。

感谢支持,比心

投稿或交流学习,备注:昵称-学校(公司)-方向,进入DL&NLP交流群。

方向有很多:机器学习、深度学习,python,情感分析、意见挖掘、句法分析、机器翻译、人机对话、知识图谱、语音识别等。

记得备注呦

推荐两个专辑给大家:

专辑 | 李宏毅人类语言处理2020笔记

专辑 | NLP论文解读

专辑 | 情感分析


整理不易,还望给个在看!

PKD-Bert:基于多层网络的Bert知识蒸馏相关推荐

  1. 从设计网络就开始知识蒸馏

    如图所示使用大网络指导小网络进行拟合行为后预测的时候直接抛弃大网的网络,这样后期是要将网络逐层拆开即可.

  2. 基于小样本知识蒸馏的乳腺癌组织病理图像分类

    基于小样本知识蒸馏的乳腺癌组织病理图像分类 期刊:中国计量大学学报 时间:2022 研究院:中国计量大学 关键词:乳腺癌 :知识蒸馏 :图像分类 :小样本学习 :卷积神经网络 方法简介 本文使用的知识 ...

  3. 【深度学习】深度学习中的知识蒸馏技术(上)简介

    本文概览: 1. 知识蒸馏介绍 1.1 什么是知识蒸馏? 在化学中,蒸馏是一种有效的分离不同沸点组分的方法,大致步骤是先升温使低沸点的组分汽化,然后降温冷凝,达到分离出目标物质的目的.化学蒸馏条件:( ...

  4. [深度学习]知识蒸馏技术

    一 知识蒸馏(Knowledge Distillation)介绍 名词解释 teacher - 原始模型或模型ensemble student - 新模型 transfer set - 用来迁移tea ...

  5. 【知识蒸馏】知识蒸馏(Knowledge Distillation)技术详解

    参考论文:Knowledge Distillation: A Survey 1.前言 ​ 近年来,深度学习在学术界和工业界取得了巨大的成功,根本原因在于其可拓展性和编码大规模数据的能力.但是,深度学习 ...

  6. 深度学习中的知识蒸馏技术(上)

    本文概览: 1. 知识蒸馏介绍 1.1 什么是知识蒸馏? 在化学中,蒸馏是一种有效的分离不同沸点组分的方法,大致步骤是先升温使低沸点的组分汽化,然后降温冷凝,达到分离出目标物质的目的.化学蒸馏条件:( ...

  7. 深度学习中的知识蒸馏技术!

    本文概览: 1. 知识蒸馏介绍 1.1 什么是知识蒸馏? 在化学中,蒸馏是一种有效的分离不同沸点组分的方法,大致步骤是先升温使低沸点的组分汽化,然后降温冷凝,达到分离出目标物质的目的.化学蒸馏条件:( ...

  8. 张祥雨团队最新工作:用于物体检测的实例条件知识蒸馏 | NeurIPS 2021

    [专栏:前沿进展]在青源LIVE第31期中,旷视研究院张祥雨团队的张培圳研究员深入浅出地为我们介绍了其团队被 NeurIPS 2021 录用的论文「用于物体检测的实例条件知识蒸馏」.本期报告首先简要回 ...

  9. 关于知识蒸馏,你想知道的都在这里!

    "蒸馏",一个化学用语,在不同的沸点下提取出不同的成分.知识蒸馏就是指一个很大很复杂的模型,有着非常好的效果和泛化能力,这是缺乏表达能力的小模型所不能拥有的.因此从大模型学到的知识 ...

  10. 关于“知识蒸馏“,你想知道的都在这里!

    "蒸馏",一个化学用语,在不同的沸点下提取出不同的成分.知识蒸馏就是指一个很大很复杂的模型,有着非常好的效果和泛化能力,这是缺乏表达能力的小模型所不能拥有的.因此从大模型学到的知识 ...

最新文章

  1. 使用消息来处理多线程程序中的一些问题
  2. Overlay 网络 — VxLAN 虚拟可扩展局域网协议
  3. Go Code Review Comments 翻译 编写优雅golang代码
  4. 牛逼哄哄的SLAM技术 即将颠覆哪些领域?
  5. 数据结构之trie树——First! G,电子字典,Type Printer,Nikitosh and xor
  6. bakaxl启动器怎么导入整合包_bakaxl启动器加皮肤光影mod
  7. php larver 导出e,laravel5 Excel导出
  8. java 的clean code 技巧
  9. VC下__func__未定义,改用__FUNCTION__
  10. Python库的下载及导入使用教程
  11. 基于STM32的小说阅读器
  12. java同步代码块作用_Java之同步代码块
  13. IP冲突,中国移动光猫路由-中兴F673A之修改IP篇
  14. 【Python】UnicodeDecodeError: 'gbk' codec can't decode byte 0xfe
  15. android 跳应用市场评分,Android 应用中跳转到应用市场评分示例
  16. 努力是你最幸福的时候
  17. 计算机技术教学,小学计算机技术教学计划
  18. slf4j报错:SLF4J:Failed to load class org.slf4j.impl.StaticLoggerBinder.Defaulting to no-operat有效解决办法
  19. 20145212罗天晨 恶意代码分析
  20. 统计文章中的单词数量

热门文章

  1. C语言读取文件大量数据到数组
  2. Changing a remote's URL
  3. 94-《纪元2205》游戏体会.(2015.11.12)
  4. 安装GIT,集成到Powershell中
  5. C# 中using的几个用途
  6. Android设计模式--之命令模式
  7. flash跟随鼠标样式
  8. PHP $_SERVER详解
  9. 转载:SPFA算法学习
  10. js数组操作大全(转)