参考资料:3D视角剖析Attention
https://zhuanlan.zhihu.com/p/441240252
Sebastian Raschka大佬的博客
https://sebastianraschka.com/blog/2023/self-attention-from-scratch.html
完整代码及可视化图片见,必看:https://nbviewer.org/github/lryup/self-learning/blob/main/self_attention.ipynb

想法来源:现实生活中,我们对世界的认知,90%左右来自于视觉感知,当我们放眼望去,无数的风景一扫而过,而真正引起关注的少之又少。如万千花丛中,我仅会关注最大、最艳的那朵;所谓众里寻她千百度,万眼直盯她脸部。大部分都只是背景色,毫不重要。注意力机制也就是关注其中最想关注的地方。

注意力机制的重要性:最近爆火的自然语言对话chatGPT已经颠覆式反应强人工智能时代的到来,他无所不能,能有逻辑,有条理的回答各种问题,甚至其水平已经超越大部分大学生。你问他用了什么架构,以及如何学懂chatGPT,以下是他的回复。可想而知注意力机制的重要性。

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-oWuW4AnC-1680052040869)(attachment:image.png)]

自注意力机制(Self-Attention)
例子:如英文这句话,“Life is short, eat dessert first”,人生苦短,先吃甜点吧。
1.先对句子进行编码
(1)因编码要求输入的是数字,我们对字符进行数字化,一个单词用一个数字表示。

sentence = 'Life is short, eat dessert first'
dc = {s:i for i,s in enumerate(sorted(sentence.replace(',', '').split()))}
print(dc)
{'Life': 0, 'dessert': 1, 'eat': 2, 'first': 3, 'is': 4, 'short': 5}
import torch
sentence_int = torch.tensor([dc[s] for s in sentence.replace(',', '').split()])
print(sentence_int)
tensor([0, 4, 5, 2, 1, 3])

将每个整形数字进行向量编码(vector embedding);假设每个数字编码为16维,因此,6个数字编码维度为6*16维。

torch.manual_seed(123)
embed = torch.nn.Embedding(6, 16)
embedded_sentence = embed(sentence_int).detach()
print(embedded_sentence)
print(embedded_sentence.shape)
tensor([[ 0.3374, -0.1778, -0.3035, -0.5880,  0.3486,  0.6603, -0.2196, -0.3792,0.7671, -1.1925,  0.6984, -1.4097,  0.1794,  1.8951,  0.4954,  0.2692],[ 0.5146,  0.9938, -0.2587, -1.0826, -0.0444,  1.6236, -2.3229,  1.0878,0.6716,  0.6933, -0.9487, -0.0765, -0.1526,  0.1167,  0.4403, -1.4465],[ 0.2553, -0.5496,  1.0042,  0.8272, -0.3948,  0.4892, -0.2168, -1.7472,-1.6025, -1.0764,  0.9031, -0.7218, -0.5951, -0.7112,  0.6230, -1.3729],[-1.3250,  0.1784, -2.1338,  1.0524, -0.3885, -0.9343, -0.4991, -1.0867,0.8805,  1.5542,  0.6266, -0.1755,  0.0983, -0.0935,  0.2662, -0.5850],[-0.0770, -1.0205, -0.1690,  0.9178,  1.5810,  1.3010,  1.2753, -0.2010,0.4965, -1.5723,  0.9666, -1.1481, -1.1589,  0.3255, -0.6315, -2.8400],[ 0.8768,  1.6221, -1.4779,  1.1331, -1.2203,  1.3139,  1.0533,  0.1388,2.2473, -0.8036, -0.2808,  0.7697, -0.6596, -0.7979,  0.1838,  0.2293]])
torch.Size([6, 16])

(2)定义权重矩阵( W q , W k , W v W_q,W_k,W_v WqWk,Wv
自注意力需要三个权重矩阵:Wq,Wk,Wv,三个权重可以根据模型训练调节。分别将三个权重矩阵点乘输入向量x
(上面6*16的向量编码),将其投影得到query,key,value序列;
[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-l1tIQpt6-1680052040870)(attachment:image.png)]
索引i指的是在长度为T的输入序列中的token(词)索引位置.
[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-aDpgzMtR-1680052040870)(attachment:image-2.png)]

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-sQEbp1Uy-1680052040870)(attachment:image-3.png)]
d k d_k dk这个维度受W的控制,这里和上面向量编码d维(16)不一样。权重W的维度设置可以对原始输入向量的维度进行降维或者升维控制。假如这里dq,dk设置24维。dv设置28维。从这看出这里是将16维进行升维。

torch.manual_seed(123)
d = embedded_sentence.shape[1]
d_q, d_k, d_v = 24, 24, 28
W_query = torch.rand(d_q, d)
W_key = torch.rand(d_k, d)
W_value = torch.rand(d_v, d)
W_query.shape,W_key.shape,W_value.shape
(torch.Size([24, 16]), torch.Size([24, 16]), torch.Size([28, 16]))

计算非归一化注意力权重
假设我们对第二个输入元素的注意向量感兴趣——那第二个输入元素在这里充当query:
[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-bzTwzAT7-1680052040871)(attachment:image.png)]
例子:以第二个单词为例。

x_2 = embedded_sentence[1]
query_2 = W_query.matmul(x_2)#[24, 16]*[16]=[24]
key_2 = W_key.matmul(x_2)
value_2 = W_value.matmul(x_2)print(query_2.shape)
print(key_2.shape)
print(value_2.shape)
torch.Size([24])
torch.Size([24])
torch.Size([28])

然后我们可以将其推广到计算剩余的key,以及所有输入元素的value,因为我们将在下一步计算非标准化注意力权重ω时需要它们:

keys = W_key.matmul(embedded_sentence.T).T
values = W_value.matmul(embedded_sentence.T).Tprint("keys.shape:", keys.shape)
print("values.shape:", values.shape)
keys.shape: torch.Size([6, 24])
values.shape: torch.Size([6, 28])

现在我们已经有了所有需要的键和值,我们可以进行下一步,计算非标准化注意力权重ω,如下图所示:
[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-4n4rB6Ck-1680052040871)(attachment:image-2.png)]
将query2的值分别与不同的key值相乘,就可以知道当前所查询对象与其他词的注意力权重大小。如下:(11.14:无归一化)

omega_24 = query_2.dot(keys[4])
print(omega_24)
tensor(11.1466)
query_2.shape,keys.shape
(torch.Size([24]), torch.Size([6, 24]))

因为我们将需要这些来计算注意力分数,让我们计算所有输入token(词)的ω值,如上图所示

omega_2 = query_2.matmul(keys.T)
print(omega_2)
tensor([ 8.5808, -7.6597,  3.2558,  1.0395, 11.1466, -0.4800])
torch.sqrt(torch.tensor(24)),omega_2[0]/torch.sqrt(torch.tensor(24))
(tensor(4.8990), tensor(1.7515))
omega_2/torch.sqrt(torch.tensor(24))
tensor([ 1.7515, -1.5635,  0.6646,  0.2122,  2.2753, -0.0980])

如果是所有的,通过计算Q*K得到, 6 ∗ 6 6*6 66的矩阵。如下以四个单词为列
[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-VSrbTwmK-1680052040872)(attachment:image.png)]

计算注意力得分
可以看出通过上面计算的ω值差异很大,我们需要对其进行规范化(softmax)得到α。其中1/sqrt(dk)对ω进行缩放(确保权向量的欧几里得长度大致相同,这有助于防止注意力权重变得太小或太大,这可能导致数值不稳定或影响模型在训练期间的收敛能力。),然后通过softmax函数进行归一化[0-1],如下所示:
[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-lVDnBAcc-1680052040872)(attachment:image.png)]
进一步,我们可以计算注意力权重如下,总和相加为1;

import torch.nn.functional as Fattention_weights_2 = F.softmax(omega_2 / d_k**0.5, dim=0)
attention_weights_2,sum(attention_weights_2)
(tensor([0.2912, 0.0106, 0.0982, 0.0625, 0.4917, 0.0458]), tensor(1.))

最终,最后一步是计算上下文向量(context vector) z ( 2 ) z^{(2)} z(2)
这是我们原始query输入 x ( 2 ) x^{(2)} x(2)的注意力加权版本,包括所有其他输入元素作为它的上下文,通过注意力权重:
[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-CumxWFFN-1680052040873)(attachment:image.png)]

context_vector_2 = attention_weights_2.matmul(values)
print(context_vector_2.shape)
print(context_vector_2)
attention_weights_2.shape,values.shape
torch.Size([28])
tensor([-1.5993,  0.0156,  1.2670,  0.0032, -0.6460, -1.1407, -0.4908, -1.4632,0.4747,  1.1926,  0.4506, -0.7110,  0.0602,  0.7125, -0.1628, -2.0184,0.3838, -2.1188, -0.8136, -1.5694,  0.7934, -0.2911, -1.3640, -0.2366,-0.9564, -0.5265,  0.0624,  1.7084])(torch.Size([6]), torch.Size([6, 28]))

注意:这里输出维度 d v = 28 d_v=28 dv=28,多余原始输入的16维。这里的 d v d_v dv维度是任意的。

多头注意力
多头注意力机制和自注意力机制有什么关联?
带缩放的点积注意力机制(scaled dot-product attention)中(也就是上面softmax后的值),
输入序列使用query,key,和value的三个矩阵进行转换。
这三个矩阵可以看作是多头注意环境下的单个注意头。下图总结了我们之前提到的这个关注点
[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-PP3SDueY-1680052040873)(attachment:image.png)]

也就是说,多头注意力涉及多个这样的头部,每个头部由query、key和value矩阵组成。
这个概念类似于卷积神经网络中多核的使用(多个卷积核,也就是W多个)。多头的意思(比如一个头注意到了该序列中的名词,虽然也能用作表示,但是总归有信息损失,这时就可以让另外一个头注意动词,介词等等,当然还有词之间的指代关系等其它的复杂的关系,如此一来,整个序列的信息被多次收集,整合多个头的信息,得到相较来说更完整的序列信息。)
[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-UjdCHDlb-1680052040874)(attachment:image.png)]
为了在代码中说明这一点,假设我们有3个注意力头,那么我们现在扩展 d ′ × d d'×d d×d维度权重矩阵 3 × d ′ × d 3×d'×d 3×d×d:

h = 3
multihead_W_query = torch.rand(h, d_q, d)
multihead_W_key = torch.rand(h, d_k, d)
multihead_W_value = torch.rand(h, d_v, d)
multihead_W_query.shape,multihead_W_key.shape,multihead_W_value.shape
(torch.Size([3, 24, 16]), torch.Size([3, 24, 16]), torch.Size([3, 28, 16]))

因此,每个查询元素现在都是 3 × d q 3×d_q 3×dq维度,其中 d q = 24 d_q=24 dq=24(在这里,让我们将重点放在索引位置2对应的第3个元素上):

multihead_query_2 = multihead_W_query.matmul(x_2)
print(multihead_query_2.shape)
x_2.shape
torch.Size([3, 24])torch.Size([16])

然后,我们可以类似的方式获取key和value:

multihead_key_2 = multihead_W_key.matmul(x_2)
multihead_value_2 = multihead_W_value.matmul(x_2)
multihead_key_2.shape,multihead_value_2.shape
(torch.Size([3, 24]), torch.Size([3, 28]))

现在,这些key和value元素是特定于query元素的。但是,与前面类似,我们还需要其他序列元素的value和key,以便计算query的注意力分数。我们可以通过将输入序列嵌入扩展到大小3,即注意头的数量:

stacked_inputs = embedded_sentence.T.repeat(3, 1, 1)#复制3份
print(stacked_inputs.shape)
torch.Size([3, 16, 6])

现在,我们可以通过torch.bmm()(批量矩阵乘法)来计算所有的键和值:

multihead_keys = torch.bmm(multihead_W_key, stacked_inputs)
multihead_values = torch.bmm(multihead_W_value, stacked_inputs)
print("multihead_keys.shape:", multihead_keys.shape)
print("multihead_values.shape:", multihead_values.shape)
multihead_keys.shape: torch.Size([3, 24, 6])
multihead_values.shape: torch.Size([3, 28, 6])

现在我们有张量来表示第一个维度中的三个注意力头。第三维度和第二个维度分别是字数和嵌入尺寸。为了使值和键更直观地解释,我们将交换第二个和第三个维度,从而得到与原始输入序列具有相同维度结构的张量,embedded_sentence:

multihead_keys = multihead_keys.permute(0, 2, 1)
multihead_values = multihead_values.permute(0, 2, 1)
print("multihead_keys.shape:", multihead_keys.shape)
print("multihead_values.shape:", multihead_values.shape)
multihead_keys.shape: torch.Size([3, 6, 24])
multihead_values.shape: torch.Size([3, 6, 28])

然后,我们按照与前面相同的步骤来计算未缩放的注意力权重(unscaled attention weights)ω和注意力权重α,然后进行scaled-softmax计算,以获得输入元素 x ( 2 ) x^{(2)} x(2)的上下文向量z,其维度为 h × d v h×d_v h×dv(这里: 3 × d v 3×d_v 3×dv)。

交叉注意力(Cross Attention)
在上面的代码演练中,我们设置 d q = d k = 24 和 d v = 28 d_q=d_k=24和d_v=28 dq=dk=24dv=28。换句话说,我们对query和key序列使用相同的维度。虽然value矩阵 W v W_v Wv通常被选择具有与query和key矩阵相同的维度(例如在PyTorch的MultiHeadAttention类中),但我们可以为value维度选择任意数字大小。
由于维度有时有点难以跟踪,让我们在下图中总结到目前为止我们所涉及的所有内容,它描述了单个注意力头的各种张量大小。
[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-5wWPsKlu-1680052040874)(attachment:image.png)]

现在,上面的插图对应于transformer中使用的自注意力机制。我们尚未讨论的这种注意机制的一种特殊形式是交叉注意。
[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-3DyjosxO-1680052040874)(attachment:image.png)]

什么是交叉注意力,他和自注意力有什么不同?
在自注意力中,使用相同的输入序列。在交叉注意力中,我们混合或组合两种不同的输入序列。在上面的原始Transformer架构中,左边是解码器模块返回的序列,右边是编码器部分处理的输入序列。
注意,在交叉注意力中,两个输入序列x1和x2可以有不同数量的元素。但是,它们的嵌入尺寸必须匹配。
下图说明了交叉注意的概念。如果我们设x1=x2,这就相当于自我注意。
[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-F65X1t26-1680052040875)(attachment:image.png)]

这在代码中是如何工作的呢?以前,当我们在本文开头实现自注意机制时,我们使用以下代码计算第二个输入元素的query以及所有key和value,如下所示:

torch.manual_seed(123)d = embedded_sentence.shape[1]
print("embedded_sentence.shape:", embedded_sentence.shape)d_q, d_k, d_v = 24, 24, 28W_query = torch.rand(d_q, d)
W_key = torch.rand(d_k, d)
W_value = torch.rand(d_v, d)x_2 = embedded_sentence[1]
query_2 = W_query.matmul(x_2)
print("query.shape", query_2.shape)keys = W_key.matmul(embedded_sentence.T).T
values = W_value.matmul(embedded_sentence.T).Tprint("keys.shape:", keys.shape)
print("values.shape:", values.shape)
embedded_sentence.shape: torch.Size([6, 16])
query.shape torch.Size([24])
keys.shape: torch.Size([6, 24])
values.shape: torch.Size([6, 28])

交叉注意的唯一变化是我们现在有了第二个输入序列,例如,第二个句子有8个输入元素而不是6个输入元素。这里,假设这是一个有8个tokens(词)的句子。

embedded_sentence_2 = torch.rand(8, 16) # 2nd input sequencekeys = W_key.matmul(embedded_sentence_2.T).T
values = W_value.matmul(embedded_sentence_2.T).Tprint("keys.shape:", keys.shape)
print("values.shape:", values.shape)
keys.shape: torch.Size([8, 24])
values.shape: torch.Size([8, 28])

注意,与self-attention相比,键和值现在有8行而不是6行。其他一切保持不变。
我们在上面讨论了很多语言Transformer。在原始的transformer架构中,当我们在语言翻译上下文中从输入句子转换到输出句子时,交叉注意是有用的。输入句子表示一个输入序列,翻译表示第二个输入序列(两个句子可以有不同的字数)。
另一个使用交叉注意力的流行模型是Stable Diffusion。Stable Diffusion使用U-Net模型中生成的图像和用于调节的文本提示之间的交叉注意,如使用潜在扩散模型的高分辨率图像合成中所述(https://arxiv.org/abs/2112.10752)-原始论文描述了Stable Diffusion模型,后来被Stability AI采用来实现流行的Stable Diffusion扩散模型。
[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-NQc3naOn-1680052040875)(attachment:image.png)]

结论
在这篇文章中,我们看到了自注意力是如何一步步编码的。然后,我们将这个概念扩展到广泛使用的大型语言转换器组件——多头注意力。在讨论了自注意和多头注意之后,我们介绍了另一个概念:交叉注意,这是一种自注意,我们可以应用在两个不同的序列之间。这已经是很多信息了。让我们把使用这个多头注意力块的神经网络训练留到以后的文章中。


众里寻她千百度,他眼仅观她脸处--无处不在的注意力机制(self-attention)相关推荐

  1. Asp.net Ajax Control Toolkit设计编程备忘录(色眼窥观版)——第5回(错不了专辑)

    色即设--设计,从网页设计师的角度出发.因为自己的的特殊性(本身是软件工程师,但是对网页设计却有浓厚的兴趣),所以我的此系列文章不仅仅从编程角度出发,还将从设计的角度出发来细数AjaxControlT ...

  2. Asp.net Ajax Control Toolkit设计编程备忘录(色眼窥观版)——第3回(UE专辑)

    前言: 色即设--设计,从网页设计的角度出发.因为自己的的特殊性(本身是软件工程师,但是对网页设计却有浓厚的兴趣),所以此系列文章不仅仅从编程角度出发,还将从设计的角度出发来细数AjaxControl ...

  3. Asp.net Ajax Control Toolkit设计编程备忘录(色眼窥观版)——第4回(忍者专辑)

    ====================================================== 注:本文源代码点此下载 ================================= ...

  4. “众里寻她千百度”情人节

     爱若天高似海深,情如地厚比水纯.沧海桑田常变化,唯有真情永长存. 2月14日,又是一年情人节,有人欢喜有人愁. 无论二人相约浪漫,还是一人以烟酒相伴,或疯狂,或遗忘-- 但我们每个人都很明白:爱,永 ...

  5. 软件测试架构师——众里寻她千百度

    "众里寻她千百度,蓦然回首,人却在灯火阑珊处",还算幸运.而对于"软件测试架构师", 众里寻她(他)千百度,那人何在?难以上青天. 软件测试架构师是一个新职位, ...

  6. 众里寻她千百度,那人却在优衣库。

    众里寻她千百度,那人却在优衣库. -------   我的第一篇微小说 ...  衣带渐宽终不悔,优衣脱的人憔悴.

  7. 软件测试架构师——众里寻她千百度 1

    "众里寻她千百度,蓦然回首,人却在灯火阑珊处",还算幸运.而对于"软件测试架构师", 众里寻她(他)千百度,那人何在?难以上青天. 软件测试架构师是一个新职位, ...

  8. 众里寻她千百度,蓦然回首,那bug却在灯火阑珊处

    今天发现consul上的A服务处于failed状态,幸运的是服务部署了两份,以预防单点故障,做负载均衡,连忙查看http://ip:port/health输出,内容如下: { "status ...

  9. 拉勾网引入百度 AI,上线全新企业及招聘者身份审核机制;AI 法律咨询服务系统落户厦门海沧...

    福建省首套人工智能法律咨询服务系统落户厦门海沧 雷锋网(公众号:雷锋网)消息 日前,海沧区司法局依托"法治海沧"微信公众号平台,在福建省率先上线了"智能海沧AI人工智能& ...

最新文章

  1. 关于FATFS文件系统挂载多个磁盘
  2. Android 线程池概念及使用
  3. SQL SERVER 2008安全配置
  4. 【C++设计技巧】C++中的RAII机制
  5. Java日历compareTo()方法与示例
  6. 【Drools二】打工人学习Drools基础语法
  7. 常见的通配符_技术干货 | 常见的mysql注入语句
  8. 黑白块游戏java代码_用java做的一个小游戏—黑白反斗棋(适合菜鸟)
  9. OpenStack本地存储选项的现在与未来
  10. WiFi----Wireshark抓包及分析说明
  11. creo减速器建模实例,减速箱proE整体及零件图
  12. 常见路径规划算法介绍
  13. 卸载控制面板(Control Panel)存在的重复程序(CrowdStrike)
  14. 声卡是HDA Intel,芯片为IDT 92HD81B1C5的ubuntu12.04下声音很小的解决方法
  15. lisp成套电气设计_针对电气成套行业的专业设计仿真软件
  16. Docker 2375 端口入侵服务器,部分解决方案
  17. 解决在win2003 enterprise上搭建IIS遇到的“需要Service Pack 2 CD-Rom 上的文件“问题
  18. python持续集成工具_21 个好用的持续集成工具,总有一款适合你
  19. 按照字符串长度大小进行升序排列
  20. 蓝牙「5.0」和「4.2」的区别???

热门文章

  1. html safari图片不显示,html - 某些FA图标显示在FF,Chrome和Safari中,但没有浏览器会全部显示它们 - 堆栈内存溢出...
  2. 安信证券 | 神州信息:金融软服增长较快,行业信创蓄势待发
  3. CentOS系列之Elasticsearch(二):查询
  4. dotnet 基于 debian 创建一个 docker 的 sdk 镜像
  5. 验证是否正确迁出CESM2
  6. 数据仓库建模方法/范式建模法/维度建模法/事实表/维度表/优缺点/建模流程/概念建模/逻辑建模/物理建模
  7. c++进制转换(完整)
  8. 如何做一个真正的男人
  9. https://wenku.baidu.com/view/24def725e53a580217fcf
  10. NXP i.MX 8M Mini处理器