本教程将演示如何构建一个基于半监督的节点分类任务的GNN网络,任务基于一个小数据集Cora,这是一个将论文作为节点,引用关系作为边的网络结构。

任务就是预测一个论文的所属分类。每一个论文包含一个词频信息作为属性特征。

首先安装dgl

pip install dgl -i https://pypi.douban.com/simple/

加载Cora数据集

import dgl.datadataset = dgl.data.CoraGraphDataset()
print('Number of categories:', dataset.num_classes)

这样会自动下载Cora数据集到Extracting file to C:\Users\vincent\.dgl\cora_v2\目录下,输出结果如下:

Downloading C:\Users\vincent\.dgl\cora_v2.zip from https://data.dgl.ai/dataset/cora_v2.zip...
Extracting file to C:\Users\vincent\.dgl\cora_v2
Finished data loading and preprocessing.NumNodes: 2708NumEdges: 10556NumFeats: 1433NumClasses: 7NumTrainingSamples: 140NumValidationSamples: 500NumTestSamples: 1000
Done saving data into cached files.
Number of categories: 7

一个DGL数据集可能包含多个Graph,但是Cora数据集仅包含一个Graph:

g = dataset[0]

一个DGL图可以通过字典的形式存储节点的属性ndata和边的属性edata。在DGL Cora数据集中,graph包含下面几个节点特征:

  • train_mask:一个bool 类型的tensor,表示一个节点是不是属于training set
  • val_mask: 一个bool 类型的tensor,表示一个节点是不是属于validation set
  • test_mask:一个bool 类型的tensor,表示一个节点是不是属于test set
  • label:节点的分类标签
  • feat:节点的属性
print('Node features')
print(g.ndata)
print('Edge features')
print(g.edata)

输出结果:

Node features
{'feat': tensor([[0., 0., 0.,  ..., 0., 0., 0.],[0., 0., 0.,  ..., 0., 0., 0.],[0., 0., 0.,  ..., 0., 0., 0.],...,[0., 0., 0.,  ..., 0., 0., 0.],[0., 0., 0.,  ..., 0., 0., 0.],[0., 0., 0.,  ..., 0., 0., 0.]]), 'label': tensor([3, 4, 4,  ..., 3, 3, 3]), 'test_mask': tensor([False, False, False,  ...,  True,  True,  True]), 'train_mask': tensor([ True,  True,  True,  ..., False, False, False]), 'val_mask': tensor([False, False, False,  ..., False, False, False])}
Edge features
{}

定义一个GNN网络

我们将构建一个两层的GCN网络,每一层通过聚合邻居信息来计算一个节点表示。

为了构建这样一个多层的GCN,我们可以简单的堆叠dgl.nn.GraphConv模块,这个模块继承了torch.nn.Module

import torch
import torch.nn as nn
import dgl.data
from dgl.nn.pytorch import GraphConv
import torch.nn.functional as Fdataset = dgl.data.CoraGraphDataset()
print('Number of categories:', dataset.num_classes)
g = dataset[0]
print('Node features')
print(g.ndata)
print('Edge features')
print(g.edata)class GCN(nn.Module):def __init__(self, in_feats, h_feats, num_classes):super(GCN, self).__init__()self.conv1 = GraphConv(in_feats, h_feats)self.conv2 = GraphConv(h_feats, num_classes)def forward(self, g, in_feat):h = self.conv1(g, in_feat)h = F.relu(h)h = self.conv2(g, h)return h# Create the model with given dimensions
model = GCN(g.ndata['feat'].shape[1], 16, dataset.num_classes)
print(model)

DGL实现了很多当下流行的聚合邻居的模块,我们可以只用一行代码就可以使用。

训练GCN

使用DGL训练GCN与训练其他Pytorch神经网络过程类似:

import torch
import torch.nn as nn
import dgl.data
from dgl.nn.pytorch import GraphConv
import torch.nn.functional as Fdataset = dgl.data.CoraGraphDataset()
print('Number of categories:', dataset.num_classes)
g = dataset[0]
print('Node features')
print(g.ndata)
print('Edge features')
print(g.edata)class GCN(nn.Module):def __init__(self, in_feats, h_feats, num_classes):super(GCN, self).__init__()self.conv1 = GraphConv(in_feats, h_feats)self.conv2 = GraphConv(h_feats, num_classes)def forward(self, g, in_feat):h = self.conv1(g, in_feat)h = F.relu(h)h = self.conv2(g, h)return hdef train(g, model):optimizer = torch.optim.Adam(model.parameters(), lr=0.01)best_val_acc = 0best_test_acc = 0features = g.ndata['feat']labels = g.ndata['label']train_mask = g.ndata['train_mask']val_mask = g.ndata['val_mask']test_mask = g.ndata['test_mask']for e in range(100):# Forwardlogits = model(g, features)# Compute predictionpred = logits.argmax(1)# Compute loss# Note that you should only compute the losses of the nodes in the training set.loss = F.cross_entropy(logits[train_mask], labels[train_mask])# Compute accuracy on training/validation/testtrain_acc = (pred[train_mask] == labels[train_mask]).float().mean()val_acc = (pred[val_mask] == labels[val_mask]).float().mean()test_acc = (pred[test_mask] == labels[test_mask]).float().mean()# Save the best validation accuracy and the corresponding test accuracy.if best_val_acc < val_acc:best_val_acc = val_accbest_test_acc = test_acc# Backwardoptimizer.zero_grad()loss.backward()optimizer.step()if e % 5 == 0:print('In epoch {}, loss: {:.3f}, val acc: {:.3f} (best {:.3f}), test acc: {:.3f} (best {:.3f})'.format(e, loss, val_acc, best_val_acc, test_acc, best_test_acc))# Create the model with given dimensions
model = GCN(g.ndata['feat'].shape[1], 16, dataset.num_classes)
print(model)
train(g, model)

输出结果:

In epoch 0, loss: 1.946, val acc: 0.134 (best 0.134), test acc: 0.138 (best 0.138)
In epoch 5, loss: 1.892, val acc: 0.506 (best 0.522), test acc: 0.499 (best 0.539)
In epoch 10, loss: 1.806, val acc: 0.600 (best 0.612), test acc: 0.633 (best 0.636)
In epoch 15, loss: 1.698, val acc: 0.594 (best 0.612), test acc: 0.626 (best 0.636)
In epoch 20, loss: 1.567, val acc: 0.632 (best 0.632), test acc: 0.653 (best 0.653)
In epoch 25, loss: 1.417, val acc: 0.712 (best 0.712), test acc: 0.700 (best 0.700)
In epoch 30, loss: 1.251, val acc: 0.738 (best 0.738), test acc: 0.737 (best 0.737)
In epoch 35, loss: 1.079, val acc: 0.746 (best 0.746), test acc: 0.751 (best 0.751)
In epoch 40, loss: 0.909, val acc: 0.746 (best 0.748), test acc: 0.758 (best 0.756)
In epoch 45, loss: 0.751, val acc: 0.738 (best 0.748), test acc: 0.766 (best 0.756)
In epoch 50, loss: 0.612, val acc: 0.744 (best 0.748), test acc: 0.767 (best 0.756)
In epoch 55, loss: 0.494, val acc: 0.752 (best 0.752), test acc: 0.773 (best 0.773)
In epoch 60, loss: 0.399, val acc: 0.762 (best 0.762), test acc: 0.776 (best 0.776)
In epoch 65, loss: 0.322, val acc: 0.762 (best 0.766), test acc: 0.776 (best 0.776)
In epoch 70, loss: 0.262, val acc: 0.764 (best 0.768), test acc: 0.778 (best 0.775)
In epoch 75, loss: 0.215, val acc: 0.766 (best 0.768), test acc: 0.778 (best 0.775)
In epoch 80, loss: 0.178, val acc: 0.766 (best 0.768), test acc: 0.779 (best 0.775)
In epoch 85, loss: 0.149, val acc: 0.766 (best 0.768), test acc: 0.780 (best 0.775)
In epoch 90, loss: 0.126, val acc: 0.768 (best 0.768), test acc: 0.779 (best 0.775)
In epoch 95, loss: 0.107, val acc: 0.768 (best 0.768), test acc: 0.776 (best 0.775)

在GPU上进行训练

在GPU上训练需要将模型和数据通过to()方法放到GPU上:

g = g.to('cuda')
model = GCN(g.ndata['feat'].shape[1], 16, dataset.num_classes).to('cuda')
train(g, model)

DGL教程【一】使用Cora数据集进行分类相关推荐

  1. GCN——DGL教程

    官方教程 中文版 安装DGL 内置函数和消息传递API dgl.function 查阅需要的函数,官方教程没有注释 消息传递框架(Message Passing Paradigm): 可以对" ...

  2. PYG教程【三】对Cora数据集进行半监督节点分类

    Cora数据集 PyG包含有大量的基准数据集.初始化数据集非常简单,数据集初始化会自动下载原始数据文件,并且会将它们处理成Data格式. 如下图所示,Cora数据集中只有一个图,该图包含2708个节点 ...

  3. PyG(PyTorch Geometric)安装教程(附Cora数据集)

    PyG(PyTorch Geometric)安装教程(附Cora数据集) PyG是多特蒙德工业大学(Technische University Dortmund)的Matthias Fey博士基于Py ...

  4. 【DGL教程】第4章 图数据集

    官方文档:https://docs.dgl.ai/en/latest/guide/data.html dgl.data实现了很多常用的图数据集,这些数据集都是dgl.data.DGLDataset的子 ...

  5. 从零开始的图像语义分割:FCN快速复现教程(Pytorch+CityScapes数据集)

    从零开始的图像语义分割:FCN复现教程(Pytorch+CityScapes数据集) 前言 一.图像分割开山之作FCN 二.代码及数据集获取 1.源项目代码 2.CityScapes数据集 三.代码复 ...

  6. 图神经网络框架DGL教程-第4章:图数据处理管道

    更多图神经网络和深度学习内容请关注: 第4章:图数据处理管道 DGL在 dgl.data 里实现了很多常用的图数据集.它们遵循了由 dgl.data.DGLDataset 类定义的标准的数据处理管道. ...

  7. python训练数据集_python – 如何训练大型数据集进行分类

    我有一个1600000推文的训练数据集.我该如何训练这类巨大的数据. 我尝试过使用nltk.NaiveBayesClassifier.如果我跑步,训练需要5天以上. def extract_featu ...

  8. ML之SVM:利用SVM算法(超参数组合进行多线程网格搜索+3fCrVa)对20类新闻文本数据集进行分类预测、评估

    ML之SVM:利用SVM算法(超参数组合进行多线程网格搜索+3fCrVa)对20类新闻文本数据集进行分类预测.评估 目录 输出结果 设计思路 核心代码 输出结果 Fitting 3 folds for ...

  9. ML之SVM:利用SVM算法(超参数组合进行单线程网格搜索+3fCrVa)对20类新闻文本数据集进行分类预测、评估

    ML之SVM:利用SVM算法(超参数组合进行单线程网格搜索+3fCrVa)对20类新闻文本数据集进行分类预测.评估 目录 输出结果 设计思路 核心代码 输出结果 Fitting 3 folds for ...

最新文章

  1. mysql 主从相关
  2. 网站推广——站长助力创业期企业网站优化推广的好选择
  3. [BZOJ4556][Tjoi2016Heoi2016]字符串 主席树+二分+倍增+后缀自动机
  4. 利用反射,泛型,静态方法快速获取表单值到Model。
  5. ios刷android8.0,颤抖吧 iOS, Android 8.0正式发布!
  6. 二维数组最大连通子数组之和
  7. ajax分片上传,ajax异步实现文件分片上传
  8. android imageview 等比例放大缩小,imageView的使用(进行原样的保持和按照比例的缩放:)...
  9. 从MVC到前后端分离(REST-个人也认为是目前比较流行和比较好的方式)
  10. 中国水稻大省创新大米销售模式 启动2019首场拍卖
  11. IOS开发之——屏幕适配-AutoLayout动画(05)
  12. chrome浏览器 json插件【WEB前端助手】
  13. ZigBee协议栈(一)--协议栈介绍
  14. 【网络工程师必备】怎么使用route命令实现内外网切换
  15. A Neural Algorithm of Artistic Style : Neural Style Transfer with Eager Executon
  16. 知道一点怎么设直线方程_两点直线方程怎么求
  17. 如何将刷题的效率提升10倍
  18. XP停止服务:不必难过 千里相送终有一别
  19. 如何写一个好checker
  20. 女孩取名起名字:带染字的古风女孩名字

热门文章

  1. JsonException: Max allowed object depth reached while trying to export from type System.Single
  2. 网页中嵌入Excel控件
  3. Failed to create the Java Virtual Machine
  4. laravel5.8笔记六:公共函数和常量设置
  5. postman模拟HTTP请求
  6. 求虚拟机11.0密钥
  7. PHP上传方式base64图片的接收方式
  8. regedit start mysql_MySQL安装完成配置的时候start service报错
  9. c++ protected_合理使用protected关键字,确保类属性的安全性
  10. HTTP使用BASIC认证的原理及实现方法