Paper Reading《Fastformer: Additive Attention Can Be All You Need 》

Paper url;笔者写作时作者Github尚未开源。Unofficial版本复现:TF版本,Pytorch版本,Keras版本;以及推荐一位Youtube上的大神Yannic Kilcher对本文进行的讲解。

1. Intuition

传统Transformer机制囿于512个token文本长度限制,涌现出了以下几种当前主流的Transformer变种,但同时也也存在着相应的缺点:

  1. 使用稀疏注意力机制降低计算复杂度。e.g. Longformer, BigBird.

    缺点:自注意力机制需要更多的tokens参与运算,大幅度提升了时间和计算开销,速度慢。

  2. 利用hashing编码技术加速自注意力机制的计算。e.g. Reformer

    缺点:计算复杂度常数很大,在处理长度有限的常见序列时效率很低。

  3. 近似地计算自注意力。e.g. Linformer

    缺点:以上下文标记的方式接近自注意力,不是文本建模的最佳方法。当序列长度很长时计算效率低。

2. Fast-former Architecture

2.1 回顾下Self-Attention

image-20210908150317288

Attention(Q,K,V)=softmax(QKTdk)V\text{Attention}(Q,K,V) = \text{softmax}(\frac{QK^T}{\sqrt{d_k}})VAttention(Q,K,V)=softmax(dk

QKT)V

先对Q,KQ,KQ,K做点乘(i.e. dot-product)得到二者的相似度矩阵,矩阵的每一行都是一个query和所有key的相似性;然后用dk\sqrt{d_k}dk

进行scaling,之后用softmax函数后再和VVV相乘。

如何理解Self-Attention的Q,K ,V向量

目标是让模型能够各个单词词向量在上下文向量中代表着是哪种含义。

  • Key vector KKK:

    describes what the content of this token is so far, which allows the token to advertise what is has to offer。更像句子本身的内容(addressable representation)

  • Query Vector QQQ

    means what does this token wanna know about the other tokens in the sequence. 更像我想从别的token那得到什么信息。

    所以 KKK and QQQ 通常是不一样的

    通过计算QWQQ W^QQWQ以及KWKK W^KKWK的内积的softmax值表征了:针对于输入句中的每一个token所代表的信息想最终聚合到最后输出句子的token的。对每一个token都会算我们想要他这个token有多少信息“作用到”输出句子上。

  • Value Vector VVV:

    通过前文softmax得到的分布,我们输入到下一个layer的乘积。类似于斜率

2.2 Architecture

image-20210908145505532

根据input text的embedding matrix,记为E∈RN×dE\in\mathbb{R}^{N\times d}ERN×d,每个token代表的列向量记[e1,e2,⋯,eN][e_1,e_2,\cdots,e_N][e1,e2,,eN]

  1. 通过线性变换层得到矩阵Q,K,VQ,K,VQ,K,V的映射,Q,K,V∈RN×dQ,K,V\in \mathbb{R}^{N \times d}Q,K,VRN×d此处为原文typo写为了Rd×dR^{d\times d}Rd×d),Q=[q1,q2,⋯,qN]Q=[q_1,q_2,\cdots,q_N]Q=[q1,q2,,qN]K=[k1,k2,⋯,kN]K=[k_1,k_2,\cdots,k_N]K=[k1,k2,,kN]V=[v1,v2,⋯,vN]V=[v_1,v_2,\cdots,v_N]V=[v1,v2,,vN].

  2. 利用加性注意力价值概括融合(summarize)QQQ矩阵的信息得到浓缩上下文信息的global query vector

    QQQ矩阵转到qqq向量的权重如何得到?利用softmax的思想计算:
    αi=exp⁡(wqTqi/d)∑j=1Nexp⁡(wqTqj/d)\alpha_{i}=\frac{\exp \left(\mathbf{w}_{q}^{T} \mathbf{q}_{i} / \sqrt{d}\right)}{\sum_{j=1}^{N} \exp \left(\mathbf{w}_{q}^{T} \mathbf{q}_{j} / \sqrt{d}\right)} αi=j=1Nexp(wqTqj/d

    )exp(wqTqi/d

    )

    其中wq∈Rdw_q\in\mathbb{R}^dwqRd是一个要学习的向量参数。α∈RN×1\mathbf{\alpha}\in \mathbb{R}^{N\times1}αRN×1

    则global query vector即为q=∑i=1Nαiqi\mathbf{q}=\sum_{i=1}^{N} \alpha_{i} \mathbf{q}_{i}q=i=1Nαiqiq∈RN×1\mathbf{q}\in\mathbb{R}^{N\times1}qRN×1

  3. 我们知道在经典Transformer中我们是利用点积来对Q,KQ,KQ,K之间交互计算相似度进行建模,但Q⋅KTQ\cdot K^TQKT会带来二次的复杂度,那我们加性注意力机制如何降低运算复杂度的呢?

    • 拼接:简单地连接两个向量不能考虑它们之间的相互作用。
    • 相加:只能得到两个向量之间的线性相互作用,但也不能学习准确的上下文表示
    • **element-wise product:**对两个变量之间的非线性相互作用进行建模,这可能有助于对长序列的复杂上下文进行建模。(Xiang Wang, et al. 2017.)

    所以在此通过global query vector和KKK矩阵计算element-wise product得到global context-aware key matrix,即pi=q∗kip_i = q * k_ipi=qki,其实相对于传统transformer的步骤就是把qi⋅kiq_i\cdot k_iqiki变成了q∗kiq * k_iqki,可以叫他全局注意力矩阵,因为qqq向量是综合了所有token的信息。

    ※其实将上式展开可以看到:pi=ki∑jαjqj=∑jαjqjkjp_i = k_i\sum_j\alpha_j q_j=\sum_j\alpha_jq_jk_jpi=kijαjqj=jαjqjkj,虽然我们有二次成绩,但是我们并不会有二次的运算复杂度,Why?

    因为我们并没有原始的Transformer那样直接计算softmax{([QWQ⋅(KWK)T]/d}\text{softmax}\{([QW^Q\cdot (KW^K)^T]/\sqrt{d}\}softmax{([QWQ(KWK)T]/d

    },该部分需要把每个序列中的每个位置的token两两组合,即需要将两个n×dn\times dn×d的矩阵相乘,计算复杂度为O(n2)O(n^2)O(n2)

    而Fast-former将Q∈RN×dQ\in \mathbb{R}^{N\times d}QRN×d压缩为q∈RN×1q \in \mathbb{R}^{N\times1}qRN×1之后,就成为了在序列上的线性运算从而摆脱了softmax的非线性运算的影响,这也就是他为什么叫additive attention,也是本paper的trick所在。

    在此再对ppp也进行加性注意力机制,其每个pip_ipi的权重为:
    βi=exp⁡(wkTpi/d)∑j=1Nexp⁡(wkTpj/d)\beta_{i}=\frac{\exp \left(\mathbf{w}_{k}^{T} \mathbf{p}_{i} / \sqrt{d}\right)}{\sum_{j=1}^{N} \exp \left(\mathbf{w}_{k}^{T} \mathbf{p}_{j} / \sqrt{d}\right)} βi=j=1Nexp(wkTpj/d

    )exp(wkTpi/d

    )

    其中wk∈Rdw_k\in\mathbb{R}^dwkRd是一个要学习的向量参数。

    则global key vector即为k=∑i=1Nβipi\mathbf{k}=\sum_{i=1}^{N} \beta_{i} \mathbf{p}_{i}k=i=1Nβipi

  4. 同于以上步骤,通过global query vector qqq计算global context-aware key matrix得到global key vectorkkk的过程,我们再用global key vectorkkk来得到global key-value matrix,记为uuu

    也即ui=k∗viu_i=k*v_iui=kvi。因为后面不再需要value vector再去启发别的向量,所以将uiu_iui经过一层线性变换层得到输出R=[r1,r2,…,rN]∈RN×d\mathbf{R}=\left[\mathbf{r}_{1}, \mathbf{r}_{2}, \ldots, \mathbf{r}_{N}\right] \in \mathbb{R}^{N \times d}R=[r1,r2,,rN]RN×d

    至此,每个key和value向量都可以与global query/key向量交互以学习上下文表示。

  5. 最后,将经过线性变换后的输出R∈RN×dR\in\mathbb{R}^{N\times d}RRN×d加上global query vector q∈RN×1q\in\mathbb{R}^{N\times 1}qRN×1,要注意二者维度不一样,即在RRR的每一行都加上向量qqq即为Fast-former的final outpout。(这一步第一次看感觉大家也觉得有点trick,作者未做出原因解释。)

  6. 以上1-5为one head,多个Fast-former层叠加成Multi-head机制。通过借鉴Linformer(Wang et al. 2020)的参数共享技术,共享了value和query的线性变换层参数以降低参数量和过拟合的风险。(作者并未做出原因解释,可能是Linformer中做出了相应的解释?)

2.3 Complexity Analysis

对于global query and key vectors计算复杂度都是O(N⋅d)O(N\cdot d)O(Nd),总的计算复杂度为O(N⋅d)2O(N\cdot d)^2O(Nd)2,对比于原始经典Transformer则为O(N2⋅d)O(N^2\cdot d)O(N2d)

image-20210909185730941

3. Experiments

实验中使用Glove Embedding初始化token向量矩阵,本文选取了文本分类、新闻推荐及文本摘要的数据集进行验证。(后面更新的version的paper故意加上了新闻推荐的内容。)

3.1 实验效果

实验结果如下所示:

image-20210909184146039

综合来看竞争力不是很稳定。

在新闻推荐任务上的结果(MIND数据集):

image-20210909184352393

前面几种为常用的新闻推荐baselines。虽然最好的结果是集成的结果,但Fastformer还是比其他几个模型显著地高。

在文本摘要任务上的结果:

image-20210909185550166

整体而言效果不错。

3.2 时间效率比较

image-20210909193750404

可以看到Fast-former在时间效率上还是提升了很大一部分的。

3.3 不同组合特征的影响

针对2.2中第三点提到的问题,不同特征采用相加、拼接、点乘的组合方式的不同结果表现:

image-20210909194124537

点乘的方式还是显著的好很多。

3.4 不同参数共享方式的影响

综合来看还是采用Q-结合Layer-wise的效果最好,但其实好的并不显著,零点几个点。

4. 一些问题

  1. Fast-former的final output为何还要加上global query vector?理论上R=(r1,⋯,rN)R=(r_1,\cdots,r_N)R=(r1,,rN)已经融合了QKVQKVQKV三者的信息了,为何没有消融实验。只用global query vector,只用R=(r1,⋯,rN)R=(r_1,\cdots,r_N)R=(r1,,rN)效果又分别如何?

  2. 这是否还是Transformer?

    Quote from Yannic Kilcher:

    Kind of like.

    为什么不是?其实本文所命名的Q,K,V并不是原始Transformer中所提到的概念,你其实可以继续明明Q,K,V,X,Y等等好几竖列一列一列地应用加性注意力机制,并不限于这三个QKV矩阵。

    为什么是?他真正体现为self-attention得地方还是用了softmax的地方即计算global query vectorqqq的权重里面的wqw_qwq,他能够根据上下文动态地计算一个句子中不同的token的权重。结合原文的式子来看

    [1]αi=exp⁡(wqTqi/d)∑j=1Nexp⁡(wqTqj/d)\alpha_{i}=\frac{\exp \left(\mathbf{w}_{q}^{T} \mathbf{q}_{i} / \sqrt{d}\right)}{\sum_{j=1}^{N} \exp \left(\mathbf{w}_{q}^{T} \mathbf{q}_{j} / \sqrt{d}\right)}αi=j=1Nexp(wqTqj/d

    )exp(wqTqi/d

    )

    [2] q=∑i=1Nαiqi\mathbf{q}=\sum_{i=1}^{N} \alpha_{i} \mathbf{q}_{i}q=i=1Nαiqi

    [1]中的wqw_qwq是self-attention中的query,qi\mathbf{q_i}qi是self-attention中的key,[2]中的qi\mathbf{q_i}qi也是self-attention中的value。

    而Fast-former后面的Key 和 Value都是根据global query vector 静态学习到的,没有自己的学习参数矩阵,每一个单独的头算出来一个query,他是什么后面就定死了,只能通过多头来实现全方位的语义挖掘。

Paper Reading《Fastformer Additive Attention Can Be All You Need 》相关推荐

  1. Paper Reading《Taming Pretrained Transformers for Extreme Multi-label Text Classification 》

    @time:2020-11-30 github code arxiv paper SIGKDD 2020 Applied Data Track 1. 主要工作 针对极端多标签文本分类(Extreme ...

  2. paper reading——《Improving Person Re-identification by Attribute and Identity Learning》

    ##这篇文章是关于利用行人属性提升行人再识别 论文链接:https://arxiv.org/pdf/1703.07220.pdf ###摘要 行人再识别(reid)和属性识别有着一个共同的目标是描述行 ...

  3. Paper:《YOLOv4: Optimal Speed and Accuracy of Object Detection》的翻译与解读

    Paper:<YOLOv4: Optimal Speed and Accuracy of Object Detection>的翻译与解读 目录 YOLOv4的评价 1.四个改进和一个创新 ...

  4. Paper Reading 《SimCSE》

    Paper Reading: SimCSE SimCSE: Simple Contrastive Learning of Sentence Embeddings 尚未发表.Github. Paper. ...

  5. 2.Paper小结——《Privacy-preserving blockchain-based federated learning for traffic flow prediction》

    题目: 基于区块链的基于隐私保护的交通流量预测的联邦学习 0.Abstract: 交通流量预测已成为智能交通系统的重要组成部分.然而,现有的基于集中式机器学习的交通流量预测方法需要收集原始数据以进行模 ...

  6. 《4DenoiseNet: Adverse Weather Denoising from Adjacent Point Clouds》

    <4DenoiseNet: Adverse Weather Denoising from Adjacent Point Clouds> 一.引言:(阐述研究的动机,说明研究的目的) 问题: ...

  7. Paper:2017年的Google机器翻译团队《Transformer:Attention Is All You Need》翻译并解读

    Paper:2017年的Google机器翻译团队<Transformer:Attention Is All You Need>翻译并解读 目录 论文评价 1.Motivation: 2.创 ...

  8. Paper:Transformer模型起源—2017年的Google机器翻译团队—《Transformer:Attention Is All You Need》翻译并解读

    Paper:Transformer模型起源-2017年的Google机器翻译团队-<Transformer:Attention Is All You Need>翻译并解读 目录 <T ...

  9. Paper:《A Unified Approach to Interpreting Model Predictions—解释模型预测的统一方法》论文解读与翻译

    Paper:<A Unified Approach to Interpreting Model  Predictions-解释模型预测的统一方法>论文解读与翻译 导读:2017年11月25 ...

最新文章

  1. 安装模拟器遇到的问题
  2. jQuery Ajax 实例 ($.ajax、$.post、$.get)
  3. NHibernate Step By Step(10)-常用的配置属性
  4. java需要前台封装对象吗_javaEE之-----------类反射直接封装前台传过来的参数
  5. python制作动图-用Python制作迷宫GIF
  6. Web.xml 文件与server.xml 文件使用总结
  7. boost::hana::members用法的测试程序
  8. Net5 已经来临,让我来送你一个成功
  9. CityEngine Web Scene如何在IIS下部署
  10. [C++11]使用using和typedef给模板定义别名
  11. JAVA多线程总结(笔记)
  12. widget(6、menu)
  13. MyEclipse非正常关闭问题
  14. java生成excel到本地_java 将数据库中的数据导出成Excel文件 并保存到本地 将文件地址返回给前端...
  15. cuda_error_launch_failed: unspecified launch failure
  16. Java视频教程从入门到精通(2021完整版)
  17. 回溯法中解空间树的组织
  18. Python学习:使用pycharm运行py文件报错系统找不到指定的路径
  19. 中医秘笈:气不足则胖,血不足则瘦
  20. 牛客网-调整数组顺序使奇数位于偶数前面

热门文章

  1. keyshot 2023安装包下载及安装教程
  2. Google Ads买量基础篇:Google如何展示App广告系列
  3. 配置Hi提醒 让提醒消息可以转发到企业微信
  4. idea - 添加本地jar包依赖
  5. 第七章 凹凸映射 渐变纹理 遮罩纹理
  6. html长方形代码_html实现圆角矩形
  7. js 正则处理名字 只显示首和尾,中间用三个星号替代
  8. 在word里怎样调标尺?使用技巧分享!
  9. Jmeter-函数助手-随机函数的使用(模拟1000+个手机用户获取短信验证码)
  10. 讲真的,我后悔来到北大青鸟重庆大学城校区了......