CS224W图机器学习笔记自用:GNN Augmentation and Training
Recap:
today’s outline:
- (4)Graph augmentation
- (5)Learning objective
1. GNN 的图增强(Graph Augmentation for GNNs)
两种图增强的方法:
- 图特征增强
- 图结构增强
1.1 为什么要增强图
需要增强图的原因:
- 特征(Features)
- 输入图缺乏特征
- 图结构(Graph structure)
- 图太稀疏 -> 消息传递效率低下
- 图太密集 -> 消息传递成本太高
- 图太大 -> 无法将计算图存入GPU中
综上所述,输入图不是嵌入的最佳计算图 。
1.2 图增强的方法
- 图特征增强
- 输入图缺乏特征 -> 特征增强
- 图结构增强
- 图太稀疏 -> 添加虚拟节点或边
- 图太密集 -> 消息传递时只采样部分邻居节点进行传递
- 图太大 -> 计算嵌入时对子图进行采样
1.2.1 图特征增强
为什么我们需要特征增强?
第一种情况:
- 输入图没有节点特征,我们只有这个图的邻接矩阵,此时需要进行图特征增强。
解决方法:
a)为节点分配常数特征
b)为节点分配唯一的ID,这些ID值可以被转换为独热向量
两种方法的比较:Constant vs. one-hot
第二种方法比第一种方法表达能力更强;第一种方法归纳能力更强,能够很容易地推广到新节点,第二种方法则不行;同时第一种方法的计算开销也更小;第一种方法适用于任意图,同时具有归约能力,能推广到新节点,第二种方法适用于小图,只适用于transductive setting,不适合inductive setting。
第二种情况:
- GNN 很难学习某些特殊结构
- 例如:环节点数特征(Cycle count feature)
- GNN无法学习 v 1 v_1 v1所在环的长度,也无法区分 v 1 v_1 v1所在的是哪个图形
- 因为这两张图中的所有节点度数都为2
- 其计算图也是完全相同的二叉树
解决方法:我们可以使用循环计数作为增强的节点特征
其他常用的增强特征有:
- 节点度数
- 聚类系数(Clustering coefficient)
- PageRank
- Centrality
我们在第二节所提到的节点特征都可以使用。
1.2.2 图结构增强
针对图稀疏:添加虚拟节点或边
- 添加虚拟边
常用方法:通过虚拟边连接 2 跳邻居
想法:用 A + A 2 A + A^2 A+A2代替邻接矩阵 A A A进行GNN的计算
例子:二部图- Author-to-papers (他们撰写的)
- 2 跳虚拟边构成作者-作者协作图
- 添加虚拟节点:虚拟节点将连接到图中的所有节点
- 假设在一个稀疏图中,两个节点的最短路径距离为 10
- 添加虚拟节点后,所有节点的距离为 2
- Node A - Virtual node - Node B
- 好处:大大提高了稀疏图中的消息传递
- 添加虚拟边
针对图密集的问题:节点邻域采样
思想:在之前的设计中,所有节点参与消息传递,现在,我们**(随机)对节点的邻域进行采样以进行消息传递**,以解决图密集的问题。
例子:例如,我们可以随机选择 2 个邻居在给定层中传递消息
在下一层,当我们计算嵌入时,我们可以采样不同的邻居(对于类似于社交网络的图,也可以仅采样一些重要的节点,不必采样那些不重要的节点)
最后在预期中,我们得到类似于使用所有邻居的情况的嵌入。
这种方法的好处:可以大大降低计算成本,并且允许scaling to 大图,在实践中的效果也很好。
2. Training with GNNs
Learning so far:
2.1 Prediction head:如何从节点嵌入到实际预测
**预测头(prediction head)**有以下几种类型:
- 节点级任务
- 边级别任务
- 图级别任务
不同的任务级别需要不同的预测头
2.1.1 节点级预测头
1. 节点级预测:我们可以直接使用节点嵌入进行预测
- 在 GNN 计算之后,我们有d维的节点嵌入 { h v ( L ) ∈ R d , ∀ v ∈ G } \{ h_v^{(L)} \in R^d,\forall v \in G \} {hv(L)∈Rd,∀v∈G}
- 假设我们要进行一个k类别的预测
- 分类问题:在k个类别中分类
- 回归问题:回归k个目标
- y ^ v = H e a d n o d e ( h v ( L ) ) = W ( H ) h v ( L ) \hat{y}_v = Head_{node}(h_v^{(L)})=W^{(H)}h_v^{(L)} y^v=Headnode(hv(L))=W(H)hv(L)
- W ( H ) ∈ R k × d W^{(H)} \in R^{k \times d} W(H)∈Rk×d: 我们映射节点嵌入从 h v ( L ) ∈ R d h_v^{(L)} \in ℝ^d hv(L)∈Rd到 y ^ v ∈ R k \hat{y}_v \in ℝ^k y^v∈Rk,这样我们就可以计算损失
2.1.2 边级别预测头
2. 边级别预测:使用节点嵌入对进行预测
- 假设我们要进行一个k类别的预测
- y ^ u v = H e a d e d g e ( h u ( L ) , h v ( L ) ) \hat{y}_{uv} = Head_{edge}(h_u^{(L)},h_v^{(L)}) y^uv=Headedge(hu(L),hv(L))
- H e a d n o d e ( h v ( L ) ) = W ( H ) h v ( L ) Head_{node}(h_v^{(L)})=W^{(H)}h_v^{(L)} Headnode(hv(L))=W(H)hv(L)有多种选择
- (1) 串联 + 线性
- 在图注意力网络也有类似的架构
- y ^ u v = L i n e a r ( C o n c a t ( h u ( L ) , h v ( L ) ) ) \hat{y}_{uv} = Linear(Concat(h_u^{(L)},h_v^{(L)})) y^uv=Linear(Concat(hu(L),hv(L)))
- 这里线性映射函数Linear(.)会把2d维的嵌入向量映射到k维(k个类别)的嵌入中
- (2)点积
- y ^ u v = ( h u ( L ) ) T h v ( L ) \hat{y}_{uv} = (h_u^{(L)})^T h_v^{(L)} y^uv=(hu(L))Thv(L)
- 这种方法仅适用于 1-way 预测(例如,链接预测:预测边缘的存在)
- 应用到 k-way 预测上,类似于多头注意力机制, W ( 1 ) , . . . , W ( k ) W^{(1)},... ,W^{(k)} W(1),...,W(k)是可训练的参数
- (1) 串联 + 线性
2.1.3 图级别预测头
3. 图级别预测:使用图中的所有节点嵌入进行预测
- 假设我们要进行一个k类别的预测
- y ^ G = H e a d g r a p h ( { h v ( L ) ∈ R d , ∀ v ∈ G } \hat{y}_G = Head_{graph}(\{h_v^{(L)} \in R^d, \forall v \in G\} y^G=Headgraph({hv(L)∈Rd,∀v∈G}
- H e a d g r a p h ( ⋅ ) Head_{graph}(\cdot) Headgraph(⋅)类似于 GNN 层中的聚合函数 A G G ( ⋅ ) AGG(\cdot) AGG(⋅)
- H e a d g r a p h ( { h v ( L ) ∈ R d , ∀ v ∈ G } Head_{graph}(\{h_v^{(L)} \in R^d, \forall v \in G\} Headgraph({hv(L)∈Rd,∀v∈G}有多种选择:
- 全局平均池化层:与节点数无关,mean pooling可用于比较大小相差很大的图形
- y ^ G = M e a n ( { h v ( L ) ∈ R d , ∀ v ∈ G } ) \hat{y}_G = Mean(\{h_v^{(L)} \in R^d, \forall v \in G\}) y^G=Mean({hv(L)∈Rd,∀v∈G})
- 全局最大池化层
- y ^ G = M a x ( { h v ( L ) ∈ R d , ∀ v ∈ G } ) \hat{y}_G = Max(\{h_v^{(L)} \in R^d, \forall v \in G\}) y^G=Max({hv(L)∈Rd,∀v∈G})
- 全局求和池化层:max pooling可以发现图中的节点数和图的结构
- y ^ G = S u m ( { h v ( L ) ∈ R d , ∀ v ∈ G } ) \hat{y}_G = Sum(\{h_v^{(L)} \in R^d, \forall v \in G\}) y^G=Sum({hv(L)∈Rd,∀v∈G})
- 全局平均池化层:与节点数无关,mean pooling可用于比较大小相差很大的图形
- 全局池化层的问题:以上选项的全局池化层都只适用于小规模的图形,在大图上应用全局池化层会有信息丢失的问题,例如
- 解决方法:分层全局池化(分层聚合所有节点嵌入)
- example:先聚合前两个节点,在聚合后两个节点
现在,我们就能区分图1和图2。
- example:先聚合前两个节点,在聚合后两个节点
那我们如何分层呢?
- DiffPool:
- 分层池化节点嵌入:利用图的社区结构,如果我们可以提前发现这些社区,那么我们就可以把每个社区当作一层,聚合社区内的节点信息,接着我们可以进一步将社区嵌入汇总到超级社区嵌入。如下图,输入图用社区检测或图分区算法分成了5个簇,这里用不同颜色表示,接着我们再汇总社区内的信息为每个社区生成一个超级节点,之后我们根据社区之间的联系再进行分簇,聚合,得到另一个超节点并不断聚合直到得到一个超级节点为止,然后就可以将其输入到预测头中:
- Ying 等人(2018)提出的DiffPool在每个级别利用 2 个独立的 GNN
- GNN A:计算节点嵌入
- GNN B:进行图分区,判断节点所属的集群
- 每个级别的 GNN A 和 B 可以并行执行
- 对于每个池化层
- 根据 GNN B 的聚类社区分配结果来聚合由 GNN A 生成的节点嵌入
- 为每个集群创建一个新节点,维护集群之间的边以生成新的池化网络
- 联合训练 GNN A 和 GNN B
2.2 Predictions and Labels
2.2.1 监督学习 VS. 无监督学习
- 监督学习 VS. 无监督学习
- 图上的监督学习:标签来自外部来源,例如,预测分子图的药物相似性
- 图上的无监督学习:信号来自图本身,例如,链接预测:预测两个节点是否连接
- 有时这些差异的界限是模糊的:我们在无监督学习中仍然有“监督”,例如,训练一个 GNN 来预测节点聚类系数,“无监督”也被称为“自我监督”
- 图上的监督标签
- 监督标签来自特定的用例,例如:
- 节点标签 y v y_v yv:在引文网络中,节点标签是节点属于哪个学科领域
- 边的标签 y u v y_{uv} yuv:在交易网络中,边的标签是边是否具有欺诈性
- 图的标签 y G y_G yG:分子图中,图标签是图的药物相似度
-Advice:将您的任务减少到节点/边/图形标签,这样我们就能使用现有的框架。
- 监督标签来自特定的用例,例如:
- 图上的非监督标签
- Problem:有时我们只有一个图,没有任何外部标签
- 解决方法:自我监督学习,我们可以在图中找到监督信号
- 以下任务不需要任何外部标签:
- 节点级别 y v y_v yv:节点统计(如聚类系数、PageRank、…)或预测节点的属性
- 边级别 y u v y_{uv} yuv:**链接预测(**隐藏两个节点之间的边,预测是否应该有链接)
- 图级别 y G y_{G} yG:图统计(例如,预测两个图是否同构)
2.3 Loss Function
2.3.1 分类 VS. 回归
- 分类(Classification):节点的标签 y ( i ) y^{(i)} y(i)具有离散值
- 例如,节点分类:节点属于哪个类别
- 回归(Regression):节点的标签 y ( i ) y^{(i)} y(i)具有连续值
- 例如,预测分子图的药物相似性或毒性水平
- GNNs可以应用于这两类问题, 不同的在于损失函数和评估指标
2.3.2 分类问题损失函数
- 交叉熵 (cross entropy CE) 是分类中非常常见的损失函数
- 我们要预测第i个数据点的类别(一共有K类)
- 其它类型损失函数: H i n g e L o s s (铰链损失) Hinge Loss(铰链损失) HingeLoss(铰链损失)
- 在"maximum-margin"的分类任务中,如支持向量机,表示预测输出,通常都是软结果(输出不是0,1这种,可能是0.87), 表示正确的类别,我们用下式作为分类函数:
H i n g e L o s s = m a x ( 0 , m − y ^ y ) Hinge Loss = max(0, m - \hat{y}y) HingeLoss=max(0,m−y^y) - 很多时候我们希望训练的是两个样本之间的相似关系,而非样本的整体分类,所以很多时候我们会用下面的公式:
H i n g e L o s s = m a x ( 0 , m − y + y ^ ) Hinge Loss = max(0, m - y + \hat{y}) HingeLoss=max(0,m−y+y^) - 其中,是y正样本的得分,是 y ^ \hat{y} y^负样本的得分,m是margin,即我们希望正样本分数越高越好,负样本分数越低越好,但二者得分之差最多到m就足够了,差距增大并不会有任何奖励。
- 在"maximum-margin"的分类任务中,如支持向量机,表示预测输出,通常都是软结果(输出不是0,1这种,可能是0.87), 表示正确的类别,我们用下式作为分类函数:
2.3.3 回归问题损失函数
- 对于回归任务,我们经常使用均方误差 (MSE) 也就是 L2 损失
- 数据点i的k-way回归
2.4 Evaluation metrics
在回归问题上,我们使用 GNN 的标准评估指标,在实践中我们通常使用sklearn程序包来实现,假设我们对 N 个数据点进行预测
2.4.1 回归问题分类指标
在图上评估回归任务,我们可以使用根均方差(RMSE) 和 平均绝对误差(MAE) 这两个指标来评价:
2.4.2 分类问题分类指标
在图上评估分类任务:
- (1) 多类分类
- 只报告准确性
- (2)二类分类
对分类阈值敏感的指标
- Accuracy (准确率)
- Precision(精确率) / Recall (召回率)
- 如果预测的范围是 [0,1],我们将使用 0.5 作为阈值
与分类阈值无关的指标
- ROC Curve:捕获 TPR 和 FPR 的权衡,因为二元分类器的分类阈值是变化的(虚线表示随机分类器的性能)
- ROC AUC:RUC曲线下的面积(Area under the ROC Curve),是分类器将随机选择的正实例得分高于随机选择的负实例的概率
3. 数据集拆分(训练/验证/测试集)
3.1 常规拆分方案
- 固定拆分:我们将一次性分割我们的数据集
- 训练集:用于优化 GNN 参数
- 验证集:用于调整超参数和各种常数及决策选择
- 测试集:只用于评估模型的最终性能
我们用训练集和验证集确定最终模型,然后将模型应用到测试集
- 随机拆分:我们将数据集随机拆分为训练/验证/测试
- 我们报告了不同随机种子的拆分方案平均性能
3.2 图的数据集拆分方案
- 拆分图和拆分一般的数据集不一样,会造成数据泄露的问题。
- 在文档数据集或图像数据集中,我们拆分数据集时,假设数据点之间相互独立,这样很容易将其拆分成三个数据集,并且没有数据泄漏
- 然而拆分图数据集是不一样的,图的问题在于节点之间相互连接,不是相互独立的,节点会从其他节点收集信息,这样会造成信息泄露的问题。
- 解决方案 1(Transductive setting):只拆分节点标签,保持图的结构不变,整个输入图在所有数据集中都是可见的(即使用整个图计算嵌入)
- 只拆分(节点)标签:
- 在训练时,我们使用整个图计算嵌入,并使用节点 1 和 2 的标签进行训练
- 在验证时,我们使用整个图计算嵌入,并评估节点 3 和 4 的标签
- 只拆分(节点)标签:
- 解决方案 2(Inductive setting):删除拆分出的数据集之间连接的边
- 现在我们有 3 个独立的图。节点 5 将不再影响我们对节点 1 的预测
- 在训练时,我们使用节点 1&2 上的图计算嵌入,并使用节点 1&2 的标签进行训练
- 在验证时,我们使用节点 3&4 上的图计算嵌入,并评估节点 3&4 的标签
- 两种方案的比较:Transductive / Inductive Settings
- Transductive Settings: 训练/验证/测试集在同一张图上
- 数据集由单个图组成
- 可以在所有数据集拆分中观察到整个图,只拆分标签
- 仅适用于节点/边预测任务
- Transductive Settings: 训练/验证/测试集在同一张图上
- Inductive Settings:训练/验证/测试集在不同的图表上
- 数据集由多个图组成
- 每个拆分只能观察拆分内的图,这使我们能够真正测试如何将其推广到看不见的图形,一个成功的模型应该泛化到看不见的图
- 适用于节点/边/图任务
3.3 图的拆分示例
- 节点分类
- 图分类
在图分类问题中,由于我们分类独立的图,因此归纳设置不需要删除边就能应用,我们可以方便地将其分为训练、验证和测试集。
- 链接预测
- 链接预测设置是图机器学习中最棘手的任务:它是一项无监督/自我监督的任务,我们需要自己创建标签和数据集拆分。
- 具体来说,我们需要对 GNN 隐藏一些边,并让 GNN 预测这些边是否存在
- 对于链接预测,我们将两次分割边
- 第 1 步:在原始图中分配 2 种类型的边
- 消息边:用于 GNN中的 消息传递
- 监督边:用于计算目标
- 第一步之后
- 图中仅保留消息边,移除监督边
- 监督边用作模型对边预测的监督,不会被输入 GNN
- 第 2 步:将边拆分为训练/验证/测试
- 选项 1:归纳链接预测拆分
- 假设我们有一个包含 3 个图的数据集。每个归纳拆分将包含一个独立的图
- 假设我们有一个包含 3 个图的数据集。每个归纳分裂将包含一个独立的图
- 在训练或验证或测试集中,每个图将有 2 种类型的边:消息边 + 监督边(监督边不是 GNN 的输入)
- 选项2:Transductive链路预测分割(默认选项)
- 根据“转导”的定义,可以在所有数据集拆分中观察到整个图
- 训练时:使用训练消息边预测训练监督边
- 验证时:使用训练消息边和训练监督边预测验证边
- 测试时:使用训练消息边、训练监督边和验证边 预测 测试边 - Transductive链路预测分割将图的边分为四类:训练消息边、训练监督边、验证边、测试边,链接预测设置既棘手又复杂,您可能会发现论文以不同的方式进行链接预测。幸运的是,我们完全支持 PyG 和 GraphGym来帮助我们进行链接预测。
4. Summary: GNN Training Pipeline
实现资源:
- DeepSNAP 为该管道提供核心模块
- GraphGym 进一步实现全流水线以方便 GNN 设计
5. Tips:When Things Don’t Go As Planned
5.1 通用提示
- 数据预处理很重要
- 节点属性的变化范围很大,从(0,1)到(-1000,1000)都有可能
- 因此需要进行标准化
- 优化器的选择
- ADAM 对学习率相对稳健
- 激活函数
- ReLU 激活函数通常效果很好
- 其他替代方案:LeakyReLU、SWISH、rational activation
- 输出层没有激活函数
- 在每一层中包含偏置项
- 嵌入维度:32、64 和 128 通常是很好的起点
5.2 调试深度网络
调试问题:损失/准确性在训练期间未收敛
- 检查管道(例如在 PyTorch 中我们需要 zero_grad)
- 调整学习率等超参数
- 注意权重参数初始化
对模型开发很重要的问题:
- 在(部分)训练数据上过拟合:
- 对于一个小的训练数据集,损失应该基本上接近于 0,对于一个表达神经网络
- 如果神经网络不能过拟合单个数据点,那是错误的
- 仔细检查损失函数!
- 仔细检查可视化!
5.3 图神经网络资源
论文阅读:
CS224W图机器学习笔记自用:GNN Augmentation and Training相关推荐
- 斯坦福CS224W图机器学习笔记自用:A General Perspective on Graph Neural Networks
1. Recap Deep Graph Encoders(深度图编码器) Graph Neural Networks GNN核心思想:节点的邻域定义了一个计算图 Aggregate from Neig ...
- CS224W图机器学习笔记8-图神经网络三大应用
图神经网络 课程和PPT主页 Prediction with GNNs 目前我们只学习了从输入图->节点嵌入,也就是学习图神经网络怎么获取节点嵌入,并没有学习针对特定下游任务的一些内容(从节点嵌 ...
- CS224W图机器学习笔记5-消息传递与节点分类
消息传递与节点分类 课程和PPT主页 本文主要解决的问题:给定一个只有部分已知标签节点的图,如何给图中其他节点分配正确的标签? 本文主要讨论一个名为"message passing" ...
- [Datawhale][CS224W]图机器学习(三)
目录 一.简介与准备 二.教程 2.1 下载安装 2.2 创建图 2.2.1 常用图创建(自定义图创建) 1.创建图对象 2.添加图节点 3.创建连接 2.2.2 经典图结构 1.全连接无向图 2.全 ...
- 【斯坦福大学公开课CS224W——图机器学习】三、节点和图嵌入
[斯坦福大学公开课CS224W--图机器学习]三.节点和图嵌入 文章目录 [斯坦福大学公开课CS224W--图机器学习]三.节点和图嵌入 1. 节点嵌入 1.1 编码器与解码器 1.2 节点嵌入的游走 ...
- 【斯坦福大学公开课CS224W——图机器学习】五、消息传递和节点分类
[斯坦福大学公开课CS224W--图机器学习]五.消息传递和节点分类 文章目录 [斯坦福大学公开课CS224W--图机器学习]五.消息传递和节点分类 1. Message Passing and No ...
- 【CS224w图机器学习】第一章 图机器学习导论
一.前言 笔记参考b站同济子豪兄的视频而成,源于斯坦福CS224W 学完本章,你将会对图神经网络有初步的了解,同时对于应用层面也有初步的印象 1.1关键词 图机器学习.图数据挖掘.图神经网络(GNN) ...
- CS224W图机器学习课,斯坦福大牛主讲 | 视频、课件
博雯 发自 凹非寺 量子位 报道 | 公众号 QbitAI 斯坦福大学的CS224W 2021冬季公开课,最近上线了. 在分析.处理大规模图形的过程中,往往在计算.算法和建模等方面充斥着挑战. 而本课 ...
- 机器学习笔记 invariance data augmentation
1 Invariance vs. Sensitivity 无论是对于图像.文本还是视频,我们都希望找到好的向量表示 好的向量表示需要对我们任务所关心的特征敏感: 动物识别问题中,动物的品种就是一个值得 ...
最新文章
- Linux下Mysql数据库的基础操作
- 设计模式--动态代理
- dbgrideh的功能
- [react] 如何给非控组件设置默认的值?
- Django讲课笔记11:视图函数的请求和响应
- C. Memory and De-Evolution 逆向思维
- java匿名内部类 内部类_java中的匿名内部类详细总结
- InTouch软件介绍
- 【科研】计算社会科学与复杂科学
- 经纬度转换坐标接口 查询位置信息
- 威猛“路威“,全新启航!
- ivx动效按钮 基础按钮制作 01
- Android通知渠道
- 自动化运维之k8s——Helm、普罗米修斯、EFK日志管理、k8s高可用集群(未完待续)
- springboot RedisTemplate 提示没有双引号序列化失败问题
- 初识MIMO(六):MU-MIMO的仿真
- 一文教你快速搞懂速度曲线规划之S形曲线(超详细+图文+推导+附件代码)
- 公司“内部管理混乱,工作很难开展”!
- 全面云化的变革悄然而至,IPLOOK助力加速云网融合
- 为什么工具类App,都要做一个社区?