链接:http://arxiv.org/abs/2010.02123

简介

Lifelong Language Knowledge Distillation终身语言知识提炼,是一种利用知识蒸馏的终身学习方法。
其主要思想是:每次遇到新任务时,不直接让model去学习,而是先在任务上训练一个teacher model,然后运用知识蒸馏技术,将知识传递给model。

  • 知识蒸馏:有两个模型: student model(小)和teacher model(大)。student model需要通过训练,模仿teacher model的行为并使得两者性能相近。

本文将知识蒸馏的思想运用到了终身学习的语言领域。但不同之处在于: L2KD的student model和teacher model是一样大的。

如下图所示。这种方法只需要为每个新任务多花一点时间训练一个一次性teacher model,在学习下一个任务时可以丢弃该模型;因此,L2KD不需要额外的内存或模型容量,这使得提出的模型在实际使用中更有效。

必须要指出的是:L2KD作为一种方法而非具体模型,可以加到大部分LLL模型上去。
因此,本文就将L2KD加到了LAMOL上去。
LAMOL介绍:https://blog.csdn.net/Baigker/article/details/121650749?spm=1001.2014.3001.5501

Proposed Approach

正如在简介中提到的,L2DK本质是一种知识蒸馏,并且在实际运用中要加到其他模型上去。因此本文也遵循这一顺序,即:先介绍LAMOL,再介绍知识蒸馏,最后才说明L2KD的原理。

LAMOL

在LAMOL的setting中,语言数据集中的所有样本都有三个部分:上下文、问题和答案。我们可以简单地将这三个部分连接成一个句子,并训练模型根据上下文和前面的问题生成答案。

除了生成给定问题的答案外,该模型还可以同时学习对整个训练样本建模。
通过这样做,在训练下一个任务时,模型可以生成前一个任务的样本(被称为伪数据),同时训练新任务的数据和前一个任务的伪数据。
因此,模型在适应新任务时忘记的更少。

知识蒸馏

语言模型

一般来说,语言模型的目标是使预测下一个词时的负对数似然(NLL)最小化:
而在知识蒸馏中,我们将student model和teacher model之间的预测误差最小化。
被认为是误差的目标单元可能在单词级或序列级进行。

Word-Level (Word-KD)

在预测下一个词时,我们最小化student和teacher的输出分布之间的交叉熵:

其中输入x<tx_{<t}x<t​来自标准答案(ground truth)序列。VVV表示词汇集,VkV_kVk​为VVV中的第kkk个单词。
θSθ_SθS​和θTθ_TθT​分别为学生模型和教师模型的参数。

Sequence-Level (Seq-KD)

我们将teacher model中的贪心解码或beam search输出序列x^\hat xx^作为硬目标直接最小化负对数似然,就像普通语言建模一样:

Seq-KD通常用于改善弱非自回归翻译(NAT)模型(Zhou et al., 2020),减少机器翻译数据集中的多模态问题(Gu et al.,2018)。

Soft Sequence-Level (Seq-KDsoft)

我们进一步研究软目标加上teacher解码序列是否对模型更有帮助,因此我们进行Seq−KDsoftSeq-KD_{soft}Seq−KDsoft​,对teacher model的贪心解码或beam search输出进行Word-KD。
Seq−KDsoftSeq-KD_{soft}Seq−KDsoft​和Word-KD之间的唯一区别是Word-KD的输入x<tx<tx<t现在被替换为x^<t\hat x<tx^<t,teacher model的输出序列:

注意,无论我们在知识蒸馏中使用何种损失函数,teacher model总是固定的。因此,LLL模型求参数θS∗θ^*_SθS∗​的优化过程可以写成:

L2DK

知识蒸馏可以应用于最小化LM和QA在LAMOL中的损失。假设有一个任务流的数据集{D1,D2,…}\{ D_1, D_2,…\}{D1​,D2​,…},我们的LLL模型从D1D_1D1​学习到Dm−1D_{m-1}Dm−1​,现在适用于DmD_mDm​。首先,我们通过最小化LAMOL中LM和QA的负对数似然损失来训练DmD_mDm​的教师模型,并获得模型参数θmTθ_m^TθmT​。
现在我们的LLL模型(参数θSθ_SθS​)可以通过从教师模型中知识蒸馏来训练DmD_mDm​:
给定一个训练样本Xim={x1,x2,…,xT}∈DmX^m_i = \{ x_1, x_2,…, x_T\} ∈D_mXim​={x1​,x2​,…,xT​}∈Dm​(包括上下文、问题和答案),我们将其最小化:

其中a1a_1a1​表示答案的起始位置。这里我们以Word-KD为例,但我们也可以将答案部分的文本替换为教师生成的答案,从而进行Seq−KDsoftSeq-KD_{soft}Seq−KDsoft​或Seq-KD。

LLL模型除了对来自DmD_mDm​的样本进行训练外,还会为之前的任务生成伪数据DprevD_{prev}Dprev​。然而,对于DprevD_{prev}Dprev​中的样本,我们不能在这里进行知识蒸馏,因为在我们的设置中,之前任务的教师模型在适应下一个任务后将被丢弃。因此,给定生成的数据Xiprev∈DprevX^{prev}_i∈D_{prev}Xiprev​∈Dprev​,我们在这里只最小化NLL损失:

最后,我们共同优化了两种损失,得到了LLL模型的参数θS∗θ^∗_SθS∗​:

整体算法流程:

【Lifelong learning】Lifelong Language Knowledge Distillation相关推荐

  1. 【论文翻译】Few Sample Knowledge Distillation for Efficient Network Compression

    Few Sample Knowledge Distillation for Efficient Network Compression 用于高效网络压缩的少样本知识提取 论文地址:https://ar ...

  2. 【Lifelong learning】LAMOL: LANGUAGE MODELING FOR LIFELONG LANGUAGE LEARNING

    链接:http://arxiv.org/abs/1909.03329v2 简介 之前的终身学习(LLL)模型大多应用于CV和游戏领域,在nlp方面的应用比较少,本文因此提出了一个语言专用的终身学习模型 ...

  3. 【NLP learning】Tokenizer分词技术概述

    [NLP learning]Tokenizer分词技术概述 目录 [NLP learning]Tokenizer分词技术概述 极简方法--空格分词(Space) 预训练模型的分词方法--子词分解/子标 ...

  4. 【Machine Learning】OpenCV中的K-means聚类

    在上一篇([Machine Learning]K-means算法及优化)中,我们介绍了K-means算法的基本原理及其优化的方向.opencv中也提供了K-means算法的接口,这里主要介绍一下如何在 ...

  5. 【Lifelong learning】Continual Learning with Knowledge Transfer for Sentiment Classification

    链接:http://arxiv.org/abs/2112.10021 简介 这是一篇在情感分类Sentiment Classification运用连续学习Continual Learning的pape ...

  6. 【Lifelong learning】Efficient Meta Lifelong-Learning with Limited Memory

    链接:http://arxiv.org/abs/2010.02500 简介 实现lifelong learning的最大问题便是catastrophic forgetting(机器会把之前的知识忘了) ...

  7. 【Distill 系列:三】CVPR2019 Relational Knowledge Distillation

    Relational Knowledge Distillation Relational Knowledge Distillation TL;DR teacher 和 student feature ...

  8. 【Machine Learning】TensorFlow实现K近邻算法预测房屋价格

    1前言 机器学习KNN算法(K近邻算法)的总体理论很简单不在这里赘述了. 这篇文章主要问题在于如果利用tensorflow深度学习框架来实现KNN完成预测问题,而不是分类问题,这篇文章中涉及很多维度和 ...

  9. 【Machine Learning】梯度下降算法介绍_02

    文章目录 前言 一.梯度 1.1 导数 1.2 偏导数 二.举例梯度下降 三.训练样本 四.梯度下降 4.1 量梯度下降(Batch Gradient Descent,BGD) 4.2 随机梯度下降( ...

最新文章

  1. MFC滑块的使用方式
  2. Scrapy基本用法
  3. java虚成员函数_Java常见知识点汇总(④)——虚函数、抽象函数、抽象类、接口...
  4. HTTP协议通信原理
  5. 资格赛:题目1:同构
  6. mxnet系列教程之1-第一个例子
  7. eos测试规格_希望您的测试更有效? 这样写您的规格。
  8. Ubuntu的多文件编译以及c语言的数组、函数
  9. httpServletRequest中的流只能读取一次的原因
  10. 如何保护Mac的数据安全?
  11. 008-对象—— 对象$this self parent 内存方式及使用方法讲解
  12. 《Elementary Methods in Number Theory》勘误
  13. oracle minus 利用率,Oracle Minus 取差集(也可以用来做分页,但效率不高)
  14. [CTF攻防世界] WEB区 关于备份的题目
  15. 计算2015年4月6日是一年中的第几星期
  16. H-A + B用于投入产出实践(VIII)
  17. 一个有意思的echarts3D树状图
  18. 如何给Windows计算机加域
  19. NameNode故障处理方法
  20. 游戏开发中,图片资源的精简

热门文章

  1. 3.6 高速缓冲存储器
  2. 为什么win7计算机没有d盘,win7没有D盘 win7d盘
  3. ArcSDE和Geodatabase10.1抢先版谍照介绍(1)
  4. Intel Thread Building Blocks (TBB) 入门篇
  5. 【计量经济学】模型设定问题
  6. 网站漏洞扫瞄时被云盾拦截解决方法
  7. Mixly 自定有OLED
  8. 如何用PS在图片上添加箭头
  9. 我的ubuntu之shell下载音乐
  10. 阿里巴巴微服务核心手册:Spring Boot+Spring cloud+Dubbo