一、消息传递范式介绍

消息传递范式是一种聚合邻接节点信息来更新中心节点信息的范式,它将卷积算子推广到不规则数据领域,实现了图与神经网络的连接。此范式包含三个步骤:(1)邻接节点信息变换;(2)邻接节点信息聚合到中心节点;(3)聚合信息变换。

消息传递图神经网络可以描述为:
xi(k)=γ(k)(xi(k−1),□j∈N(i)ϕ(k)(xi(k−1),xj(k−1),ej,i)),\mathbf{x}_i^{(k)} = \gamma^{(k)} \left( \mathbf{x}_i^{(k-1)}, \square_{j \in \mathcal{N}(i)} \, \phi^{(k)}\left(\mathbf{x}_i^{(k-1)}, \mathbf{x}_j^{(k-1)},\mathbf{e}_{j,i}\right) \right), xi(k)​=γ(k)(xi(k−1)​,□j∈N(i)​ϕ(k)(xi(k−1)​,xj(k−1)​,ej,i​)),
xi(k−1)∈RF\mathbf{x}^{(k-1)}_i\in\mathbb{R}^Fxi(k−1)​∈RF表示(k-1)层中节点i的节点特征,ej,i∈RD\mathbf{e}_{j,i} \in \mathbb{R}^Dej,i​∈RD表示从节点j到节点i的边的特征,□\square□表示可微分的、具有排列不变形的函数,具有排列不变形的函数有和函数、均值函数和最大值函数。γ\gammaγ和ϕ\phiϕ表示可微分的函数。

二、Pytorch Geometric中的MessagePassing基类

Pytorch Geometric提供了MessagePassing类,实现了消息传播的自动处理,继承该基类可以方便地构造消息传递图神经网络,我们只需要定义函数ϕ\phiϕ(即message函数)和函数γ\gammaγ(即update函数),以及消息聚合方案(aggr=“add”、aggr="mean"或aggr=“max”)。

  • MessagePassing(aggr=“add”, flow=“source_to_target”, node_dim=-2):
    aggr: 定义要使用的聚合方案(“add”、“mean"或"max”)
    flow: 定义消息传递的流向(“source_to_target"或"target_to_source”)
    node_dim: 定义沿着哪个轴线传播

  • MessagePassing.propagate(edge_index, size=None, **kwargs):
    开始传播消息的起始调用。它以edge_index(边的端点的索引)和flow(消息的流向)以及一些额外的数据为参数。
    size=(N,M)设置对称邻接矩阵的形状。

  • MessagePassing.message(…)接受最初传递给propagate函数的所有参数。

  • MessagePassing.aggregate(…)将从源节点传递过来的消息聚合在目标节点上,一般可选的聚合方式有sum,mean和max。

  • MessagePassing.message_and_aggregate(…)融合了邻接节点信息变换和邻接节点信息聚合。

  • MessagePassing.update(aggr_out, …)为每个节点更新节点表征,即实现γ\gammaγ函数。该函数以聚合函数的输出为第一参数,并接收所有传递给propagate函数的参数。

三、继承MessagePassing类的GCNConv

GCNConv的数学定义为:
xi(k)=∑j∈N(i)∪{i}1deg⁡(i)⋅deg⁡(j)⋅(Θ⋅xj(k−1)),\mathbf{x}_i^{(k)} = \sum_{j \in \mathcal{N}(i) \cup \{ i \}} \frac{1}{\sqrt{\deg(i)} \cdot \sqrt{\deg(j)}} \cdot \left( \mathbf{\Theta} \cdot \mathbf{x}_j^{(k-1)} \right), xi(k)​=j∈N(i)∪{i}∑​deg(i)​⋅deg(j)​1​⋅(Θ⋅xj(k−1)​),
其中相邻节点的特征通过权重矩阵Θ\mathbf{\Theta}Θ进行转换,然后按端点的度进行归一化处理,最后进行加总。这个公式可以分为以下几个步骤:

  1. 向邻接矩阵添加自环边。
  2. 线性转换节点特征矩阵。
  3. 计算归一化系数。
  4. 归一化j中的节点特征。
  5. 将相邻节点特征相加。
import torch
from torch_geometric.nn import MessagePassing
from torch_geometric.utils import add_self_loops, degreeclass GCNConv(MessagePassing):def __init__(self, in_channels, out_channels):super(GCNConv, self).__init__(aggr='add', flow='source_to_target')# "Add" aggregation (Step 5).# flow='source_to_target' 表示消息从源节点传播到目标节点self.lin = torch.nn.Linear(in_channels, out_channels)def forward(self, x, edge_index):# x has shape [N, in_channels]# edge_index has shape [2, E]# Step 1: Add self-loops to the adjacency matrix.edge_index, _ = add_self_loops(edge_index, num_nodes=x.size(0))# Step 2: Linearly transform node feature matrix.x = self.lin(x)# Step 3: Compute normalization.row, col = edge_indexdeg = degree(col, x.size(0), dtype=x.dtype)deg_inv_sqrt = deg.pow(-0.5)norm = deg_inv_sqrt[row] * deg_inv_sqrt[col]# Step 4-5: Start propagating messages.return self.propagate(edge_index, x=x, norm=norm)def message(self, x_j, norm):# x_j has shape [E, out_channels]# Step 4: Normalize node features.return norm.view(-1, 1) * x_j# 初始化和调用
conv = GCNConv(16, 32)
x = conv(x, edge_index)

四、复写message函数

class GCNConv(MessagePassing):def forward(self, x, edge_index):# ....return self.propagate(edge_index, x=x, norm=norm, d=d)def message(self, x_j, norm, d_i):# x_j has shape [E, out_channels]return norm.view(-1, 1) * x_j * d_i

五、覆写aggregate函数

class GCNConv(MessagePassing):def __init__(self, in_channels, out_channels):super(GCNConv, self).__init__(aggr='add', flow='source_to_target')def forward(self, x, edge_index):# ....return self.propagate(edge_index, x=x, norm=norm, d=d)def aggregate(self, inputs, index, ptr, dim_size):print(self.aggr)print("`aggregate` is called")return super().aggregate(inputs, index, ptr=ptr, dim_size=dim_size)

六、覆写aggregate函数

class GCNConv(MessagePassing):def __init__(self, in_channels, out_channels):super(GCNConv, self).__init__(aggr='add', flow='source_to_target')def forward(self, x, edge_index):# ....return self.propagate(edge_index, x=x, norm=norm, d=d)def aggregate(self, inputs, index, ptr, dim_size):print(self.aggr)print("`aggregate` is called")return super().aggregate(inputs, index, ptr=ptr, dim_size=dim_size)

七、覆写message_and_aggregate函数

from torch_sparse import SparseTensorclass GCNConv(MessagePassing):def __init__(self, in_channels, out_channels):super(GCNConv, self).__init__(aggr='add', flow='source_to_target')def forward(self, x, edge_index):# ....adjmat = SparseTensor(row=edge_index[0], col=edge_index[1], value=torch.ones(edge_index.shape[1]))# 此处传的不再是edge_idex,而是SparseTensor类型的Adjancency Matrixreturn self.propagate(adjmat, x=x, norm=norm, d=d)def message(self, x_j, norm, d_i):# x_j has shape [E, out_channels]return norm.view(-1, 1) * x_j * d_i # 这里不管正确性def aggregate(self, inputs, index, ptr, dim_size):print(self.aggr)print("`aggregate` is called")return super().aggregate(inputs, index, ptr=ptr, dim_size=dim_size)def message_and_aggregate(self, adj_t, x, norm):print('`message_and_aggregate` is called')

八、覆写update函数

class GCNConv(MessagePassing):def __init__(self, in_channels, out_channels):super(GCNConv, self).__init__(aggr='add', flow='source_to_target')def update(self, inputs: Tensor) -> Tensor:return inputs

消息传递的图神经网络相关推荐

  1. 百度图神经网络学习——day03:图神经网络算法(一)

    文章目录 一.图卷积网络(Graph Convolutional Network) 1.核心公式 2.算法流程 二.图注意力算法(GAT) 1.计算方法 2.多头Attention 三.空间GNN 四 ...

  2. 图神经网络详解及其在交通预测方面的应用

    layout: mypost title: 图神经网络及其在交通预测方面的应用 categories: [Traffic prediction, Graph Neural Networks] 图神经网 ...

  3. PGL图学习之图神经网络GNN模型GCN、GAT

    在4922份提交内容中,主要涉及13个研究方向,具体有: 1.AI应用应用,例如:语音处理.计算机视觉.自然语言处理等 2.深度学习和表示学习 3.通用机器学习 4.生成模型 5.基础设施,例如:数据 ...

  4. 【笔记整理】图神经网络学习

    [笔记整理]图神经网络学习 文章目录 [笔记整理]图神经网络学习 一.GNN简介 1.图结构 & 图基础算法 1)引言("非欧几何, 处理图数据的NN") 2)图基本概念 ...

  5. 图神经网络解偏微分方程系列(一)

    图神经网络解偏微分方程系列(一) 1. 标题和概述 Learning continuous-time PDEs from sparse(稀疏) data with graph neural netwo ...

  6. Graph Decipher: A transparent dual-attention graph neural network 图解密器:一种透明的双注意图神经网络,用于理解节点分类的消息传递机制

    引用 Pang Y, Liu C. Graph Decipher: A transparent dual-attention graph neural network to understand th ...

  7. 【图神经网络DGL】GCN在Karate Club上的实战(消息传递范式 | 生成训练可视化动图)

    学习总结 回顾[图神经网络DGL]数据封装和消息传递机制 的数据封装,在做异构图神经网络时,DGL比PyG方便很多(尽管PyG已经支持了异构图Aminer和栗子,但对图结构数据做批处理还是需要自己实现 ...

  8. Graph Representation 图神经网络

    Graph Representation 图神经网络 图表示学习(representation learning)--图神经网络框架,主要涉及PyG.DGL.Euler.NeuGraph和AliGra ...

  9. 图神经网络快速爆发,最新进展都在这里了

    译者 | 刘畅 出品 | AI科技大本营(rgznai100) 近年来,图神经网络(GNNs)发展迅速,最近的会议上发表了大量相关的研究论文.本文作者正在整理一个GNN的简短介绍和最新研究报告的摘要. ...

最新文章

  1. 2020校招薪酬大比拼,你被倒挂了没?
  2. Redis集群:哨兵(Sentinel)
  3. CSS DIV Shadow
  4. oracle大表如何快速删除一列,Oracle 对表中的记录进行大批量删除
  5. python 元组捷豹_GitHub - jaguarzls/pyecharts: Python Echarts Plotting Library
  6. linux在xt文件写入内容,0728linux基础内容小记
  7. 面趣 | 那些面试没过的程序员,都去了哪里?答案真的挺励志
  8. github设置仓库可见性 私人仓库设置他人协作/可见
  9. 从折叠屏到AR 三星Galaxy新品预热宣传片大招频现
  10. Aspose.Words操作Word.PDF,让图片和文本垂直居中,水平居中解决方案
  11. 要读顶级会议上的论文
  12. 读Thinking in Java(1~4)
  13. windows 搭建HTTP文件服务器(Nginx 方式)
  14. 分享视频分析软件常用的几个C++库
  15. matlab图像画轮毂,轮毂设计及三维造型(全套图纸三维).doc
  16. 《东周列国志》第六十七回 卢蒲癸计逐庆封 楚灵王大合诸侯
  17. W25Q64简介(译)
  18. 输入数字怎么变成大写python_用Python将数字转换为中文大写
  19. QT串口助手(五):文件操作
  20. cad画钟表_cad应用环形矩阵制作钟表盘

热门文章

  1. django-聚合函数
  2. django-行对向的反向查找
  3. 关于GTID模式下备份时 --set-gtid-purged=OFF 参数的实验【转】
  4. 【gradle】mac下 gradle默认本地仓库位置
  5. 《产品设计与开发(原书第5版)》——3.8 步骤5:选出最佳机会方案
  6. flask笔记3-模板
  7. lintcode 中等题:Single number III 落单的数III
  8. 路径匹配C++变量文件夹下所有文件
  9. CentOS下如何配置LAMP环境
  10. 远离ARP*** ARP防火墙新版发布