目录

  • 前言
  • 数据处理
  • 模型搭建
    • 1. 前向传播
    • 2. 反向传播
    • 3. 训练
    • 4. 测试
  • 实验结果
  • 完整代码

前言

HAN的原理请见:WWW 2019 | HAN:异质图注意力网络。

数据处理

导入数据:

path = os.path.abspath(os.path.dirname(os.getcwd())) + '\data\DBLP'
dataset = DBLP(path)
graph = dataset[0]
print(graph)

输出如下:

HeteroData(author={x=[4057, 334],y=[4057],train_mask=[4057],val_mask=[4057],test_mask=[4057]},paper={ x=[14328, 4231] },term={ x=[7723, 50] },conference={ num_nodes=20 },(author, to, paper)={ edge_index=[2, 19645] },(paper, to, author)={ edge_index=[2, 19645] },(paper, to, term)={ edge_index=[2, 85810] },(paper, to, conference)={ edge_index=[2, 14328] },(term, to, paper)={ edge_index=[2, 85810] },(conference, to, paper)={ edge_index=[2, 14328] }
)

可以发现,DBLP数据集中有作者(author)、论文(paper)、术语(term)以及会议(conference)四种类型的节点。DBLP中包含14328篇论文(paper), 4057位作者(author), 20个会议(conference), 7723个术语(term)。作者分为四个领域:数据库、数据挖掘、机器学习、信息检索。

任务:对author节点进行分类,一共4类。

由于conference节点没有特征,因此需要预先设置特征:

graph['conference'].x = torch.ones((graph['conference'].num_nodes, 1))

所有conference节点的特征都初始化为[1]

获取一些有用的数据:

num_classes = torch.max(graph['author'].y).item() + 1
train_mask, val_mask, test_mask = graph['author'].train_mask, graph['author'].val_mask, graph['author'].test_mask
y = graph['author'].y

模型搭建

首先导入包:

from torch_geometric.nn import HANConv

模型参数:

  1. in_channels:输入通道,比如节点分类中表示每个节点的特征数,一般设置为-1。
  2. out_channels:输出通道,最后一层GCNConv的输出通道为节点类别数(节点分类)。
  3. heads:多头注意力机制中的头数。值得注意的是,GANConv和GATConv不一样的地方在于,GANConv模型是把多头注意力的结果直接展平,而不是进行concat操作。
  4. negative_slope:LeakyRELU的参数。

于是模型搭建如下:

class HAN(nn.Module):def __init__(self, in_channels, hidden_channels, out_channels):super(HAN, self).__init__()# H, D = self.heads, self.out_channels // self.headsself.conv1 = HANConv(in_channels, hidden_channels, graph.metadata(), heads=8)self.conv2 = HANConv(hidden_channels, out_channels, graph.metadata(), heads=4)def forward(self, data):x_dict, edge_index_dict = data.x_dict, data.edge_index_dictx = self.conv1(x_dict, edge_index_dict)x = self.conv2(x, edge_index_dict)x = F.softmax(x['author'], dim=1)return x

输出一下模型:

model = HAN(-1, 64, num_classes).to(device)
HAN((conv1): HANConv(64, heads=8)(conv2): HANConv(4, heads=4)
)

1. 前向传播

查看官方文档中HANConv的输入输出要求:

可以发现,HANConv中需要输入的是节点特征字典x_dict和邻接关系字典edge_index_dict

因此有:

x_dict, edge_index_dict = data.x_dict, data.edge_index_dict
x = self.conv1(x_dict, edge_index_dict)

此时我们不妨输出一下x['author']及其size:

tensor([[0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],[0.0969, 0.0601, 0.0000,  ..., 0.0000, 0.0000, 0.0251],[0.0000, 0.0000, 0.0000,  ..., 0.1288, 0.0000, 0.0602],...,[0.0000, 0.0000, 0.0000,  ..., 0.0096, 0.0000, 0.0240],[0.0000, 0.0000, 0.0000,  ..., 0.0096, 0.0000, 0.0240],[0.0801, 0.0558, 0.0837,  ..., 0.0277, 0.0347, 0.0000]],device='cuda:0', grad_fn=<SumBackward1>)
torch.Size([4057, 64])

此时的x一共4057行,每一行表示一个author节点经过第一层卷积更新后的状态向量。

那么同理,由于:

x = self.conv2(x, edge_index_dict)

所以经过第二层卷积后得到的x['author']的size应该为:

torch.Size([4057, 4])

即每个author节点的维度为4的状态向量。

由于我们需要进行4分类,所以最后需要加上一个softmax:

x = F.softmax(x, dim=1)

dim=1表示对每一行进行运算,最终每一行之和加起来为1,也就表示了该节点为每一类的概率。输出此时的x:

tensor([[0.2591, 0.2539, 0.2435, 0.2435],[0.3747, 0.2067, 0.2029, 0.2157],[0.2986, 0.2338, 0.2338, 0.2338],...,[0.2740, 0.2453, 0.2403, 0.2403],[0.2740, 0.2453, 0.2403, 0.2403],[0.3414, 0.2195, 0.2195, 0.2195]], device='cuda:0',grad_fn=<SoftmaxBackward0>)

2. 反向传播

在训练时,我们首先利用前向传播计算出输出:

f = model(graph)

f即为最终得到的每个节点的4个概率值,但在实际训练中,我们只需要计算出训练集的损失,所以损失函数这样写:

loss = loss_function(f[train_mask], y[train_mask])

然后计算梯度,反向更新!

3. 训练

训练时返回验证集上表现最优的模型:

def train():model = HAN(-1, 64, num_classes).to(device)print(model)optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=1e-4)loss_function = torch.nn.CrossEntropyLoss().to(device)min_epochs = 5best_val_acc = 0final_best_acc = 0model.train()for epoch in range(100):f = model(graph)loss = loss_function(f[train_mask], y[train_mask])optimizer.zero_grad()loss.backward()optimizer.step()# validationval_acc, val_loss = test(model, val_mask)test_acc, test_loss = test(model, test_mask)if epoch + 1 > min_epochs and val_acc > best_val_acc:best_val_acc = val_accfinal_best_acc = test_accprint('Epoch {:3d} train_loss {:.5f} val_acc {:.3f} test_acc {:.3f}'.format(epoch, loss.item(), val_acc, test_acc))return final_best_acc

4. 测试

def test(model, mask):model.eval()with torch.no_grad():out = model(graph)loss_function = torch.nn.CrossEntropyLoss().to(device)loss = loss_function(out[mask], y[mask])_, pred = out.max(dim=1)correct = int(pred[mask].eq(y[mask]).sum().item())acc = correct / int(test_mask.sum())return acc, loss.item()

实验结果

数据集采用DBLP网络,训练100轮,分类正确率为78.54%:

HAN Accuracy: 0.7853853239177156

完整代码

代码地址:GNNs-for-Node-Classification。原创不易,下载时请给个follow和star!感谢!!

PyG搭建异质图注意力网络HAN实现DBLP节点分类相关推荐

  1. 图注意力网络_EMNLP 2019开源论文:针对短文本分类的异质图注意力网络

    本文同步发表在 PaperWeekly EMNLP 2019开源论文:针对短文本分类的异质图注意力网络​mp.weixin.qq.com 本文由北邮和南洋理工联合发表在自然语言处理顶会 EMNLP 2 ...

  2. GNN学习笔记(四):图注意力神经网络(GAT)节点分类任务实现

    目录 0 引言 1.Cora数据集 2.citeseer数据集 3.Pubmed数据集 4.DBLP数据集 5.Tox21 数据集 6.代码 嘚嘚嘚,唠叨小主,闪亮登场,哈哈,过时了过时了,闪亮登场换 ...

  3. HAN - Heterogeneous Graph Attention Network 异构图注意力网络 WWW2019

    论文题目:Heterogeneous Graph Attention Network (HAN)异构图注意力网络 作者:北京邮电大学Xiao Wang,Houye Ji等人 来源:WWW2019 论文 ...

  4. 知识图注意力网络 KGAT

    自经典的GAT之后,各种花式图注意力网络层出不穷. 例如, 动态图注意力网络,异质图注意力网络, 知识图注意力网络. 本文介绍了用于推荐的知识图注意力网络KGAT,发表在KDD2019. 作者: 黄海 ...

  5. HAN - Heterogeneous Graph Attention Network 异构图注意力网络 WWW 2019

    文章目录 1 相关介绍 背景 元路径 meta-path 异构图和同构图 相关工作 Graph Neural Network Network Embedding 贡献 2 HAN模型 2.1 Node ...

  6. 异构图注意力网络Heterogeneous Graph Attention Network ( HAN )

    文章目录 前言 一.基础知识 1.异构图(Heterogeneous Graph) 2.元路径 3.异构图注意力网络 二.异构图注意力网络 1.结点级别注意力(Node-level Attention ...

  7. HGAT-用于半监督短文本分类的异构图注意力网络

    来源:EMNLP 2019 论文链接 代码及数据集链接 摘要 短文本分类在新闻和推特中找到了丰富和有用的标记,以帮助用户找到相关信息.由于在许多实际应用案例中缺乏有标记的训练数据,因此迫切需要研究半监 ...

  8. HGANMDA:用于miRNA与疾病关联预测的分层图注意力网络(Molecular Therapy)

    HGANMDA:Hierarchical graph attention network for miRNA-disease association prediction https://www.sc ...

  9. GNN动手实践(二):复现图注意力网络GAT

    参考论文:Graph Attention Networks 一.前言 GAT(图注意力网络)是GNNs中重要的SOTA模型,该模型是从空域角度来进行定义,能够用消息传递范式来进行解释.GAT与GCN最 ...

最新文章

  1. Python需求增速达174%,AI人才缺口仍超百万!这份来自2017年的实际招聘数据如是说
  2. 《神魔降世》隐私政策
  3. 微型计算机在现代通信中的应用,计算机基础单元试卷
  4. Ubuntu 10.10(64位)编译Android 2.3
  5. 2021-07-29
  6. 男人是大猪蹄子的证据找到了!
  7. C++服务器设计(七):聊天系统服务端实现
  8. CCIE理论-第十五篇-IPV6-重分布+ACL+前缀列表
  9. 超级计算机阿波罗11,Apollo 8000推进超算科学发展
  10. Linux服务器安全策略配置-PAM身份验证模块(二)
  11. linux PHY驱动
  12. JAVA中运行看不见窗口_eclipse中已经把窗口设置为可视,为什么运行 时还是看不到窗口?...
  13. Oracle刷建表语句
  14. GEE:快速下载数字高程DEM数据
  15. 在C语言中如何计算根号
  16. Linux 管理多个软件版本的方法总结
  17. U盘系统安装步骤超级简单,弄懂ghost不管是windows7win10都不难
  18. 数据预处理部分的思维导图
  19. 阿里巴巴淘系开源首个多模态直播服饰检索数据集
  20. Linux中常用的文件目录,Linux学习笔记2——Linux中常用文件目录操作命令

热门文章

  1. Abp报 Castle.Proxies.XXXAppServiceProxy的错误问题
  2. BitTorrent概述(选自维基)
  3. Java(二)分支循环、数组、字符串、方法
  4. Unity中通过mask组件裁剪出圆形图片,制作出圆形头像
  5. 【C++】C++行操作,退格,退行,空格输出。
  6. c++无限小数加法实现
  7. PVE使用AMD CPU 5600G 核显直通
  8. 办公室服务器安装系统,教你如何架设办公室FTP服务器以Serv-U为例
  9. 底部弹出PopupWindow并且背景变为半透明效果
  10. python语言程序设计_梁勇—第五章练习题重点题目答案