TorchDrug教程–逆合成

教程来源TorchDrug开源

目录

  • TorchDrug安装
  • 分子数据结构
  • 属性预测
  • 预训练的分子表示
  • 分子生成
  • 逆合成
  • 知识图推理

反合成是药物发现的一项基本任务。给定一个目标分子,反合成的目标是确定一组可以产生目标的反应物。

在这个例子中,我们将展示如何使用G2Gs框架预测逆合成。G2Gs首先识别反应中心,即产物中产生的键。根据反应中心,产物被分解成几个合成子,每个合成子被转化为一个反应物。

准备数据

我们使用标准USPTO50k数据集。该数据集包含50k分子及其合成途径。首先,让我们下载并加载数据集。这可能需要一段时间。有两种模式来加载数据集。reaction模式将数据集加载为(reactants, product)对,用于中心识别。synthon模式将数据集作为(reactantsynthon)对加载,用于synthon完成。

from torchdrug import data, datasets, utils
reaction_dataset = datasets.USPTO50k("~/molecule-datasets/",atom_feature="center_identification",kekulize=True)
synthon_dataset = datasets.USPTO50k("~/molecule-datasets/", as_synthon=True,atom_feature="synthon_completion",kekulize=True)

然后我们将数据集中的一些样本可视化。对于反应数据集,我们可以使用connected components()将反应物图和生成物图拆分为单个分子。注意USPTO50k忽略了所有非目标产品,所以右边只有一个产品。

from torchdrug.utils import plotfor i in range(2):sample = reaction_dataset[i]reactant, product = sample["graph"]reactants = reactant.connected_components()[0]products = product.connected_components()[0]plot.reaction(reactants, products)


下面是synthon数据集中对应的示例。

for i in range(3):sample = synthon_dataset[i]reactant, synthon = sample["graph"]plot.reaction([reactant], [synthon])


为了确保两个数据集使用相同的split,我们可以在调用split()之前设置随机种子。

import torchtorch.manual_seed(1)
reaction_train, reaction_valid, reaction_test = reaction_dataset.split()
torch.manual_seed(1)
synthon_train, synthon_valid, synthon_test = synthon_dataset.split()

中心识别

现在我们定义我们的模型。我们使用一个关系图卷积网络(RGCN)作为我们的表示模型,并包装它来完成中心识别任务。注意,这里也可以使用其他图表示学习模型。

from torchdrug import core, models, tasksreaction_model = models.RGCN(input_dim=reaction_dataset.node_feature_dim,hidden_dims=[256, 256, 256, 256, 256, 256],num_relation=reaction_dataset.num_bond_type,concat_hidden=True)
reaction_task = tasks.CenterIdentification(reaction_model,feature=("graph", "atom", "bond"))
reaction_optimizer = torch.optim.Adam(reaction_task.parameters(), lr=1e-3)
reaction_solver = core.Engine(reaction_task, reaction_train, reaction_valid,reaction_test, reaction_optimizer,gpus=[0], batch_size=128)
reaction_solver.train(num_epoch=50)
reaction_solver.evaluate("valid")
reaction_solver.save("g2gs_reaction_model.pth")

验证集上的计算结果可能如下所示

accuracy: 0.836367

我们可以从我们的模型中展示一些预测。为了多样性,我们收集了4种不同反应类型的样品。

batch = []
reaction_set = set()
for sample in reaction_valid:if sample["reaction"] not in reaction_set:reaction_set.add(sample["reaction"])batch.append(sample)if len(batch) == 4:break
batch = data.graph_collate(batch)
batch = utils.cuda(batch)
result = reaction_task.predict_synthon(batch)

下面的代码可视化了基本事实以及我们对样本的预测。我们用蓝色代表基本事实,红色代表错误的预测,紫色代表正确的预测。

def atoms_and_bonds(molecule, reaction_center):is_reaction_atom = (molecule.atom_map > 0) & \(molecule.atom_map.unsqueeze(-1) == \reaction_center.unsqueeze(0)).any(dim=-1)node_in, node_out = molecule.edge_list.t()[:2]edge_map = molecule.atom_map[molecule.edge_list[:, :2]]is_reaction_bond = (edge_map > 0).all(dim=-1) & \(edge_map == reaction_center.unsqueeze(0)).all(dim=-1)atoms = is_reaction_atom.nonzero().flatten().tolist()bonds = is_reaction_bond[node_in < node_out].nonzero().flatten().tolist()return atoms, bondsproducts = batch["graph"][1]
reaction_centers = result["reaction_center"]for i, product in enumerate(products):true_atoms, true_bonds = atoms_and_bonds(product, product.reaction_center)true_atoms, true_bonds = set(true_atoms), set(true_bonds)pred_atoms, pred_bonds = atoms_and_bonds(product, reaction_centers[i])pred_atoms, pred_bonds = set(pred_atoms), set(pred_bonds)overlap_atoms = true_atoms.intersection(pred_atoms)overlap_bonds = true_bonds.intersection(pred_bonds)atoms = true_atoms.union(pred_atoms)bonds = true_bonds.union(pred_bonds)red = (1, 0.5, 0.5)blue = (0.5, 0.5, 1)purple = (1, 0.5, 1)atom_colors = {}bond_colors = {}for atom in atoms:if atom in overlap_atoms:atom_colors[atom] = purpleelif atom in pred_atoms:atom_colors[atom] = redelse:atom_colors[atom] = bluefor bond in bonds:if bond in overlap_bonds:bond_colors[bond] = purpleelif bond in pred_bonds:bond_colors[bond] = redelse:bond_colors[bond] = blueplot.highlight(product, atoms, bonds, atom_colors, bond_colors)

合成纤维完成

类似地,我们在synthon数据集上训练synthon完成模型。

synthon_model = models.RGCN(input_dim=synthon_dataset.node_feature_dim,hidden_dims=[256, 256, 256, 256, 256, 256],num_relation=synthon_dataset.num_bond_type,concat_hidden=True)
synthon_task = tasks.SynthonCompletion(synthon_model, feature=("graph",))
synthon_optimizer = torch.optim.Adam(synthon_task.parameters(), lr=1e-3)
synthon_solver = core.Engine(synthon_task, synthon_train, synthon_valid,synthon_test, synthon_optimizer,gpus=[0], batch_size=128)
synthon_solver.train(num_epoch=10)
synthon_solver.evaluate("valid")
synthon_solver.save("g2gs_synthon_model.pth")

我们可以得到一些结果

bond accuracy: 0.983013
node in accuracy: 0.967535
node out accuracy: 0.892999
stop accuracy: 0.929348
total accuracy: 0.844374

然后,我们执行束搜索,以产生候选反应物。

batch = []
reaction_set = set()
for sample in synthon_valid:if sample["reaction"] not in reaction_set:reaction_set.add(sample["reaction"])batch.append(sample)if len(batch) == 4:break
batch = data.graph_collate(batch)
batch = utils.cuda(batch)
reactants, synthons = batch["graph"]
reactants = reactants.ion_to_molecule()
predictions = synthon_task.predict_reactant(batch, num_beam=10, max_prediction=5)synthon_id = -1
i = 0
titles = []
graphs = []
for prediction in predictions:if synthon_id != prediction.synthon_id:synthon_id = prediction.synthon_id.item()i = 0graphs.append(reactants[synthon_id])titles.append("Truth %d" % synthon_id)i += 1graphs.append(prediction)if reactants[synthon_id] == prediction:titles.append("Prediction %d-%d, Correct!" % (synthon_id, i))else:titles.append("Prediction %d-%d" % (synthon_id, i))# reset attributes so that pack can work properly
mols = [graph.to_molecule() for graph in graphs]
graphs = data.PackedMolecule.from_molecule(mols)
graphs.visualize(titles, save_file="uspto50k_synthon_valid.png", num_col=6)

逆合成

给定训练过的模型,我们可以将它们组合成一个端点管道进行逆向合成。这是通过将两个子任务包裹在一个逆合成任务中来完成的。

注意,如果您从未声明reaction_tasksynthon_task的求解器,那么在将它们组合到管道中之前,您需要手动调用它们的preprocess()方法。

# reaction_task.preprocess(reaction_train, None, None)
# synthon_task.preprocess(synthon_train, None, None)
task = tasks.Retrosynthesis(reaction_task, synthon_task, center_topk=2,num_synthon_beam=5, max_prediction=10)

管道将对来自两个子任务的预测之间的所有可能组合执行波束搜索。为了演示,我们使用一个较小的光束尺寸,并且只对验证集的子集进行评估。注意,如果我们给光束搜索更多的预算,结果会更好。

from torch.utils import data as torch_datalengths = [len(reaction_valid) // 10,len(reaction_valid) - len(reaction_valid) // 10]
reaction_valid_small = torch_data.random_split(reaction_valid, lengths)[0]optimizer = torch.optim.Adam(task.parameters(), lr=1e-3)
solver = core.Engine(task, reaction_train, reaction_valid_small, reaction_test,optimizer, gpus=[0], batch_size=32)

要加载两个子任务的参数,我们只需load_optimizer。注意负载优化器应该设置为False以避免冲突。

solver.load("g2gs_reaction_model.pth", load_optimizer=False)
solver.load("g2gs_synthon_model.pth", load_optimizer=False)
solver.evaluate("valid")

反合成的准确性可能接近于以下

top-1 accuracy: 0.47541
top-3 accuracy: 0.741803
top-5 accuracy: 0.827869
top-10 accuracy: 0.879098

以下是验证集中样本的前1个预测

batch = []
reaction_set = set()
for sample in reaction_valid:if sample["reaction"] not in reaction_set:reaction_set.add(sample["reaction"])batch.append(sample)if len(batch) == 4:break
batch = data.graph_collate(batch)
batch = utils.cuda(batch)
predictions, num_prediction = task.predict(batch)products = batch["graph"][1]
top1_index = num_prediction.cumsum(0) - num_prediction
for i in range(len(products)):reactant = predictions[top1_index[i]].connected_components()[0]product = products[i].connected_components()[0]plot.reaction(reactant, product)

TorchDrug教程--逆合成相关推荐

  1. TorchDrug教程--预训练的分子表示

    TorchDrug教程–预训练的分子表示 教程来源TorchDrug开源 目录 TorchDrug安装 属性预测 预训练的分子表示 分子生成 逆合成 知识图推理 在许多药物发现任务中,收集标记数据在时 ...

  2. TorchProtein教程--预训练的蛋白质结构表示(5)

    TorchProtein教程–预训练的蛋白质结构表示(5) 本教程来自唐建团队的开源框架torchprotein 目录 torchprotein安装 蛋白质数据结构 基于序列的蛋白质特性预测 基于结构 ...

  3. 【模型复现】逆合成预测/文本分类模型——MeGAN 快速复现

    MeGAN 快速复现教程 01 镜像详情 镜像简介: 模型论文2021年5月发表在JCIM上的关于逆合成路线规划一篇文章,标题为<Molecule Edit Graph Attention Ne ...

  4. 使用Docker搭建svn服务器教程

    使用Docker搭建svn服务器教程 svn简介 SVN是Subversion的简称,是一个开放源代码的版本控制系统,相较于RCS.CVS,它采用了分支管理系统,它的设计目标就是取代CVS.互联网上很 ...

  5. mysql修改校对集_MySQL 教程之校对集问题

    本篇文章主要给大家介绍mysql中的校对集问题,希望对需要的朋友有所帮助! 推荐参考教程:<mysql教程> 校对集问题 校对集,其实就是数据的比较方式. 校对集,共有三种,分别为:_bi ...

  6. mysql备份psb文件怎么打开_Navicat for MySQL 数据备份教程

    原标题:Navicat for MySQL 数据备份教程 一个安全和可靠的服务器与定期运行备份有密切的关系,因为错误有可能随时发生,由攻击.硬件故障.人为错误.电力中断等都会照成数据丢失.备份功能为防 ...

  7. php rabbmq教程_RabbitMQ+PHP 教程一(Hello World)

    介绍 RabbitMQ是一个消息代理器:它接受和转发消息.你可以把它当作一个邮局:当你把邮件放在信箱里时,你可以肯定邮差先生最终会把邮件送到你的收件人那里.在这个比喻中,RabbitMQ就是这里的邮箱 ...

  8. 【置顶】利用 NLP 技术做简单数据可视化分析教程(实战)

    置顶 本人决定将过去一段时间在公司以及日常生活中关于自然语言处理的相关技术积累,将在gitbook做一个简单分享,内容应该会很丰富,希望对你有所帮助,欢迎大家支持. 内容介绍如下 你是否曾经在租房时因 ...

  9. Google Colab 免费GPU服务器使用教程 挂载云端硬盘

    一.前言 二.Google Colab特征 三.开始使用 3.1在谷歌云盘上创建文件夹 3.2创建Colaboratory 3.3创建完成 四.设置GPU运行 五.运行.py文件 5.1安装必要库 5 ...

最新文章

  1. ubuntu for nvidia-drivers for AI
  2. python seek到指定行_python文件操作seek()偏移量,读取指正到指定位置操作
  3. 中青评论:家政本科招生难,专业名字误终身?
  4. [python网络编程]DNSserver
  5. 蓝桥学院2019算法题1.3
  6. 牛客网(剑指offer) 第十四题 链表中倒数第k个节点
  7. Android xUtils3.0使用手册(二) - 数据库操作
  8. 无法启动此程序因为计算机丢失gdiplus,gdiplus.dll 丢失
  9. ug冲模标准件库_基于UG建立模具标准件库
  10. Arduino DY-SV17F自动语音播报
  11. 离散数学知识点总结(3):等值演算,16个命题定律 / 基础等价式,重言式的替换规则,证明有效性和可满足性的方法
  12. 解析:浏览器事件冒泡及事件捕获
  13. Docker未授权漏洞复现(合天网安实验室)
  14. 解答为什么@Autowired使用在接口上而不是实现类上
  15. 计算机英语格式怎么写,26个英文字母,正确的书写格式,孩子真的会吗?
  16. 王道考研——操作系统(第一章 计算机系统概述)
  17. python语言程序设计基础考试题库_中国大学MOOC(慕课)_Python语言程序设计基础_测试题及答案...
  18. Excel导出带图片详解
  19. Linux系统编程 复习笔记
  20. 软件工程导论张海蕃书籍pdf_《软件工程导论》张海蕃 课后习题答案.docx

热门文章

  1. 千月影视乐彩影视,H5对接苹果CMS 安卓APP搜索接口苹果cms(2开苹果cms对接版H5数据)
  2. 丰和重仓股票本周涨幅
  3. 项目风险管理论文示例2
  4. 机械制图计算机类实验报告,工程制图与CAD实习实验报告模板
  5. 小路绫只会做料理 (ayaya)(树状数组 二分)
  6. VC,CString,UTF8与GBK互转
  7. 互联网远程办公 关于公司执行线上办公管理办法
  8. 基于暗通道先验的单幅图像去雾算法小结
  9. mysql6.5client下载_mysql-client多个版本客户端安装
  10. CAS4.0集成OpenLdap返回用户属性