前言

论文链接:https://arxiv.org/abs/1704.01212
github:https://github.com/ifding/graph-neural-networks
参考:https://blog.csdn.net/qq_27075943/article/details/106623059

MPNN 不是一个模型,而是一个框架。作者在这篇论文中主要将现有模型抽象其共性并提出成 MPNN 框架,同时利用 MPNN 框架在分子分类预测中取得了一个不错的成绩。

在这篇论文中,作者的目标是证明:能够应用于化学预测任务的模型可以直接从分子图中学习到分子的特征,并且不受到图同构的影响。为此,作者将应用于图上的监督学习框架称之为消息传递神经网络(MPNN),这种框架是从目前比较流行的支持图数据的神经网络模型中抽象出来的一些共性,抽象出来的目的在于理解它们之间的关系。

本文给出的一个例子是利用 MPNN 框架代替计算代价昂贵的 DFT 来预测有机分子的量子特性:

1. MPNN框架

首先定义变量名称及含义:
定义无向图GGG,节点vvv的特征向量为xvx_vxv​,边的特征为evwe_{vw}evw​连接节点vvv和www
前向传递的两个阶段分别为:消息传递阶段(Message Passing)读出阶段(Readout)

对于消息传递阶段(Message Passing):消息函数定义为MtM_tMt​,定点更新函数定义为UtU_tUt​,ttt为运行的时间步,更新过程如下:
mvt+1=∑w∈N(v)Mt(hvt,hwt,evw)m_v^{t+1}=\sum_{w \in N(v)}M_t(h_v^t,h_w^t,e_{vw})mvt+1​=w∈N(v)∑​Mt​(hvt​,hwt​,evw​)hvt+1=Ut(hvt,mvt+1)h_v^{t+1}=U_t(h_v^t,m_v^{t+1})hvt+1​=Ut​(hvt​,mvt+1​)其中,N(v)N(v)N(v) 表示图 GGG 中节点 vvv 的邻居。

对于 读出阶段(Readout) 使用一个读出函数 RRR 来计算整张图的特征向量:
y^=R(hvT∣v∈G)\hat{y}=R(h_v^T|v \in G)y^​=R(hvT​∣v∈G)

消息函数MtM_tMt​,向量更新函数UtU_tUt​ 和读出函数RRR都是可微函数。RRR作用于节点的状态集合,同时对节点的排列不敏感,这样才能保证 MPNN 对图同构保持不变。

此外,我们也可以通过引入边的隐藏层状态来学习图中的每一条边的特征,并且同样可以用上面的等式进行学习和更新。

接下来我们看下如何通过定义消息函数、更新函数和读出函数来适配不同种模型。

Paper1 : Convolutional Networks for Learning Molecular Fingerprints, Duvenaud et al. (2015)

这篇论文中消息函数为:
M(hv,hw,evw)=(hw,evw)M(h_v,h_w,e_{vw})=(h_w,e_{vw})M(hv​,hw​,evw​)=(hw​,evw​)其中 (.,.)(.,.)(.,.) 表示拼接(concat);

节点的更新函数为:
Ut(hvt,mvt+1)=σ(Htdeg(v)mvt+1)U_t(h_v^t,m_v^{t+1}) = \sigma(H_t^{deg(v)}m_v^{t+1})Ut​(hvt​,mvt+1​)=σ(Htdeg(v)​mvt+1​)

其中 σ\sigmaσ 为 sigmoid 函数,deg(v)deg(v)deg(v) 表示节点 vvv 的度,HtvH_t^vHtv​是一个可学习的矩阵,ttt 为时间步,NNN 为节点度;

读出函数 RRR 将先前所有隐藏层的状态hvth_v^thvt​进行连接:
R=f(∑v,tsoftmax(Wthvt))R = f(\sum_{v,t}softmax(W_th_v^t))R=f(v,t∑​softmax(Wt​hvt​))

其中 fff 是一个神经网络,WtW_tWt​是一个可学习的读出矩阵。

这种消息传递阶段可能会存在一些问题,比如说最终的消息向量分别对连通的节点和连通的边求和mvt+1=(∑hwt,∑evw)m_v^{t+1}=(\sum h_w^t,\sum e_{vw})mvt+1​=(∑hwt​,∑evw​)由此可见,该模型实现的消息传递无法识别节点和边之间的相关性。

Paper 2 : Gated Graph Neural Networks (GG-NN), Li et al. (2016)

这篇论文比较有名,作者后续也是在这个模型的基础上进行改进的。

GG-NN 使用的消息函数为:
Mt(hvt,hwt,evw)=AevwhwtM_t(h_v^t,h_w^t,e_{vw}) = A_{e_{vw}}h_w^tMt​(hvt​,hwt​,evw​)=Aevw​​hwt​
其中AevwA_{e_{vw}}Aevw​​是evwe_{vw}evw​的一个可学习矩阵,每条边都会对应那么一个矩阵;

更新函数为:
Ut(hvt,mvt+1)=GRU(hvt,mvt+1)U_t(h_v^t,m_v^{t+1})=GRU(h_v^t,m_v^{t+1})Ut​(hvt​,mvt+1​)=GRU(hvt​,mvt+1​)

其中GRUGRUGRU为门控制单元(Gate Recurrent Unit)。该工作使用了权值捆绑,所以在每一个时间步 ttt 下都会使用相同的更新函数;

读出函数 R 为:
R=∑v∈Vσ(i(hv(T)),hv0)⊙(j(hv(T)))R = \sum_{v \in V} \sigma(i(h_v^{(T)}),h_v^0) \odot (j(h_v^{(T)}))R=v∈V∑​σ(i(hv(T)​),hv0​)⊙(j(hv(T)​))

其中iii和 jjj为神经网络,⊙\odot⊙表示元素相乘。

Paper 3 : Interaction Networks, Battaglia et al. (2016)

这篇论文考虑图中的节点和图结构,同时也考虑每个时间步下的节点级的影响。这种情况下更新函数的输入会多一些(hv,xv,mv)(h_v,x_v,m_v)(hv​,xv​,mv​)其中 xvx_vxv​ 是一个外部向量,表示对顶点 vvv 的一些外部影响。

这篇论文的消息函数M(hv,hw,evw)M(h_v,h_w,e_{vw})M(hv​,hw​,evw​) 是一个以 (hv,hw,evw)(h_v,h_w,e_{vw})(hv​,hw​,evw​)为输入的神经网络,节点更新函数U(hv,xv,mv)U(h_v,x_v,m_v)U(hv​,xv​,mv​)是一个以(hv,xv,mv)(h_v,x_v,m_v)(hv​,xv​,mv​)为输入的神经网络,最终会有一个图级别的输出 R=f(∑v∈GhvT)R=f(\sum_{v\in G}h_v^T)R=f(∑v∈G​hvT​),其中 fff 是一个神经网络,输入是最终的隐藏层状态的和。在原论文中T=1T=1T=1

Paper 4 : Molecular Graph Convolutions, Kearnes et al. (2016)

这篇论文与其他 MPNN 稍微有些不同,主要区别在于考虑了边表示 ev,wte_{v,w}^tev,wt​,并且在消息传递阶段会进行更新。

消息传递函数用的是节点的消息:
Mt(hvt,hwt,evwt)=evwtM_t(h_v^t,h_w^t,e_{vw}^t)=e_{vw}^tMt​(hvt​,hwt​,evwt​)=evwt​

节点的更新函数为:
Ut(hvt,mvt+1)=α(W1(α(W0hvt),mvt+1))U_t(h_v^t,m_v^{t+1})=\alpha(W_1(\alpha(W_0h_v^t),m_v^{t+1}))Ut​(hvt​,mvt+1​)=α(W1​(α(W0​hvt​),mvt+1​))

其中(.,.)(.,.)(.,.)表示拼接(concat),α\alphaα为 ReLU 激活函数,W0,W1W_0,W_1W0​,W1​ 为可学习权重矩阵;

边状态的更新定义为:
evwt+1=Ut′(evwt,hvt,hwt)e_{vw}^{t+1}=U_t^{'}(e_{vw}^t,h_v^t,h_w^t)evwt+1​=Ut′​(evwt​,hvt​,hwt​) evwt+1=α(W4(α(W2,evwt),α(W3(hvt,hwt))))e_{vw}^{t+1}=\alpha(W_4(\alpha(W_2,e_{vw}^t),\alpha(W_3(h_v^t,h_w^t))))evwt+1​=α(W4​(α(W2​,evwt​),α(W3​(hvt​,hwt​))))

其中,WiW_iWi​为可学习权重矩阵。

Paper 5 : Deep Tensor Neural Networks, Schutt et al. (2017)

消息函数为:
Mt=tanh(Wfc((Wcfhwt+b1)⊙(Wdfevw+b2)))M_t = tanh(W^{fc}((W^{cf}h_w^t+b1)\odot(W^{df}e_{vw}+b2)))Mt​=tanh(Wfc((Wcfhwt​+b1)⊙(Wdfevw​+b2)))
其中 Wfc,Wcf,WdfW^{fc},W^{cf},W^{df}Wfc,Wcf,Wdf为矩阵,b1,b2b_1,b_2b1​,b2​ 为偏置向量;

更新函数为:
Ut(hvt,mvt+1)=hvt+mvt+1U_t(h_v^t,m_v^{t+1})=h_v^{t}+m_v^{t+1}Ut​(hvt​,mvt+1​)=hvt​+mvt+1​

读出函数通过单层隐藏层接受每个节点并且求和后输出:
R=∑vNN(hvT)R = \sum_vNN(h_v^T)R=v∑​NN(hvT​)

Paper 6 : Laplacian Based Methods, Bruna et al. (2013); Defferrard et al. (2016); Kipf & Welling (2016)

基于拉普拉斯矩阵的方法将图像中的卷积运算扩展到网络图 GGG 的邻接矩阵 AAA 中。

在 Bruna et al. (2013); Defferrard et al. (2016); 的工作中,消息函数为:
Mt(hvt,hwt)=CvwthwtM_t(h_v^t,h_w^t)=C_{vw}^th_w^tMt​(hvt​,hwt​)=Cvwt​hwt​

其中,矩阵 CvwtC_{vw}^tCvwt​为拉普拉斯矩阵 LLL 的特征向量组成的矩阵;

节点的更新函数为:
Ut(hvt,mvt+1)=σ(mvt+1)U_t(h_v^t,m_v^{t+1})=\sigma(m_v^{t+1})Ut​(hvt​,mvt+1​)=σ(mvt+1​)

其中,σ\sigmaσ为非线性的激活函数,比如说 ReLU。

在 Kipf & Welling (2016) 的工作中,消息函数为:
Mt(hvt,hwt)=CvwthwtM_t(h_v^t,h_w^t)=C_{vw}^th_w^tMt​(hvt​,hwt​)=Cvwt​hwt​
其中,Cvw=(deg(v)deg(w))−1/2AvwC_{vw}=(deg(v)deg(w))^{-1/2}A_{vw}Cvw​=(deg(v)deg(w))−1/2Avw​;

节点的更新函数为:
Uvt(hvt,mvt+1)=ReLU(Wtmvt+1)U_v^t(h_v^t,m_v^{t+1})=ReLU(W^tm_v^{t+1})Uvt​(hvt​,mvt+1​)=ReLU(Wtmvt+1​)

可以看到以上模型都是 MPNN 框架的不同实例,所以作者呼吁大家应该致力于将这一框架应用于某个实际应用,并根据不同情况对关键部分进行修改,从而引导模型的改进,这样才能最大限度的发挥模型的能力。

2. MPNN Variants

本节来介绍下作者将 MPNN 框架应用于分子预测领域,提出了 MPNN 的变种,并以 QM9 数据集为例进行了实验。

QM9 数据集中的分子大多数由碳氢氧氮等元素组成,并组成了约 134k 个有机分子,可以划分为四大类(具体类别不介绍了),任务是根据分子结构预测分子所属类别。

作者主要是基于 GG-NN 来探索 MPNN 的多种改进方式(不同的消息函数、输出函数等),之所以用 GG-NN 是因为这是一个很强的 baseline。

2.1 Message Functions

首先来看下消息函数,可以以 GG-NN 中使用的消息函数开始,GG-NN 用的是矩阵乘法:
M(hv,hw,evw)=AevwhwM(h_v,h_w,e_{vw}) = A_{e_{vw}}h_wM(hv​,hw​,evw​)=Aevw​​hw​

为了兼容边特征,作者提出了新的消息函数:
M(hv,hw,evw)=A(evw)hwM(h_v,h_w,e_{vw}) = A({e_{vw}})h_wM(hv​,hw​,evw​)=A(evw​)hw​
其中,A(evw)A(e_{vw})A(evw​)是将边的向量evwe_{vw}evw​映射到 d×d 维矩阵的神经网络。

矩阵乘法有一个特点,从节点 w 到节点 v 的函数仅与隐藏层状态hwh_whw​和边向量evwe_{vw}evw​有关,而和隐藏状态hvth_v^thvt​无关。理论上来说,如果节点消息同时依赖于源节点 w 和目标节点 v 的话,网络的消息通道将会得到更有效的利用。所以也可以尝试去使用一种消息函数的变种:
mvw=f(hwt,hvt,evw)m_{vw}=f(h_w^t,h_v^t,e_{vw})mvw​=f(hwt​,hvt​,evw​)其中,fff 为神经网络。

2.2 Virtual Graph Elements

其次看来下消息传递,作者探索了两种不同的消息传递方式。

最简单的修改就是为没有连接的节点添加一个虚拟的边,这样消息便具有了更长的传播距离;

此外,作者也尝试了使用潜在的“主”节点(master node),这个节点可以通过特殊的边来连接到图中任意一个节点。主节点充当了一个全局的暂存空间,每个节点都会在消息传递过程中通过主节点进行读取和写入。同时允许主节点具有自己的节点维度,以及内部更新函数(GRU)的单独权重。其目的同样是为了在传播阶段传播很长的距离。

2.3 Readout Functions

然后来看下读出函数,作者同样尝试了两种读出函数:

首先是 GG-NN 中的读出函数:
R=∑v∈Vσ(i(hv(T),hv0)⊙(j)(hv(T)))R=\sum_{v \in V}\sigma(i(h_v^{(T)},h_v^0) \odot(j)(h_v^{(T)}))R=v∈V∑​σ(i(hv(T)​,hv0​)⊙(j)(hv(T)​))
此外也考虑 set2set 模型。set2set 模型是专门为在集合运算而设计的,并且相比简单累加节点的状态来说具有更强的表达能力。模型首先通过线性映射将数据映射到元组(hvt,xv)(h_v^t, x_v)(hvt​,xv​),并将投影元组作为输入T={(hvT,xv)}T=\{(h_v^T,x_v) \}T={(hvT​,xv​)}然后经过 MMM 步计算后,set2set 模型会生成一个与节点顺序无关的 Graph-level 的 embeedding 向量,从而得到我们的输出向量。

2.4 Multiple Towers

最后考虑下 MPNN 的伸缩性。

对一个稠密图来说,消息传递阶段的每一个时间步的时间复杂度为O(n2d2)O(n^2d^2)O(n2d2)其中 nnn 为节点数,ddd 为向量维度,可以看到时间复杂度还是非常高的。
为了解决这个问题作者将向量维度 ddd 拆分成 kkk 份,就变成了 kkk 个 d/kd/kd/k 维向量,并在传播过程中每个子向量分别进行传播和更新,最后再进行合并。此时的子向量时间复杂度为 O(n2(d/k)2)O(n^2(d/k)^2)O(n2(d/k)2)考虑 kkk 个子向量的时间复杂度为O(n2d2/k)O(n^2d^2/k)O(n2d2/k)

3.Input Representation

这一节主要介绍 GNN 的输入。

对于分子来说有很多可以提取的特征,比如说原子组成、化学键等,详细的特征列表如下图所示:
对于邻接矩阵,作者模型尝试了三种边表示形式:

**化学图(Chemical Graph):**在不考虑距离的情况下,邻接矩阵的值是离散的键类型:单键,双键,三键或芳香键;

**距离分桶(Distance bins):**基于矩阵乘法的消息函数的前提假设是边信息是离散的,因此作者将键的距离分为 10 个 bin,比如说 [2,6][2,6][2,6] 中均匀划分 8 个 bin,[0,2][0,2][0,2] 为 1 个 bin,[6,+∞][6, +∞][6,+∞] 为 1 个 bin;

**原始距离特征(Raw distance feature):**也可以同时考虑距离和化学键的特征,这时每条边都有自己的特征向量,此时邻接矩阵的每个实例都是一个 5 维向量,第一维是距离,其余思维是四种不同的化学键。

4. Experiment

来看一下实验结果,以 QM-9 数据集为例,共包含 130462 个分子,以 MAE 为评估指标。

下图为现有算法和作者改进的算法之间的对比:

下图为不考虑空间信息的结果:

下图为考虑多塔模型和结果:

5.Conclusion

总结:作者从诸多模型中抽离出了 MPNN 框架,并且通过实验表明,具有消息函数、更新函数和读出函数的 MPNN 具有良好的归纳能力,可以用于预测分析特性,优于目前的 Baseline,并且无需进行复杂的特征工程。此外,实验结果也揭示了全局主节点和利用 set2set 模型的重要性,多塔模型也使得 MPNN 更具伸缩性,方便应用于大型图中。

论文笔记:Neural Message Passing for Quantum Chemistry相关推荐

  1. 消息传递框架MPNN: Neural Message Passing for Quantum Chemistry

    来源:ICML 2017 论文链接: https://arxiv.org/abs/1704.01212 代码: https://github.com/ifding/graph-neural-netwo ...

  2. 论文笔记Neural Ordinary Differential Equations

    论文笔记Neural Ordinary Differential Equations 概述 参数的优化 连续标准化流(Continuous Normalizing Flows) 生成式的隐轨迹时序模型 ...

  3. toch_geometric 笔记:message passing GCNConv

    1 message passing介绍 将卷积算子推广到不规则域通常表示为一个邻域聚合(neighborhood aggregation)或消息传递(message passing )方案       ...

  4. 推荐系统论文笔记---Neural News Recommendation with Attentive Multi-View Learning

    文章目录 一.概述 二.主要解决问题 三.解决思路 1.News Encoder 2.User Encoder 3.Click Predictor 4.Model Training 四.实验结果 一. ...

  5. 图深度学习论文笔记整理活动 | ApacheCN

    整体进度:https://github.com/apachecn/graph-emb-dl-notes/issues/1 贡献指南:https://github.com/apachecn/graph- ...

  6. 浅谈Spherical Message Passing for 3D Graph Networks

    目录 ​背景 MPNN SMPNN Definition of Graph in paper Computational Steps Spherical message passing New Com ...

  7. 【论文阅读笔记 KDD2021】《Relational Message Passing for Knowledge Graph Completion》

    论文链接:https://arxiv.org/pdf/2002.06757.pdf 代码和数据集:https://github.com/hwwang55/PathCon 文章目录 ABSTRACT 1 ...

  8. 论文笔记 Semantics-Guided Neural Networks for Efficient Skeleton-Based Human Action Recognition - CVPR

    Semantics-Guided Neural Networks for Efficient Skeleton-Based Human Action Recognition 2020 CVPR | c ...

  9. 论文笔记:FILLING THE G AP S: MULTIVARIATE TIME SERIES IMPUTATION BY GRAPH NEURAL NETWORKS

    0 abstract & introduction 之前的补全方法并不能很好地捕获/利用 不同sensor之间的非线性时间/空间依赖关系 高效的时间序列补全方法,不仅应该考虑过去(或者未来)的 ...

最新文章

  1. java可变长字符串类型,Java 常用类——StringBufferamp;StringBuilder【可变字符序列】_IT技术_软件云...
  2. 如何在Marketing Cloud Launchpad里创建新的tile
  3. .NET 实现并行的几种方式(二)
  4. mysql有dataguard吗_Oracle查看是否搭建DataGuard
  5. svn 导入的 web项目怎么变成了java项目了
  6. 音视频编解码:NVIDIA Jetson Linux Multimedia API(总结)
  7. AngularJS的ng-click阻止冒泡
  8. java获取电脑配置_Java.Utils:获取电脑配置信息
  9. 计算机新建文件夹的步骤打开,如何制作文件夹!(新建文件夹的操作步骤)
  10. 字符数组初始化c语言,C语言字符数组
  11. esp8266 windows烧录问题
  12. 本地数据库IndexedDB - 初学者
  13. CFI Flash, JEDEC Flash ,Parellel Flash, SPI Flash, Nand Flash,Nor Flash的区别和联系
  14. 对比学习(contrastive learning)
  15. C# System.BadImageFormatException 解决方法
  16. 龙门标局商标SaaS服务系统,商标知产业务模块功能如此强大!
  17. width mismatch when connecting input pin '/processing system 7_0/irq_f2p'(2) to net 'xlconcat_0_dout
  18. 【转】MUD教程--巫师入门教程2
  19. python100例排列组合_Python列表list排列组合操作示例
  20. Android语音转文字一识别语音

热门文章

  1. 上手机器学习系列-第3篇(上)-聊聊logistic回归
  2. 智慧新零售网络解决方案,助力新零售企业数智化转型
  3. python与excel常用的第三方库_Python读写Excel文件第三方库汇总
  4. 计算机网络知识点总结(每日更新)
  5. python去除视频马赛克_DeepMosaics
  6. 关于低功耗输电线路在线监测摄像头,你知道多少
  7. 学会清除上网记录,防范于千里之外
  8. 2022 年全国职业院校技能大赛(中职组) 网络安全竞赛试题A 模块评分标准
  9. rtsp协议中数据的分包
  10. macOS+matlab 2020b matlab_bgl工具箱使用时 MEX文件编译出错