点击上方,选择星标,每天给你送干货!


来自:NLP从入门到放弃

大家好,我是DASOU;

今天从代码角度深入了解一下知识蒸馏,主要核心部分就是分析一下在知识蒸馏中损失函数是如何实现的;

知识蒸馏一个简单的脉络可以这么去梳理:学什么,从哪里学,怎么学?

学什么:学的是老师的知识,体现在网络的参数上;

从哪里学:输入层,中间层,输出层;

怎么学:损失函数度量老师网络和学生网络的差异性;

从架构上来说,BERT可以蒸馏到简单的TextCNN,LSTM等,也就可以蒸馏到TRM架构模型,比如12层BERT到4层BERT;

之前工作中用到的是BERT蒸馏到TextCNN;

最近在往TRM蒸馏靠近,使用的是 Textbrewer 这个库(这个库太强大了);

接下来,我从代码的角度来梳理一下知识蒸馏的核心步骤,其实最主要的就是分析一下损失函数那块的代码形式。

我以一个文本分类的任务为例子,在阅读理解的过程中,最需要注意的一点是数据的流入流出的Shape,这个很重要,在自己写代码的时候,最重要的其实就是这个;

首先使用的是MNLI任务,也就是一个文本分类任务,三个标签;

输入为Batch_data:[32,128]---[Batch_size,seq_len];

老师网络:BERT_base:12层,Hidden_size为768;

学生网络:BERT_base:4层,Hidden_size为312;

首先第一个步骤是训练一个老师网络,这个没啥可说。

其次是初始化学生网络,然后将输入Batch_data流经两个网络;

在初始化学生网络的时候,之前有的同学问到是如何初始化的一个BERT模型的;

关于这个,最主要的是修改Config文件那里的层数,由正常的12改为4,然后如果你不是从本地load参数到学生网络,BERT模型的类会自动调用初始化;

然后我们来说数据首先流经学生网络,我们得到两个东西,一个是最后一层【CLS】的输出,此时未经softmax操作,所以是logits,维度为:[32,3]-[batch_size,label_size];

第二个东西是中间隐层的输出,维度为:[5,32,128,312],也就是 [隐层数量,batch_size,seq_len,Hidden_size];

需要注意的是这里的隐层数量是5,因为正常的隐层在模型定义的时候是4,然后这里是加上了embedding层;

还有一点需要注意的是,在度量学生网络和老师网络隐层差异的时候,这里是度量的seq_len,也就是对每个token的输出都做了操作;

如果在这里我们想做类似【CLS】的输出的时候,只需要提取最开始的一个[32,312]的向量就可以;不过,一般来说我们不这么做;

其次流经老师网络,我们同样得到两个东西,一个是最后一层【CLS】的输出,此时未经softmax操作,所以是logits,维度为:[32,3]-[batch_size,label_size];

第二个东西是中间隐层的输出,维度为:[5,32,128,768],也就是 [隐层数量,batch_size,seq_len,Hidden_size];

这里需要注意的是老师网络和学生网络隐层数量不一样,一个是768,一个是312。

这其实是一个很常见的现象;就是我们的学生网络在减少参数的时候,不仅会变矮,有时候我们也想让它变窄,也就是隐层的输出会发生变化,从768变为312;

这个维度的变化需要注意两点,首先就是在学生模型初始化的时候,不能套用老师网络的对应层的参数,因为隐层Hidden_size发生了变化。所以一般调用的是BERT自带的初始化方式;

其次就是在度量学生网络和老师网络差异性的时候,因为矩阵大小不一致,不能直接做MSE。在代码层面上,需要做一个线性映射,才能做MSE。

而且还需要注意的一点是,由于老师网络已经固定不动了,所以在做映射的时候我们是要对学生网路的312加一个线性层转化到768层,也就是说这个线性层是加在了学生网络;

整个架构的损失函数可以分为三种:首先对于【CLS】的输出,使用KL散度度量差异;对于隐层输出使用MSE和MMD损失函数进行度量;

对于损失函数这块的选择,其实我觉得没啥经验可说,只能试一试;

看了很多论文加上自己的经验,一般来说在最后面使用KL,中间层使用MSE会更好一点;当然有的实验也会在最后一层直接用MSE;玄学。

在初看代码的时候,MMD这个之前我没接触过,还特意去看了一下,关于理论我就不多说了,一会看代码吧。

首先对【CLS】的输出,代码如下:

def kd_ce_loss(logits_S, logits_T, temperature=1):if isinstance(temperature, torch.Tensor) and temperature.dim() > 0:temperature = temperature.unsqueeze(-1)beta_logits_T = logits_T / temperaturebeta_logits_S = logits_S / temperaturep_T = F.softmax(beta_logits_T, dim=-1)loss = -(p_T * F.log_softmax(beta_logits_S, dim=-1)).sum(dim=-1).mean()return loss

首先对于 logits_S,就是学生网络的【CLS】的输出,logits_T就是老师网络【CLS】的输出,temperature 在代码中默认参数是1,例子中设置为了8;

整个代码其实很简单,就是先做Temp的一个转化,注意这里我们对学生网络的输出和老师网络的输出都做了转化,然后做loss计算;

其次我们来看比较复杂的中间层的度量;

首先需要掌握一点,就是学生网络和老师网络层之间的对应关系;

学生网络是4层,老师网络12层,那么在对应的时候,简单的对应关系就是这样的:

layer_T : 0, layer_S : 0,
layer_T : 3, layer_S : 1,
layer_T : 6, layer_S : 2,
layer_T : 9, layer_S : 3,
layer_T : 12, layer_S : 4,

这个对应关系是需要我们认为去设定的,将学生网络的1层对应到老师网络的12层可不可以?当然可以,但是效果不一定好;

一般来说等间隔的对应上就好;

这个对应关系其实还有一个用处,就是学生网络在初始化的时候【假如没有变窄,只是变矮,也就是层数变低了】,那么可以从依据这个对应关系把权重copy过来;

学生网络的隐层输出为:[5,32,128,312],老师网络隐层输出为[5,32,128,768]

那么在代码实现的时候,需要做一个zip函数把对应层映射过去,然后每一层计算MSE,然后加起来作为损失函数;

我们来看代码:

inters_T = {feature: results_T.get(feature,[]) for feature in FEATURES}
inters_S = {feature: results_S.get(feature,[]) for feature in FEATURES}for ith,inter_match in enumerate(self.d_config.intermediate_matches):if type(layer_S) is list and type(layer_T) is list: ## MMD损失函数对应的情况inter_S = [inters_S[feature][s] for s in layer_S]inter_T = [inters_T[feature][t] for t in layer_T]name_S = '-'.join(map(str,layer_S))name_T = '-'.join(map(str,layer_T))if self.projs[ith]: ## 这里失去做学生网络隐层的映射#inter_T = [self.projs[ith](t) for t in inter_T]inter_S = [self.projs[ith](s) for s in inter_S]else:## MSE 损失函数inter_S = inters_S[feature][layer_S]inter_T = inters_T[feature][layer_T]name_S = str(layer_S)name_T = str(layer_T)if self.projs[ith]:inter_S = self.projs[ith](inter_S) # 需要注意的是隐层输出是312,但是老师网络是768,所以这里要做一个linear投影到更高维,方便计算损失函数intermediate_loss = match_loss(inter_S, inter_T, mask=inputs_mask_S)  ## loss = F.mse_loss(state_S, state_T)total_loss += intermediate_loss * match_weight

这个代码里面比如迷糊的是【self.d_config.intermediate_matches】,打印出来发现是这个东西:

IntermediateMatch: layer_T : 0, layer_S : 0, feature : hidden, weight : 1, loss : hidden_mse, proj : ['linear', 312, 768, {}],
IntermediateMatch: layer_T : 3, layer_S : 1, feature : hidden, weight : 1, loss : hidden_mse, proj : ['linear', 312, 768, {}],
IntermediateMatch: layer_T : 6, layer_S : 2, feature : hidden, weight : 1, loss : hidden_mse, proj : ['linear', 312, 768, {}],
IntermediateMatch: layer_T : 9, layer_S : 3, feature : hidden, weight : 1, loss : hidden_mse, proj : ['linear', 312, 768, {}],
IntermediateMatch: layer_T : 12, layer_S : 4, feature : hidden, weight : 1, loss : hidden_mse, proj : ['linear', 312, 768, {}],
IntermediateMatch: layer_T : [0, 0], layer_S : [0, 0], feature : hidden, weight : 1, loss : mmd, proj : None,
IntermediateMatch: layer_T : [3, 3], layer_S : [1, 1], feature : hidden, weight : 1, loss : mmd, proj : None,
IntermediateMatch: layer_T : [6, 6], layer_S : [2, 2], feature : hidden, weight : 1, loss : mmd, proj : None,
IntermediateMatch: layer_T : [9, 9], layer_S : [3, 3], feature : hidden, weight : 1, loss : mmd, proj : None,
IntermediateMatch: layer_T : [12, 12], layer_S : [4, 4], feature : hidden, weight : 1, loss : mmd, proj : None

简单说,这个变量存储的就是上面我们谈到的层与层之间的对应关系。前面5行就是MSE损失函数度量,后面那个注意看,层数对应的时候是一个列表,对应的是MMD损失函数;

我们来看一下MMD损失的代码形式:

def mmd_loss(state_S, state_T, mask=None):state_S_0 = state_S[0] # (batch_size , length, hidden_dim_S)state_S_1 = state_S[1] # (batch_size , length, hidden_dim_S)state_T_0 = state_T[0] # (batch_size , length, hidden_dim_T)state_T_1 = state_T[1] # (batch_size , length, hidden_dim_T)if mask is None:gram_S = torch.bmm(state_S_0, state_S_1.transpose(1, 2)) / state_S_1.size(2)  # (batch_size, length, length)gram_T = torch.bmm(state_T_0, state_T_1.transpose(1, 2)) / state_T_1.size(2)loss = F.mse_loss(gram_S, gram_T)else:mask = mask.to(state_S[0])valid_count = torch.pow(mask.sum(dim=1), 2).sum()gram_S = torch.bmm(state_S_0, state_S_1.transpose(1, 2)) / state_S_1.size(2)  # (batch_size, length, length)gram_T = torch.bmm(state_T_0, state_T_1.transpose(1, 2)) / state_T_1.size(2)loss = (F.mse_loss(gram_S, gram_T, reduction='none') * mask.unsqueeze(-1) * mask.unsqueeze(1)).sum() / valid_countreturn loss

看最重要的代码就可以:

state_S_0 = state_S[0]#  32 128 312 (batch_size , length, hidden_dim_S)
state_T_0 = state_T[0] #  32 128 768 (batch_size , length, hidden_dim_T)
gram_S = torch.bmm(state_S_0, state_S_1.transpose(1, 2)) / state_S_1.size(2)
gram_T = torch.bmm(state_T_0, state_T_1.transpose(1, 2)) / state_T_1.size(2)

简单说就是现在自己内部计算bmm,然后两个矩阵之间做mse;这里如果我没理解错使用的是一个线性核函数;

损失函数代码大致就是这样,之后有时间我写个简单的repository,梳理一下整个流程;

说个正事哈

由于微信平台算法改版,公号内容将不再以时间排序展示,如果大家想第一时间看到我们的推送,强烈建议星标我们和给我们多点点【在看】。星标具体步骤为:

(1)点击页面最上方深度学习自然语言处理”,进入公众号主页。

(2)点击右上角的小点点,在弹出页面点击“设为星标”,就可以啦。

感谢支持,比心

投稿或交流学习,备注:昵称-学校(公司)-方向,进入DL&NLP交流群。

方向有很多:机器学习、深度学习,python,情感分析、意见挖掘、句法分析、机器翻译、人机对话、知识图谱、语音识别等。

记得备注呦

整理不易,还望给个在看!

【知识蒸馏】如何写好BERT知识蒸馏的损失函数代码(一)相关推荐

  1. 计算机知识探索怎么写,计算机基础知识及探索.doc

    PAGE PAGE 23 HYPERLINK "/ASPX/602009818/JournalContent/1325923866.aspx"计算机基础知识参考试题及答案解析一.单 ...

  2. 手写数字识别中多元分类原理_广告行业中那些趣事系列:从理论到实战BERT知识蒸馏...

    导读:本文将介绍在广告行业中自然语言处理和推荐系统实践.本文主要分享从理论到实战知识蒸馏,对知识蒸馏感兴趣的小伙伴可以一起沟通交流. 摘要:本篇主要分享从理论到实战知识蒸馏.首先讲了下为什么要学习知识 ...

  3. BERT知识蒸馏TinyBERT

    1. 概述 诸如BERT等预训练模型的提出显著的提升了自然语言处理任务的效果,但是随着模型的越来越复杂,同样带来了很多的问题,如参数过多,模型过大,推理事件过长,计算资源需求大等.近年来,通过模型压缩 ...

  4. BERT知识蒸馏Distilled BiLSTM

    1. 概述 随着BERT模型的提出,在NLP上的效果在不断被刷新,伴随着计算能力的不断提高,模型的深度和复杂度也在不断上升,BERT模型在经过下游任务Fine-tuning后,由于参数量巨大,计算比较 ...

  5. 论文浅尝 | MulDE:面向低维知识图嵌入的多教师知识蒸馏

    笔记整理:朱渝珊,浙江大学在读博士,研究方向为快速知识图谱的表示学习,多模态知识图谱. Motivation 为了更高的精度,现有的KGE方法都会采用较高的embedding维度,但是高维KGE需要巨 ...

  6. bert模型蒸馏实战

    由于bert模型参数很大,在用到生产环境中推理效率难以满足要求,因此经常需要将模型进行压缩.常用的模型压缩的方法有剪枝.蒸馏和量化等方法.比较容易实现的方法为知识蒸馏,下面便介绍如何将bert模型进行 ...

  7. 【模型蒸馏】TinyBERT: Distilling BERT for Natural Language Understanding

    总述 TinyBert主要探究如何使用模型蒸馏来实现BERT模型的压缩. 主要包括两个创新点: 对Transformer的参数进行蒸馏,需要同时注意embedding,attention_weight ...

  8. K-BERT:BERT+知识图谱

    1 简介 本文根据2019年<K-BERT:Enabling Language Representation with Knowledge Graph>翻译总结的.如标题所述就是BERT+ ...

  9. 计算机网络基础知识论文摘要,计算机网络基础知识论文大纲格式 计算机网络基础知识论文框架如何写...

    [100个]计算机网络基础知识论文大纲格式供您参考,希望能解决毕业生们的计算机网络基础知识论文框架如何写相关问题,写好提纲那就开始写计算机网络基础知识论文吧! 五.高职<计算机网络>课程活 ...

  10. 【RPC框架、RPC框架必会的基本知识、手写一个RPC框架案例、优秀的RPC框架Dubbo、Dubbo和SpringCloud框架比较】

    一.RPC框架必会的基本知识 1.1 什么是RPC? RPC(Remote Procedure Call --远程过程调用),它是一种通过网络从远程计算机程序上请求服务,而不需要了解底层网络的技术. ...

最新文章

  1. Hibernate的Session介绍[转 adoocoke]
  2. python3数据类型:Number(数字)
  3. Android设计原则及规范指南!UI设计师值得一看!
  4. 使用javaGUI编写检测是否有网
  5. Isim你不得不知道的技巧(整理)
  6. Set ip IPv6 env (by quqi99)
  7. 苹果按键强制恢复出厂_【数码】苹果手机忘了解锁密码不要慌,你可以这样做!...
  8. UVa 437 巴比伦塔(The Tower of Babylon)
  9. 集成模型Bagging和Boosting的区别
  10. tmux鼠标配置出现错误unknown option: mode-mouse
  11. 【Python基础】第八篇 | 容器之列表的使用
  12. 一种简单、安全的Dota全图新思路 作者:LC 【转】
  13. 搭建企业私有Git服务
  14. 关于TLC2543不常见问题
  15. python sklearn svm_文本分类和预测 sklearn.svm.LinearSVC(1)
  16. mysql药品库管理项目简介_MySQL数据库项目化教程简介,目录书摘
  17. Cufllinks的安装与使用
  18. 用Scipy理解Gamma函数
  19. unity3d C#用匿名委托循环注册按钮点击事件报错:索引超界 ArgumentOutOfRangeException: Index was out of range. Must be non-ne
  20. 408 知识点笔记——操作系统(内存管理)

热门文章

  1. 基本数据类型的包装类和随机数
  2. 城市交通_ssl1636_floyd
  3. Linux系统安全保护措施
  4. linux用命令行来执行php程序
  5. 判断web app是否从主屏启动
  6. flash与javacript:图片交互
  7. 使用http连接到Analysis services
  8. 细说show slave status参数详解(最全)【转】
  9. HDU 1698 Just a Hook(线段树:区间更新)
  10. 程序员的目标应该是向牛人看齐而不是当经理或者赚大钱