3D Infomax improves GNNs for Molecular Property Prediction

出处

  • 作者:Hannes Stark等
  • 机构:Massachusetts Institute of Technology等
  • 期刊:Proceedings of the 39 th International Conference on Machine Learning,2022/06/04
  • Code :github

摘要

  • 使用现有的三维分子数据集来预先训练一个模型,以推理出仅有二维分子图的分子的几何形状。
  • 模型的名称为 3D Infomax,最大化学习到的3D summary vector和GNN的表征之间的相互信息(mutual info)。
  • 使用未知几何形状的分子进行微调,GNN仍然能够提供一些隐性的3D信息并用于下游任务。
  • 在很多属性上有着较大的进步,比如在QM9量子力学特性上,MAE减少了22%

介绍

现有的分子特性预测方法和3D infomax的动机:

  • 标准方法:利用GNN和2D的分子图,结果快但差;
  • 显性的3D方法:使用经典的方法或者机器学习的方法计算3D坐标,然后作为输入进行预测。结果准确但是对于实际应用来说计算坐标太慢。
  • 3D Infomax:① 预训练:用一个2D网络对有3D信息的分子进行训练,得到有着隐性3D信息的表征。② 将2D网络的参数微调。结果真是又快又好。

背景

  • 2D分子图

    • G=(ν,ϵ),其中ν是节点−原子,ϵ是边−共价键G=(\nu,\epsilon),其中\nu是节点-原子,\epsilon 是边-共价键G=(ν,ϵ),其中ν是节点原子,ϵ是边共价键,边可以包含键类型信息,节点可以包含一些特征数据,比如原子编号,但至此都无3D坐标信息。
  • 3D 分子构象
    • 不同构象会带来不同化学性质,为保证抓住3D信息,所以需要考虑几乎所有的构象。
    • 当考虑一个分子ccc个已知的构象时,把他们表示成一组点云{R}j1⋯cj\{R\}^j_{j1\cdots c}{R}j1cj。每个点云R={rv}v∈νR=\{r_v\}_{v \in \nu}R={rv}vν表示分子中所有原子ν\nuν的坐标(即一组点云是一个构象的所有坐标点集合)。
    • RD-Kit的ETKDG算法能快速计算构象但是不准确;最流行的是CREST,速度和准确率兼备,但仍然需要大约6小时(per cpu)完成一个药类分子的计算。
  • 分子的对称性
    • 当所有原子坐标 jointly translated或者围着一点旋转(SE(3)对称),那么分子的构象就不会改变。同时,分子的性质会被他们的手性决定。我们的方法也能在表征中体现对称性。
  • 图神经网络
    • 大部分GNN可以被一个MPNN框架描述,比如我们用的PNA模型。
    • MPNN的目的是为了学习一个图的表征。他们通过不断迭代地去应用消息传递层,然后将所有点的表征结合。一个消息传递层通过使用置换不变性函数(mean,max,sum,不管数据如何置换位置,结果都不变)计算该点的邻点和其之间的边的值以用于更新该点信息。消息传递层之后,另一个置换不变性函数被用于提取点层的embedding到图层的embedding。

相关文献

  • 分子属性预测

    • MPNN框架问世之后,GNN就被广泛运用于量子化学、药物发现和分子性质的预测。利用3D信息的一个简单方法就是利用键长作为边信息(SchNet);DimeNet提取键角;SMP又包含另一个角度信息;GemNet也提取扭角,这样所有原子的相对位置就都被定义了;EGNN则是使用成对的原子距离。
  • 自监督学习
    • 对比学习是在计算机视觉中比较流行的自监督方法,通过对比相似的输入和不相似的输入的embedding来学习表征。
    • 分子化学的数据集很小所以自监督学习很重要。不少研究者也利用对比学习来表征分子性质,但是很有限制也难以泛化。
    • 过往大多是使用2D,我们和GraphMVP都是额外利用了3D结构来获取更多信息的表征。GraphMVP提出了一个生成式和一个对比式的3D预训练模型。生成模型可以纳入多个构象信息,然后利用对比式预训练提高下游任务表现。
    • 我们与GraphMVP的区别:3D infomax不需要额外的生成预训练任务,直接在一个新的对比损失函数中包含了这些信息。此外,我们的评估包括量子力学任务,我们发现在这个领域可能的改进比非量子属性的改进要大得多。

方法(3D Infomax)

使用对比学习完成了输入为2D信息,但可以推断出3D几何信息的模型。预训练模型为PNA(将图多种方法聚合,sota和简单)

图1

  • 最大化一个使用2D分子图的2D GNN和使用3D构象的3D GNN之间的互信息(在机器学习中,理想情况下,当互信息最大,可以认为从数据集中拟合出来的随机变量的概率分布与真实分布相同。)
  • 与训练之后,我们将它迁移到属性预测任务,并微调。在微调过程中,GNN产出的3D信息会被用于提高预测。
  • 在图1中,有两个模型。需要预训练的是2D网络 networkfanetwork f^anetworkfa,它可以产出一个表征fa(G)=za∈Rdzf^a(G)=z^a \in \R^{d_z}fa(G)=zaRdz;另一个将R={rv}v∈νR=\{r_v\}_{v \in \nu}R={rv}vν编码的3D网络networkfbnetwork f^bnetworkfb给出一个表征fb(G)=zb∈Rdzf^b(G)=z^b \in \R^{d_z}fb(G)=zbRdz。可以当成是一个对比蒸馏,因为student 2D网络可以从teacher 3D网络那学会生产3D信息。
    对比框架
    为了教会2D networkfanetwork f^anetworkfa从2D图输入中学到3D信息,我们最大化了潜在2D表征zaz^aza和3D表征zbz^bzb的互信息。因为当两者来自同一个分子,那我们希望zaz^azazbz^bzb尽可能的一致,所以利用了图2的对比学习。
    对于一组batch,中间包含N个分子图{Gi}i∈{1⋯N}\{G_i\}_{i \in \{1\cdots N\}}{Gi}i{1N},点坐标{Ri}i∈{1⋯N}\{R_i\}_{i \in \{1\cdots N\}}{Ri}i{1N},然后得到多个表征ziaz_i^aziazibz_i^bzib
    图2
    对比学习的第一个目的就是如果两者是正样本对,那么就要最大化表征的相似度,表示他们是来自同一个分子(同一个index i)。第二个目标就是强迫负样本对ziaz_i^aziazkb,i≠kz_k^b, i\ne kzkb,i=k不相似。
    这两个目标都是通过修改NTXent loss实现(如何实现相似的越相似)的:

    其中

    是余弦相似度,τ\tauτ是一个温度参数(超参),可以当做最相似的负样本对的权重(也就是当负样本对很相似时,调整τ\tauτ)。不同的对比损失组合和自监督学习是有可能学会一个2D和3D表征之间的联合嵌入空间,上面的函数是表现最好的。
    使用多构象
    使用ccc个最高概率的第 i 个构象{Rij}j∈{1⋯c}\{R_i^j\}_{j \in \{1\cdots c\}}{Rij}j{1c},如果不够c个就把能量最低的重复。图2右边,就是将分子的2D表征和每一个构象进行比较。

3D网络

3D网络输入是原子坐标作为点云,然后输出一个置换不变的表征zbz^bzb,尽可能多的把3D结构信息编码,但是不能够接触2D信息,不然的话互信息可能会因为两个模型的交互变得更大。
我们的模型将每一对atom的欧几里得距离进行编码,这样表征可以定义所有原子的相对位置并且保证旋转平移不变性,并且也是反射不变的,但是对手性分子没办法区别。
duvd_{uv}duv表示u原子和v原子之间的距离,会先使用高频的sine和cosine去投影到一个高维空间(因为键长之间的区别比较小)。然后以F=4的频率map(有点类似position encoding),更详细的操作可见MPNN框架。

数据

  • 3D数据集是来自QM9(134k个平均18个原子的只有一个构象的小分子,kaggle下载)、GEOM-Drugs(304k)和QMugs(665k)。后两者有较大的且是多构象的药类分子(44.4和30.6平均原子个数)
  • 微调:预测十个来自QM9和GEOM-Drug的量子特性,这些数据不与预训练的数据相交。
  • 预训练用了50k单构象来自QM9,140k5构象来自GEON-Drugs,620k3构象来自QMugs

对比

Baseline

  • 距离预测:使用已有的最低能量构象去预训练一个GNN,以直接预测所有原子之间的距离。然后将任意两个u,v原子的表征简单地拼接在一起(uv,vu),随后放入mlp(U,直接降到1维),||表示拼接

    softplus(x)=log(1+ex)softplus(x) = log(1 + e^x )softplus(x)=log(1+ex),loss function是MSE。
  • 构象生成:GeoMol(sota生成分子构象的模型),一个生成式模型,产生一个分子的可能的3D结构的分布,从而获取到多构象信息。利用他们模型做预训练任务然后提取网络用于不同下游任务
  • GraphCL:一个卷积增强预训练模型with JOAO配置,模型通过学习产出一个对增强不变的表征来完成自监督目标。
    结果
    数值为MAE,RAND INIT模型随机初始参数,PROPRED指用GEOM-Drugs的Gibbs自由能来做的预训练,DISPRED指用有最高概率的构象去预测所有原子的距离,CONFGEN指与训练的时候预测10个构象,3D Infomax分别使用三个数据集做预训练,RDKIT SMP使用RDKit生成的3D坐标输入SMP(一个GNN)做训练,True 3D SMP最后一列是用真实的3d坐标使用SMP预测的,蓝色表示improvement,橙色表示worse。

    对QM9数据集中的8中特性做预测

结论

相当于一个2D分子图的预训练模型,能得到隐含3D信息的表征,并且具有一定泛化能力(不会有负迁移),可以借助同一个分子的多构象信息来帮助下游属性预测任务。

【读文献】3D Infomax 小分子预训练模型相关推荐

  1. 【综述】分子预训练模型综述

    A Systematic Survey of Molecular Pre-trained Models 目录 总结 一.Introduction 二.Molecular Descriptors 三.P ...

  2. 一文读懂最强中文NLP预训练模型ERNIE

    基于飞桨开源的持续学习的语义理解框架ERNIE 2.0,及基于此框架的ERNIE 2.0预训练模型,在共计16个中英文任务上超越了BERT和XLNet, 取得了SOTA效果.本文带你进一步深入了解ER ...

  3. 腾讯优图开源业界首个3D医疗影像大数据预训练模型

    整理 | Jane出品 | AI科技大本营(ID:rgznai100) 近日,腾讯优图首个医疗AI深度学习预训练模型 MedicalNet 正式对外开源.这也是全球第一个提供多种 3D 医疗影像专用预 ...

  4. 全球首个!腾讯优图开源3D医疗影像大数据预训练模型

    点击我爱计算机视觉标星,更快获取CVML新技术 本文转载自腾讯优图. 近日,腾讯优图首个医疗AI深度学习预训练模型MedicalNet正式对外开源.这也是全球第一个提供多种3D医疗影像专用预训练模型的 ...

  5. 香侬读 | 让预训练模型学习知识:使用多学习器增强知识建模能力

    论文标题: K-Adapter: Infusing Knowledge into Pre-Trained Models with Adapters 论文作者: Ruize Wang, Duyu Tan ...

  6. 让预训练模型学习知识:使用多学习器增强知识建模能力

    论文标题: K-Adapter: Infusing Knowledge into Pre-Trained Models with Adapters 论文作者: Ruize Wang, Duyu Tan ...

  7. 翟季冬:基于国产超算的百万亿参数超大预训练模型训练方法

    [前沿进展]训练参数规模万亿的预训练模型,对于超级计算机而言是不小的挑战.如何提升超算的计算效率,实现更大规模的参数训练,成为近年来研究者探索的课题.在近日举办的Big Model Meetup第二期 ...

  8. NLP千亿预训练模型的“第四范式”之Prompt Learning介绍分享

    点击上方"AI遇见机器学习",选择"星标"公众号 重磅干货,第一时间送达 论文转载自知乎专栏:ai炼丹师 作者:避暑山庄梁朝伟 一.背景 随着GPT-3诞生,最 ...

  9. 腾讯开源首个医疗AI项目,业内首个3D医疗影像大数据预训练模型

    乾明 发自 凹非寺  量子位 报道 | 公众号 QbitAI 腾讯AI,开源又有新动作. 旗下顶级AI实验室腾讯优图,对外开源了腾讯首个医疗AI项目--深度学习预训练模型MedicalNet. 这一项 ...

最新文章

  1. SpringBoot+Mybatis+ Druid+PageHelper 实现多数据源并分页
  2. 2005年上半年 网络工程师 上下午试卷【附带答案】
  3. Android 基本 Jackson Marshalling(serialize)/Unmarshalling(deserialize)
  4. Paddington2
  5. USACO3.32Shopping Offers(DP)
  6. [云炬创业基础笔记]第五章创业机会评估测试11
  7. MikroTik RouterOS x86最大内存只能支持2G
  8. KAFKA 最新版 单机安装、配置、部署(linux环境)
  9. c语言case标号是连续的吗,在switch语句中,case后的标号只能是什么?_后端开发...
  10. BGP链路冗余使用直接接口和回环口分析
  11. Qt 学习之路 :访问网络(4)
  12. 灰度世界算法(Gray World Algorithm)
  13. 手把手写Demo系列之车道线检测
  14. 计算机恢复原始桌面图标,Win10桌面图标如何恢复原来排列?
  15. Android开根号运算
  16. nas 和 远程文件夹同步_我应该如何使用Qsync来同步我计算机和NAS上的档案?
  17. Unity跳一跳小游戏简单代码
  18. pd.diff()函数详解
  19. Linux下查看CPU核数
  20. Chapter7.1:频域分析法理论基础

热门文章

  1. 打开chm文件,不显示文件内容的解决办法
  2. Centos7 动起来
  3. 巴比特 | 元宇宙每日必读:美版权局判定用AI工具生成的图片不受版权保护,官方解释:AI生成具有不可预测性,但并非一刀切...
  4. 雷铭电子商务系统7.0【开源版本】全面介绍
  5. 数字信号处理X——MBD开发流程与自动化测试
  6. CentOS个人版 各种软件安装
  7. typecho图标_Typecho浏览器图标favicon.ico添加方法 - 新手站长网
  8. 【PSO三维路径规划】基于matlab粒子群算法无人机山地三维路径规划【含Matlab源码 1405期】
  9. 【Java基础】循环、嵌套、跳转控制break/continue、调试器、函数
  10. java 网络抓包_基于java的网络抓包方法