PyG MessagePassing机制源码分析


Google在2017发表的论文Neural Message Passing for Quantum Chemistry中提到的Message Passing Neural Networks机制成为了后来图机器学习计算的标准范式实现。

而PyG提供了信息传递(邻居聚合) 操作的框架模型。

其中,
□\square表示 可微、排列不变 的函数,比如说summeanmax
γ\gammaγϕ\phiϕ 表示 可微 的函数,比如说 MLP

在propagate中,依次会调用messageaggregateupdate函数。
其中,
message为公式中 ϕ\phiϕ 部分,表示特征传递
aggregate为公式中 □\square 部分,表示特征聚合
update为公式中 γ\gammaγ 部分,表示特征更新

MessagePassing类

PyG使用MessagePassing类作为实现 信息传递 机制的基类。我们只需要继承其即可。
下面,我们以GCN为例子
GCN信息传递公式如下:

源码分析

一般的图卷积层是通过的forward函数进行调用的,通常的调用顺序如下,那么是如何将自定义的参数kwargs与后续的函数的入参进行对应的呢?(图来源:https://blog.csdn.net/minemine999/article/details/119514944)

MessagePassing初始化构建了Inspector类, 其主要的作用是对子类中自定义的message,aggregate,message_and_aggregate,以及update函数的参数的提取。

class MessagePassing(torch.nn.Module):special_args: Set[str] = {'edge_index', 'adj_t', 'edge_index_i', 'edge_index_j', 'size','size_i', 'size_j', 'ptr', 'index', 'dim_size'}def __init__(self, aggr: Optional[str] = "add",flow: str = "source_to_target", node_dim: int = -2,decomposed_layers: int = 1):super().__init__()self.aggr = aggrassert self.aggr in ['add', 'sum', 'mean', 'min', 'max', 'mul', None]self.flow = flowassert self.flow in ['source_to_target', 'target_to_source']self.node_dim = node_dimself.decomposed_layers = decomposed_layersself.inspector = Inspector(self)self.inspector.inspect(self.message)self.inspector.inspect(self.aggregate, pop_first=True)self.inspector.inspect(self.message_and_aggregate, pop_first=True)self.inspector.inspect(self.update, pop_first=True)self.inspector.inspect(self.edge_update)self.__user_args__ = self.inspector.keys(['message', 'aggregate', 'update']).difference(self.special_args)self.__fused_user_args__ = self.inspector.keys(['message_and_aggregate', 'update']).difference(self.special_args)self.__edge_user_args__ = self.inspector.keys(['edge_update']).difference(self.special_args)

inspect函数中,inspect.signature(func).parameters, 获取了子类的函数入参,比如当func="message"时,params = inspect.signature(‘message’).parameters就会获得子类自定义message函数的参数,

class Inspector(object):def __init__(self, base_class: Any):self.base_class: Any = base_classself.params: Dict[str, Dict[str, Any]] = {}def inspect(self, func: Callable,pop_first: bool = False) -> Dict[str, Any]:## 注册func函数的入参,并建立func与入参之间的对应关系params = inspect.signature(func).parametersparams = OrderedDict(params)if pop_first:

参数的传递过程:
从上图可知,参数是从forward传递进来的,而propagate将参数传递后面到对应的函数中,这部分的参数对应关系主要由MessagePassing类的__collect__函数进行参数收集和数据赋值。

__collect__函数中的args主要对应子类中相关函数(message,aggregate,update等)的自定义参数self.__user_args__kwargs为子类的forward函数中调用propagate传递进来的参数。

self.__user_args___i_j后缀是非常重要的参数,其中i表示与target节点相关的参数,j表示source节点相关的参数,其图上的指向为j->i for j 属于N(i),后缀不包含_i_j的参数直接被透传。(默认:self.flow==source_to_target)

def __collect__(self, args, edge_index, size, kwargs):i, j = (1, 0) if self.flow == 'source_to_target' else (0, 1)out = {}for arg in args:# 遍历自定义函数中的参数if arg[-2:] not in ['_i', '_j']: # 不包含_i和_j的自定义参数直接透传out[arg] = kwargs.get(arg, Parameter.empty) # 从用户传递进来的kwargs参数中获取值else:dim = 0 if arg[-2:] == '_j' else 1 # 注意这里的取值维度data = kwargs.get(arg[:-2], Parameter.empty) # 取用户传递进来的kwargs前缀arg[:-2]的数据if isinstance(data, (tuple, list)):assert len(data) == 2if isinstance(data[1 - dim], Tensor):self.__set_size__(size, 1 - dim, data[1 - dim])data = data[dim]if isinstance(data, Tensor):self.__set_size__(size, dim, data)data = self.__lift__(data, edge_index,j if arg[-2:] == '_j' else i)out[arg] = dataif isinstance(edge_index, Tensor):out['adj_t'] = Noneout['edge_index'] = edge_indexout['edge_index_i'] = edge_index[i]out['edge_index_j'] = edge_index[j]out['ptr'] = Noneelif isinstance(edge_index, SparseTensor):out['adj_t'] = edge_indexout['edge_index'] = Noneout['edge_index_i'] = edge_index.storage.row()out['edge_index_j'] = edge_index.storage.col()out['ptr'] = edge_index.storage.rowptr()out['edge_weight'] = edge_index.storage.value()out['edge_attr'] = edge_index.storage.value()out['edge_type'] = edge_index.storage.value()out['index'] = out['edge_index_i']out['size'] = sizeout['size_i'] = size[1] or size[0]out['size_j'] = size[0] or size[1]out['dim_size'] = out['size_i']return out

propagate中依次从coll_dict中获取与messageaggregateupdate函数的参数进行调用。注意这里获取的参数是通过上述的self.inspector.distribute函数进行获取的。

def propagate(self,..):##...##...msg_kwargs = self.inspector.distribute('message', coll_dict)out = self.message(**msg_kwargs)##...##...aggr_kwargs = self.inspector.distribute('aggregate', coll_dict)out = self.aggregate(out, **aggr_kwargs)update_kwargs = self.inspector.distribute('update', coll_dict)return self.update(out, **update_kwargs)

自定义 message , aggregate , update

   def message(self, x_i, x_j, norm):# x_j ::= x[edge_index[0]] shape = [E, out_channels]# x_i ::= x[edge_index[1]] shape = [E, out_channels]print("x_j", x_j.shape, x_j)print("x_i: ", x_i.shape, x_i)# norm.view(-1, 1).shape = [E, 1]# Step 4: Normalize node features.return norm.view(-1, 1) * x_jdef aggregate(self, inputs: Tensor, index: Tensor,ptr: Optional[Tensor] = None,dim_size: Optional[int] = None) -> Tensor:# 第一个参数不能变化# index ::= edge_index[1]# dim_size ::= [number of node]print("agg_index: ",index)print("agg_dim_size: ",dim_size)# Step 5: Aggregate the messages.# out.shape = [number of node, out_channels]out = scatter(inputs, index, dim=self.node_dim, dim_size=dim_size)print("agg_out:",out.shape,out)return outdef update(self, inputs: Tensor, x_i, x_j) -> Tensor:# 第一个参数不能变化# inputs ::= aggregate.out# Step 6: Return new node embeddings.print("update_x_i: ",x_i.shape,x_i)print("update_x_j: ",x_j.shape,x_j)print("update_inputs: ",inputs.shape, inputs)return inputs

GCN Demo

from typing import Optional
from torch_scatter import scatter
import torch
import numpy as np
import random
import os
from torch import Tensor
from torch_geometric.nn import MessagePassing
from torch_geometric.utils import add_self_loops, degreeclass GCNConv(MessagePassing):def __init__(self, in_channels, out_channels):super().__init__(aggr='add')  # "Add" aggregation (Step 5).self.lin = torch.nn.Linear(in_channels, out_channels)def forward(self, x, edge_index):# x has shape [N, in_channels]# edge_index has shape [2, E]# Step 1: Add self-loops to the adjacency matrix.edge_index, _ = add_self_loops(edge_index, num_nodes=x.size(0))# Step 2: Linearly transform node feature matrix.x = self.lin(x) # x = lin(x)# Step 3: Compute normalization.row, col = edge_index # row, col is the [out index] and [in index]deg = degree(col, x.size(0), dtype=x.dtype) # [in_degree] of each node, deg.shape = [N]deg_inv_sqrt = deg.pow(-0.5)deg_inv_sqrt[deg_inv_sqrt == float('inf')] = 0norm = deg_inv_sqrt[row] * deg_inv_sqrt[col] # deg_inv_sqrt.shape = [E]# Step 4-6: Start propagating messages.return self.propagate(edge_index, x=x, norm=norm)def message(self, x_i, x_j, norm):# x_j ::= x[edge_index[0]] shape = [E, out_channels]# x_i ::= x[edge_index[1]] shape = [E, out_channels]print("x_j", x_j.shape, x_j)print("x_i: ", x_i.shape, x_i)# norm.view(-1, 1).shape = [E, 1]# Step 4: Normalize node features.return norm.view(-1, 1) * x_jdef aggregate(self, inputs: Tensor, index: Tensor,ptr: Optional[Tensor] = None,dim_size: Optional[int] = None) -> Tensor:# 第一个参数不能变化# index ::= edge_index[1]# dim_size ::= [number of node]print("agg_index: ",index)print("agg_dim_size: ",dim_size)# Step 5: Aggregate the messages.# out.shape = [number of node, out_channels]out = scatter(inputs, index, dim=self.node_dim, dim_size=dim_size)print("agg_out:",out.shape,out)return outdef update(self, inputs: Tensor, x_i, x_j) -> Tensor:# 第一个参数不能变化# inputs ::= aggregate.out# Step 6: Return new node embeddings.print("update_x_i: ",x_i.shape,x_i)print("update_x_j: ",x_j.shape,x_j)print("update_inputs: ",inputs.shape, inputs)return inputsdef set_seed(seed=1029):random.seed(seed)os.environ['PYTHONHASHSEED'] = str(seed) # 为了禁止hash随机化,使得实验可复现np.random.seed(seed)torch.manual_seed(seed)torch.cuda.manual_seed(seed)torch.cuda.manual_seed_all(seed) # if you are using multi-GPU.torch.backends.cudnn.benchmark = Falsetorch.backends.cudnn.deterministic = Trueif __name__ == '__main__':set_seed(0)# x.shape = [5, 2]x = torch.tensor([[1,2], [3,4], [3,5], [4,5], [2,6]], dtype=torch.float)# edge_index.shape = [2, 6]edge_index = torch.tensor([[0,1,2,3,1,4], [1,0,3,2,4,1]])print("num_node: ",x.shape[0])print("num_edge: ",edge_index.shape[1])in_channels = x.shape[1]out_channels = 3gcn = GCNConv(in_channels, out_channels)out = gcn(x, edge_index)print(out)

PyG MessagePassing机制源码分析相关推荐

  1. Apache Storm 实时流处理系统通信机制源码分析

    我们今天就来仔细研究一下Apache Storm 2.0.0-SNAPSHOT的通信机制.下面我将从大致思想以及源码分析,然后我们细致分析实时流处理系统中源码通信机制研究. 1. 简介 Worker间 ...

  2. Spark资源调度机制源码分析--基于spreadOutApps及非spreadOutApps两种资源调度算法

    Spark资源调度机制源码分析--基于spreadOutApps及非spreadOutApps两种资源调度算法 1.spreadOutApp尽量平均分配到每个executor上: 2.非spreadO ...

  3. ART虚拟机 | Cleaner机制源码分析

    目录 思考问题 1.Android为什么要将Finalize机制替换成Cleaner机制? 2.Cleaner机制回收Native堆内存的原理是什么? 3.Cleaner机制源码是如何实现的? 一.版 ...

  4. k8s 驱逐eviction机制源码分析

    原理部分 1. 驱逐概念介绍 kubelet会定期监控node的内存,磁盘,文件系统等资源,当达到指定的阈值后,就会先尝试回收node级别的资源,比如当磁盘资源不足时会删除不同的image,如果仍然在 ...

  5. Android——RIL 机制源码分析

    Android 电话系统框架介绍 在android系统中rild运行在AP上,AP上的应用通过rild发送AT指令给BP,BP接收到信息后又通过rild传送给AP.AP与BP之间有两种通信方式: 1. ...

  6. Linux Thermal机制源码分析之Governor

    一.thermal_init() 在开始源码分析之前,需要先说明一下.Linux 内核代码庞大而复杂,如何 reading the Fxxking source code 相信是很多从事 Linux ...

  7. Nacos 服务端健康检查及客户端服务订阅机制源码分析(三)

    Nacos 服务端健康检查 长连接 概念:长连接,指在一个连接上可以连续发送多个数据包,在连接保持期间,如果没有数据包发送,需要双方发送链路检测包 注册中心客户端 2.0 以后使用 gRPC 代替 h ...

  8. Handler机制源码分析

    一.Handler使用上需要注意的几点 1.1 handler使用不当造成的内存泄漏 public class MainActivity extends AppCompatActivity {priv ...

  9. HashMap扩容机制源码分析

    前几天写了一篇,ArrayList扩容源码分析.好像源码也没有我们想象的那么可怕?(当然了,只是简单的分析,后面等我知识充足了,将进一步的分析) 今天本来想打游戏的,但是网速太差了,真是的是让人火爆. ...

  10. Android -- 消息处理机制源码分析(Looper,Handler,Message)

    android的消息处理有三个核心类:Looper,Handler和Message.其实还有一个Message Queue(消息队列),但是MQ被封装到Looper里面了,我们不会直接与MQ打交道,因 ...

最新文章

  1. matlab 跳步循环,跳步急停是用单脚或双脚起跳,上体稍后仰,两脚同时平行落地。落地时()着地,用前脚掌内侧抵蹬住地面,两膝弯曲,降低重心,两臂屈肘微张,以保持身体平衡。...
  2. 分类算法:决策树(C4.5)
  3. 2、IDEA以新窗口的形式打开多个项目
  4. 15年3月c语言试卷,2015年3月计算机二级C语言试卷及答案..doc
  5. 由浅入深了解EventBus:(五)
  6. LINUX开机自启问题
  7. dbc数据库 与 mysql_【图片】DBC2000安装及数据库详细解析(不断更行中......)【dbc2000吧】_百度贴吧...
  8. ERP实施项目的计划阶段要点分析
  9. 给英文文章加音标,建生词表
  10. pageoffice在线编辑时向保存方法传递参数
  11. 6.3 探索性空间数据分析
  12. python读取excel合并单元_python 读写excel (合并单元格)
  13. Preparing Your Data for Use with robot_localization 准备 robot_localization 数据
  14. 如何删除Word中的边框线
  15. 《程序人生》2020无畏年少青春,迎风潇洒前行,杭漂程序员2019的心路历程,披荆斩棘终雨过天晴
  16. 读书笔记——WebKit技术内幕 WebKit架构和模块
  17. C++11(及现代C++风格)和快速迭代式开发 -- 刘未鹏
  18. HC32F460 SPI DMA 驱动 TFT显示屏
  19. Node.JS实战57:给图片加水印。
  20. 图论(十四)——图的着色

热门文章

  1. 最网最全python框架--scrapy(体系学习,爬取全站校花图片),学完显著提高爬虫能力(附源代码),突破各种反爬
  2. 组合数学 —— 母函数
  3. 解决MacOs10.15+ shimo 无法正常使用 PPTP协议问题
  4. thrift 问题梳理
  5. 微软官方dllcache恢复的批处理
  6. 苹果计算机取消用户名和密码进入不,苹果电脑怎么退出账户登录不了怎么办
  7. python基础教程第4版pdf百度云-Python入门书籍电子版PDF百度云网盘免费下载
  8. linux 卸载 sdcc,Linux sdcc安装
  9. 利用Backtrader进行期权回测之五:用backtrader_plotting查看回测结果
  10. 通过修改window本地hosts文件修改域名指向