目录

0 引言

1、Cora数据集

2、citeseer数据集

3、Pubmed数据集

4、DBLP数据集

5、Tox21 数据集

6、代码


嘚嘚嘚,唠叨小主,闪亮登场,哈哈,过时了过时了,闪亮登场换成大驾光临,哈哈,这样才颇有气势,哼哼...

(唠叨小主哼哼了两声,对新改的词表示满意,微拉裙侧,留下了高跟鞋的声音...)

0 引言

近年来,人们对深度学习方法在图上的扩展越来越感兴趣。在多方因素的成功推动下,研究人员借鉴了卷积网络、循环网络和深度自动编码器的思想,定义和设计了用于处理图数据的神经网络结构,由此一个新的研究热点——“图神经网络(Graph Neural Networks,GNN)”应运而生。

图神经网络的研究与图嵌入或网络嵌入密切相关,图嵌入或网络嵌入是数据挖掘和机器学习界日益关注的另一个课题。许多图嵌入算法通常是无监督的算法,它们可以大致可以划分为三个类别,即矩阵分解、随机游走和深度学习方法。同时图嵌入的深度学习方法也属于图神经网络,包括基于图自动编码器的算法(如DNGR和SDNE)和无监督训练的图卷积神经网络(如GraphSage)。我们将图神经网络划分为五大类别,分别是:图卷积网络(Graph Convolution Networks,GCN)、 图注意力网络(Graph Attention Networks)、图自编码器( Graph Autoencoders)、图生成网络( Graph Generative Networks) 和图时空网络(Graph Spatial-temporal Networks)。


今天的任务——参照GNN学习笔记中的代码使用PyG中的图卷积模块在PyG的数据集上实现节点分类或回归任务,之前用到的是MLP、图卷积神经网络、图注意力神经网络,数据集是Cora数据集。

1、Cora数据集

Cora数据集由机器学习论文组成,是近年来图深度学习很喜欢使用的数据集。在数据集中,论文分为以下七类之一:基于案例、遗传算法、神经网络、概率方法、强化学习、规则学习、理论。论文的选择方式是,在最终语料库中,每篇论文引用或被至少一篇其他论文引用。整个语料库中有2708篇论文。在词干堵塞和去除词尾后,只剩下1433个独特的单词。文档频率小于10的所有单词都被删除。cora数据集包含1433个独特单词,所以特征是1433维。0和1描述的是每个单词在paper中是否存在。

数据下载地址:https://linqs-data.soe.ucsc.edu/public/lbc/cora.tgz

2、citeseer数据集

数据集下载地址:http://www.cs.umd.edu/~sen/lbc-proj/data/citeseer.tgz

3、Pubmed数据集

PubMed数据集包括来自Pubmed数据库的19717篇关于糖尿病的科学出版物,分为三类:

Diabetes Mellitus, Experimental
Diabetes Mellitus Type 1
Diabetes Mellitus Type 2
引文网络由44338个链接组成。数据集中的每个出版物都由一个由500个唯一单词组成的字典中的TF/IDF加权词向量来描述。
数据集下载地址:https://linqs-data.soe.ucsc.edu/public/Pubmed-Diabetes.tgz

4、DBLP数据集

DBLP数据集用XML描述,字段信息包括:author、title、pages、year、booktitle、url、crossref、publisher、ee、cdrom、isbn、cite_label等。其中作者名属性信息的格式是统一的,处理比较方便。目前,DBLP对作者重名问题的处理已经有不错的效果。例如:输入一作者名“wei wang”,可以得到16个不同的作者及其工作单位,并能链接得到每个作者的发表论文情况、个人主页和合作者列表等信息。(不存在问题了吗?)此外,引文信息中除了基本信息:作者名、文章名、会议名之外,加入新的信息:author keywords,对应于论文中的keywords。但是,并非所有的论文都包含有author keywords信息,也并非所有作者都有个人主页,在个人主页链接识别上还存在问题。

5、Tox21 数据集

此数据集来源于一个PubChem网站的一个2014年的竞赛:https://tripod.nih.gov/tox21/challenge/about.jsp
PubChem是美国国立卫生研究院(NIH)的开放化学数据库,是世界上最大的免费化学物信息集合。
PubChem的数据由数百个数据源提供,包括:政府机构,化学品供应商,期刊出版商等。

数据集可在此下载:https://tripod.nih.gov/tox21/challenge/data.jsp#

训练集和测试集都是由多个分子结构构成的sdf格式的文件。


6、代码

我们今天基于Citeseer数据构建图注意力神经网络模型。

# 获取并分析数据集
from torch_geometric.datasets import Planetoid
from torch_geometric.transforms import NormalizeFeatures
import torch
from torch.nn import Linear
import torch.nn.functional as F
from torch_geometric.nn import GATConvdataset = Planetoid(root='data/Planetoid', name='citeseer', transform=NormalizeFeatures())###神经网络的构造
class GCN(torch.nn.Module):def __init__(self, hidden_channels):super(GCN, self).__init__()torch.manual_seed(12345)self.conv1 = GCNConv(dataset.num_features, hidden_channels)self.conv2 = GCNConv(hidden_channels, dataset.num_classes)# self.lin1 = Linear(dataset.num_features, hidden_channels)# self.lin2 = Linear(hidden_channels, dataset.num_classes)def forward(self, x, edge_index):# x = self.lin1(x)x = self.conv1(x, edge_index)x = x.relu()x = F.dropout(x, p=0.5, training=self.training)x = self.conv2(x, edge_index)#x = self.lin2(x)return xmodel = GCN(hidden_channels=16)
print(model)###模型的训练
model = GCN(hidden_channels=16)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4) ##定义Adam优化器
criterion = torch.nn.CrossEntropyLoss() ##交叉熵损失def train():model.train()optimizer.zero_grad()  # Clear gradients.out = model(data.x, data.edge_index)  # Perform a single forward pass.loss = criterion(out[data.train_mask], data.y[data.train_mask])  # Compute the loss solely based on the training nodes.loss.backward()  # Derive gradients.optimizer.step()  # Update parameters based on gradients.return lossfor epoch in range(1, 201):loss = train()print(f'Epoch: {epoch:03d}, Loss: {loss:.4f}')##模型的测试 这部分与MLP神经网络相同
def test():model.eval()out = model(data.x, data.edge_index)pred = out.argmax(dim=1)  # Use the class with highest probability.test_correct = pred[data.test_mask] == data.y[data.test_mask]  # Check against ground-truth labels.test_acc = int(test_correct.sum()) / int(data.test_mask.sum())  # Derive ratio of correct predictions.return test_acctest_acc = test()
print(f'Test Accuracy: {test_acc:.4f}')##可视化
model.eval()out = model(data.x, data.edge_index)
visualize(out, color=data.y)

参考资料:
【1】https://zhuanlan.zhihu.com/p/75307407?from_voters_page=true

【2】https://blog.csdn.net/yyl424525/article/details/100831452

【3】https://blog.csdn.net/yyl424525/article/details/100831452

【4】https://blog.csdn.net/qq_32797059/article/details/106577815

GNN学习笔记(四):图注意力神经网络(GAT)节点分类任务实现相关推荐

  1. 【图神经网络】图神经网络(GNN)学习笔记:图分类

    图神经网络GNN学习笔记:图分类 1. 基于全局池化的图分类 2. 基于层次化池化的图分类 2.1 基于图坍缩的池化机制 1 图坍缩 2 DIFFPOOL 3. EigenPooling 2.2 基于 ...

  2. 【图神经网络】图神经网络(GNN)学习笔记:图的基础理论

    图神经网络GNN学习笔记:图的基础理论 1. 图的概述 2.图的基本类型 2.1 有向图和无向图 2.2 非加权图与加权图 2.3 连通图与非连通图 2.4 二部图 2.5 邻居和度 2.6 子图和路 ...

  3. 【Pytorch神经网络实战案例】22 基于Cora数据集实现图注意力神经网络GAT的论文分类

    注意力机制的特点是,它的输入向量长度可变,通过将注意力集中在最相关的部分来做出决定.注意力机制结合RNN或者CNN的方法. 1 实战描述 [主要目的:将注意力机制用在图神经网络中,完成图注意力神经网络 ...

  4. 深度学习 十四讲 循环神经网络例子--名字分类

    任务:根据输入的不同名字,分出所属国家 模型如下 数据两列:名字,国家 实现过程 准备数据 用ASCII表作为字典长度,字典长度为128 实际上这个77对应的是一个one_hot向量,这个向量一共有1 ...

  5. 【图神经网络】图神经网络(GNN)学习笔记:GNN的应用简介

    @TOC GNN的应用简述 GNN的适用范围非常广泛: 显式关联结构的数据:药物分子.电路网络等 隐式关联结构的数据:图像.文本等 生物化学领域中:分子指纹识别.药物分子设计.疾病分类等 交通领域中: ...

  6. 【图神经网络】图神经网络(GNN)学习笔记:GNN的通用框架

    图神经网络GNN学习笔记:GNN的通用框架 1. MPNN 2. NLNN 3. GN 参考资料 所谓通用框架,是对多种变体GNN网络结构的一般化总结,也是GNN编程的通用范式,这里介绍3类通用框架: ...

  7. 【图神经网络】图神经网络(GNN)学习笔记:基于GNN的图表示学习

    图神经网络GNN学习笔记:基于GNN的图表示学习 1. 图表示学习 2. 基于GNN的图表示学习 2.1 基于重构损失的GNN 2.2 基于对比损失的GNN 参考资料 本文主要就基于GNN的无监督图表 ...

  8. 注意力机制 神经网络_图注意力网络(GAT)

    引言 作者借鉴图神经网络中的注意力机制,提出了图注意力神经网络架构,创新点主要包含如下几个:①采用masked self-attention层,②隐式的对邻居节点采用不同权重③介绍了多头注意力机制. ...

  9. 【图结构】之图注意力网络GAT详解

    作者:張張張張 github地址:https://github.com/zhanghekai [转载请注明出处,谢谢!] GATGATGAT源代码地址:https://github.com/Petar ...

最新文章

  1. Ajax同步和异步的区别
  2. C++基础代码--20余种数据结构和算法的实现
  3. lazyload.js详解
  4. 树莓派AI视觉云台——5.SSH文件传输
  5. Python可视化神器之pyecharts
  6. miniui datagrid 隐藏列默认赋值_Qt商业组件DataGrid:内置视图和布局详解(一)
  7. vue echarts动态数据定时刷新
  8. 《深入浅出统计学》之统计学知识小结
  9. authentication failed : unrecognized kernel32 module. / NM
  10. ArcGIS Pro 学习路径
  11. 如何用Python给自己做一个年终总结
  12. HC05蓝牙模块(主从一体)简单使用
  13. FreeSWITCH的端口设置
  14. 修改系统时区(基于Debian的系统)--用Enki学Linux系列(15)
  15. 程序人生-hello`s P2P
  16. 加州大学欧文计算机排名,2019加州大学欧文分校排名(USNews排名)
  17. 计算机专业毕业设计题目大全文库,计算机专业毕业设计论文题目.doc
  18. UI设计师就业发展前景如何?
  19. 金融机构的反洗钱(AML)合规工作和系统建设
  20. Java多线程-将全量用户表70万数据压缩并生成CSV文件和推送到FTP上(最快快方式)

热门文章

  1. NOIP 2004 合唱队形
  2. (二)树莓派系列教程:树莓派4B手动连接wifi,远程控制。命令行界面、桌面界面
  3. spinnaker-简介
  4. SpringBoot项目入门,使用Eclipse创建Springboot项目
  5. python django实验室药物管理预警系统
  6. cursor 鼠标样式的几种样式
  7. windows10下模拟器运行LVGL记录
  8. 【sql注入】二次注入
  9. 油漆算法问题_不同类型的油漆(以及何时使用它们)
  10. oracle rac 11.2.0.4 镜像copy迁移数据到新存储