我们使用图神经网络来生成节点表征,并通过基于监督学习的对图神经网络的训练,使得图神经网络学会产生高质量的节点表征。高质量的节点表征能够用于衡量节点的相似性,同时高质量的节点表征也是准确分类节点的前提。

获取数据集

from torch_geometric.datasets import Planetoid
from torch_geometric.transforms import NormalizeFeatures# NormalizeFeatures在将数据输入神经网络之前对节点特征进行归一化
dataset = Planetoid(root='dataset', name='Cora', transform=NormalizeFeatures())print(f'Dataset: {dataset}')
print(f'Number of graphs: {len(dataset)}')
print(f'Number of features: {dataset.num_features}')
print(f'Number of classes: {dataset.num_classes}')

MLP神经网络

1.创建MLP神经网络

import torch
from torch.nn import Linear
import torch.nn.functional as Fclass MLP(torch.nn.Module):def __init__(self, hidden_channels):super(MLP, self).__init__()torch.manual_seed(12345)self.lin1 = Linear(dataset.num_features, hidden_channels)self.lin2 = Linear(hidden_channels, dataset.num_classes)def forward(self, x):# 第一个线性层将1433维的节点表征嵌入到低维空间(hidden_channels=16)x = self.lin1(x)x = x.relu()x = F.dropout(x, p=0.5, training=self.training)# 第二个线性层将节点表征嵌入到类别空间(num_classes=7)x = self.lin2(x)return x

2.训练MLP神经网络

model = MLP(hidden_channels=16)
print(model)
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4)def train():model.train()optimizer.zero_grad()out = model(data.x)loss = criterion(out[data.train_mask], data.y[data.train_mask])loss.backward()optimizer.step()return lossfor epoch in range(1, 201):loss = train()print(f'Epoch: {epoch:03d}, Loss:{loss:.4f}')

3.MLP神经网络测试

def test():model.eval()out = model(data.x)pred = out.argmax(dim=1)test_correct = pred[data.test_mask] == data.y[data.test_mask]test_acc = int(test_correct.sum()) / int(data.test_mask.sum())return test_acctest_acc = test()
print(f'Test Accuracy: {test_acc:.4f}')

GCN图神经网络

1.创建GCN图神经网络

import torch
import torch.nn.functional as F
from torch_geometric.nn import GCNConvclass GCN(torch.nn.Module):def __init__(self, hidden_channels):super(GCN, self).__init__()torch.manual_seed(12345)self.conv1 = GCNConv(dataset.num_features,hidden_channels)self.conv2 = GCNConv(hidden_channels,dataset.num_classes)def forward(self, x, edge_index):x = self.conv1(x, edge_index)x = x.relu()x = F.dropout(x, p=0.5, training=self.training)x = self.conv2(x, edge_index)return x

2.将未训练的节点表征可视化

import matplotlib.pyplot as plt
from sklearn.manifold import TSNEdef visualize(h, color):z = TSNE(n_components=2).fit_transform(out.detach().cpu().numpy())plt.figure(figsize=(10,10))plt.xticks([])plt.yticks([])plt.scatter(z[:,0],z[:,1],s=70,c=color,cmap="Set2")plt.show()model = GCN(hidden_channels=16)
model.eval()
out = model(data.x, data.edge_index)
visualize(out, color=data.y)

3.GCN图神经网络的训练

model = GCN(hidden_channels=16)
optimizer = torch.optim.Adam(model.parameters(),lr=0.01,weight_decay=5e-4)
criterion = torch.nn.CrossEntropyLoss()def train():model.train()optimizer.zero_grad()out = model(data.x, data.edge_index)loss = criterion(out[data.train_mask], data.y[data.train_mask])loss.backward()optimizer.step()return lossfor epoch in range(1, 201):loss = train()print(f'Epoch: {epoch:03d}, Loss: {loss:.4f}')

4.GCN图神经网络的测试

def test():model.eval()out = model(data.x, data.edge_index)pred = out.argmax(dim=1)test_correct = pred[data.test_mask]==data.y[data.test_mask]test_acc = int(test_correct.sum())/int(data.test_mask.sum())return test_acctest_acc = test()
print(f'Test Accuracy: {test_acc:.4f}')

5.可视化训练后的节点表征

model.eval()
out = model(data.x, data.edge_index)
visualize(out, color=data.y)

图注意力神经网络(GAT)

import torch
from torch.nn import Linear
import torch.nn.functional as F
from torch_geometric.nn import GATConvclass GAT(torch.nn.Module):def __init__(self, hidden_channels):super(GAT, self).__init__()torch.manual_seed(12345)self.conv1 = GATConv(dataset.num_features, hidden_channels)self.conv2 = GATConv(hidden_channels, dataset.num_classes)def forward(self, x, edge_index):x = self.conv1(x, edge_index)x = x.relu()x = F.dropout(x, p=0.5, training=self.training)x = self.conv2(x, edge_index)return x

原文地址

基于图神经网络的节点表征学习相关推荐

  1. 图神经网络基础--基于图神经网络的节点表征学习

    图神经网络基础–基于图神经网络的节点表征学习 引言 在图节点预测或边预测任务中,首先需要生成节点表征(Node Representation).我们使用图神经网络来生成节点表征,并通过基于监督学习的对 ...

  2. 图神经网络/GNN(三)-基于图神经网络的节点表征学习

    Task3概览: 在图任务当中,首要任务就是要生成节点特征,同时高质量的节点表征也是用于下游机器学习任务的前提所在.本次任务通过GNN来生成节点表征,并通过基于监督学习对GNN的训练,使得GNN学会产 ...

  3. 图神经网络GNN(三):基于图神经网络的节点表征学习

    1. 写在前面 这个系列整理的关于GNN的相关基础知识, 图深度学习是一个新兴的研究领域,将深度学习与图数据连接了起来,推动现实中图预测应用的发展. 之前一直想接触这一块内容,但总找不到能入门的好方法 ...

  4. 基于图神经网络的节点表征

    我们使用图神经网络来生成节点表征,并通过基于监督学习的对图神经网络的训练,使得图神经网络学会产生高质量的节点表征.高质量的节点表征能够用于衡量节点的相似性,同时高质量的节点表征也是准确分类节点的前提. ...

  5. Datawhale 6月学习——图神经网络:超大图上的节点表征学习

    前情回顾 图神经网络:图数据表示及应用 图神经网络:消息传递图神经网络 图神经网络:基于GNN的节点表征学习 图神经网络:基于GNN的节点预测任务及边预测任务 1 超大图上的节点表征学习 1.1 简述 ...

  6. 腾讯游戏自研学术成果:基于图分割的网络表征学习初始化技术

    图是一种通用的数据表现形式,图算法逐渐在大数据处理中展现其价值.网络表征学习算法作为目前比较主流的一种图数据处理算法,引起学术界和工业界的极大兴趣. 本文介绍了 IEG 在网络表征学习方面的一个自研学 ...

  7. Datawhale 图神经网络 Task05 超大图上的节点表征学习

    学习课程:gitee_Datawhale_GNN 学习论坛:Datawhale CLUB 公众号:Datawhale 本次学习的内容是有关于超大图的,具体的论文是Cluster-GCN: An Eff ...

  8. 超大图上的节点表征学习

    一.Cluster-GCN 论文 Cluster-GCN: An Efficient Algorithm for Training Deep and Large Graph Convolutional ...

  9. 节点表征学习与节点预测和边预测

    基于图神经网络的节点表征学习 引言 在图节点预测或边预测任务中,需要先构造节点表征(representation),节点表征是图节点预测和边预测任务成功的关键.在此篇文章中,我们将学习如何基于图神经网 ...

最新文章

  1. UA MATH567 高维统计II 随机向量4 Frame、凸性与各向同性
  2. 机器学习:如何用相关性实现特征选择?
  3. 计算机开机显示已删除,教大家电脑开机出现部分便签的元数据已被损坏怎么办...
  4. 重温6 ListView相关|单位dp/sp
  5. 9:01 2009-7-20
  6. oracle权限的分配
  7. 国风这么火,少不了古风建筑PNG格式
  8. Cron 表达式一篇通
  9. 在Android中查看和管理sqlite数据库
  10. 手机怎么安装py thon_Python调试器– Py​​thon pdb
  11. java 调用 c# webservice 压缩 Liststring示例
  12. Outlook邮箱设置签名
  13. PyTorch神经网络框架
  14. Jenkins 与 GitLab 的自动化构建之旅
  15. android+桌面歌词,Android6.0系统适配桌面歌词效果
  16. python评价指标_[Python人工智能] 六.神经网络的评价指标、特征标准化和特征选择...
  17. Galera Cluster 实现mysql群集
  18. 解决C盘大小不足的问题
  19. 【方向盘】Spring Cloud 2021.0.0正式发布,FeignClient调用结果可一键缓存
  20. 华为模拟器ensp学习笔记

热门文章

  1. mongodb数据库的安装 for windows版本 0916
  2. 使用supervisor使Laravel的queue保持后台常驻
  3. TIDB事务过大transction too large解决方法
  4. Helios Service Release 2安装SVN
  5. Linux环境下Tomcat部署Solr4.x
  6. zabbix-3.0.4安装部署
  7. Facebook想用机器人取代App
  8. Symantec Backup Exec 2014 备份Exchange 2013之二安装主备服务器
  9. 《Sibelius 脚本程序设计》连载(三十九) - 4.9 SystemStaff
  10. 面试官系统精讲Java源码及大厂真题 - 14 简化工作:Guava Lists Maps 实际工作运用和源码