该章节介绍VITGAN对抗生成网络中,MSA多头注意力 部分的代码实现。


  1. 网络结构简介
  2. Mapping NetWork 实现
  3. PositionalEmbedding 实现
  4. MLP 实现
  5. MSA多头注意力 实现
  6. SLN自调制 实现
  7. CoordinatesPositionalEmbedding 实现
  8. ModulatedLinear 实现
  9. Siren 实现
  10. Generator生成器 实现
  11. PatchEmbedding 实现
  12. ISN 实现
  13. Discriminator鉴别器 实现
  14. VITGAN 实现

MSA多头注意力 简介



import tensorflow as tfclass MSA(tf.keras.layers.Layer):def __init__(self, d_model, num_heads, discriminator):super().__init__()self.num_heads = num_headsself.d_model = d_modelself.discriminator = discriminator # 是否鉴别器assert d_model % self.num_heads == 0self.depth = d_model // self.num_headsself.wq = tf.keras.layers.Dense(d_model, use_bias=False)self.wk = tf.keras.layers.Dense(d_model, use_bias=False)self.wv = tf.keras.layers.Dense(d_model, use_bias=False)self.dense = tf.keras.layers.Dense(d_model, use_bias=False)def split_heads(self, x, batch_size):"""Split the last dimension into (num_heads, depth).Transpose the result such that the shape is (batch_size, num_heads, seq_len, depth)"""x = tf.reshape(x, (batch_size, -1, self.num_heads, self.depth))return tf.transpose(x, perm=[0, 2, 1, 3])def scaled_dot_product_attention(self, q, k, v, mask):"""Calculate the attention weights.q, k, v must have matching leading dimensions.k, v must have matching penultimate dimension, i.e.: seq_len_k = seq_len_v.The mask has different shapes depending on its type(padding or look ahead)but it must be broadcastable for addition.Args:q: query shape == (..., seq_len_q, depth)k: key shape == (..., seq_len_k, depth)v: value shape == (..., seq_len_v, depth_v)mask: Float tensor with shape broadcastableto (..., seq_len_q, seq_len_k). Defaults to None.Returns:output, attention_weights"""# (..., seq_len_q, seq_len_k)if self.discriminator:# 欧氏距离matmul_q = tf.expand_dims(q, axis=-2)matmul_k = tf.expand_dims(k, axis=-3)matmul_qk = tf.math.sqrt(tf.math.reduce_sum(tf.math.square(matmul_q - matmul_k), axis=-1))else:# 乘积matmul_qk = tf.matmul(q, k, transpose_b=True)# scale matmul_qkdk = tf.cast(tf.shape(k)[-1], tf.float32)scaled_attention_logits = matmul_qk / tf.math.sqrt(dk)# add the mask to the scaled tensor.if mask is not None:scaled_attention_logits += (mask * -1e9)# softmax is normalized on the last axis (seq_len_k) so that the scores# add up to 1.attention_weights = tf.nn.softmax(scaled_attention_logits, axis=-1)  # (..., seq_len_q, seq_len_k)output = tf.matmul(attention_weights, v)  # (..., seq_len_q, depth_v)return output, attention_weightsdef call(self, v, k, q, mask):batch_size = tf.shape(q)[0]q = self.wq(q)  # (batch_size, seq_len, d_model)k = self.wk(k)  # (batch_size, seq_len, d_model)v = self.wv(v)  # (batch_size, seq_len, d_model)# (batch_size, num_heads, seq_len_q, depth)q = self.split_heads(q, batch_size)# (batch_size, num_heads, seq_len_k, depth)k = self.split_heads(k, batch_size)# (batch_size, num_heads, seq_len_v, depth)v = self.split_heads(v, batch_size)# scaled_attention.shape == (batch_size, num_heads, seq_len_q, depth)# attention_weights.shape == (batch_size, num_heads, seq_len_q, seq_len_k)scaled_attention, attention_weights = self.scaled_dot_product_attention(q, k, v, mask)# (batch_size, seq_len_q, num_heads, depth)scaled_attention = tf.transpose(scaled_attention, perm=[0, 2, 1, 3])concat_attention = tf.reshape(scaled_attention, (batch_size, -1, self.d_model))  # (batch_size, seq_len_q, d_model)# (batch_size, seq_len_q, d_model)output = self.dense(concat_attention)# return output, attention_weightsreturn outputif __name__ == "__main__":layer = MSA(128, 8, discriminator=False)x = tf.random.uniform([2,5,128], dtype=tf.float32)o = layer(x,x,x,None)tf.print(tf.shape(o))


  • 论文地址:VITGAN: Training GANs with Vision Transformers
  • 多头注意力实现代码参考:代码地址
  • 源码:https://github.com/tfwcn/VITGAN-tf2

