前言

上一篇文章从pyg提供的基本工具出发,介绍了pyg。但是大家用三方库,一般是将其作为积木来构建一个比较大的模型,把它用在自己的数据集上,而不是满足于跑跑demo里的简单模型和标准数据集。因此本文将从复现T-GCN(论文和官方源码见此)的角度出发,讲述怎么使用pyg搭建一个GNN-RNN模型,包括数据集的构建和模型的搭建。

刚开始复现的时候,我踩了很多坑,有的坑是因为不熟悉pyg踩的,有的坑是因为作者论文里和源码里的模型不一致踩的。这里说是复现,但是是以作者源码里的模型为准。那你可能会问了,都有源码了,我在这里吵吵啥“复现”呢?因为源码不是采用pyg写的,而是使用了原始的GCN计算方式,使用一些矩阵乘法做的。

前置准备

虽然我不是作者团队的……但是我觉得还是有必要介绍一下这篇文章的模型和数据集

T-GCN介绍

这里介绍的T-GCN的全称是Traffic-GCN,同学们有可能会在别的地方看到这个简称,但是不一定指的是这个模型。

T-GCN的核心模型结构是使用了GCN+GRU二者组合,先使用GCN得到更丰富的节点特征,再将每个节点的特征都送入GRU中进行计算。相当于使用GCN聚合空间特征,再使用GRU聚合时序特征,具体的计算公式如下。需要注意的是,每次输入的都是当前时间的特征和GRU的隐层特征,二者拼接后作为T-GCN-Cell(可以认为是T-GCN内部的一层卷积)的输入。

Convt(Xt,ht)=L⋅concat(Xt,ht)GRU(Xt,ht)=ut=σ(WuConvt(Xt,ht)+bu)rt=σ(WrConvt(Xt,ht)+br)ct=tanh(WcConvt(Xt,ht)+bc)ht+1=ut∗ht+(1−ut)∗ctConv_t(X_{t},h_{t})=L\cdot concat(X_{t},h_{t})\\ GRU(X_{t},h_{t})= \begin{aligned} u_t & = \sigma(W_uConv_t(X_{t},h_{t})+b_u) \\ r_t & = \sigma(W_rConv_t(X_{t},h_{t})+b_r) \\ c_t & = tanh(W_cConv_t(X_{t},h_{t})+b_c)\\ h_{t+1} & = u_t*h_t+(1-u_t)*c_t \end{aligned}\\ Convt​(Xt​,ht​)=L⋅concat(Xt​,ht​)GRU(Xt​,ht​)=ut​rt​ct​ht+1​​=σ(Wu​Convt​(Xt​,ht​)+bu​)=σ(Wr​Convt​(Xt​,ht​)+br​)=tanh(Wc​Convt​(Xt​,ht​)+bc​)=ut​∗ht​+(1−ut​)∗ct​​

数据集介绍

这里只采用“shenzhen”数据集。该数据集是在深圳156条道路采集的交通流量数据,采集间隔为5分钟一次,维度为1。此外,还附带一个道路间的邻接矩阵。也就是作者在建模过程中,将道路当作节点,将道路是否连通作为建图标准。

开始复现

使用pyg复现T-GCN的过程是比较痛苦的,因为必须要削足适履。个人认为pyg对时序数据的支持似乎不那么友好,当然也有可能是因为我没找到适合时序数据使用的DataLoader和Data对象。

时序序列的GNN数据有什么问题?

在之前的pyg介绍文章就说过,DataLoader会将每个Data对象视为一个图,形成mini-batch时,将一个batch里的Data对象打包成一个大图。这在非时序样本时,没有任何问题,但是对于时序数据,每个样本中包含多个图,此时DataLoader打包出来的对象可能就不符合我们的心意了。

静态图

我们先假设一种最简单的情况,每个样本的图结构和连接关系完全相同(这也被称为“静态图”),因此我们给每个Data对象都是完全相同的邻接矩阵。为了追求并行化,我们通常把一个时间点的所有数据抽出来一起计算。假设初始时隐状态为零向量,T-GCN的大致伪码如下图所示。

# x.shape is [num_nodes, seq_len, num_features]
h = zeros()
for i in range(seq_len):h = gru(gcn(concat(x[:, i, :], h), edge_index))

这么乍一看,好像没问题,实际上也确实没问题。edge_index按照mini-batch的方式拼接成大图;x拼接之后形成了[batch_size*num_nodes, seq_len, num_features]的矩阵,每次取其中一个时间点进行运算,输入的逻辑非常正确。

动态图

然后我们再看,假如样本中的图结构关系(这里只考虑边变化的情况)可以随着时间动态变化,那对于每个时间点,都需要一张独立的图,此时Data对象规定的edge_index结构就不满足我们的要求了,与之配合的DataLoader也会拼接出错误的mini-batch。

对于这种情况,我前思后想,辗转反侧,想到了一个相当削足适履的方法,就是改造DataLoader生成mini-batch的函数,使之对样本中每一个时间点的邻接矩阵进行拼接,然后生成一个List,维度为[seq_len, 2, num_edges],同时需要注意,每张图的num_edges可能不一样,所以这样一个数据还没法打包到一个Tensor里,只能用List存下所有时间点的batch大图

不过幸好,T-GCN是静态图,没这么麻烦,这种情况只是自己在做磕盐的时候遇到的,如果大家有更好的方法,也欢迎讨论。

搭建削足适履的模型

DataSet

然后我们搭建DataSet对象

from typing import List, Union, Tupleimport numpy as np
import torchfrom torch_geometric.data import InMemoryDataset, Dataset, Data
from utils.utils import dataset_path
from constant import DATASET_NAME_TRAFFIC
import pandas as pdclass TrafficDataSet(InMemoryDataset):# 一个点是15分钟seq_len = 4predict_len = 1DATASET_TYPE = 'sz'PROCESSED_DATASET_FILENAME = '%s_seq%d_pre%d' % (DATASET_TYPE, seq_len, predict_len)speed_name = DATASET_TYPE + '_speed.csv'adj_name = DATASET_TYPE + '_adj.csv'def __init__(self):super().__init__(root=dataset_path(DATASET_NAME_TRAFFIC))self.data, self.slices, self.max_speed, self.num_nodes, self.seq_len, self.pre_len = torch.load(self.processed_paths[0])@propertydef raw_file_names(self) -> Union[str, List[str], Tuple]:return [TrafficDataSet.speed_name, TrafficDataSet.adj_name]@propertydef processed_file_names(self) -> Union[str, List[str], Tuple]:return TrafficDataSet.PROCESSED_DATASET_FILENAME + '.pt'def download(self):passdef process(self):# 一个文件里是有抬头的,一个没有speed = pd.read_csv(self.raw_paths[0]).valuesadj = pd.read_csv(self.raw_paths[1], header=None).valuesnum_nodes = len(adj)adj = process_adj(adj)# 对样本的输出进行归一化,归一化参数需要记录下来,计算测试集MSE时要用max_speed = np.max(speed)speed = speed / max_speedspeed = torch.tensor(speed, dtype=torch.float32)adj = torch.tensor(adj, dtype=torch.int64)time_len = speed.shape[0]seq_len = TrafficDataSet.seq_lenpre_len = TrafficDataSet.predict_lendata_list = []for i in range(time_len - seq_len - pre_len):# speed = [time_len, num_nodes]# x = [num_nodes, seq_len, num_features=1]x = speed[i: i + seq_len].transpose(0,1).reshape([num_nodes, seq_len, 1])# y = [pre_len, num_nodes] -> [num_nodes, pre_len]y = speed[i + seq_len: i + seq_len + pre_len].transpose(0, 1)pyg_data = Data(x, edge_index=adj, y=y)data_list.append(pyg_data)data, slices = self.collate(data_list)torch.save((data, slices, max_speed, num_nodes, seq_len, pre_len), self.processed_paths[0])# 数据集给的是邻接矩阵,需要转换成pyg接受的稀疏矩阵的形式
def process_adj(adj):node_cnt = len(adj)pyg_adj = [[],[]]for i in range(node_cnt):for j in range(node_cnt):if adj[i][j] == 1:pyg_adj[0].append(i)pyg_adj[1].append(j)return np.array(pyg_adj)

T-GCN模型

模型听起来也不复杂,因此对照着源码直接开始搭建

import torch
from torch_geometric.nn.conv import GCNConvimport torch.nn.functional as F
class TGCN_Conv_Module(torch.nn.Module):def __init__(self, args):super(TGCN_Conv_Module, self).__init__()self.args = argsself.num_features = args.c_inself.nhid = args.c_out# 卷积层将输入与GRU的hidden_state拼接起来作为输入,输出hidden_size的特征self.conv1 = GCNConv(self.num_features+self.nhid, self.nhid)def forward(self, x, edge_index):# 实际上作者源码中只使用了一层GCN卷积,而论文中是两层x = F.relu(self.conv1(x, edge_index))x = torch.sigmoid(x)return xclass TGCNCell(torch.nn.Module):def __init__(self, args):super(TGCNCell, self).__init__()self.args = argsself.num_features = args.c_inself.nhid = args.c_outself.seq_len = args.seq_lenself.num_nodes = args.num_nodes# 这是仿照作者源码里的写法,实际上这是两个GCN,在forward函数中会将其输出拆成两半self.graph_conv1 = GCNConv(self.nhid+self.num_features, self.nhid * 2)self.graph_conv2 = GCNConv(self.nhid+self.num_features, self.nhid)self.reset_parameters()def reset_parameters(self):torch.nn.init.constant_(self.graph_conv1.bias, 1.0)def forward(self, x, edge_index, hidden_state):ru_input = torch.concat([x, hidden_state], dim=1)# 这里将一个GCN的输出拆成两半,如果熟悉其矩阵写法的话,实际上就是用了俩GCN# 但是这里的拆分函数也是仿照源码,个人觉得拆分的维度不对,但是这么写的准确率高ru = torch.sigmoid(self.graph_conv1(ru_input, edge_index))r, u = torch.chunk(ru.reshape([-1, self.num_nodes * 2 * self.nhid]), chunks=2, dim=1)r = r.reshape([-1, self.nhid])u = u.reshape([-1, self.nhid])c_input = torch.concat([x, r * hidden_state], dim=1)c = torch.tanh(self.graph_conv2(c_input, edge_index))new_hidden_state = u * hidden_state + (1.0 - u) * creturn new_hidden_state# 先进行图级别聚合,再进行序列建模
class RNNProcessHelper(torch.nn.Module):def __init__(self, args, rnn_cell):super(RNNProcessHelper, self).__init__()self.args = argsself.num_features = args.c_inself.nhid = args.c_outself.out_dim = args.out_dimself.seq_len = args.seq_lenself.num_nodes = args.num_nodesself.rnn_cell = rnn_celldef forward(self, data, hidden_state=None):x, edge_index = data.x, data.edge_indexif type(edge_index) is torch.Tensor:is_seq_edge_index = Falseelif type(edge_index) is list:is_seq_edge_index = Trueelse:raise '没有边连接信息!'if not hidden_state:hidden_state = torch.zeros([x.shape[0], self.nhid]).to(self.args.device)hidden_state_list = []for i in range(self.seq_len):# return gru_output.shape = [batch_size*num_nodes, hidden_size]if is_seq_edge_index:hidden_state = self.rnn_cell(x[:, i, :], edge_index[i], hidden_state)else:hidden_state = self.rnn_cell(x[:, i, :], edge_index, hidden_state)hidden_state_list.append(hidden_state)return hidden_state_list# 回归任务
class TGCN_Reg_Net(torch.nn.Module):def __init__(self, args):super(TGCN_Reg_Net, self).__init__()self.args = argsself.num_features = args.c_inself.nhid = args.c_outself.out_dim = args.out_dimself.seq_len = args.seq_lenself.num_nodes = args.num_nodes# self.tgcn_cell = TGCN_Cell(args)tgcn_cell = TGCNCell(args)self.seq_process_helper = RNNProcessHelper(args, tgcn_cell)# 将每个节点最终的hidden_state -> 该节点未来3小时的车速self.lin_out = torch.nn.Linear(self.nhid, self.out_dim)def forward(self, data):hidden_state_list = self.seq_process_helper(data)# 选最后一个output,用于预测hidden_state_last = hidden_state_list[-1]out = self.lin_out(hidden_state_last)# 按照数据集的构建方式,[batch*num_nodes, out_dim]return out@staticmethoddef test(model, loader, args) -> float:import mathmodel.eval()loss = 0.0max_speed = args.max_speed# 因为batch=1,所以一次是算一个样本的msefor data in loader:data = data.to(args.device)out = model(data)loss += F.mse_loss(out, data.y).item()mse_loss = loss / len(loader.dataset)rmse_loss = math.sqrt(mse_loss) * max_speed# print("val RMSE loss:{}".format(rmse_loss))return rmse_loss@staticmethoddef get_loss_function():from utils.loss_utils import mse_lossreturn mse_loss

主函数

import mathimport torch
from torch_geometric.loader import DataLoaderfrom utils.dataset_utils import split_dataset_by_ratio
from classfiers.tgcn import TGCN_Reg_Net
from datasets.traffic import TrafficDataSet
from utils.args_utils import get_args_pred
from utils.task_utils import trainif __name__ == '__main__':dataset = TrafficDataSet()train_set, test_set = split_dataset_by_ratio(dataset)args = get_args_pred(dataset)train_loader = DataLoader(train_set, batch_size=args.batch_size, shuffle=False)test_loader = DataLoader(test_set, batch_size=1, shuffle=False)model = TGCN_Reg_Net(args).to(args.device)optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)train(model, train_loader, test_loader, optimizer, args)

整个模型的大致搭建过程就是这样,还有一些工具函数没有给出,但是看名字也大致能知道是干什么的

复现中碰到的其他问题

在我复现的过程中,好不容易完成了模型的搭建,看起来也和源码里的一模一样了,但是最终的指标差很多。然后尝试着看了看源码中对超参数的规定,发现有一个很微妙的参数,weight_decay,作者并不是将其加在了优化器Adam的构造函数中(也就是令Adam的weight_decay=0,而是在计算loss时,在mse_loss的基础上,加上了模型参数的l2正则化损失,其计算方式如下

def regular_loss(model, lamda=0):reg_loss = 0.0for param in model.parameters():reg_loss += torch.sum(param ** 2)return lamda * reg_lossdef mse_loss(out, label, model, reg_weight=0):classify_loss = F.mse_loss(out.squeeze(), label.squeeze())reg_loss = regular_loss(model, reg_weight)return classify_loss + reg_loss

后来查阅一些资料发现,这是因为Adam对模型的惩罚力度也会随着模型的训练进行自适应调整,使用AdamW可以解决这一问题,然而实际上也没什么卵用。因此,对于我这种半桶水来说,还是老老实实用作者调出来的参数吧……

后记

这里简要介绍了自己复现T-GCN的过程,把模型和DataSet的构建过程贴了出来。现在暂时没有整理可以直接运行的源码供大家下载,因为作者已经开放了源码,而我只不过是拿pyg重新实现了一下,对于学习pyg本身没有太大的用处,对于学习T-GCN也没有太大的用处。

pytorch_geometric(pyg)复现T-GCN相关推荐

  1. GCN图卷积神经网络综述

    文章目录 一.GNN简史 二.GCN的常用方法及分类 2.1 基于频域的方法 2.2 基于空间域的方法 2.3 图池化模块 三. GCN常用的基准数据集 四.GCN的主要应用 4.1 计算机视觉 4. ...

  2. PGL图学习之图神经网络GNN模型GCN、GAT

    在4922份提交内容中,主要涉及13个研究方向,具体有: 1.AI应用应用,例如:语音处理.计算机视觉.自然语言处理等 2.深度学习和表示学习 3.通用机器学习 4.生成模型 5.基础设施,例如:数据 ...

  3. 「紫禁之巅」四大图神经网络架构

    近年来,人们对深度学习方法在图数据上的扩展越来越感兴趣.在深度学习的成功推动下,研究人员借鉴了卷积网络.循环网络和深度自动编码器的思想,定义和设计了用于处理图数据的神经网络结构.图神经网络的火热使得各 ...

  4. GNN Tricks《Bag of Tricks of Semi-Supervised Classification with Graph Neural Networks》

    Wang Y. Bag of Tricks of Semi-Supervised Classification with Graph Neural Networks[J]. arXiv preprin ...

  5. (初学必看)deep graph library(dgl)库的入门引导

    文章目录 前言 简单? 内置数据集 定义模型 定义dgl中的一个图 附录 前言 下载这个库要去官方网站:https://www.dgl.ai/,网站上会给你下载命令,这有点像下载pytorch的时候. ...

  6. ICML 2019 | SGC:简单图卷积网络

    目录 前言 1. 简单图卷积 1.1 GCN 1.2 SGC 2. 谱分析 2.1 图卷积 2.2 SGC与低通滤波器 3. 实验 3.1 引文/社交网络 3.2 下游任务 4. 总结 前言 题目: ...

  7. [PyG] 1.如何使用GCN完成一个最基本的训练过程(含GCN实现)

    0. 前言 为啥要学习Pytorch-Geometric呢?(下文统一简称为PyG) 简单来说,是目前做的项目有用到,还有1个特点,就是相比NYU的DeepGraphLibrary, DGL的问题是A ...

  8. GCN的几种模型复现笔记

    引言 本篇笔记紧接上文,主要是上一篇看写了快2w字,再去接入代码感觉有点不太妙,后台都崩了好几次,因为内存不足,那就正好将内容分开来,可以水两篇,另外也给脑子放个假,最近事情有点多,思绪都有些乱,跳出 ...

  9. GCN学习:用PyG实现自定义layers的GCN网络及训练(五)

    深度讲解PyG实现自定义layer的GCN 完整代码 自定义layer传播方式 从节点角度解读GCN原理 逐行讲解代码原理 init forward message 目前的代码讲解基本都是直接使用Py ...

  10. PyG利用GCN实现Cora、Citeseer、Pubmed引用论文节点分类

    文章目录 前言 一.导入相关库 二.Cora.Citeseer.Pubmed数据集 三.定义配置类 四.定义工具类 五.加载数据集 六.定义GCN网络 七.定义模型 八.模型训练 九.模型验证 八.结 ...

最新文章

  1. html选择按钮selected,HTML Option defaultSelected用法及代码示例
  2. php画中画,画中画功能 怎么将两个视频叠加播放,制作成画中画效果
  3. isulad代替docker_云原生时代的华为新“引擎”:iSula | Linux 中国
  4. Linux下用at计划任务
  5. 在Angular单元测试代码的it方法里连续调用两次detectChange方法,会触发两次ngAfterViewInit吗
  6. python拼图游戏_乐趣无穷的Python课堂
  7. 7-290 鸡兔同笼 (10 分)
  8. LINUX的一些简单命令 时间修改
  9. 给一个div innerhtml 后 没有内容显示的问题_实战:仅用18行JavaScript构建一个倒数计时器...
  10. skynet源码阅读7--死循环检测
  11. python爬虫,以某小说网站为例
  12. 新版Google工具栏(For Firefox)发布
  13. ArcGIS面矢量挖洞
  14. PV、UV、CTR含义
  15. Kubernetes(K8s)Events介绍(上)
  16. SSH协议及免密码登录
  17. 关于所谓U盘有占用空间,却看不到文件的一些看法
  18. 设计模式真的能改善软件质量吗 (二)
  19. 2021校招京东物流新锐之星校招笔试面试总结
  20. strictmath_Java StrictMath cosh()方法与示例

热门文章

  1. Cause: java.lang.ArrayIndexOutOfBoundsException: 8
  2. 蓝牙车载 linux,《基于嵌入式Linux蓝牙在车载电子系统中的应用》.pdf
  3. CV之ModelScope:基于ModelScope框架的人脸人像数据集利用DCT-Net算法实现人像卡通化图文教程之详细攻略
  4. 运动生物力学软件OpenSim入门及进阶——(一)解剖生理学
  5. 云存储平台——Seafile搭建
  6. 高远球技术(羽毛球)
  7. 【解决方法】域名指向本地(127.0.0.1, 0.0.0.0)
  8. 口算加密php怎么使用,从数盲到口算 ——带你玩转RSA加密算法(一)
  9. 评价页面,随手写的评价简陋模板
  10. 记一次OpenStack排错Exceeded maximum number of re tries. Exhausted all hosts available for retrying build