PyG框架:Graph Classification
训练GNN用来做Graph Classification
一、原理
1、根据Message Passing得到每个节点的node embedding
2、readout layer
把所有节点的node embedding聚合成整个图的graph embedding。
【文献中有很多种不同的readout layer,但最常用的是mean】
【跟Node Classification的区别】:是否把每个节点的node embedding聚合成一个graph embedding?
针对mini-batch,PyG框架有封装好的模块,torch_geometric.nn.global_mean_pool 可以分别将mini-batch中每个图的所有node embedding聚合成一个graph embedding(一个batch中有多少个图,就有多少个graph embedding)。一个batch的graph embedding矩阵的shape为:[batch_size,hidden_channels]。hidden_channels:一个graph embedding(向量)的长度
3、训练一个针对graph embedding的分类器
二、代码实现
PyG框架是什么?如何安装?可以参照官方文档or我的上一篇博客:https://blog.csdn.net/qq_38432089/article/details/122152640?spm=1001.2014.3001.5501
1、数据集准备
import torch
from torch_geometric.datasets import TUDataset# 1、数据集下载
dataset = TUDataset(root='data/TUDataset', name='MUTAG')# 查看数据集信息
print()
print(f'Dataset:{dataset}:')
print('====================')
print(f'Number of graphs:{len(dataset)}')
print(f'Number of features:{dataset.num_features}')
print(f'Number of classes:{dataset.num_classes}')data = dataset[0] # Get the first graph object.print()
print(data)
print('=============================================================')# Gather some statistics about the first graph.
print(f'Number of nodes:{data.num_nodes}')
print(f'Number of edges:{data.num_edges}')
print(f'Average node degree:{data.num_edges / data.num_nodes:.2f}')
print(f'Has isolated nodes:{data.has_isolated_nodes()}')
print(f'Has self-loops:{data.has_self_loops()}')
print(f'Is undirected:{data.is_undirected()}')
# 2、训练集、测试集准备
torch.manual_seed(12345)
dataset = dataset.shuffle()
train_dataset = dataset[:150]
test_dataset = dataset[150:]
# 训练集、测试集数量
print(f'Number of training graphs:{len(train_dataset)}')
print(f'Number of test graphs:{len(test_dataset)}')
# 3、mini-batch
from torch_geometric.loader import DataLoadertrain_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)# 查看每个batch的信息:2⋅64+22=150 graphs.
for step, data in enumerate(train_loader):print(f'Step{step + 1}:')print('=======')print(f'Number of graphs in the current batch:{data.num_graphs}')print(data)print()
2、模型搭建
from torch.nn import Linear
import torch.nn.functional as F
from torch_geometric.nn import GCNConv
from torch_geometric.nn import global_mean_poolclass GCN(torch.nn.Module):def __init__(self, hidden_channels):super(GCN, self).__init__()torch.manual_seed(12345)self.conv1 = GCNConv(dataset.num_node_features, hidden_channels)self.conv2 = GCNConv(hidden_channels, hidden_channels)self.conv3 = GCNConv(hidden_channels, hidden_channels)self.lin = Linear(hidden_channels, dataset.num_classes)def forward(self, x, edge_index, batch):# 1. Obtain node embeddingsx = self.conv1(x, edge_index)x = x.relu()x = self.conv2(x, edge_index)x = x.relu()x = self.conv3(x, edge_index)# 2. Readout layerx = global_mean_pool(x, batch) # [batch_size, hidden_channels]# 3. Apply a final classifierx = F.dropout(x, p=0.5, training=self.training)x = self.lin(x)return xmodel = GCN(hidden_channels=64)
print(model)
model = GCN(hidden_channels=64)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
criterion = torch.nn.CrossEntropyLoss()def train():model.train()for data in train_loader: # Iterate in batches over the training dataset.out = model(data.x, data.edge_index, data.batch) # Perform a single forward pass.loss = criterion(out, data.y) # Compute the loss.loss.backward() # Derive gradients.optimizer.step() # Update parameters based on gradients.optimizer.zero_grad() # Clear gradients.def test(loader):model.eval()correct = 0for data in loader: # Iterate in batches over the training/test dataset.out = model(data.x, data.edge_index, data.batch) pred = out.argmax(dim=1) # Use the class with highest probability.correct += int((pred == data.y).sum()) # Check against ground-truth labels.return correct / len(loader.dataset) # Derive ratio of correct predictions.for epoch in range(1, 171):train()train_acc = test(train_loader)test_acc = test(test_loader)print(f'Epoch:{epoch:03d}, Train Acc:{train_acc:.4f}, Test Acc:{test_acc:.4f}')
运行结果:
PyG框架:Graph Classification相关推荐
- PyG框架:mini-batch
一.mini-batch 在graph classification的一些基准数据集中,每个图的样本都很小,如果每次只操作一个,不能充分利用GPU资源.所以考虑把它们分成多个mini-batch. 1 ...
- SortPool (DGCNN) - An End-to-End Deep Learning Architecture for Graph Classification AAAI 2018
文章目录 1 背景介绍 图核方法 DGCNN和WL和PK的关系 2 Deep Graph Convolutional Neural Network (DGCNN) 深度图卷积神经网络 图卷积层 与We ...
- 论文阅读笔记:Multi-view adaptive graph convolutions for graph classification
论文阅读笔记:Multi-view adaptive graph convolutions for graph classification 文章目录 论文阅读笔记:Multi-view adapti ...
- graph classification and drug discovery
Capsule network Graph Capsule Convolutional Neural Networks (ICML 2018) code Capsule Graph Neural Ne ...
- 《CapsE—Graph Classification via Capsule Neural Networks》解读
<Capsule Neural Networks for Graph Classification using Explicit Tensorial Graph Representations& ...
- 论文阅读Graph-Hist: Graph Classification from LatentFeature Histograms with Application to Bot Detection
方法介绍: 在这项工作中,开发了一种受经典网络分析启发的新图分类架构.在大型网络分析中,通常的做法是计算局部特征并研究分布. 在这里,我们使用端到端的图卷积架构来提取局部潜在特征并根据这些特征的分布对 ...
- GNN Pooling(三):An End-to-End Deep Learning Architecture for Graph Classification,AAAI2018;以及图核
目录 核,图核,图卷积核 Deep Graph Convolutional Neural Network (DGCNN) Graph convolution layers Connection wit ...
- 论文解读(GLA)《Label-invariant Augmentation for Semi-Supervised Graph Classification》
- PyG图神经网络框架torch-geometric安装
最近需要使用到PyG框架,安装的时候需要注意一些问题,记录一下,方便后来者避坑! 步骤1 首先要先确定自己的torch版本 如果使用的Anaconda可以使用conda list命令查看版本号 进入官 ...
- 2018_WWW_Dual Graph Convolutional Networks for Graph-Based Semi-Supervised Classification
[论文阅读笔记]2018_WWW_Dual Graph Convolutional Networks for Graph-Based Semi-Supervised Classification-(T ...
最新文章
- 因为那里面有我,也有你
- 硬核科普:一文看懂人脸识别技术流程
- 手机qpython下载_QPython
- 详解CSS三大特性之层叠性、继承性和重要性——Web前端系列学习笔记
- jmeter性能分析_使用JMeter和Yourkit进行REST / HTTP服务的性能分析
- 为Lucene选择快速唯一标识符(UUID)
- [Leetcode][第392题][JAVA][判断子序列][动态规划][双指针]
- linux之openssh配置
- 加速Qt在线更新--使用traefik-1.7.24(不支持traefik-2.0以上版本
- 新编译的GDAL1.9 C/C++ C# Python版本
- 编程语言EF速度测试(4):nsieve-bits
- zabbix利用sendEmail邮件报警
- JAVA里的jsp网页背景_Java-带CSS的JSP不显示背景图像
- svnadmin hotcopy整库拷贝方式(转载)
- android im腾讯云,腾讯云即时通信 IMSDK 相关问题
- 表达无序列表语义的html标签是,HTML语义标签的介绍和常用的语义标签
- 宝宝的个人博客开通了
- 阿卡索口语学习(Learn And Talk 0)短语及单词(二)
- 三角形内切圆和外接圆半径及其面积计算
- 计算机网络实践的体会,计算机网络实训心得体会