一、使用InMemoryDataset数据集类

import os.path as ospimport torch
from torch_geometric.data import (InMemoryDataset, download_url)
from torch_geometric.io import read_planetoid_dataclass PlanetoidPubMed(InMemoryDataset):r""" 节点代表文章,边代表引文关系。训练、验证和测试的划分通过二进制掩码给出。参数:root (string): 存储数据集的文件夹的路径transform (callable, optional): 数据转换函数,每一次获取数据时被调用。pre_transform (callable, optional): 数据转换函数,数据保存到文件前被调用。"""url = 'https://gitee.com/rongqinchen/planetoid/tree/master/data'def __init__(self, root, transform=None, pre_transform=None):super(PlanetoidPubMed, self).__init__(root, transform, pre_transform)self.data, self.slices = torch.load(self.processed_paths[0])@propertydef raw_dir(self):return osp.join(self.root, 'raw')@propertydef processed_dir(self):return osp.join(self.root, 'processed')@propertydef raw_file_names(self):names = ['x', 'tx', 'allx', 'y', 'ty', 'ally', 'graph', 'test.index']return ['ind.pubmed.{}'.format(name) for name in names]@propertydef processed_file_names(self):return 'data.pt'def download(self):for name in self.raw_file_names:download_url('{}/{}'.format(self.url, name), self.raw_dir)def process(self):data = read_planetoid_data(self.raw_dir, 'pubmed')data = data if self.pre_transform is None else self.pre_transform(data)torch.save(self.collate([data]), self.processed_paths[0])def __repr__(self):return '{}()'.format(self.name)dataset = PlanetoidPubMed('dataset/Cora')
data = dataset[0]
print(dataset.num_classes)
print(dataset[0].num_nodes)
print(dataset[0].num_edges)
print(dataset[0].num_features)

二、节点预测任务

  1. 定义一个GAT图神经网络
from torch_geometric.nn import GATConv,Sequential
from torch.nn import Linear, ReLU
import torch.nn.functional as Fclass GAT(torch.nn.Module):def __init__(self, num_features, hidden_channels_list, num_classes):super(GAT, self).__init__()torch.manual_seed(12345)hns = [num_features] + hidden_channels_listconv_list = []for idx in range(len(hidden_channels_list)):conv_list.append((GATConv(hns[idx], hns[idx+1]), 'x, edge_index -> x'))conv_list.append(ReLU(inplace=True),)self.convseq = Sequential('x, edge_index', conv_list)self.linear = Linear(hidden_channels_list[-1], num_classes)def forward(self, x, edge_index):x = self.convseq(x, edge_index)x = F.dropout(x, p=0.5, training=self.training)x = self.linear(x)return x
  1. 实例化模型并设置参数
model = GAT(num_features=dataset.num_features,hidden_channels_list=[200, 100],num_classes=dataset.num_classes)
print(model)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4)
criterion = torch.nn.CrossEntropyLoss()
  1. 进行训练
def train():model.train()optimizer.zero_grad()  # Clear gradients.out = model(data.x, data.edge_index)  # Perform a single forward pass.# Compute the loss solely based on the training nodes.loss = criterion(out[data.train_mask], data.y[data.train_mask])loss.backward()  # Derive gradients.optimizer.step()  # Update parameters based on gradients.return lossfor epoch in range(1, 201):loss = train()print(f'Epoch: {epoch:03d}, Loss: {loss:.4f}')
  1. 测试结果
def test():model.eval()out = model(data.x, data.edge_index)pred = out.argmax(dim=1)  # Use the class with highest probability.test_correct = pred[data.test_mask] == data.y[data.test_mask]  # Check against ground-truth labels.test_acc = int(test_correct.sum()) / int(data.test_mask.sum())  # Derive ratio of correct predictions.return test_acctest_acc = test()
print(f'Test Accuracy: {test_acc:.4f}')

三、边预测任务实践

import torch
from torch_geometric.nn import GCNConvclass Net(torch.nn.Module):def __init__(self, in_channels, out_channels):super(Net, self).__init__()self.conv1 = GCNConv(in_channels, 128)self.conv2 = GCNConv(128, out_channels)def encode(self, x, edge_index):x = self.conv1(x, edge_index)x = x.relu()return self.conv2(x, edge_index)def decode(self, z, pos_edge_index, neg_edge_index):edge_index = torch.cat([pos_edge_index, neg_edge_index], dim=-1)return (z[edge_index[0]] * z[edge_index[1]]).sum(dim=-1)def decode_all(self, z):prob_adj = z @ z.t()return (prob_adj > 0).nonzero(as_tuple=False).t()

原文地址

节点预测与边预测任务实践相关推荐

  1. 【转】节点预测与边预测任务实践

    节点预测与边预测任务实践 引言 在此小节我们将利用PlanetoidPubMed数据集类,来实践节点预测与边预测任务. 注:边预测任务实践中的代码来源于link_pred.py. 节点预测任务实践 之 ...

  2. PaddleOCR加载chinese_ocr_db_crnn_modile模型进行中英文混合预测(Http服务)实践

    1. 环境搭建 参考:<PaddleOCR加载chinese_ocr_db_crnn_server模型进行中英文混合预测(命令行)实践> 2. 服务端部署 hub serving star ...

  3. 节点相似性与链路预测

    一.问题描述与评价标准 刻画节点的相似性有很多种方法,最简单直接的就是利用节点的属性.近年来,基于网络结构信息的节点相似性刻画得到了越来越多的重视. 节点相似性分析的一个典型应用就是链路预测,它是指如 ...

  4. 节点表征学习与节点预测和边预测

    基于图神经网络的节点表征学习 引言 在图节点预测或边预测任务中,需要先构造节点表征(representation),节点表征是图节点预测和边预测任务成功的关键.在此篇文章中,我们将学习如何基于图神经网 ...

  5. 使用PyG进行图神经网络的节点分类、链路预测和异常检测

    图神经网络(Graph Neural Networks)是一种针对图结构数据(如社交图.网络安全网络或分子表示)设计的机器学习算法.它在过去几年里发展迅速,被用于许多不同的应用程序.在这篇文章中我们将 ...

  6. 定量预测方法总结及案例实践

    文章目录 1 前序 2 预测方法及案例 2.1 回归分析 2.1.1 含有哑变量的线性回归分析案例 2.1.2 自变量之间有交互作用的回归分析案例 2.1.3 非线性回归分析--预测第三产业国内生产总 ...

  7. 大数据毕业设计 LSTM时间序列预测算法 - 股票预测 天气预测 房价预测

    文章目录 0 简介 1 基于 Keras 用 LSTM 网络做时间序列预测 2 长短记忆网络 3 LSTM 网络结构和原理 3.1 LSTM核心思想 3.2 遗忘门 3.3 输入门 3.4 输出门 4 ...

  8. 毕业设计 LSTM的预测算法 - 股票预测 天气预测 房价预测

    文章目录 0 简介 1 基于 Keras 用 LSTM 网络做时间序列预测 2 长短记忆网络 3 LSTM 网络结构和原理 3.1 LSTM核心思想 3.2 遗忘门 3.3 输入门 3.4 输出门 4 ...

  9. R语言构建xgboost模型、预测推理:输出预测概率、预测标签

    R语言构建xgboost模型.预测推理:输出预测概率.预测标签 目录 R

  10. ML之分类预测:分类预测评估指标之AUC计算的的两种函数具体代码案例实现

    ML之分类预测:分类预测评估指标之AUC计算的的两种函数具体代码案例实现 目录 分类预测评估指标之AUC计算的的两种函数代码案例实现 输出结果 实现代码

最新文章

  1. Acwing--朴素dijkstra
  2. 笔记,提醒,pytorch安装命令(conda)
  3. CMake结合PCL库学习(2)
  4. Codeforces 337D Book of Evil:树的直径【结论】
  5. Vuejs发送Ajax请求
  6. (大纲)三小时学会openCV
  7. C#控件访问调用它的父级页面
  8. [Tip]ActiveScaffold本地化
  9. 统计字符串出现的次数(参照传智播客视频)
  10. Android拍照返回图片
  11. 企业发卡系统源码/带有代理功能发卡平台源码
  12. Python 调用谷歌翻译(2021年9月测试可用)
  13. 大数据培训课资源调度器详解
  14. 前端CSS核心部分盒子模型
  15. 08.音频系统:第002节_Android音频系统框架简述
  16. 内网渗透DC-1靶场通关(CTF)
  17. CAD如何快速标注尺寸?CAD标注尺寸教程
  18. OC和swift混合工程更新库时报:target has transitive dependencies that include statically linked binaries
  19. AMD将坚持x86架构,不会投身ARM架构怀抱
  20. nested exception is java.io.FileNotFoundException: class path resource [springmvc.xml] cannot be ope

热门文章

  1. 前端开发 简单表格的编写练习 0228
  2. dj鲜生-34-存档-用户中心地址页重复查询默认地址的优化-利用自定义模型管理器的方法来实现
  3. git-创建版本仓库-创建版本-查看版本
  4. prometheus修改数据保留时间
  5. 有孚网络北京云数据中心荣获绿色建筑国际LEED金牌认证和国家CQC A级机房认证...
  6. 安装JDK以及配置Java运行环境
  7. 微信小程序开发之路(二)
  8. Android Ubuntu 安装问题FAQ
  9. 使用solrj和EasyNet.Solr进行原子更新
  10. 面试:Websocket