论文

图注意力网络来自 Graph Attention Networks,ICLR 2018. https://arxiv.org/abs/1710.10903

注意力机制

代码

import tensorflow as tf
from tensorflow import keras
from tensorflow.python.keras import activations
from tensorflow.python.keras import constraints
from tensorflow.python.keras import initializers
from tensorflow.python.keras import regularizersclass GraphAttentionLayer(keras.layers.Layer):def compute_output_signature(self, input_signature):passdef __init__(self,input_dim,output_dim,adj,nodes_num,dropout_rate=0.0,activation=None,use_bias=True,kernel_initializer='glorot_uniform',bias_initializer='zeros',kernel_regularizer=None,bias_regularizer=None,activity_regularizer=None,kernel_constraint=None,bias_constraint=None,coef_dropout=0.0,**kwargs):""":param input_dim: 输入的维度:param output_dim: 输出的维度,不等于input_dim:param adj: 具有自环的tuple类型的邻接表[coords, values, shape], 可以采用sp.coo_matrix生成:param nodes_num: 点数量:param dropout_rate: 丢弃率,防过拟合,默认0.5:param activation: 激活函数:param use_bias: 偏移,默认True:param kernel_initializer: 权值初始化方法:param bias_initializer: 偏移初始化方法:param kernel_regularizer: 权值正则化:param bias_regularizer: 偏移正则化:param activity_regularizer: 输出正则化:param kernel_constraint: 权值约束:param bias_constraint: 偏移约束:param coef_dropout: 互相关系数丢弃,默认0.0:param kwargs:"""super(GraphAttentionLayer, self).__init__()self.activation = activations.get(activation)self.use_bias = use_biasself.kernel_initializer = initializers.get(kernel_initializer)self.bias_initializer = initializers.get(bias_initializer)self.kernel_regularizer = regularizers.get(kernel_regularizer)self.bias_regularizer = regularizers.get(bias_regularizer)self.kernel_constraint = constraints.get(kernel_constraint)self.bias_constraint = constraints.get(bias_constraint)self.input_dim = input_dimself.output_dim = output_dimself.support = [tf.SparseTensor(indices=adj[0][0], values=adj[0][1], dense_shape=adj[0][2])]self.dropout_rate = dropout_rateself.coef_drop = coef_dropoutself.nodes_num = nodes_numself.kernel = Noneself.mapping = Noneself.bias = Nonedef build(self, input_shape):"""只执行一次"""self.kernel = self.add_weight(shape=(self.input_dim, self.output_dim),initializer=self.kernel_initializer,regularizer=self.kernel_regularizer,constraint=self.kernel_constraint,trainable=True)if self.use_bias:self.bias = self.add_weight(shape=(self.nodes_num, self.output_dim),initializer=self.kernel_initializer,regularizer=self.kernel_regularizer,constraint=self.kernel_constraint,trainable=True)print('[GAT LAYER]: GAT W & b built.')def call(self, inputs, training=True):# 完成输入到输出的映射关系# inputs = tf.nn.l2_normalize(inputs, 1)raw_shape = inputs.shapeinputs = tf.reshape(inputs, shape=(1, raw_shape[0], raw_shape[1]))  # (1, nodes_num, input_dim)mapped_inputs = keras.layers.Conv1D(self.output_dim, 1, use_bias=False)(inputs)  # (1, nodes_num, output_dim)# mapped_inputs = tf.nn.l2_normalize(mapped_inputs)sa_1 = keras.layers.Conv1D(1, 1)(mapped_inputs)  # (1, nodes_num, 1)sa_2 = keras.layers.Conv1D(1, 1)(mapped_inputs)  # (1, nodes_num, 1)con_sa_1 = tf.reshape(sa_1, shape=(raw_shape[0], 1))  # (nodes_num, 1)con_sa_2 = tf.reshape(sa_2, shape=(raw_shape[0], 1))  # (nodes_num, 1)con_sa_1 = tf.cast(self.support[0], dtype=tf.float32) * con_sa_1  # (nodes_num, nodes_num) W_hicon_sa_2 = tf.cast(self.support[0], dtype=tf.float32) * tf.transpose(con_sa_2, [1, 0])  # (nodes_num, nodes_num) W_hjweights = tf.sparse.add(con_sa_1, con_sa_2)  # concatenationweights_act = tf.SparseTensor(indices=weights.indices,values=tf.nn.leaky_relu(weights.values),dense_shape=weights.dense_shape)  # 注意力互相关系数attention = tf.sparse.softmax(weights_act)  # 输出注意力机制inputs = tf.reshape(inputs, shape=raw_shape)if self.coef_drop > 0.0:attention = tf.SparseTensor(indices=attention.indices,values=tf.nn.dropout(attention.values, self.coef_dropout),dense_shape=attention.dense_shape)if training and self.dropout_rate > 0.0:inputs = tf.nn.dropout(inputs, self.dropout_rate)if not training:print("[GAT LAYER]: GAT not training now.")attention = tf.sparse.reshape(attention, shape=[self.nodes_num, self.nodes_num])value = tf.matmul(inputs, self.kernel)value = tf.sparse.sparse_dense_matmul(attention, value)if self.use_bias:ret = tf.add(value, self.bias)else:ret = tf.reshape(value, (raw_shape[0], self.output_dim))return self.activation(ret)

参考

https://blog.csdn.net/weixin_36474809/article/details/89401552

https://github.com/PetarV-/GAT

更多内容访问 omegaxyz.com
网站所有代码采用Apache 2.0授权
网站文章采用知识共享许可协议BY-NC-SA4.0授权
© 2020 • OmegaXYZ-版权所有 转载请注明出处

图注意力网络(GAT) TensorFlow解析相关推荐

  1. 【图结构】之图注意力网络GAT详解

    作者:張張張張 github地址:https://github.com/zhanghekai [转载请注明出处,谢谢!] GATGATGAT源代码地址:https://github.com/Petar ...

  2. 【GNN】图注意力网络GAT(含代码讲解)

    CSDN页面公式加载有问题,如果影响观看请戳本文的知乎版本:https://zhuanlan.zhihu.com/p/112938037 毫无疑问,图神经网络(Graph Neural Network ...

  3. 图神经网络 | (8)图注意力网络(GAT)

    本篇博客要介绍的是图注意力网络(Graph Attention Networks,GAT),它通过注意力机制(Attention Mechanism)来对邻居节点做聚合操作,实现对不同邻居权重的自适应 ...

  4. GNN动手实践(二):复现图注意力网络GAT

    参考论文:Graph Attention Networks 一.前言 GAT(图注意力网络)是GNNs中重要的SOTA模型,该模型是从空域角度来进行定义,能够用消息传递范式来进行解释.GAT与GCN最 ...

  5. 图注意力网络GAT - 《Graph Attention Networks》论文详解

    目录 前言 正文 图注意力机制层(Graph Attentional Layer) 层的输入 注意力系数 归一化注意力系数 通过邻居节点更新自身节点 层的输出 GAT相比于先前研究的优势 附作者简介 ...

  6. 注意力机制 神经网络_图注意力网络(GAT)

    引言 作者借鉴图神经网络中的注意力机制,提出了图注意力神经网络架构,创新点主要包含如下几个:①采用masked self-attention层,②隐式的对邻居节点采用不同权重③介绍了多头注意力机制. ...

  7. DeepLearning | 图注意力网络Graph Attention Network(GAT)论文、模型、代码解析

    本篇博客是对论文 Velikovi, Petar, Cucurull, Guillem, Casanova, Arantxa,et al. Graph Attention Networks, 2018 ...

  8. 图注意力网络(Graph Attention Network, GAT) 模型解读与代码实现(tensorflow2.0)

    前面的文章,我们讲解了图神经网络三剑客GCN.GraphSAGE.GAT中的两个: 图卷积神经网络(GCN)理解与tensorflow2.0代码实现 GraphSAGE 模型解读与tensorflow ...

  9. 147页详述「结构在神经网络中的复兴」,图注意力网络一作博士论文公开

    点上方蓝字计算机视觉联盟获取更多干货 在右上方 ··· 设为星标 ★,与你不见不散 仅作分享,不代表本公众号立场,侵权联系删除 转载于:机器之心 AI博士笔记系列推荐 周志华<机器学习>手 ...

  10. 图神经网络与图注意力网络相关知识概述

    #图神经网络# #图注意力网络# 随着计算机行业和互联网时代的不断发展与进步,图神经网络已经成为人工智能和大数据的重要研究领域.图神经网络是对相邻节点间信息的传播和聚合的重要技术,可以有效地将深度学习 ...

最新文章

  1. Ecshop实现仿Taobao地区运费模板
  2. Android系统--输入系统(一)必备的Linux知识_inotify和epoll
  3. Institute for Manufacturing virtual check in part 1
  4. Notepad++ JSON关键字自动提示
  5. codeforces C. Xor-tree
  6. Expression Studio 3在windows7下安装失败
  7. 《软件项目管理(第二版)》第 9 章——项目监督与控制 重点部分总结
  8. 电子商务概论_走进经管优质线上课堂(二)之电子商务概论
  9. CISCO安全 ×××技术
  10. [转载] Python中的数学函数,三角函数,随机数函数
  11. CloudEra Hadoop VMWare单节点环境设置
  12. SPSS操作(五):主成分分析
  13. Nokia手机S40平台手机开发环境的搭建的过程
  14. 多线程并发测试工具类
  15. 什么是5W1H分析法?
  16. linux 建立ssh隧道,在Linux、Windows、macOS上创建SSH隧道并通过SSH隧道连接到MySQL
  17. mysql 允许局域网连接_设置Mysql允许局域网或外部连接
  18. 我叫MT online 公会BOSS百分比、难度、BOSS技能及站位
  19. uni.navigateTo失效
  20. 专访吴军:未来10年,AI的发展方向是应用,不会出现重大的理论突破

热门文章

  1. java文件中注释出现乱码解决办法
  2. java导出excel弹出下载框_JavaWeb导出Excel文件并弹出下载框
  3. EasyExcel导出excel(写)
  4. php错误日志框架,错误与日志 - Laravel - 为 WEB 艺术家创造的 PHP 框架。
  5. java.lang.integer_java 中 关于java.lang.ArrayStoreException: java.lang.Integer异常,是什么原因?...
  6. 华为8lite支持云闪付吗_2K/120Hz屏?华为P40Pro尊享版价格曝光 | 一加8曝4.15发布
  7. python怎么弄成黑色背景图片_怎么能把图片的黑色背景改成透明背景
  8. mybatis插入时间_深入分析MyBatis源码
  9. php模板建站seo,phpwin建站教程,phpwind模板
  10. HighCharts:设置坐标轴字体样式