注意力原理

注意力计算时有3个输入:

  • Q:可以看作是多个特征的集合,在序列模型中结构通常是:(batch_size, seq_len_q, depth),seq_len_q是时间长度,代表一段时间的depth维特征。
  • K:可以看作是当前要计算注意力的特征,用来与Q的多个特征,进行矩阵相乘,计算出K的注意力权重。
  • V:代表的是结果,将Q、K计算出的注意力分数,与V相乘,得到一个叠加了权重的V值。这就是注意力层的输出。

这样计算其实与人的注意力思考方式有点相似,就是人考虑一件事的重要性K,与以前的经验Q进行比较,然后得到重要程度,再与现在要做的事情V,按重要性排序。注意力模块差不多就是这样的简单原理。然后看下面的结构图与实现代码,就很好理解了。

结构如图:

注意力公式:

Tensorflow 实现

  • 多头注意力

def scaled_dot_product_attention(q, k, v, mask):"""计算注意力权重。q, k, v 必须具有匹配的前置维度。k, v 必须有匹配的倒数第二个维度,例如:seq_len_k = seq_len_v。虽然 mask 根据其类型(填充或前瞻)有不同的形状,但是 mask 必须能进行广播转换以便求和。参数:q: 请求的形状 == (..., seq_len_q, depth)k: 主键的形状 == (..., seq_len_k, depth)v: 数值的形状 == (..., seq_len_v, depth_v)mask: Float 张量,其形状能转换成(..., seq_len_q, seq_len_k)。默认为None。返回值:输出,注意力权重"""# matmul_qk(bs, 8, seq_len_q, seq_len_k)matmul_qk = tf.matmul(q, k, transpose_b=True)  # (..., seq_len_q, seq_len_k)# 缩放 matmul_qkdk = tf.cast(tf.shape(k)[-1], tf.float32)scaled_attention_logits = matmul_qk / tf.math.sqrt(dk)# 将 mask 加入到缩放的张量上。if mask is not None:# mask为1的位置变成非常小的数scaled_attention_logits += (mask * -1e9)  # softmax 在最后一个轴(seq_len_k)上归一化,因此分数# 相加等于1。attention_weights = tf.nn.softmax(scaled_attention_logits, axis=-1)  # (..., seq_len_q, seq_len_k)output = tf.matmul(attention_weights, v)  # (..., seq_len_q, depth_v)return output, attention_weightsdef get_angles(pos, i, d_model):'''获取角度'''angle_rates = 1 / np.power(10000, (2 * (i//2)) / np.float32(d_model))return pos * angle_ratesdef positional_encoding(position, d_model):'''位置编码'''angle_rads = get_angles(np.arange(position)[:, np.newaxis],np.arange(d_model)[np.newaxis, :],d_model)# 将 sin 应用于数组中的偶数索引(indices);2iangle_rads[:, 0::2] = np.sin(angle_rads[:, 0::2])# 将 cos 应用于数组中的奇数索引;2i+1angle_rads[:, 1::2] = np.cos(angle_rads[:, 1::2])pos_encoding = angle_rads[np.newaxis, ...]return tf.cast(pos_encoding, dtype=tf.float32)def create_padding_mask(seq):'''创建填充遮挡,1为遮挡位置'''seq = tf.cast(tf.math.equal(seq, 0), tf.float32)# 添加额外的维度来将填充加到# 注意力对数(logits)。return seq[:, tf.newaxis, tf.newaxis, :]  # (batch_size, 1, 1, seq_len)def create_look_ahead_mask(size):'''创建前瞻遮挡,1为遮挡位置'''mask = 1 - tf.linalg.band_part(tf.ones((size, size)), -1, 0)return mask  # (seq_len, seq_len)class MultiHeadAttention(tf.keras.layers.Layer):'''多头注意力'''def __init__(self, d_model, num_heads):super(MultiHeadAttention, self).__init__()self.num_heads = num_headsself.d_model = d_modelassert d_model % self.num_heads == 0self.depth = d_model // self.num_headsself.wq = tf.keras.layers.Dense(d_model)self.wk = tf.keras.layers.Dense(d_model)self.wv = tf.keras.layers.Dense(d_model)self.dense = tf.keras.layers.Dense(d_model)def split_heads(self, x, batch_size):"""分拆最后一个维度到 (num_heads, depth).转置结果使得形状为 (batch_size, num_heads, seq_len, depth)"""x = tf.reshape(x, (batch_size, -1, self.num_heads, self.depth))return tf.transpose(x, perm=[0, 2, 1, 3])def call(self, v, k, q, mask):batch_size = tf.shape(q)[0]q = self.wq(q)  # (batch_size, seq_len, d_model)k = self.wk(k)  # (batch_size, seq_len, d_model)v = self.wv(v)  # (batch_size, seq_len, d_model)q = self.split_heads(q, batch_size)  # (batch_size, num_heads, seq_len_q, depth)k = self.split_heads(k, batch_size)  # (batch_size, num_heads, seq_len_k, depth)v = self.split_heads(v, batch_size)  # (batch_size, num_heads, seq_len_v, depth)# scaled_attention.shape == (batch_size, num_heads, seq_len_q, depth)# attention_weights.shape == (batch_size, num_heads, seq_len_q, seq_len_k)scaled_attention, attention_weights = scaled_dot_product_attention(q, k, v, mask)scaled_attention = tf.transpose(scaled_attention, perm=[0, 2, 1, 3])  # (batch_size, seq_len_q, num_heads, depth)concat_attention = tf.reshape(scaled_attention, (batch_size, -1, self.d_model))  # (batch_size, seq_len_q, d_model)output = self.dense(concat_attention)  # (batch_size, seq_len_q, d_model)return output, attention_weights
  • Bahdanau注意力

class BahdanauAttention(tf.keras.layers.Layer):def __init__(self, units):super().__init__()# For Eqn. (4), the  Bahdanau attentionself.W1 = tf.keras.layers.Dense(units, use_bias=False)self.W2 = tf.keras.layers.Dense(units, use_bias=False)self.attention = tf.keras.layers.AdditiveAttention()def call(self, query, value, mask):shape_checker = ShapeChecker()shape_checker(query, ('batch', 't', 'query_units'))shape_checker(value, ('batch', 's', 'value_units'))shape_checker(mask, ('batch', 's'))# From Eqn. (4), `W1@ht`.w1_query = self.W1(query)shape_checker(w1_query, ('batch', 't', 'attn_units'))# From Eqn. (4), `W2@hs`.w2_key = self.W2(value)shape_checker(w2_key, ('batch', 's', 'attn_units'))query_mask = tf.ones(tf.shape(query)[:-1], dtype=bool)value_mask = maskcontext_vector, attention_weights = self.attention(inputs = [w1_query, value, w2_key],mask=[query_mask, value_mask],return_attention_scores = True,)shape_checker(context_vector, ('batch', 't', 'value_units'))shape_checker(attention_weights, ('batch', 't', 's'))return context_vector, attention_weights

参考资料:https://www.tensorflow.org/text/tutorials/transformer#scaled_dot_product_attention

【人工智能笔记】第三十节:注意力原理分析,及tensorflow 2.0 实现相关推荐

  1. OpenCV学习笔记(三十六)——Kalman滤波做运动目标跟踪 OpenCV学习笔记(三十七)——实用函数、系统函数、宏core OpenCV学习笔记(三十八)——显示当前FPS OpenC

    OpenCV学习笔记(三十六)--Kalman滤波做运动目标跟踪 kalman滤波大家都很熟悉,其基本思想就是先不考虑输入信号和观测噪声的影响,得到状态变量和输出信号的估计值,再用输出信号的估计误差加 ...

  2. OpenCV学习笔记(三十一)——让demo在他人电脑跑起来 OpenCV学习笔记(三十二)——制作静态库的demo,没有dll也能hold住 OpenCV学习笔记(三十三)——用haar特征训练自己

    OpenCV学习笔记(三十一)--让demo在他人电脑跑起来 这一节的内容感觉比较土鳖.这从来就是一个老生常谈的问题.学MFC的时候就知道这个事情了,那时候记得老师强调多次,如果写的demo想在人家那 ...

  3. 【OS学习笔记】三十六 保护模式十:通过中断发起任务切换----中断任务

    上一篇文章学习了:OS学习笔记]三十五 保护模式十:中断描述符表.中断门和陷阱门 本篇文章接着上一篇文章学习中断任务. 我们在前面文章中一直在说通过中断发起任务切换,本文就是将之前没有说明白的内容:通 ...

  4. 【OS学习笔记】三十四 保护模式十:中断和异常区别

    上几篇文章学习了分页机制的一些原理: [OS学习笔记]三十 保护模式九:段页式内存管理机制概述 [OS学习笔记]三十一 保护模式九:页目录.页表和页三者的关系详解 今天继续学习保护模式下的关于中断与异 ...

  5. WPF,Silverlight与XAML读书笔记第三十九 - 可视化效果之3D图形

    原文:WPF,Silverlight与XAML读书笔记第三十九 - 可视化效果之3D图形 说明:本系列基本上是<WPF揭秘>的读书笔记.在结构安排与文章内容上参照<WPF揭秘> ...

  6. Python编程基础:第三十节 文件检测File Detection

    第三十节 文件检测File Detection 前言 实践 前言 我们通常会涉及到文件相关的操作,例如检测.读写.复制.删除等等.本节我们一起来学习文件检测相关知识,即检测指定路径下是否存在该文件. ...

  7. 【OS学习笔记】三十九 保护模式十:中断和异常的处理与抢占式多任务对应的汇编代码----动态加载的用户程序/任务一代码

    本文是以下几篇文章对应的动态加载的用户程序/任务一代码: [OS学习笔记]三十四 保护模式十:中断和异常区别 [OS学习笔记]三十五 保护模式十:中断描述符表.中断门和陷阱门 [OS学习笔记]三十六 ...

  8. 【OS学习笔记】三十八 保护模式十:中断和异常的处理与抢占式多任务对应的汇编代码----微型内核汇代码

    本文是以下几篇文章对应的微型内核代码汇编代码: [OS学习笔记]三十四 保护模式十:中断和异常区别 [OS学习笔记]三十五 保护模式十:中断描述符表.中断门和陷阱门 [OS学习笔记]三十六 保护模式十 ...

  9. 【OS学习笔记】三十五 保护模式十:中断描述符表、中断门和陷阱门

    上一篇文章学习了中断与异常的概念:[OS学习笔记]三十四 保护模式十:中断和异常区别 本片文章接着学习以下内容: 中断描述符表 中断门 陷阱门 1 中断描述符表 我们前面讲了无数次,在实模式下,是由位 ...

  10. 【OS学习笔记】三十二 保护模式九:分页机制对应的汇编代码之---内核代码

    本片文章是以下两篇文章: [OS学习笔记]三十 保护模式九:段页式内存管理机制概述 [OS学习笔记]三十一 保护模式九:页目录.页表和页三者的关系详解 对应的内核汇编代码. ;代码清单16-1;文件名 ...

最新文章

  1. RPC 笔记(07)— socket 通信(多进程服务器)
  2. 缓存处理类(MemoryCache结合文件缓存)
  3. Visual Studio 2019更新到16.2.3
  4. 助力小白常见JS逆向乱杀喂饭教程——Url加密
  5. Ubuntu服务器安装snmpd(用于监控宝)
  6. 消费者驱动的契约测试_告诉我们您想要什么,我们将做到:消费者驱动的合同测试消息传递...
  7. EhCache 常用配置项详解
  8. 孤零零好可怜的光棍节
  9. python入门基础系列八_03python—9个基础常识-python小白入门系列
  10. form表单会跨域_我的Vue不小心跨域了o()o 干它
  11. 享元模式在 Java String 中的应用
  12. python 自然语言处理(四)____词典资源
  13. html如何添加阿里图标,CSS引入阿里iconfont图标步骤
  14. HttpClient的使用与连接资源释放
  15. python人脸识别训练模型_AI的强大!用Python实现一个简单的人脸识别--中享思途...
  16. UT-FT-ST测试
  17. narwal机器人_Narwal云鲸智能扫拖机器人,会自己洗拖布
  18. 2022安全员-C证上岗证题目及答案
  19. 企业微信可以取消实名认证吗?如何操作?
  20. 职业生涯规划jd网上商城

热门文章

  1. Spring--Spring配置
  2. 外国人最习惯用的社交软件有哪些?
  3. 俄勒冈大学计算机科学专业,俄勒冈大学计算机与信息科学详解 热门专业不容错过...
  4. 均值不等式中考_不等式(初三不等式100道带答案)
  5. linux使用教程PDF,腾讯、阿里Java高级面试真题汇总
  6. 一般英文(java)
  7. 703n的OpenWrt配置四:把路由器变成下载机
  8. 简单好用的录音软件?
  9. 计算几何VS解析几何
  10. 【解决】CSS下拉菜单不会显示的问题