本文主要是介绍如何用PyTorch Geometric快速实现Node2Vec节点分类,并对其结果进行可视化。

整个过程包含四个步骤:

  • 导入图数据(这里以Cora为例)
  • 创建Node2Vec模型
  • 训练和测试数据
  • TSNE降维后可视化

Node2vec方法的参数如下:

  • edge_index (LongTensor):邻接矩阵
  • embedding_dim (int):每个节点的embedding维度
  • walk_length (int):步长
  • context_size (int):正采样时的窗口大小
  • walks_per_node (int, optional) :每个节点走多少步
  • p (float, optional) :p值
  • q (float, optional) :q值
  • num_negative_samples (int, optional) :每个正采样对应多少负采样

代码如下:

import torch
import matplotlib.pyplot as plt
from sklearn.manifold import TSNE
from torch_geometric.datasets import Planetoid
from torch_geometric.nn import Node2Vecdataset = Planetoid(root='G:/torch_geometric_datasets', name='Cora')
data = dataset[0]device = 'cuda' if torch.cuda.is_available() else 'cpu'
model = Node2Vec(data.edge_index, embedding_dim=128, walk_length=20,context_size=10, walks_per_node=10, num_negative_samples=1,sparse=True).to(device)
loader = model.loader(batch_size=128, shuffle=True, num_workers=4)# 在pytorch旧版本中使用torch.optim.SparseAdam(model.parameters(), lr=0.01),新版本中需要转为list, 本文pytorch版本1.7.1
optimizer = torch.optim.SparseAdam(list(model.parameters()), lr=0.01)def train():model.train()total_loss = 0for pos_rw, neg_rw in loader:optimizer.zero_grad()loss = model.loss(pos_rw.to(device), neg_rw.to(device))loss.backward()optimizer.step()total_loss += loss.item()return total_loss / len(loader)@torch.no_grad()
def test():model.eval()z = model()acc = model.test(z[data.train_mask], data.y[data.train_mask],z[data.test_mask], data.y[data.test_mask], max_iter=150) # 使用train_mask训练一个分类器,用test_mask分类return accfor epoch in range(1, 101):loss = train()acc = test()print(f'Epoch:{epoch:02d}, Loss:{loss:.4f}, Acc:{acc:.4f}')@torch.no_grad()
def plot_points(colors):model.eval()z = model(torch.arange(data.num_nodes, device=device))z = TSNE(n_components=2).fit_transform(z.cpu().numpy())y = data.y.cpu().numpy()plt.figure(figsize=(8, 8))for i in range(dataset.num_classes):plt.scatter(z[y == i, 0], z[y == i, 1], s=20, color=colors[i])plt.axis('off')plt.show()colors = ['#ffc0cb', '#bada55', '#008080', '#420420', '#7fe5f0', '#065535', '#ffd700']
plot_points(colors)

输出结果如下:

Epoch:01, Loss: 8.0661, Acc: 0.1570
Epoch:02, Loss: 6.0309, Acc: 0.1800
Epoch:03, Loss: 4.9328, Acc: 0.2050
Epoch:04, Loss: 4.1206, Acc: 0.2400
Epoch:05, Loss: 3.4587, Acc: 0.2760
Epoch:06, Loss: 2.9389, Acc: 0.2950
Epoch:07, Loss: 2.5340, Acc: 0.3220
Epoch:08, Loss: 2.2042, Acc: 0.3410
Epoch:09, Loss: 1.9404, Acc: 0.3700
Epoch:10, Loss: 1.7295, Acc: 0.4050
Epoch:11, Loss: 1.5594, Acc: 0.4340
Epoch:12, Loss: 1.4231, Acc: 0.4660
Epoch:13, Loss: 1.3143, Acc: 0.4850
Epoch:14, Loss: 1.2242, Acc: 0.5100
Epoch:15, Loss: 1.1539, Acc: 0.5310
Epoch:16, Loss: 1.0997, Acc: 0.5560
Epoch:17, Loss: 1.0559, Acc: 0.5760
Epoch:18, Loss: 1.0199, Acc: 0.6020
Epoch:19, Loss: 0.9921, Acc: 0.6120
Epoch:20, Loss: 0.9671, Acc: 0.6190
Epoch:21, Loss: 0.9487, Acc: 0.6300
Epoch:22, Loss: 0.9335, Acc: 0.6390
Epoch:23, Loss: 0.9203, Acc: 0.6480
Epoch:24, Loss: 0.9106, Acc: 0.6580
Epoch:25, Loss: 0.8994, Acc: 0.6630
Epoch:26, Loss: 0.8924, Acc: 0.6600
Epoch:27, Loss: 0.8858, Acc: 0.6610
Epoch:28, Loss: 0.8792, Acc: 0.6670
Epoch:29, Loss: 0.8731, Acc: 0.6800
Epoch:30, Loss: 0.8697, Acc: 0.6830
Epoch:31, Loss: 0.8652, Acc: 0.6850
Epoch:32, Loss: 0.8618, Acc: 0.6840
Epoch:33, Loss: 0.8586, Acc: 0.6920
Epoch:34, Loss: 0.8550, Acc: 0.6900
Epoch:35, Loss: 0.8523, Acc: 0.6820
Epoch:36, Loss: 0.8507, Acc: 0.6800
Epoch:37, Loss: 0.8483, Acc: 0.6870
Epoch:38, Loss: 0.8469, Acc: 0.6930
Epoch:39, Loss: 0.8449, Acc: 0.6950
Epoch:40, Loss: 0.8433, Acc: 0.6920
Epoch:41, Loss: 0.8422, Acc: 0.6980
Epoch:42, Loss: 0.8398, Acc: 0.6960
Epoch:43, Loss: 0.8401, Acc: 0.6930
Epoch:44, Loss: 0.8374, Acc: 0.6930
Epoch:45, Loss: 0.8377, Acc: 0.6990
Epoch:46, Loss: 0.8363, Acc: 0.6970
Epoch:47, Loss: 0.8354, Acc: 0.7060
Epoch:48, Loss: 0.8339, Acc: 0.7130
Epoch:49, Loss: 0.8333, Acc: 0.7060
Epoch:50, Loss: 0.8340, Acc: 0.7090
Epoch:51, Loss: 0.8332, Acc: 0.7090
Epoch:52, Loss: 0.8325, Acc: 0.7090
Epoch:53, Loss: 0.8321, Acc: 0.7070
Epoch:54, Loss: 0.8316, Acc: 0.7160
Epoch:55, Loss: 0.8317, Acc: 0.7100
Epoch:56, Loss: 0.8297, Acc: 0.7130
Epoch:57, Loss: 0.8309, Acc: 0.7140
Epoch:58, Loss: 0.8296, Acc: 0.7230
Epoch:59, Loss: 0.8296, Acc: 0.7230
Epoch:60, Loss: 0.8276, Acc: 0.7190
Epoch:61, Loss: 0.8287, Acc: 0.7120
Epoch:62, Loss: 0.8294, Acc: 0.7120
Epoch:63, Loss: 0.8272, Acc: 0.7050
Epoch:64, Loss: 0.8286, Acc: 0.7040
Epoch:65, Loss: 0.8283, Acc: 0.7090
Epoch:66, Loss: 0.8278, Acc: 0.7110
Epoch:67, Loss: 0.8274, Acc: 0.7140
Epoch:68, Loss: 0.8283, Acc: 0.7190
Epoch:69, Loss: 0.8269, Acc: 0.7160
Epoch:70, Loss: 0.8271, Acc: 0.7210
Epoch:71, Loss: 0.8260, Acc: 0.7190
Epoch:72, Loss: 0.8273, Acc: 0.7130
Epoch:73, Loss: 0.8252, Acc: 0.7150
Epoch:74, Loss: 0.8264, Acc: 0.7120
Epoch:75, Loss: 0.8250, Acc: 0.7160
Epoch:76, Loss: 0.8253, Acc: 0.7190
Epoch:77, Loss: 0.8244, Acc: 0.7220
Epoch:78, Loss: 0.8263, Acc: 0.7220
Epoch:79, Loss: 0.8271, Acc: 0.7180
Epoch:80, Loss: 0.8253, Acc: 0.7110
Epoch:81, Loss: 0.8260, Acc: 0.7080
Epoch:82, Loss: 0.8246, Acc: 0.7140
Epoch:83, Loss: 0.8256, Acc: 0.7170
Epoch:84, Loss: 0.8257, Acc: 0.7210
Epoch:85, Loss: 0.8256, Acc: 0.7190
Epoch:86, Loss: 0.8244, Acc: 0.7170
Epoch:87, Loss: 0.8254, Acc: 0.7240
Epoch:88, Loss: 0.8249, Acc: 0.7170
Epoch:89, Loss: 0.8252, Acc: 0.7160
Epoch:90, Loss: 0.8243, Acc: 0.7010
Epoch:91, Loss: 0.8254, Acc: 0.7050
Epoch:92, Loss: 0.8249, Acc: 0.7030
Epoch:93, Loss: 0.8249, Acc: 0.7110
Epoch:94, Loss: 0.8233, Acc: 0.6990
Epoch:95, Loss: 0.8243, Acc: 0.6990
Epoch:96, Loss: 0.8248, Acc: 0.7140
Epoch:97, Loss: 0.8240, Acc: 0.7090
Epoch:98, Loss: 0.8247, Acc: 0.7100
Epoch:99, Loss: 0.8255, Acc: 0.7060
Epoch:100, Loss: 0.8242, Acc: 0.7160


从输出结果看出train的loss后面降低,但是精度却没有降低,有点过拟合了。

PYG教程【四】Node2Vec节点分类及其可视化相关推荐

  1. PyG基于DeepWalk实现节点分类及其可视化

    文章目录 前言 一.导入相关库 二.加载Cora数据集 三.定义DeepWalk 四.可视化 完整代码 前言 大家好,我是阿光. 本专栏整理了<图神经网络代码实战>,内包含了不同图神经网络 ...

  2. PyG搭建R-GCN实现节点分类

    目录 前言 数据处理 模型搭建 1. 前向传播 2. 反向传播 3. 训练 4. 测试 实验结果 完整代码 前言 R-GCN的原理请见:ESWC 2018 | R-GCN:基于图卷积网络的关系数据建模 ...

  3. PyG搭建GAT实现节点分类

    目录 前言 模型搭建 1. 前向传播 2. 反向传播 3. 训练 4. 测试 实验结果 完整代码 前言 GAT的原理比较简单,具体请见:ICLR 2018 | GAT:图注意力网络 模型搭建 首先导入 ...

  4. PyG基于Node2Vec实现节点分类及其可视化

    文章目录 前言 一.导入相关库 二.加载Cora数据集 三.定义Node2Vec 四.定义模型 五.模型训练 六.可视化 完整代码 前言 大家好,我是阿光. 本专栏整理了<图神经网络代码实战&g ...

  5. PYG教程【三】对Cora数据集进行半监督节点分类

    Cora数据集 PyG包含有大量的基准数据集.初始化数据集非常简单,数据集初始化会自动下载原始数据文件,并且会将它们处理成Data格式. 如下图所示,Cora数据集中只有一个图,该图包含2708个节点 ...

  6. 使用PyG进行图神经网络的节点分类、链路预测和异常检测

    图神经网络(Graph Neural Networks)是一种针对图结构数据(如社交图.网络安全网络或分子表示)设计的机器学习算法.它在过去几年里发展迅速,被用于许多不同的应用程序.在这篇文章中我们将 ...

  7. GNN学习笔记(四):图注意力神经网络(GAT)节点分类任务实现

    目录 0 引言 1.Cora数据集 2.citeseer数据集 3.Pubmed数据集 4.DBLP数据集 5.Tox21 数据集 6.代码 嘚嘚嘚,唠叨小主,闪亮登场,哈哈,过时了过时了,闪亮登场换 ...

  8. PyG搭建异质图注意力网络HAN实现DBLP节点分类

    目录 前言 数据处理 模型搭建 1. 前向传播 2. 反向传播 3. 训练 4. 测试 实验结果 完整代码 前言 HAN的原理请见:WWW 2019 | HAN:异质图注意力网络. 数据处理 导入数据 ...

  9. Cesium教程(十四):简易三维模型的可视化

    Cesium教程(十四):简易三维模型的可视化 效果预览 1.高效三维数据格式:3D Tiles 3D Tiles是Cesium提出的处理三维地理大数据的数据格式,目前已是OGC数据标准之一,并在We ...

最新文章

  1. 2020年世界机器人报告
  2. 2018.12.28-bzoj-2006-[NOI2010]超级钢琴
  3. 实例解析linux内核I2C体系结构
  4. ubuntu 14.04下spark简易安装
  5. 利用QT实现X轴为时间动态显示曲线
  6. caffe框架下目标检测——faster-rcnn实战篇问题集锦
  7. Coding Contest HDU - 5988
  8. 售价19000元!华为发布全新5G折叠屏手机Mate Xs
  9. oracle 添加归档日志文件_oracle 归档日志文件路径设置
  10. 深入浅出 Proguard
  11. CSS3:伪类前的冒号和两个冒号区别
  12. linux 命令行下载bt,linux命令行下载BT种子和磁力链接
  13. 【瓦片地图】瓦片地图坐标转换
  14. Unity新创建的物体是灰色的,而且无法通过白色材质球给予纯白色(结果还是灰色)
  15. LeetCode #739 - Daily Temperatures
  16. go语言比java高级在哪里
  17. oracle分区表的作用
  18. java基础之Object类_繁星漫天_新浪博客
  19. selenium自动获取王者荣耀英雄海报并保存到本地
  20. 创建维基百科有什么作用?怎么编辑维基页面

热门文章

  1. java中集合的排序
  2. PHP的SOAP原理及实现
  3. PHPStorm的命令行配置成为Git bash的
  4. Linux的cmake3的安装 cmake3编译安装成功了的 yum对于cmake3表示成功但实际没成功
  5. JQUERY使选定DOM元素还原end
  6. main java game,playgame 一个JAVA编写的飞行小游戏,有基本完整的 框架,适合初学者参照学习 Other s 其他 238万源代码下载- www.pudn.com...
  7. dw指向html的根路径,dreamweaver中绝对、文档相对和站点根目录相对路径区分
  8. php跨域单点登录,SSO单点登录、跨域重定向、跨域设置Cookie、京东单点登录实例分析...
  9. 浏览器打开域名变成localhost_史上最全微信域名防封API原理及实现方案
  10. python3字符串操作_python3-字符串操作