使用图神经网络预测药物-药物相互作用

了解药物如何相互作用是医学研究和实践的关键问题。图形机器学习(GraphML)领域可用于以高可信度回答有关药物 - 药物相互作用的问题。本次学习通过GraphSAGE图卷积来预测DDI(药物-药物相互作用)。

使用图神经网络预测药物-药物相互作用

  • 使用图神经网络预测药物-药物相互作用
    • 1.数据
    • 2.定义词汇意义
    • 3.学习目标
    • 4.加载数据集
    • 5.消息传递概述
    • 6.数据集拆分
    • 7.模型构建
      • GraphSAGE构建
      • 链接预测器
    • 8.训练

1.数据

数据来源于Open Graph Benchmark (OGB) 是用于图形机器学习任务的开源基准数据集的集合。我们在本文中的重点是ogbl-ddi数据集,如上所述,它由单个药物 - 药物相互作用(DDI)网络组成。

我们在数学上将 DDI 图定义为 G = (V, E),其中 V 是节点集,E 是边的集合。图中的每个节点 v ∈ V 代表 FDA 批准或实验药物。两个节点u和v之间存在边缘(u,v)表明两种药物相互作用,使得同时服用两种药物的效果与药物彼此独立作用的预期效果有很大不同[1]。例如,靶向相同蛋白质的两种药物可能具有显着的相互作用。

2.定义词汇意义

首先,一些词汇,‘正边’是数据集中存在的边。每个正边代表由边缘端点表示的两种药物之间的已知显着相互作用。
如果两种药物不相互作用(即,它们一起服用与单独服用时具有相同的效果),则图表中不会存在边缘。这些“单独的节点”被称为负边;换句话说,就是图中不存在的边。

3.学习目标

本次学习的目标是开发一个图形机器学习模型来解决链接预测任务:给定两个药物作为输入,我们希望预测两种药物是否相互作用,即图中的这两个节点之间是否应该存在边缘。这应该允许我们通过将缺失的边缘理解为正边或负边来完成数据集。

4.加载数据集

按照 OGB 网站的示例,我们可以将 DDI 数据集加载到 PyTorch Geometric (PyG) 中:

from ogb.linkproppred import PygLinkPropPredDatasetdataset_name = 'ogbl-ddi'
dataset = PygLinkPropPredDataset(name=dataset_name)
print(f'The{dataset_name}dataset has{len(dataset)}graph(s).')
ddi_G = dataset[0]
print(f'DDI 图:{ddi_G}')
print(f'节点数量 |V|:{ddi_G.num_nodes}')
print(f'边的数量 |E|:{ddi_G.num_edges}')
print(f'无向图?:{ddi_G.is_undirected()}')
print(f'节点平均度:{ddi_G.num_edges / ddi_G.num_nodes:.2f}')
print(f'节点特征:{ddi_G.num_node_features}')
print(f'有孤立点?:{ddi_G.has_isolated_nodes()}')
print(f'有自循环?:{ddi_G.has_self_loops()}')
输出
Using backend: pytorch
Downloading http://snap.stanford.edu/ogb/data/linkproppred/ddi.zip
Downloaded 0.04 GB: 100%|██████████████████████████████████████████████████████████████| 46/46 [00:31<00:00,  1.45it/s]
Extracting dataset\ddi.zip
Processing...
Loading necessary files...
This might take a while.
Processing graphs...
100%|████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 58.81it/s]
Converting graphs into PyG objects...
100%|████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<?, ?it/s]
Saving...
The ogbl-ddi dataset has 1 graph(s).
Done!
DDI 图: Data(num_nodes=4267, edge_index=[2, 2135822])
节点数量 |V|: 4267
边的数量 |E|: 2135822
无向图?: True
节点平均度: 500.54
节点特征: 0
有孤立点?: False
有自循环?: False

注意:DDI数据集中,图没有节点特征,后续需要进行加工处理。

5.消息传递概述

每个 GNN“层”由在每个节点与其邻居之间传递的一轮此消息、每个节点接收的邻居消息的聚合以及使用聚合消息计算更新的嵌入来定义的。
消息传递是节点能够合并来自其局部邻域结构的信息以确定其自身嵌入的机制。它由生成消息的每个节点组成,该消息在回合中沿该节点的传出边缘传递给其他节点。
消息传递的PyG实现,以及其他知识可以看我这一篇PyG消息传递

6.数据集拆分

OGB 为我们提供了数据集拆分。正边是图中存在的边:集合 {(u,v)∈E} ,其中 u,v∈V ,负边是图中不存在的边:集合 {(u,v)∉E} 。 我们需要正边和负边来训练我们的链接预测模型。 数据集的拆分详情可以查看ogb官网,

split_edges = dataset.get_edge_split()
train_edges, valid_edges, test_edges = split_edges['train'], split_edges['valid'], split_edges['test']
print(train_edges)
print(f'训练集正边的数量:{train_edges["edge"].shape[0]}')
print(f'验证集正边的数量:{valid_edges["edge"].shape[0]}')
print(f'验证集负边的数量:{valid_edges["edge_neg"].shape[0]}')
print(f'测试集正边的数量:{test_edges["edge"].shape[0]}')
print(f'测试集负边的数量:{valid_edges["edge_neg"].shape[0]}')
输出
{'edge': tensor([[4039, 2424],[4039,  225],[4039, 3901],...,[ 647,  708],[ 708,  338],[ 835, 3554]])}
训练集正边的数量: 1067911
验证集正边的数量: 133489
验证集负边的数量: 101882
测试集正边的数量: 133489
测试集负边的数量: 101882

ddi_graph.edge_index 的形状为 [2, 2 * E] 因为我们在 GNN 中使用 edge_index,这需要从 u 向 v 和 v 向 u 发送信息(因为我们将在下面看到)。药物1对药物2有作用,相反药物2对药物1也有作用,作用是相互的。

7.模型构建

模型有两个部分: 1)图神经网络生成节点嵌入 2) 输出链接预测概率的深度神经网络

GraphSAGE构建

import torch
import torch_geometric
import torch.nn as nn
import torch.nn.functional as Ffrom torch_geometric.loader import DataLoader
from torch_geometric.nn import SAGEConv
from torch_geometric.utils import negative_sampling
from tqdm import trangeclass GraphSAGE(torch.nn.Module):"""使用 GraphSAGE 架构构建的图神经网络。"""def __init__(self, conv, in_channels, hidden_channels, out_channels, num_layers, dropout):'''in_channels:初始节点嵌入的维度。由于药物没有节点特征,我们将随机初始化这些向量。hidden_channels:中间节点嵌入的维度。隐藏层的维度。out_channels:输出节点嵌入的维度。num_layers:我们的 GNN 中的层数K。这是应用 GraphSAGE 运算符的次数。dropout:Dropout 应用于权重矩阵 W1 和 W2。'''super(GraphSAGE, self).__init__()self.convs = torch.nn.ModuleList()assert (num_layers >= 2), 'Have at least 2 layers'##至少两层卷积# 在每一个layer中增加conv,上一层与下一层的维度必须一致# 我们还应用了归一化,之后输出节点嵌入。每个卷积层都是 L2 归一化的。self.convs.append(conv(in_channels, hidden_channels, normalize=True))for l in range(num_layers - 2):self.convs.append(conv(hidden_channels, hidden_channels, normalize=True))self.convs.append(conv(hidden_channels, out_channels, normalize=True))self.num_layers = num_layersself.dropout = dropoutdef forward(self, x, edge_index, edge_attr):if edge_attr is not None: ## 如果有edge_attrreturn self.forward_with_edge_attr(x, edge_index, edge_attr)# x 是初始节点嵌入的矩阵,形状 [N, in_channels]for i in range(self.num_layers - 1):# 第 i 层进行消息传递和聚合x = self.convs[i](x, edge_index)# x 的形状为 [N, hidden_channels]# 通过非线性激活函数relux = F.relu(x)x = F.dropout(x, p=self.dropout, training=self.training)# 生成最终嵌入, x 的形状为 [N, out_channels]x = self.convs[self.num_layers - 1](x, edge_index)return xdef forward_with_edge_attr(self, x, edge_index, edge_attr):# x 是初始节点嵌入的矩阵,形状 [N, in_channels]for i in range(self.num_layers - 1):# 第 i 层进行消息传递和聚合x = self.convs[i](x, edge_index, edge_attr)# x 的形状为 [N, hidden_channels]# 通过非线性激活函数relux = F.relu(x)x = F.dropout(x, p=self.dropout,training=self.training)# 生成最终嵌入, x 的形状为 [N, out_channels]x = self.convs[self.num_layers - 1](x, edge_index, edge_attr)return x
#设置参数
graphsage_in_channels = 256
graphsage_hidden_channels = 256
graphsage_out_channels = 256
graphsage_num_layers = 2
dropout = 0.5 ###注意,因为数据库ddi本身没有附带节点特征矩阵,所以我们要创立初始嵌入。torch.nn.Embedding
initial_node_embeddings = torch.nn.Embedding(ddi_graph.num_nodes, graphsage_in_channels).to(device)##图节点特征向量形状为[N,in_channels]
initial_node_embeddings
输出
Embedding(4267, 256)
## 实例化模型GraphSAGE
graphsage_model = GraphSAGE(SAGEConv, graphsage_in_channels, graphsage_hidden_channels,graphsage_out_channels,graphsage_num_layers, dropout).to(device)

链接预测器

link_predictor_in_channels = graphsage_out_channels
link_predictor_hidden_channels = link_predictor_in_channelsclass LinkPredictor(torch.nn.Module):"""将两个输入转换为单个输出的通用网络。"""def __init__(self, in_channels, hidden_channels, dropout, out_channels=1,concat=lambda x, y: x * y):super(LinkPredictor, self).__init__()self.model = nn.Sequential(nn.Linear(in_channels, hidden_channels), nn.ReLU(), nn.Dropout(p=dropout), nn.Linear(hidden_channels, out_channels), nn.Sigmoid())self.concat = concatdef forward(self, u, v):x = self.concat(u, v)return self.model(x)link_predictor = LinkPredictor(in_channels=link_predictor_in_channels, hidden_channels=link_predictor_hidden_channels, dropout=dropout).to(device)

8.训练

##训练我们的完整模型(GraphSAGE + LinkPredictor)
def train(graphsage_model, link_predictor, initial_node_embeddings, edge_index, pos_train_edges, optimizer, batch_size, edge_attr=None):total_loss, total_examples = 0, 0# 设置我们的模型进行训练graphsage_model.train()link_predictor.train()# 迭代成批的训练边(“正边”)# (最后一次迭代的边数可能比 batch_size 少)for pos_samples in DataLoader(pos_train_edges, batch_size, shuffle=True):optimizer.zero_grad()# 运行 GraphSAGE 前向传递node_embeddings = graphsage_model(initial_node_embeddings, edge_index, edge_attr)#对由 attr:'edge_index'给出的图的随机对负边进行采样。# neg_samples 是一个尺寸为 [2, batch_size] 的张量neg_samples = negative_sampling(edge_index, num_nodes=initial_node_embeddings.size(0),num_neg_samples=len(pos_samples),method='dense')# 在正边嵌入上运行链接预测器前向传递pos_preds = link_predictor(node_embeddings[pos_samples[:, 0]], node_embeddings[pos_samples[:, 1]])# 在负边嵌入上运行链接预测器前向传递neg_preds = link_predictor(node_embeddings[neg_samples[0]], node_embeddings[neg_samples[1]])preds = torch.concat((pos_preds, neg_preds))labels = torch.concat((torch.ones_like(pos_preds), torch.zeros_like(neg_preds)))loss = F.binary_cross_entropy(preds, labels)loss.backward()optimizer.step()num_examples = len(pos_preds)total_loss += loss.item() * num_examplestotal_examples += num_examplesreturn total_loss / total_examples
##参数
lr = 0.005
batch_size = 65536
epochs = 2
eval_steps = 5
optimizer = torch.optim.Adam(list(graphsage_model.parameters()) + list(link_predictor.parameters()),lr=lr)

我们根据 OGB 数据集中提供的验证和测试正负边缘来评估我们的模型:

pos_valid_edges = valid_edges['edge'].to(device)
neg_valid_edges = valid_edges['edge_neg'].to(device)
pos_test_edges = test_edges['edge'].to(device)
neg_test_edges = test_edges['edge_neg'].to(device)from ogb.linkproppred import Evaluatorevaluator = Evaluator(name = dataset_name)
@torch.no_grad()
def test(graphsage_model, link_predictor, initial_node_embeddings, edge_index, pos_valid_edges, neg_valid_edges, pos_test_edges, neg_test_edges, batch_size, evaluator, edge_attr=None):graphsage_model.eval()link_predictor.eval()final_node_embeddings = graphsage_model(initial_node_embeddings, edge_index, edge_attr)pos_valid_preds = []for pos_samples in DataLoader(pos_valid_edges, batch_size):pos_preds = link_predictor(final_node_embeddings[pos_samples[:, 0]], final_node_embeddings[pos_samples[:, 1]])pos_valid_preds.append(pos_preds.squeeze())pos_valid_pred = torch.cat(pos_valid_preds, dim=0)neg_valid_preds = []for neg_samples in DataLoader(neg_valid_edges, batch_size):neg_preds = link_predictor(final_node_embeddings[neg_samples[:, 0]], final_node_embeddings[neg_samples[:, 1]])neg_valid_preds.append(neg_preds.squeeze())neg_valid_pred = torch.cat(neg_valid_preds, dim=0)pos_test_preds = []for pos_samples in DataLoader(pos_test_edges, batch_size):pos_preds = link_predictor(final_node_embeddings[pos_samples[:, 0]], final_node_embeddings[pos_samples[:, 1]])pos_test_preds.append(pos_preds.squeeze())pos_test_pred = torch.cat(pos_test_preds, dim=0)neg_test_preds = []for neg_samples in DataLoader(neg_test_edges, batch_size):neg_preds = link_predictor(final_node_embeddings[neg_samples[:, 0]], final_node_embeddings[neg_samples[:, 1]])neg_test_preds.append(neg_preds.squeeze())neg_test_pred = torch.cat(neg_test_preds, dim=0)# Calculate Hits@20evaluator.K = 20valid_hits = evaluator.eval({'y_pred_pos': pos_valid_pred, 'y_pred_neg': neg_valid_pred})test_hits = evaluator.eval({'y_pred_pos': pos_test_pred, 'y_pred_neg': neg_test_pred})return valid_hits, test_hits
import matplotlib.pyplot as pltepochs_bar = trange(1, epochs + 1, desc='Loss n/a')edge_index = ddi_graph.edge_index.to(device)
pos_train_edges = train_edges['edge'].to(device)losses = []
valid_hits_list = []
test_hits_list = []
for epoch in epochs_bar:loss = train(graphsage_model, link_predictor, initial_node_embeddings.weight, edge_index, pos_train_edges, optimizer, batch_size)losses.append(loss)epochs_bar.set_description(f'Loss{loss:0.4f}')if epoch % eval_steps == 0:valid_hits, test_hits = test(graphsage_model, link_predictor, initial_node_embeddings.weight, edge_index, pos_valid_edges, neg_valid_edges, pos_test_edges, neg_test_edges, batch_size, evaluator)print()print(f'Epoch:{epoch}, Validation Hits@20:{valid_hits["hits@20"]:0.4f}, Test Hits@20:{test_hits["hits@20"]:0.4f}')valid_hits_list.append(valid_hits['hits@20'])test_hits_list.append(test_hits['hits@20'])else:valid_hits_list.append(valid_hits_list[-1] if valid_hits_list else 0)test_hits_list.append(test_hits_list[-1] if test_hits_list else 0)plt.title(dataset.name + ": GraphSAGE")
plt.xlabel("Epoch")
plt.plot(losses, label="Training loss")
plt.plot(valid_hits_list, label="Validation Hits@20")
plt.plot(test_hits_list, label="Test Hits@20")
plt.legend()
plt.show()
输出
Loss 0.4814: 100%|███████████████████████████████████████████████████████████████████████| 2/2 [00:27<00:00, 13.65s/it]

我自己电脑太拉了就跑了2个epoch,看看能不能跑通代码试一下。
白嫖gpu跑了一下50epoch

使用图神经网络预测药物-药物相互作用相关推荐

  1. 【广告技术】使用图神经网络进行信息聚合与推理,解决多证据事实验证问题

    [Wiztalk腾讯广告专场]系列分享来袭,第三期由清华大学计算机系副教授.博士生导师刘知远老师与清华大学计算机系硕士生周界为大家深度介绍 <基于图结构的事实验证>. 从浅显的文本处理走向 ...

  2. 中科院、华为等提出Vision GNN,只使用图神经网络进行视觉任务

    ©作者 | 周春鹏 单位 | 浙江大学 研究方向 | 计算机视觉 网络结构在基于深度学习的计算机视觉系统中起着至关重要的作用.目前广泛应用的卷积神经网络和卷积神经转换器将图像视为网格或序列结构,难以灵 ...

  3. 图神经网络用于RNA-蛋白质相互作用的新预测

    <De novo p rediction of RNA-protein interactions with Graph Neural Networks> 时间:2021年9月28日 作者: ...

  4. 论文浅尝 - IJCAI2020 | KGNN:基于知识图谱的图神经网络预测药物与药物相互作用...

    转载公众号 |  AI TIME 论道 药物间相互作用(DDI)预测是药理学和临床应用中一个具有挑战性的问题,在临床试验期间,有效识别潜在的DDI对患者和社会至关重要.现有的大多数方法采用基于AI的计 ...

  5. DrugVQA | 用视觉问答技术预测药物蛋白质相互作用

    1.研究背景 鉴定新的药物-蛋白质相互作用对于药物发现至关重要,基于机器学习的方法利用药物描述符和一维(1D)蛋白质序列已经开发了许多鉴定方法.这些方法一般都是通过将配体,蛋白质及其相互作用的信息整合 ...

  6. 生物信息学|深度学习改善了药物药物相互作用和药物食物相互作用的预测精度

    本篇推文引自:Deep learning improves prediction of drug–drug and drug–food interactions 1. 摘要     药物相互作用(包括 ...

  7. 生物信息学|用于预测药物-药物相互作用事件的多模态深度学习框架

    本篇推文引自:A multimodal deep learning framework for predicting drug–drug interaction events 1. 摘要     动机 ...

  8. 生物信息学|机制驱动的可解释深度神经网络,用于药物组合的协同预测和通路反卷积

    0. 摘要     联合用药在癌症治疗方面显示出了巨大的潜力.不仅可以减轻耐药性,而且可以提高治疗效果.抗癌药物数量的快速增长已经导致所有药物组合的实验研究变得昂贵和耗时.计算技术可以提高药物联合筛选 ...

  9. Bioinformatics | 预测药物-药物相互作用的多模态深度学习框架

    今天给大家介绍来自华中农业大学信息学院章文教授课题组在Bioinformatics上发表的一篇关于预测药物与药物相互作用事件的文章.作者提出了一个多模态深度学习框架- DDIMDL.它将不同的药物特征 ...

最新文章

  1. 用了 HTTPS 就一定安全吗?HTTPS 原理分析——带着疑问层层深入
  2. 功能性平台创新-农业大健康·杨建国:谋定都江堰精华灌区
  3. enum的介绍以及和#define的区别
  4. POJ-1260 Pearls DP
  5. android 二级 滚动,android使用 ScrollerView 实现 可上下滚动的分类栏实例
  6. R packages for big data:data.table
  7. 【机器学习与数据挖掘】浅谈指标SSE,MSE,RMSE,R-square
  8. matlab矩阵的白化,白化原理及Matlab实现
  9. java for步长_Velocity模板循环支持自定义步长
  10. CAD VCL Multiplatform SDK 定制Crack
  11. 考HCIE大概需要多少钱?
  12. 腾讯云运维工程师认证TCA原题(含解析)
  13. 数据挖掘经典十大算法_对基本概念的理解
  14. MIT线性代数笔记十七讲 正交矩阵和施密特正交化
  15. 华为鸿蒙系统支持车型,华为助力,鸿蒙OS车载系统量产,北汽极狐智能驾驶车型登场...
  16. C++中cos,sin,asin,acos这些三角函数操作的是弧度,而非角度,
  17. React基础学习笔记(一)-react前端项目的两种搭建方式
  18. null == undefined ?
  19. DataGridView导出Excel 隐藏列不显示
  20. git merge本地合并分支出现文件冲突处理方法

热门文章

  1. chmod 和 chown 命令用法
  2. 1231321321
  3. sql substr oracle,Substr也可以使用索引吗?
  4. 表现层(UI)、业务逻辑层(BLL)、数据访问层(DAL)
  5. AI热潮来袭||网友:AI会不会抢自己的饭碗啊~~~
  6. 抖音企业号,抖音搜索框SEO优化系统搭建。
  7. 木马也办“假身份证” 数字签名面临信任危机
  8. vscode配置opencv环境,包括opencv源码编译(mingw64 + cmake)
  9. 赞雨林木风:从修改版到定制版
  10. 美森快船收费标准和操作流程是怎样的?