基于GIN的图表征网络的实现

基于图同构网络的图表征学习包含以下过程:

  1. 首先计算得到节点表征;
  2. 然后对图上各个节点的表征做图池化,得到图的表征。

基于图同构网络的图表征模块(GINGraphRepr Module)

import torch
from torch import nn
from torch_geometric.nn import global_add_pool, global_mean_pool, global_max_pool, GlobalAttention, Set2Set
from gin_node import GINNodeEmbeddingclass GINGraphRepr(nn.Module):def __init__(self, num_tasks=1, num_layers=5, emb_dim=300, residual=False, drop_ratio=0, JK="last", graph_pooling="sum"):super(GINGraphPooling, self).__init__()self.num_layers = num_layersself.drop_ratio = drop_ratioself.JK = JKself.emb_dim = emb_dimself.num_tasks = num_tasksif self.num_layers < 2:raise ValueError("Number of GNN layers must be greater than 1.")self.gnn_node = GINNodeEmbedding(num_layers, emb_dim, JK=JK, drop_ratio=drop_ratio, residual=residual)# Pooling function to generate whole-graph embeddingsif graph_pooling == "sum":self.pool = global_add_poolelif graph_pooling == "mean":self.pool = global_mean_poolelif graph_pooling == "max":self.pool = global_max_poolelif graph_pooling == "attention":self.pool = GlobalAttention(gate_nn=nn.Sequential(nn.Linear(emb_dim, emb_dim), nn.BatchNorm1d(emb_dim), nn.ReLU(), nn.Linear(emb_dim, 1)))elif graph_pooling == "set2set":self.pool = Set2Set(emb_dim, processing_steps=2)else:raise ValueError("Invalid graph pooling type.")if graph_pooling == "set2set":self.graph_pred_linear = nn.Linear(2*self.emb_dim, self.num_tasks)else:self.graph_pred_linear = nn.Linear(self.emb_dim, self.num_tasks)def forward(self, batched_data):h_node = self.gnn_node(batched_data)h_graph = self.pool(h_node, batched_data.batch)output = self.graph_pred_linear(h_graph)if self.training:return outputelse:# At inference time, relu is applied to output to ensure positivity# 因为预测目标的取值范围就在 (0, 50] 内return torch.clamp(output, min=0, max=50)
  • sum::对节点表征求和
  • mean:对节点表征求平均
  • max:取节点表征的最大值
  • attention:基于Attention对节点表征加权求和
  • set2set:另一种基于Attention对节点表征加权求和

基于图同构网络的节点嵌入模块(GINNodeEmbedding Module)

import torch
from mol_encoder import AtomEncoder
from gin_conv import GINConv
import torch.nn.functional as F# GNN to generate node embedding
class GINNodeEmbedding(torch.nn.Module):"""Output:node representations"""def __init__(self, num_layers, emb_dim, drop_ratio=0.5, JK="last", residual=False):"""GIN Node Embedding Module"""super(GINNodeEmbedding, self).__init__()self.num_layers = num_layersself.drop_ratio = drop_ratioself.JK = JK# add residual connection or notself.residual = residualif self.num_layers < 2:raise ValueError("Number of GNN layers must be greater than 1.")# 首先用AtomEncoder做嵌入得到第0层节点表征self.atom_encoder = AtomEncoder(emb_dim)# List of GNNsself.convs = torch.nn.ModuleList()self.batch_norms = torch.nn.ModuleList()# 从第1层到第num_layers层,点表征的计算都以上一层的节点表征、边和边的属性为输入for layer in range(num_layers):self.convs.append(GINConv(emb_dim))self.batch_norms.append(torch.nn.BatchNorm1d(emb_dim))def forward(self, batched_data):x, edge_index, edge_attr = batched_data.x, batched_data.edge_index, batched_data.edge_attr# computing input node embeddingh_list = [self.atom_encoder(x)]  # 先将类别型原子属性转化为原子表征for layer in range(self.num_layers):h = self.convs[layer](h_list[layer], edge_index, edge_attr)h = self.batch_norms[layer](h)if layer == self.num_layers - 1:# remove relu for the last layerh = F.dropout(h, self.drop_ratio, training=self.training)else:h = F.dropout(F.relu(h), self.drop_ratio, training=self.training)if self.residual:h += h_list[layer]h_list.append(h)# Different implementations of Jk-concatif self.JK == "last":node_representation = h_list[-1]elif self.JK == "sum":node_representation = 0for layer in range(self.num_layers + 1):node_representation += h_list[layer]return node_representation

GINConv的层数越多,此节点嵌入模块的感受野越大,节点i的表征最远能捕获到节点i的距离为num_layers的邻接节点的信息。

图同构网络的关键组件GINConv

import torch
from torch import nn
from torch_geometric.nn import MessagePassing
import torch.nn.functional as F
from ogb.graphproppred.mol_encoder import BondEncoder### GIN convolution along the graph structure
class GINConv(MessagePassing):def __init__(self, emb_dim):'''emb_dim (int): node embedding dimensionality'''super(GINConv, self).__init__(aggr = "add")self.mlp = nn.Sequential(nn.Linear(emb_dim, emb_dim), nn.BatchNorm1d(emb_dim), nn.ReLU(), nn.Linear(emb_dim, emb_dim))self.eps = nn.Parameter(torch.Tensor([0]))self.bond_encoder = BondEncoder(emb_dim = emb_dim)def forward(self, x, edge_index, edge_attr):edge_embedding = self.bond_encoder(edge_attr) # 先将类别型边属性转换为边表征out = self.mlp((1 + self.eps) *x + self.propagate(edge_index, x=x, edge_attr=edge_embedding))return outdef message(self, x_j, edge_attr):return F.relu(x_j + edge_attr)def update(self, aggr_out):return aggr_out

GINConv模块遵循“消息传递、消息聚合、消息更新”这一过程:

  • 首先self.propagate()方法开始执行,该方法接收edge_index, x, edge_attr三个参数,edge_index是形状为[2,num_edges]的张量。
  • 在消息传递过程中,此张量先按行拆分为x_i和x_j张量,x_j表示消息传递的源节点,x_i表示消息传递的目标节点。
  • 接着message()方法被调用,此方法定义了从源节点到目标节点的消息,这里要传递的消息是源节点表征与边表征之和的relu()的输出。我们在初始化时用aggr="add"定义了消息聚合方式,那么传入一个目标节点的消息被求和得到aggr_out,它还是目标节点的中间过程的消息。
  • 然后执行消息更新过程,update()方法被调用。我们希望在更新中加入目标节点自身的消息,因此在update方法中只返回输入的aggr_out。
  • 最后在forward()方法中执行消息的更新。

原文地址

基于图神经网络的图表示学习方法相关推荐

  1. 【图神经网络】图分类学习研究综述[2]:基于图神经网络的图分类

    基于GNN的图分类学习研究综述[2]:基于图神经网络的图分类 论文阅读:基于GNN的图分类学习研究综述 3. 基于图神经网络的图分类 3.1 卷积 3.2 池化 论文阅读:基于GNN的图分类学习研究综 ...

  2. 当图网络遇上计算机视觉!计算机视觉中基于图神经网络和图Transformer的方法和最新进展...

    点击下方卡片,关注"CVer"公众号 AI/CV重磅干货,第一时间送达 点击进入-> CV 微信技术交流群 可能是目前最全面的<当图网络遇上计算机视觉>综述!近四 ...

  3. 图神经网络的图网络学习(上)

    图神经网络的图网络学习(上) 原文:Learning the Network of Graphs for Graph Neural Networks 摘要 图神经网络 (GNN) 在许多使用图结构数据 ...

  4. 「图神经网络复杂图挖掘」 的研究进展

    来源:专知 图神经网络对非欧式空间数据建立了深度学习框架,相比传统网络表示学习模型,它对图结构能够实施更加深层的信息聚合操作.近年来,图神经网络完成了向复杂图结构的迁移,诞生了一系列基于复杂图的图神经 ...

  5. 图神经网络与图注意力网络相关知识概述

    #图神经网络# #图注意力网络# 随着计算机行业和互联网时代的不断发展与进步,图神经网络已经成为人工智能和大数据的重要研究领域.图神经网络是对相邻节点间信息的传播和聚合的重要技术,可以有效地将深度学习 ...

  6. 图神经网络基础--图结构数据

    图神经网络基础–图结构数据 注:本节大部分内容(包括图片)来源于"Chapter 2 - Foundations of Graphs, Deep Learning on Graphs&quo ...

  7. 图神经网络的图网络学习(下)

    原文:Learning the Network of Graphs for Graph Neural Networks 1. 文章信息 作者 Yixiang Shan, Jielong Yang, X ...

  8. 图神经网络之图卷积网络——GCN

    图卷积网络--GCN 一.前置基础知识回顾 图的基本概念 构造图神经网络的目的 训练方式 二.回顾卷积神经网络在图像及文本上的发展 图像上的卷积网络 文本上的卷积网络 图卷积网络的必要性 三.图卷积网 ...

  9. 基于图神经网络的图表征学习方法

    图表征学习是指将整个图表示成低维.实值.稠密的向量形式,用来对整个图结构进行分析,包括图分类.图之间的相似性计算等. 相比之前的图节点,图的表征学习更加复杂,但构建的方法是建立在图节点表征的基础之上. ...

最新文章

  1. [导入]ZT笑到内伤:史上最雷,最爆寒的电影字幕
  2. MinGW问题解决:gcc: installation problem, cannot exec `cc1'
  3. jquery iCheck 插件
  4. roku能不能安装软件_如何阻止假期更改Roku主题
  5. 大家都在说的分布式系统到底是什么
  6. [position]返回顶部
  7. 人脸识别研究任务及开源项目调研
  8. OSX系统编译cocos2dx andriod工程
  9. python迭代对象有哪些_Python可迭代对象操作示例
  10. python阈值计算_python – 在numpy中计算超过阈值的数组值的最快方法
  11. Django 模板系统
  12. HDU 2955 Robberies抢劫案(01背包,变形)
  13. 查看JVisualVM查看信息
  14. 王者荣耀是用什么代码变成MOBA游戏的,该怎么学?有前途吗?
  15. YDOOK:ANSYS 谐波分析的要点和主要应用场景 谐波效应的来源
  16. 概率论中的一些基础知识——条件概率 先验概率 后验概率 似然 概率分布函数 概率密度函数
  17. 计算机del键作用,计算机里的英文字母“DEL”键是干什么用的
  18. matlab GUI制作拼图小游戏
  19. python制作辅助和易语言的区别_为什么多数外挂都用易语言?
  20. 2017国产品牌台式计算机,2017三大热门国产平板电脑推荐

热门文章

  1. mysql-外键操作-级联删除
  2. Vmware16一打开虚拟机就蓝屏
  3. AndroidMainfest.xml具体解释——lt;activitygt;
  4. 使用CrashHandler来获取应用的crash信息
  5. POJ-Prime Gap 素数筛选+二分查找
  6. Hadoop---集群安装
  7. 浅谈移动互联网广告设计评论
  8. 程序员应具备的职业素质
  9. IBM将发布以固态硬盘为基础的全企业系统
  10. 数据状态更新时的差异 diff 及 patch 机制