多头总框架

多头注意力融合了来自于多个注意力汇聚的不同知识,这些知识的不同来源于相同的查询、键和值的不同的子空间表示。

而多头self-attention模块,则是将Q,K,V通过参数矩阵映射后(即Q,K,V分别接一个全连接层),通过张量操作(X.reshape())将张量变换为可以实现多个头并行计算的样子,然后再做self-attention,将这个过程重复h(原论文中h=8)次,最后再将所有的结果拼接起来,再送入一个全连接层即可,图示如上:

代码解析:

 1、经过参数矩阵映射(即Q,K,V分别接一个全连接层):

self.W_q = nn.Linear(query_size,  num_hiddens, bias=bias)
self.W_k = nn.Linear(key_size,    num_hiddens, bias=bias)
self.W_v = nn.Linear(value_size,  num_hiddens, bias=bias)

2、通过张量操作将张量变换为可以实现多个头并行计算的样子(即下面的transpose_qkv函数) :

queries = transpose_qkv(self.W_q(queries), self.num_heads)
keys =    transpose_qkv(self.W_k(keys),    self.num_heads)
values =  transpose_qkv(self.W_v(values),  self.num_heads)

3、然后再做self-attention:

output = self.attention(queries, keys, values, valid_lens)

4、将这个过程重复h(原论文中h=8)次(即上述的valid_lens实现重复h次):

# valid_lens 的形状: (batch_size,)或(batch_size,查询的个数)
if valid_lens is not None:# 在轴0,将第一项(标量或者矢量)复制num_heads次,然后如此复制第二项,然后诸如此类。valid_lens = torch.repeat_interleave(valid_lens, repeats=self.num_heads, dim=0)

5、最后再将所有的结果拼接起来,再送入一个全连接层即可:

output_concat = transpose_output(output, self.num_heads)
return self.W_o(output_concat)

其中W_o是:

self.W_o = nn.Linear(num_hiddens, num_hiddens, bias=bias)

总体代码如下:

import torch
from torch import nn
from d2l import torch as d2lclass MultiHeadAttention(nn.Module):"""多头注意力"""def __init__(self, key_size, query_size, value_size, num_hiddens, num_heads, dropout, bias=False, **kwargs):super(MultiHeadAttention, self).__init__(**kwargs)self.num_heads = num_heads# 此处使用缩放点积注意力作为每一个注意力头self.attention = d2l.DotProductAttention(dropout)self.W_q = nn.Linear(query_size,  num_hiddens, bias=bias)self.W_k = nn.Linear(key_size,    num_hiddens, bias=bias)self.W_v = nn.Linear(value_size,  num_hiddens, bias=bias)self.W_o = nn.Linear(num_hiddens, num_hiddens, bias=bias)def forward(self, queries, keys, values, valid_lens):# queries,keys,values的形状: (batch_size,查询或者“键-值”对的个数,num_hiddens)# 经过变换后,输出的queries,keys,values 的形状: (batch_size * num_heads,查询或者“键-值”对的个数, num_hiddens/num_heads)queries = transpose_qkv(self.W_q(queries), self.num_heads)keys =    transpose_qkv(self.W_k(keys),    self.num_heads)values =  transpose_qkv(self.W_v(values),  self.num_heads)# valid_lens 的形状: (batch_size,)或(batch_size,查询的个数)if valid_lens is not None:# 在轴0,将第一项(标量或者矢量)复制num_heads次,然后如此复制第二项,然后诸如此类。valid_lens = torch.repeat_interleave(valid_lens, repeats=self.num_heads, dim=0)# output的形状:(batch_size*num_heads,查询的个数,num_hiddens/num_heads)output = self.attention(queries, keys, values, valid_lens)# output_concat的形状:(batch_size,查询的个数,num_hiddens)output_concat = transpose_output(output, self.num_heads)return self.W_o(output_concat)

上述代码31行有“ output = self.attention(queries, keys, values, valid_lens) ” ,其实这一步便是完成以下公式的操作,所谓多头,只是多个attention同时计算罢了


通过张量操作实现多个头并行计算

基于适当的张量操作,可以实现多头注意力的并行计算

为了能够使多个头并行计算, 上面的MultiHeadAttention类将使用下面定义的两个转置函数(transpose_qkv与transpose_output)。

具体来说,transpose_output函数反转了transpose_qkv函数的操作。

#######################################################################
#### 为了能够使多个头并行计算,上面的MultiHeadAttention类将使用下面定义的两个转置函数。
#### 具体来说,transpose_output函数反转了transpose_qkv函数的操作。
def transpose_qkv(X, num_heads):"""为了多注意力头的并行计算而变换形状"""# 输入X的形状:(batch_size,查询或者“键-值”对的个数,num_hiddens)# 输出X的形状:(batch_size,查询或者“键-值”对的个数,num_heads,num_hiddens/num_heads)X = X.reshape(X.shape[0], X.shape[1], num_heads, -1)# 输出X的形状:(batch_size,num_heads,查询或者“键-值”对的个数, num_hiddens/num_heads)X = X.permute(0, 2, 1, 3)# 最终输出的形状:(batch_size * num_heads,查询或者“键-值”对的个数, num_hiddens/num_heads)return X.reshape(-1, X.shape[2], X.shape[3])def transpose_output(X, num_heads):"""逆转transpose_qkv函数的操作"""X = X.reshape(-1, num_heads, X.shape[1], X.shape[2])X = X.permute(0, 2, 1, 3)return X.reshape(X.shape[0], X.shape[1], -1)

测试

下面我们使用键和值相同的小例子来测试我们编写的MultiHeadAttention类。 多头注意力输出的形状是(batch_sizenum_queriesnum_hiddens)。


#### 下面我们使用键和值相同的小例子来测试我们编写的MultiHeadAttention类。 多头注意力输出的形状是(batch_size,num_queries,num_hiddens)
num_hiddens, num_heads = 100, 5
# key_size, query_size, value_size,与num_hiddens相同
attention = MultiHeadAttention(num_hiddens, num_hiddens, num_hiddens, num_hiddens, num_heads, 0.5)
print(attention.eval())batch_size, num_queries = 2, 4
num_kvpairs, valid_lens = 6, torch.tensor([3, 2])
X = torch.ones((batch_size, num_queries, num_hiddens)) # (2,4,100)
Y = torch.ones((batch_size, num_kvpairs, num_hiddens)) # (2,6,100)
print(attention(X, Y, Y, valid_lens).shape)

10.5. 多头注意力 — 动手学深度学习 2.0.0-beta0 documentation

精(李沐)多头注意力,代码理解相关推荐

  1. 多头注意力代码解读(非常好的一个版本)

    初始化阶段, 其中要注意的是 hid_dim要和Q.K.V词向量的长度相等 import torch from torch import nnclass MultiheadAttention(nn.M ...

  2. 基于Pycharm运行李沐老师的深度学习课程代码

    最近在b站看李沐老师的深度学习课程,受益颇多.不过觉得光看视频实在是不过瘾,最好还是能实际的玩起来.鉴于我还是习惯使用pycharm,且不需要过多的中间过程展示,所以代码的编写基本都是在pycharm ...

  3. 李沐动手学深度学习:08 线性回归(代码逐行理解)

    目录 一.相关资料连接 1.1 李沐视频 1.2 代码.PPT 二.代码及笔记(使用Jupyter Notebook) 2.1 线性回归从零开始实现 2.1.1 基本概念 2.1.2 基础优化算法 2 ...

  4. 【深度学习】跟李沐学ai 线性回归 从零开始的代码实现超详解

    目录 一.引言 二.本文代码做了什么 如何利用数据集训练 三.代码实现与解析 一.导包 二.相应的函数实现 1 生成样本(数据集) 2 按批量读取数据集 3 定义模型 损失函数 算法 1 定义模型 2 ...

  5. 李沐「动手学深度学习」中文课程笔记来了!代码还有详细中文注释

    关注公众号,发现CV技术之美 本文转自机器之心,编辑张倩. markdown笔记与原课程视频一一对应,Jupyter代码均有详细中文注释,这份学习笔记值得收藏. 去年年初,机器之心知识站上线了亚马逊资 ...

  6. 多头注意力机制的理解

    先来看图: 从图片中可以看出V K Q 是固定的单个值,而Linear层有3个,Scaled Dot-Product Attention 有3个,即3个多头:最后cancat在一起,然后Linear层 ...

  7. 【Transformer 相关理论深入理解】注意力机制、自注意力机制、多头注意力机制、位置编码

    目录 前言 一.注意力机制:Attention 二.自注意力机制:Self-Attention 三.多头注意力机制:Multi-Head Self-Attention 四.位置编码:Positiona ...

  8. 李沐动手学深度学习V2-全卷积网络FCN和代码实现

    一.全卷积网络FCN 1. 介绍 语义分割是对图像中的每个像素分类,全卷积网络(fully convolutional network,FCN)采用卷积神经网络实现了从图像像素到像素类别的变换 ,与前 ...

  9. 深入理解深度学习——Transformer:解码器(Decoder)的多头注意力层(Multi-headAttention)

    分类目录:<深入理解深度学习>总目录 相关文章: ·注意力机制(Attention Mechanism):基础知识 ·注意力机制(Attention Mechanism):注意力汇聚与Na ...

最新文章

  1. pandas dropna
  2. 【Visual C++】游戏开发笔记四十一 浅墨DirectX教程之九 为三维世界添彩:纹理映射技术(一)...
  3. How HBO’s Silicon Valley built “Not Hotdog” with mobile TensorFlow, Keras React Native
  4. rds_dbsync数据源同步工具
  5. SQL中Truncate的用法
  6. silverlight: [HtmlPage_NotEnabled] 调试资料字符串不可用的解决
  7. 英特尔:赔你15亿算了;Nvidia:反正我早就不做你那块了
  8. SQL Server 中死锁产生的原因及解决办法
  9. Lingo多版本下载地址和安装教程
  10. Vscode搭建jdk源码阅读环境 wsl
  11. Ubuntu 查看本机IP地址
  12. 第三十三章 SQL命令 DROP INDEX
  13. 此文对你人生会有莫大好处的,建议永久保存
  14. linux 家用路由器,饱受折磨的家用路由器 | 在研究的127个家用路由器中,没有一个路由器幸免...
  15. Teardrop代码编程
  16. Carson带你学Android:你要的WebView与 JS 交互方式都在这里了
  17. 【题解】【循环】幂级数求和
  18. 遍身罗绮者 不是养蚕人
  19. ubuntu中抓包工具tcpdump使用详解
  20. Android中 自定义logo二维码绘制(仿微信QQ二维码)

热门文章

  1. JEECG 3.7 新装亮相,移动APP发布
  2. ORACLE START WITH 语句的树级结构例子
  3. Java知识点汇总1
  4. Redis-列表(List)基础
  5. Document API
  6. 再不懂ZooKeeper,就安安心心把这篇文章看完
  7. 华为VLAN间互访配置
  8. code第一部分:数组
  9. 找单词(母函数问题)
  10. java.net.SocketException: Broken pipe问题解决