文 | 炼丹学徒
编 | 小轶

从前车马很慢,显卡跑的也慢,一生只够爱一个RNN。后来时代进步了,数据量和计算力阔绰了,堆叠起来的Transformer能够在更深更宽的模型结构里吃下去更多的数据。从19年的预训练浪潮开始,暴力美学兴起,更深的Transformer更久的预训练更大的模型参数量,暴力出奇迹一个个NLP榜单被刷新,但谁又记得起来当初Transformer论文里“解决RNN无法并行化训练问题”的追求效率的motivation呢?身在普通高校,手握2080Ti和Titan V,向着大厂的预训练模型望洋兴叹,我们开始怀念起当初人人都训练得起的LSTM和GRU。那是精巧轻量的模型,那是人人都刷的起SOTA的时代。

今天这篇来自微软的论文告诉我们,大厂里有一些研究员也还是爱我们的,Finetuning Pretrained Transformers into RNNs,在保持性能的情况下,将预训练好的Transformer模型微调到其RNN变体,极大地降低显存使用和计算开销。

论文题目:
Finetuning Pretrained Transformers into RNNs

论文链接:
https://arxiv.org/abs/2103.13076

Arxiv访问慢的小伙伴也可以在 【夕小瑶的卖萌屋】订阅号后台回复关键词 【0407】 下载论文PDF~

本文提出的模型名为 T2R,代表 Transformer to RNN 。转换的过程为 swap-then-finetune ,即,对于一个预训练好的 Transformer 模型,我们将其










的注意力计算改为线性







的替换模块,然后进行微调
。可以预感到,其核心就在于如何用线性的子层对注意力层进行模拟。接下来,我们对其进行详解。

概述

在2019年EMNLP论文 Transformer Disp [1] 中,作者提出:可以将注意力层






的相似度计算()替换为核函数的分数

ICML'20的另一工作Transformers are RNNs [2]则在此基础上进一步优化,提出了










的注意力计算替换为线性







的模块

今天要讲的 T2R 这篇文章是紧随上面 ICML'20 这篇工作进行的。之前 Transformers are RNNs 的方法中,使用的核函数没有参数,不可训。而 T2R 把核函数里封装了一个MLP变成可训练的。T2R原文的推导直接使用了 Transformers are RNNsTransformer Disp 的结论,因而推导过程并不完整。我们今天也沿着T2R的思路进行讲解,如果想要更深入了解 Transformer 转 RNN 领域的,可以阅读下面两篇论文:

[1] Tsai et al. Transformer Disp: A Unified Understanding of Transformer's Attention via the Lens of Kernel. EMNLP 2019

[2] Katharopoulos et al. Transformers are RNNs: Fast autoregressive transformers with linear attention. ICML 2020

Transformer开销

Transformer 由多头注意力层、前馈层、层归一化层堆叠后组成。本篇论文中要替换的,就是其中的多头注意力层。

在开始讲解如何替换之前,我们还是先梳理一下传统Transformer的多头注意力层。整个计算过程可以总结如下图所示:

▲传统Transformer的多头注意力层计算过程

这张图我们自下往上看。首先,我们将多头注意力层的source隐状态记作






















,target隐状态记作






















如何理解此处的source和target:比如,在解码器的编码器-解码器注意力层中,




就是编码器端的序列长度,




就是解码器端的长度。在自回归推断的解码器自注意力层中,




就是已生成序列(加上自己)的长度,




等于1,指当前要预测的这个字符。

从隐状态




,我们通过线性变换得到








。则,注意力层的输出为:

其中,






操作
旨在计算









的相似度(这里划重点!等一会儿就要对这个计算动手脚了!):

上述的多头注意力的计算是我们熟知的。论文对其复杂度进行了分析。设多头数为




,每个头的隐状态长度




,每个




的隐状态总长







,则有如下结论:

  • 特征计算:即由隐状态




    计算得到








    的过程,复杂度分别为











    ,























  • 注意力计算: 由








    计算得到最终输出的过程,复杂度为









    ,与








    的长度成平方关系。

  • 推断时的显存








    ,与已经解码的长度线性相关。

注意力层的RNN替代方案

T2R的注意力层计算过程则如下图所示:

首先,我们注意到原始的注意力计算中,









的相似度计算方式()需要先进行点乘,放缩后再进行指数运算,难以开展后续的近似优化。所以这里的关键之处就在于,T2R






的相似度计算方案替换为核函数的乘积


























































































此处,



















的参数都是通过一个单层MLP学习得到的。
















维矩阵,














维bias向量,即,T2R的相似度计算核函数将原本




维的向量降到了




维然后进行相似度计算。对于多头计算中的每一个头,他们的



















是独立学出来的。因此,T2R在每一层中,共增加了










个可学习的参数(小于总参数量的2%)。

我们把新的相似度计算方法代入到注意力的输出式中,得到:
























































,则:

而根据 Transformers are RNNs [2] 的结论,此处的












可以视作RNN递归的隐状态
。比如,在解码器端做自回归生成时,每个词向它前文的单词进行注意力计算来预测下一个词,















可以被定义为递归的隐状态:

























































注意到我们主要讨论的






函数是针对






来计算相似度的,而






是由喂入该层的隐状态线性变化得到的。为了加速推断速度,具体实现中把





















































代入,得到从隐状态

























直接线性变换得到的结果,从而在推断的时候不需要计算






,而从隐状态直接计算得到相似度的值,即:

其中,

此时的开销:

  • 特征计算:我们记




    输出




    维的特征向量,则生成








    的复杂度为










    ,






















  • 注意力计算: 由








    计算得到最终输出的过程,假设k<<M,N,此时复杂度为
















    ,与








    的长度成线性关系。

  • 推断时的显存:假设k<<M,则占用显存








    ,为常数。

Transformer和T2R对比

讲到这里,我们再对比一下传统Transformer和T2R的差异:

  • 特征计算:计算




    不变,计算


















    ,











    降为










    ,










  • 注意力计算: 由









    降为
















    ,平方->线性。

  • 推断时的显存:由








    降为








    ,线性->常数。

实验

数据集的效果

T2R主要使用ELU和RFA作为baseline进行比较。ELU和RFA为此前的另外两篇使用核函数转Transformer为RNN工作。因为ELU和RFA的核函数都是不可训练的,所以无法取代预训练好的模型里的注意力层进行功能上的替换和拟合。

首先,T2R在语言模型上开展了实验。数据集使用WikiText-103,评测指标使用困惑度 perplexity 。发现T2R因为在核函数中放置了可训练的MLP,在加载预训练模型时获得更大的收益。

此外,T2R在翻译任务上开展实验,使用数据集 WMT14 EN-DE,WMT14 EN-FR 和 WMT17 ZH-EN。研究员们发现虽然随机初始化时,T2R弱于另外两个baseline,但是加载预训练后反超另外两个baseline。

生成时的加速和显存节省

研究员发现 T2R 比另外两个模型的推断速度更快(如下左图所示),因为使用了更小的特征维度,以及更快的特征计算方法。对于推断时的显存占用,Transformer 随着输出序列的增长而线性增加,转为 RNN 结构的模型则保持常数(如下右图所示)。

消融实验

随着核函数输出特征尺寸的增大,其效果也更加接近Transformer。相比于之前的工作,T2R 可以通过控制特征尺寸从而在效果和速度间权衡。

小结

本文提出的T2R,在 Transformers are RNNs 的基础上,将无参数的核函数封装为 MLP 加激活函数,从而可训练。在此基础上,T2R 替换掉预训练 Transformer 的注意力层,从而降低了计算消耗和显存使用,并且得到和原预训练模型相似的结果。

后台回复关键词【入群

加入卖萌屋NLP/IR/Rec与求职讨论群

后台回复关键词【顶会

获取ACL、CIKM等各大顶会论文集!

Transformer太大了,我要把它微调成RNN相关推荐

  1. Hugging Face实战(NLP实战/Transformer实战/预训练模型/分词器/模型微调/模型自动选择/PyTorch版本/代码逐行解析)下篇之模型训练

    模型训练的流程代码是不是特别特别多啊?有的童鞋看过Bert那个源码写的特别特别详细,参数贼多,运行一个模型百八十个参数的. Transformer对NLP的理解是一个大道至简的感觉,Hugging F ...

  2. 阿里二面:Redis 中的 AOF 文件太大了怎么办?

    一.前言 写这篇文章的目的是来自我的一位粉丝的投稿,说面试阿里被问到了这个问题.不得不说阿里的面试问的都挺有质量,一般的我们只会关注 Redis 的两种持久化方式 RDB 和 AOF.但老周这里盲猜面 ...

  3. 《数据分析实战 基于EXCEL和SPSS系列工具的实践》一3.4 数据量太大了怎么办

    本节书摘来自华章出版社<数据分析实战 基于EXCEL和SPSS系列工具的实践>一书中的第3章,第3.4节,纪贺元 著,更多章节内容可以访问云栖社区"华章计算机"公众号查 ...

  4. 不管工作压力太大了,还是工作压力太小了;都容易引起开发人员的辞职风波...

    最近一件事情充分意识到,工作压力太大了开发人员容易辞职,集团公司有一个大型软件项目进展不利,我们公司的领导干部也被叫过去开会需要进行协助开发,由于这个项目会影响到整个集团的声誉,上级领导也很重视,项目 ...

  5. win10打开程序响应很慢_小程序商城打开加载很慢?你上传的图片是不是太大了,压缩一下吧!...

    原创:轻栈 今天分享一篇纯干货,看完能给小程序商城提速. 打开小程序商城,有时能看到加载条,先看到内容要等等. 等待是一件消磨耐心的事情,在这个浮躁的时代,愿意等的人真的少.所以,我们要找出导致小程序 ...

  6. StringBuilder 和 String拼接10万个字符串的速度测试差别太大了

    /*** StringBuilder 和 String拼接10万个字符串的速度测试差别太大了* String耗时毫秒: 32693* StringBuilder耗时毫秒: 16*/ public cl ...

  7. 《数据分析实战:基于EXCEL和SPSS系列工具的实践》一3.4 数据量太大了怎么办

    本节书摘来华章计算机<数据分析实战:基于EXCEL和SPSS系列工具的实践>一书中的第3章 ,第3.4节,纪贺元 著 更多章节内容可以访问云栖社区"华章计算机"公众号查 ...

  8. html中图片太大了,css背景图片太大的坏处与解决方法

    在制作网页的过程中,有时候我们为了页面的个性.漂亮,会给通过CSS给网页设置一个很大的背景图片,可是背景图片太大的话不但会给我们的访客带来一些烦恼,还会对搜索引擎不友好,导致排名的降低,实在是得不偿失 ...

  9. 如何压缩动态图片大小?gif图太大了怎么压缩?

    对于新媒体行业人员来说,平时在工作中需要存非常多的素材,这些素材中有很多就是gif格式的,随着积累的素材越来越多,这些素材会占用大量的储存空间,那么遇到这种情况应该怎么办呢?应该如何压缩动态图片大小? ...

最新文章

  1. 8年程序员210天没找到工作,小公司老板:降薪5千,爱来不来
  2. mybatis修改mysql变量_Java通过MyBatis框架对MySQL数据进行增删查改的基本方法
  3. EditThisCookie使用
  4. python下载网页歌词_python3个人学习笔记-批量下载分析歌词2
  5. 宝塔服务器管理助手Linux面版-使用教程
  6. java jtable 单元格编辑_java – 在基于JTable面板的单元格编辑器中...
  7. opencv puttext
  8. mysql修改服务器ip,mysql数据库修改服务器ip
  9. python安装第三方库太慢,很容易失败报错?教你如何提速
  10. 电容或电感的电压_眼见不一定为实!电阻、电容和电感的实际等效模型
  11. java提现功能开发_如何利用java实现提现金额到支付宝账户的功能
  12. 需求(Java):使用Jsoup获取知乎网页的信息,信息如下:
  13. 如何用计算机求极限,计算器的极限_500字
  14. 使用光泵磁力仪(OPMs)非接触测量视网膜活动
  15. scrollHeight, clientHeight, offsetHeight的区别
  16. 浏览器报Uncaught ReferenceError: require is not defined
  17. 一个奇怪的bug,记录一下
  18. 计算机高级工考试题库2018,维修电工高级工试题题库2018
  19. linux优化网页加载过程,【zz】Linux起步过程中硬件模块的加载
  20. 图标跟着摄像机(Camera)orthographicSize的值改变大小

热门文章

  1. C++学习10 static静态成员变量和静态成员函数
  2. Java的几个同步辅助类
  3. 迪美特TVZ8双核智能高清播放器 在电视上编程不是梦
  4. FarPoint Spread For .Net 4.0
  5. STL中vectortype的复制
  6. 北航博士,研究所月入两万
  7. 那些年,我和发哥在恒大的日子
  8. 面试官让你用C语言实现大数相乘,慌吗?
  9. rk3188开机失败(ump_file_open() 251)
  10. 计算机硬件知识考证题,计算机硬件知识题(答案)资料