pytorch_geometric(pyg)复现T-GCN
前言
上一篇文章从pyg提供的基本工具出发,介绍了pyg。但是大家用三方库,一般是将其作为积木来构建一个比较大的模型,把它用在自己的数据集上,而不是满足于跑跑demo里的简单模型和标准数据集。因此本文将从复现T-GCN(论文和官方源码见此)的角度出发,讲述怎么使用pyg搭建一个GNN-RNN模型,包括数据集的构建和模型的搭建。
刚开始复现的时候,我踩了很多坑,有的坑是因为不熟悉pyg踩的,有的坑是因为作者论文里和源码里的模型不一致踩的。这里说是复现,但是是以作者源码里的模型为准。那你可能会问了,都有源码了,我在这里吵吵啥“复现”呢?因为源码不是采用pyg写的,而是使用了原始的GCN计算方式,使用一些矩阵乘法做的。
前置准备
虽然我不是作者团队的……但是我觉得还是有必要介绍一下这篇文章的模型和数据集
T-GCN介绍
这里介绍的T-GCN的全称是Traffic-GCN,同学们有可能会在别的地方看到这个简称,但是不一定指的是这个模型。
T-GCN的核心模型结构是使用了GCN+GRU二者组合,先使用GCN得到更丰富的节点特征,再将每个节点的特征都送入GRU中进行计算。相当于使用GCN聚合空间特征,再使用GRU聚合时序特征,具体的计算公式如下。需要注意的是,每次输入的都是当前时间的特征和GRU的隐层特征,二者拼接后作为T-GCN-Cell(可以认为是T-GCN内部的一层卷积)的输入。
Convt(Xt,ht)=L⋅concat(Xt,ht)GRU(Xt,ht)=ut=σ(WuConvt(Xt,ht)+bu)rt=σ(WrConvt(Xt,ht)+br)ct=tanh(WcConvt(Xt,ht)+bc)ht+1=ut∗ht+(1−ut)∗ctConv_t(X_{t},h_{t})=L\cdot concat(X_{t},h_{t})\\ GRU(X_{t},h_{t})= \begin{aligned} u_t & = \sigma(W_uConv_t(X_{t},h_{t})+b_u) \\ r_t & = \sigma(W_rConv_t(X_{t},h_{t})+b_r) \\ c_t & = tanh(W_cConv_t(X_{t},h_{t})+b_c)\\ h_{t+1} & = u_t*h_t+(1-u_t)*c_t \end{aligned}\\ Convt(Xt,ht)=L⋅concat(Xt,ht)GRU(Xt,ht)=utrtctht+1=σ(WuConvt(Xt,ht)+bu)=σ(WrConvt(Xt,ht)+br)=tanh(WcConvt(Xt,ht)+bc)=ut∗ht+(1−ut)∗ct
数据集介绍
这里只采用“shenzhen”数据集。该数据集是在深圳156条道路采集的交通流量数据,采集间隔为5分钟一次,维度为1。此外,还附带一个道路间的邻接矩阵。也就是作者在建模过程中,将道路当作节点,将道路是否连通作为建图标准。
开始复现
使用pyg复现T-GCN的过程是比较痛苦的,因为必须要削足适履。个人认为pyg对时序数据的支持似乎不那么友好,当然也有可能是因为我没找到适合时序数据使用的DataLoader和Data对象。
时序序列的GNN数据有什么问题?
在之前的pyg介绍文章就说过,DataLoader会将每个Data对象视为一个图,形成mini-batch时,将一个batch里的Data对象打包成一个大图。这在非时序样本时,没有任何问题,但是对于时序数据,每个样本中包含多个图,此时DataLoader打包出来的对象可能就不符合我们的心意了。
静态图
我们先假设一种最简单的情况,每个样本的图结构和连接关系完全相同(这也被称为“静态图”),因此我们给每个Data对象都是完全相同的邻接矩阵。为了追求并行化,我们通常把一个时间点的所有数据抽出来一起计算。假设初始时隐状态为零向量,T-GCN的大致伪码如下图所示。
# x.shape is [num_nodes, seq_len, num_features]
h = zeros()
for i in range(seq_len):h = gru(gcn(concat(x[:, i, :], h), edge_index))
这么乍一看,好像没问题,实际上也确实没问题。edge_index按照mini-batch的方式拼接成大图;x拼接之后形成了[batch_size*num_nodes, seq_len, num_features]
的矩阵,每次取其中一个时间点进行运算,输入的逻辑非常正确。
动态图
然后我们再看,假如样本中的图结构关系(这里只考虑边变化的情况)可以随着时间动态变化,那对于每个时间点,都需要一张独立的图,此时Data对象规定的edge_index结构就不满足我们的要求了,与之配合的DataLoader也会拼接出错误的mini-batch。
对于这种情况,我前思后想,辗转反侧,想到了一个相当削足适履的方法,就是改造DataLoader生成mini-batch的函数,使之对样本中每一个时间点的邻接矩阵进行拼接,然后生成一个List,维度为[seq_len, 2, num_edges]
,同时需要注意,每张图的num_edges
可能不一样,所以这样一个数据还没法打包到一个Tensor里,只能用List存下所有时间点的batch大图
不过幸好,T-GCN是静态图,没这么麻烦,这种情况只是自己在做磕盐的时候遇到的,如果大家有更好的方法,也欢迎讨论。
搭建削足适履的模型
DataSet
然后我们搭建DataSet对象
from typing import List, Union, Tupleimport numpy as np
import torchfrom torch_geometric.data import InMemoryDataset, Dataset, Data
from utils.utils import dataset_path
from constant import DATASET_NAME_TRAFFIC
import pandas as pdclass TrafficDataSet(InMemoryDataset):# 一个点是15分钟seq_len = 4predict_len = 1DATASET_TYPE = 'sz'PROCESSED_DATASET_FILENAME = '%s_seq%d_pre%d' % (DATASET_TYPE, seq_len, predict_len)speed_name = DATASET_TYPE + '_speed.csv'adj_name = DATASET_TYPE + '_adj.csv'def __init__(self):super().__init__(root=dataset_path(DATASET_NAME_TRAFFIC))self.data, self.slices, self.max_speed, self.num_nodes, self.seq_len, self.pre_len = torch.load(self.processed_paths[0])@propertydef raw_file_names(self) -> Union[str, List[str], Tuple]:return [TrafficDataSet.speed_name, TrafficDataSet.adj_name]@propertydef processed_file_names(self) -> Union[str, List[str], Tuple]:return TrafficDataSet.PROCESSED_DATASET_FILENAME + '.pt'def download(self):passdef process(self):# 一个文件里是有抬头的,一个没有speed = pd.read_csv(self.raw_paths[0]).valuesadj = pd.read_csv(self.raw_paths[1], header=None).valuesnum_nodes = len(adj)adj = process_adj(adj)# 对样本的输出进行归一化,归一化参数需要记录下来,计算测试集MSE时要用max_speed = np.max(speed)speed = speed / max_speedspeed = torch.tensor(speed, dtype=torch.float32)adj = torch.tensor(adj, dtype=torch.int64)time_len = speed.shape[0]seq_len = TrafficDataSet.seq_lenpre_len = TrafficDataSet.predict_lendata_list = []for i in range(time_len - seq_len - pre_len):# speed = [time_len, num_nodes]# x = [num_nodes, seq_len, num_features=1]x = speed[i: i + seq_len].transpose(0,1).reshape([num_nodes, seq_len, 1])# y = [pre_len, num_nodes] -> [num_nodes, pre_len]y = speed[i + seq_len: i + seq_len + pre_len].transpose(0, 1)pyg_data = Data(x, edge_index=adj, y=y)data_list.append(pyg_data)data, slices = self.collate(data_list)torch.save((data, slices, max_speed, num_nodes, seq_len, pre_len), self.processed_paths[0])# 数据集给的是邻接矩阵,需要转换成pyg接受的稀疏矩阵的形式
def process_adj(adj):node_cnt = len(adj)pyg_adj = [[],[]]for i in range(node_cnt):for j in range(node_cnt):if adj[i][j] == 1:pyg_adj[0].append(i)pyg_adj[1].append(j)return np.array(pyg_adj)
T-GCN模型
模型听起来也不复杂,因此对照着源码直接开始搭建
import torch
from torch_geometric.nn.conv import GCNConvimport torch.nn.functional as F
class TGCN_Conv_Module(torch.nn.Module):def __init__(self, args):super(TGCN_Conv_Module, self).__init__()self.args = argsself.num_features = args.c_inself.nhid = args.c_out# 卷积层将输入与GRU的hidden_state拼接起来作为输入,输出hidden_size的特征self.conv1 = GCNConv(self.num_features+self.nhid, self.nhid)def forward(self, x, edge_index):# 实际上作者源码中只使用了一层GCN卷积,而论文中是两层x = F.relu(self.conv1(x, edge_index))x = torch.sigmoid(x)return xclass TGCNCell(torch.nn.Module):def __init__(self, args):super(TGCNCell, self).__init__()self.args = argsself.num_features = args.c_inself.nhid = args.c_outself.seq_len = args.seq_lenself.num_nodes = args.num_nodes# 这是仿照作者源码里的写法,实际上这是两个GCN,在forward函数中会将其输出拆成两半self.graph_conv1 = GCNConv(self.nhid+self.num_features, self.nhid * 2)self.graph_conv2 = GCNConv(self.nhid+self.num_features, self.nhid)self.reset_parameters()def reset_parameters(self):torch.nn.init.constant_(self.graph_conv1.bias, 1.0)def forward(self, x, edge_index, hidden_state):ru_input = torch.concat([x, hidden_state], dim=1)# 这里将一个GCN的输出拆成两半,如果熟悉其矩阵写法的话,实际上就是用了俩GCN# 但是这里的拆分函数也是仿照源码,个人觉得拆分的维度不对,但是这么写的准确率高ru = torch.sigmoid(self.graph_conv1(ru_input, edge_index))r, u = torch.chunk(ru.reshape([-1, self.num_nodes * 2 * self.nhid]), chunks=2, dim=1)r = r.reshape([-1, self.nhid])u = u.reshape([-1, self.nhid])c_input = torch.concat([x, r * hidden_state], dim=1)c = torch.tanh(self.graph_conv2(c_input, edge_index))new_hidden_state = u * hidden_state + (1.0 - u) * creturn new_hidden_state# 先进行图级别聚合,再进行序列建模
class RNNProcessHelper(torch.nn.Module):def __init__(self, args, rnn_cell):super(RNNProcessHelper, self).__init__()self.args = argsself.num_features = args.c_inself.nhid = args.c_outself.out_dim = args.out_dimself.seq_len = args.seq_lenself.num_nodes = args.num_nodesself.rnn_cell = rnn_celldef forward(self, data, hidden_state=None):x, edge_index = data.x, data.edge_indexif type(edge_index) is torch.Tensor:is_seq_edge_index = Falseelif type(edge_index) is list:is_seq_edge_index = Trueelse:raise '没有边连接信息!'if not hidden_state:hidden_state = torch.zeros([x.shape[0], self.nhid]).to(self.args.device)hidden_state_list = []for i in range(self.seq_len):# return gru_output.shape = [batch_size*num_nodes, hidden_size]if is_seq_edge_index:hidden_state = self.rnn_cell(x[:, i, :], edge_index[i], hidden_state)else:hidden_state = self.rnn_cell(x[:, i, :], edge_index, hidden_state)hidden_state_list.append(hidden_state)return hidden_state_list# 回归任务
class TGCN_Reg_Net(torch.nn.Module):def __init__(self, args):super(TGCN_Reg_Net, self).__init__()self.args = argsself.num_features = args.c_inself.nhid = args.c_outself.out_dim = args.out_dimself.seq_len = args.seq_lenself.num_nodes = args.num_nodes# self.tgcn_cell = TGCN_Cell(args)tgcn_cell = TGCNCell(args)self.seq_process_helper = RNNProcessHelper(args, tgcn_cell)# 将每个节点最终的hidden_state -> 该节点未来3小时的车速self.lin_out = torch.nn.Linear(self.nhid, self.out_dim)def forward(self, data):hidden_state_list = self.seq_process_helper(data)# 选最后一个output,用于预测hidden_state_last = hidden_state_list[-1]out = self.lin_out(hidden_state_last)# 按照数据集的构建方式,[batch*num_nodes, out_dim]return out@staticmethoddef test(model, loader, args) -> float:import mathmodel.eval()loss = 0.0max_speed = args.max_speed# 因为batch=1,所以一次是算一个样本的msefor data in loader:data = data.to(args.device)out = model(data)loss += F.mse_loss(out, data.y).item()mse_loss = loss / len(loader.dataset)rmse_loss = math.sqrt(mse_loss) * max_speed# print("val RMSE loss:{}".format(rmse_loss))return rmse_loss@staticmethoddef get_loss_function():from utils.loss_utils import mse_lossreturn mse_loss
主函数
import mathimport torch
from torch_geometric.loader import DataLoaderfrom utils.dataset_utils import split_dataset_by_ratio
from classfiers.tgcn import TGCN_Reg_Net
from datasets.traffic import TrafficDataSet
from utils.args_utils import get_args_pred
from utils.task_utils import trainif __name__ == '__main__':dataset = TrafficDataSet()train_set, test_set = split_dataset_by_ratio(dataset)args = get_args_pred(dataset)train_loader = DataLoader(train_set, batch_size=args.batch_size, shuffle=False)test_loader = DataLoader(test_set, batch_size=1, shuffle=False)model = TGCN_Reg_Net(args).to(args.device)optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)train(model, train_loader, test_loader, optimizer, args)
整个模型的大致搭建过程就是这样,还有一些工具函数没有给出,但是看名字也大致能知道是干什么的
复现中碰到的其他问题
在我复现的过程中,好不容易完成了模型的搭建,看起来也和源码里的一模一样了,但是最终的指标差很多。然后尝试着看了看源码中对超参数的规定,发现有一个很微妙的参数,weight_decay
,作者并不是将其加在了优化器Adam的构造函数中(也就是令Adam的weight_decay=0
,而是在计算loss时,在mse_loss的基础上,加上了模型参数的l2正则化损失,其计算方式如下
def regular_loss(model, lamda=0):reg_loss = 0.0for param in model.parameters():reg_loss += torch.sum(param ** 2)return lamda * reg_lossdef mse_loss(out, label, model, reg_weight=0):classify_loss = F.mse_loss(out.squeeze(), label.squeeze())reg_loss = regular_loss(model, reg_weight)return classify_loss + reg_loss
后来查阅一些资料发现,这是因为Adam对模型的惩罚力度也会随着模型的训练进行自适应调整,使用AdamW可以解决这一问题,然而实际上也没什么卵用。因此,对于我这种半桶水来说,还是老老实实用作者调出来的参数吧……
后记
这里简要介绍了自己复现T-GCN的过程,把模型和DataSet的构建过程贴了出来。现在暂时没有整理可以直接运行的源码供大家下载,因为作者已经开放了源码,而我只不过是拿pyg重新实现了一下,对于学习pyg本身没有太大的用处,对于学习T-GCN也没有太大的用处。
pytorch_geometric(pyg)复现T-GCN相关推荐
- GCN图卷积神经网络综述
文章目录 一.GNN简史 二.GCN的常用方法及分类 2.1 基于频域的方法 2.2 基于空间域的方法 2.3 图池化模块 三. GCN常用的基准数据集 四.GCN的主要应用 4.1 计算机视觉 4. ...
- PGL图学习之图神经网络GNN模型GCN、GAT
在4922份提交内容中,主要涉及13个研究方向,具体有: 1.AI应用应用,例如:语音处理.计算机视觉.自然语言处理等 2.深度学习和表示学习 3.通用机器学习 4.生成模型 5.基础设施,例如:数据 ...
- 「紫禁之巅」四大图神经网络架构
近年来,人们对深度学习方法在图数据上的扩展越来越感兴趣.在深度学习的成功推动下,研究人员借鉴了卷积网络.循环网络和深度自动编码器的思想,定义和设计了用于处理图数据的神经网络结构.图神经网络的火热使得各 ...
- GNN Tricks《Bag of Tricks of Semi-Supervised Classification with Graph Neural Networks》
Wang Y. Bag of Tricks of Semi-Supervised Classification with Graph Neural Networks[J]. arXiv preprin ...
- (初学必看)deep graph library(dgl)库的入门引导
文章目录 前言 简单? 内置数据集 定义模型 定义dgl中的一个图 附录 前言 下载这个库要去官方网站:https://www.dgl.ai/,网站上会给你下载命令,这有点像下载pytorch的时候. ...
- ICML 2019 | SGC:简单图卷积网络
目录 前言 1. 简单图卷积 1.1 GCN 1.2 SGC 2. 谱分析 2.1 图卷积 2.2 SGC与低通滤波器 3. 实验 3.1 引文/社交网络 3.2 下游任务 4. 总结 前言 题目: ...
- [PyG] 1.如何使用GCN完成一个最基本的训练过程(含GCN实现)
0. 前言 为啥要学习Pytorch-Geometric呢?(下文统一简称为PyG) 简单来说,是目前做的项目有用到,还有1个特点,就是相比NYU的DeepGraphLibrary, DGL的问题是A ...
- GCN的几种模型复现笔记
引言 本篇笔记紧接上文,主要是上一篇看写了快2w字,再去接入代码感觉有点不太妙,后台都崩了好几次,因为内存不足,那就正好将内容分开来,可以水两篇,另外也给脑子放个假,最近事情有点多,思绪都有些乱,跳出 ...
- GCN学习:用PyG实现自定义layers的GCN网络及训练(五)
深度讲解PyG实现自定义layer的GCN 完整代码 自定义layer传播方式 从节点角度解读GCN原理 逐行讲解代码原理 init forward message 目前的代码讲解基本都是直接使用Py ...
- PyG利用GCN实现Cora、Citeseer、Pubmed引用论文节点分类
文章目录 前言 一.导入相关库 二.Cora.Citeseer.Pubmed数据集 三.定义配置类 四.定义工具类 五.加载数据集 六.定义GCN网络 七.定义模型 八.模型训练 九.模型验证 八.结 ...
最新文章
- html选择按钮selected,HTML Option defaultSelected用法及代码示例
- php画中画,画中画功能 怎么将两个视频叠加播放,制作成画中画效果
- isulad代替docker_云原生时代的华为新“引擎”:iSula | Linux 中国
- Linux下用at计划任务
- 在Angular单元测试代码的it方法里连续调用两次detectChange方法,会触发两次ngAfterViewInit吗
- python拼图游戏_乐趣无穷的Python课堂
- 7-290 鸡兔同笼 (10 分)
- LINUX的一些简单命令 时间修改
- 给一个div innerhtml 后 没有内容显示的问题_实战:仅用18行JavaScript构建一个倒数计时器...
- skynet源码阅读7--死循环检测
- python爬虫,以某小说网站为例
- 新版Google工具栏(For Firefox)发布
- ArcGIS面矢量挖洞
- PV、UV、CTR含义
- Kubernetes(K8s)Events介绍(上)
- SSH协议及免密码登录
- 关于所谓U盘有占用空间,却看不到文件的一些看法
- 设计模式真的能改善软件质量吗 (二)
- 2021校招京东物流新锐之星校招笔试面试总结
- strictmath_Java StrictMath cosh()方法与示例
热门文章
- Cause: java.lang.ArrayIndexOutOfBoundsException: 8
- 蓝牙车载 linux,《基于嵌入式Linux蓝牙在车载电子系统中的应用》.pdf
- CV之ModelScope:基于ModelScope框架的人脸人像数据集利用DCT-Net算法实现人像卡通化图文教程之详细攻略
- 运动生物力学软件OpenSim入门及进阶——(一)解剖生理学
- 云存储平台——Seafile搭建
- 高远球技术(羽毛球)
- 【解决方法】域名指向本地(127.0.0.1, 0.0.0.0)
- 口算加密php怎么使用,从数盲到口算 ——带你玩转RSA加密算法(一)
- 评价页面,随手写的评价简陋模板
- 记一次OpenStack排错Exceeded maximum number of re tries. Exhausted all hosts available for retrying build