作者 | Chilia

整理 | NewBeeNLP

首先来看一下原始Transformer的复杂度

self-attention复杂度

记:序列长度为n,一个位置的embedding大小为d。例如(32,512,768)的序列,n=512,d=768.

首先,得到的QKV都是大小为 的。

  • 相似度计算 : 与 运算,得到 矩阵,复杂度为

  • softmax计算: 对每行做softmax复杂度为 ,则n行的复杂度为

  • 乘上V加权: 与运算,得到矩阵,复杂度为

多头selfattention复杂度
  • Attention操作复杂度:首先经过"切头",把输出变成 长度,就是 和 的运算,由于h为常数,复杂度为

  • 之后的softmax和乘V加权同上。

  • 之后,还需要把这些头拼接起来,经过一层线性映射之后输出。concat操作拼起来形成nxd的矩阵,然后经过输出线性映射,保证输出也是的,所以是与计算,复杂度为

故最后的复杂度为:

1. Sparse Transformer

  • 论文:Generating Long Sequences with Sparse Transformers

  • 地址:https://arxiv.org/pdf/1904.10509.pdf

Sparse Attention是为了解决Transformer模型随着长度n的增加,Attention部分所占用的内存和计算呈平方增加的问题。原始Transformer的复杂度为 , 而sparse transformer试图把此复杂度降低为 .这样,就可以处理上千长度的输入,层数可以达到上百层。

1.1 Intuition

Transformer的Decoder部分是一个 自回归(AR) 模型。对于图像生成任务,可以把图像的像素点按照从上到下从左到右的方式当成一个序列,然后在序列上去做自回归。

论文中首先构造了一个128层的full-attention transformer网络,并在Cifar10图像生成问题上进行了训练。如下图所示,底部的黑色部分表示尚未生成到的部分,白色凸显的部分则是当前步骤注意力权重高的地方。

  • (a)中是transformer中比较低层layer的注意力,可以看到,低层次的时候主要关注的还是 局部区域 的部分。

  • (b)在第19层和20层,Attention学习到了横向和纵向的规律。

  • (c)还有可能学习到和数据本身相关的attention。比如下图,第二列第二张学习到了鸟的边缘。

  • (d) 64-128层的注意力是高度稀疏的,只有极少的像素点有较高的注意力。

无论如何,注意力权重高的地方只占一小部分,这就为 稀疏注意力 提供了数据上的支持。作为解决注意力平方问题的早期论文,本文从图像生成的问题上揭示了attention的原罪,那就是其实不需要那么 密集 的注意力。

1.2 Factorized Self-attention

Sparse Transformer就是把full self-attention 分解 成若干个小的、复杂度低的self-attention。这个过程叫做factorization。

定义集合 , 这个集合中的每个元素还是集合,表示第i位input可以关注的位置。对于full-attention, 显然就是 {j:j<i}.

每个位置的attention现在就变成了下图公式。其实没多大变化,只不过以前可以关注自己之前的所有位置,现在只能关注到一些特定的位置而已。

对于factorized self-attention,使用p个sparse注意力头,每个sparse注意力头有着不同的关注列表,第m个注意力头的关注列表记作

  • 为了保证sparse注意力头的高效性( efficiency ), 我们必须要保证 是 复杂度的。

  • 同时,为了保证sparse注意力头是有效( valid )的,我们需要保证每个位置都可以经过一些路径attend到 之前所有位置 (毕竟,这样才属于"factorize" full -attention)。同时这个路径长度不超过p+1,这样保证所有原本在全注意力上能够传递的信号在稀疏注意力的框架下仍然可以有效传递。

两种可能的sparse attention方法

当p = 2时,即两个注意力头的时候,文章给出了如下两种可以的sparse attention方法,能够满足上文所述的efficiency和valid条件。

(1)strided attention

  • 一个注意力头只能关注当前位置前 个位置

  • 另一个注意力头只能关注当前位置前面隔 "跳"的位置

这样相当于关注当前行、当前列的信息,就如之前看的图像生成例子中的(b)一样。所以,这种注意力机制比较适用于图像。

(2)fixed attention

  • Ai(1) = {j: floor(j/l) = floor(i/l)}

  • Ai(2) = {j: j mod l ∈ {t, t+1, ..., l}},其中t=l-c且c是超参数。

一般情况下,l取值为{128, 256}, c取值为{8, 16, 32}。这种模式非常适合于NLP问题,因为一般一句话的最后一个hidden state(下图浅蓝色)包含了整句话最多的意思。另外一个注意力头也可以关注到当前位置的前面每个token。

稀疏注意力的组合

一个直接的方法是在不同的层使用不同稀疏机制。这样每个层的不同机制”交织(interleave)“在一起。

另一种方式则是在每个层使用 组合 的稀疏注意力,组合的方法则是把经过不同稀疏注意力机制的输出concat起来,就像普通的多头一样。

深度残差Transformer

深层次的Transformer训练起来十分困难,因为使用残差的方式会比较好。除了我们熟悉的transformer层内的layernorm之外,还增加了 层间 的残差连接,可以处理上百层的层。

2. Longformer

  • 论文:Longformer: The Long-Document Transformer

  • 地址:https://arxiv.org/pdf/2004.05150.pdf

2.1 问题提出

BERT模型能处理的最大序列长度是512. 这是因为普通transformer的时间复杂度是随着序列长度n而平方增长的。如果我们想要处理更长的序列该怎么办呢?

  • 最简单的方法就是直接截断成512长度的。这点普遍用于文本分类问题。

  • 截成多个长度为512的序列段(这些序列段可以互相overlapping),每个都输入给Bert获得输出,然后将多段的输出拼接起来。

  • 两个阶段去解决问题,就像搜索里面的召回 - 排序一样。一般用于Question-Answer问题,第一个阶段去选择相关文档,第二个阶段去找到对应的answer。

无论哪种方式,毫无疑问都会带来损失:截断会带来损失,两阶段会带来cascading error。如果能直接处理长序列就好了。

2.2 局部和全局attention的结合

Longformer将局部attention和全局attention结合起来,局部attention用来捕捉局部信息,一般用于 底层 (就像上文sparse attention中看到的,底层attention主要关注局部信息,是十分稀疏的)。全局attention则捕捉全局信息,用于 高层 ,目的在于综合所有的信息,从而得到一个好的representation。

Sliding window

滑动窗口的大小为w,那么每个位置只attend前后w/2个位置。将模型多层叠加起来之后, 高层 的每个位置都可以关注到input的每个位置(就像卷积的感受野一样,这里可以有全局感受野)。一个 层的transformer,最上层的感受野是 的。

这样,每一层的计算复杂度就是而不是的了。

另外,每一层的w其实可以不同,鉴于越高层需要的全局信息越多,可以在层级较高的时候把w调大。来达到模型效率(efficiency)和模型表达能力(representation capacity)的平衡。

Dilated Sliding Window

引入dilated window的目的是为了再避免增加计算量的情况下继续增大感受野,类似空洞卷积。一个window有着大小为d的gap,那么最高层的感受野就是 的。

对于多头注意力,可以让有些头不用dilation,专注于关注 局部 信息;有些头用dilation,关注 更远 的信息。另外,底层不适合用dilated sliding window, 因为底层需要去学习局部的信息;高层可以适当的使用少量的dilated window,以降低参数量。

2.3 Global Attention

究竟要选择哪种attention方式,其实是和任务有关的。对于MLM任务,或许只关注局部信息就足够了,所以使用滑窗是可以的;但是对于分类任务,BERT模型把整句话的信息都集中在了[CLS]中,所以 [CLS]应该能够关注到所有位置 。对于QA,我们将question和document拼接起来送入transformer中,由于每个位置都需要去比较看是否贴近 question ,所以理应所有位置都能关注到question的每个token,因此question的每个token需要具有全局注意力。

这里的”全局注意力“指的是,某个位置上的token可以关注所有其他位置,所有其他位置也都可以关注这个token。具体要选择那个位置赋予全局注意力,是和任务的性质有关的。

3. Transformer-XL

其实transformer-XL并不是解决transformer复杂度问题的,而是用来解决长文本的long-term dependency问题。但是transformer-XL在推理阶段可以达到比vanilla transformer快1800倍的加速,所以在这里也一并介绍了。

3.1 问题的提出

由于BERT等transformer模型的最长输入长度只有512,在处理长文本的时候只能像我们上文说的那样,截成若干个512长度的片段(segment),依次输入到BERT中训练,如下图所示。这样导致的问题就是,数据最多只能关注到自己所在片段的那512个token,段和段之间的信息丢失了。

在测试阶段,以文本生成这种 自回归任务 为例,需要依次取时间片为L = 512 的分段,然后将整个片段提供给模型后预测一个结果。在下个时间片时再将这个分段向 右移一个单位 ,这个新的片段也将通过整个网络的计算后得到一个值。Transformer的这个特性导致其预测阶段的 计算量是非常大的

3.2 Transformer XL

Transformer-XL的核心包括两部分:片段循环(segment-level recurrence)和相对位置编码(relative positional encoding)

3.2.1 Segment-Level Recurrence with State Reuse

在训练阶段,上一个segment的隐藏状态会被 缓存下来 ,然后在计算当前段的时候再重复使用上一个segment的隐层状态。因为上个片段的特征在当前片段进行了 重复使用 ,这也就赋予了Transformer-XL建模更长期的依赖的能力。

长度为的连续两个segment表示为 和 。的隐层节点的状态表示为,其中是隐层节点的维度。的隐层节点的状态的计算过程为:

其中表示stop-gradient,表示这一部分并不参与BP的计算, 表示两个隐层节点在长度维度进行拼接。

不要被这个复杂的公式吓到!其实它想表达的意思很简单,就是每次在算一个segment的self-attention时,用当前这个segment的每个token当成Query向量,然后当前这个segment+上一个segment的每个token当成Key和Value向量,Query去关注Key和Value。这样,就把两个原本割裂的segment用attention给”粘合“了起来。记segment长度为N,那么一个L层的网络,最上面的层可以关注到的”感受野“就是O(N*L). 训练阶段如下:

除了能够关注到更远的位置以外,另一个好处 推理速度 的提升。Transformer的自回归架构每次只能前进一个time step,而Transfomer-XL的推理过程直接复用上一个片段的表示而不是从头计算,每次可以前进一个 segment长度 。其实这是一个空间换时间的方法。在原先transformer的方法中,推理时每次都只移动一个time step,因此只需要记录上一个segment的最后一个hidden state。现在则需要记录上一个segment的所有hidden state。推理阶段如下:

3.2.2 相对位置编码

Transformer的位置编码是以segment为单位的,表示为,第个元素表示的是在这个segment中第个元素的位置编码,表示的是能编码的最大长度,即segment长度。对于不同的segment来说,它们的位置编码是完全相同的,我们完全没法确认它属于哪个segment或者它在分段之前的输入数据中的相对位置。

为了解决这个问题,可以采用相对位置编码的方式。其思想是:一个位置在i的query向量去关注位置j的key向量,我们并不需要知道i和j是什么,真正重要的是i-j这个 相对距离 。所以,使用一个相对位置偏差的embedding矩阵 来进行位置偏差的编码。之后,我们需要把这个描绘相对位置的embedding融入到传统transformer的attention计算中去。

那么,如何做到这样的融合呢?

位置i的向量 作为query,要去和位置j的向量 计算注意力权重( 作为key),使用 绝对位置 的计算公式如下。其中,E表示embedding,U代表绝对位置编码。

将相对位置编码融入attention计算之后:

可以发现做了如下的几处改进:

1) 被拆分成 和 ,也就是说输入序列和位置编码不再共享权值。(蓝色和浅黄色部分)

2)绝对位置编码换成了相对位置编码 (棕色)。

3)引入了两个新的可学习的参数 和 来替换Transformer中的query向量 。表明 对于所有的query位置对应的query位置向量是相同的。因为我们已经把相对位置编码融入了key端,那么query端就不再需要位置编码了。(红色和绿色部分)

一起交流

想和你一起学习进步!『NewBeeNLP』目前已经建立了多个不同方向交流群(机器学习 / 深度学习 / 自然语言处理 / 搜索推荐 / 图网络 / 面试交流 / 等),名额有限,赶紧添加下方微信加入一起讨论交流吧!(注意一定o要备注信息才能通过)

不得不看!降低Transformer复杂度的方法相关推荐

  1. 一文看懂推荐系统:召回02:Swing 模型,和itemCF很相似,区别在于计算相似度的方法不一样

    一文看懂推荐系统:召回02:Swing 模型,和itemCF很相似,区别在于计算相似度的方法不一样 提示:最近系统性地学习推荐系统的课程.我们以小红书的场景为例,讲工业界的推荐系统. 我只讲工业界实际 ...

  2. NC:MetaSort通过降低微生物群落复杂度以突破宏基因组组装难题

    点评:目前本领域研究最大的问题是缺少大量细菌基因组作为参考宏基因组序列.这一问题严重限制了研究的准确度和进一步功能研究.而目前的研究还主要集中在扩增子和宏基因组数据的宏观描述上,即使发布了大量宏基因组 ...

  3. 计算机主机平时怎么保养,怎样保养电脑(不得不看的四个好习惯)

    怎样保养电脑:这是很多用户都会考虑的问题,因为电脑的寿命跟用户的保养程度是成正比的.就是有的用户对电脑保养的一些常识不太了解,导致电脑寿命大幅度缩短的情况,小编今天教大家爱惜电脑不得不看的四个好习惯来 ...

  4. 想要专升本你不得不看的全干货_吐血整理_专升本_计算机文化基础( 十 三 )

    大家好,我是阿Ken.很快就要整理完第三章了~ 对于专升本_计算机文化基础我已经在博客里整理了已经一半多了,希望能够在我整理后能够帮助其他的小伙伴,这月底整理完所有的专升本_计算机文化基础的笔记,感兴 ...

  5. 蔬菜大棚成本_蔬菜大棚的类型及建造成本 不得不看

    原标题:蔬菜大棚的类型及建造成本 不得不看 大棚蔬菜利润与成本 一般的每亩地的蔬菜大棚骨架需要5500元左右,为了节省造价及延长大棚的使用时间,比较经济的建造大棚的方法是,用菱镁大棚骨架以及塑料纸;如 ...

  6. java降低if的圈复杂度_如何降低圈复杂度?

    我正在研究将RequestDTO发送到Web服务的类.我需要先验证请求,然后再发送. 可以从3个不同的地方发送请求,每个" requesttype"都有不同的验证规则,例如requ ...

  7. ​2018你不得不看的国内CRM软件排行榜

    2018你不得不看的国内CRM软件排行榜 短短几年时间,CRM在中国的发展就已经非常迅猛,现在已经成为了管理软件增长最快的产业.在我们总结的CRM软件排行榜中,腾讯企点的CRM软件赫然摆在前列.而CR ...

  8. Peer J:整合高通量绝对丰度定量方法解析土壤细菌群落及动态

    本文转自"上海天昊生物",已获授权 英文题目: Assessing soil bacterial community and dynamics by integrated high ...

  9. python 文本相似度计算函数_四种计算文本相似度的方法对比

    作者:Yves Peirsman 编译:Bing 编者按:本文作者为Yves Peirsman,是NLP领域的专家.在这篇博文中,作者比较了各种计算句子相似度的方法,并了解它们是如何操作的.词嵌入(w ...

最新文章

  1. 利用AOP实现对方法执行时间的统计
  2. java jar包命令行下可以双击不运行解决方法(改变java默认图标)
  3. Android Activity中加入View后进行后台截屏截图
  4. 带你玩转Pandas
  5. leetcode 274, 275. H-Index I, II(H 指数问题合集,线性查找/二分查找)
  6. POJ3764-The xor-longest Path【Trie(字典树)】
  7. Qt数字与字符串之间的相互转换
  8. gstreamer读取USB摄像头H264帧并用rtmp推流
  9. 算法高级(41)-推荐算法实现
  10. import librosa出错解决方案
  11. 位图切割器位图裁剪器
  12. 查找php超时原因_php环境搭建(正确配置nginx和php)
  13. 一网打尽 SCI、SCIE、SSCI 、EI等指标及影响因子查询
  14. 行政区村界线_中国各省界线是如何形成的?古代行政区划界原则
  15. 小米线刷教程+小米8背面指纹版的MIUI10和MIUI11包分享
  16. Android人脸支付功能,android支付宝上刷脸支付的人脸识别技术
  17. Hot and cold pages
  18. Android Studio获取数字签名(SHA1)
  19. html5 3d场景设计,基于 HTML5 WebGL 的加油站 3D 可视化监控
  20. u盘坏了数据可以恢复吗?实用小方法

热门文章

  1. Axure通用版电商后台管理系统+通用版移动端商城商户端+电商管理系统+对账管理+消息管理+内容管理+运营管理、会员管理、订单管理、促销管理、财务管理+通用版商城前后端电商系统+电商用户数据大屏看板
  2. 房屋征收管理系统高保真原型征收项目管理(项目预警)房屋测绘管理(测绘确权)协议管理测绘报表管理web端后台管理系统
  3. 记录一次OOM排查经历
  4. Python---String 字符串类型
  5. 中国象棋口诀及要领精髓
  6. adb shell top 使用
  7. 如何拔出手上的刺,假如不用缝衣针挑出来的话
  8. golang(7 方法重写)
  9. iOS 关于NSString的一些方法
  10. [STL][C++]LIST