文章目录

  • 注意力汇聚:Nadaraya-Watson核回归
    • 1 - 生成数据集
    • 2 - 平均汇聚
    • 3 - 非参数注意力汇聚
    • 4 - 带参数注意力汇聚
      • 批量矩阵乘法
      • 定义模型
      • 训练
    • 5 - 小结

注意力汇聚:Nadaraya-Watson核回归

框架下的注意力机制的主要成分:查询(自主提示)和键(非自主提示)之间交互形成了注意力汇聚,注意力汇聚有选择地聚合了值(感官输入)以生成最终的输出。在本节中,我们将介绍注意力汇聚的更多细节,以便从宏观上了解注意力机制在实践中的运作方式。1964年提出的Nadaraya-Watson核回归模型是⼀个简单但完整的例⼦,可以⽤于演⽰具有注意⼒机制的机器学习

import torch
from torch import nn
from d2l import torch as d2l

1 - 生成数据集

n_train = 50 # 训练样本数
x_train,_ = torch.sort(torch.rand(n_train) * 5) # 排序后的训练样本
def f(x):return 2 * torch.sin(x) + x**0.8y_train = f(x_train) + torch.normal(0.0,0.5,(n_train,)) # 训练样本的输出
x_test = torch.arange(0,5,0.1) # 测试样本
y_truth = f(x_test) # 测试样本的真实输出
n_test = len(x_test) # 测试样本数
n_test
50

下面的函数将绘制所有的训练样本(样本由圆圈表示),不带噪声项的真实数据生成函数f(标记为“Truth”),以及学习得到的预测函数(标记为“Pred”)

def plot_kernel_reg(y_hat):d2l.plot(x_test,[y_truth,y_hat],'x','y',legend=['Truth','Pred'],xlim=[0,5],ylim=[-1,5])d2l.plt.plot(x_train,y_train,'o',alpha=0.5);

2 - 平均汇聚

y_hat = torch.repeat_interleave(y_train.mean(),n_test)
plot_kernel_reg(y_hat)


[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-j6TRUXSQ-1662988499736)(https://yingziimage.oss-cn-beijing.aliyuncs.com/img/202209122107812.svg)]

3 - 非参数注意力汇聚


# X_repeat的形状:(n_test,n_train)
# 每一行都包含着相同的测试输入(例如:同样的查询)
X_repeat = x_test.repeat_interleave(n_train).reshape((-1,n_train))# x_train包含着键。attention_weights的形状:(n_test,n_train),
# 每⼀⾏都包含着要在给定的每个查询的值(y_train)之间分配的注意⼒权重
attention_weights = nn.functional.softmax(-(X_repeat - x_train)**2 / 2,dim=1)
# y_hat的每个元素都是值的加权平均值,其中的权重是注意力权重
y_hat = torch.matmul(attention_weights,y_train)
plot_kernel_reg(y_hat)


[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-bHHxRgOe-1662988499737)(https://yingziimage.oss-cn-beijing.aliyuncs.com/img/202209122107813.svg)]

现在,我们来观察注意力的权重,这里测试数据的输入相当于查询,而训练数据的输入相当于键。因为两个输入都是经过排序的,因此由观察可知,“查询-键”对越接近,注意力汇聚的注意力权重就越高

d2l.show_heatmaps(attention_weights.unsqueeze(0).unsqueeze(0),xlabel='Sorted training inputs',ylabel='Sorted testing inputs')


[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-d4R52D1O-1662988499737)(https://yingziimage.oss-cn-beijing.aliyuncs.com/img/202209122107814.svg)]

4 - 带参数注意力汇聚

批量矩阵乘法

X = torch.ones((2,1,4))
Y = torch.ones((2,4,6))torch.bmm(X,Y).shape
torch.Size([2, 1, 6])

在注意力机制的背景中,我们可以使用小批量矩阵乘法来计算小批量数据中的加权平均值

weights = torch.ones((2,10)) * 0.1
values = torch.arange(20.0).reshape((2,10))
torch.bmm(weights.unsqueeze(1),values.unsqueeze(-1))
tensor([[[ 4.5000]],[[14.5000]]])

定义模型

基于带参数的注意力汇聚,使用小批量矩阵乘法,定义Nadaraya-Watson核回归的带参数版本为:

class NWKernelRegression(nn.Module):def __init__(self,**kwargs):super().__init__(**kwargs)self.w = nn.Parameter(torch.rand((1,),requires_grad = True))def forward(self,queries,keys,values):# queries和attention_weights的形状为(查询个数,“键-值”对个数)queries = queries.repeat_interleave(keys.shape[1]).reshape((-1,keys.shape[1]))self.attention_weights = nn.functional.softmax(-((queries - keys) * self.w)**2 /2 ,dim=1)# values的形状为(查询个数,“键-值”对个数)return torch.bmm(self.attention_weights.unsqueeze(1),values.unsqueeze(-1)).reshape(-1)

训练

接下来,将训练数据集变换为键和值用于训练注意力模型。在带参数的注意力汇聚模型中,任何一个训练样本的输入都会和除自己以外的所有训练样本的“键-值”对进行计算,从而得到其对应的预测输出

# X_tile的形状:(n_train,n_train),每一行都包含着相同的训练输入
X_tile = x_train.repeat((n_train,1))
# Y_tile的形状:(n_train,n_train),每一行都包含着相同的训练输出
Y_tile = y_train.repeat((n_train,1))
# keys的形状:('n_train','n_train'-1)
keys = X_tile[(1 - torch.eye(n_train)).type(torch.bool)].reshape((n_train, -1))
# values的形状:('n_train','n_train'-1)
values = Y_tile[(1 - torch.eye(n_train)).type(torch.bool)].reshape((n_train, -1))

训练带参数的注意力汇聚模型时,使用平方损失函数和随机梯度下降

net = NWKernelRegression()
loss = nn.MSELoss(reduction='none')
trainer = torch.optim.SGD(net.parameters(),lr=0.5)
animator = d2l.Animator(xlabel='epoch',ylabel='loss',xlim=[1,5])for epoch in range(5):trainer.zero_grad()l = loss(net(x_train,keys,values),y_train)l.sum().backward()trainer.step()print(f'epoch {epoch + 1}, loss {float(l.sum()):.6f}')animator.add(epoch + 1, float(l.sum()))


[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-5nsQsico-1662988499738)(https://yingziimage.oss-cn-beijing.aliyuncs.com/img/202209122107815.svg)]

如下所示,训练完带参数的注意力汇聚模型后,我们发现:在尝试拟合带噪声的训练数据时,预测结果绘制的线不如之前非参数模型的平滑

# keys的形状:(n_test,n_train),每⼀⾏包含着相同的训练输⼊(例如,相同的键)
keys = x_train.repeat((n_test, 1))
# value的形状:(n_test,n_train)
values = y_train.repeat((n_test, 1))
y_hat = net(x_test, keys, values).unsqueeze(1).detach()
plot_kernel_reg(y_hat)


[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-V9IQPxat-1662988499738)(https://yingziimage.oss-cn-beijing.aliyuncs.com/img/202209122107816.svg)]

为什么新的模型更不平滑了呢?我们看一下输出结果的绘制图:与非参数的注意力汇聚模型相比,带参数的模型加入可学习的参数后,曲线在注意力权重较大的区域变得更不平滑

d2l.show_heatmaps(net.attention_weights.unsqueeze(0).unsqueeze(0),xlabel='Sorted training inputs',ylabel='Sorted testing inputs')


[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-rRT7kL7O-1662988499738)(https://yingziimage.oss-cn-beijing.aliyuncs.com/img/202209122107817.svg)]

5 - 小结

  • Nadaraya-Watson核回归时具有注意力机制的机器学习范例
  • Nadaraya-Watson核回归的注意⼒汇聚是对训练数据中输出的加权平均。从注意力的角度来看,分配给每个值的注意力权重取决于你将值所对应的键核查询作为输入的函数
  • 注意力汇聚可以分为非参数型核带参数型

注意力机制 - 注意力汇聚:Nadaraya-Watson核回归相关推荐

  1. 89. 注意力机制以及代码实现Nadaraya-Waston 核回归

    1. 心理学 动物需要在复杂环境下有效关注值得注意的点 心理学框架:人类根据随意线索和不随意线索选择注意点 随意:随着自己的意识,有点强调主观能动性的意味. 2. 注意力机制 2. 非参注意力池化层 ...

  2. 注意力机制(一):注意力提示、注意力汇聚、Nadaraya-Watson 核回归

    专栏:神经网络复现目录 注意力机制 注意力机制(Attention Mechanism)是一种人工智能技术,它可以让神经网络在处理序列数据时,专注于关键信息的部分,同时忽略不重要的部分.在自然语言处理 ...

  3. 注意力汇聚:Nadaraya-Watson 核回归

    Nadaraya-Watson核回归是具有注意力机制的机器学习范例. Nadaraya-Watson核回归的注意力汇聚是对训练数据中输出的加权平均.从注意力的角度来看,分配给每个值的注意力权重取决于将 ...

  4. 10.2. 注意力汇聚:Nadaraya-Watson 核回归

    文章目录 10.2. 注意力汇聚:Nadaraya-Watson 核回归 10.2.1. 生成数据集 10.2.2. 平均汇聚 10.2.3. 非参数注意力汇聚 10.2.4. 带参数注意力汇聚 10 ...

  5. 注意力机制 - 注意力提示

    文章目录 注意力提示 1 - 生物学中的注意力提示 2 - 查询.键和值 3 - 注意力的可视化 4 - 小结 注意力提示 ⾃经济学研究稀缺资源分配以来,我们正处在"注意⼒经济"时 ...

  6. 注意力机制-深度学习中的注意力机制+注意力机制在自然语言处理中的应用

    1 深度学习中的注意力机制 https://mp.weixin.qq.com/s?__biz=MzA4Mzc0NjkwNA==&mid=2650783542&idx=1&sn= ...

  7. 【动手学深度学习】(task123)注意力机制剖析

    note 将注意力汇聚的输出计算可以作为值的加权平均,选择不同的注意力评分函数会带来不同的注意力汇聚操作. 当查询和键是不同长度的矢量时,可以使用可加性注意力评分函数.当它们的长度相同时,使用缩放的& ...

  8. 【动手深度学习-笔记】注意力机制(一)注意力机制框架

    生物学中的注意力提示 非自主性提示: 在没有主观意识的干预下,眼睛会不自觉地注意到环境中比较突出和显眼的物体. 比如我们自然会注意到一堆黑球中的一个白球,马路上最酷的跑车等. 自主性提示: 在主观意识 ...

  9. Transformer:注意力机制(attention)和自注意力机制(self-attention)的学习总结

    目录 前言 1. 注意力机制 1.1非自主提示和自主提示 1.2 查询,键和值 1.3 注意力机制的公式 1.3.1 平均汇聚 1.3.2 非参数的注意力汇聚(Nadaraya-Watson核回归) ...

最新文章

  1. 面试题:Class.forName 和 ClassLoader 有什么区别?
  2. 代码的演化-DI(理解依赖注入di,控制反转ioc)
  3. 创业路上的这点事之 从无到有,从有到......
  4. 开发健壮的企业级应用的研究
  5. windows下Meteor+AngularJS开发的坑
  6. [置顶] 状态压缩DP 简单入门题 11题
  7. Linux/Ubuntu 安装与单机配置hadoop
  8. 一种调用dll的巧妙方法
  9. 触屏touch事件记录
  10. ACM算法分类及完成情况
  11. 用计算机模拟宇宙,科学家尝试利用计算机模拟整个宇宙的演化
  12. 3分钟微信支付商家注册0.2费率开户方法,0.38~0.6的必看
  13. 西数硬盘刷新固件_关于西数硬盘转速的fake news
  14. 故宫景点功课24:宁寿宫区6
  15. qq小程序开发者工具无法编写代码
  16. 江南爱窗帘十大品牌,怎么合理的搭配窗帘配色
  17. AI 编程助手 亚马逊CodeWhisperer使用简介
  18. 202020 公文系统安装技巧
  19. 程序员实用工具和网站(转)
  20. 基于Qt的国旗制作(巴勒斯坦国旗)

热门文章

  1. R 语言中如何调整 matrix 和 dataframe 中列的顺序
  2. 逻辑回归原理及其推导
  3. Rainbow Bridge:trustless bridge between NEAR and Ethereum
  4. cookie实现登录功能
  5. ssh配置免密登录、scp文件传输免密
  6. 《精英的傲慢:好的社会该如何定义成功》笔记与摘录二
  7. 7天入门Python 3 — Python对象属性及核心数据类型
  8. 太厉害了!java线程死锁例子
  9. Puppeteer将动态html页面生成pdf(终极解决方案)
  10. css 文字 3d旋转动画,CSS3 简单的三维文字旋转动画