GraphGallery,一个基于TensorFlow 2.x与 PyTorch 的GNN benchmark 框架
GraphGallery
【导读】图神经网络(Graph Neural Networks,GNN)是近几年兴起的新的研究热点,其借鉴了传统卷积神经网络等模型的思想,在图结构数据上定义了一种新的神经网络架构。如果作为初入该领域的科研人员,想要快速学习并验证自己的idea,需要花费一定的时间搜集数据集,定义模型的训练测试过程,寻找现有的模型进行比较测试,这无疑是繁琐且不必要的。GraphGallery 为科研人员提供了一个简单方便的框架,用于在一些常用的数据集上快速建立和测试自己的模型,并且与现有的 benchmark 模型进行比较。其支持目前主流的两大机器学习框架:TensorFlow 和 PyTorch,为科研人员提供了一些简易操作的API。
安装
- 直接从源码安装(可以体验最新版本)
git clone https://github.com/EdisonLeeeee/GraphGallery.git
cd GraphGallery
python setup.py install
- 从 Pypi 安装(可以使用稳定版本)
# -U 表示升级使用最新版本
pip install -U graphgallery
快速上手
1. Dataset
数据集包含两种,一种是领域内划分好的数据集 Planetoid
,以及扩展性更强的以 npz
格式存储的数据集。
数据集详细信息请见 https://github.com/EdisonLeeeee/GraphData
- Planetoid
from graphgallery.data import Planetoid
# set `verbose=False` to avoid additional outputs
data = Planetoid('cora', verbose=False)
graph = data.graph
idx_train, idx_val, idx_test = data.split() # 使用固定的划分,即 每个类别20个结点作为训练集,剩余结点中选取500个作为验证集,1000个作为测试集
>>> graph
Graph(adj_matrix(2708, 2708), attr_matrix(2708, 2708), labels(2708,))
目前包含 3 种数据集
>>> data.supported_datasets
('citeseer', 'cora', 'pubmed')
- NPZDataset
from graphgallery.data import NPZDataset;
data = NPZDataset('cora', verbose=False)
graph = data.graph
idx_train, idx_val, idx_test = data.split(random_state=42) # 采用 10%,10%,80%的划分
>>> graph
Graph(adj_matrix(2708, 2708), attr_matrix(2708, 2708), labels(2708,))
目前包含 13 种数据集
>>> data.supported_datasets
('citeseer', 'citeseer_full', 'cora', 'cora_ml', 'cora_full', 'amazon_cs', 'amazon_photo', 'coauthor_cs', 'coauthor_phy', 'polblogs', 'pubmed', 'flickr', 'blogcatalog')
定义自己的 npz
数据集
from graphgallery.data import Graph# Load the adjacency matrix A, attribute matrix X and labels vector y
# A - scipy.sparse.csr_matrix of shape [n_nodes, n_nodes]
# X - scipy.sparse.csr_matrix or np.ndarray of shape [n_nodes, n_atts]
# y - np.ndarray of shape [n_nodes]
...mydataset = Graph(adj_matrix=A, attr_matrix=X, labels=y)
# save dataset
mydataset.to_npz('path/to/mydataset.npz')
# load dataset
mydataset = Graph.from_npz('path/to/mydataset.npz')
2. Config
GraphGallery 支持 TensorFlow 和 PyTorch 两个后端(默认TensorFlow 后端),通过切换后端可以调用不同的API和模型
>>> from graphgallery import backend, set_backend
>>> backend()
TensorFlow 2.1.2 Backend>>> set_backend('torch') # torch, pytorch or th
PyTorch 1.6.0+cu101 Backend>>> set_backend('tf') # tensorflow or tf
TensorFlow 2.1.2 Backend
同时,支持定义运算过程中的张量 浮点数和整数类型
>>> from graphgallery import intx, floatx, set_intx, set_floatx
>>> intx() # TensorFlow 后端整数默认 int32, PyTorch后端默认 int64
>>> floatx() # 对于两个后端浮点数默认皆为 float32# 修改默认数据类型
>>> set_intx('int64')
>>> set_floatx('float64')
3. Tensor
GraphGallery 支持将任意输入转换为合适后端的张量(并给予合适的数据类型)
- 普通张量
>>> backend()
TensorFlow 2.1.2 Backend>>> from graphgallery import transforms as T
>>> arr = [1, 2, 3]
>>> T.astensor(arr)
<tf.Tensor: shape=(3,), dtype=int32, numpy=array([1, 2, 3], dtype=int32)>
- 稀疏张量
>>> import scipy.sparse as sp
>>> sp_matrix = sp.eye(3) # 创建一个 3X3 的单位矩阵
>>> T.astensor(sp_matrix)
<tensorflow.python.framework.sparse_tensor.SparseTensor at 0x7f1bbc205dd8>
类似的,只需要切换后端,亦可将输入转换为 PyTorch 张量
>>> set_backend('torch') # torch, pytorch or th
PyTorch 1.6.0+cu101 Backend>>> T.astensor(arr)
tensor([1, 2, 3])>>> T.astensor(sp_matrix)
tensor(indices=tensor([[0, 1, 2],[0, 1, 2]]),values=tensor([1., 1., 1.]),size=(3, 3), nnz=3, layout=torch.sparse_coo)
astensor
函数接收三个参数,
x
: 需要转化的Python对象dtype
: 转化的类型,若不指定则根据后端的 intx(), floatx() 函数推断devicie
: 参数所在的设备 (可以指定"CPU", “GPU”, “cuda”, “GPU:0” 等等),若不指定则为 “CPU:0”kind
: 转化成何种张量,“T” 表示 TensorFlow 张量,“P” 表示 PyTorch 张量,若不指定则模型转为当前后端适合的张量
4. Transforms
GraphGallery 的 transforms 模块包含各种对输入数据的变换操作,例如针对(稀疏)邻接矩阵的变换,(密集)特征矩阵的变换,以及包含上节所述的张量转换。
例如对稀疏邻接矩阵(adjacency matrix)做 GCN 常见的归一化操作
>>> from graphgallery import transforms as T
>>> T.normalize_adj(adj_matrix)
其默认实现了
KaTeX parse error: Undefined control sequence: \mbox at position 73: …frac{1}{2}},\\ \̲m̲b̲o̲x̲{where} \ \tild…
以及对结点特征矩阵(Attribute matrix)做行归一化
>>> from graphgallery import transforms as T
>>> T.normalize_attr(attr_matrix)
其默认实现了
KaTeX parse error: Undefined control sequence: \mbox at position 23: …} = D^{-1}X,\\ \̲m̲b̲o̲x̲{where}\, D^{-1…
5. Models
顾名思义,GraphGallery 是一个GNN模型的 Gallery。
GraphGallery 实现了一系列的半监督结点分类模型,具体可见项目主页:https://github.com/EdisonLeeeee/GraphGallery
以最常见的GCN模型为例
from graphgallery.nn.models import GCN
model = GCN(graph, adj_transform='normalize_adj', attr_transform='normalize_attr', device="GPU", seed=123)
model.build()
his = model.train(idx_train, idx_val, verbose=1, epochs=100)
loss, accuracy = model.test(idx_test, verbose=1)
print(f'Test loss {loss:.5}, Test accuracy {accuracy:.2%}')
graph
是输入的图,adj_transform
是对邻接矩阵的变换,attr_transform
是对结点特征矩阵的变换,并且可以指定运行设备device
和用于重现结果的随机种子seed
模型调用
build
快速搭建一个 GCN 模型,build 可以指定包含隐藏层单元个数(层数),激活函数,学习率等参数
# 一层隐藏层 (32单元),激活函数 RELU
>>> model.build(hiddens=32, activations='relu')# 两层隐藏层(32和64单元),两层的激活函数都是 RELU
>>> model.build(hiddens=[32, 64], activations='relu')# 两层隐藏层 (32和64单元),激活函数分别是 RELU 和 ELU
>>> model.build(hiddens=[32, 64], activations=['relu', 'elu'])
- 模型调用
train
方法进行训练。idx_train
是训练集结点,同理idx_val
是验证集结点(也可以不指定),verbose
可以指定 0, 1, 2, 3, 4 五种训练过程输出,返回的his
是 一个记录训练历史情况的类,可以通过调用his.history
查看训练过程的输出。 - 模型调用
test
方法进行测试,idx_test
是测试集结点,verbose
可指定 0 和1两种,最终返回 测试集的损失和准确率
在 Planetoid Cora
数据集上的结果
Training...
100/100 [==============================] - 1s 14ms/step - loss: 1.0161 - acc: 0.9500 - val_loss: 1.4101 - val_acc: 0.7740 - time: 1.4180
Testing...
1/1 [==============================] - 0s 62ms/step - test_loss: 1.4123 - test_acc: 0.8120 - time: 0.0620
Test loss 1.4123, Test accuracy 81.20%
至此,只需要几行代码即可完成对一个模型的调用和训练测试,并且当你切换不同的后端,调用的是不同后端实现的模型(甚至不需要更改上述调用代码)。
后续工作
- 实现更多的 GNN 模型(两种后端)
- 支持更多的任务(目前主要支持半监督的结点分类任务),未来会加入链路预测,图分类等任务
- 支持更多样的图数据结构(目前只支持单一无向同构图),未来会考虑异构图,多图
- 为项目提供更好的项目文档和注释(完善中…)
GraphGallery 项目主页:https://github.com/EdisonLeeeee/GraphGallery
GraphData 项目主页:https://github.com/EdisonLeeeee/GraphData
GraphGallery,一个基于TensorFlow 2.x与 PyTorch 的GNN benchmark 框架相关推荐
- 一个基于Tensorflow的神经网络机器翻译系统
一个基于Tensorflow的神经网络机器翻译系统 Github地址:https://github.com/zhaocq-nlp/NJUNMT-tf 系统完全基于Tensorflow最基本的array ...
- 论文浅尝 | ADRL:一个基于注意力机制的知识图谱深度强化学习框架
论文笔记整理:谭亦鸣,东南大学博士. 来源:Knowledge-Based Systems 197 (2020) 105910 链接:https://www.sciencedirect.com/sci ...
- 如何做一个基于python校园网站系统毕业设计毕设作品(Django框架)
分析架构 我们开发系统,常规有两个架构,一个BS架构(浏览器/服务器模式),一个CS(客户端/服务器端模式):基于Python(Django框架)的网站开发属于B/S架构(即浏览器和服务器架构模式), ...
- 如何做一个基于JAVA在线考试系统毕业设计毕设作品(springboot框架)
分析架构 我们开发系统,常规有两个架构,一个BS架构(浏览器/服务器模式),一个CS(客户端/服务器端模式):基于JAVA的网站开发属于B/S架构(即浏览器和服务器架构模式),架构如图 分析系统功能 ...
- 教程 | 一个基于TensorFlow的简单故事生成案例:带你了解LSTM
在深度学习中,循环神经网络(RNN)是一系列善于从序列数据中学习的神经网络.由于对长期依赖问题的鲁棒性,长短期记忆(LSTM)是一类已经有实际应用的循环神经网络.现在已有大量关于 LSTM 的文章和文 ...
- 基于 TensorFlow 的图像识别(R实现)
提到机器学习,深度学习这些,大家都会立马想起Python.但R的实力也不容小觑.今天就用R来演示一个基于TensorFlow的图像识别的例子.如果你想运行这些代码,就必须先安装配置好TensorFlo ...
- 令人激动!谷歌推强化学习新框架「多巴胺」,基于TensorFlow,已开源丨附github...
郭一璞 发自 凹非寺 量子位 报道 | 公众号 QbitAI 上周那个在DOTA2 TI8赛场上"装逼失败"的OpenAI Five,背后是强化学习的助推. 其实不仅仅是Open ...
- RNN循环神经网络的自我理解:基于Tensorflow的简单句子使用(通俗理解RNN)
解读tensorflow之rnn: 该开始接触RNN我们都会看到这样的张图: 如上图可以看到每t-1时的forward的结果和t时的输入共同作为这一次forward的输入 所以RNN存在一定的弊端, ...
- python基于tensorflow的人脸识别系统设计与实现.zip(论文+源码)
摘 要 人脸识别技术是模式是别和计算机视觉研究中的一个重要领域,在边防安全.视频监控.身份验证等方面有重要的应用价值.人脸检测是快速.准确识别人脸的前提,其目的是将人脸从图像背景中检测出来.传统的课堂 ...
最新文章
- 两台服务器安装redis集群_Redis Cluster搭建高可用Redis服务器集群
- 关于librtmp接收数据(接收网络电视的数据流)
- 北斗导航 | 精密单点定位软件之rtklib的静态定位测试(RTKlib)
- [转]百万数据查询优化技巧三四则
- SharePoint 2010中的内容类型集线器 - 内容类型发布与订阅
- python初学者用什么开发环境搭建_2019-04-11 python入门学习——配置机器及搭建开发环境...
- css简单的数学运算
- PyTorch 学习笔记(六):PyTorch的十八个损失函数
- Struts2不扫描jar包中的action
- 《诗经》诗无邪 —— 雅篇
- c语言清屏函数怎么用_怎么用好 Golang 的 init 函数
- 微信小程序常用api
- 转载:Xshell使用教程
- 远程连接服务器的命令工具,windows系统如何实现远程命令?远程命令工具您选哪个?...
- 电脑键盘equals在哪个位置_【电脑键盘在哪里调出来】电脑键盘在哪里找_电脑模拟键盘在哪里...
- Python学习第一天
- 微信支付当前页面的URL未注册问题[单页面]
- 文心一言 VS ChatGpt
- CLIP学习笔记:Learning Transferable Visual Models From Natural Language Supervision
- JAVA API1.8中文版 谷歌翻译 最准确最全的翻译版本!蓝奏下载
热门文章
- 2022经典生活感悟说说,句句值千金
- 圣诞音乐贺卡beepMusic_v6d;--铃儿响叮当;
- 线段树维护(最大区间和,最大子段和,最长连续上升子序列)
- Krita开发文档翻译——Introduction to Hacking Krita
- BIOS、UEFI及系统安装
- 高通SDM845平台Sensor学习——3.SLPI(Physical Sensor)
- 重装系统后电脑图片显示不出来怎么办
- gitbucket push卡住
- 六则励志故事,送给程序员的你,希望从中获得启发与帮助!
- 备份恢复Lesson 10. Restore and Recovery Concepts