图神经网络中最流行和广泛采用的任务之一就是节点分类,其中训练集/验证集/测试集中的每个节点从一组预定义的类别中分配一个真实类别。

为了对节点进行分类,图神经网络利用节点自身的特征,以及相邻节点和边的特征进行消息传递。消息传递可以重复多次,以聚合来自更大范围的邻居节点的信息。

dgl框架为我们提供了一些内置的图卷积模块,可以执行一轮的消息传递。

在本文中,我们使用dgl.nn.pytorch的SAGEConv模块,该模块来自这篇论文GraphSAGE:Inductive Representation Learning on Large Graphs

通常对于图上的深度学习模型,我们需要一个多层图神经网络,在这里我们进行多轮的消息传递。这可以通过如下方式堆叠图卷积模块来实现。

1 构造GNN模型

先导入必要包(本文dgl 版本为 0.5.2)

import dgl.nn as dglnn
import dgl
import torch
import torch.nn as nn
import torch.nn.functional as F
from dgl.data import * 

构造一个两层的gnn模型

class SAGE(nn.Module):def __init__(self, in_feats, hid_feats, out_feats, dropout=0.2):super().__init__()self.conv1 = dglnn.SAGEConv( in_feats=in_feats, out_feats=hid_feats, feat_drop=0.2, aggregator_type='gcn')self.conv2 = dglnn.SAGEConv(in_feats=hid_feats, out_feats=out_feats, feat_drop=0.2, aggregator_type='mean')self.dropout =  nn.Dropout(dropout)def forward(self, graph, inputs):# inputs 是节点的特征 [N, in_feas]h = self.conv1(graph, inputs)h = self.dropout(F.relu(h))h = self.conv2(graph, h)return h 

注意,我们不仅可以将上面的模型用于节点分类,还可以获取节点的特征表示为了其他下游任务,如边分类/回归、链接预测或图分类。

2 数据集与数据分析

dataset = CoraGraphDataset() # Cora citation network dataset
graph = dataset[0]
graph = dgl.remove_self_loop(graph)  # 消除自环
node_features = graph.ndata['feat']
node_labels = graph.ndata['label']
train_mask = graph.ndata['train_mask']
valid_mask = graph.ndata['val_mask']
test_mask = graph.ndata['test_mask']
n_features = node_features.shape[1]
n_labels = int(node_labels.max().item() + 1) 

print("图的节点数和边数: ", graph.num_nodes(), graph.num_edges())
print("训练集节点数:", train_mask.sum().item())
print("验证集集节点数:", valid_mask.sum().item())
print("测试集节点数:", test_mask.sum().item())
print("节点特征维数:", n_features)
print("标签类目数:", n_labels)

随机抽200个节点并画图展示:

import networkx as nx
import numpy as np
import matplotlib.pyplot as plt G = graph.to_networkx()
res = np.random.randint(0, high=G.number_of_nodes(), size=(200))k = G.subgraph(res)
pos = nx.spring_layout(k)plt.figure()
nx.draw(k, pos=pos, node_size=8 )
plt.savefig('cora.jpg', dpi=600)
plt.show()

3 训练模型与评估

def evaluate(model, graph, features, labels, mask):model.eval()with torch.no_grad():logits = model(graph, features)logits = logits[mask]labels = labels[mask]_, indices = torch.max(logits, dim=1)correct = torch.sum(indices == labels)return correct.item() * 1.0 / len(labels)model = SAGE(in_feats=n_features, hid_feats=128, out_feats=n_labels)
opt = torch.optim.Adam(model.parameters())# 开始训练
best_val_acc = 0
for epoch in range(200): print('Epoch {}'.format(epoch))model.train()# 用所有的节点进行前向传播logits = model(graph, node_features)# 计算损失loss = F.cross_entropy(logits[train_mask], node_labels[train_mask])# 计算验证集accuracyacc = evaluate(model, graph, node_features, node_labels, valid_mask)# backward propagationopt.zero_grad()loss.backward()opt.step()print('loss = {:.4f}'.format(loss.item()))if acc > best_val_acc:best_val_acc = acctorch.save(model.state_dict(), 'save_model/best_model.pth')print("current val acc = {}, best val acc = {}".format(acc, best_val_acc))

测试集评估

model.load_state_dict(torch.load("save_model/best_model.pth"))
acc = evaluate(model, graph, node_features, node_labels, test_mask)
print("test accuracy: ", acc)

完结:-) 觉得有用记得双击点赞呀!

pytorch 训练过程acc_【图节点分类】10分钟就学会的图节点分类教程,基于pytorch和dgl...相关推荐

  1. pytorch 训练过程acc_深度学习Pytorch实现分类模型

    今天将介绍深度学习中的分类模型,以下主要介绍Softmax的基本概念.神经网络模型.交叉熵损失函数.准确率以及Pytorch实现图像分类.01Softmax基本概念 在分类问题中,通常标签都为类别,可 ...

  2. pytorch训练过程中loss出现NaN的原因及可采取的方法

    在pytorch训练过程中出现loss=nan的情况 1.学习率太高. 2.loss函数 3.对于回归问题,可能出现了除0 的计算,加一个很小的余项可能可以解决 4.数据本身,是否存在Nan,可以用n ...

  3. 【工具篇】10分钟快速学会React图表搭建

    10分钟快速学会React图表搭建 本次紧着之前的antd,接着学习有关react图表以及富文本编辑器的搭建. 本次的功能实现基于上次的[工具篇]10分钟学会Ant Design of React用法 ...

  4. pytorch 训练过程acc_Pytorch之Softmax多分类任务

    在上一篇文章中,笔者介绍了什么是Softmax回归及其原理.因此在接下来的这篇文章中,我们就来开始动手实现一下Softmax回归,并且最后要完成利用Softmax模型对Fashion MINIST进行 ...

  5. 深度神经网络训练过程中为什么验证集上波动很大_图神经网络的新基准

    作者 | 李光明 编辑 | 贾 伟 编者注:本文解读论文与我们曾发文章<Bengio 团队力作:GNN 对比基准横空出世,图神经网络的「ImageNet」来了>所解读论文,为同一篇,不同作 ...

  6. 钢铁侠头盔制作图纸下载_如何在10分钟内制作头盔图

    钢铁侠头盔制作图纸下载 我每天的大部分时间都涉及创建,修改和部署Helm图表以管理应用程序的部署. Helm是Kubernetes的应用程序包管理器,负责协调应用程序的下载,安装和部署. Helm图表 ...

  7. 指纹测试天赋测试软件,指纹也能测天赋 10分钟出结果(图)

    儿童只需把指头伸在特定的接触器上,测试系统就能掌握他的天赋状况. MAKEWELLTQ正在分析孩子的指纹纹理. 红网12月14日讯(记者 董雷)通过几个手纹就可测出一个孩子详细的TQ(天赋)报告,这么 ...

  8. pytorch 训练过程acc_pytorch入门练手:一个简单的CNN模型

    由于新型冠状肺炎疫情一直没能开学,在家自己学习了一下pytorch,本来说按着官网的60分钟教程过一遍的,但是CIFAR-10数据库的下载速度太慢了-- 这台电脑里也没有现成的数据库,想起之前画了一些 ...

  9. Ubuntu在pytorch训练过程中总是出现死机,重启

    问题解析:一般是gpu或者cpu在和内存io的时候,内存容量不足被强制kill了,举个例子,我训练的模型大小约占用显存16g,但是在存储模型的过程会被32g的内存撑爆 1 在pycharm的设置文件将 ...

最新文章

  1. EF code First数据迁移学习笔记
  2. 渗透测试-基于白名单执行payload--Compiler
  3. 只读域控制器RODC的安装
  4. SpringMVC教程--Validation校验
  5. 去掉PE文件随机基址的方法
  6. java class 结构_Java class文件的结构
  7. poj 3522(最小生成树应用)
  8. rational rose 逆向工程
  9. 定时任务 cron 表达式详解
  10. 判断CPU大小端模式
  11. java对数据库的操作_java对数据库的操作(jdbc)
  12. Java:GB18030字节数组与UTF8互转
  13. MFC开发IM-第十五篇、打包的MFC程序别人无法启动的原因
  14. 推行CMMI能在哪些方面为软件企业带来好处?
  15. UI设计中配色专辑素材|做图配色,一键搞定
  16. nfs+lvm解决磁盘空间扩容问题
  17. Android 开源项目分类汇总(转)
  18. 炒黄金短线交易如何放大收益
  19. 如何在html定位一张图片,css图片怎么定位?
  20. SAP ABAP ZBA_R003 查询用户下的角色里的公司

热门文章

  1. python if elif else 区别
  2. CodeGen CreateFile实用程序
  3. 编译器设计-解析类型
  4. 2021年大数据Hadoop(十):HDFS的数据读写流程
  5. Mysql中的递归层次查询(父子查询,无限极查询)
  6. python Django 管理站点1.3
  7. Android OpenCV 边缘检测 Canny 的使用
  8. android EditText 修改光标的颜色值
  9. 微信小程序地图的实现
  10. 上三角矩阵的特征值分解