GraphSAGE: 算法原理,实现和应用
在上一篇文章中介绍了GCN
浅梦:【Graph Neural Network】GCN: 算法原理,实现和应用zhuanlan.zhihu.com
GCN是一种在图中结合拓扑结构和顶点属性信息学习顶点的embedding表示的方法。然而GCN要求在一个确定的图中去学习顶点的embedding,无法直接泛化到在训练过程没有出现过的顶点,即属于一种直推式(transductive)的学习。
本文介绍的GraphSAGE则是一种能够利用顶点的属性信息高效产生未知顶点embedding的一种归纳式(inductive)学习的框架。
其核心思想是通过学习一个对邻居顶点进行聚合表示的函数来产生目标顶点的embedding向量。
GraphSAGE算法原理
GraphSAGE 是Graph SAmple and aggreGatE的缩写,其运行流程如上图所示,可以分为三个步骤
1. 对图中每个顶点邻居顶点进行采样
2. 根据聚合函数聚合邻居顶点蕴含的信息
3. 得到图中各顶点的向量表示供下游任务使用
采样邻居顶点
出于对计算效率的考虑,对每个顶点采样一定数量的邻居顶点作为待聚合信息的顶点。设采样数量为k,若顶点邻居数少于k,则采用有放回的抽样方法,直到采样出k个顶点。若顶点邻居数大于k,则采用无放回的抽样。
当然,若不考虑计算效率,我们完全可以对每个顶点利用其所有的邻居顶点进行信息聚合,这样是信息无损的。
生成向量的伪代码
这里K是网络的层数,也代表着每个顶点能够聚合的邻接点的跳数,如K=2的时候每个顶点可以最多根据其2跳邻接点的信息学习其自身的embedding表示。
在每一层的循环k中,对每个顶点v,首先使用v的邻接点的k-1层的embedding表示 来产生其邻居顶点的第k层聚合表示 ,之后将 和顶点v的第k-1层表示 进行拼接,经过一个非线性变换产生顶点v的第k层embedding表示 。
聚合函数的选取
由于在图中顶点的邻居是天然无序的,所以我们希望构造出的聚合函数是对称的(即改变输入的顺序,函数的输出结果不变),同时具有较高的表达能力。
- MEAN aggregator
上式对应于伪代码中的第4-5行,直接产生顶点的向量表示,而不是邻居顶点的向量表示。 mean aggregator将目标顶点和邻居顶点的第k-1层向量拼接起来,然后对向量的每个维度进行求均值的操作,将得到的结果做一次非线性变换产生目标顶点的第k层表示向量。
- Pooling aggregator
Pooling aggregator 先对目标顶点的邻接点表示向量进行一次非线性变换,之后进行一次pooling操作(maxpooling or meanpooling),将得到结果与目标顶点的表示向量拼接,最后再经过一次非线性变换得到目标顶点的第k层表示向量。
- LSTM aggregator
LSTM相比简单的求平均操作具有更强的表达能力,然而由于LSTM函数不是关于输入对称的,所以在使用时需要对顶点的邻居进行一次乱序操作。
参数的学习
在定义好聚合函数之后,接下来就是对函数中的参数进行学习。文章分别介绍了无监督学习和监督学习两种方式。
- 无监督学习形式
基于图的损失函数希望临近的顶点具有相似的向量表示,同时让分离的顶点的表示尽可能区分。 目标函数如下
其中v是通过固定长度的随机游走出现在u附近的顶点, 是负采样的概率分布, 是负样本的数量。
与DeepWalk不同的是,这里的顶点表示向量是通过聚合顶点的邻接点特征产生的,而不是简单的进行一个embedding lookup操作得到。
- 监督学习形式
监督学习形式根据任务的不同直接设置目标函数即可,如最常用的节点分类任务使用交叉熵损失函数。
GraphSAGE的实现
这里以MEAN aggregator简单讲下聚合函数的实现
features, node, neighbours = inputsnode_feat = tf.nn.embedding_lookup(features, node)
neigh_feat = tf.nn.embedding_lookup(features, neighbours)concat_feat = tf.concat([neigh_feat, node_feat], axis=1)
concat_mean = tf.reduce_mean(concat_feat,axis=1,keep_dims=False)output = tf.matmul(concat_mean, self.neigh_weights)
if self.use_bias:output += self.bias
if self.activation:output = self.activation(output)
对于第 层的aggregator,features
为第 层所有顶点的向量表示矩阵,node
和neighbours
分别为第k层采样得到的顶点集合及其对应的邻接点集合。
首先通过embedding_lookup
操作获取得到顶点和邻接点的第 层的向量表示。然后通过concat
将他们拼接成一个(batch_size,1+neighbour_size,embeding_size)
的张量,使用reduce_mean
对每个维度求均值得到一个(batch_size,embedding_size)
的张量。
最后经过一次非线性变换得到output
,即所有顶点的第 层的表示向量
- GraphSAGE 下面是完整的GraphSAGE方法的代码
def GraphSAGE(feature_dim, neighbor_num, n_hidden, n_classes, use_bias=True, activation=tf.nn.relu,aggregator_type='mean', dropout_rate=0.0, l2_reg=0):features = Input(shape=(feature_dim,))node_input = Input(shape=(1,), dtype=tf.int32)neighbor_input = [Input(shape=(l,),dtype=tf.int32) for l in neighbor_num]if aggregator_type == 'mean':aggregator = MeanAggregatorelse:aggregator = PoolingAggregatorh = featuresfor i in range(0, len(neighbor_num)):if i > 0:feature_dim = n_hiddenif i == len(neighbor_num) - 1:activation = tf.nn.softmaxn_hidden = n_classesh = aggregator(units=n_hidden, input_dim=feature_dim, activation=activation, l2_reg=l2_reg, use_bias=use_bias,dropout_rate=dropout_rate, neigh_max=neighbor_num[i])([h, node_input,neighbor_input[i]])#output = hinput_list = [features, node_input] + neighbor_inputmodel = Model(input_list, outputs=output)return model
其中feature_dim
表示顶点属性特征向量的维度,neighbor_num
是一个list
表示每一层抽样的邻居顶点的数量,n_hidden
为聚合函数内部非线性变换时的参数矩阵的维度,n_classes
表示预测的类别的数量,aggregator_type
为使用的聚合函数的类别。
GraphSAGE应用
本例中的训练,评测和可视化的完整代码在下面的git仓库中
shenweichen/GraphNeuralNetworkgithub.com
这里我们使用引文网络数据集Cora进行测试,Cora数据集包含2708个顶点, 5429条边,每个顶点包含1433个特征,共有7个类别。
按照论文的设置,从每个类别中选取20个共140个顶点作为训练,500个顶点作为验证集合,1000个顶点作为测试集。 采样时第1层采样10个邻居,第2层采样25个邻居。
- 节点分类任务结果
通过多次运行准确率在0.80-0.82之间。
- 节点向量可视化
参考资料
- Hamilton W, Ying Z, Leskovec J. Inductive representation learning on large graphs[C]//Advances in Neural Information Processing Systems. 2017: 1024-1034.(https://papers.nips.cc/paper/6703-inductive-representation-learning-on-large-graphs.pdf)
GraphSAGE: 算法原理,实现和应用相关推荐
- GAT 算法原理介绍与源码分析
GAT 算法原理介绍与源码分析 文章目录 GAT 算法原理介绍与源码分析 零. 前言 (与正文无关, 请忽略) 广而告之 一. 文章信息 二. 核心观点 三. 核心观点解读 四. 源码分析 4.1 G ...
- CRF(条件随机场)与Viterbi(维特比)算法原理详解
摘自:https://mp.weixin.qq.com/s/GXbFxlExDtjtQe-OPwfokA https://www.cnblogs.com/zhibei/p/9391014.html C ...
- 三维目标检测算法原理
三维目标检测算法原理 输入输出接口 Input: (1)图像视频分辨率(整型int) (2)图像视频格式(RGB,YUV,MP4等) (3)左右两边的车道线位置信息摄像头标定参数(中心位置(x,y) ...
- 3D-2D:PnP算法原理
3D-2D:PnP算法原理 1.问题背景-- 什么是PnP问题 ? 2.PnP问题的求解方法 2.1 P3P 2.1.1 算法的实际理解 2.1.2 算法的数学推导 2.1.3 算法的缺陷 2.2 直 ...
- MySQL索引背后的数据结构及算法原理【转】
http://blog.codinglabs.org/articles/theory-of-mysql-index.html MySQL索引背后的数据结构及算法原理[转] 摘要 本文以MySQL数据库 ...
- 文本分类的基本思想和朴素贝叶斯算法原理
文本分类的基本思想和朴素贝叶斯算法原理
- Bagging与随机森林算法原理小结
点击上方"小白学视觉",选择加"星标"或"置顶" 重磅干货,第一时间送达 在集成学习原理小结中,我们讲到了集成学习有两个流派,一个是boos ...
- 干货 | 非常全面的谱聚类算法原理总结
点击上方"小白学视觉",选择加"星标"或"置顶" 重磅干货,第一时间送达 谱聚类算法是目前最流行的聚类算法之一,其性能及适用场景优于传统的聚 ...
- 层次聚类算法原理总结
点击上方"小白学视觉",选择加"星标"或"置顶" 重磅干货,第一时间送达 层次聚类(hierarchical clustering)基于簇间 ...
最新文章
- 疯狂java讲义之流程控制与数组
- Mysql学习总结(41)——MySql数据库基本语句再体会
- Chrome新的语言API,让您的浏览器说话
- “开源、共享、创新”, 中国最具前景开发者峰会落幕魔都
- android显示绘图动画,Android自定义View绘图实现渐隐动画
- Java工作笔记-AJAX实现整体不变,局部更新(与整体刷新比较)
- shiro框架的使用
- 漫画:优秀程序员的必备特质有哪些?
- 【博客管理】博客目录导航【置顶】
- 一文看懂NXP汽车电机控制解决方案(NXP整理)
- 如何新建一个keil工程 的详细步骤
- 变电站红外图像数据集
- php匿名聊天室开源,[开源项目]基于WebSocket的匿名聊天室
- 如何在excel中挑选出奇数行和偶数行
- 微信消息实现自动推送--方式一 成功啦 进来学
- Vue nvm重装node和npm与vue3报错Emitted ‘error‘ event on ChildProcess instance at errno: -4058
- html+JavaScript 实现贪吃蛇程序
- 劝人学医,天打雷劈?给医学新生的 10 条入学建议
- MDK解决方案:Warning L6989W
- 入学校计算机社团申请书,学校社团成立申请书
热门文章
- 华为和H3C命令对比
- 《Serverless 与容器决战在即?有了弹性伸缩就不一样了》
- Unable to simultaneously satisfy constraints
- 范德堡大学用机器学习预测自杀,准确率在80%以上
- MIT麻省理工最新研究揭示GAN生成数据可视化分析
- 我是如何通过拉勾教育学习《java高薪训练营》课程突破困境的
- [fuzz论文阅读] Symbolic execution for software testing: three decades later
- 架构设计:一种远程调用服务的设计构思(zookeeper的一种应用实践)
- bind()使用方法
- HTML旅游网页设计制作 DW旅游网站官网滚动网页 DIV旅游风景介绍网页设计与实现