Transformer 的出色表现让注意力机制出现在深度学习的各处。本文整理了深度学习中最常用的6种注意力机制的数学原理和代码实现。

1、Full Attention

2017的《Attention is All You Need》中的编码器-解码器结构实现中提出。它结构并不复杂,所以不难理解。

上图 1.左侧显示了 Scaled Dot-Product Attention 的机制。当我们有多个注意力时,我们称之为多头注意力(右),这也是最常见的注意力的形式公式如下:

公式1

这里Q(Query)、K(Key)和V(values)被认为是它的输入,dₖ(输入维度)被用来降低复杂度和计算成本。这个公式可以说是深度学习中注意力机制发展的开端。下面我们看一下它的代码:

class FullAttention(nn.Module):def __init__(self, mask_flag=True, factor=5, scale=None, attention_dropout=0.1, output_attention=False):super(FullAttention, self).__init__()self.scale = scaleself.mask_flag = mask_flagself.output_attention = output_attentionself.dropout = nn.Dropout(attention_dropout)def forward(self, queries, keys, values, attn_mask):B, L, H, E = queries.shape_, S, _, D = values.shapescale = self.scale or 1. / sqrt(E)scores = torch.einsum("blhe,bshe->bhls", queries, keys)if self.mask_flag:if attn_mask is None:attn_mask = TriangularCausalMask(B, L, device=queries.device)scores.masked_fill_(attn_mask.mask, -np.inf)A = self.dropout(torch.softmax(scale * scores, dim=-1))V = torch.einsum("bhls,bshd->blhd", A, values)if self.output_attention:return (V.contiguous(), A)else:return (V.contiguous(), None)

2、ProbSparse Attention

借助“Transformer Dissection: A Unified Understanding of Transformer’s Attention via the lens of Kernel”中的信息我们可以将公式修改为下面的公式2。第i个query的attention就被定义为一个概率形式的核平滑方法(kernel smoother):

公式2

从公式 2,我们可以定义第 i 个查询的稀疏度测量如下:

最后,注意力块的最终公式是下面的公式4。

代码如下:

class ProbAttention(nn.Module):def __init__(self, mask_flag=True, factor=5, scale=None, attention_dropout=0.1, output_attention=False):super(ProbAttention, self).__init__()self.factor = factorself.scale = scaleself.mask_flag = mask_flagself.output_attention = output_attentionself.dropout = nn.Dropout(attention_dropout)def _prob_QK(self, Q, K, sample_k, n_top): # n_top: c*ln(L_q)# Q [B, H, L, D]B, H, L_K, E = K.shape_, _, L_Q, _ = Q.shape# calculate the sampled Q_KK_expand = K.unsqueeze(-3).expand(B, H, L_Q, L_K, E)index_sample = torch.randint(L_K, (L_Q, sample_k)) # real U = U_part(factor*ln(L_k))*L_qK_sample = K_expand[:, :, torch.arange(L_Q).unsqueeze(1), index_sample, :]Q_K_sample = torch.matmul(Q.unsqueeze(-2), K_sample.transpose(-2, -1)).squeeze(-2)# find the Top_k query with sparisty measurementM = Q_K_sample.max(-1)[0] - torch.div(Q_K_sample.sum(-1), L_K)M_top = M.topk(n_top, sorted=False)[1]# use the reduced Q to calculate Q_KQ_reduce = Q[torch.arange(B)[:, None, None], torch.arange(H)[None, :, None], M_top, :] # factor*ln(L_q)Q_K = torch.matmul(Q_reduce, K.transpose(-2, -1)) # factor*ln(L_q)*L_kreturn Q_K, M_topdef _get_initial_context(self, V, L_Q):B, H, L_V, D = V.shapeif not self.mask_flag:# V_sum = V.sum(dim=-2)V_sum = V.mean(dim=-2)contex = V_sum.unsqueeze(-2).expand(B, H, L_Q, V_sum.shape[-1]).clone()else: # use maskassert(L_Q == L_V) # requires that L_Q == L_V, i.e. for self-attention onlycontex = V.cumsum(dim=-2)return contexdef _update_context(self, context_in, V, scores, index, L_Q, attn_mask):B, H, L_V, D = V.shapeif self.mask_flag:attn_mask = ProbMask(B, H, L_Q, index, scores, device=V.device)scores.masked_fill_(attn_mask.mask, -np.inf)attn = torch.softmax(scores, dim=-1) # nn.Softmax(dim=-1)(scores)context_in[torch.arange(B)[:, None, None],torch.arange(H)[None, :, None],index, :] = torch.matmul(attn, V).type_as(context_in)if self.output_attention:attns = (torch.ones([B, H, L_V, L_V])/L_V).type_as(attn).to(attn.device)attns[torch.arange(B)[:, None, None], torch.arange(H)[None, :, None], index, :] = attnreturn (context_in, attns)else:return (context_in, None)def forward(self, queries, keys, values, attn_mask):B, L_Q, H, D = queries.shape_, L_K, _, _ = keys.shapequeries = queries.transpose(2,1)keys = keys.transpose(2,1)values = values.transpose(2,1)U_part = self.factor * np.ceil(np.log(L_K)).astype('int').item() # c*ln(L_k)u = self.factor * np.ceil(np.log(L_Q)).astype('int').item() # c*ln(L_q) U_part = U_part if U_part<L_K else L_Ku = u if u<L_Q else L_Qscores_top, index = self._prob_QK(queries, keys, sample_k=U_part, n_top=u) # add scale factorscale = self.scale or 1./sqrt(D)if scale is not None:scores_top = scores_top * scale# get the contextcontext = self._get_initial_context(values, L_Q)# update the context with selected top_k queriescontext, attn = self._update_context(context, values, scores_top, index, L_Q, attn_mask)return context.transpose(2,1).contiguous(), attn

这也是 Informer这个用于长序列时间序列预测的新型Transformer中使用的注意力。

3、LogSparse Attention

我们之前讨论的注意力有两个缺点:1. 与位置无关 2. 内存的瓶颈。为了应对这两个问题,研究人员使用了卷积算子和 LogSparse Transformers。

Transformer 中相邻层之间不同注意力机制的图示

卷积自注意力显示在(右)中,它使用步长为 1,内核大小为 k 的卷积层将输入(具有适当的填充)转换为Q/K。这种位置感知可以根据(左)中的形状正确匹配最相关的特征

他们不是使用步长为 1,卷积核大小 1,而是使用步长为 1 ,核大小为 k 的随意卷积(以确保模型无法访问未来的点)将输入转换为Q和K

代码实现

class Attention(nn.Module):def __init__(self, n_head, n_embd, win_len, scale, q_len, sub_len, sparse=None, attn_pdrop=0.1, resid_pdrop=0.1):super(Attention, self).__init__()if(sparse):print('Activate log sparse!')mask = self.log_mask(win_len, sub_len)else:mask = torch.tril(torch.ones(win_len, win_len)).view(1, 1, win_len, win_len)self.register_buffer('mask_tri', mask)self.n_head = n_headself.split_size = n_embd * self.n_headself.scale = scaleself.q_len = q_lenself.query_key = nn.Conv1d(n_embd, n_embd * n_head * 2, self.q_len)self.value = Conv1D(n_embd * n_head, 1, n_embd)self.c_proj = Conv1D(n_embd, 1, n_embd * self.n_head)self.attn_dropout = nn.Dropout(attn_pdrop)self.resid_dropout = nn.Dropout(resid_pdrop)def log_mask(self, win_len, sub_len):mask = torch.zeros((win_len, win_len), dtype=torch.float)for i in range(win_len):mask[i] = self.row_mask(i, sub_len, win_len)return mask.view(1, 1, mask.size(0), mask.size(1))def row_mask(self, index, sub_len, win_len):"""Remark:1 . Currently, dense matrices with sparse multiplication are not supported by Pytorch. Efficient implementationshould deal with CUDA kernel, which we haven't implemented yet.2 . Our default setting here use Local attention and Restart attention.3 . For index-th row, if its past is smaller than the number of cells the lastcell can attend, we can allow current cell to attend all past cells to fullyutilize parallel computing in dense matrices with sparse multiplication."""log_l = math.ceil(np.log2(sub_len))mask = torch.zeros((win_len), dtype=torch.float)if((win_len // sub_len) * 2 * (log_l) > index):mask[:(index + 1)] = 1else:while(index >= 0):if((index - log_l + 1) < 0):mask[:index] = 1breakmask[index - log_l + 1:(index + 1)] = 1  # Local attentionfor i in range(0, log_l):new_index = index - log_l + 1 - 2**iif((index - new_index) <= sub_len and new_index >= 0):mask[new_index] = 1index -= sub_lenreturn maskdef attn(self, query: torch.Tensor, key, value: torch.Tensor, activation="Softmax"):activation = activation_dict[activation](dim=-1)pre_att = torch.matmul(query, key)if self.scale:pre_att = pre_att / math.sqrt(value.size(-1))mask = self.mask_tri[:, :, :pre_att.size(-2), :pre_att.size(-1)]pre_att = pre_att * mask + -1e9 * (1 - mask)pre_att = activation(pre_att)pre_att = self.attn_dropout(pre_att)attn = torch.matmul(pre_att, value)return attndef merge_heads(self, x):x = x.permute(0, 2, 1, 3).contiguous()new_x_shape = x.size()[:-2] + (x.size(-2) * x.size(-1),)return x.view(*new_x_shape)def split_heads(self, x, k=False):new_x_shape = x.size()[:-1] + (self.n_head, x.size(-1) // self.n_head)x = x.view(*new_x_shape)if k:return x.permute(0, 2, 3, 1)else:return x.permute(0, 2, 1, 3)def forward(self, x):value = self.value(x)qk_x = nn.functional.pad(x.permute(0, 2, 1), pad=(self.q_len - 1, 0))query_key = self.query_key(qk_x).permute(0, 2, 1)query, key = query_key.split(self.split_size, dim=2)query = self.split_heads(query)key = self.split_heads(key, k=True)value = self.split_heads(value)attn = self.attn(query, key, value)attn = self.merge_heads(attn)attn = self.c_proj(attn)attn = self.resid_dropout(attn)return attnclass Conv1D(nn.Module):def __init__(self, out_dim, rf, in_dim):super(Conv1D, self).__init__()self.rf = rfself.out_dim = out_dimif rf == 1:w = torch.empty(in_dim, out_dim)nn.init.normal_(w, std=0.02)self.w = Parameter(w)self.b = Parameter(torch.zeros(out_dim))else:raise NotImplementedErrordef forward(self, x):if self.rf == 1:size_out = x.size()[:-1] + (self.out_dim,)x = torch.addmm(self.b, x.view(-1, x.size(-1)), self.w)x = x.view(*size_out)else:raise NotImplementedErrorreturn x

来自:https://github.com/AIStream-Peelout/flow-forecast/

4、LSH Attention

Reformer的论文选择了局部敏感哈希的angular变体。它们首先约束每个输入向量的L2范数(即将向量投影到一个单位球面上),然后应用一系列的旋转,最后找到每个旋转向量所属的切片。这样一来就需要找到最近邻的值,这就需要局部敏感哈希(LSH)了,它能够快速在高维空间中找到最近邻。一个局部敏感哈希算法可以将每个向量 x 转换为 hash h(x),和这个 x 靠近的哈希更有可能有着相同的哈希值,而距离远的则不会。作者希望最近的向量最可能得到相同的哈希值,或者 hash-bucket 大小相似的更有可能相同。

局部敏感哈希算法使用球投影点的随机旋转,通过argmax在有符号的轴投影上建立bucket。在这个高度简化的2D描述中,对于三个不同的角哈希,两个点x和y不太可能共享相同的哈希桶(上图),除非它们的球面投影彼此接近(下图)。

通过固定一个大小为 [dₖ, b/2] 的随机矩阵 R 来获得 b 个哈希值。h(x) = argmax([xR;-xR]) 其中 [u;v] 表示两个向量的串联。这样就可以使用LSH,将查询位置I带入重写公式1:

下图我们可以示意性地解释 LSH 注意力:

  • 原始的注意力矩阵通常是稀疏的,但不利于计算
  • LSH Attention基于哈希桶进行键的排序进行查询
  • 在排序后的注意矩阵中,来自同一桶的对将聚集在对角线附近
  • 采用批处理方法,m个连续查询的块相互处理,一个块返回。

代码很长为了节约时间这里就不贴了:

https://github.com/lucidrains/reformer-pytorch/blob/master/reformer_pytorch/reformer_pytorch.py

5、Sparse Attention(Generating Long Sequences with Sparse Transformers)

OpenAI的Sparse Attention,通过“只保留小区域内的数值、强制让大部分注意力为零”的方式,来减少Attention的计算量。通过top-k选择,将注意退化为稀疏注意。这样,保留最有助于引起注意的部分,并删除其他无关的信息。这种选择性方法在保存重要信息和消除噪声方面是有效的。注意力可以更多地集中在最有贡献的价值因素上。

代码:https://github.com/openai/sparse_attention

6、Single-Headed Attention(Single Headed Attention RNN: Stop Thinking With Your Head)

SHA-RNN模型的注意力是简化到只保留了一个头并且唯一的矩阵乘法出现在query (下图Q) 那里,A是缩放点乘注意力 (Scaled Dot-Product Attention) ,是向量之间的运算。所以这种计算量比较小,能够快速的进行训练,就像它介绍的那样:

Obtain strong results on a byte level language modeling dataset (enwik8) in under 24 hours on a single GPU (12GB Titan V)

代码:https://github.com/Smerity/sha-rnn

引用:

  1. Kitaev, N., Ł. Kaiser, and A. Levskaya, Reformer: The efficient transformer. arXiv preprint arXiv:2001.04451, 2020.
  2. Li, S., et al., Enhancing the locality and breaking the memory bottleneck of transformer on time series forecasting. Advances in Neural Information Processing Systems, 2019. 32.
  3. Zhou, H., et al. Informer: Beyond efficient transformer for long sequence time-series forecasting. in Proceedings of AAAI. 2021.
  4. Vaswani, A., et al., Attention is all you need. Advances in neural information processing systems, 2017. 30.
  5. Rewon Child, Scott Gray, Alec Radford, Ilya Sutskever Generating Long Sequences with Sparse Transformers
  6. Stephen Merity ,Single Headed Attention RNN: Stop Thinking With Your Head

注意机制的发展到现在远远不止这些,在本篇文章中只整理了一些常见的注意力机制,希望对你有所帮助。

另外就是来自Erasmus University的Gianni Brauwers 和Flavius Frasincar在TKDE上发表的《A General Survey on Attention Mechanisms in Deep Learning》综述论文,提供了一个关于深度学习注意力机制的重要概述。各种注意力机制通过一个由注意力模型,统一符号和一个全面的分类注意力机制组成的框架来进行解释,还有注意力模型评价的各种方法。

有兴趣和有资源的话可以进行阅读,女神Alexandra Elbakyan的网站还未提供该论文。

https://www.overfit.cn/post/739299d8be4e4ddc8f5804b37c6c82ad

作者:Reza Yazdanfar

用于Transformer的6种注意力的数学原理和代码实现相关推荐

  1. 奇异值分解SVD数学原理及代码(Python)

    奇异值分解SVD数学原理及代码(Python) 首先简单介绍一下什么是正交矩阵(酉矩阵) 如果 或 其中,E为单位矩阵,或,则n阶实矩阵A称为正交矩阵.正交矩阵是实数特殊化的酉矩阵,因此总是属于正规矩 ...

  2. 【FPGA】CRC校验算法从数学原理到代码实现

    老规矩,转b站 [[FPGA]CRC校验算法从数学原理到代码实现-哔哩哔哩]

  3. tensor如何实现转置_PyTorch中的傅立叶卷积:通过FFT有效计算大核卷积的数学原理和代码实现...

    卷积 卷积在数据分析中无处不在.几十年来,它们已用于信号和图像处理.最近,它们已成为现代神经网络的重要组成部分. 在数学上,卷积表示为: 尽管离散卷积在计算应用程序中更为常见,但由于本文使用连续变量证 ...

  4. 几种风控算法的原理和代码实现

    一.基算法 1.决策树(Decision Tree) (1)原理:决策树根据样本数据集的数据特征对数据集进行划分,直到针对所有特征都划分过,或者划分的数据子集的所有数据的类别标签相同. (2)代码实现 ...

  5. OpenGL坐标变换及其数学原理,两种摄像机交互模型(附源程序)

    OpenGL坐标变换及其数学原理,两种摄像机交互模型(附源程序) 实验平台:win7,VS2010 先上结果截图(文章最后下载程序,解压后直接运行BIN文件夹下的EXE程序): a.鼠标拖拽旋转物体, ...

  6. 详解Transformer模型及相关的数学原理

    声明:本文参考了许多相关资料,视频,博客,结合<Attention is All You Need>这篇文章的每一个细节,从一个初学者的角度出发详细解读Transformer模型,无代码. ...

  7. Transformer模型有多少种变体?复旦邱锡鹏教授团队做了全面综述

    视学算法报道 转载自:机器之心 编辑:Liyuan.杜伟 自提出至今,Transformer 模型已经在自然语言处理.计算机视觉以及其他更多领域「大展拳脚」,学界也提出了各种各样基于原始模型的变体.但 ...

  8. 谷歌将AutoML应用于Transformer架构,翻译结果飙升,已开源!

    来源:新智元 本文约1600字,建议阅读8分钟. Evolved Transformer不仅实现了最先进的翻译结果,与原始的Transformer相比,它还展示了语言建模的改进性能. [ 导读 ]为了 ...

  9. 谷歌将AutoML应用于Transformer架构,翻译结果飙升,已开源

    https://www.toutiao.com/a6702613730661761548/ 2019-06-15 12:44:29 [新智元导读]为了探索AutoML在序列域中的应用是否能够取得的成功 ...

最新文章

  1. Sql 先进先出计算积分
  2. requireJS对文件合并与压缩(二)
  3. 电影院票务管理系统数据库设计(2)
  4. 有趣的Web版Ubuntu Linux
  5. 都说dlib是人脸识别的神器,那到底能不能识破妖怪的伪装?
  6. Kubernetes初步学习
  7. 浅谈OpenCL四大模型之Execution Model
  8. 客户和顾客是一个意思吗_“啤酒度数”和“啤酒酒精度”一个意思吗?
  9. Tableau可视化学习笔记:day09-10
  10. Vuex之store仓库计算属性Getter
  11. 复杂网络分析软件NetworkX和UCINET数据关联的方法
  12. matlab cnn 实例,Deep Learning学习 之 CNN代码解析(MATLAB)(示例代码)
  13. c语言中人脸磨皮算法,人脸磨皮算法
  14. MaxCompute SQL
  15. 移动通信网络规划:多址技术
  16. 伤病缠身仍愿竭力而战 澳网一别穆雷何时再见?
  17. 基于时延估计的动力型下肢假肢分段控制策略研究
  18. jadbc oracle clob,XML blob issue with External table
  19. 最后1天,购票渠道即将关闭!Unite 2018开发者大会全日程公布
  20. 中断使能和清除使能、中断挂起和清除挂起

热门文章

  1. 微博--图片,视频,评论抓取
  2. 4G网络要升级成5G,换卡还是换手机?
  3. Linux处理cds文件,Linux 使用CDS磁盘+LVM
  4. Y4M(YUV4MPEG2) 格式文件详解
  5. 10个最佳的大数据处理编程语言
  6. easyuefi如何添加引导_UEFI怎么用 UEFI安全启动设置添加方法步骤图解
  7. QtXlsx详细配置
  8. 【Web技术】1391- 页面可视化搭建工具前生今世
  9. css3 logo 自上而下动画 渐渐出现
  10. 关于majaro安装后的配置,简单记录 机型华硕FZ53v