一. Seq2seq模型蒸馏方法总体过程如下

1. 训练teacher模型

2. 产生student模型

3. 利用teacher模型预测的logits和来自语料的true labels来计算student 模型的训练过程中的loss。

二. 涉及的具体步骤和参数有

1. 训练参数量相对较大的teacher模型。

2. 生成student模型,可以从teacher模型结构中抽取部分层组成,也可以随机初始化student模型的参数。

如果从teacher模型中抽取,则可以在训练时固定某些层,例如可在训练时freeze_embeds.

如果student模型的encoder和teacher模型的encoder完全一致,在训练时,可以考虑freeze_encoder。其他情况则不考虑freeze_encoder。

3. 根据teacher logits产生的时间不同,模型蒸馏可分为在线蒸馏离线蒸馏

        离线蒸馏是采用teacher模型,预先将decoder端每个token对应的词表(或类别)大小的概率分布预测出来,在训练时和true label一起输入来计算loss。

        在线蒸馏是同时将teacher模型和student模型加载到训练机上,在训练时利用teacher模型来预测每个token位置的概率分布(logits), 同时和true label一起参与loss的计算。

在线蒸馏时,teacher模型参数固定,只有student模型的参数为trainable状态。

三. 关于loss的计算

1. Loss共有3部分构成,即来自teacher_logits的loss_ce, 来自true_labels的loss_mlm, 和来自中间层的loss_hid.

对应3个loss部分在总的loss中的比例系数可以分别用alpha_ce, alpha_mlm, 和alpha_hid表示。因此总的loss可以表示为:

        loss_total = (alpha_ce * loss_ce) + (alpha_mlm * loss_mlm) + (alpha_hid * loss_hid)

其中,

loss_ce = distill_loss_fn(student_logits, teacher_logits,temperature)

loss_mlm = loss_fn(student_logits, true_labels)

loss_hid = mse_loss(student_hid, teacher_hid).

2. 关于loss_hid可以这样理解,采用teacher中的某些层来监督student中的各层的结果。例如采用一个12层的teacher模型,来蒸馏一个3层的student模型,如果只关注encoder端,可以用teacher_encoder  [0, 6, 11]层来分别监督student_encoder  [0, 1, 2]层的训练结果。

如果是离线蒸馏,并且需要在loss中计算student各层的损失,则在需要将teacher模型各层的结果,和teacher logits一起预先计算并保存。

3. loss中涉及的3个部分的损失函数不同,其中mlm对应的是一般的cross_entropy, hid对应的为mse,ce部分对应的为和温度相关的KLDivLoss, loss_ce具体可以描述为:

loss_ce = KLDivLoss(

softmax(student_logits/temperature, dim=-1),  # vocab_size

softmax(teacher_logits/temperature, dim=-1)

) * (temperature ^ 2)

关于最后需要乘温度的平方,可以阅读【经典简读】知识蒸馏(Knowledge Distillation) 经典之作 - 知乎,简单表述为,loss_ce乘上 temperature^2 后,与loss_mlm的值相当,因此为了平衡loss_ce对损失的贡献,要乘temperature^2 。

4. 机器翻译(生成式)中使用离线蒸馏的问题

尝试将100句中英语料的teacher logtis预测并保存,发现保存后的文件为869M(.npy格式).这很大原因是因为label的维度过大导致的,因为teacher logits的最后一个维度为词表大小,词表大小为5w左右(裁剪后的mbart50模型)。

考虑到机器翻译的语句对经常为千万级别,对teacher logtis的存储空间要求较高,因此离线蒸馏在现有方法改进之前,并不适用机器翻译。

5. 综合来看,蒸馏涉及的主要参数有

--teacher_model

--student_encoder_layers=3

--student_decoder_layers=3

--temperature=2

--alpha_ce=0.5

--alpha_mlm=0.5

--alpha_hid=0

--freeze_encoder=False

--freeze_embeds

--max_sentence_length=64

--train_batch_size

--train_epochs=5

Seq2seq模型蒸馏方法相关推荐

  1. BERT模型蒸馏有哪些方法?

    ©PaperWeekly 原创 · 作者|蔡杰 学校|北京大学硕士生 研究方向|问答系统 我们都知道预训练模型的标准范式: pretrain-利用大量的未标记数据通过一些自监督的学习方式学习丰富的语义 ...

  2. 使用seq2seq模型进行机器翻译的方法不同

    自然语言处理| 深度学习 (Natural language processing | Deep learning) Machine translation is a computational li ...

  3. seq2seq模型_推断速度达seq2seq模型的100倍,谷歌开源文本生成新方法LaserTagger

    使用 seq2seq 模型解决文本生成任务伴随着一些重大缺陷,谷歌研究人员提出新型文本生成方法 LaserTagger,旨在解决这些缺陷,提高文本生成的速度和效率. 选自arXiv,作者:Eric M ...

  4. 推断速度达seq2seq模型的100倍,谷歌开源文本生成新方法LaserTagger

    选自arXiv 作者:Eric Malmi等 机器之心编译 序列到序列(seq2seq)模型给机器翻译领域带来了巨大变革,并成为多种文本生成任务的首选工具,如文本摘要.句子融合和语法纠错.模型架构改进 ...

  5. 降低预测过程计算成本,这些NLP模型压缩方法要知道

    编译 | 凯隐 出品 | AI科技大本营(ID:rgznai100) 近年来,基于谷歌Transformer的语言模型在神经机器翻译,自然语言推理和其他自然语言理解任务上取得了长足进展. 通过多种语言 ...

  6. 蚂蚁金服AAAI论文:基于长短期老师的样本蒸馏方法和自动车险定损系统的最新突破...

    来源 | 蚂蚁金服 出品 | AI科技大本营(ID:rgznai100) 一年一度在人工智能方向的顶级会议之一AAAI 2020于2月7日至12日在美国纽约举行,旨在汇集世界各地的人工智能理论和领域应 ...

  7. 【模型蒸馏】从入门到放弃:深度学习中的模型蒸馏技术

    点击上方,选择星标或置顶,每天给你送干货! 阅读大概需要17分钟 跟随小博主,每天进步一丢丢 来自 | 知乎   作者 | 小锋子Shawn 地址 | https://zhuanlan.zhihu.c ...

  8. seq2seq模型_直观理解并使用Tensorflow实现Seq2Seq模型的注意机制

    采用带注意机制的序列序列结构进行英印地语神经机器翻译 Seq2seq模型构成了机器翻译.图像和视频字幕.文本摘要.聊天机器人以及任何你可能想到的包括从一个数据序列到另一个数据序列转换的任务的基础.如果 ...

  9. 娓娓道来!那些BERT模型压缩方法

    本文约3000字,建议阅读10+分钟 本文主要介绍知识蒸馏.参数共享和参数矩阵近似方法. 作者 | Chilia 哥伦比亚大学 nlp搜索推荐 整理 | NewBeeNLP 基于Transformer ...

最新文章

  1. 写给非技术人员的机器学习指南
  2. 深入理解Memcache原理
  3. Cisco二层交换机命令
  4. python初学者编程指南_动态编程初学者指南
  5. python Chrome + selenium自动化测试与python爬虫获取网页数据
  6. 一个方便使用的在线截图Web控件-WebImageMaker
  7. 维基百科(wikipedia)数据下载(含地理数据)
  8. 【算法题目】数组中的逆序对
  9. Java的break和continue关键字
  10. pycharm 设置环境变量
  11. 【渝粤教育】电大中专电商运营实操 (14)作业 题库
  12. 我读《写给大家看的设计书》
  13. 游戏上云成标配 云服务器该怎么选?
  14. 535. TinyURL 的加密与解密 : 设计一个 URL 简化系统
  15. 高德API 经纬度转换地市区县(含读取文件)
  16. 从MySQL Bug#67718浅谈B+树索引的分裂优化
  17. Python基础 PyCharm如何新建项目
  18. 使用163网易相册的朋友注意啦!
  19. 俄罗斯黑客挑战美国国家网络安全
  20. 后疫情时代,VR购物—零售业的硬核破局之道

热门文章

  1. 中文LLaMA模型和指令精调的Alpaca大模型:中文数据进行二次预训练,进一步提升了中文基础语义理解能力
  2. 【owl】OWL之动画
  3. 华三交换机查看上层交换机vlan
  4. (PCB系列三)AD六层板布线经验累积
  5. Android应用的加固与逆向
  6. input函数和类型转换
  7. Qt QImage像素格式小结
  8. SQL server数据库手动建库建表建约束,代码建库建表,数据库备份
  9. kaggle之路-准备工作
  10. 当访问共享文件夹时需输入用户名和密码的解决办法