Fastformer

首先开头先给出FastFormer(https://arxiv.org/abs/2108.09084)的结构示意图

当然,很多人第一眼看到这个图会有一定的理解,也存在一些疑惑,这样的结构如何实现平方复杂度到线性复杂度的转换呢,相信你看了下面的内容之后,对于上述结构会有个更好的认识。

概述

首先,论文中指出,FastFormer是一种在线性复杂度下就可以实现上下文建模的Transformer变体。在FastFormer中,作者首先利用additive attention mechanism从原有注意力机制的Query矩阵得到全局查询向量q\mathbf{q}q,然后利用元素级别的点乘实现对Key和q\mathbf{q}q的交互建模,得到p1,p2,⋯,pN\mathbf{p_1,p_2,\cdots,p_N}p1​,p2​,⋯,pN​,之后再利用相同的机制得到全局键向量k\mathbf{k}k,在此利用元素级别的点乘实现对Value和k\mathbf{k}k的交互建模,得到u1,u2,⋯,uN\mathbf{u_1,u_2,\cdots,u_N}u1​,u2​,⋯,uN​,再将所得到的向量经过Transformation模块,得到r1,r2,⋯,rN\mathbf{r_1,r_2,\cdots,r_N}r1​,r2​,⋯,rN​,此时利用残差的思想,加上原始的Query矩阵,得到FastFormer的输出。

经过上面描叙,相信你已经了解了模块的整个流程,下面我们会进一步了解模型的细节部分以及如何实现注意力。

FastFormer

给定FastFormer的输入为E=[e1,e2,⋯,eN]∈RN×d\mathbf{E}=\mathbf{[e_1,e_2,\cdots , e_N]}\in \mathbb{R}^{N\times d}E=[e1​,e2​,⋯,eN​]∈RN×d,和原始的attention机制一样,首先经过三个独立的线性变换模块,分别得到Q,K,V∈RN×d\mathbf{Q,K,V}\in \mathbb{R}^{N\times d}Q,K,V∈RN×d,分别表示为[q1,q2,⋯,qN]\mathbf{[q_1,q_2,\cdots , q_N]}[q1​,q2​,⋯,qN​][k1,k2,⋯,kN]\mathbf{[k_1,k_2,\cdots , k_N]}[k1​,k2​,⋯,kN​],[v1,v2,⋯,vN]\mathbf{[v_1,v_2,\cdots , v_N]}[v1​,v2​,⋯,vN​]。

普通的Attention机制采用点积注意机制对query和key之间的交互进行建模,但是整个的复杂度是平方级别的。作者为了降低建模的复杂度,在进行建模之前,利用additive attention机制对序列进行总结,在线性复杂度下就可以得到一个全局向量q\mathbf{q}q,后续再利用这个全局查询语义向量进行建模。上述的整个计算过程的数学公式如下:

αi=exp⁡(wqTqi/d)∑j=1Nexp⁡(wqTqj/d)q=∑i=1Nαiqi\begin{aligned} \alpha_i&=\frac{\exp( \textbf w_q^T\mathbf{q}_i/\sqrt d)}{\sum_{j=1}^N\exp( \textbf w_q^T\mathbf{q}_j/\sqrt d)}\\ \mathbf q&=\sum_{i=1}^N\alpha_i\mathbf q_i \end{aligned}αi​q​=∑j=1N​exp(wqT​qj​/d​)exp(wqT​qi​/d​)​=i=1∑N​αi​qi​​

其中wqT∈Rd\textbf w_q^T\in \mathbb R^dwqT​∈Rd是一个可学习的参数。

Fastformer的另一个核心问题是如何对全局查询向量q\mathbf{q}q与键矩阵key之间的交互进行建模。一般来说我们可以尝试向量相加、concat进行建模,但上述操作不能区别q\mathbf{q}q对不同键的影响,不利于理解上下文。作者则选择采用元素积(Element-wise product)来建模两个向量之间非线性关系,利用元素积操作得到p1,p2,⋯,pN\mathbf{p_1,p_2,\cdots,p_N}p1​,p2​,⋯,pN​,计算公式如下:

pi=q∗ki\begin{aligned} \mathbf p_i = \mathbf {q}*\mathbf {k}_i \end{aligned}pi​=q∗ki​​

同样,类似于上述流程利用additive attention和Element-wise product实现p1,p2,⋯,pN\mathbf{p_1,p_2,\cdots,p_N}p1​,p2​,⋯,pN​和v1,v2,⋯,vN\mathbf{v_1,v_2,\cdots , v_N}v1​,v2​,⋯,vN​的交互建模,最后连接到一个线性层得到矩阵R=[r1,r2,⋯,rN]∈RN×d\mathbf{R}=\mathbf{[r_1,r_2,\cdots , r_N]}\in \mathbb{R}^{N\times d}R=[r1​,r2​,⋯,rN​]∈RN×d

βi=exp⁡(wkTqi/d)∑j=1Nexp⁡(wkTqj/d)k=∑i=1Nβipiui=k∗viR=WRU\begin{aligned} \beta_i&=\frac{\exp( \textbf w_k^T\mathbf{q}_i/\sqrt d)}{\sum_{j=1}^N\exp( \textbf w_k^T\mathbf{q}_j/\sqrt d)}\\ \mathbf k&=\sum_{i=1}^N\beta_i\mathbf p_i\\ \mathbf u_i &= \mathbf {k}*\mathbf {v}_i\\\mathbf R &=\text W_R \mathbf U \end{aligned}βi​kui​R​=∑j=1N​exp(wkT​qj​/d​)exp(wkT​qi​/d​)​=i=1∑N​βi​pi​=k∗vi​=WR​U​

基于残差的思想再加上初始矩阵Q\mathbf QQ,即可得到模块最后的输出O\mathbf OO

O=Q+R\begin{aligned} \mathbf O = \mathbf{Q+R} \end{aligned}O=Q+R​

此外,作者还指出在多头的情况下,每个头的操作都是一样的,只是参数不一样,最终将每个头的输出在隐藏维度上进行联接。受到(Linformer: Self-attention with linear complexity )中共享参数的启发,作者将Query和Value的转换参数进行共享,减少内存成本;还将每个FastFormer Layer的参数进行共享,以进一步减少参数大小,降低过拟合的风险。

模型分析

复杂度分析

时间复杂度

对于additive attention来说,计算得到全局向量q,k\mathbf{q,k}q,k的时间复杂度和内存成本均为O(N∗d)O(N*d)O(N∗d);对于element-wise product来说,时间复杂度和内存成本均为O(N∗d)O(N*d)O(N∗d),因此整个的复杂度就是O(N∗d)O(N*d)O(N∗d)。相对于原始Transformer( Transformer)的复杂度O(N2∗d)O(N^2*d)O(N2∗d),复杂度由平方级下降到线性级别。

参数量}

FastFormer中所引入的额外参数只有WQ,WK,WV,WrT∈Rd×dW_Q,W_K,W_V,W_r^T\in \mathbb R^{d\times d}WQ​,WK​,WV​,WrT​∈Rd×d和WqT,WkT∈RdW_q^T,W_k^T\in \mathbb R^dWqT​,WkT​∈Rd,此时额外参数量为4d2+2d4d^2+2d4d2+2d,采用层内参数共享(WQ=WVW_Q=W_VWQ​=WV​)和层间参数共享后,模型的参数量最终为(3d2+2d)∗h(3d^2+2d)*h(3d2+2d)∗h。而仅仅计算原始Transformer( Transformer)的注意力矩阵的参数量(不包含前馈神经网络和正则化等模块参数),模型的单层参数量就已经达到4d2∗h4d^2*h4d2∗h。

表征分析

这里我们通过数学公式来分析FastFormer内部究竟做了哪些。

首先根据公式(2)和(3),我们可以得到:

pi=ki⋅∑jαjqj=∑jαjqjki\begin{aligned} \mathbf {p}_i&=\mathbf k_i\cdot \sum_j\alpha_j\mathbf q_j\\ &=\sum_j\alpha_j\mathbf q_j\mathbf k_i \end{aligned}pi​​=ki​⋅j∑​αj​qj​=j∑​αj​qj​ki​​

而由原有Attention我们可以知道

aij=softmax(QKTd)ai′=∑jsoftmax(QKTd)=∑jγijqikj\begin{aligned} a_{ij}&=\text {softmax}(\frac{QK^T}{\sqrt d}) \\ a_i^{'}&=\sum_j\text {softmax}(\frac{QK^T}{\sqrt d})\\ &=\sum_j\gamma_{ij}\mathbf q_i\mathbf k_j \end{aligned}aij​ai′​​=softmax(d​QKT​)=j∑​softmax(d​QKT​)=j∑​γij​qi​kj​​

推导到这里我们可以看出,FastFormer中的矩阵P\mathbf PP充当着类似于原有Attention Map的角色,同理我们还能得到:

ui=vi⋅∑jβjpj=∑jβjvipj=∑jβjvi∑tαtqtkj=∑j∑tβjαtqtkjvi\begin{aligned} \mathbf {u}_i&=\mathbf v_i\cdot \sum_j\beta_j\mathbf p_j\\ &=\sum_j\beta_j\mathbf v_i\mathbf p_j\\ &=\sum_j\beta_j\mathbf v_i\sum_t\alpha_t\mathbf q_t\mathbf k_j\\ &=\sum_j\sum_t\beta_j\alpha_t\mathbf q_t\mathbf k_j\mathbf v_i \end{aligned}ui​​=vi​⋅j∑​βj​pj​=j∑​βj​vi​pj​=j∑​βj​vi​t∑​αt​qt​kj​=j∑​t∑​βj​αt​qt​kj​vi​​

而原生Transformer模块中,我们得到的中间特征为

Oi=∑jaij⋅vj=∑i∑jϕijqikjvj\begin{aligned} O_i&=\sum _ja_{ij}\cdot \mathbf v_j\\ &=\sum _i\sum _j\phi_{ij} \mathbf q_i\mathbf k_j\mathbf v_j \end{aligned}Oi​​=j∑​aij​⋅vj​=i∑​j∑​ϕij​qi​kj​vj​​

其形式和ui\mathbf {u}_iui​类似,最后的Transformation模块则是充当了原有的FFN,通过公式推理,我们可以发现作者所提出的FastFormer本质上是类似于Transformer架构的模型。

实验结果

作者展示模型在情感分类、主题分类上的实验结果,还对于模型的推理时间进行试验,都取得了SOTA的效果,如下所展示,在此不多做介绍,有兴趣的同学可以阅读原文进一步了解。





本文仅作为技术交流和分享,严禁未经授权挪作他用。如果对上述存在译文,或者想进一步沟通可以在评论区评论或者联系邮箱1377157216@qq.com

Fastformer论文解读相关推荐

  1. 自监督学习(Self-Supervised Learning)多篇论文解读(下)

    自监督学习(Self-Supervised Learning)多篇论文解读(下) 之前的研究思路主要是设计各种各样的pretext任务,比如patch相对位置预测.旋转预测.灰度图片上色.视频帧排序等 ...

  2. 自监督学习(Self-Supervised Learning)多篇论文解读(上)

    自监督学习(Self-Supervised Learning)多篇论文解读(上) 前言 Supervised deep learning由于需要大量标注信息,同时之前大量的研究已经解决了许多问题.所以 ...

  3. 可视化反投射:坍塌尺寸的概率恢复:ICCV9论文解读

    可视化反投射:坍塌尺寸的概率恢复:ICCV9论文解读 Visual Deprojection: Probabilistic Recovery of Collapsed Dimensions 论文链接: ...

  4. 从单一图像中提取文档图像:ICCV2019论文解读

    从单一图像中提取文档图像:ICCV2019论文解读 DewarpNet: Single-Image Document Unwarping With Stacked 3D and 2D Regressi ...

  5. 点云配准的端到端深度神经网络:ICCV2019论文解读

    点云配准的端到端深度神经网络:ICCV2019论文解读 DeepVCP: An End-to-End Deep Neural Network for Point Cloud Registration ...

  6. 图像分类:CVPR2020论文解读

    图像分类:CVPR2020论文解读 Towards Robust Image Classification Using Sequential Attention Models 论文链接:https:// ...

  7. CVPR2020论文解读:手绘草图卷积网络语义分割

    CVPR2020论文解读:手绘草图卷积网络语义分割 Sketch GCN: Semantic Sketch Segmentation with Graph Convolutional Networks ...

  8. CVPR2020论文解读:3D Object Detection三维目标检测

    CVPR2020论文解读:3D Object Detection三维目标检测 PV-RCNN:Point-Voxel Feature Se tAbstraction for 3D Object Det ...

  9. CVPR2020论文解读:三维语义分割3D Semantic Segmentation

    CVPR2020论文解读:三维语义分割3D Semantic Segmentation xMUDA: Cross-Modal Unsupervised Domain Adaptation for 3D ...

最新文章

  1. java.util.concurrent BlockingQueue详解
  2. 一些有用的Linux命令
  3. 创建两个相同名称的文件夹
  4. MFC让窗口最前端显示
  5. Linux 内核完成接口
  6. 使用LDA模型对新的文档进行分类
  7. Linux系统运维-Telnet命令
  8. 图像拼接算法的基本原理
  9. oracle 11g 重置,oracle数据库重置
  10. 一代人终将老去,但总有人正年轻
  11. MySql每晚12点都会弹出这个?
  12. flutter comsumer局部刷新的问题
  13. IP地址分配和IP地址的划分
  14. ObjectARX自定义实体
  15. 树结构(Java实现)
  16. Android中使用Post带参数请求的方法
  17. arduino液晶显示频
  18. android打印 编辑并打印 word
  19. 微信小程序 --长按复制、点击复制实现
  20. Ubuntu 10.04环境下载编译Android-2.2.1 (froyo) 源代码 1/2

热门文章

  1. 编码学习——UTF-8与Unicode互转具体流程
  2. php和mysql不在一台机器上_MySQL_在同一台机器上运行多个 MySQL 服务,**************************************** - phpStudy...
  3. A*搜索算法AStar_BFS
  4. java excel 导出 下载_使用Java导出Excel表格并由浏览器直接下载
  5. 【我的Android进阶之旅】解决魅族手机USB调试时,无法授权出现“Because an app is obscuring a permission request.”错误提示的问题
  6. 第二章 五行,金木水火土
  7. POI之excel固定模板导出
  8. 基于多阈值的形态提取遥感图像中的沿海线的特征方法(Qu Jishuang)
  9. AI+教育 I 69天流利说APP学习浅谈自适应学习
  10. 在电脑桌面上添加便签的方法步骤解析