2018-Self-Attention with Relative Position Representations
文章目录
- 1. Title
- 2. Summary
- 3. Problem Statement
- 4. Method(s)
- 4.1 Relation-aware Self-Attention
- 4.2 Relative Position Representation
- 4.3 Efficient Implementation
- 5. Evaluation
- 6. Conclusion
1. Title
Self-Attention with Relative Position Representations
https://github.com/evelinehong/Transformer_Relative_Position_PyTorch
2. Summary
Transformer的核心结构Self-Attention机制由于其无法对输入token的相对位置或绝对位置信息进行建模,因此,目前主流的方案都是在输入token之外再额外加上一个Positional Encoding来引入位置信息。本文则是从Self-Attention机制内部出发,通过在计算过程中引入token之间的相对位置关系向量,打破了Self-Attention机制的Permutation-Invariant特性,从而更高效地完成了位置信息的编码,性能得到了提升。
阅读本文主要是在阅读Vision Transformer相关论文中看到了相关应用,Relative Positional Encoding在CV领域也有很多应用,对Vision Transformer性能的提升也是比较明显的。
3. Problem Statement
不同于RNN、CNN,Transformer结构没有显式对相对或者绝对位置进行建模的能力,为此,目前常见的做法是输入中额外添加包含位置信息的特征表示。
但是本文则是从另一个角度出发,Transformer之所以无法对相对或者绝对位置建模,是因为其核心操作Self-Attention是Permutation Invariant,这个性质的简单说明可以参见我另一篇博客:Conditional Positional Encodings for Vision Transformers。
因此,倘若能够打破Self-Attention操作的Permutation Invariant特性,即可不再需要额外的位置信息的输入。
4. Method(s)
4.1 Relation-aware Self-Attention
将输入看做是一个带有标签的有向全连接图。
对于两个输入元素 x i x_i xi和 x j x_j xj之间的边通过两个向量来表示 a i j V , a i j K ∈ R d a a_{i j}^{V}, a_{i j}^{K} \in \mathbb{R}^{d_{a}} aijV,aijK∈Rda,这些向量表示在多个head之间共享, d a = d z d_a=d_z da=dz。通过引入边的特征表示,原始的Self-Attention机制修改为以下计算方式:
z i = ∑ j = 1 n α i j ( x j W V + a i j V ) z_{i}=\sum_{j=1}^{n} \alpha_{i j}\left(x_{j} W^{V}+a_{i j}^{V}\right) zi=j=1∑nαij(xjWV+aijV) α i j = exp e i j ∑ k = 1 n exp e i k \alpha_{i j}=\frac{\exp e_{i j}}{\sum_{k=1}^{n} \exp e_{i k}} αij=∑k=1nexpeikexpeij e i j = x i W Q ( x j W K + a i j K ) T d z e_{i j}=\frac{x_{i} W^{Q}\left(x_{j} W^{K}+a_{i j}^{K}\right)^{T}}{\sqrt{d_{z}}} eij=dz
即对于各个Value和Key来说,都会引入一个相互的位置关系表示,从而打破了Self-Attention的Permutation-Invariant。
4.2 Relative Position Representation
考虑到计算量、内存消耗以及远距离的精确位置信息效用不是很足等因素,本文对最远的Relative Position Distance限制为 k k k。
a i j K = w c l i p ( j − i , k ) K a i j V = w c l i p ( j − i , k ) V clip ( x , k ) = max ( − k , min ( k , x ) ) \begin{aligned} a_{i j}^{K} &=w_{\mathrm{clip}(j-i, k)}^{K} \\ a_{i j}^{V} &=w_{\mathrm{clip}(j-i, k)}^{V} \\ \operatorname{clip}(x, k) &=\max (-k, \min (k, x)) \end{aligned} aijKaijVclip(x,k)=wclip(j−i,k)K=wclip(j−i,k)V=max(−k,min(k,x))
在这种设定下,仅需要学习 w K = ( w − k K , … , w k K ) w^{K}=\left(w_{-k}^{K}, \ldots, w_{k}^{K}\right) wK=(w−kK,…,wkK)和 w V = ( w − k V , … , w k V ) w^{V}=\left(w_{-k}^{V}, \ldots, w_{k}^{V}\right) wV=(w−kV,…,wkV)。
下面结合https://github.com/evelinehong/Transformer_Relative_Position_PyTorch这份代码,对这个部分进行更详细地阐述。
class RelativePosition(nn.Module):def __init__(self, num_units, max_relative_position):super().__init__()self.num_units = num_unitsself.max_relative_position = max_relative_positionself.embeddings_table = nn.Parameter(torch.Tensor(max_relative_position * 2 + 1, num_units))nn.init.xavier_uniform_(self.embeddings_table)def forward(self, length_q, length_k):range_vec_q = torch.arange(length_q)range_vec_k = torch.arange(length_k)distance_mat = range_vec_k[None, :] - range_vec_q[:, None]distance_mat_clipped = torch.clamp(distance_mat, -self.max_relative_position, self.max_relative_position)final_mat = distance_mat_clipped + self.max_relative_positionfinal_mat = torch.LongTensor(final_mat)embeddings = self.embeddings_table[final_mat]return embeddings
代码说明见下图:
4.3 Efficient Implementation
对于一个长度为 n n n和一个head数为 h h h的Multi-Head Self-Attention来说,通过在多个head之间共享Relative Position Representation,使得其空间复杂度由 O ( h n 2 d a ) O(hn^2d_a) O(hn2da)下降至 O ( n 2 d a ) O(n^2d_a) O(n2da),同时,不同Sequence之间也可以进行共享。
因此,对于一个batchsize为 b b b的序列来说,其空间复杂度由 O ( b h n d z ) O(bhnd_z) O(bhndz)上升为 O ( b h n d z + n 2 d a ) O(bhnd_z+n^2d_a) O(bhndz+n2da),其中 O ( n 2 d a ) O(n^2d_a) O(n2da)为 b b b个Sequence的Relative Position Representation所带来的额外空间消耗。
当没有Relative Position Representation时, e i j e_{ij} eij可以通过 b h bh bh个并行化的 n × d z n \times d_z n×dz和 d z × n d_z \times n dz×n进行矩阵乘法高效得到。这种高效计算的简单推导如下图:
当加入Relative Positional Representation之后,上述高效计算的前提就被打破了, e i j e_{ij} eij的计算不能分解为 q i q_i qi和 k j k_j kj两个独立的部分了,而是 q i q_i qi和 k i j k_{ij} kij两个不完全独立的部分,此时无法直接将其转化为高效的矩阵计算。
为了解决这个问题,作者将 k i j k_{ij} kij部分拆开,将其分为两个部分分开计算,每个部分可以独立采用一个并行化的高效计算矩阵运算来完成:
e i j = x i W Q ( x j W K ) T + x i W Q ( a i j K ) T d z e_{i j}=\frac{x_{i} W^{Q}\left(x_{j} W^{K}\right)^{T}+x_{i} W^{Q}\left(a_{i j}^{K}\right)^{T}}{\sqrt{d_{z}}} eij=dz
上式中,第一部分与未加入Relative Positional Representation时计算方式一样,第二部分则采用稍微不太一样的矩阵计算来完成:
记上式右侧部分为 e i j ′ e_{ij}' eij′,记 x i W Q x_iW^Q xiWQ为 q i q_i qi,记 a i j K a_{ij}^K aijK为 k i j k_{ij} kij,忽略分母项,则右侧部分可表示为 e i j ′ = q i k i j T e_{ij}'=q_ik_{ij}^T eij′=qikijT。
- 我们一共有 b h × n bh \times n bh×n个 q i q_i qi,每个 q i q_i qi的维度为 d z d_z dz
- 同样我们一共有 n × n n \times n n×n个 k i j k_{ij} kij,每个 k i j k_{ij} kij的维度为 d z d_z dz
为了能够进行高效的矩阵计算,我们需要将 q i q_i qi和 k i j k_{ij} kij进行重新解释(reshape):
- q i q_i qi也可以表示为我们一共有 n × b h n \times bh n×bh个 q i ′ q_i' qi′,每个 q i ′ q_i' qi′的维度为 d z d_z dz
- k i j k_{ij} kij也可以表示为我们一共有 n × n n \times n n×n个 k i j ′ k_{ij}' kij′,每个 k i j ′ k_{ij}' kij′的维度为 d z d_z dz(含义没有发生变化)
此时我们便可以对 q i ′ q_i' qi′和 k i j ′ k_{ij}' kij′进行 n n n个并行化的两个大小为 b h × d z bh \times d_z bh×dz和 d z × n d_z \times n dz×n的矩阵计算来加速计算。最终再重新reshape回原始的大小即可完成 e i j e_{ij} eij两个部分的高效并行化计算。
具体可以参见以下代码:
class MultiHeadAttentionLayer(nn.Module):def __init__(self, hid_dim, n_heads, dropout, device):super().__init__()assert hid_dim % n_heads == 0self.hid_dim = hid_dimself.n_heads = n_headsself.head_dim = hid_dim // n_headsself.max_relative_position = 2self.relative_position_k = RelativePosition(self.head_dim, self.max_relative_position)self.relative_position_v = RelativePosition(self.head_dim, self.max_relative_position)self.fc_q = nn.Linear(hid_dim, hid_dim)self.fc_k = nn.Linear(hid_dim, hid_dim)self.fc_v = nn.Linear(hid_dim, hid_dim)self.fc_o = nn.Linear(hid_dim, hid_dim)self.dropout = nn.Dropout(dropout)self.scale = torch.sqrt(torch.FloatTensor([self.head_dim])).to(device)def forward(self, query, key, value, mask=None):# query = [batch size, query len, hid dim]# key = [batch size, key len, hid dim]# value = [batch size, value len, hid dim]batch_size = query.shape[0]len_k = key.shape[1]len_q = query.shape[1]len_v = value.shape[1]# get q k vquery = self.fc_q(query) # b n dkey = self.fc_k(key) # b n dvalue = self.fc_v(value) # b n dr_q1 = query.view(batch_size, -1, self.n_heads, self.head_dim).permute(0, 2, 1, 3) # b h n d/hr_k1 = key.view(batch_size, -1, self.n_heads, self.head_dim).permute(0, 2, 1, 3) # b h n d/hattn1 = torch.matmul(r_q1, r_k1.permute(0, 1, 3, 2)) # first item of equal (5) b h n nr_q2 = query.permute(1, 0, 2).contiguous().view(len_q, batch_size * self.n_heads, self.head_dim) # n b*h d/hr_k2 = self.relative_position_k(len_q, len_k) # n n d/hattn2 = torch.matmul(r_q2, r_k2.transpose(1, 2)).transpose(0, 1) # b*h n nattn2 = attn2.contiguous().view(batch_size, self.n_heads, len_q, len_k) # second item of equal (5) b h n nattn = (attn1 + attn2) / self.scale if mask is not None:attn = attn.masked_fill(mask == 0, -1e10)attn = self.dropout(torch.softmax(attn, dim=-1))# attn = [batch size, n heads, query len, key len]r_v1 = value.view(batch_size, -1, self.n_heads, self.head_dim).permute(0, 2, 1, 3)weight1 = torch.matmul(attn, r_v1)r_v2 = self.relative_position_v(len_q, len_v)weight2 = attn.permute(2, 0, 1, 3).contiguous().view(len_q, batch_size * self.n_heads, len_k)weight2 = torch.matmul(weight2, r_v2)weight2 = weight2.transpose(0, 1).contiguous().view(batch_size, self.n_heads, len_q, self.head_dim)x = weight1 + weight2# x = [batch size, n heads, query len, head dim]x = x.permute(0, 2, 1, 3).contiguous()# x = [batch size, query len, n heads, head dim]x = x.view(batch_size, -1, self.hid_dim)# x = [batch size, query len, hid dim]x = self.fc_o(x)# x = [batch size, query len, hid dim]return x
5. Evaluation
本篇论文主要是用于NLP领域,其实验结果如下:
6. Conclusion
本文主要是从Self-Attention机制本身出发,在计算过程中引入了相对位置信息,从而打破了Self-Attention的Permutation-Invariant特性,提升了各个word之间关系构建能力。
2018-Self-Attention with Relative Position Representations相关推荐
- 文本生成(二)【NLP论文复现】Relative position representations 相对位置编码突破Bert的文本长度限制!
Relative position representations 相对位置编码突破Bert文本512长度的限制 前言 Self-Attention with Relative Position Re ...
- 论文阅读笔记:Self-Attention with Relative Position Representations
提示:阅读论文时进行相关思想.结构.优缺点,内容进行提炼和记录,论文和相关引用会标明出处. 文章目录 前言 介绍 相关 具体结构 Relation-aware自注意力 相对位置表示 高效实现 实验结果 ...
- Relative Position Representations
Self-Attention with Relative Position Representations 摘要 在原始transformer,位置信息通过加一个position的embedding实 ...
- 论文阅读——Self-Attention with Relative Position Representations
Self-Attention with Relative Position Representations Abstract 2017年Vaswani等人提出的Transformer需要在输入中添加绝 ...
- How Self-Attention with Relative Position Representations works
本文的主要内容是基于相对位置表示的自注意力机制是如何工作的. 1. 引论 本篇文章是基于 Self-Attention with Relative Position Representatio ...
- 相对位置编码之RPR式:《Self-Attention with Relative Position Representations》论文笔记
- [NLP] 相对位置编码 Relative Position Representatitons (RPR)
1. 翻译:https://medium.com/@_init_/how-self-attention-with-relative-position-representations-works-281 ...
- 【论文笔记】Rethinking and Improving Relative Position Encoding for Vision Transformer
论文 论文题目:Rethinking and Improving Relative Position Encoding for Vision Transformer 接收:ICCV 2021 论文地址 ...
- 一文读懂css的相对定位【relative position】以及相对定位为什么要设置偏移量?
目录 何为定位 偏移量 垂直方向 top bottom 水平方向 left right relative-相对定位 何为相对定位 相对定位的特点 实例 元素代码的起始位置 元素若不开启相对定位,即便设 ...
最新文章
- mybatis 报错最终解决 :argument type mismatch
- 出现module ‘xgboost‘ has no attribute ‘DMatrix‘的临时解决方法
- 【转】mysqldump的锁表的问题
- 七夕礼物没送对?飞桨PaddlePaddle帮你读懂女朋友的小心思
- python 进度条程序_Python:显示程序运行进度条
- abstract不能和哪些关键字共存 学习
- mysql 日志还原数据库_通过Mysql-bin日志恢复还原数据
- HDU -2546饭卡(01背包+贪心)
- cloudstack java api_CloudStack API编程指引
- java数据库编程之JDBC
- 荣耀20/20 Pro相机规格曝光:DxOMark排名或将再次改变
- 利用Postman工具测试若依前后端分离接口
- CSS Sprite的应用【转】
- Java中的Object 类的常见方法
- 差分进化算法python_L单目标差分进化算法
- Microsoft.SharePoint.dll分享
- 显示前半内容后半内容用省略号_2015年广东中考满分作文赏析:特别的一朵花_1500字...
- 腾讯地图实时精准定位
- JavaSE进阶26 - IO流概述、字节流、字符流、转换流、缓冲流
- BGP(1):BGP 的基本机制