文章目录

  • 多头注意力
    • 模型
    • 实现
    • 小结

多头注意力

在实践中,当给定 相同的查询、键和值的集合 时,我们希望模型可以基于相同的注意力机制学习到不同的行为,然后将不同的行为作为知识组合起来,捕获序列内各种范围的依赖关系(例如,短距离依赖和长距离依赖关系)。因此,允许注意力机制组合使用查询、键和值的不同 子空间表示(representation subspaces) 可能是有益的。

为此,与其只使用单独一个注意力汇聚,我们可以用独立学习得到的 h h h 组不同的线性投影(linear projections) 来变换查询、键和值。然后,这 h h h 组变换后的查询、键和值将并行地送到注意力汇聚中。最后,将这 h h h 个注意力汇聚的输出拼接在一起,并且通过另一个可以学习的线性投影进行变换,以产生最终输出。这种设计被称为多头注意力(multihead attention)。对于 h h h 个注意力汇聚输出,每一个注意力汇聚都被称作一个头(head)

本质地讲,自注意力机制是:通过某种运算来直接计算得到句子在编码过程中每个位置上的注意力权重;然后再以权重和的形式来计算得到整个句子的隐含向量表示。

自注意力机制的缺陷是:模型在对当前位置的信息进行编码时,会过度的将注意力集中于自身的位置, 因此作者提出了通过多头注意力机制来解决这一问题。

下图展示了使用全连接层来实现可学习的线性变换的多头注意力。

模型

在实现多头注意力之前,让我们用数学语言将这个模型形式化地描述出来。给定查询 q ∈ R d q \mathbf{q} \in \mathbb{R}^{d_q} q∈Rdq​、键 k ∈ R d k \mathbf{k} \in \mathbb{R}^{d_k} k∈Rdk​和值 v ∈ R d v \mathbf{v} \in \mathbb{R}^{d_v} v∈Rdv​,每个注意力头 h i \mathbf{h}_i hi​( i = 1 , … , h i = 1, \ldots, h i=1,…,h)的计算方法为:

h i = f ( W i ( q ) q , W i ( k ) k , W i ( v ) v ) ∈ R p v , \mathbf{h}_i = f(\mathbf W_i^{(q)}\mathbf q, \mathbf W_i^{(k)}\mathbf k,\mathbf W_i^{(v)}\mathbf v) \in \mathbb R^{p_v}, hi​=f(Wi(q)​q,Wi(k)​k,Wi(v)​v)∈Rpv​,

其中,可学习的参数包括 W i ( q ) ∈ R p q × d q \mathbf W_i^{(q)}\in\mathbb R^{p_q\times d_q} Wi(q)​∈Rpq​×dq​、 W i ( k ) ∈ R p k × d k \mathbf W_i^{(k)}\in\mathbb R^{p_k\times d_k} Wi(k)​∈Rpk​×dk​和 W i ( v ) ∈ R p v × d v \mathbf W_i^{(v)}\in\mathbb R^{p_v\times d_v} Wi(v)​∈Rpv​×dv​,以及代表注意力汇聚的函数 f f f。
f f f 可以是之前学习的加性注意力缩放点积注意力。多头注意力的输出需要经过另一个线性转换,它对应着 h h h 个头连结后的结果,因此其可学习参数是 W o ∈ R p o × h p v \mathbf W_o\in\mathbb R^{p_o\times h p_v} Wo​∈Rpo​×hpv​:

W o [ h 1 ⋮ h h ] ∈ R p o . \mathbf W_o \begin{bmatrix}\mathbf h_1\\\vdots\\\mathbf h_h\end{bmatrix} \in \mathbb{R}^{p_o}. Wo​⎣ ⎡​h1​⋮hh​​⎦ ⎤​∈Rpo​.

基于这种设计,每个头都可能会关注输入的不同部分,可以表示比简单加权平均值更复杂的函数。

import math
import torch
from torch import nn
from d2l import torch as d2l

实现

在实现过程中,我们选择缩放点积注意力作为每一个注意力头。为了避免计算代价和参数代价的大幅增长,我们设定 p q = p k = p v = p o / h p_q = p_k = p_v = p_o / h pq​=pk​=pv​=po​/h。值得注意的是,如果我们将查询、键和值的线性变换的输出数量设置为 p q h = p k h = p v h = p o p_q h = p_k h = p_v h = p_o pq​h=pk​h=pv​h=po​,则可以并行计算 h h h 个头。在下面的实现中, p o p_o po​是通过参数 num_hiddens 指定的。

class MultiHeadAttention(nn.Module):"""多头注意力"""def __init__(self, key_size, query_size, value_size, num_hiddens,num_heads, dropout, bias=False, **kwargs):super(MultiHeadAttention, self).__init__(**kwargs)self.num_heads = num_headsself.attention = d2l.DotProductAttention(dropout)self.W_q = nn.Linear(query_size, num_hiddens, bias=bias)self.W_k = nn.Linear(key_size, num_hiddens, bias=bias)self.W_v = nn.Linear(value_size, num_hiddens, bias=bias)self.W_o = nn.Linear(num_hiddens, num_hiddens, bias=bias)def forward(self, queries, keys, values, valid_lens):# queries, keys, values的形状:# (batch_size,查询或“键-值”对的个数,num_hiddens)# valid_len 的形状:# (batch_size,)或(batch_size,查询的个数)# 经过变换后,输出的queries,keys,values的形状:# (batch_size*num_heads,查询或“键-值”个数,num_hiddens/num_head)queries = transpose_qkv(self.W_q(queries), self.num_heads)keys = transpose_qkv(self.W_k(keys), self.num_heads)values = transpose_qkv(self.W_v(values), self.num_heads)if valid_lens is not None:# 在轴0,将第一项(标量或矢量) 复制 num_heads次,# 然后如此复制第二项,然后诸如此类valid_lens = torch.repeat_interleave(valid_lens,repeats=self.num_heads,dim=0)# output的形状:(batch_size*num_heads, 查询个数,num_hiddens/num_head)output = self.attention(queries, keys, values, valid_lens)# output_concat的形状:(batch_size, 查询个数,num_hiddens)output_concat = transpose_output(output, self.num_heads)return self.W_o(output_concat)

为了能够使多个头并行计算,上面的 MultiHeadAttention 类将使用下面定义的两个转置函数。具体来说,transpose_output 函数反转了 transpose_qkv 函数的操作。

def transpose_qkv(X, num_heads):"""为了多头注意力的并行计算而变换形状"""# 输入X的形状(batch_size, 查询或”键-值“对的个数,num_hiddens)# 输出X的形状(batch_size,查询或”键-值“对的个数,# num_heads,num_hiddens/num_heads)X = X.reshape(X.shape[0], X.shape[1], num_heads, -1)# 输出X的形状(batch_size,# num_heads,查询或”键-值“对的个数,num_hiddens/num_heads)X = X.permute(0, 2, 1, 3)# 输出X的形状(batch_size*num_heads,# 查询或”键-值“对的个数,num_hiddens/num_heads)return X.reshape(-1, X.shape[2], X.shape[3])def transpose_output(X, num_heads):"""逆转transpose_qkv函数的操作"""# 输入X的形状(batch_size*num_heads,# 查询或”键-值“对的个数,num_hiddens/num_heads)# 输出X的形状(batch_size,# num_heads,查询或”键-值“对的个数,num_hiddens/num_heads)X = X.reshape(-1, num_heads, X.shape[1], X.shape[2])# 输出X的形状(batch_size,查询或”键-值“对的个数,# num_heads,num_hiddens/num_heads)X = X.permute(0, 2, 1, 3)# 输出X的形状(batch_size,查询或”键-值“对的个数,num_hiddens)return X.reshape(X.shape[0], X.shape[1], -1)

下面我们使用键和值相同的小例子来测试我们编写的 MultiHeadAttention 类。多头注意力输出的形状是 (batch_size,num_queries, num_hiddens)。

num_hiddens, num_heads = 100, 5
attention = MultiHeadAttention(num_hiddens, num_hiddens, num_hiddens,num_hiddens, num_heads, 0.5)
attention.eval()
MultiHeadAttention((attention): DotProductAttention((dropout): Dropout(p=0.5, inplace=False))(W_q): Linear(in_features=100, out_features=100, bias=False)(W_k): Linear(in_features=100, out_features=100, bias=False)(W_v): Linear(in_features=100, out_features=100, bias=False)(W_o): Linear(in_features=100, out_features=100, bias=False)
)
batch_size, num_queries = 2, 4
num_kvpairs, valid_lens = 6, torch.tensor([3, 2])
X = torch.ones((batch_size, num_queries, num_hiddens))
Y = torch.ones((batch_size, num_kvpairs, num_hiddens))
attention(X, Y, Y, valid_lens).shape
torch.Size([2, 4, 100])

小结

1、多头注意力融合了来自于多个注意力汇聚的不同知识,这些知识的不同来源于相同的查询、键和值的不同的子空间表示。

2、基于适当的张量操作,可以实现多头注意力的并行计算。

Multihead Attention - 多头注意力相关推荐

  1. Transformer结构解读(Multi-Head Attention、AddNorm、Feed Forward)

    咱们还是照图讨论,transformer结构图如下,本文主要讨论Encoder部分,从低端输入inputs开始,逐个结构进行: 图一 一.首先说一下Encoder的输入部分: 在NLP领域,个人理解, ...

  2. 深入理解深度学习——Transformer:解码器(Decoder)的多头注意力层(Multi-headAttention)

    分类目录:<深入理解深度学习>总目录 相关文章: ·注意力机制(Attention Mechanism):基础知识 ·注意力机制(Attention Mechanism):注意力汇聚与Na ...

  3. 深入理解深度学习——注意力机制(Attention Mechanism):带掩码的多头注意力(Masked Multi-head Attention)

    分类目录:<深入理解深度学习>总目录 相关文章: ·注意力机制(AttentionMechanism):基础知识 ·注意力机制(AttentionMechanism):注意力汇聚与Nada ...

  4. Attention,Multi-head Attention--注意力,多头注意力详解

    Attention 首先谈一谈attention. 注意力函数其实就是把一个query,一个key-value的集合映射成一个输出.其中query,key,value,output(Attention ...

  5. 【基础整理】attention:浅谈注意力机制与自注意力模型(附键值对注意力 + 多头注意力)

    划水休息两天不看论文了 ~ 来重新复习一下基础qaq 以下讲解参考大名鼎鼎的 nndl 邱锡鹏 <神经网络与深度学习> 部分内容(详见第八章,注意力与外部记忆)是对于不太行的初学者也比较友 ...

  6. Transformer、多头注意力机制学习笔记:Attention is All You Need.

    文章目录 相关参考连接: https://blog.csdn.net/hpulfc/article/details/80448570 https://blog.csdn.net/weixin_4239 ...

  7. 比标准Attention快197倍!Meta推出多头注意力机制“九头蛇”

    点击上方"视学算法",选择加"星标"或"置顶" 重磅干货,第一时间送达 丰色 发自 凹非寺 量子位 | 公众号 QbitAI 尽管Trans ...

  8. 自注意力(Self-Attention)与Multi-Head Attention机制详解

    自注意力机制属于注意力机制之一.与传统的注意力机制作用相同,自注意力机制可以更多地关注到输入中的关键信息.self-attention可以看成是multi-head attention的输入数据相同时 ...

  9. 通过7个版本的attention的变形,搞懂transformer多头注意力机制

    --1-- Transformer模型架构 Transformer 由两个独立的模块组成,即Encoder和Decoder Encoder 编码器是一个堆叠N个相同的层.每层由两个子层组成,第一个是多 ...

最新文章

  1. nohup和的区别与关系
  2. PM到底做什么(What Do Product Managers Do?)
  3. 苹果开发账号过期不续费会怎样?
  4. 封神-运维大脑 | 日志检测工具
  5. C语言小算法:ACSII码(多字节)和Unicode(宽字节)互转
  6. 基于Python实现相关分析案例
  7. Atitit.每周末总结 于每周一计划日程表 流程表 v8 -------------import 上周遗漏日志补充 检查话费 检查流量情况 Crm问候 Crm表total and 问候
  8. RedHat 企业版5下系统故障恢复
  9. 慢性病管理系统/案列/APP/小程序/网站
  10. Hough变换的理解
  11. 利用PYTHON计算偏相关系数(Partial correlation coefficient)
  12. 解决Ubuntu19.04下网易云音乐打不开的问题
  13. 后盾网div+css,css定位(后盾网)
  14. 模拟摄影测量和数字摄影测量
  15. 关于TI XDS100V1和XDS100V3仿真器电脑无法识别的解决办法
  16. java fuoco车架_为速度而生 JAVA Fuoco铝合金气动公路
  17. 怎么将word转换成excel表格格式最简单
  18. num_workers
  19. 苹果计算机快捷键设置,那些你必须熟悉苹果电脑的快捷键,你知道吗?
  20. ESP32开发之旅——人体感应传感器HC-SR501

热门文章

  1. 【优化覆盖】基于matlab入侵杂草和花授粉混合算法无线传感器覆盖优化问题【含Matlab源码 1328期】
  2. 【Java】SE总结
  3. PAT乙级 1018 锤子剪刀布
  4. Java同一个线程对象能否多次调用start方法
  5. nodeType(节点类型) 属性值说明
  6. mysql源码分析——VIO数据结构
  7. 我的2022年度状态总结(Formal ver. )
  8. 设计模式之设计原则与思想:设计原则(二)
  9. 前端JS解决CST时间格式转成正常
  10. Win32 API 封装类总结