该章节介绍VITGAN对抗生成网络中,MSA多头注意力 部分的代码实现。

目录(文章发布后会补上链接):

  1. 网络结构简介
  2. Mapping NetWork 实现
  3. PositionalEmbedding 实现
  4. MLP 实现
  5. MSA多头注意力 实现
  6. SLN自调制 实现
  7. CoordinatesPositionalEmbedding 实现
  8. ModulatedLinear 实现
  9. Siren 实现
  10. Generator生成器 实现
  11. PatchEmbedding 实现
  12. ISN 实现
  13. Discriminator鉴别器 实现
  14. VITGAN 实现

MSA多头注意力 简介


MSA多头注意力,生成器时使用乘积计算注意力,判别器时使用欧氏距离计算注意力,实现参考Tensorflow官网教程:教程地址

代码实现

import tensorflow as tfclass MSA(tf.keras.layers.Layer):def __init__(self, d_model, num_heads, discriminator):super().__init__()self.num_heads = num_headsself.d_model = d_modelself.discriminator = discriminator # 是否鉴别器assert d_model % self.num_heads == 0self.depth = d_model // self.num_headsself.wq = tf.keras.layers.Dense(d_model, use_bias=False)self.wk = tf.keras.layers.Dense(d_model, use_bias=False)self.wv = tf.keras.layers.Dense(d_model, use_bias=False)self.dense = tf.keras.layers.Dense(d_model, use_bias=False)def split_heads(self, x, batch_size):"""Split the last dimension into (num_heads, depth).Transpose the result such that the shape is (batch_size, num_heads, seq_len, depth)"""x = tf.reshape(x, (batch_size, -1, self.num_heads, self.depth))return tf.transpose(x, perm=[0, 2, 1, 3])def scaled_dot_product_attention(self, q, k, v, mask):"""Calculate the attention weights.q, k, v must have matching leading dimensions.k, v must have matching penultimate dimension, i.e.: seq_len_k = seq_len_v.The mask has different shapes depending on its type(padding or look ahead)but it must be broadcastable for addition.Args:q: query shape == (..., seq_len_q, depth)k: key shape == (..., seq_len_k, depth)v: value shape == (..., seq_len_v, depth_v)mask: Float tensor with shape broadcastableto (..., seq_len_q, seq_len_k). Defaults to None.Returns:output, attention_weights"""# (..., seq_len_q, seq_len_k)if self.discriminator:# 欧氏距离matmul_q = tf.expand_dims(q, axis=-2)matmul_k = tf.expand_dims(k, axis=-3)matmul_qk = tf.math.sqrt(tf.math.reduce_sum(tf.math.square(matmul_q - matmul_k), axis=-1))else:# 乘积matmul_qk = tf.matmul(q, k, transpose_b=True)# scale matmul_qkdk = tf.cast(tf.shape(k)[-1], tf.float32)scaled_attention_logits = matmul_qk / tf.math.sqrt(dk)# add the mask to the scaled tensor.if mask is not None:scaled_attention_logits += (mask * -1e9)# softmax is normalized on the last axis (seq_len_k) so that the scores# add up to 1.attention_weights = tf.nn.softmax(scaled_attention_logits, axis=-1)  # (..., seq_len_q, seq_len_k)output = tf.matmul(attention_weights, v)  # (..., seq_len_q, depth_v)return output, attention_weightsdef call(self, v, k, q, mask):batch_size = tf.shape(q)[0]q = self.wq(q)  # (batch_size, seq_len, d_model)k = self.wk(k)  # (batch_size, seq_len, d_model)v = self.wv(v)  # (batch_size, seq_len, d_model)# (batch_size, num_heads, seq_len_q, depth)q = self.split_heads(q, batch_size)# (batch_size, num_heads, seq_len_k, depth)k = self.split_heads(k, batch_size)# (batch_size, num_heads, seq_len_v, depth)v = self.split_heads(v, batch_size)# scaled_attention.shape == (batch_size, num_heads, seq_len_q, depth)# attention_weights.shape == (batch_size, num_heads, seq_len_q, seq_len_k)scaled_attention, attention_weights = self.scaled_dot_product_attention(q, k, v, mask)# (batch_size, seq_len_q, num_heads, depth)scaled_attention = tf.transpose(scaled_attention, perm=[0, 2, 1, 3])concat_attention = tf.reshape(scaled_attention, (batch_size, -1, self.d_model))  # (batch_size, seq_len_q, d_model)# (batch_size, seq_len_q, d_model)output = self.dense(concat_attention)# return output, attention_weightsreturn outputif __name__ == "__main__":layer = MSA(128, 8, discriminator=False)x = tf.random.uniform([2,5,128], dtype=tf.float32)o = layer(x,x,x,None)tf.print(tf.shape(o))

参考资料:

  • 论文地址:VITGAN: Training GANs with Vision Transformers
  • 多头注意力实现代码参考:代码地址
  • 源码:https://github.com/tfwcn/VITGAN-tf2

【人工智能笔记】第三十六节:TF2实现VITGAN对抗生成网络,MSA多头注意力 实现相关推荐

  1. OpenCV学习笔记(三十六)——Kalman滤波做运动目标跟踪 OpenCV学习笔记(三十七)——实用函数、系统函数、宏core OpenCV学习笔记(三十八)——显示当前FPS OpenC

    OpenCV学习笔记(三十六)--Kalman滤波做运动目标跟踪 kalman滤波大家都很熟悉,其基本思想就是先不考虑输入信号和观测噪声的影响,得到状态变量和输出信号的估计值,再用输出信号的估计误差加 ...

  2. Python编程基础:第三十六节 模块Modules

    第三十六节 模块Modules 前言 实践 前言 我们目前所有的代码都写在一个文档里面.如果你的项目比较大,那么把所有功能写在一个文件里就非常不便于后期维护.为了提高我们代码的可读性,降低后期维护的成 ...

  3. VTK学习笔记(三十六)VTK图像填充

    VTK学习笔记(三十六)VTK图像填充 1.官方示例 2.其他例子 总结 1.官方示例 来自官方示例代码,自己只是添加了理解. 代码: #include <vtkCamera.h> #in ...

  4. Slicer学习笔记(三十六)slicer坐标系

    Slicer学习笔记(三十六)slicer坐标系 1.坐标系统 1.1.世界坐标系 1.2.解剖坐标系 1.3.图像坐标系 1.4.图像变换 1.5.二维示例或计算IJtoLS矩阵 1.6.与其他软件 ...

  5. 【OS学习笔记】三十六 保护模式十:通过中断发起任务切换----中断任务

    上一篇文章学习了:OS学习笔记]三十五 保护模式十:中断描述符表.中断门和陷阱门 本篇文章接着上一篇文章学习中断任务. 我们在前面文章中一直在说通过中断发起任务切换,本文就是将之前没有说明白的内容:通 ...

  6. Udacity机器人软件工程师课程笔记(三十六) - GraphSLAM

    一.引入 GraphSLAM是解决完整的slam问题的slam算法.这意味着该算法将恢复整个路径和地图,而不仅仅是最近的姿势和地图.这种差异使它可以考虑当前姿势与先前姿势之间的依赖性.适用于我们的Gr ...

  7. 重学前端学习笔记(三十六)--Flex 布局

    笔记说明 重学前端是程劭非(winter)[前手机淘宝前端负责人]在极客时间开的一个专栏,每天10分钟,重构你的前端知识体系,笔者主要整理学习过程的一些要点笔记以及感悟,完整的可以加入winter的专 ...

  8. Android开发笔记(三十六)展示类控件

    View/ViewGroup View是单个视图,所有的控件类都是从它派生出来:而ViewGroup是个视图组织,所有的布局视图类都是从它派生出来.由于View和ViewGroup是基类,因此很少会直 ...

  9. 微软企业库4.1学习笔记(三十六)日志模块 简介

    日志模块 企业库的日志模块简单的实现了日志功能的常用功能.开发者可以利用模块在下面的位置记录信息: 事件日志 电子邮件 数据库 消息队列 文本文件 WMI的事件查看器 自定义的位置 模块为记录在任何位 ...

最新文章

  1. 智能,万亿维空间中的求解
  2. 张宏江:开源时代如何解决人的思维孤岛
  3. Gradle之依赖管理
  4. echarts雷达图线的样式_echarts 雷达图的个性化设置
  5. 微信小程序小模块界面,边框带阴影
  6. 关于“因为数据库正在使用,所以无法获得对数据库的独占访问权”的最终解决方案...
  7. deepin终端编译c程序_Deepin Linux安装使用Visual Studio Code(VSCode)调试C++
  8. Apache Camel简介
  9. python第2位的值_Python组通过匹配元组列表中的第二个元组值
  10. 在python3中如何加载静态文件详版步骤
  11. 温馨剪纸风三八妇女节PSD分层海报模板
  12. 【csdn】markdown使用教程
  13. 开源中国翻译频道链接收藏
  14. java类可视化doxygen_安装doxygen(一个自动文档生成工具)+Graphviz图形可视化软件...
  15. wps 云服务器登录_WPS云服务使用协议
  16. 通过管道方式(CreatePipe)获取DOS命令行执行后的返回结果
  17. 离散数学 数学三大危机
  18. 想要下载的编程软件太难找?部分软件官网介绍
  19. matlab均方根误差
  20. 电子科技大学软件工程860考研上岸初试经验分享

热门文章

  1. 24时区来源,CST,CET,UTC,DST,Unix时间戳概述、关系、转换
  2. 0xc000007b 问题总结
  3. Hander异步消息处理机制完全解析
  4. layui 实现动态 radio 、select下拉框 jQuery赋值方法
  5. 《大数据架构和算法实现之路:电商系统的技术实战》——2.4 案例实践
  6. PAC学习框架-泛化误差
  7. Emacs 入门指引(一) Emacs简介
  8. Anniversary Cake (深搜)
  9. HTML显示默认图片实现
  10. android:ellipsize属性