论文题目:Distilling Task-Specific Knowledge from BERT into Simple Neural Networks
论文链接:https://arxiv.org/pdf/1903.12136.pdf

摘要

在自然语言处理文献中,神经网络变得越来越深入和复杂。这一趋势的苗头就是深度语言表示模型,其中包括BERT、ELMo和GPT。这些模型的出现和演进甚至导致人们相信上一代、较浅的语言理解神经网络(例如LSTM)已经过时了。然而这篇论文证明了如果没有网络架构的改变、不加入外部训练数据或其他的输入特征,基本的“轻量级”神经网络仍然可以具有竞争力。文本将最先进的语言表示模型BERT中的知识提炼为单层BiLSTM,以及用于句子对任务的暹罗对应模型。在语义理解、自然语言推理和情绪分类的多个数据集中,知识蒸馏模型获得了与ELMo的相当结果,参数量只有ELMo的大约1/100倍,而推理时间快了15倍。

1 简介

关于自然语言处理研究中,神经网络模型已经成了主力军,并且模型结构层出不穷,好像永无止境一样,这些过程中最开始的神经网络例如LSTM变得容易被忽视。例如ELMo模型在2018年一些列任务上取得了sota效果,再到双向编码表示模型Bert、GPT-2在更多任务上取得了很大提升。

但是如此之大的模型在实践落地的过程中是存在问题的:

  • 由于参数量特别大,例如 BERT 和 GPT-2,在移动设备等资源受限的系统中是不可部署的。
  • 由于推理时间效率低,它们也可能不适用于实时系统,对于QPS压测很多场景基本是不过关的。
  • 根据摩尔定律可知,我们需要在一定时间过后重新压缩模型以及重新评估模型性能。

针对上述问题,本文提出了一种基于领域知识的高效迁移学习方法:

  • 作者将BERT-large蒸馏到了单层的BiLSTM中,参数量减少了100倍,速度提升了15倍,效果虽然比BERT差不少,但可以和ELMo打成平手。
  • 同时因为任务数据有限,作者基于以下规则进行了10+倍的数据扩充:用[MASK]随机替换单词;基于POS标签替换单词;从样本中随机取出n-gram作为新的样本

2 相关工作

关于模型压缩的背景介绍,大家可以看下 李rumor的文章https://zhuanlan.zhihu.com/p/273378905,总结比较精炼和到位,这里不再重复赘述:

Hinton在NIPS2014[1]提出了知识蒸馏(Knowledge Distillation)的概念,旨在把一个大模型或者多个模型ensemble学到的知识迁移到另一个轻量级单模型上,方便部署。简单的说就是用小模型去学习大模型的预测结果,而不是直接学习训练集中的label。

在蒸馏的过程中,我们将原始大模型称为教师模型(teacher),新的小模型称为学生模型(student),训练集中的标签称为hard label,教师模型预测的概率输出为soft label,temperature(T)是用来调整soft label的超参数。

蒸馏这个概念之所以work,核心思想是因为好模型的目标不是拟合训练数据,而是学习如何泛化到新的数据。所以蒸馏的目标是让学生模型学习到教师模型的泛化能力,理论上得到的结果会比单纯拟合训练数据的学生模型要好。

在BERT提出后,如何瘦身就成了一个重要分支。主流的方法主要有剪枝、蒸馏和量化。量化的提升有限,因此免不了采用剪枝+蒸馏的融合方法来获取更好的效果。接下来将介绍BERT蒸馏的主要发展脉络,从各个研究看来,蒸馏的提升一方面来源于从精调阶段蒸馏->预训练阶段蒸馏,另一方面则来源于蒸馏最后一层知识->蒸馏隐层知识->蒸馏注意力矩阵。

3 模型方法

本篇论文第一步选择teacher 模型和student模型,第二步确立蒸馏程序:确立logit-regression目标函数和迁移数据集构建。

3.1 模型选择

对于“teacher”模型,本文选择Bert去做微调任务,比如文本分类,文本对分类等。对文本分类,可以直接将文本输入到bert,拿到cls输出直接softmax,可以得到每个标签概率:y(B)=softmax(Wh)y (B) = softmax(Wh)y(B)=softmax(Wh),其中W∈Rk∗dW\in R^{k *d}WRkd是softmax权重矩阵,k是类别个数。对于文本对任务,我们可以直接两个文本输入到Bert提取特征,然后收入到softmax进行分类。

对于“student”模型,本文选择的是BiLSTM和一个非线性分类器。如下图所示:


主要流程是将文本词向量表示,输入到BiLSTM,选取正向和反向最后时刻的隐藏层输出并进行拼接,然后经过一个relu输出,输入到softmax得到最后的概率。

3.2 蒸馏目标

yi=softmax(z)=exp(wiTh)∑iexpWjThy_{i}=softmax(z)=\frac{exp(w_{i}^{T}h)}{\sum_{i}exp{W_{j}^{T}h}}yi=softmax(z)=iexpWjThexp(wiTh)
其中wiw_{i}wi是权重矩阵WWW的第i行,zzz等于wThw^ThwTh

蒸馏的目标就是为了最小化student模型与teacher模型的平方误差MSE:
Ldistill=∣∣Z(B)−Z(S)∣∣22L_{distill}=||Z(B)-Z(S)||_{2}^{2}Ldistill=Z(B)Z(S)22
其中Z(B)Z(B)Z(B)Z(S)Z(S)Z(S)分类代表teacher和student模型的logit输出

最终蒸馏模型的训练函数可以将MSE损失和交叉熵损失结合起来:
L=α∗LCE+(1−α)Ldistill=−α∑itilog(yiS)−(1−α)∣∣Z(B)−Z(S)∣∣22L=\alpha *L_{CE}+(1-\alpha)L_{distill}\\ =-\alpha\sum_{i}t_{i}log(y_{i}^{S})-(1-\alpha)||Z^(B)-Z^(S)||_{2}^{2}L=αLCE+(1α)Ldistill=αitilog(yiS)(1α)Z(B)Z(S)22

3.3 数据增强

  • 用[MASK]随机替换单词:“I loved the comedy.”变成“I [MASK] the comedy”
  • 基于POS标签替换单词;“What do pigs eat?” 变成“How do pigs eat?”
  • 从样本中随机取出n-gram作为新的样本

4 实验结果

本文采用的数据集为SST-2、MNLI、QQP
实验结果如下:

推理更加快:

5 蒸馏代码

https://github.com/qiangsiwei/bert_distill

# coding:utf-8
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.autograd import Variable
from keras.preprocessing import sequence
import pickle
from tqdm import tqdm
import numpy as np
from transformers import BertTokenizer
from utils import load_data
from bert_finetune import BertClassificationUSE_CUDA = torch.cuda.is_available()
if USE_CUDA: torch.cuda.set_device(0)
FTensor = torch.cuda.FloatTensor if USE_CUDA else torch.FloatTensor
LTensor = torch.cuda.LongTensor if USE_CUDA else torch.LongTensor
device = torch.device('cuda' if USE_CUDA else 'cpu')class RNN(nn.Module):def __init__(self, x_dim, e_dim, h_dim, o_dim):super(RNN, self).__init__()self.h_dim = h_dimself.dropout = nn.Dropout(0.2)self.emb = nn.Embedding(x_dim, e_dim, padding_idx=0)self.lstm = nn.LSTM(e_dim, h_dim, bidirectional=True, batch_first=True)self.fc = nn.Linear(h_dim * 2, o_dim)self.softmax = nn.Softmax(dim=1)self.log_softmax = nn.LogSoftmax(dim=1)def forward(self, x):embed = self.dropout(self.emb(x))out, _ = self.lstm(embed)hidden = self.fc(out[:, -1, :])return self.softmax(hidden), self.log_softmax(hidden)class Teacher(object):def __init__(self, bert_model='bert-base-chinese', max_seq=128, model_dir=None):self.max_seq = max_seqself.tokenizer = BertTokenizer.from_pretrained(bert_model, do_lower_case=True)self.model = torch.load(model_dir)self.model.eval()def predict(self, text):tokens = self.tokenizer.tokenize(text)[:self.max_seq]input_ids = self.tokenizer.convert_tokens_to_ids(tokens)input_mask = [1] * len(input_ids)padding = [0] * (self.max_seq - len(input_ids))input_ids = torch.tensor([input_ids + padding], dtype=torch.long).to(device)input_mask = torch.tensor([input_mask + padding], dtype=torch.long).to(device)logits = self.model(input_ids, input_mask, None)return F.softmax(logits, dim=1).detach().cpu().numpy()def train_student(bert_model_dir="/data0/sina_up/dajun1/src/doc_dssm/sentence_bert/bert_pytorch",teacher_model_path="./model/teacher.pth",student_model_path="./model/student.pth",data_dir="data/hotel",vocab_path="data/char.json",max_len=50,batch_size=64,lr=0.002,epochs=10,alpha=0.5):teacher = Teacher(bert_model=bert_model_dir, model_dir=teacher_model_path)teach_on_dev = True(x_tr, y_tr, t_tr), (x_de, y_de, t_de), vocab_size = load_data(data_dir, vocab_path)l_tr = list(map(lambda x: min(len(x), max_len), x_tr))l_de = list(map(lambda x: min(len(x), max_len), x_de))x_tr = sequence.pad_sequences(x_tr, maxlen=max_len)x_de = sequence.pad_sequences(x_de, maxlen=max_len)with torch.no_grad():t_tr = np.vstack([teacher.predict(text) for text in t_tr])t_de = np.vstack([teacher.predict(text) for text in t_de])with open(data_dir+'/t_tr', 'wb') as fout: pickle.dump(t_tr,fout)with open(data_dir+'/t_de', 'wb') as fout: pickle.dump(t_de,fout)model = RNN(vocab_size, 256, 256, 2)if USE_CUDA: model = model.cuda()opt = optim.Adam(model.parameters(), lr=lr)ce_loss = nn.NLLLoss()mse_loss = nn.MSELoss()for epoch in range(epochs):losses, accuracy = [], []model.train()for i in range(0, len(x_tr), batch_size):model.zero_grad()bx = Variable(LTensor(x_tr[i:i + batch_size]))by = Variable(LTensor(y_tr[i:i + batch_size]))bl = Variable(LTensor(l_tr[i:i + batch_size]))bt = Variable(FTensor(t_tr[i:i + batch_size]))py1, py2 = model(bx)loss = alpha * ce_loss(py2, by) + (1-alpha) * mse_loss(py1, bt)  # in paper, only mse is usedloss.backward()opt.step()losses.append(loss.item())for i in range(0, len(x_de), batch_size):model.zero_grad()bx = Variable(LTensor(x_de[i:i + batch_size]))bl = Variable(LTensor(l_de[i:i + batch_size]))bt = Variable(FTensor(t_de[i:i + batch_size]))py1, py2 = model(bx)loss = mse_loss(py1, bt)if teach_on_dev:loss.backward()             opt.step()losses.append(loss.item())model.eval()with torch.no_grad():for i in range(0, len(x_de), batch_size):bx = Variable(LTensor(x_de[i:i + batch_size]))by = Variable(LTensor(y_de[i:i + batch_size]))bl = Variable(LTensor(l_de[i:i + batch_size]))_, py = torch.max(model(bx, bl)[1], 1)accuracy.append((py == by).float().mean().item())print(np.mean(losses), np.mean(accuracy))torch.save(model, student_model_path)if __name__ == "__main__":train_student()

参考链接

  • 【经典简读】知识蒸馏(Knowledge Distillation) 经典之作
  • 【论文笔记】Distilling Task-Specific Knowledge from BERT into Simple Neural Networks
  • 知识蒸馏论文选读(二)

给Bert加速吧!NLP中的知识蒸馏论文 Distilled BiLSTM解读相关推荐

  1. Paper:《Distilling the Knowledge in a Neural Network神经网络中的知识蒸馏》翻译与解读

    Paper:<Distilling the Knowledge in a Neural Network神经网络中的知识蒸馏>翻译与解读 目录 <Distilling the Know ...

  2. 【深度学习】深度学习中的知识蒸馏技术(上)简介

    本文概览: 1. 知识蒸馏介绍 1.1 什么是知识蒸馏? 在化学中,蒸馏是一种有效的分离不同沸点组分的方法,大致步骤是先升温使低沸点的组分汽化,然后降温冷凝,达到分离出目标物质的目的.化学蒸馏条件:( ...

  3. 深度学习中的知识蒸馏技术(上)

    本文概览: 1. 知识蒸馏介绍 1.1 什么是知识蒸馏? 在化学中,蒸馏是一种有效的分离不同沸点组分的方法,大致步骤是先升温使低沸点的组分汽化,然后降温冷凝,达到分离出目标物质的目的.化学蒸馏条件:( ...

  4. 深度学习中的知识蒸馏技术!

    本文概览: 1. 知识蒸馏介绍 1.1 什么是知识蒸馏? 在化学中,蒸馏是一种有效的分离不同沸点组分的方法,大致步骤是先升温使低沸点的组分汽化,然后降温冷凝,达到分离出目标物质的目的.化学蒸馏条件:( ...

  5. 目标检测中的知识蒸馏方法

    目标检测中的知识蒸馏方法 知识蒸馏 (Knowledge Distillation KD) 是模型压缩(轻量化)的一种有效的解决方案,这种方法可以使轻量级的学生模型获得繁琐的教师模型中的知识.知识蒸馏 ...

  6. 深度学习中的知识蒸馏技术(下)

    本文概览: 写在前面: 这是一篇介绍知识蒸馏在推荐系统中应用的文章,关于知识蒸馏理论基础的详细介绍,请看我的这篇文章: 深度学习中的知识蒸馏技术(上) 1. 背景介绍 1.1 简述推荐系统架构 如果从 ...

  7. 知识蒸馏论文翻译(5)—— Feature Normalized Knowledge Distillation for Image Classification(图像分类)

    知识蒸馏论文翻译(5)-- Feature Normalized Knowledge Distillation for Image Classification(图像分类) 用于图像分类的特征归一化知 ...

  8. 知识蒸馏论文翻译(7)—— Knowledge Distillation from Internal Representations(内部表征)

    知识蒸馏论文翻译(7)-- Knowledge Distillation from Internal Representations(内部表征) 文章目录 知识蒸馏论文翻译(7)-- Knowledg ...

  9. 知识蒸馏论文翻译(1)——CONFIDENCE-AWARE MULTI-TEACHER KNOWLEDGE DISTILLATION(多教师知识提炼)

    知识蒸馏论文翻译(1)--CONFIDENCE-AWARE MULTI-TEACHER KNOWLEDGE DISTILLATION(多教师知识提炼) 文章目录 知识蒸馏论文翻译(1)--CONFID ...

最新文章

  1. 谷歌CEO称公司预计每月收购一家小公司
  2. 关于保存到session里的信息
  3. L1-045 宇宙无敌大招呼
  4. php数据库中数据查询
  5. 多站点IIS的架设:端口法
  6. 【学习笔记】【C语言】进制
  7. html页面性能优化两则
  8. c语言数组特殊初始化方法
  9. Greenplum技术浅析
  10. Python敏感词过滤DFA算法+免费附带敏感词库
  11. RGB三色灯珠WS2812B/WS2815B
  12. 学计算机办公文员软件,办公文员必须掌握的办公软件有哪些
  13. ai人工智能_人工智能能否赢得奥运
  14. python电路仿真软件_4种电路仿真软件比较 - SmartLinkCloud,智联网云平台 - OSCHINA - 中文开源技术交流社区...
  15. MyBatisPlus的使用--十数个案例足以让你步入mybatisplus
  16. PIE-engine APP 教程 ——基于PIE云平台的城市生态宜居性评价系统——以京津冀城市群为例
  17. GEE|假彩色目视解译山东省玉米、水稻、小麦等样本集制作代码
  18. LeetCode 876、链表的中间结点
  19. 华罗庚统筹法与计算机专业,华罗庚的优选法与统筹法
  20. 大话西游猛击源码_我们猛击Return(Enter)键可能会演变的原因

热门文章

  1. 洛谷-Cow Gymnastics B
  2. VS快捷键大全(超详细)
  3. 深度优先搜索DFS | Morris遍历:力扣99. 恢复二叉搜索树
  4. 新来个阿里 P7,仅花 2 小时,做出一个多线程永动任务,看完直接跪了
  5. Windows文件名区分大小写
  6. uva12627 Erratic Expansion
  7. Install Samba in CRUX
  8. 抖音如何查看访客记录,丨国仁网络
  9. vue项目Echarts更新数据是数据表展示错版
  10. k8s 部署 TCP node应用