Fastformer论文解读
Fastformer
首先开头先给出FastFormer(https://arxiv.org/abs/2108.09084)的结构示意图
当然,很多人第一眼看到这个图会有一定的理解,也存在一些疑惑,这样的结构如何实现平方复杂度到线性复杂度的转换呢,相信你看了下面的内容之后,对于上述结构会有个更好的认识。
概述
首先,论文中指出,FastFormer是一种在线性复杂度下就可以实现上下文建模的Transformer变体。在FastFormer中,作者首先利用additive attention mechanism从原有注意力机制的Query矩阵得到全局查询向量q\mathbf{q}q,然后利用元素级别的点乘实现对Key和q\mathbf{q}q的交互建模,得到p1,p2,⋯,pN\mathbf{p_1,p_2,\cdots,p_N}p1,p2,⋯,pN,之后再利用相同的机制得到全局键向量k\mathbf{k}k,在此利用元素级别的点乘实现对Value和k\mathbf{k}k的交互建模,得到u1,u2,⋯,uN\mathbf{u_1,u_2,\cdots,u_N}u1,u2,⋯,uN,再将所得到的向量经过Transformation模块,得到r1,r2,⋯,rN\mathbf{r_1,r_2,\cdots,r_N}r1,r2,⋯,rN,此时利用残差的思想,加上原始的Query矩阵,得到FastFormer的输出。
经过上面描叙,相信你已经了解了模块的整个流程,下面我们会进一步了解模型的细节部分以及如何实现注意力。
FastFormer
给定FastFormer的输入为E=[e1,e2,⋯,eN]∈RN×d\mathbf{E}=\mathbf{[e_1,e_2,\cdots , e_N]}\in \mathbb{R}^{N\times d}E=[e1,e2,⋯,eN]∈RN×d,和原始的attention机制一样,首先经过三个独立的线性变换模块,分别得到Q,K,V∈RN×d\mathbf{Q,K,V}\in \mathbb{R}^{N\times d}Q,K,V∈RN×d,分别表示为[q1,q2,⋯,qN]\mathbf{[q_1,q_2,\cdots , q_N]}[q1,q2,⋯,qN][k1,k2,⋯,kN]\mathbf{[k_1,k_2,\cdots , k_N]}[k1,k2,⋯,kN],[v1,v2,⋯,vN]\mathbf{[v_1,v_2,\cdots , v_N]}[v1,v2,⋯,vN]。
普通的Attention机制采用点积注意机制对query和key之间的交互进行建模,但是整个的复杂度是平方级别的。作者为了降低建模的复杂度,在进行建模之前,利用additive attention机制对序列进行总结,在线性复杂度下就可以得到一个全局向量q\mathbf{q}q,后续再利用这个全局查询语义向量进行建模。上述的整个计算过程的数学公式如下:
αi=exp(wqTqi/d)∑j=1Nexp(wqTqj/d)q=∑i=1Nαiqi\begin{aligned} \alpha_i&=\frac{\exp( \textbf w_q^T\mathbf{q}_i/\sqrt d)}{\sum_{j=1}^N\exp( \textbf w_q^T\mathbf{q}_j/\sqrt d)}\\ \mathbf q&=\sum_{i=1}^N\alpha_i\mathbf q_i \end{aligned}αiq=∑j=1Nexp(wqTqj/d)exp(wqTqi/d)=i=1∑Nαiqi
其中wqT∈Rd\textbf w_q^T\in \mathbb R^dwqT∈Rd是一个可学习的参数。
Fastformer的另一个核心问题是如何对全局查询向量q\mathbf{q}q与键矩阵key之间的交互进行建模。一般来说我们可以尝试向量相加、concat进行建模,但上述操作不能区别q\mathbf{q}q对不同键的影响,不利于理解上下文。作者则选择采用元素积(Element-wise product)来建模两个向量之间非线性关系,利用元素积操作得到p1,p2,⋯,pN\mathbf{p_1,p_2,\cdots,p_N}p1,p2,⋯,pN,计算公式如下:
pi=q∗ki\begin{aligned} \mathbf p_i = \mathbf {q}*\mathbf {k}_i \end{aligned}pi=q∗ki
同样,类似于上述流程利用additive attention和Element-wise product实现p1,p2,⋯,pN\mathbf{p_1,p_2,\cdots,p_N}p1,p2,⋯,pN和v1,v2,⋯,vN\mathbf{v_1,v_2,\cdots , v_N}v1,v2,⋯,vN的交互建模,最后连接到一个线性层得到矩阵R=[r1,r2,⋯,rN]∈RN×d\mathbf{R}=\mathbf{[r_1,r_2,\cdots , r_N]}\in \mathbb{R}^{N\times d}R=[r1,r2,⋯,rN]∈RN×d
βi=exp(wkTqi/d)∑j=1Nexp(wkTqj/d)k=∑i=1Nβipiui=k∗viR=WRU\begin{aligned} \beta_i&=\frac{\exp( \textbf w_k^T\mathbf{q}_i/\sqrt d)}{\sum_{j=1}^N\exp( \textbf w_k^T\mathbf{q}_j/\sqrt d)}\\ \mathbf k&=\sum_{i=1}^N\beta_i\mathbf p_i\\ \mathbf u_i &= \mathbf {k}*\mathbf {v}_i\\\mathbf R &=\text W_R \mathbf U \end{aligned}βikuiR=∑j=1Nexp(wkTqj/d)exp(wkTqi/d)=i=1∑Nβipi=k∗vi=WRU
基于残差的思想再加上初始矩阵Q\mathbf QQ,即可得到模块最后的输出O\mathbf OO
O=Q+R\begin{aligned} \mathbf O = \mathbf{Q+R} \end{aligned}O=Q+R
此外,作者还指出在多头的情况下,每个头的操作都是一样的,只是参数不一样,最终将每个头的输出在隐藏维度上进行联接。受到(Linformer: Self-attention with linear complexity )中共享参数的启发,作者将Query和Value的转换参数进行共享,减少内存成本;还将每个FastFormer Layer的参数进行共享,以进一步减少参数大小,降低过拟合的风险。
模型分析
复杂度分析
时间复杂度
对于additive attention来说,计算得到全局向量q,k\mathbf{q,k}q,k的时间复杂度和内存成本均为O(N∗d)O(N*d)O(N∗d);对于element-wise product来说,时间复杂度和内存成本均为O(N∗d)O(N*d)O(N∗d),因此整个的复杂度就是O(N∗d)O(N*d)O(N∗d)。相对于原始Transformer( Transformer)的复杂度O(N2∗d)O(N^2*d)O(N2∗d),复杂度由平方级下降到线性级别。
参数量}
FastFormer中所引入的额外参数只有WQ,WK,WV,WrT∈Rd×dW_Q,W_K,W_V,W_r^T\in \mathbb R^{d\times d}WQ,WK,WV,WrT∈Rd×d和WqT,WkT∈RdW_q^T,W_k^T\in \mathbb R^dWqT,WkT∈Rd,此时额外参数量为4d2+2d4d^2+2d4d2+2d,采用层内参数共享(WQ=WVW_Q=W_VWQ=WV)和层间参数共享后,模型的参数量最终为(3d2+2d)∗h(3d^2+2d)*h(3d2+2d)∗h。而仅仅计算原始Transformer( Transformer)的注意力矩阵的参数量(不包含前馈神经网络和正则化等模块参数),模型的单层参数量就已经达到4d2∗h4d^2*h4d2∗h。
表征分析
这里我们通过数学公式来分析FastFormer内部究竟做了哪些。
首先根据公式(2)和(3),我们可以得到:
pi=ki⋅∑jαjqj=∑jαjqjki\begin{aligned} \mathbf {p}_i&=\mathbf k_i\cdot \sum_j\alpha_j\mathbf q_j\\ &=\sum_j\alpha_j\mathbf q_j\mathbf k_i \end{aligned}pi=ki⋅j∑αjqj=j∑αjqjki
而由原有Attention我们可以知道
aij=softmax(QKTd)ai′=∑jsoftmax(QKTd)=∑jγijqikj\begin{aligned} a_{ij}&=\text {softmax}(\frac{QK^T}{\sqrt d}) \\ a_i^{'}&=\sum_j\text {softmax}(\frac{QK^T}{\sqrt d})\\ &=\sum_j\gamma_{ij}\mathbf q_i\mathbf k_j \end{aligned}aijai′=softmax(dQKT)=j∑softmax(dQKT)=j∑γijqikj
推导到这里我们可以看出,FastFormer中的矩阵P\mathbf PP充当着类似于原有Attention Map的角色,同理我们还能得到:
ui=vi⋅∑jβjpj=∑jβjvipj=∑jβjvi∑tαtqtkj=∑j∑tβjαtqtkjvi\begin{aligned} \mathbf {u}_i&=\mathbf v_i\cdot \sum_j\beta_j\mathbf p_j\\ &=\sum_j\beta_j\mathbf v_i\mathbf p_j\\ &=\sum_j\beta_j\mathbf v_i\sum_t\alpha_t\mathbf q_t\mathbf k_j\\ &=\sum_j\sum_t\beta_j\alpha_t\mathbf q_t\mathbf k_j\mathbf v_i \end{aligned}ui=vi⋅j∑βjpj=j∑βjvipj=j∑βjvit∑αtqtkj=j∑t∑βjαtqtkjvi
而原生Transformer模块中,我们得到的中间特征为
Oi=∑jaij⋅vj=∑i∑jϕijqikjvj\begin{aligned} O_i&=\sum _ja_{ij}\cdot \mathbf v_j\\ &=\sum _i\sum _j\phi_{ij} \mathbf q_i\mathbf k_j\mathbf v_j \end{aligned}Oi=j∑aij⋅vj=i∑j∑ϕijqikjvj
其形式和ui\mathbf {u}_iui类似,最后的Transformation模块则是充当了原有的FFN,通过公式推理,我们可以发现作者所提出的FastFormer本质上是类似于Transformer架构的模型。
实验结果
作者展示模型在情感分类、主题分类上的实验结果,还对于模型的推理时间进行试验,都取得了SOTA的效果,如下所展示,在此不多做介绍,有兴趣的同学可以阅读原文进一步了解。
本文仅作为技术交流和分享,严禁未经授权挪作他用。如果对上述存在译文,或者想进一步沟通可以在评论区评论或者联系邮箱1377157216@qq.com
Fastformer论文解读相关推荐
- 自监督学习(Self-Supervised Learning)多篇论文解读(下)
自监督学习(Self-Supervised Learning)多篇论文解读(下) 之前的研究思路主要是设计各种各样的pretext任务,比如patch相对位置预测.旋转预测.灰度图片上色.视频帧排序等 ...
- 自监督学习(Self-Supervised Learning)多篇论文解读(上)
自监督学习(Self-Supervised Learning)多篇论文解读(上) 前言 Supervised deep learning由于需要大量标注信息,同时之前大量的研究已经解决了许多问题.所以 ...
- 可视化反投射:坍塌尺寸的概率恢复:ICCV9论文解读
可视化反投射:坍塌尺寸的概率恢复:ICCV9论文解读 Visual Deprojection: Probabilistic Recovery of Collapsed Dimensions 论文链接: ...
- 从单一图像中提取文档图像:ICCV2019论文解读
从单一图像中提取文档图像:ICCV2019论文解读 DewarpNet: Single-Image Document Unwarping With Stacked 3D and 2D Regressi ...
- 点云配准的端到端深度神经网络:ICCV2019论文解读
点云配准的端到端深度神经网络:ICCV2019论文解读 DeepVCP: An End-to-End Deep Neural Network for Point Cloud Registration ...
- 图像分类:CVPR2020论文解读
图像分类:CVPR2020论文解读 Towards Robust Image Classification Using Sequential Attention Models 论文链接:https:// ...
- CVPR2020论文解读:手绘草图卷积网络语义分割
CVPR2020论文解读:手绘草图卷积网络语义分割 Sketch GCN: Semantic Sketch Segmentation with Graph Convolutional Networks ...
- CVPR2020论文解读:3D Object Detection三维目标检测
CVPR2020论文解读:3D Object Detection三维目标检测 PV-RCNN:Point-Voxel Feature Se tAbstraction for 3D Object Det ...
- CVPR2020论文解读:三维语义分割3D Semantic Segmentation
CVPR2020论文解读:三维语义分割3D Semantic Segmentation xMUDA: Cross-Modal Unsupervised Domain Adaptation for 3D ...
最新文章
- java.util.concurrent BlockingQueue详解
- 一些有用的Linux命令
- 创建两个相同名称的文件夹
- MFC让窗口最前端显示
- Linux 内核完成接口
- 使用LDA模型对新的文档进行分类
- Linux系统运维-Telnet命令
- 图像拼接算法的基本原理
- oracle 11g 重置,oracle数据库重置
- 一代人终将老去,但总有人正年轻
- MySql每晚12点都会弹出这个?
- flutter comsumer局部刷新的问题
- IP地址分配和IP地址的划分
- ObjectARX自定义实体
- 树结构(Java实现)
- Android中使用Post带参数请求的方法
- arduino液晶显示频
- android打印 编辑并打印 word
- 微信小程序 --长按复制、点击复制实现
- Ubuntu 10.04环境下载编译Android-2.2.1 (froyo) 源代码 1/2
热门文章
- 编码学习——UTF-8与Unicode互转具体流程
- php和mysql不在一台机器上_MySQL_在同一台机器上运行多个 MySQL 服务,**************************************** - phpStudy...
- A*搜索算法AStar_BFS
- java excel 导出 下载_使用Java导出Excel表格并由浏览器直接下载
- 【我的Android进阶之旅】解决魅族手机USB调试时,无法授权出现“Because an app is obscuring a permission request.”错误提示的问题
- 第二章 五行,金木水火土
- POI之excel固定模板导出
- 基于多阈值的形态提取遥感图像中的沿海线的特征方法(Qu Jishuang)
- AI+教育 I 69天流利说APP学习浅谈自适应学习
- 在电脑桌面上添加便签的方法步骤解析