官方链接:https://tensorflow.google.cn/versions/r2.1/api_docs/python/tf/keras/layers/Attention

tf.keras.layers.Attention(
    use_scale=False, **kwargs
)

Inputs are query tensor of shape [batch_size, Tq, dim]value tensor of shape [batch_size, Tv, dim] and key tensor of shape [batch_size, Tv, dim]. The calculation follows the steps:

  1. Calculate scores with shape [batch_size, Tq, Tv] as a query-key dot product: scores = tf.matmul(query, key, transpose_b=True).
  2. Use scores to calculate a distribution with shape [batch_size, Tq, Tv]distribution = tf.nn.softmax(scores).
  3. Use distribution to create a linear combination of value with shape batch_size, Tq, dim]return tf.matmul(distribution, value).

例子1

import tensorflow as tf
import numpy as npquery = tf.convert_to_tensor(np.asarray([[[1., 1., 1., 3.]]]))key_list = tf.convert_to_tensor(np.asarray([[[1., 1., 2., 4.], [4., 1., 1., 3.], [1., 1., 2., 1.]],[[1., 0., 2., 1.], [1., 2., 1., 2.], [1., 0., 2., 1.]]]))query_value_attention_seq = tf.keras.layers.Attention()([query, key_list])print('query shape:', query.shape)
print('key shape:', key_list.shape)print('result 1:',query_value_attention_seq)

结果:

query shape: (1, 1, 4)
key shape: (2, 3, 4)
result 1: tf.Tensor(
[[[1.8067516  1.         1.7310829  3.730812  ]][[0.99999994 1.9293262  1.0353367  1.9646629 ]]], shape=(2, 1, 4), dtype=float32)

根据文档中提到步骤自己实现

scores = tf.matmul(query, key_list, transpose_b=True)distribution = tf.nn.softmax(scores)result = tf.matmul(distribution, key_list)
print('result 2:',query_value_attention_seq)

结果如下:可以看到结果是和我们理解的一样的

result 2: tf.Tensor(
[[[1.8067516  1.         1.7310829  3.730812  ]][[0.99999994 1.9293262  1.0353367  1.9646629 ]]], shape=(2, 1, 4), dtype=float32)

tf.keras.layers.Attention 理解总结相关推荐

  1. 批标准化 tf.keras.layers.BatchNormalization 中的trainable参数与training参数比较

    巨坑提醒:tf.keras与tensorflow混用,trainable=False根本不起作用.正文不用看了. 摘要: 在tensorflow中,training参数和trainable参数是两个不 ...

  2. 批标准化 tf.keras.layers.BatchNormalization 参数解析与应用分析

    Table of Contents 函数调用 设置training=None时可能存在的问题 :tf.keras.backend.learning_phase()的特点 批标准化函数产生的变量是可训练 ...

  3. Tensorflow学习之tf.keras(一) tf.keras.layers.Model(另附compile,fit)

    模型将层分组为具有训练和推理特征的对象. 继承自:Layer, Module tf.keras.Model(*args, **kwargs ) 参数 inputs 模型的输入:keras.Input ...

  4. 全连接层tf.keras.layers.Dense()介绍

    函数原型 tf.keras.layers.Dense(units, # 正整数,输出空间的维数activation=None, # 激活函数,不指定则没有use_bias=True, # 布尔值,是否 ...

  5. Tensorflow 2.x(keras)源码详解之第七章:keras中的tf.keras.layers

      大家好,我是爱编程的喵喵.双985硕士毕业,现担任全栈工程师一职,热衷于将数据思维应用到工作与生活中.从事机器学习以及相关的前后端开发工作.曾在阿里云.科大讯飞.CCF等比赛获得多次Top名次.现 ...

  6. tensorflow keras 上采样(放大图片) tf.keras.layers.UpSampling2D 示例

    input_shape = (4, 28, 28, 3) # 样本数:4,图片大小:28 * 28,通道:3 inputs = tf.random.normal(input_shape) print( ...

  7. tf.keras.layers.MaxPool2D 最大池化层 示例

    tf.keras.layers.MaxPool2D 最大池化层 示例 import tensorflow as tf import numpy as np inputs = np.random.ran ...

  8. tf.keras.layers.Conv1D 一维卷积 示例

    tf.keras.layers.Conv1D 一维卷积 示例 import tensorflow as tf from tensorflow import keras import numpy as ...

  9. tf.keras.layers.MaxPool2D 示例 池化层

    tf.keras.layers.MaxPool2D 示例 池化层 import tensorflow as tf import numpy as np inputs = np.random.randi ...

最新文章

  1. 【微信小程序】token/session失效了该怎么跳转页面
  2. 其实昨天去加班也没有干什么事情,就只有3个人
  3. 总结C语言中的数组知识点
  4. H264—MP4格式及在MP4文件中提取H264的SPS、PPS及码流
  5. 【名额有限】云开发AI拓展能力等你来体验!
  6. ubuntu安装php5-mysql_Ubuntu下安装Apache2, php5 mysql
  7. 一个好的技术团队应该怎么选择开发语言
  8. ROS----窃听小乌龟行动计划
  9. 【redis】redis基础命令,分布式锁,缓存问题学习大集合
  10. Python爬虫之小米应用商店
  11. SQL SERVER如何通过SQL语句获服务器硬件和系统信息
  12. python学习(二)----字典
  13. python list平均数_数据分析之Python干货笔记
  14. 基于QT的音乐播放器
  15. 我的框架——MyBean
  16. iOS 中的编码方式详解(主要讲解Unicode)
  17. cGAN/cDCGAN,MNIST数据集初体验(内含原理,代码)
  18. vissim跟驰模型_vissim简介
  19. 2008年的高考分析-山西省临汾市
  20. 常用简体中文字体转Unicode和Unicode 2编码对照表

热门文章

  1. Mono for Android—初体验之“电话拨号器”
  2. CentOS 5.8 Zimbra邮件系统安装与配置
  3. 【原创】vegas提示NTDLL.DLL出错的解决办法
  4. JDK Executor执行器的应用
  5. 容器编排技术 -- Kubernetes Ingress解析
  6. NodeJS 使用官方oracledb库连接数据库教程
  7. 这部日本「神作」彻底拉低了我入门AI的门槛
  8. Fedora安装Docker
  9. Kubernetes 环境搭建 - MacOS
  10. Python3标准库:asyncio异步I/O、事件循环和并发工具