文章目录

  • 1. Title
  • 2. Summary
  • 3. Problem Statement
  • 4. Method(s)
    • 4.1 Relation-aware Self-Attention
    • 4.2 Relative Position Representation
    • 4.3 Efficient Implementation
  • 5. Evaluation
  • 6. Conclusion

1. Title

Self-Attention with Relative Position Representations
https://github.com/evelinehong/Transformer_Relative_Position_PyTorch

2. Summary

Transformer的核心结构Self-Attention机制由于其无法对输入token的相对位置或绝对位置信息进行建模,因此,目前主流的方案都是在输入token之外再额外加上一个Positional Encoding来引入位置信息。本文则是从Self-Attention机制内部出发,通过在计算过程中引入token之间的相对位置关系向量,打破了Self-Attention机制的Permutation-Invariant特性,从而更高效地完成了位置信息的编码,性能得到了提升。
阅读本文主要是在阅读Vision Transformer相关论文中看到了相关应用,Relative Positional Encoding在CV领域也有很多应用,对Vision Transformer性能的提升也是比较明显的。

3. Problem Statement

不同于RNN、CNN,Transformer结构没有显式对相对或者绝对位置进行建模的能力,为此,目前常见的做法是输入中额外添加包含位置信息的特征表示。
但是本文则是从另一个角度出发,Transformer之所以无法对相对或者绝对位置建模,是因为其核心操作Self-Attention是Permutation Invariant,这个性质的简单说明可以参见我另一篇博客:Conditional Positional Encodings for Vision Transformers。
因此,倘若能够打破Self-Attention操作的Permutation Invariant特性,即可不再需要额外的位置信息的输入。

4. Method(s)

4.1 Relation-aware Self-Attention

将输入看做是一个带有标签的有向全连接图。
对于两个输入元素 x i x_i xix j x_j xj之间的边通过两个向量来表示 a i j V , a i j K ∈ R d a a_{i j}^{V}, a_{i j}^{K} \in \mathbb{R}^{d_{a}} aijV,aijKRda,这些向量表示在多个head之间共享, d a = d z d_a=d_z da=dz。通过引入边的特征表示,原始的Self-Attention机制修改为以下计算方式:
z i = ∑ j = 1 n α i j ( x j W V + a i j V ) z_{i}=\sum_{j=1}^{n} \alpha_{i j}\left(x_{j} W^{V}+a_{i j}^{V}\right) zi=j=1nαij(xjWV+aijV) α i j = exp ⁡ e i j ∑ k = 1 n exp ⁡ e i k \alpha_{i j}=\frac{\exp e_{i j}}{\sum_{k=1}^{n} \exp e_{i k}} αij=k=1nexpeikexpeij e i j = x i W Q ( x j W K + a i j K ) T d z e_{i j}=\frac{x_{i} W^{Q}\left(x_{j} W^{K}+a_{i j}^{K}\right)^{T}}{\sqrt{d_{z}}} eij=dz

xiWQ(xjWK+aijK)T
即对于各个Value和Key来说,都会引入一个相互的位置关系表示,从而打破了Self-Attention的Permutation-Invariant。

4.2 Relative Position Representation

考虑到计算量、内存消耗以及远距离的精确位置信息效用不是很足等因素,本文对最远的Relative Position Distance限制为 k k k
a i j K = w c l i p ( j − i , k ) K a i j V = w c l i p ( j − i , k ) V clip ⁡ ( x , k ) = max ⁡ ( − k , min ⁡ ( k , x ) ) \begin{aligned} a_{i j}^{K} &=w_{\mathrm{clip}(j-i, k)}^{K} \\ a_{i j}^{V} &=w_{\mathrm{clip}(j-i, k)}^{V} \\ \operatorname{clip}(x, k) &=\max (-k, \min (k, x)) \end{aligned} aijKaijVclip(x,k)=wclip(ji,k)K=wclip(ji,k)V=max(k,min(k,x))

在这种设定下,仅需要学习 w K = ( w − k K , … , w k K ) w^{K}=\left(w_{-k}^{K}, \ldots, w_{k}^{K}\right) wK=(wkK,,wkK)w V = ( w − k V , … , w k V ) w^{V}=\left(w_{-k}^{V}, \ldots, w_{k}^{V}\right) wV=(wkV,,wkV)

下面结合https://github.com/evelinehong/Transformer_Relative_Position_PyTorch这份代码,对这个部分进行更详细地阐述。

class RelativePosition(nn.Module):def __init__(self, num_units, max_relative_position):super().__init__()self.num_units = num_unitsself.max_relative_position = max_relative_positionself.embeddings_table = nn.Parameter(torch.Tensor(max_relative_position * 2 + 1, num_units))nn.init.xavier_uniform_(self.embeddings_table)def forward(self, length_q, length_k):range_vec_q = torch.arange(length_q)range_vec_k = torch.arange(length_k)distance_mat = range_vec_k[None, :] - range_vec_q[:, None]distance_mat_clipped = torch.clamp(distance_mat, -self.max_relative_position, self.max_relative_position)final_mat = distance_mat_clipped + self.max_relative_positionfinal_mat = torch.LongTensor(final_mat)embeddings = self.embeddings_table[final_mat]return embeddings

代码说明见下图:

4.3 Efficient Implementation

对于一个长度为 n n n和一个head数为 h h h的Multi-Head Self-Attention来说,通过在多个head之间共享Relative Position Representation,使得其空间复杂度由 O ( h n 2 d a ) O(hn^2d_a) O(hn2da)下降至 O ( n 2 d a ) O(n^2d_a) O(n2da),同时,不同Sequence之间也可以进行共享。
因此,对于一个batchsize为 b b b的序列来说,其空间复杂度由 O ( b h n d z ) O(bhnd_z) O(bhndz)上升为 O ( b h n d z + n 2 d a ) O(bhnd_z+n^2d_a) O(bhndz+n2da),其中 O ( n 2 d a ) O(n^2d_a) O(n2da)b b b个Sequence的Relative Position Representation所带来的额外空间消耗。

当没有Relative Position Representation时, e i j e_{ij} eij可以通过 b h bh bh个并行化的 n × d z n \times d_z n×dzd z × n d_z \times n dz×n进行矩阵乘法高效得到。这种高效计算的简单推导如下图:

当加入Relative Positional Representation之后,上述高效计算的前提就被打破了, e i j e_{ij} eij的计算不能分解为 q i q_i qik j k_j kj两个独立的部分了,而是 q i q_i qik i j k_{ij} kij两个不完全独立的部分,此时无法直接将其转化为高效的矩阵计算。

为了解决这个问题,作者将 k i j k_{ij} kij部分拆开,将其分为两个部分分开计算,每个部分可以独立采用一个并行化的高效计算矩阵运算来完成:
e i j = x i W Q ( x j W K ) T + x i W Q ( a i j K ) T d z e_{i j}=\frac{x_{i} W^{Q}\left(x_{j} W^{K}\right)^{T}+x_{i} W^{Q}\left(a_{i j}^{K}\right)^{T}}{\sqrt{d_{z}}} eij=dz

xiWQ(xjWK)T+xiWQ(aijK)T

上式中,第一部分与未加入Relative Positional Representation时计算方式一样,第二部分则采用稍微不太一样的矩阵计算来完成:
记上式右侧部分为 e i j ′ e_{ij}' eij,记 x i W Q x_iW^Q xiWQq i q_i qi,记 a i j K a_{ij}^K aijKk i j k_{ij} kij,忽略分母项,则右侧部分可表示为 e i j ′ = q i k i j T e_{ij}'=q_ik_{ij}^T eij=qikijT

  • 我们一共有 b h × n bh \times n bh×nq i q_i qi,每个 q i q_i qi的维度为 d z d_z dz
  • 同样我们一共有 n × n n \times n n×nk i j k_{ij} kij,每个 k i j k_{ij} kij的维度为 d z d_z dz

为了能够进行高效的矩阵计算,我们需要将 q i q_i qik i j k_{ij} kij进行重新解释(reshape):

  • q i q_i qi也可以表示为我们一共有 n × b h n \times bh n×bhq i ′ q_i' qi,每个 q i ′ q_i' qi的维度为 d z d_z dz
  • k i j k_{ij} kij也可以表示为我们一共有 n × n n \times n n×nk i j ′ k_{ij}' kij,每个 k i j ′ k_{ij}' kij的维度为 d z d_z dz(含义没有发生变化)

此时我们便可以对 q i ′ q_i' qik i j ′ k_{ij}' kij进行 n n n个并行化的两个大小为 b h × d z bh \times d_z bh×dzd z × n d_z \times n dz×n的矩阵计算来加速计算。最终再重新reshape回原始的大小即可完成 e i j e_{ij} eij两个部分的高效并行化计算。

具体可以参见以下代码:

class MultiHeadAttentionLayer(nn.Module):def __init__(self, hid_dim, n_heads, dropout, device):super().__init__()assert hid_dim % n_heads == 0self.hid_dim = hid_dimself.n_heads = n_headsself.head_dim = hid_dim // n_headsself.max_relative_position = 2self.relative_position_k = RelativePosition(self.head_dim, self.max_relative_position)self.relative_position_v = RelativePosition(self.head_dim, self.max_relative_position)self.fc_q = nn.Linear(hid_dim, hid_dim)self.fc_k = nn.Linear(hid_dim, hid_dim)self.fc_v = nn.Linear(hid_dim, hid_dim)self.fc_o = nn.Linear(hid_dim, hid_dim)self.dropout = nn.Dropout(dropout)self.scale = torch.sqrt(torch.FloatTensor([self.head_dim])).to(device)def forward(self, query, key, value, mask=None):# query = [batch size, query len, hid dim]# key = [batch size, key len, hid dim]# value = [batch size, value len, hid dim]batch_size = query.shape[0]len_k = key.shape[1]len_q = query.shape[1]len_v = value.shape[1]# get q k vquery = self.fc_q(query)  # b n dkey = self.fc_k(key)  # b n dvalue = self.fc_v(value)  # b n dr_q1 = query.view(batch_size, -1, self.n_heads, self.head_dim).permute(0, 2, 1, 3)  # b h n d/hr_k1 = key.view(batch_size, -1, self.n_heads, self.head_dim).permute(0, 2, 1, 3)  # b h n d/hattn1 = torch.matmul(r_q1, r_k1.permute(0, 1, 3, 2))  # first item of equal (5) b h n nr_q2 = query.permute(1, 0, 2).contiguous().view(len_q, batch_size * self.n_heads, self.head_dim)  # n b*h d/hr_k2 = self.relative_position_k(len_q, len_k)  # n n d/hattn2 = torch.matmul(r_q2, r_k2.transpose(1, 2)).transpose(0, 1)  # b*h n nattn2 = attn2.contiguous().view(batch_size, self.n_heads, len_q, len_k)  # second item of equal (5) b h n nattn = (attn1 + attn2) / self.scale if mask is not None:attn = attn.masked_fill(mask == 0, -1e10)attn = self.dropout(torch.softmax(attn, dim=-1))# attn = [batch size, n heads, query len, key len]r_v1 = value.view(batch_size, -1, self.n_heads, self.head_dim).permute(0, 2, 1, 3)weight1 = torch.matmul(attn, r_v1)r_v2 = self.relative_position_v(len_q, len_v)weight2 = attn.permute(2, 0, 1, 3).contiguous().view(len_q, batch_size * self.n_heads, len_k)weight2 = torch.matmul(weight2, r_v2)weight2 = weight2.transpose(0, 1).contiguous().view(batch_size, self.n_heads, len_q, self.head_dim)x = weight1 + weight2# x = [batch size, n heads, query len, head dim]x = x.permute(0, 2, 1, 3).contiguous()# x = [batch size, query len, n heads, head dim]x = x.view(batch_size, -1, self.hid_dim)# x = [batch size, query len, hid dim]x = self.fc_o(x)# x = [batch size, query len, hid dim]return x

5. Evaluation

本篇论文主要是用于NLP领域,其实验结果如下:

6. Conclusion

本文主要是从Self-Attention机制本身出发,在计算过程中引入了相对位置信息,从而打破了Self-Attention的Permutation-Invariant特性,提升了各个word之间关系构建能力。

2018-Self-Attention with Relative Position Representations相关推荐

  1. 文本生成(二)【NLP论文复现】Relative position representations 相对位置编码突破Bert的文本长度限制!

    Relative position representations 相对位置编码突破Bert文本512长度的限制 前言 Self-Attention with Relative Position Re ...

  2. 论文阅读笔记:Self-Attention with Relative Position Representations

    提示:阅读论文时进行相关思想.结构.优缺点,内容进行提炼和记录,论文和相关引用会标明出处. 文章目录 前言 介绍 相关 具体结构 Relation-aware自注意力 相对位置表示 高效实现 实验结果 ...

  3. Relative Position Representations

    Self-Attention with Relative Position Representations 摘要 在原始transformer,位置信息通过加一个position的embedding实 ...

  4. 论文阅读——Self-Attention with Relative Position Representations

    Self-Attention with Relative Position Representations Abstract 2017年Vaswani等人提出的Transformer需要在输入中添加绝 ...

  5. How Self-Attention with Relative Position Representations works

      本文的主要内容是基于相对位置表示的自注意力机制是如何工作的. 1. 引论   本篇文章是基于 Self-Attention with Relative Position Representatio ...

  6. 相对位置编码之RPR式:《Self-Attention with Relative Position Representations》论文笔记

  7. [NLP] 相对位置编码 Relative Position Representatitons (RPR)

    1. 翻译:https://medium.com/@_init_/how-self-attention-with-relative-position-representations-works-281 ...

  8. 【论文笔记】Rethinking and Improving Relative Position Encoding for Vision Transformer

    论文 论文题目:Rethinking and Improving Relative Position Encoding for Vision Transformer 接收:ICCV 2021 论文地址 ...

  9. 一文读懂css的相对定位【relative position】以及相对定位为什么要设置偏移量?

    目录 何为定位 偏移量 垂直方向 top bottom 水平方向 left right relative-相对定位 何为相对定位 相对定位的特点 实例 元素代码的起始位置 元素若不开启相对定位,即便设 ...

最新文章

  1. mybatis 报错最终解决 :argument type mismatch
  2. 出现module ‘xgboost‘ has no attribute ‘DMatrix‘的临时解决方法
  3. 【转】mysqldump的锁表的问题
  4. 七夕礼物没送对?飞桨PaddlePaddle帮你读懂女朋友的小心思
  5. python 进度条程序_Python:显示程序运行进度条
  6. abstract不能和哪些关键字共存 学习
  7. mysql 日志还原数据库_通过Mysql-bin日志恢复还原数据
  8. HDU -2546饭卡(01背包+贪心)
  9. cloudstack java api_CloudStack API编程指引
  10. java数据库编程之JDBC
  11. 荣耀20/20 Pro相机规格曝光:DxOMark排名或将再次改变
  12. 利用Postman工具测试若依前后端分离接口
  13. CSS Sprite的应用【转】
  14. Java中的Object 类的常见方法
  15. 差分进化算法python_L单目标差分进化算法
  16. Microsoft.SharePoint.dll分享
  17. 显示前半内容后半内容用省略号_2015年广东中考满分作文赏析:特别的一朵花_1500字...
  18. 腾讯地图实时精准定位
  19. JavaSE进阶26 - IO流概述、字节流、字符流、转换流、缓冲流
  20. BGP(1):BGP 的基本机制

热门文章

  1. C#编写简单的迷宫游戏
  2. 分离扫描文档方法(1) —— Dynamic Web TWAIN:如何使用空白页作为扫描文档的分隔器
  3. UI设计素材干货,动效的优秀模板
  4. android音乐播放器开发在线加载歌词
  5. 做自媒体不赚钱了,有多少人是月入过千的?
  6. 【matlab】模拟退火算法代码分析(附sj.txt文件)
  7. 基于PHP后台请求亚马逊订单列表listOrder接口
  8. 图片镜像水平翻转,垂直翻转以及顺时针,逆时针旋转
  9. win10下使用iverilog仿真+gtkwave/WaveDrom查看波形
  10. BFS团战可以输、提莫必须死(转载)