本文主要以Deep Graph Library(DGL)为基础,利用图神经网络来进行图节点分类任务。本篇针对的图为同构图。

1. DGL 介绍

DGL是一个python包,用以在现有的深度学习框架上(包括Pytorch、MXNet和TensorFlow)来实现图神经网络系列模型。它提供了对消息传递的通用控制,通过自动批处理和高度调整的稀疏矩阵内核进行速度优化,以及多 GPU/CPU 训练以扩展到数亿个节点和边缘的图形。
DGL拥有丰富的文档及相关接口,而且文档有中文版本,十分容易学习和上手。
DGL的github链接:https://github.com/dmlc/dgl

2. 图节点分类实践

2.1 数据集加载

本文使用的数据集为DGL中已经有的Cora数据集,该数据集为论文引用数据集,包含论文节点和论文之间的引用关系,通过论文本身的特征和引用关系来对论文进行分类,其共包括以下七类:

  • 基于案例
  • 遗传算法
  • 神经网络
  • 概率方法
  • 强化学习
  • 规则学习
  • 理论
import dgl.data
from dgl.nn import GraphConv
import torch.nn as nn
from dgl.nn.pytorch.conv import SAGEConv
import torch
import torch.nn.functional as Fdataset = dgl.data.CoraGraphDataset()
print('Number of categories:', dataset.num_classes)
g = dataset[0]
print("结点信息",g.ndata)
print("边信息",g.edata)

通过上述代码,可以加载Cora数据集,并能够看到数据集的基本情况,数据集共包含2708个节点,10556条边。

2.2 图神经网络模块定义

简单的GCN构建:
以下代码构建了一个两层图卷积网络(GCN),每一层通过聚合邻居信息来计算新的节点表示。
如果想要构建多层 GCN,您可以简单地堆叠dgl.nn.GraphConv 模块,这些模块继承自torch.nn.Module.

class GCN(nn.Module):def __init__(self, in_feats, h_feats, num_classes):super(GCN, self).__init__()self.conv1 = GraphConv(in_feats, h_feats)self.conv2 = GraphConv(h_feats, num_classes)def forward(self, g, in_feat):h = self.conv1(g, in_feat)h = F.relu(h)h = self.conv2(g, h)return h

GraphSAGE构建:
GraphSAGE 是图神经网络中比较经典的模型,GraphSAGE 包含采样和聚合 (Sample and aggregate),首先使用节点之间连接信息,对邻居进行采样,然后通过多层聚合函数不断地将相邻节点的信息融合在一起。本文参照DGL中的例子来实现的GraphSAGE,代码如下:

class GraphSAGE(nn.Module):def __init__(self,in_feats,n_hidden,n_classes,n_layers,activation,dropout,aggregator_type):super(GraphSAGE, self).__init__()self.layers = nn.ModuleList()self.dropout = nn.Dropout(dropout)self.activation = activation# input layerself.layers.append(SAGEConv(in_feats, n_hidden, aggregator_type))# hidden layersfor i in range(n_layers - 1):self.layers.append(SAGEConv(n_hidden, n_hidden, aggregator_type))# output layerself.layers.append(SAGEConv(n_hidden, n_classes, aggregator_type)) # activation Nonedef forward(self, graph, inputs):h = self.dropout(inputs)for l, layer in enumerate(self.layers):h = layer(graph, h)if l != len(self.layers) - 1:h = self.activation(h)h = self.dropout(h)return h

2.3 评价函数

def evaluate(model, graph, features, labels, nid):model.eval()with torch.no_grad():logits = model(graph, features)logits = logits[nid]labels = labels[nid]_, indices = torch.max(logits, dim=1)correct = torch.sum(indices == labels)return correct.item() * 1.0 / len(labels)

2.4 图神经网络的训练

全图(使用所有的节点和边的特征)上的训练只需要使用上面定义的模型进行前向传播计算,并通过在训练节点上比较预测和真实标签来计算损失,从而完成后向传播。
节点特征和标签存储在其图上,训练、验证和测试的分割也以布尔掩码的形式存储在图上。

features = g.ndata['feat']
labels = g.ndata['label']
train_mask = g.ndata['train_mask']
val_mask = g.ndata['val_mask']
test_mask = g.ndata['test_mask']
train_nid = train_mask.nonzero().squeeze()
val_nid = val_mask.nonzero().squeeze()
test_nid = test_mask.nonzero().squeeze()def train(g, model):optimizer = torch.optim.Adam(model.parameters(), lr=0.01)best_val_acc = 0best_test_acc = 0for e in range(100):# Forwardlogits = model(g, features)# Compute predictionpred = logits.argmax(1)# Compute loss# Note that you should only compute the losses of the nodes in the training set.loss = F.cross_entropy(logits[train_mask], labels[train_mask])# Backwardoptimizer.zero_grad()loss.backward()optimizer.step()acc = evaluate(model, g, features, labels, val_nid)print("Epoch {:05d}  | Loss {:.4f} | Accuracy {:.4f} | ".format(e, loss.item(), acc))

双层GNN训练:

model = GCN(g.ndata['feat'].shape[1], 16, dataset.num_classes)
train(g, model)
print()
acc = evaluate(model, g, features, labels, test_nid)
print("Test Accuracy {:.4f}".format(acc))

GraphSAGE训练:

modeSAGE = GraphSAGE(g.ndata['feat'].shape[1],16,dataset.num_classes,2,F.relu,0.5,"gcn")
train(g, modeSAGE)
acc = evaluate(modeSAGE, g, features, labels, test_nid)
print("Test Accuracy {:.4f}".format(acc))

运行上述代码即可得到分类的效果,一般来说GraphSAGE的效果会略好于双层的GCN,但差距并不太大。

3. 总结

本文主要在DGL包自带的同构图数据集上进行了一个简单的图节点分类的尝试,之后会尝试在其他数据集(异构图/知识图谱)上进行图节点分类的任务。

图神经网络实践之图节点分类(一)相关推荐

  1. Graph Decipher: A transparent dual-attention graph neural network 图解密器:一种透明的双注意图神经网络,用于理解节点分类的消息传递机制

    引用 Pang Y, Liu C. Graph Decipher: A transparent dual-attention graph neural network to understand th ...

  2. 图神经网络17-DGL实战:节点分类/回归

    对于图神经网络来说,最常见和被广泛使用的任务之一就是节点分类. 图数据中的训练.验证和测试集中的每个节点都具有从一组预定义的类别中分配的一个类别,即正确的标注. 节点回归任务也类似,训练.验证和测试集 ...

  3. 图神经网络(三):节点分类

    节点分类问题 数据集:Cora 包含七类学术论文,论文与论文之间存在引用和被引用的关系 数据集导入 from torch_geometric.datasets import Planetoid fro ...

  4. 论文阅读笔记:《一种改进的图卷积网络半监督节点分类》

    论文阅读笔记:<一种改进的图卷积网络半监督节点分类> 文章目录 论文阅读笔记:<一种改进的图卷积网络半监督节点分类> 摘要: 引言 非欧几里得数据 1 深度池化对偶图神经网络 ...

  5. 图神经网络基础--基于图神经网络的节点表征学习

    图神经网络基础–基于图神经网络的节点表征学习 引言 在图节点预测或边预测任务中,首先需要生成节点表征(Node Representation).我们使用图神经网络来生成节点表征,并通过基于监督学习的对 ...

  6. activiti动态增加节点_图神经网络之动态图

    图这种结构普遍存在于人类社会生活中,如互联网中网页间的互相链接会构成图.网民购买商品会构成"网民-商品"图.人和人的交流会构成图.论文的互相引用也会构成图.有许多任务需要根据这些图 ...

  7. 【论文解读|AAAI2021】HGSL - Heterogeneous Graph Structure Learning for Graph Neural Networks 图神经网络的异构图结构学习

    文章目录 1 摘要 2 引言 相关工作 3 方法 3.1 特征图产生器 3.1.1 特征相似图 3.1.2特征传播图 3.2 语义图生成器 4 实验 5 结论 论文链接: http://shichua ...

  8. 基于图神经网络的异构图表示学习和推荐算法研究(完整代码+数据)

    基于图神经网络的异构图表示学习和推荐算法研究.包含基于对比学习的关系感知异构图神经网络(Relation-aware Heterogeneous Graph Neural Network with C ...

  9. 图神经网络系列-Graph图基本介绍、度中心性、特征向量中心性、中介中心性、连接中心性

    图神经网络系列-Graph图基本介绍.度中心性.特征向量中心性.中介中心性.连接中心性 目录 图的定义 图的类型 空图形 简单图 多重图 有向图 无向图 连通与断开图 正则图 完全图 循环图 二部图 ...

最新文章

  1. Redux 入门教程(二):中间件与异步操作
  2. 配置gradle时,一直报错提示:ERROR: JAVA_HOME is set to an invalid directory: D:\Java\jdk1.8.0_144;
  3. leetcode 226. Invert Binary Tree
  4. ABAP业务涉及到的相关数据库表 .
  5. php 实验室管理系统,生物信息实验室管理系统-Metalims安装
  6. WD与循环 组合数学
  7. 【每日SQL打卡】​​​​​​​​​​​​​​​DAY 2丨连续出现的数字【难度中等】
  8. 解决Windows x64bit环境下无法使用PLSQL Developer连接到Oracle DB中的问题
  9. mysql宾馆客房管理系统视频_java swing mysql实现的酒店宾馆管理系统项目源码附带视频指导运行教程...
  10. 【电路基础】第1章-电路的基本规律(1)
  11. Java系统日志管理
  12. 区块链技术之P2P网络(一)
  13. 基于颜色的R2V软件快速矢量化
  14. 2022-2028全球独立水疗浴缸行业调研及趋势分析报告
  15. vue项目使用vue-amap调用高德地图api详细步骤
  16. 微信小程序开发什么工具好?
  17. js代码优化8个优点
  18. 新闻管理系统源码java_小虫新闻管理系统 .rar - WEB源码|JSP源码/Java|源代码 - 源码中国...
  19. Voluntarily Relinquishing the Processor-----《Pro_Java_8_Programming_(3rd_edition)》
  20. 长沙计算机应届生工资水平,长沙毕业生期望的平均月薪是多少?答案在这里

热门文章

  1. 预防和减少交警在事故处置现场的伤亡 美国有哪些经验可借鉴?
  2. 2015年互联网医疗如何“移动医疗”?如何“医疗移动”?
  3. 好用的chatgpt网站推荐
  4. c语言复数运算 除法,c语言 复数的运算
  5. 去除页面上的图片和视频
  6. 如何更详细查看SAP 系统版本信息
  7. NOr flash onenand
  8. 汽车靠发动机带动发电机发电
  9. 国标消消乐---6.国标编码设计
  10. 想知道截图翻译软件有哪些?我来告诉你