CS224W-07:图神经网络二
图神经网络二
第六节主要是对图神经网络做了一个整体上的介绍,本节介绍几种经典的GNN 和设计GNN的基本思路。具体内容为
- 单层 GNN
- 单层 GNN 的一般形式
- 经典的 GNN 网络:GCN, GraphSAGE, GAT
- 多层 GNN 设计
- 如何确定 GNN 网络的层数
- 过平滑 (Over smoothing) 问题
- 跳跃连接 (Skip connections)
- 实际训练中图的操作
- 特征增广 (Feature augmentation)
- 结构变换 (Structure manipulation)
单层 GNN
单层 GNN 的一般形式
节点嵌入的计算包括消息传递和聚合两个步骤,消息传递是根据邻域节点的嵌入计算消息,消息聚合是将邻域节点传递的消息与自身的嵌入结合到一起计算新的嵌入。消息计算可以表示为
mu(l)=MSG(l)(hu(l−1)),∀u∈{N(v)∪v}\mathbf m_u^{(l)} = \text{MSG}^{(l)} (\mathbf h_u^{(l-1)}), \, \forall u \in \{N(v) \cup v \} mu(l)=MSG(l)(hu(l−1)),∀u∈{N(v)∪v}
例如线性连接 (Linear layer)为
mu(l)=W(l)hu(l−1),∀u∈N(v)mv(l)=B(l)hv(l−1)\mathbf m_u^{(l)} =\mathbf W^{(l)} \mathbf h_u^{(l-1)}, \quad \forall u \in N(v) \\ \mathbf m_v^{(l)} =\mathbf B^{(l)} \mathbf h_v^{(l-1)} \quad mu(l)=W(l)hu(l−1),∀u∈N(v)mv(l)=B(l)hv(l−1)
消息聚合可以表示为
hv(l)=AGG(l)({mu(l),∀u∈N(v)},mv(l))\mathbf h_v^{(l)} = \text{AGG}^{(l)}(\{ \mathbf m_u^{(l)}, \forall u \in N(v) \}, \mathbf m_v^{(l)}) hv(l)=AGG(l)({mu(l),∀u∈N(v)},mv(l))
Graph Convolution Networks (GCN)
GCN 是一种基于空间的图卷积网络,消息传递和聚合比较简单
$$
\begin{align}
& \text{Message: } \mathbf m_u^{(l)} = \frac {1}{|N(v)|} \mathbf W^{(l)} \mathbf h_u^{(l)} \
& \text{Agrregation: } \mathbf h_v^{(l)} = \sigma (\sum_{u \in N(v)} \mathbf m_u^{(l)}) \
& \text{Together: } \mathbf h_v^{(l)} = \sigma (\sum_{u \in N(v)} \frac {1}{|N(v)|} \mathbf W^{(l)} \mathbf h_u^{(l)})
\end{align}
$$
GraphSAGE
GraphSAGE 中介绍了三种不同的聚合方式,其计算公式如下
hu(l)=σ(W(l)⋅CONCAT(hv(l−1),AGG({hu(l−1),∀u∈N(v)}))\mathbf h_u^{(l)} = \sigma \left(\mathbf W^{(l)} \cdot \text{CONCAT} \left (\mathbf h_v^{(l-1)}, \text{AGG}(\{\mathbf h_u^{(l-1)}, \forall u \in N(v)\} \right) \right) hu(l)=σ(W(l)⋅CONCAT(hv(l−1),AGG({hu(l−1),∀u∈N(v)}))
消息传递的计算实际是在函数 AGG 中完成的,GraphSAGE 可以分解为两步
- 从邻域节点聚合消息 : hN(v)(l)=AGG({hu(l−1),∀u∈N(v)})\mathbf h_{N(v)}^{(l)} = \text{AGG}(\{\mathbf h_u^{(l-1)}, \forall u \in N(v)\})hN(v)(l)=AGG({hu(l−1),∀u∈N(v)})
- 将节点嵌入与邻域消息在此聚合:$\mathbf h_u^{(l)} = \sigma \left(\mathbf W^{(l)} \cdot \text{CONCAT} (\mathbf h_{N(v)}^{(l)}, \mathbf h_u^{(l-1)}) \right) $
聚合函数有三种选择
Mean: 计算邻域节点的加权平均
AGG=∑u∈N(v)hu(l−1)∣N(v)∣\text{AGG} = \sum_{u \in N(v)} \frac {\mathbf h_u^{(l-1)}}{|N(v)|} AGG=u∈N(v)∑∣N(v)∣hu(l−1)Pool: 对邻域节点嵌入向量做变换后,使用平均池化或最大池化
AGG=Mean(MLP(hu(l−1)),∀u∈N(v))\text{AGG} = \text{Mean}\left(\text{MLP}(\mathbf h_u^{(l-1)}), \forall u \in N(v) \right) AGG=Mean(MLP(hu(l−1)),∀u∈N(v))LSTM: 将节点顺序打乱后,再应用 LSTM
AGG=LSTM([hu(l−1),∀u∈π(N(v))])\text{AGG} = \text{LSTM} ([\mathbf h_u^{(l-1)}, \forall u \in \pi(N(v))]) AGG=LSTM([hu(l−1),∀u∈π(N(v))])
GraphSAGE 还使用了 L2 归一化,对于节点的嵌入
hv(l)=hv(l)∥hv(l)∥2\mathbf h_v^{(l)} = \frac {\mathbf h_v{(l)}}{\|\mathbf h_v{(l)} \|_2} hv(l)=∥hv(l)∥2hv(l)
归一化保证节点的嵌入向量都在同一尺度,在某些情况下可能会有更好的效果。
Graph Attention Networks (GAT)
GAT在消息聚合时给每个节点的消息增加了注意力权重,计算公式为
hv(l)=σ(∑u∈N(v)αvuW(l)hu(l−1))\mathbf h_v^{(l)} = \sigma \left( \sum_{u \in N(v)} \alpha_{vu} \mathbf W^{(l)} \mathbf h_u^{(l-1)} \right) hv(l)=σ⎝⎛u∈N(v)∑αvuW(l)hu(l−1)⎠⎞
在 GCN 和 GraphSAGE 中,αvu=1N(v)\alpha_{vu} = \frac {1}{N(v)}αvu=N(v)1, 这表示领域 N(v)N(v)N(v) 中每个节点对于节点 vvv 都是同等重要的。在 GAT 中,αvu\alpha_{vu}αvu 是可学习的参数。用 evue_{vu}evu 表示节点 uuu 传向节点 vvv 消息的重要性,aaa 表示重要性计算函数
evu=a(W(l)hv(l−1),W(l)hu(l−1))e_{vu} = a(\mathbf W^{(l)} \mathbf h_v^{(l-1)}, \mathbf W^{(l)} \mathbf h_u^{(l-1)}) evu=a(W(l)hv(l−1),W(l)hu(l−1))
在原论文中
evu=LeakyReLU(aT⋅CONCAT(hv(l−1),W(l)hu(l−1)))e_{vu} = \text{LeakyReLU} \left(\mathbf a^T \cdot \text{CONCAT}( \mathbf h_v^{(l-1)}, \mathbf W^{(l)} \mathbf h_u^{(l-1)}) \right) evu=LeakyReLU(aT⋅CONCAT(hv(l−1),W(l)hu(l−1)))
其中 a\mathbf aa 可学习参数。再使用 Softmax 对边重要性作归一化
αvu=exp(evu)∑k∈N(v)exp(evk)\alpha_{vu} = \frac {\exp (e_{vu})} {\sum_{k \in N(v)} \exp(e_{vk})} αvu=∑k∈N(v)exp(evk)exp(evu)
GAT 还可以增加多个注意模块 (Multi-head attention),这样对网络的稳定性有一定提升。具体做法就是对每条边的消息计算多个注意力,然后将这些消息聚合起来得到最终的嵌入向量,公式为
hv(l)[1]=σ(∑u∈N(v)αvu1W(l)hu(l−1))hv(l)[2]=σ(∑u∈N(v)αvu2W(l)hu(l−1))hv(l)[3]=σ(∑u∈N(v)αvu3W(l)hu(l−1))\mathbf h_v^{(l)}[1] = \sigma \left( \sum_{u \in N(v)} \alpha_{vu}^1 \mathbf W^{(l)} \mathbf h_u^{(l-1)} \right) \\ \mathbf h_v^{(l)}[2] = \sigma \left( \sum_{u \in N(v)} \alpha_{vu}^2 \mathbf W^{(l)} \mathbf h_u^{(l-1)} \right) \\ \mathbf h_v^{(l)}[3] = \sigma \left( \sum_{u \in N(v)} \alpha_{vu}^3 \mathbf W^{(l)} \mathbf h_u^{(l-1)} \right) \\ hv(l)[1]=σ⎝⎛u∈N(v)∑αvu1W(l)hu(l−1)⎠⎞hv(l)[2]=σ⎝⎛u∈N(v)∑αvu2W(l)hu(l−1)⎠⎞hv(l)[3]=σ⎝⎛u∈N(v)∑αvu3W(l)hu(l−1)⎠⎞
聚合方式使用 Concat 、求和或者其他方式
hv(l)=AGG(hv(l)[1],hv(l)[2],hv(l)[3])\mathbf h_v^{(l)} = \text{AGG}(\mathbf h_v^{(l)}[1], \mathbf h_v^{(l)}[2], \mathbf h_v^{(l)}[3]) hv(l)=AGG(hv(l)[1],hv(l)[2],hv(l)[3])
GAT 的优点有:
- 计算效率高:边的注意力和聚合操作都可以并行计算
- 存储效率高:只需要存储节点和边的计算权重,空间复杂度为 O(V+E)O(V + E)O(V+E) ; 不论图的大小如何,网络参数大小不变
- 局部性:注意力只作用于图中的局部结构
- 归纳性学习能力:训练得到的边的权重是共享的,与图的结构无关
GNN 中常用的模块
与 CNN 中类似,GNN 中也可以使用 Normalization, Dropout, Activation 等模块 ,常用的 GNN 包含的模块如下
Batch Normalization 计算如下
Dropout 作用于消息计算之后,以线性连接为例
激活函数与其他神经网络一样。授课课题组有一个方便的工具 GraphGym
多层 GNN
构建多层 GNN 的最简单的方式是将多个单层 GNN 网络按顺序堆叠起来,第1层的输入嵌入为节点特征 hv(0)=xv\mathbf h_v^{(0)} = \mathbf x_vhv(0)=xv ,依次使用 GNN 计算节点嵌入向量,最后一层输出所需的节点嵌入。
与CNN 等深度神经网络不同的是,多层 GNN 会遇到一个特殊问题——过平滑 (Over smoothing) 。当 GNN 层数太多时,每个节点的嵌入向量会逐渐收敛至同一个值,最后输出的节点嵌入不再能够区分不同节点。过平滑的出现与 GNN 的计算方式密切相关。 目前GNN 一般使用消息传递和聚合的方式来更新节点嵌入,而消息聚合的结果取决于邻域节点,也就是该节点的感受野。通过多层 GNN,节点的感受野不断扩大,获得的信息也从局部扩展到图的全局,但是与其他节点感受野的重合度也越来越大。如下图所示,经过三层 GNN 后,节点的邻域已经扩展到 3跳,但几乎囊括了图中所有节点,不同节点的邻域的重合度也变得非常大,计算得到的嵌入会变得非常接近,最终造成过平滑问题。概括起来就是:堆叠多层 GNN 网络 → 节点之间的感受野高度重合 → 节点嵌入变得十分接近 → 造成过平滑问题。
那么该如何解决过平滑问题呢?可以从两方面着手:一是使用浅层 GNN;二是借鉴 Resnet 的思想,增加跳跃连接 (Skip connections).
浅层 GNN
在设计 GNN 时,我们需要不能盲目的增加网络的深度,更深的 GNN 不一定有更好的效果。GNN 的层数可以通过分析节点需要的最大感受野来确定(例如图的直径来表示节点最大感受野),网络的层数往往比需要的感受野稍微大一点即可。虽然浅层 GNN 可以避免过平滑问题,那么该如何提高浅层 GNN 的表达力呢?解决方案有两种:
- 加强单层 GNN 内消息聚合的复杂性。在之前的例子中,消息计算通常只用了一次线性计算,我们可以将消息聚合变得更复杂,使用多层感知 (MLP),将消息传递和聚合变成一个深度神经网络。
- 增加不传递消息的计算层。在 GNN 前后分别增加前处理层和后处理层,这些层只对节点特征进行计算,不做任何消息传递。
跳跃连接
观察到深层 GNN 可能导致过平滑问题,浅层 GNN 的节点嵌入的区分度可能更大,所以我们可以将浅层网络的嵌入通过跳跃连接,与深层网络的嵌入相加,以保留节点的区分度。
比如在 GCN 中,原始的消息聚合方式为 $ \mathbf h_v^{(l)} = \sigma (\sum_{u \in N(v)} \frac {1}{|N(v)|} \mathbf W^{(l)} \mathbf h_u^{(l)})$ ,增加跳跃连接后
hv(l)=σ(∑u∈N(v)1∣N(v)∣W(l)hu(l)+hv(l))\mathbf h_v^{(l)} = \sigma (\sum_{u \in N(v)} \frac {1}{|N(v)|} \mathbf W^{(l)} \mathbf h_u^{(l)} + \mathbf h_v^{(l)}) hv(l)=σ(u∈N(v)∑∣N(v)∣1W(l)hu(l)+hv(l))
疑问:与上一节讲的公式 hv(l)=σ(∑u∈N(v)1∣N(v)∣W(l)hu(l)+B(l)hv(l))\mathbf h_v^{(l)} = \sigma (\sum_{u \in N(v)} \frac {1}{|N(v)|} \mathbf W^{(l)} \mathbf h_u^{(l)} + \mathbf B^{(l)}\mathbf h_v^{(l)})hv(l)=σ(∑u∈N(v)∣N(v)∣1W(l)hu(l)+B(l)hv(l)) 很相像,而且与GraphSAGE 中的聚合方式也比较像,这种跳跃连接真的有很大作用吗?
图的操作 (Graph Manipulation in GNNs)
在之间的内容中,输入到 GNN 中计算的图和图的原始数据是一样的,但是实际中,计算图和原始图可能存在不同。我们可能需要改变图的节点特征和图的结构。
- 改变特征:当节点缺少特征,需要对特征做扩展
- 改变结构:当图太稀疏(消息传递效率低),或太密集(消息传递计算量大),或图太大(无法用GPU 计算整张图)
特征扩展
当节点没有任何特征是,可以使用常数特征或者独热编码 (one-hot )作为节点特征,二者的优缺点如下
当然这两种特征很难区分图中的某些结构,具体例子参看课件。除此之外,还可使用第二讲中介绍的特征:Clustering coefficient, PageRank, Centerity 等等。
改变图结构
- 增加虚拟边:当图过于稀疏时,可以增加虚拟边,可以将 2 跳邻域内的节点都连接起来。邻接矩阵变成 A+A2A + A^2A+A2
- 增加虚拟节点:增加一个虚拟节点,它与所有节点都相连,这样所有节点之间的最大距离都变为2,有效提高消息传递的效率
- 邻域节点随机采样:当图过于稠密时,可以对邻域节点进行随机采样,以减少消息传递的计算量
CS224W-07:图神经网络二相关推荐
- 从图(Graph)到图卷积(Graph Convolution):漫谈图神经网络 (二)
在从图(Graph)到图卷积(Graph Convolution): 漫谈图神经网络 (一)中,我们简单介绍了基于循环图神经网络的两种重要模型,在本篇中,我们将着大量笔墨介绍图卷积神经网络中的卷积操作 ...
- 图神经网络(二)--GNNs
转自:https://zhuanlan.zhihu.com/p/75307407 目录 一.什么是图神经网络 二.有哪些图神经网络 三.图神经网络的应用 一.什么是图神经网络? 在过去的几年中,神经网 ...
- 图神经网络(Graph Neural Networks,GNN)综述
鼠年大吉 HAPPY 2020'S NEW YEAR 作者:苏一 https://zhuanlan.zhihu.com/p/75307407 本文仅供学术交流.如有侵权,可联系删除. 本篇文章是对论文 ...
- 论文浅尝 - IJCAI2020 | KGNN:基于知识图谱的图神经网络预测药物与药物相互作用...
转载公众号 | AI TIME 论道 药物间相互作用(DDI)预测是药理学和临床应用中一个具有挑战性的问题,在临床试验期间,有效识别潜在的DDI对患者和社会至关重要.现有的大多数方法采用基于AI的计 ...
- 漫谈图神经网络 (三)
恭喜你看到了本系列的第三篇!前面两篇分别介绍了基于循环的图神经网络和基于卷积的图神经网络,那么在本篇中,我们则主要关注在得到了各个结点的表示后,如何生成整个图的表示.其实之前我们也举了一些例子,比如最 ...
- 【CS224W】(task9)图神经网络的表示能力(GIN图同构模型)
note ranking by discriminative power(input):sum-multiset > mean-distribution > max-set [基础部分]G ...
- 图神经网络(二)GCN的性质(3)GCN是一个低通滤波器
图神经网络(二)GCN的性质(3)GCN是一个低通滤波器 在图的半监督学习任务中,通常会在相应的损失函数里面增加一个正则项,该正则项需要保证相邻节点之间的类别信息趋于一致,一般情况下,我们选用拉普拉 ...
- 图神经网络(二)GCN的性质(2)GCN能够对图数据进行端对端学习
图神经网络(二)GCN的性质(2)GCN能够对图数据进行端对端学习 近几年,随着深度学习的发展,端对端学习变得越来越重要,人们普遍认为,深度学习的成功离不开端对端学习的作用机制.端对端学习实现了一种 ...
- 百度图神经网络学习——day04:图神经网络算法(二)
文章目录 一.图采样 1.GraphSAGE 2.PinSAGE 二.邻居聚合 1.GIN模型的聚合函数 2.其他复杂的聚合函数 三.编程实现 1.GraphSage采样函数实现 2.GraphSag ...
最新文章
- 一个电脑白痴与黑客的对话
- VII Python(9)socket编程
- php职业认证,如何用 PHP 进行 HTTP 认证
- 【GOF23设计模式】迭代器模式
- 从binlog恢复数据及Mysqlbinlog文件删除
- 遍历map时删除不需要的元素方法
- linux登录用户目录,linux命令
- Hibernate3动态条件查询
- android 看门狗引起crash分析
- 番茄花园GHOST SP3无法安装IIS 信息服务的解决方法
- 17SWFObject使用
- 一个JSP页面打开另外一个JSP页面并传值
- Unity游戏开发前置知识
- 码云webhook node版
- 如何打印int整数的32位二进制数(位运算)
- 苹果输入法怎么换行_朋友圈不折叠的N种方法安卓苹果通用
- java全jit编译_Javac编译与JIT编译
- 高德地图 点击获取坐标
- 企业网络安全防护概述
- No qualifying bean of type ‘com.itheima.dao.BookDao1‘ available: expected single matching bean 问题解决
热门文章
- ligo 原理_在LIGO的实验中,Ubuntu被用来检测引力波
- 看纷享销客如何布局连接型CRM
- JS 的cookie三部曲
- 离散数学_九章:关系(5)
- 【P45】直流单电源24V JLH 1969 经典耳放参数优化
- 大数据最重要的算法是什么,最常用的算法有哪些?
- PPPOE拨号之六:华为路由器 PPPoE拨号配置(包含Client+NAT与服务器配置)
- 一成电计算机考研国家线2O 9,【九〇六 | 打卡】考研“国家线”只是起点,我们要挑战骇浪惊涛!...
- 【山外笔记-计算机网络·第7版】第13章:计算机网络名词缩写汇总
- 用两种方法改错,体会封装和友员的关系!