第十六课.Pytorch-geometric入门(一)
目录
- PyG安装
- 图结构基础
- 基准数据集
- Mini-Batches
- 构建GCN
PyG安装
Pytorch-geometric即PyG,是一个基于pytorch的图神经网络框架。其官方链接为:PyG
在安装PyG之前,我们需要先安装好pytorch,建议使用更高版本的pytorch,比如 pytorch1.9.x + cuda11.1,然后使用pip安装,对于windows系统,我们可以做以下操作:
pip install torch-scatter torch-sparse torch-cluster torch-spline-conv torch-geometric -f https://pytorch-geometric.com/whl/torch-1.9.0+cu111.html
图结构基础
在PyG中,一个Graph被定义为g=(X,(I,E))g=(X,(I,E))g=(X,(I,E)),其中,XXX表示节点的特征矩阵,NNN为节点的个数,FFF为每个节点的特征维数;我们用元组(I,E)(I,E)(I,E)表示图的邻接矩阵(COO稀疏格式),III是边的索引,EEE是DDD维的边特征。
关于COO格式的稀疏矩阵
普通稀疏矩阵的最常见存储方式为坐标存储法(coordinate format),把矩阵的行列值(i,j,v)(i,j,v)(i,j,v)记录下来,假设现在获得一个邻接矩阵,我们可以用COO格式保存:
观察邻接矩阵,非零的元素表示有边存在,矩阵中共有9个非零元素,因此有9条单向的边。我们从左向右,从上到下记录这些边在矩阵中的索引,以及值;例如第一条边,值为1,位于矩阵的第0行,第0列。
特别的,对于Graph,我们只考虑边的连接关系时,邻接矩阵的值就只有0和1,因此,我们可以省略COO格式中的value这个对象。
一个Graph本质是torch_geometric.data.Data
的实例,它包括以下几个常见对象(属性,attributes):
data.x
:节点的特征矩阵,形状为[num_nodes,num_node_features]
data.edge_index
:图的边索引,用COO稀疏矩阵格式保存,形状为[2,num_edgs]
,数据类型为torch.long
;data.edge_attr
:边的特征矩阵,形状为[num_edges,num_edge_features]
;data.y
:计算损失所需的目标数据,target,针对训练的目标可能有不同的形状,比如节点级别的形状为[num_nodes,*]
,或者图级别的形状为[1,*]
;
下面我们构建一个简单的无权无向图,每个节点的特征维数为1:
import torch
from torch_geometric.data import Data"""
边的邻接矩阵为:
[[0,1,0],[1,0,1],[0,1,0]]
"""
edge_index=torch.tensor([[0,1,1,2],[1,0,2,1]],dtype=torch.long)x=torch.tensor([[-1],[0],[1]],dtype=torch.float)data=Data(x=x,edge_index=edge_index)print(data) # Data(edge_index=[2, 4], x=[3, 1])
注意到,图虽然只有两条边,但我们需要定义4个索引元组来说明一条边的两个方向。
我们可以将图数据迁移到GPU上:
device=torch.device("cuda")
data=data.to(device)
基准数据集
PyG 包含大量常见的基准数据集,例如所有 Planetoid 数据集(Cora、Citeseer、Pubmed)。初始化数据集很简单。 数据集的初始化将自动下载其原始文件并将其处理为之前描述的数据格式。 例如,要加载 ENZYMES 数据集(由 6 个类别中的 600 个图组成):
from torch_geometric.datasets import TUDataset# 第一次调用会将数据集下载保存至'./datasets'下
dataset=TUDataset(root='./datasets',name='ENZYMES')print(dataset) # ENZYMES(600)
print(len(dataset)) # 600
print(dataset.num_classes) # 6
print(dataset.num_node_features) # 3
我们现在可以访问数据集中的所有 600 个图:
data=dataset[0]print(data) # Data(edge_index=[2, 168], x=[37, 3], y=[1])
print(data.is_undirected()) # True
我们可以看到数据集中的第一个图包含 37 个节点,每个节点有 3 个特征。 有 168/2 = 84 条无向边,并且该图恰好分配给一个类。 此外,数据对象正好持有一个图级别目标。
现在,我们下载 Cora,一个半监督图节点分类的基准数据集:
from torch_geometric.datasets import Planetoiddataset=Planetoid(root='./datasets',name='Cora')print(len(dataset)) # 1, 只有1个图
print(dataset.num_classes) # 7
print(dataset.num_node_features) # 1433
在这里,数据集仅包含一个无向图:
data=dataset[0]print(data.is_undirected()) # True
print(data.num_nodes) # 2708
print(data.train_mask.sum().item()) # 140
print(data.val_mask.sum().item()) # 500
print(data.test_mask.sum().item()) # 1000
这次,Data
对象为每个节点保存了一个标签,以及附加的节点级属性:train_mask
,val_mask
,test_mask
,其中:
train_mask
:表示针对哪些节点进行训练(140个节点);val_mask
:表示针对哪些节点进行验证(500个节点);test_mask
:表示针对哪些节点进行测试(1000个节点);
Mini-Batches
神经网络通常以批处理方式进行训练。 PyG 通过创建稀疏块对角邻接矩阵(由 edge_index 定义)并在节点维度中连接特征和目标矩阵来实现小批量的并行化。 这种组合允许在一批中的示例上有不同数量的节点和边。
PyG 包含自己的 torch_geometric.loader.DataLoader
,它已经负责这个串联过程。 我们通过一个例子来了解它:
from torch_geometric.datasets import TUDataset
from torch_geometric.data import DataLoaderdataset=TUDataset(root='./datasets',name='ENZYMES')
loader=DataLoader(dataset,batch_size=32,shuffle=True)for batch in loader:print(batch.num_graphs) # 32
torch_geometric.data.Batch
继承自 torch_geometric.data.Data
并包含一个名为 batch
的附加属性。batch
是一个列向量,它将每个节点映射到批处理中的相应图中。
构建GCN
现在我们使用一个简单的GCN层并在Cora引文数据集上复现实验。
首先加载Cora数据集:
from torch_geometric.datasets import Planetoiddataset=Planetoid(root="./datasets",name="Cora")
现在实现一个两层GCN:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import GCNConvclass GCN(nn.Module):def __init__(self):super().__init__()self.conv1=GCNConv(dataset.num_node_features,16)self.conv2=GCNConv(16,dataset.num_classes)def forward(self,data):x,edge_index=data.x,data.edge_indexx=self.conv1(x,edge_index)x=F.relu(x)x=F.dropout(x)x=self.conv2(x,edge_index)# print(x.size()) # torch.Size([2708, 7])return F.log_softmax(x,dim=-1)
定义损失函数和优化方法,训练模型:
device=torch.device('cuda' if torch.cuda.is_available() else 'cpu')model=GCN().to(device)data=dataset[0].to(device)
print(data) # Data(edge_index=[2, 10556], test_mask=[2708], train_mask=[2708], val_mask=[2708], x=[2708, 1433], y=[2708])
print(data.num_nodes) # 2708optimizer=torch.optim.Adam(model.parameters(),lr=0.01,weight_decay=5e-4)model.train()
for epoch in range(200):optimizer.zero_grad()out=model(data)loss=F.nll_loss(out[data.train_mask],data.y[data.train_mask])loss.backward()optimizer.step()
最后在测试节点上评估模型:
model.eval()
pred = model(data).argmax(dim=-1) # argmax返回最大值的索引print(pred.size()) # torch.Size([2708])
print(pred[data.test_mask].size()) # torch.Size([1000])
print(data.test_mask) # bool型, torch.Size([2708])correct = (pred[data.test_mask] == data.y[data.test_mask]).sum()
acc = int(correct) / int(data.test_mask.sum())
print('Accuracy: {}'.format(acc)) # Accuracy: 0.775
第十六课.Pytorch-geometric入门(一)相关推荐
- NeHe OpenGL教程 第三十六课:从渲染到纹理
转自[翻译]NeHe OpenGL 教程 前言 声明,此 NeHe OpenGL教程系列文章由51博客yarin翻译(2010-08-19),本博客为转载并稍加整理与修改.对NeHe的OpenGL管线 ...
- NeHe OpenGL第四十六课:全屏反走样
NeHe OpenGL第四十六课:全屏反走样 全屏反走样 当今显卡的强大功能,你几乎什么都不用做,只需要在创建窗口的时候该一个数据.看看吧,驱动程序为你做完了一切. 在图形的绘制中,直线的走样是非 ...
- OpenGL教程翻译 第二十六课 法线纹理
第二十六课 法线纹理 背景 我们之前使用的光照技术还算不错,光线在模型表面得到了很好的插值,为场景营造出真实感.但是这种效果还能够有非常大的提升.事实上,我们以前使用的这种插值方式在某种程度上来说是对 ...
- 量化交易 第十六课 单因子有效性分析之收益率分析
第十六课 单因子有效性分析之收益率分析 概述 因子收益率 因子收益率计算 计算数值结果 分为数分组结果 因子在周期内的平均收益率 概述 我们需要通过分析因子的收益率来确定因子在不同股票位置上的表现. ...
- NeHe OpenGL教程 第二十六课:反射
转自[翻译]NeHe OpenGL 教程 前言 声明,此 NeHe OpenGL教程系列文章由51博客yarin翻译(2010-08-19),本博客为转载并稍加整理与修改.对NeHe的OpenGL管线 ...
- 深入浅出CChart 每日一课——快乐高四第五十六课 絮絮叨叨,岁月杀猪刀之FAQ
CChart发布已有多年,QQ交流群也成立了很久.在和网友的交流中,发行了CChart的很多问题,也进行了很多改进和完善. 网友们接触CChart的时间有早有晚,不同的网友经常在群里或私聊的时候提出的 ...
- OpenGL3.0教程 第十六课:阴影贴图
OpenGL3.0教程 原文链接:http://www.opengl-tutorial.org/intermediate-tutorials/tutorial-12-opengl-extensions ...
- Android Things创客DIY第六课-Android Things入门配件包开发案例教程-4位数码管显示
4位数码管显示 之前的<Android Things创客DIY第三课-Android Things入门配件包开发案例教程-数码管显示>中,介绍了如何使用Android Things控制1位 ...
- ionic入门教程第十六课-在微信中使用ionic的解决方案(按需加载加强版)
对于微信端来说,其实使用ionic是一个比较大的前端框架. 有更多比较轻量化的前端框架可以选择. 但是使用ionic有一个明显的优点就是,能够做到一端开发,三端同步上线. 这个梗说了好多遍了,但确实是 ...
最新文章
- 让产品自己召唤人——马化腾
- 联想微型计算机2005款配置,2005款联想43厘米液晶显示屏,55寸液晶屏价格
- [ACM_几何] Wall
- 刚开始学Web前端,用什么软件好?
- DIV或者DIV里面的图片水平与垂直居中的方法 - 站住,别跑 - 博客园
- 论文浅尝 | 使用位置敏感的序列标注联合抽取实体和重叠关系
- 七年级上册计算机工作计划,清华大学版信息技术七年级上册学期教学工作计划...
- 【python】 邮件发送-----zmail
- 创业,如果不懂这9条路径规划,就等于走上了一条不归路
- vue前台导出zip文件_在Vue.js中使用JSZip实现在前端解压文件的方法_心病_前端开发者...
- Intel彻底封杀Skylake非黑盒版超频
- 一套提取自 Ant Design 的优质图标
- angular 注入器配置_Angular 的服务逻辑
- 通用时与儒略日代码解析
- 毕业设计指导教师评语 计算机,毕业设计指导教师评语
- python图片查看器
- 测量RT-Thread线程调度的时间的方法
- 《Java Web程序设计基础教程》简介
- 2014年计算机求职总结--准备篇 (顺便也带点自己在美国准备的总结吧)
- 【CS 1373】射命丸文(二维前缀和)
热门文章
- Spring Boot 青睐的数据库连接池HikariCP为什么是史上最快的?
- Spring 和 SpringBoot 最核心的 3 大区别,详解!
- 阿里员工绩效只拿3.25!自我反省:平时假装努力!晚上没加班!去厕所时间太长!还老买彩票!...
- 写给小白看的线程和进程,高手勿入
- JVM调优,面到了阿里性能优化师!
- 滴滴千万级ElasticSearch平台发展之路!
- 如何为MNIST手写数字分类开发CNN
- 月薪8k和月薪38K的程序员差距在哪里?
- 分享|智办事助力杭州佰勤医疗器械组织管理数字化过渡
- java 手机号脱敏,身份证号脱敏 工具类