MLP、GCN、GAT在数据集citeseer等上的节点分类任务

算是GNN的helloworld,直接上代码,注释很详细

# -*- coding: utf-8 -*-
"""
Created on Fri Feb 18 19:10:05 2022@author: lz
"""from torch_geometric.datasets import Planetoid
from torch_geometric.transforms import NormalizeFeaturesdataset = Planetoid(root = 'dataset', name='CiteSeer', transform=NormalizeFeatures())print()
print(f'Dataset:{dataset}')
print(f'Number of Graph:{len(dataset)}')
print(f'Number of features:{dataset.num_features}')
print(f'Number of classes:{dataset.num_classes}')data = dataset[0]print()
print(data)print()
print(f'Number of nodes:{data.num_nodes}')
print(f'Number of edges:{data.num_edges}')
print(f'Average node degree:{data.num_edges / data.num_nodes:.2f}')
print(f'Number of training nodes:{data.train_mask.sum()}')
print(f'Training node label rate:{data.train_mask.sum() / data.num_nodes:.2f}')
print(f'Contains isolated nodes:{data.has_isolated_nodes()}')
print(f'Contains self-loops:{data.has_self_loops()}')
print(f'Is undirected:{data.is_undirected()}')'''
可视化节点表征分布的方法
'''
import matplotlib.pyplot as plt
from sklearn.manifold import TSNE
def visualize(h, color):z = TSNE(n_components=2).fit_transform(out.detach().cpu().numpy())plt.figure(figsize=(10,10))plt.xticks([])plt.yticks([])plt.scatter(z[:, 0], z[:, 1], s=70, c=color, cmap="Set2")plt.show()'''
MLP神经网络的构造
'''
import torch
from torch.nn import Module
from torch.nn import Linear
import torch.nn.functional as Fclass MLP(Module):def __init__(self, hidden_channels):super(MLP, self).__init__()torch.manual_seed(12345)self.lin1 = Linear(dataset.num_features, hidden_channels)#dataset.num?不应该是dataset[0].num?难道dataset也有Num属性?self.lin2 = Linear(hidden_channels, dataset.num_classes)def forward(self, x):x = self.lin1(x)#等价于 self.lin1.forward(x),还是module call的forwardrelu = torch.nn.ReLU(inplace = True)x = relu(x)x = F.dropout(x, p=0.5, training=self.training)#!x = self.lin2(x)return xmodel = MLP(hidden_channels=16)
print()
print('MLP神经网络的构造')
print(model)print()
print('利用交叉熵损失和Adam优化器来训练这个简单的MLP神经网络')
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay= 5e-4)def train():model.train()#!optimizer.zero_grad()out = model(data.x)loss = criterion(out[data.train_mask], data.y[data.train_mask])loss.backward()#!optimizer.step()return lossprint('开始训练')for epoch in range(1, 201):loss = train()print(f'Epoch: {epoch:03d}, Loss: {loss:.4f}')print('看看测试集上的表现')def test():model.eval()#!out = model(data.x)pred = out.argmax(dim = 1)#选择概率最大的类test_correct = pred[data.test_mask] == data.y[data.test_mask]#检查标签是否正确test_acc = int(test_correct.sum()) / int(data.test_mask.sum()) return test_acctest_acc = test()
print(f'Test Accuracy:{test_acc:.4f}')print('将MLP中的torch.nn.Linear 替换为torch_geometric.nn.GCNConv,我们就可以得到一个GCN网络')
from torch_geometric.nn import GCNConvclass GCN(Module):def __init__(self, hidden_channels):super(GCN, self).__init__()torch.manual_seed(12345)self.conv1 = GCNConv(dataset.num_features, hidden_channels)#dataset.num?不应该是dataset[0].num?难道dataset也有Num属性?self.conv2 = GCNConv(hidden_channels, dataset.num_classes)def forward(self, x, edge_index):x = self.conv1(x, edge_index)#等价于 self.lin1.forward(x),还是module call的forwardrelu = torch.nn.ReLU(inplace = True)x = relu(x)x = F.dropout(x, p=0.5, training=self.training)#!x = self.conv2(x, edge_index)return xmodel = GCN(hidden_channels=16)
print(model)        print()
print('可视化未经训练的GCN生成的节点表征')
model.eval()   out = model(data.x, data.edge_index)
visualize(out, color=data.y)print()
print('训练GCN图神经网络')
optimizer = torch.optim.Adam(model.parameters(), lr = 0.01, weight_decay=5e-4)
criterion = torch.nn.CrossEntropyLoss()def train():model.train()optimizer.zero_grad()out = model(data.x, data.edge_index)#进行一次正向计算loss = criterion(out[data.train_mask], data.y[data.train_mask])loss.backward()optimizer.step()return lossfor epoch in range(1, 201):loss = train()print(f'Epoch: {epoch:03d}, Loss: {loss:.4f}')print('测试集上的准确性')
def test():model.eval()out = model(data.x, data.edge_index)pred = out.argmax(dim = 1)#选择概率最大的类test_correct = pred[data.test_mask] == data.y[data.test_mask]test_acc = int(test_correct.sum()) / int(data.test_mask.sum())return test_acctest_acc = test()
print(f'Test Accuracy: {test_acc:.4f}')print()
print('可视化训练后的GCN生成的节点表征')
model.eval()out = model(data.x, data.edge_index)
visualize(out, color=data.y)print()
print('将MLP中的torch.nn.Linear 替换为torch_geometric.nn.GCNConv,我们就可以得到一个GCN网络')
from torch_geometric.nn import GATConv
class GAT(Module):def __init__(self, hidden_channels):super(GAT, self).__init__()torch.manual_seed(12345)self.conv1 = GATConv(dataset.num_features, hidden_channels)#dataset.num?不应该是dataset[0].num?难道dataset也有Num属性?self.conv2 = GATConv(hidden_channels, dataset.num_classes)def forward(self, x, edge_index):x = self.conv1(x, edge_index)#等价于 self.lin1.forward(x),还是module call的forwardrelu = torch.nn.ReLU(inplace = True)x = relu(x)x = F.dropout(x, p=0.5, training=self.training)#!x = self.conv2(x, edge_index)return xmodel = GAT(hidden_channels=16)
print(model)print()
print('训练GAT图神经网络')
optimizer = torch.optim.Adam(model.parameters(), lr = 0.01, weight_decay=5e-4)
criterion = torch.nn.CrossEntropyLoss()def train():model.train()optimizer.zero_grad()out = model(data.x, data.edge_index)#进行一次正向计算loss = criterion(out[data.train_mask], data.y[data.train_mask])loss.backward()optimizer.step()return lossfor epoch in range(1, 201):loss = train()print(f'Epoch: {epoch:03d}, Loss: {loss:.4f}')print('测试集上的准确性')
def test():model.eval()out = model(data.x, data.edge_index)pred = out.argmax(dim = 1)#选择概率最大的类test_correct = pred[data.test_mask] == data.y[data.test_mask]test_acc = int(test_correct.sum()) / int(data.test_mask.sum())return test_acctest_acc = test()
print(f'Test Accuracy: {test_acc:.4f}')print()
print('可视化训练后的GAT生成的节点表征')
model.eval()out = model(data.x, data.edge_index)
visualize(out, color=data.y)

datawhalechina-GNN组队学习 作业:PyG不同模块在PyG数据集上的应用相关推荐

  1. 迁移学习的使用技巧和在不同数据集上的选择

    迁移学习的使用技巧和在不同数据集上的选择 1.迁移学习是指调整预训练的神经网络并应用到新的不同数据集上. 根据以下两个方面:新数据集的大小,以及新数据集和原始数据集之间的相似性 使用迁移学习的方式将不 ...

  2. Datawhale组队学习周报(第035周)

    希望开设的开源内容 目前Datawhale的开源内容分为两种:第一种是已经囊括在我们的学习路线图内的Datawhale精品课,第二种是暂未囊括在我们的学习路线图内的Datawhale打磨课.我们根据您 ...

  3. Datawhale组队学习周报(第032周)

    希望开设的开源内容 目前Datawhale的开源内容分为两种:第一种是已经囊括在我们的学习路线图内的Datawhale精品课,第二种是暂未囊括在我们的学习路线图内的Datawhale打磨课.我们根据您 ...

  4. 跟优秀的人一起进步:四月组队学习

    Datawhale学习 主办:Datawhale,人民邮电出版社异步社区 寄语:本次组队学习涵盖了机器学习算法.计算机视觉.Pandas.爬虫编程实践四个模块的内容. 第十期:Datawhale联合伯 ...

  5. Datawhale组队学习周报(第028周)

    吼一嗓子: 如果您有开源的内容希望通过组队学习的方式与大家分享,那么请跟我联系,我们来排期. 如果您对Datawhale某一门开源内容感兴趣,希望跟我们一起为学习者答疑解惑,那么请跟我联系,我们来排期 ...

  6. Datawhale组队学习周报(第027周)

    吼一嗓子: 如果您有开源的内容希望通过组队学习的方式与大家分享,那么请跟我联系,我们来排期. 如果您对Datawhale某一门开源内容感兴趣,希望跟我们一起为学习者答疑解惑,那么请跟我联系,我们来排期 ...

  7. AI学习笔记(十一)CNN之图像识别(上)

    AI学习笔记之CNN之图像识别(上) 图像识别 图像识别简介 模式识别 图像识别的过程 图像识别的应用 分类与检测 VGG Resnet 迁移学习&inception 卷积神经网络迁移学习fi ...

  8. HNU工训中心STC-B学习板大作业-基于OLED模块的多功能MP4

    主要功能在下面这张流程图里(直接用报告的流程图了) 下面展示一下效果(数码管的"welcome"比较抽象) ps. 后面新加的功能(我觉得MP4应该还具有看小说的功能,但是小说字太 ...

  9. python编程语言的优缺点_组队学习优秀作业 | Python的发展历史及其前景

    ↑↑↑关注后"星标"BioPython每日干货 & 每月组队学习,不错过BioPython学习 开源贡献: BioPython团队 创始人 Guido van Rossum ...

最新文章

  1. 获取图像的梯度,方向和方向梯度图像
  2. java 调用word插件_java一键生成word操作,比poi简单
  3. 68. Leetcode 669. 修剪二叉搜索树 (二叉搜索树-基本操作类)
  4. Map 的 key、value 是否允许为null
  5. 智慧交通day02-车流量检测实现05:小车匀加速案例
  6. string input must not be null解决办法
  7. ChinaJoy 第二天,是谁独得万千宠爱?
  8. 制作本地yum镜像站
  9. 大学生开学必备物品清单的详细介绍
  10. Centos7 安装 Kubernetes dashboard (安装篇)
  11. linux 下ftp的上传与下载
  12. 2、Zookeeper集群搭建、命令行Client操作
  13. Mariadb 安装教程 Windows版
  14. IOT物联网技术架构_物联网系统架构正式上架
  15. Educoder---Java继承与接口、文件
  16. 外贸供应链ERP怎么选?全流程综合管理解析
  17. 怎样黑进Microsoft:循序渐进指南 (转)
  18. 让网页FLASH变成黑白的css语句
  19. 循环辅助:continue和break
  20. 怎样解锁CAD图纸中被锁定的图层?

热门文章

  1. 前端架构之脚手架原理分析(一)
  2. 爱普生Epson EP-804AR 一体机驱动
  3. 苹果手机查看python代码_[代码全屏查看]-基于Python的苹果序列号官网查询接口调用代码实例...
  4. 蛋鸡养殖智能环控系统方案
  5. 关于ubuntu安装flash插件的问题
  6. GAN(生成式对抗网络)简介
  7. 浅谈 long long 和 int 的区别
  8. 网络黑客攻防学习平台之基础关第二题
  9. 微信公众号文章搬迁完成!
  10. taobao.trade.fullinfo.get( 获取单笔交易的详细信息 )