【人工智能笔记】第三十六节:TF2实现VITGAN对抗生成网络,MSA多头注意力 实现
该章节介绍VITGAN对抗生成网络中,MSA多头注意力 部分的代码实现。
目录(文章发布后会补上链接):
- 网络结构简介
- Mapping NetWork 实现
- PositionalEmbedding 实现
- MLP 实现
- MSA多头注意力 实现
- SLN自调制 实现
- CoordinatesPositionalEmbedding 实现
- ModulatedLinear 实现
- Siren 实现
- Generator生成器 实现
- PatchEmbedding 实现
- ISN 实现
- Discriminator鉴别器 实现
- VITGAN 实现
MSA多头注意力 简介
MSA多头注意力,生成器时使用乘积计算注意力,判别器时使用欧氏距离计算注意力,实现参考Tensorflow官网教程:教程地址
代码实现
import tensorflow as tfclass MSA(tf.keras.layers.Layer):def __init__(self, d_model, num_heads, discriminator):super().__init__()self.num_heads = num_headsself.d_model = d_modelself.discriminator = discriminator # 是否鉴别器assert d_model % self.num_heads == 0self.depth = d_model // self.num_headsself.wq = tf.keras.layers.Dense(d_model, use_bias=False)self.wk = tf.keras.layers.Dense(d_model, use_bias=False)self.wv = tf.keras.layers.Dense(d_model, use_bias=False)self.dense = tf.keras.layers.Dense(d_model, use_bias=False)def split_heads(self, x, batch_size):"""Split the last dimension into (num_heads, depth).Transpose the result such that the shape is (batch_size, num_heads, seq_len, depth)"""x = tf.reshape(x, (batch_size, -1, self.num_heads, self.depth))return tf.transpose(x, perm=[0, 2, 1, 3])def scaled_dot_product_attention(self, q, k, v, mask):"""Calculate the attention weights.q, k, v must have matching leading dimensions.k, v must have matching penultimate dimension, i.e.: seq_len_k = seq_len_v.The mask has different shapes depending on its type(padding or look ahead)but it must be broadcastable for addition.Args:q: query shape == (..., seq_len_q, depth)k: key shape == (..., seq_len_k, depth)v: value shape == (..., seq_len_v, depth_v)mask: Float tensor with shape broadcastableto (..., seq_len_q, seq_len_k). Defaults to None.Returns:output, attention_weights"""# (..., seq_len_q, seq_len_k)if self.discriminator:# 欧氏距离matmul_q = tf.expand_dims(q, axis=-2)matmul_k = tf.expand_dims(k, axis=-3)matmul_qk = tf.math.sqrt(tf.math.reduce_sum(tf.math.square(matmul_q - matmul_k), axis=-1))else:# 乘积matmul_qk = tf.matmul(q, k, transpose_b=True)# scale matmul_qkdk = tf.cast(tf.shape(k)[-1], tf.float32)scaled_attention_logits = matmul_qk / tf.math.sqrt(dk)# add the mask to the scaled tensor.if mask is not None:scaled_attention_logits += (mask * -1e9)# softmax is normalized on the last axis (seq_len_k) so that the scores# add up to 1.attention_weights = tf.nn.softmax(scaled_attention_logits, axis=-1) # (..., seq_len_q, seq_len_k)output = tf.matmul(attention_weights, v) # (..., seq_len_q, depth_v)return output, attention_weightsdef call(self, v, k, q, mask):batch_size = tf.shape(q)[0]q = self.wq(q) # (batch_size, seq_len, d_model)k = self.wk(k) # (batch_size, seq_len, d_model)v = self.wv(v) # (batch_size, seq_len, d_model)# (batch_size, num_heads, seq_len_q, depth)q = self.split_heads(q, batch_size)# (batch_size, num_heads, seq_len_k, depth)k = self.split_heads(k, batch_size)# (batch_size, num_heads, seq_len_v, depth)v = self.split_heads(v, batch_size)# scaled_attention.shape == (batch_size, num_heads, seq_len_q, depth)# attention_weights.shape == (batch_size, num_heads, seq_len_q, seq_len_k)scaled_attention, attention_weights = self.scaled_dot_product_attention(q, k, v, mask)# (batch_size, seq_len_q, num_heads, depth)scaled_attention = tf.transpose(scaled_attention, perm=[0, 2, 1, 3])concat_attention = tf.reshape(scaled_attention, (batch_size, -1, self.d_model)) # (batch_size, seq_len_q, d_model)# (batch_size, seq_len_q, d_model)output = self.dense(concat_attention)# return output, attention_weightsreturn outputif __name__ == "__main__":layer = MSA(128, 8, discriminator=False)x = tf.random.uniform([2,5,128], dtype=tf.float32)o = layer(x,x,x,None)tf.print(tf.shape(o))
参考资料:
- 论文地址:VITGAN: Training GANs with Vision Transformers
- 多头注意力实现代码参考:代码地址
- 源码:https://github.com/tfwcn/VITGAN-tf2
【人工智能笔记】第三十六节:TF2实现VITGAN对抗生成网络,MSA多头注意力 实现相关推荐
- OpenCV学习笔记(三十六)——Kalman滤波做运动目标跟踪 OpenCV学习笔记(三十七)——实用函数、系统函数、宏core OpenCV学习笔记(三十八)——显示当前FPS OpenC
OpenCV学习笔记(三十六)--Kalman滤波做运动目标跟踪 kalman滤波大家都很熟悉,其基本思想就是先不考虑输入信号和观测噪声的影响,得到状态变量和输出信号的估计值,再用输出信号的估计误差加 ...
- Python编程基础:第三十六节 模块Modules
第三十六节 模块Modules 前言 实践 前言 我们目前所有的代码都写在一个文档里面.如果你的项目比较大,那么把所有功能写在一个文件里就非常不便于后期维护.为了提高我们代码的可读性,降低后期维护的成 ...
- VTK学习笔记(三十六)VTK图像填充
VTK学习笔记(三十六)VTK图像填充 1.官方示例 2.其他例子 总结 1.官方示例 来自官方示例代码,自己只是添加了理解. 代码: #include <vtkCamera.h> #in ...
- Slicer学习笔记(三十六)slicer坐标系
Slicer学习笔记(三十六)slicer坐标系 1.坐标系统 1.1.世界坐标系 1.2.解剖坐标系 1.3.图像坐标系 1.4.图像变换 1.5.二维示例或计算IJtoLS矩阵 1.6.与其他软件 ...
- 【OS学习笔记】三十六 保护模式十:通过中断发起任务切换----中断任务
上一篇文章学习了:OS学习笔记]三十五 保护模式十:中断描述符表.中断门和陷阱门 本篇文章接着上一篇文章学习中断任务. 我们在前面文章中一直在说通过中断发起任务切换,本文就是将之前没有说明白的内容:通 ...
- Udacity机器人软件工程师课程笔记(三十六) - GraphSLAM
一.引入 GraphSLAM是解决完整的slam问题的slam算法.这意味着该算法将恢复整个路径和地图,而不仅仅是最近的姿势和地图.这种差异使它可以考虑当前姿势与先前姿势之间的依赖性.适用于我们的Gr ...
- 重学前端学习笔记(三十六)--Flex 布局
笔记说明 重学前端是程劭非(winter)[前手机淘宝前端负责人]在极客时间开的一个专栏,每天10分钟,重构你的前端知识体系,笔者主要整理学习过程的一些要点笔记以及感悟,完整的可以加入winter的专 ...
- Android开发笔记(三十六)展示类控件
View/ViewGroup View是单个视图,所有的控件类都是从它派生出来:而ViewGroup是个视图组织,所有的布局视图类都是从它派生出来.由于View和ViewGroup是基类,因此很少会直 ...
- 微软企业库4.1学习笔记(三十六)日志模块 简介
日志模块 企业库的日志模块简单的实现了日志功能的常用功能.开发者可以利用模块在下面的位置记录信息: 事件日志 电子邮件 数据库 消息队列 文本文件 WMI的事件查看器 自定义的位置 模块为记录在任何位 ...
最新文章
- 智能,万亿维空间中的求解
- 张宏江:开源时代如何解决人的思维孤岛
- Gradle之依赖管理
- echarts雷达图线的样式_echarts 雷达图的个性化设置
- 微信小程序小模块界面,边框带阴影
- 关于“因为数据库正在使用,所以无法获得对数据库的独占访问权”的最终解决方案...
- deepin终端编译c程序_Deepin Linux安装使用Visual Studio Code(VSCode)调试C++
- Apache Camel简介
- python第2位的值_Python组通过匹配元组列表中的第二个元组值
- 在python3中如何加载静态文件详版步骤
- 温馨剪纸风三八妇女节PSD分层海报模板
- 【csdn】markdown使用教程
- 开源中国翻译频道链接收藏
- java类可视化doxygen_安装doxygen(一个自动文档生成工具)+Graphviz图形可视化软件...
- wps 云服务器登录_WPS云服务使用协议
- 通过管道方式(CreatePipe)获取DOS命令行执行后的返回结果
- 离散数学 数学三大危机
- 想要下载的编程软件太难找?部分软件官网介绍
- matlab均方根误差
- 电子科技大学软件工程860考研上岸初试经验分享