GNN教程:DGL框架中的采样模型!
↑↑↑关注后"星标"Datawhale
每日干货 & 每月组队学习,不错过
Datawhale干货
作者:秦州,算法工程师,Datawhale成员
引言
本文为GNN教程的系列干货。之前介绍了DGL这个框架,以及如何使用DGL编写一个GCN模型,用在学术数据集上,这样的模型是workable的。然而,现实生活中我们还会遇到非常庞大的图数据,庞大到邻接矩阵和特征矩阵不能同时塞进内存中,这时如何解决这样的问题呢?
DGL采用了和GraphSAGE类似的邻居采样策略,通过构建计算子图缩小了每次计算的图规模,这篇博文将会介绍DGL提供的采样模型。
GCN中暴露的问题
首先我们回顾一下GCN的逐层embedding更新公式,给定图
, 我们用在程序中用邻接矩阵和及节点embedding表示它,那么一个-层的GCN网络采用如下的更新公式,层节点的embedding 取决于它所有在层的邻居embedding
其中,
是节点的邻居节点集合,是正规化后的,比如,是可训练的权重矩阵。
在节点分类的任务中,我们采用如下形式计算loss:
其中
可以是任意的损失函数,比如交叉熵损失。
之前我们在GCN博文中提到,因为计算节点embedding的更新需要载入整个邻接矩阵
和特征矩阵进入到内存中(如果利用显卡加速计算,那么这些矩阵将会被载入到显存中),这样就暴露出一个问题,当图的规模特别大的时候,会变得特别大,当图中每个节点的特征维数特别高的时候会变得特别大,这两种情况都会导致整个图没法载入到内存(或者显存)中从而无法计算。
解决这个问题的方式,正如我们在GraphSAGE的博文中提到的那样,通过mini-batch训练的方式,每次只构建该batch内节点的一个子图进行更新。DGL这个框架自version 0.3之后正式支持这种mini-batch的训练方式,下面我们重点介绍一下它的编程框架。
DGL
和GraphSAGE中一致,DGL将mini-batch 训练的前期准备工作分为两个层面,首先建立为了更新一个batch内节点embedding而需要的所有邻居节点信息的子图,其次为了保证子图的大小不会受到”超级节点“影响,通过采样的技术将每个节点的邻居个数保持一致,使得这一批节点和相关邻居的embedding能够构成一个Tensor放入显卡中计算。
这两个模块分别叫NodeFlow
和Neighbor Sampling
,下面来详细得介绍它们。
NodeFlow
记一个batch内需要更新embedding的节点集合为
,从这个节点集合出发,我们可以根据边信息查找计算所要用到的所有邻居节点,比如在下图的例子中,图结构如a)所示,假设我们使用的是2层GCN模型(每个节点的更新考虑其2度以内的邻居embedding),某个batch内我们需要更新节点的embedding,根据更新规则,为了更新,我们需要其一度节点的embedding信息,即需要节点,而这些节点的更新又需要节点的embedding。因此我们的计算图如图b)所示,先由(Layer 0)更新的embedding,再由的embedding更新的embedding。这样的计算图在DGL中叫做
NodeFlow
。
NodeFlow
是一种层次结构的图,节点被组织在
层之内(比如上面例子中2层的GCN节点分布在Layer0, Layer1 和 Layer2中),只有在相邻的层之间才存在边,两个相邻的层称为块(block)。
NodeFlow
是反向建立的,首先确立一个batch内需要更新的节点集合(即Layer2中的节点),然后这个节点的1阶邻居构成了NodeFlow的下一层(即Layer1中的节点),再将下一层的节点当做是需要更新的节点集合,重复该操作,直到建立好所有层节点信息。
通过这种方式,在每个batch的训练中,我们实际上将原图a)转化成了一个子图b),因此当原图很大无法塞进内存的时候,我们可以通过调小batch_size解决这样的问题。
根据逐层更新公式可知,每一个block之间的计算是完全独立的,因此NodeFlow
提供了函数block_compute
以提供底层embedding向高层的传递和计算工作。
Neighbor Sampling
现实生活中的图的节点分布常常是长尾的,这意味着有一些“超级节点”的度非常高,而还有一大部分节点的度很小。如果我们在NodeFlow
的建立过程中关联到“超级节点“的话,”超级节点“就会为NodeFlow
的下一层带来很多节点,使得整个NodeFlow
非常庞大,违背了设计小的计算子图的初衷。为了解决这样的问题,GraphSAGE提出了邻居采样的策略,通过为每个节点采样一定数量的邻居来近似
,加上采样策略之后,节点embedding的更新公式变为:
其中
表示采样后的邻居集合。假设表示层采样的邻居数量(D^(L)表示该batch的节点个数),称为第层的”感知野“(respective field),那么通过采样技术一个
NodeFlow
的节点数就能被控制在内。
具体实现
在具体实现中,采样和计算是两个独立的模型,也就是说,我们通过采样获得子图,再将这个子图输入到标准的GCN模型中训练,这种解耦合的方式使模型变得非常灵活,因为我们可以对采样的方式进行定制,比如Stochastic Training of Graph Convolutional Networks with Variance Reduction选择特定的邻居以将方差控制在一定的范围内。这种模型与采样分离的方式也是大部分支持超大规模图计算框架的方式(包括这里介绍的DGL,之后我们要介绍的Euler)。
DGL提供NeighborSampler
类来构建采样后的NodeFlow
,NeighborSampler
返回的是一个迭代器,生成NodeFlow
实例,们来看看DGL提供的一个结合采样策略的GCN实例代码:
# dropout probability
dropout = 0.2
# batch size
batch_size = 1000
# number of neighbors to sample
num_neighbors = 4
# number of epochs
num_epochs = 1# initialize the model and cross entropy loss
model = GCNSampling(in_feats, n_hidden, n_classes, L,mx.nd.relu, dropout, prefix='GCN')
model.initialize()
loss_fcn = gluon.loss.SoftmaxCELoss()# use adam optimizer
trainer = gluon.Trainer(model.collect_params(), 'adam',{'learning_rate': 0.03, 'wd': 0})for epoch in range(num_epochs):i = 0for nf in dgl.contrib.sampling.NeighborSampler(g, batch_size,num_neighbors,neighbor_type='in',shuffle=True,num_hops=L,seed_nodes=train_nid):# When `NodeFlow` is generated from `NeighborSampler`, it only contains# the topology structure, on which there is no data attached.# Users need to call `copy_from_parent` to copy specific data,# such as input node features, from the original graph.nf.copy_from_parent()with mx.autograd.record():# forwardpred = model(nf)batch_nids = nf.layer_parent_nid(-1).astype('int64')batch_labels = labels[batch_nids]# cross entropy lossloss = loss_fcn(pred, batch_labels)loss = loss.sum() / len(batch_nids)# backwardloss.backward()# optimizationtrainer.step(batch_size=1)print("Epoch[{}]: loss {}".format(epoch, loss.asscalar()))i += 1# We only train the model with 32 mini-batches just for demonstration.if i >= 32:break
上面的代码中,model由GCNsampling
定义,虽然它的名字里有sampling,但这只是一个标准的GCN模型,其中没有任何和采样相关的内容,和采样相关代码的定义在dgl.contrib.sampling.Neighborsampler
中,使用图结构g
初始化这个类,并且定义采样的邻居个数num_neighbors
,它返回的nf
即是NodeFlow
实例,采样后的子图。因为nf
只会返回子图的拓扑结构,不会附带节点Embedding,所以需要调用copy_from_parent()
方法来获取Embedding,layer_parent_nid
返回该nodeflow中每一层的节点id,根据上面的图示,当前batch内的节点(称为种子节点)位于最高层,所以layer_parent_nid(-1)
返回当前batch内的节点id。剩下的步骤就是一个标准的模型训练代码,包括前向传播,计算loss,反向传播在此不再赘述。
Control Variate
通过采样而估计的
是无偏的,但是方差会较大,因此需要采大量的邻居样本来减少方差,因此在GraphSAGE的原论文中,作者设定了,。但是这样做在每一次采样中我们都有大量的邻居需要聚合,因此control variate和核心思路是缓存历史上计算过的聚合值,根据和本次采样的邻居共同估计,同时在每一轮中更新。通过使用这种计算,每一个节点采样两个邻居就足够了。
Control variate方法的原理为:给定随机变量
,我们想要估计它的期望,为此我们寻找另一个随机变量,和强相关并且的期望能够被轻松地计算得到。通过估计期望的近似值 可以表示为:
具体到我们的场景上,
是某次采样节点邻居的聚合,是该节点所有邻居的聚合。基于control variate的方法训练GCN的过程为:
那么上面的代码可以按照这种思路改写为:
g.ndata['h_0'] = features
for i in range(L):g.ndata['h_{}'.format(i+1)] = mx.nd.zeros((features.shape[0], n_hidden))# With control-variate sampling, we only need to sample 2 neighbors to train GCN.for nf in dgl.contrib.sampling.NeighborSampler(g, batch_size, expand_factor=2,neighbor_type='in', num_hops=L,seed_nodes=train_nid):for i in range(nf.num_blocks):# aggregate history on the original graphg.pull(nf.layer_parent_nid(i+1),fn.copy_src(src='h_{}'.format(i), out='m'),lambda node: {'agg_h_{}'.format(i): node.mailbox['m'].mean(axis=1)})nf.copy_from_parent()h = nf.layers[0].data['features']for i in range(nf.num_blocks):prev_h = nf.layers[i].data['h_{}'.format(i)]# compute delta_h, the difference of the current activation and the historynf.layers[i].data['delta_h'] = h - prev_h# refresh the old historynf.layers[i].data['h_{}'.format(i)] = h.detach()# aggregate the delta_hnf.block_compute(i,fn.copy_src(src='delta_h', out='m'),lambda node: {'delta_h': node.data['m'].mean(axis=1)})delta_h = nf.layers[i + 1].data['delta_h']agg_h = nf.layers[i + 1].data['agg_h_{}'.format(i)]# control variate estimatornf.layers[i + 1].data['h'] = delta_h + agg_hnf.apply_layer(i + 1, lambda node : {'h' : layer(node.data['h'])})h = nf.layers[i + 1].data['h']# update historynf.copy_to_parent()
上文代码中,nf
是NeighborSampler
返回的对象,在nf
的对象的每一个block
内,首先调用pull
函数获取
(即代码中的
agg_h_{}
),然后计算delta_h
和agg_h
),最后将更新后的结果拷贝回原大图中。
后话
这一篇博文介绍了DGL这个框架怎么对大图进行计算的,总结起来,它吸取了GraphSAGE的思路,通过为每个mini-batch构建子图并采样邻居的方式将图规模控制在可计算的范围内。这种采样-计算分离的模型基本是目前所有图神经网络计算大图时所采用的策略。
有两个细节没有介绍,第一、具体的采样方法,对于邻居的采样方法有很多种,除了最容易想到的重采样/负采样策略很多学者还提出了一些更加优秀的策略,之后我们会在"加速计算、近似方法"模块中详细讨论这些方法的原理;第二、对于超大规模的图,很多框架采用的是分布式的方式,典型的如Euler,这一系列我们还将写一篇关于Euler的博文,介绍它与DGL的异同,它的分布式架构和在超大规模图计算上做的创新。
Reference
DGL Tutorial on NodeFlow and Neighbor Sampling
“整理不易,点赞三连↓
GNN教程:DGL框架中的采样模型!相关推荐
- laravel中的ORM模型修改created_at,updated_at,deleted_at三个时间字段类型
laravel框架中的ORM模型极大的简化了数据库操作,同时也提高了数据操作安全性. 在laravel框架ORM模型中默认会有三个时间字段,created_at,updated_at,deleted_ ...
- GNN教程:DGL框架实现GCN算法!
↑↑↑关注后"星标"Datawhale 每日干货 & 每月组队学习,不错过 Datawhale干货 作者:秦州,算法工程师,Datawhale成员 引言 本文为GNN教程的 ...
- python flask框架下登录注册界面_Python的Flask框架中实现简单的登录功能的教程
Python 的 Flask 框架中实现简单的登录功能的教程 , 登录是各个 web 框架中的基础功能 , 需要的朋友可以参考下 回顾 在前面的系列章节中, 我们创建了一个数据库并且学着用用户和邮件来 ...
- 华为开源自研AI框架昇思MindSpore模型体验:ModelZoo中的CRNN
目录 一.环境准备 1.进入ModelArts官网 2.使用CodeLab体验Notebook实例 二.脚本说明 三.数据集 四.训练过程 五.评估过程 六.推理过程 CRNN描述 CRNN是一种基于 ...
- yii引入php文件,Yii2框架中CSS、JS文件引入要领_PHP开发框架教程
在yii2中,因为yii2版本升级致使了,许多yii2的用法跟yii1有着很大的区分,这几天一直在view层的视图界面徜徉着,碰到什么问题呢? (引荐进修:yii框架) 问题就是搞不清我该怎样去引入C ...
- python框架 mysql数据库_在Python的框架中为MySQL实现restful接口的教程
最近在做游戏服务分层的时候,一直想把mysql的访问独立成一个单独的服务DBGate,原因如下: 请求收拢到DBGate,可以使DBGate变为无状态的,方便横向扩展 当请求量或者存储量变大时,mys ...
- jsx 调用php,JavaScript_JavaScript的React框架中的JSX语法学习入门教程,什么是JSX?
在用React写组件的 - phpStudy...
JavaScript的React框架中的JSX语法学习入门教程 什么是JSX? 在用React写组件的时候,通常会用到JSX语法,粗看上去,像是在Javascript代码里直接写起了XML标签,实质上 ...
- Unittest自动化测试框架教程(四)——Python中的数据驱动测试DDT
" 数据驱动测试DDT(Data Drivern test),是自动化测试领域优势中亮眼的闪光点,在unittest测试框架中对数据驱动更是提供了强大的支持,文章通过基础概念的引入,介绍了 ...
- [DT框架使用教程01]如何在DT框架中创建插件
[DT框架使用教程01]如何在DT框架中创建插件 DT框架代码地址: https://github.com/huifeng-kooboo/DT 由于国内访问速度的问题 也可以访问gitee的地址: h ...
最新文章
- Django配置开发环境和生产环境以及配置Jinja2模板引擎
- Python---寻找给定序列中相差最小的两个数字
- python人脸检测与微信小程序_python+requests对app和微信小程序进行接口测试
- java 程序是由什么组成的 java_【问答题】一个典型的JAVA程序结构是由什么组成。...
- python学习[第二篇] 基础二
- oracle安装教程以及使用注意事项
- 计算机机房安全管理问题与措施,机房管理中存在的问题及处理对策
- 安防与消防融合发展的现状与机遇分析
- # 图书馆网上销售系统(c#+sql server)
- Echarts地图合并提取
- 屏幕录像专家----百度百科
- 罗马数字数字1到10对照表
- 1,标准差的计算 2,标准分数z-score
- OpenGL ES (二)EGL介绍和使用
- 自己做量化交易软件(1)通通量化分析环境安装使用
- Java 并发 随笔 1-初尝并发
- 仿照三元组的抽象数据类型分别写出抽象数据类型复数和有理数的定义
- 大四实习生的日常(一)
- 设备漏电对计算机影响,机箱漏电会不会影响电脑 机箱漏电会不会烧电脑主机内部硬件吗...
- python工程师工作总结_工程师的第一份工作小结
热门文章
- jlink api sdk c# 离线数获取 标定
- 实现一个 能在O(1)时间复杂度 完成 Push、Pop、Min操作的 栈
- 用ILSpy查看Session.SessionID的生成算法
- ORACLE分页SQL
- 中国电子学会图形化四级编程题:加减法混合运算器
- 如何利用离散Hopfield神经网络进行数字识别(1)
- uml具有多种视图_UML建模与架构文档化
- 爱耳日腾讯天籁行动再升级 助力100位青年听障人才打破“屏障”
- 网友抱怨:「苹果除了每年收我的钱,似乎什么都不想做」
- 图像分析用 OpenCV 与 Skimage,哪一个更好?