更多图神经网络和深度学习内容请关注:

第3章:构建图神经网络(GNN)模块

DGL NN模块是用户构建GNN模型的基本模块。根据DGL所使用的后端深度神经网络框架, DGL NN模块的父类取决于后端所使用的深度神经网络框架。对于PyTorch后端, 它应该继承 PyTorch的NN模块;对于MXNet后端,它应该继承 MXNet Gluon的NN块; 对于TensorFlow后端,它应该继承 Tensorflow的Keras层。 在DGL NN模块中,构造函数中的参数注册和前向传播函数中使用的张量操作与后端框架一样。这种方式使得DGL的代码可以无缝嵌入到后端框架的代码中。 DGL和这些深度神经网络框架的主要差异是其独有的消息传递操作。

DGL已经集成了很多常用的 Conv Layers、 Dense Conv Layers、 Global Pooling Layers 和 Utility Modules。欢迎给DGL贡献更多的模块!

本章将使用PyTorch作为后端,用 SAGEConv 作为例子来介绍如何构建用户自己的DGL NN模块。

3.1 DGL NN模块的构造函数

构造函数__init__完成以下几个任务:

  • 设置选项。
  • 注册可学习的参数或者子模块。
  • 初始化参数。
import torch.nn as nnfrom dgl.utils import expand_as_pairclass SAGEConv(nn.Module):def __init__(self,in_feats,out_feats,aggregator_type,bias=True,norm=None,activation=None):super(SAGEConv, self).__init__()self._in_src_feats, self._in_dst_feats = expand_as_pair(in_feats)#函数可以返回一个二维元组。self._out_feats = out_featsself._aggre_type = aggregator_typeself.norm = normself.activation = activation
Using backend: pytorch

在构造函数中,用户首先需要设置数据的维度。对于一般的PyTorch模块,维度通常包括输入的维度、输出的维度和隐层的维度。 对于图神经网络输入维度可被分为源节点特征维度和目标节点特征维度

除了数据维度,图神经网络的一个典型选项是聚合类型(self._aggre_type)。对于特定目标节点,聚合类型决定了如何聚合不同边上的信息。 常用的聚合类型包括 meansummaxmin。一些模块可能会使用更加复杂的聚合函数,比如 lstm

上面代码里的 norm 是用于特征归一化的可调用函数。在SAGEConv论文里,归一化可以是L2归一化: hv=hv/∥hv∥2h_v = h_v / \lVert h_v \rVert_2hv​=hv​/∥hv​∥2​

# 聚合类型:mean、max_pool、lstm、gcn
if aggregator_type not in ['mean', 'max_pool', 'lstm', 'gcn']:raise KeyError('Aggregator type {} not supported.'.format(aggregator_type))
if aggregator_type == 'max_pool':self.fc_pool = nn.Linear(self._in_src_feats, self._in_src_feats)
if aggregator_type == 'lstm':self.lstm = nn.LSTM(self._in_src_feats, self._in_src_feats, batch_first=True)
if aggregator_type in ['mean', 'max_pool', 'lstm']:self.fc_self = nn.Linear(self._in_dst_feats, out_feats, bias=bias)
self.fc_neigh = nn.Linear(self._in_src_feats, out_feats, bias=bias)
self.reset_parameters()

注册参数和子模块。在SAGEConv中,子模块根据聚合类型而有所不同。这些模块是纯PyTorch NN模块,例如 nn.Linearnn.LSTM 等。 构造函数的最后调用了 reset_parameters() 进行权重初始化。

def reset_parameters(self):"""重新初始化可学习的参数"""gain = nn.init.calculate_gain('relu')if self._aggre_type == 'max_pool':nn.init.xavier_uniform_(self.fc_pool.weight, gain=gain)if self._aggre_type == 'lstm':self.lstm.reset_parameters()if self._aggre_type != 'gcn':nn.init.xavier_uniform_(self.fc_self.weight, gain=gain)nn.init.xavier_uniform_(self.fc_neigh.weight, gain=gain)

完整代码

import torch.nn as nnfrom dgl.utils import expand_as_pairclass SAGEConv(nn.Module):def __init__(self,in_feats,out_feats,aggregator_type,bias=True,norm=None,activation=None):super(SAGEConv, self).__init__()self._in_src_feats, self._in_dst_feats = expand_as_pair(in_feats)self._out_feats = out_featsself._aggre_type = aggregator_typeself.norm = normself.activation = activation# 聚合类型:mean、max_pool、lstm、gcnif aggregator_type not in ['mean', 'max_pool', 'lstm', 'gcn']:raise KeyError('Aggregator type {} not supported.'.format(aggregator_type))if aggregator_type == 'max_pool':self.fc_pool = nn.Linear(self._in_src_feats, self._in_src_feats)if aggregator_type == 'lstm':self.lstm = nn.LSTM(self._in_src_feats, self._in_src_feats, batch_first=True)if aggregator_type in ['mean', 'max_pool', 'lstm']:self.fc_self = nn.Linear(self._in_dst_feats, out_feats, bias=bias)self.fc_neigh = nn.Linear(self._in_src_feats, out_feats, bias=bias)self.reset_parameters()def reset_parameters(self):"""重新初始化可学习的参数"""gain = nn.init.calculate_gain('relu')if self._aggre_type == 'max_pool':nn.init.xavier_uniform_(self.fc_pool.weight, gain=gain)if self._aggre_type == 'lstm':self.lstm.reset_parameters()if self._aggre_type != 'gcn':nn.init.xavier_uniform_(self.fc_self.weight, gain=gain)nn.init.xavier_uniform_(self.fc_neigh.weight, gain=gain)

3.2 编写DGL NN模块的forward函数

在NN模块中, forward() 函数执行了实际的消息传递和计算。与通常以张量为参数的PyTorch NN模块相比, DGL NN模块额外增加了1个参数 dgl.DGLGraphforward() 函数的内容一般可以分为3项操作:

  • 检测输入图对象是否符合规范。
  • 消息传递和聚合。
  • 聚合后,更新特征作为输出。

下文展示了SAGEConv示例中的 forward() 函数。

输入图对象的规范检测

def forward(self, graph, feat):with graph.local_scope():# 指定图类型,然后根据图类型扩展输入特征feat_src, feat_dst = expand_as_pair(feat, graph)

graph.local_scope():限定语句块内为局部作用域,对数据特征的操作不影响原始图的特征,常用于forward方法中,
用法:

def foo(g):with g.local_scope():g.edata['h'] = torch.ones((g.num_edges(), 3))g.edata['h2'] = torch.ones((g.num_edges(), 3))return g.edata['h']

in-place操作会影响原始图数据
如:

def foo(g):with g.local_scope():# in-place operationg.edata['h'] += 1return g.edata['h']

参考dgl API

forward() 函数需要处理输入的许多极端情况,这些情况可能导致计算和消息传递中的值无效。 比如在 GraphConv 等conv模块中,DGL会检查输入图中是否有入度为0的节点。 当1个节点入度为0时, mailbox 将为空,并且聚合函数的输出值全为0, 这可能会导致模型性能不佳。但是,在 SAGEConv 模块中,被聚合的特征将会与节点的初始特征拼接起来, forward() 函数的输出不会全为0。在这种情况下,无需进行此类检验。

DGL NN模块可在不同类型的图输入中重复使用,包括:同构图、异构图(1.5 异构图)和子图块(第6章:在大图上的随机(批次)训练)。

SAGEConv的数学公式如下:

hN(dst)(l+1)=aggregate({hsrcl,∀src∈N(dst)})h_{\mathcal{N}(dst)}^{(l+1)} = \mathrm{aggregate} \left(\{h_{src}^{l}, \forall src \in \mathcal{N}(dst) \}\right)hN(dst)(l+1)​=aggregate({hsrcl​,∀src∈N(dst)})

hdst(l+1)=σ(W⋅concat(hdstl,hN(dst)l+1)+b)h_{dst}^{(l+1)} = \sigma \left(W \cdot \mathrm{concat} (h_{dst}^{l}, h_{\mathcal{N}(dst)}^{l+1}) + b \right)hdst(l+1)​=σ(W⋅concat(hdstl​,hN(dst)l+1​)+b)

hdst(l+1)=norm(hdstl+1)h_{dst}^{(l+1)} = \mathrm{norm}(h_{dst}^{l+1})hdst(l+1)​=norm(hdstl+1​)

源节点特征 feat_src 和目标节点特征 feat_dst 需要根据图类型被指定。 用于指定图类型并将 feat 扩展为 feat_srcfeat_dst 的函数是 expand_as_pair()。 该函数的细节如下所示。

def expand_as_pair(input_, g=None):if isinstance(input_, tuple):# 二分图的情况return input_elif g is not None and g.is_block:# 子图块的情况if isinstance(input_, Mapping):input_dst = {k: F.narrow_row(v, 0, g.number_of_dst_nodes(k))for k, v in input_.items()}else:input_dst = F.narrow_row(input_, 0, g.number_of_dst_nodes())return input_, input_dstelse:# 同构图的情况return input_, input_

对于同构图上的全图训练,源节点和目标节点相同,它们都是图中的所有节点。

在异构图的情况下,图可以分为几个二分图,每种关系对应一个。关系表示为 (src_type, edge_type, dst_dtype)。 当输入特征 feat 是1个元组时,图将会被视为二分图。元组中的第1个元素为源节点特征,第2个元素为目标节点特征。

在小批次训练中,计算应用于给定的一堆目标节点所采样的子图。子图在DGL中称为区块(block)。 在区块创建的阶段,dst nodes 位于节点列表的最前面。通过索引 [0:g.number_of_dst_nodes()] 可以找到 feat_dst

确定 feat_srcfeat_dst 之后,以上3种图类型的计算方法是相同的。

消息传递和聚合

def forward(self, graph, feat):with graph.local_scope():import dgl.function as fnimport torch.nn.functional as Ffrom dgl.utils import check_eq_shape# 指定图类型,然后根据图类型扩展输入特征feat_src, feat_dst = expand_as_pair(feat, graph)if self._aggre_type == 'mean':graph.srcdata['h'] = feat_srcgraph.update_all(fn.copy_u('h', 'm'), fn.mean('m', 'neigh'))h_neigh = graph.dstdata['neigh']elif self._aggre_type == 'gcn':check_eq_shape(feat)graph.srcdata['h'] = feat_srcgraph.dstdata['h'] = feat_dstgraph.update_all(fn.copy_u('h', 'm'), fn.sum('m', 'neigh'))# 除以入度degs = graph.in_degrees().to(feat_dst)h_neigh = (graph.dstdata['neigh'] + graph.dstdata['h']) / (degs.unsqueeze(-1) + 1)elif self._aggre_type == 'max_pool':graph.srcdata['h'] = F.relu(self.fc_pool(feat_src))graph.update_all(fn.copy_u('h', 'm'), fn.max('m', 'neigh'))h_neigh = graph.dstdata['neigh']else:raise KeyError('Aggregator type {} not recognized.'.format(self._aggre_type))# GraphSAGE中gcn聚合不需要fc_selfif self._aggre_type == 'gcn':rst = self.fc_neigh(h_neigh)else:rst = self.fc_self(h_self) + self.fc_neigh(h_neigh)

上面的代码执行了消息传递和聚合的计算。这部分代码会因模块而异。请注意,代码中的所有消息传递均使用 update_all() API和 DGL内置的消息/聚合函数来实现,以充分利用 2.2 编写高效的消息传递代码 里所介绍的性能优化。

聚合后,更新特征作为输出

# 激活函数
if self.activation is not None:rst = self.activation(rst)
# 归一化
if self.norm is not None:rst = self.norm(rst)
return rst

forward() 函数的最后一部分是在完成消息聚合后更新节点的特征。 常见的更新操作是根据构造函数中设置的选项来应用激活函数和进行归一化。

完整代码

import torch.nn as nnfrom dgl.utils import expand_as_pairclass SAGEConv(nn.Module):def __init__(self,in_feats,out_feats,aggregator_type,bias=True,norm=None,activation=None):super(SAGEConv, self).__init__()self._in_src_feats, self._in_dst_feats = expand_as_pair(in_feats)self._out_feats = out_featsself._aggre_type = aggregator_typeself.norm = normself.activation = activation# 聚合类型:mean、max_pool、lstm、gcnif aggregator_type not in ['mean', 'max_pool', 'lstm', 'gcn']:raise KeyError('Aggregator type {} not supported.'.format(aggregator_type))if aggregator_type == 'max_pool':self.fc_pool = nn.Linear(self._in_src_feats, self._in_src_feats)if aggregator_type == 'lstm':self.lstm = nn.LSTM(self._in_src_feats, self._in_src_feats, batch_first=True)if aggregator_type in ['mean', 'max_pool', 'lstm']:self.fc_self = nn.Linear(self._in_dst_feats, out_feats, bias=bias)self.fc_neigh = nn.Linear(self._in_src_feats, out_feats, bias=bias)self.reset_parameters()def reset_parameters(self):"""重新初始化可学习的参数"""gain = nn.init.calculate_gain('relu')if self._aggre_type == 'max_pool':nn.init.xavier_uniform_(self.fc_pool.weight, gain=gain)if self._aggre_type == 'lstm':self.lstm.reset_parameters()if self._aggre_type != 'gcn':nn.init.xavier_uniform_(self.fc_self.weight, gain=gain)nn.init.xavier_uniform_(self.fc_neigh.weight, gain=gain)def forward(self, graph, feat):with graph.local_scope():import dgl.function as fnimport torch.nn.functional as Ffrom dgl.utils import check_eq_shape# 指定图类型,然后根据图类型扩展输入特征feat_src, feat_dst = expand_as_pair(feat, graph)if self._aggre_type == 'mean':graph.srcdata['h'] = feat_srcgraph.update_all(fn.copy_u('h', 'm'), fn.mean('m', 'neigh'))h_neigh = graph.dstdata['neigh']elif self._aggre_type == 'gcn':check_eq_shape(feat)graph.srcdata['h'] = feat_srcgraph.dstdata['h'] = feat_dstgraph.update_all(fn.copy_u('h', 'm'), fn.sum('m', 'neigh'))# 除以入度degs = graph.in_degrees().to(feat_dst)h_neigh = (graph.dstdata['neigh'] + graph.dstdata['h']) / (degs.unsqueeze(-1) + 1)elif self._aggre_type == 'max_pool':graph.srcdata['h'] = F.relu(self.fc_pool(feat_src))graph.update_all(fn.copy_u('h', 'm'), fn.max('m', 'neigh'))h_neigh = graph.dstdata['neigh']else:raise KeyError('Aggregator type {} not recognized.'.format(self._aggre_type))# GraphSAGE中gcn聚合不需要fc_selfif self._aggre_type == 'gcn':rst = self.fc_neigh(h_neigh)else:rst = self.fc_self(h_self) + self.fc_neigh(h_neigh)# 激活函数if self.activation is not None:rst = self.activation(rst)# 归一化if self.norm is not None:rst = self.norm(rst)return rst

3.3 异构图上的GraphConv模块

异构图上的GraphConv模块

图神经网络框架DGL教程-第3章:构建图神经网络(GNN)模块相关推荐

  1. 图神经网络框架DGL教程-第4章:图数据处理管道

    更多图神经网络和深度学习内容请关注: 第4章:图数据处理管道 DGL在 dgl.data 里实现了很多常用的图数据集.它们遵循了由 dgl.data.DGLDataset 类定义的标准的数据处理管道. ...

  2. DGL教程【三】构建自己的GNN模块

    有时,利用现有的GNN模型进行堆叠无法满足我们的需求,例如我们希望通过考虑节点重要性或边权值来发明一种聚合邻居信息的新方法. 本节将介绍: DGL的消息传递API 自己实现一个GraphSage卷积模 ...

  3. 开源图神经网络框架DGL升级:GCMC训练时间从1天缩到1小时,RGCN实现速度提升291倍...

    乾明 编辑整理  量子位 报道 | 公众号 QbitAI 又一个AI框架迎来升级. 这次,是纽约大学.亚马逊联手推出图神经网络框架DGL. 不仅全面上线了对异构图的支持,复现并开源了相关异构图神经网络 ...

  4. 图神经网络框架DGL实现Graph Attention Network (GAT)笔记

    参考列表: [1]深入理解图注意力机制 [2]DGL官方学习教程一 --基础操作&消息传递 [3]Cora数据集介绍+python读取 一.DGL实现GAT分类机器学习论文 程序摘自[1],该 ...

  5. 神经网络学习小记录2——利用tensorflow构建循环神经网络(RNN)

    神经网络学习小记录2--利用tensorflow构建循环神经网络(RNN) 学习前言 RNN简介 tensorflow中RNN的相关函数 tf.nn.rnn_cell.BasicLSTMCell tf ...

  6. 亚马逊+纽约大学开源图神经网络框架DGL:新手友好,与主流框架无缝衔接

    量子位 授权转载 | 公众号 QbitAI 最近,纽约大学.纽约大学上海分校.AWS上海研究院以及AWS MXNet Science Team共同开源了一个面向图神经网络及图机器学习的全新框架,命名为 ...

  7. 图神经网络框架DGL学习 103——信息传递 (Message Passing Tutorial)

    在图神经网络中,信息的传递和特征的转变,用户可以自定义的.当然在DGL中,也有高级别的API供调用. 现在来看一个网页排名简单的模型.每一个节点都有相同的PV值,PV=0.01, 每一个节点首先会均匀 ...

  8. 图神经网络框架DGL学习 102——图、节点、边及其特征赋值

    101(入门)以后就是开始具体逐项学习图神经网络的各个细节.下面介绍: 1.如何构建图 2.将特征赋给节点或者边,及查询方法 这算是图神经网络最基础最基础的部分了. 一.如何构建图 DGL中创建的图的 ...

  9. 重磅:腾讯正式开源图计算框架Plato,十亿级节点图计算进入分钟级时代

    整理 | 唐小引 来源 | CSDN(ID:CSDNnews) 腾讯开源进化 8 年,进入爆发期. 继刚刚连续开源 TubeMQ.Tencent Kona JDK.TBase.TKEStack 四款重 ...

最新文章

  1. iloc loc 区别
  2. 中国电信天翼Live究竟胜算几何?
  3. 原地不动 福玛特机器人_智能扫地机器人一直在原地打转是怎么回事以及解决办法...
  4. JAVA数据结构-稀疏数组
  5. mysql主从架构备份,mysql数据库容灾实时备份主从架构
  6. Centos 配置多个虚拟IP
  7. [众包]Eclipse 运行简单亚马逊AMT模板
  8. “The server requested authentication method unknown to the client.”的解决方案
  9. 10打开没有反应_118个遇水反应化学品清单及高压反应釜操作经验
  10. CH 6202 黑暗城堡
  11. 计算机窗口闪屏,电脑闪屏怎么办?如何解决电脑经常闪屏问题
  12. Node.JS实战57:给图片加水印。
  13. dell服务器更换硬盘raid,DELL T620服务器硬盘坏,更换硬盘做RAID同步
  14. 亲身经历!4个月写完硕士毕业论文一稿过,我是如何做到的?
  15. 戴尔笔记本提示“您已插入低瓦数电源适配器 在bios设置中可以禁用此警告”
  16. win7制作ntp服务器,win7系统搭建ntp服务器的操作方法
  17. myeclipse新建项目部署到tomcat中,点击finish键没反应
  18. 【面试准备之】HR面试时100个关键问题
  19. java编写火车订票系统_毕业设计(论文)-基于JavaWeb技术的火车订票系统.doc
  20. 【转载】儒林外史人物——严贡生和严监生(一)

热门文章

  1. python repr函数_python编程 魔法函数之__str__和__repr__
  2. 浏览器开发者工具用法
  3. 百度智能云OCR身份证识别-SDK
  4. signature=2f0e364618bd844a5fe88c26cefcaa33,Microsoft Word - CEFC_Failure_Detection_Resume.doc
  5. python求字符串中循环节个数
  6. 弹性布局(伸缩布局)
  7. 这个代码生成器火了…SmartSoftHelp
  8. Uncaught (in promise) 的解决方法
  9. DataScience:基于GiveMeSomeCredit数据集利用特征工程处理、逻辑回归LoR算法实现构建风控中的金融评分卡模型
  10. #define的一些用法