10.3注意力的评价函数
1. 注意力的构造函数
1.1 <key, value> , query
key :代表从原始输入x 中, 从事物自身本来所固有的某一个属性(或多个属性上)上提取出来的特征, 这些特征可以抽象的表示原始输入;
value: 从原始输入中, 进行映射,在另外一个特征维度上来表征输入;
query: (即可以是人脑的意识作用下, 也可以是机器学习出来的)提炼出来的一种属性特征;
1.2 注意力层的输出
上图的步骤:
使用 query 的属性与 key 的属性,通过注意力的评价函数, 得到 key 与 query 之间的注意力分数;
注意分数通过softmax 便得到注意力权重;
将注意力权重 作用到各个 value 上,得到最终注意力层的输出;
即最终该注意力层的输出: 是注意力权重与各个vaule 值的加权和得到;
1.3 注意力层数学描述
用数学语言描述,假设有一个查询 query∈Rvquery \in \mathbb{R}^vquery∈Rv
和 mmm 个“键-值”对 (k1,v1),....(km,vm)(k_1, v_1), .... (k_m, v_m)(k1,v1),....(km,vm),
其中 ki∈Rk,vi∈Rvk_i \in \mathbb{R}^k, v_i \in \mathbb {R}^vki∈Rk,vi∈Rv
则,注意力层的输出函数就被表示成值的加权和:
f(q,(k1,v1),…,(km,vm))=∑i=1mα(q,ki)vi∈Rv,f(\mathbf{q}, (\mathbf{k}_1, \mathbf{v}_1), \ldots, (\mathbf{k}_m, \mathbf{v}_m)) = \sum_{i=1}^m \alpha(\mathbf{q}, \mathbf{k}_i) \mathbf{v}_i \in \mathbb{R}^v, f(q,(k1,v1),…,(km,vm))=i=1∑mα(q,ki)vi∈Rv,
其中注意力权重,
- 先通过注意力评分函数将查询和键 (q,kq, kq,k)两个向量映射成标量,
- 再经过softmax运算得到的。
α(q,ki)=softmax(a(q,ki))=exp(a(q,ki))∑j=1mexp(a(q,kj))∈R.\alpha(\mathbf{q}, \mathbf{k}_i) = \mathrm{softmax}(a(\mathbf{q}, \mathbf{k}_i)) = \frac{\exp(a(\mathbf{q}, \mathbf{k}_i))}{\sum_{j=1}^m \exp(a(\mathbf{q}, \mathbf{k}_j))} \in \mathbb{R}. α(q,ki)=softmax(a(q,ki))=∑j=1mexp(a(q,kj))exp(a(q,ki))∈R.
正如我们所看到的,选择不同的注意力评分函数会导致不同的注意力汇聚操作。
在本节中,我们将介绍两个流行的评分函数,稍后将用他们来实现更复杂的注意力机制。
2. 带有mask 的softmax 运算;
正如上面提到的,softmax操作用于输出一个概率分布作为注意力权重。
在某些情况下,并非所有的值都应该被纳入到注意力汇聚中。 例如,为了在 9.5节中高效处理小批量数据集, 某些文本序列被填充了没有意义的特殊词元。
为了仅将有意义的词元作为值来获取注意力汇聚, 我们可以指定一个有效序列长度(即词元的个数), 以便在计算softmax时过滤掉超出指定范围的位置。
通过这种方式,我们可以在下面的masked_softmax函数中 实现这样的掩蔽softmax操作(masked softmax operation), 其中任何超出有效长度的位置都被掩蔽并置为0。
具体的实现:
便是 通过在最后一个轴 上 使用 掩蔽查过有效长度的数字来实现;
3. 加性注意力
一般来说,当查询和键是不同长度的矢量时, 我们可以使用加性注意力作为评分函数。 给定查询q∈Rq\mathbf{q} \in \mathbb{R}^qq∈Rq
和 键k∈Rk\mathbf{k} \in \mathbb{R}^kk∈Rk, 加性注意力(additive attention)的评分函数为:
a(q,k)=wv⊤tanh(Wqq+Wkk)∈R,a(\mathbf q, \mathbf k) = \mathbf w_v^\top \text{tanh}(\mathbf W_q\mathbf q + \mathbf W_k \mathbf k) \in \mathbb{R},a(q,k)=wv⊤tanh(Wqq+Wkk)∈R,
其中可学习的参数是 Wq∈Rh×q\mathbf W_q\in\mathbb R^{h\times q}Wq∈Rh×q ,Wk∈Rh×k\mathbf W_k\in\mathbb R^{h\times k}Wk∈Rh×k, wv∈Rh\mathbf w_v\in\mathbb R^{h}wv∈Rh,
如 前面图中所示, 将查询和键连结起来后输入到一个多层感知机(MLP)中, 感知机包含一个隐藏层,其隐藏单元数是一个超参数。 通过使用作为激活函数,并且禁用偏置项。
下面我们来实现加性注意力。
#@save
class AdditiveAttention(tf.keras.layers.Layer):"""Additiveattention."""def __init__(self, key_size, query_size, num_hiddens, dropout, **kwargs):super().__init__(**kwargs)self.W_k = tf.keras.layers.Dense(num_hiddens, use_bias=False)self.W_q = tf.keras.layers.Dense(num_hiddens, use_bias=False)self.w_v = tf.keras.layers.Dense(1, use_bias=False)self.dropout = tf.keras.layers.Dropout(dropout)def call(self, queries, keys, values, valid_lens, **kwargs):queries, keys = self.W_q(queries), self.W_k(keys)# 在维度扩展后,# queries的形状:(batch_size,查询的个数,1,num_hidden)# key的形状:(batch_size,1,“键-值”对的个数,num_hiddens)# 使用广播方式进行求和features = tf.expand_dims(queries, axis=2) + tf.expand_dims(keys, axis=1)features = tf.nn.tanh(features)# self.w_v仅有一个输出,因此从形状中移除最后那个维度。# scores的形状:(batch_size,查询的个数,“键-值”对的个数)scores = tf.squeeze(self.w_v(features), axis=-1)self.attention_weights = masked_softmax(scores, valid_lens)# values的形状:(batch_size,“键-值”对的个数,值的维度)return tf.matmul(self.dropout(self.attention_weights, **kwargs), values)
4. 缩放点积注意力
使用点积可以得到计算效率更高的评分函数, 但是点积操作要求查询和键具有相同的长度 d。 假设查询和键的所有元素都是独立的随机变量, 并且都满足零均值和单位方差, 那么两个向量的点积的均值为0,方差为d{d}d。
为确保无论向量长度如何, 点积的方差在不考虑向量长度的情况下仍然是1, 我们将点积除以d\sqrt{d}d, 则缩放点积注意力(scaled dot-product attention)评分函数为:
a(q,k)=q⊤k/d.a(\mathbf q, \mathbf k) = \mathbf{q}^\top \mathbf{k} /\sqrt{d}.a(q,k)=q⊤k/d.
在实践中,我们通常从小批量的角度来考虑提高效率, 例如基于个查询和个键-值对计算注意力, 其中查询和键的长度为,值的长度为。
查询 Q∈Rn×d\mathbf Q\in\mathbb R^{n\times d}Q∈Rn×d、 键 K∈Rm×d\mathbf K\in\mathbb R^{m\times d}K∈Rm×d和 值V∈Rm×v\mathbf V\in\mathbb R^{m\times v}V∈Rm×v的缩放点积注意力是:
softmax(QK⊤d)V∈Rn×v.\mathrm{softmax}\left(\frac{\mathbf Q \mathbf K^\top }{\sqrt{d}}\right) \mathbf V \in \mathbb{R}^{n\times v}.softmax(dQK⊤)V∈Rn×v.
在下面的缩放点积注意力的实现中,我们使用了暂退法进行模型正则化。
#@save
class DotProductAttention(tf.keras.layers.Layer):"""Scaleddotproductattention."""def __init__(self, dropout, **kwargs):super().__init__(**kwargs)self.dropout = tf.keras.layers.Dropout(dropout)# queries的形状:(batch_size,查询的个数,d)# keys的形状:(batch_size,“键-值”对的个数,d)# values的形状:(batch_size,“键-值”对的个数,值的维度)# valid_lens的形状:(batch_size,)或者(batch_size,查询的个数)def call(self, queries, keys, values, valid_lens, **kwargs):d = queries.shape[-1]scores = tf.matmul(queries, keys, transpose_b=True)/tf.math.sqrt(tf.cast(d, dtype=tf.float32))self.attention_weights = masked_softmax(scores, valid_lens)return tf.matmul(self.dropout(self.attention_weights, **kwargs), values)
5. 小结
选择不同的注意力评价函数在注意力层中带来不同的注意力操作。
当查询和键是不同长度的矢量时,可以使用可加性注意力评分函数。
当它们的长度相同时,使用缩放的“点-积”注意力评分函数的计算效率更高。
10.3注意力的评价函数相关推荐
- 小米 10 年再创业,高端 5G 手机和 AIoT 有多少机会?
作者 | 小谦 责编 | 唐小引 头图 | CSDN 下载自东方 IC 出品 | CSDN(ID:CSDNnews) 从 2010 年到 2020 年,经过 10 年高速发展,小米已经成为一家坐拥数亿 ...
- Xgboost调参小结
XGBoost全称是eXtreme Gradient Boosting,由陈天奇所设计,和传统的梯度提升算法相比,XGBoost进行了许多改进,它能够比其他使用梯度提升的集成算法更加快速.关于xg ...
- respond with a status of 40_高中英语作文高分秘籍!50组高级替换词+40个高级句型+88个高级词组,还不快记下!...
很多同学都发现,高中阶段对于写作的考查要求更高,一篇没有错误但平淡无奇的文章是拿不到高分的,保证正确性的前提下,写作必须要有出彩之处才能得到阅卷老师的青睐,今天就为大家分享高中写作50组高级替换词+4 ...
- 硅谷独角兽公司的监控系统长啥样?
前言 不同的业务场景中我们对各个运维系统的需求也是不同的,Pinterest是来自于硅谷的初创公司,在他们成长的过程中一步步对运维系统进行改进和升级,如今的Pinterest 的监控系统更是实现了监控 ...
- CVPR2022知识蒸馏用于目标检测:Focal and Global Knowledge Distillation for Detectors
论文下载:https://arxiv.org/abs/2111.11837 源码下载:https://github.com/yzd-v/FGD Abstract 知识蒸馏已成功应用于图像分类.然而目标 ...
- 谷歌:科技让世界更美好
谷歌:科技让世界更美好 一个以改变世界为目标的企业是什么样子的?这大概是知道谷歌的人们都想要深入了解的一个问题. 从"不作恶"到"让世界更美好"的企业宗旨 ,在 ...
- XML之文档类型定义和合法性(转)
来至:liang--liang博客:http://www.cnblogs.com/liang--liang/archive/2008/01/15/1039277.html 好牛 XML被作为一种元标记 ...
- 日本机器人全球领先来自这三大顶尖技术
来源:工业机器人 ▍日本尖端技术之一:机器人精密减速机 世界上工业机器人使用的精密减速机几乎都为日本所垄断.尽管国内也量产的RV减速机,但国产机器人企业却鲜有选用的,原因是日本精密减速机技术遥遥领先, ...
- Zemax学习笔记——多重结构配置的激光扩束镜
假设你需要设计一个激光扩束器,使用的波长为1.053 μm,输入光束直径为100 mm,输出光束直径为 20mm,且输入光束与输出光束平行. 限制条件: 只能使用两个镜片 本设计必须是伽利略式的 只有 ...
- PyTorch实例2——文本情绪分类器
实例主要用于熟悉相关模型,并且练习创建一个模型的步骤:数据收集.数据预处理.构建模型.训练模型.测试模型.观察模型表现.保存模型 传送门:蓝桥云课实验 目录 1. 实验环境 2. 实验目的 3. 相关 ...
最新文章
- 415 (Unsupported Media Type)
- C++继承机制下的构造函数
- HBase 数据导入功能实现方式解释
- hdu 2461(线段树求面积并)
- Java设计模式笔记(6)观察者模式
- Laravel核心解读--Database(四) 模型关联
- Jakarta EE工作组正式成立
- Linux下Ipython安装
- 【2019南京ICPC网络赛 D】Robots【DAG上的随机游走】
- VS2017安装方法
- C++制作一个连点器
- html5ie11缩放,IE 11 页面缩放后再次打开不能保存之前的缩放比例
- laravel7 LogicException Please make sure the PHP Redis extension is installed and enabled
- 微信开发者工具 the permission value is offline verifying 异常
- 用商汤的mmdetection 学习目标检测中的 Recalls, Precisions, AP, mAP 算法 Part1
- 通讯录管理系统(C++)
- win10连不上网,“网络重置”后,网络适配器出现感叹号,右下角WLAN消失,网络连接是空白。
- 【NPDP产品经理】发散思维让你的思维视野更广阔
- 八大资管业务类型汇总
- 键盘右边数字键不能用,只能当方向键使用
热门文章
- pcb 受潮_怎样让PCB远离潮湿危害
- 分享一个超厉害的网站,几乎解决一切command not found问题
- fix协议封装挑战-数据有效性校验
- 解释:什么是木马、蠕虫、病毒
- Octapharma宣布,Nuwiq(R) (simoctocog alfa)用于既往未曾治疗的患者(PUP)的NuProtect研究的最终结果将在ASH 2019上呈报
- Windows操作系统单网卡设置双IP
- 黑鲨重装计算机安装无法继续,黑鲨装机,小编教你黑鲨怎么安装win7
- adobe怎么统计字数_pdf文档统计字数的问题
- vscode修改背景
- 【PDF合并】滴滴出行电子发票及行程报销单【一页打印】