对于图神经网络来说,最常见和被广泛使用的任务之一就是节点分类。
图数据中的训练、验证和测试集中的每个节点都具有从一组预定义的类别中分配的一个类别,即正确的标注。
节点回归任务也类似,训练、验证和测试集中的每个节点都被标注了一个正确的数字。

概述

为了对节点进行分类,图神经网络执行了 guide_cn-message-passing
中介绍的消息传递机制,利用节点自身的特征和其邻节点及边的特征来计算节点的隐藏表示。
消息传递可以重复多轮,以利用更大范围的邻居信息。

编写神经网络模型

DGL提供了一些内置的图卷积模块,可以完成一轮消息传递计算。
本章中选择 :class:dgl.nn.pytorch.SAGEConv 作为演示的样例代码(针对MXNet和PyTorch后端也有对应的模块),
它是GraphSAGE模型中使用的图卷积模块。

对于图上的深度学习模型,通常需要一个多层的图神经网络,并在这个网络中要进行多轮的信息传递。
可以通过堆叠图卷积模块来实现这种网络架构,具体如下所示。

# 构建一个2层的GNN模型import dgl.nn as dglnnimport torch.nn as nnimport torch.nn.functional as Fclass SAGE(nn.Module):def __init__(self, in_feats, hid_feats, out_feats):super().__init__()# 实例化SAGEConve,in_feats是输入特征的维度,out_feats是输出特征的维度,aggregator_type是聚合函数的类型self.conv1 = dglnn.SAGEConv(in_feats=in_feats, out_feats=hid_feats, aggregator_type='mean')self.conv2 = dglnn.SAGEConv(in_feats=hid_feats, out_feats=out_feats, aggregator_type='mean')def forward(self, graph, inputs):# 输入是节点的特征h = self.conv1(graph, inputs)h = F.relu(h)h = self.conv2(graph, h)return h

模型的训练

全图(使用所有的节点和边的特征)上的训练只需要使用上面定义的模型进行前向传播计算,并通过在训练节点上比较预测和真实标签来计算损失,从而完成后向传播。

本节使用DGL内置的数据集 :class:dgl.data.CiteseerGraphDataset 来展示模型的训练。
节点特征和标签存储在其图上,训练、验证和测试的分割也以布尔掩码的形式存储在图上。

 node_features = graph.ndata['feat']node_labels = graph.ndata['label']train_mask = graph.ndata['train_mask']valid_mask = graph.ndata['val_mask']test_mask = graph.ndata['test_mask']n_features = node_features.shape[1]n_labels = int(node_labels.max().item() + 1)

下面是通过使用准确性来评估模型的一个例子。

def evaluate(model, graph, features, labels, mask):model.eval()with torch.no_grad():logits = model(graph, features)logits = logits[mask]labels = labels[mask]_, indices = torch.max(logits, dim=1)correct = torch.sum(indices == labels)return correct.item() * 1.0 / len(labels)

用户可以按如下方式实现模型的训练。

    model = SAGE(in_feats=n_features, hid_feats=100, out_feats=n_labels)opt = torch.optim.Adam(model.parameters())for epoch in range(10):model.train()# 使用所有节点(全图)进行前向传播计算logits = model(graph, node_features)# 计算损失值loss = F.cross_entropy(logits[train_mask], node_labels[train_mask])# 计算验证集的准确度acc = evaluate(model, graph, node_features, node_labels, valid_mask)# 进行反向传播计算opt.zero_grad()loss.backward()opt.step()print(loss.item())# 如果需要的话,保存训练好的模型。本例中省略。

DGL的GraphSAGE样例 <https://github.com/dmlc/dgl/blob/master/examples/pytorch/graphsage/train_full.py>__
提供了一个端到端的同构图节点分类的例子。用户可以在 GraphSAGE 类中看到模型实现的细节。
这个模型具有可调节的层数、dropout概率,以及可定制的聚合函数和非线性函数。

异构图上的节点分类模型的训练

如果图是异构的,用户可能希望沿着所有边类型从邻居那里收集消息。
用户可以使用 :class:dgl.nn.pytorch.HeteroGraphConv
模块(针对MXNet和PyTorch后端也有对应的模块)在所有边类型上执行消息传递,
并为每种边类型使用一种图卷积模块。

下面的代码定义了一个异构图卷积模块。模块首先对每种边类型进行单独的图卷积计算,然后将每种边类型上的消息聚合结果再相加,
并作为所有节点类型的最终结果。

# Define a Heterograph Conv modelclass RGCN(nn.Module):def __init__(self, in_feats, hid_feats, out_feats, rel_names):super().__init__()# 实例化HeteroGraphConv,in_feats是输入特征的维度,out_feats是输出特征的维度,aggregate是聚合函数的类型self.conv1 = dglnn.HeteroGraphConv({rel: dglnn.GraphConv(in_feats, hid_feats)for rel in rel_names}, aggregate='sum')self.conv2 = dglnn.HeteroGraphConv({rel: dglnn.GraphConv(hid_feats, out_feats)for rel in rel_names}, aggregate='sum')def forward(self, graph, inputs):# 输入是节点的特征字典h = self.conv1(graph, inputs)h = {k: F.relu(v) for k, v in h.items()}h = self.conv2(graph, h)return h

dgl.nn.HeteroGraphConv 接收一个节点类型和节点特征张量的字典作为输入,并返回另一个节点类型和节点特征的字典。

本章的 guide_cn-training-heterogeneous-graph-example
中已经有了 useritem 的特征,用户可用如下代码获取。

model = RGCN(n_hetero_features, 20, n_user_classes, hetero_graph.etypes)user_feats = hetero_graph.nodes['user'].data['feature']item_feats = hetero_graph.nodes['item'].data['feature']labels = hetero_graph.nodes['user'].data['label']train_mask = hetero_graph.nodes['user'].data['train_mask']

然后,用户可以简单地按如下形式进行前向传播计算:

node_features = {'user': user_feats, 'item': item_feats}h_dict = model(hetero_graph, {'user': user_feats, 'item': item_feats})h_user = h_dict['user']h_item = h_dict['item']

异构图上模型的训练和同构图的模型训练是一样的,只是这里使用了一个包括节点表示的字典来计算预测值。
例如,如果只预测 user 节点的类别,用户可以从返回的字典中提取 user 的节点嵌入。

opt = torch.optim.Adam(model.parameters())for epoch in range(5):model.train()# 使用所有节点的特征进行前向传播计算,并提取输出的user节点嵌入logits = model(hetero_graph, node_features)['user']# 计算损失值loss = F.cross_entropy(logits[train_mask], labels[train_mask])# 计算验证集的准确度。在本例中省略。# 进行反向传播计算opt.zero_grad()loss.backward()opt.step()print(loss.item())# 如果需要的话,保存训练好的模型。本例中省略。

完整例子大家可以参考

DGL提供了一个用于节点分类的RGCN的端到端的例子
RGCN <https://github.com/dmlc/dgl/blob/master/examples/pytorch/rgcn-hetero/entity_classify.py>__
。用户可以在 RGCN模型实现文件 <https://github.com/dmlc/dgl/blob/master/examples/pytorch/rgcn-hetero/model.py>__
中查看异构图卷积 RelGraphConvLayer 的具体定义。

图神经网络17-DGL实战:节点分类/回归相关推荐

  1. Graph Decipher: A transparent dual-attention graph neural network 图解密器:一种透明的双注意图神经网络,用于理解节点分类的消息传递机制

    引用 Pang Y, Liu C. Graph Decipher: A transparent dual-attention graph neural network to understand th ...

  2. 图神经网络(三):节点分类

    节点分类问题 数据集:Cora 包含七类学术论文,论文与论文之间存在引用和被引用的关系 数据集导入 from torch_geometric.datasets import Planetoid fro ...

  3. 图神经网络框架DGL实现Graph Attention Network (GAT)笔记

    参考列表: [1]深入理解图注意力机制 [2]DGL官方学习教程一 --基础操作&消息传递 [3]Cora数据集介绍+python读取 一.DGL实现GAT分类机器学习论文 程序摘自[1],该 ...

  4. 图神经网络框架DGL教程-第4章:图数据处理管道

    更多图神经网络和深度学习内容请关注: 第4章:图数据处理管道 DGL在 dgl.data 里实现了很多常用的图数据集.它们遵循了由 dgl.data.DGLDataset 类定义的标准的数据处理管道. ...

  5. 开源图神经网络框架DGL升级:GCMC训练时间从1天缩到1小时,RGCN实现速度提升291倍...

    乾明 编辑整理  量子位 报道 | 公众号 QbitAI 又一个AI框架迎来升级. 这次,是纽约大学.亚马逊联手推出图神经网络框架DGL. 不仅全面上线了对异构图的支持,复现并开源了相关异构图神经网络 ...

  6. 论文阅读笔记:《一种改进的图卷积网络半监督节点分类》

    论文阅读笔记:<一种改进的图卷积网络半监督节点分类> 文章目录 论文阅读笔记:<一种改进的图卷积网络半监督节点分类> 摘要: 引言 非欧几里得数据 1 深度池化对偶图神经网络 ...

  7. GNN学习笔记(四):图注意力神经网络(GAT)节点分类任务实现

    目录 0 引言 1.Cora数据集 2.citeseer数据集 3.Pubmed数据集 4.DBLP数据集 5.Tox21 数据集 6.代码 嘚嘚嘚,唠叨小主,闪亮登场,哈哈,过时了过时了,闪亮登场换 ...

  8. 图神经网络框架DGL学习 102——图、节点、边及其特征赋值

    101(入门)以后就是开始具体逐项学习图神经网络的各个细节.下面介绍: 1.如何构建图 2.将特征赋给节点或者边,及查询方法 这算是图神经网络最基础最基础的部分了. 一.如何构建图 DGL中创建的图的 ...

  9. (DataWhale)图神经网络Task03:基于图神经网络GCN/GAT的节点表征与分类

    文章目录 Cora数据集的准备与分析 TSNE可视化节点表征分布 图节点分类模型实现与对比(MLP vs. GCN vs. GAT) MLP分类模型 GCN分类模型 GAT分类模型 结果比较与分析 参 ...

最新文章

  1. 踏入职场后,差距来自哪里
  2. 惠普z640服务器装系统,顾问文档: HP Z440、Z640 和 Z840 工作站 - 在采用 Broadwell 处理器的系统上安装 HP ZTurbo Quad Pro 后,出现黑屏...
  3. 单词拆分—leetcode139
  4. 分布式锁用Redis还是Zookeeper?
  5. Python3迭代器和生成器
  6. 南下事业篇——深圳 深圳(回顾)
  7. 基于无监督深度学习的单目视觉的深度和自身运动轨迹估计的深度神经模型
  8. Ubuntu安装过程中的问题
  9. 资深程序员是用五年时间攒够100万,老婆是关键
  10. 五笔打字简明教程(86版)
  11. PcShare服务端改造
  12. CSS3之颜色渐变效果
  13. 这本书非常值得一读!《微习惯》读后感
  14. linux环境下pytesseract的安装和央行征信中心的登录验证码识别
  15. 华为路由器接口编号与接口的对应关系
  16. kali 桥接上网_kali新手入门教学(16)--如何在校园网下使用桥接模式上网
  17. pfx证书转pem、crt、key
  18. Python代码画喜羊羊怎么画_利用Python让你的命令行像蔡徐坤一样会打篮球
  19. 深入理解Java虚拟机(周志明版)总结—WSYW126
  20. CSDN《IT人才成长路线图》重磅开源!60+专家,13个技术领域,绘出35张图谱

热门文章

  1. 发电机是如何被发明的?又是如何工作的?
  2. ROS机器人操作系统(roscpp)
  3. BUUCTF [BJDCTF2020]Mark loves cat
  4. Access denied for user \'root\'@\'localhost\'” 解决办法
  5. 汽车诊断之UDS入门-UDS概述
  6. 再造一个甘肃建投 甘肃省建设投资计划继续推进改革和结构调整
  7. EasyUI获取DataGrid中某一列的所有值
  8. 基于改进逆透视变换的智能车测距技术_禾赛科技:助力2020智能网联汽车C-V2X大规模测试活动_E理财云掌号...
  9. 5分钟教你轻松掌握箱线图
  10. Xmind 2021--做思维导图的软件有哪些,免费又好用的那种