疾病-基因与图神经网络和图自动编码器的相互作用:

学习图自编码器


PyG实战

  • 1.学习内容:
  • 2.GAE自编码器:
  • 3.数据集的处理
  • 4.总结

1.学习内容:

  1. 图卷积GNN知识
  2. GCN作为图形自动编码器GAE架构
  3. GAE架构在疾病-基因相互作用预测的应用

在PYG库中的卷积层中有许多不同的变体,但每层的核心是三个步骤:消息传递、聚合和更新
在pytorch_geometric中,可以使用一行代码构建GCN层:

from torch_geometric.nn import GCNConv
conv = GCNConv(in_channels, out_channels)

in_channels和out_channels分别表示节点的输入表示维度和输出表示维度的大小。一般来说 in_channels=X.shape[1](X是节点特征矩阵)。
GCN虽然是最简单的GNN,但在实践中效果很好,GCN的变体通常排在图形数据集基准的首位。可以查看开放图形基准测试的数据集OGB的排行榜。在排行榜中单纯的GCN可能使用的数据集比较少。


2.GAE自编码器:

首先,我了解一点机器学习领域中的自编码器(AutoEncode)。自编码器包含两个主要的部分:Encode(编码)和 Decode(解码)。AE的作用大体就是把一个高维向量X编码成低维的隐变量h,然后h通过解码器解码到初始维度,最好的情况就是解码器能够近似或者完美恢复原来的输入。这就要求编码器尽可能地学习最有信息量的特征。
那么在图中原理差不多也是一样的,在GAE中,我们有一个编码器,其工作是将输入图映射到较低维空间,以及一个解码器,用于从低维嵌入重建输入图。也就是说,我们将解码器输出解释为重建的邻接矩阵 A^\hat AA^。目标是优化模型,使重建损失(A^\hat AA^和原始图形输入A之间的差异)最小化。

我们将定义一个具有两个图形卷积层、一个 ReLU 和一个 dropout 的 GCN,以帮助建模性能。当然,在编码器部分我们可以多样化,插入其他GNN卷积层。
GCNEncode:

导入库
import torch
import torch.nn as nn
from torch_geometric.nn import GCNConv
from torch_geometric.nn import GAE
from torch_geometric.transforms import RandomLinkSplit
from torch_geometric.utils import train_test_split_edges, negative_sampling, degree
import torch.nn.functional as F
import torch.optim as optim
import torch_geometric.transforms as T
import math
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import plotly.express as px
import random
import string
from sklearn import metrics
from torch_geometric.data import Data, download_url, extract_gz
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
class GAEncoder(nn.Module):def __init__(self, in_channels, hidden_size, out_channels, dropout):super(GAEncoder, self).__init__()self.conv1 = GCNConv(in_channels, hidden_size, cached=True)self.conv2 = GCNConv(hidden_size, out_channels, cached=True)self.dropout = nn.Dropout(dropout)def forward(self, x, edge_index):x = self.conv1(x, edge_index)x = F.relu(x)x = self.dropout(x)out = self.conv2(x, edge_index)return out

model = GAE(GAEncoder(20, 200, 20, 0.5)).to(device)

查看模型结构
GAE((encoder): GAEncoder((conv1): GCNConv(20, 200)(conv2): GCNConv(200, 20)(relu): ReLU()(dropout): Dropout(p=0.5, inplace=False))(decoder): InnerProductDecoder()  ## 默认解码器点积运算符
)
def train(train_data, model, optimizer):model.train()optimizer.zero_grad()z = model.encode(train_data.x,train_data.edge_index)loss = model.recon_loss(z, train_data.pos_edge_label_index.to(device))loss.backward(retain_graph=True)optimizer.step()return float(loss)@torch.no_grad()
def gae_test(test_data,model):model.eval()z = model.encode(test_data.x, test_data.edge_index)loss = model.test(z, test_data.pos_edge_label_index, test_data.neg_edge_label_index)return loss

以上就是模型GAE的网络架构。

3.数据集的处理

url = 'http://snap.stanford.edu/biodata/datasets/10012/files/DG-AssocMiner_miner-disease-gene.tsv.gz'
extract_gz(download_url(url, '.'), '.')
data_path = "./DG-AssocMiner_miner-disease-gene.tsv"
df = pd.read_csv(data_path, sep="\t")
df.head()
输出# Disease ID   Disease Name            Gene ID
0   C0036095    Salivary Gland Neoplasms    1462
1   C0036095    Salivary Gland Neoplasms    1612
2   C0036095    Salivary Gland Neoplasms    182
3   C0036095    Salivary Gland Neoplasms    2011
4   C0036095    Salivary Gland Neoplasms    2019df.shape
(21357, 3)

导入数据

def load_data(data_path,class_node=519):df = pd.read_csv(data_path, sep='\t')dise_id = df['# Disease ID']Gene_id = df['Gene ID']dis_mapping = {index_id: int(i) + 0 for i, index_id in enumerate(dise_id.unique())}gen_mapping = {index_id: int(i) + class_node for i, index_id in enumerate(Gene_id.unique())}src_nodes = [dis_mapping[index] for index in df['# Disease ID']]dst_nodes = [gen_mapping[index] for index in df['Gene ID']]edge_index = torch.tensor([src_nodes, dst_nodes])rev_edge_index = torch.tensor([dst_nodes, src_nodes])data = Data()data.num_nodes = len(dis_mapping) + len(gen_mapping)data.edge_index = torch.concat([edge_index, rev_edge_index],dim=1)data.x = torch.ones((data.num_nodes, 20))return data, gen_mapping, dis_mapping

通过上述,我们得到了无向图的data图数据

data_object, gene_mapping, dis_mapping = load_data(data_path)
print(data_object)
print("Number of genes:", len(gene_mapping))
print("Number of diseases:", len(dz_mapping))输出
Data(num_nodes=7813, edge_index=[2, 42714], x=[7813, 20])
Number of genes: 7294
Number of diseases: 519

在pytorch_geometric中使用 RandomLinkSplit 方法创建训练集、验证集和测试集。

transform = T.Compose([T.NormalizeFeatures(),T.ToDevice(device),T.RandomLinkSplit(num_val=0.05, num_test=0.15, is_undirected=True,split_labels=True, add_negative_train_samples=True),
])
train_datasets, val_datasets, test_datasets = transform(data_object)
print("Train Data:", train_datasets)
print("Validation Data:", val_datasets)
print("Test Data:", test_datasets)
查看训练测试集
Train Data: Data(num_nodes=7813, edge_index=[2, 34174], x=[7813, 20], pos_edge_label=[17087], pos_edge_label_index=[2, 17087], neg_edge_label=[17087], neg_edge_label_index=[2, 17087])
Validation Data: Data(num_nodes=7813, edge_index=[2, 34174], x=[7813, 20], pos_edge_label=[1067], pos_edge_label_index=[2, 1067], neg_edge_label=[1067], neg_edge_label_index=[2, 1067])
Test Data: Data(num_nodes=7813, edge_index=[2, 36308], x=[7813, 20], pos_edge_label=[3203], pos_edge_label_index=[2, 3203], neg_edge_label=[3203], neg_edge_label_index=[2, 3203])

接下来就是训练

optimizer = optim.Adam(model.parameters(),lr=0.1)
losses = []
test_auc = []
test_ap = []
train_aucs = []
train_aps = []
for epoch in range(1, 50):loss = train(train_datasets, model, optimizer)losses.append(loss)auc, ap = gae_test(test_datasets, model)test_auc.append(auc)test_ap.append(ap)train_auc, train_ap = gae_test(train_datasets, model)train_aucs.append(train_auc)train_aps.append(train_ap)print('Epoch: {:03d}, test AUC: {:.4f}, test AP: {:.4f}, train AUC: {:.4f}, train AP: {:.4f}, loss:{:.4f}'.format(epoch, auc, ap, train_auc, train_ap, loss))
Epoch: 001, test AUC: 0.9426, test AP: 0.9421, train AUC: 0.9436, train AP: 0.9385, loss:3.2068
Epoch: 002, test AUC: 0.9426, test AP: 0.9421, train AUC: 0.9436, train AP: 0.9385, loss:3.2009
Epoch: 003, test AUC: 0.9426, test AP: 0.9421, train AUC: 0.9436, train AP: 0.9385, loss:3.1856
Epoch: 004, test AUC: 0.9426, test AP: 0.9421, train AUC: 0.9436, train AP: 0.9385, loss:3.0031
Epoch: 005, test AUC: 0.9426, test AP: 0.9421, train AUC: 0.9436, train AP: 0.9385, loss:3.0008
Epoch: 006, test AUC: 0.9426, test AP: 0.9421, train AUC: 0.9436, train AP: 0.9385, loss:3.0743
Epoch: 007, test AUC: 0.9426, test AP: 0.9421, train AUC: 0.9436, train AP: 0.9385, loss:3.0811
Epoch: 008, test AUC: 0.9426, test AP: 0.9421, train AUC: 0.9436, train AP: 0.9385, loss:3.1047
Epoch: 009, test AUC: 0.9426, test AP: 0.9421, train AUC: 0.9436, train AP: 0.9385, loss:3.0700
Epoch: 010, test AUC: 0.9426, test AP: 0.9421, train AUC: 0.9436, train AP: 0.9385, loss:3.1119
Epoch: 011, test AUC: 0.9426, test AP: 0.9421, train AUC: 0.9436, train AP: 0.9385, loss:3.1230
Epoch: 012, test AUC: 0.9426, test AP: 0.9421, train AUC: 0.9436, train AP: 0.9385, loss:3.0303
Epoch: 013, test AUC: 0.9426, test AP: 0.9421, train AUC: 0.9436, train AP: 0.9385, loss:3.1233
Epoch: 014, test AUC: 0.9426, test AP: 0.9421, train AUC: 0.9436, train AP: 0.9385, loss:3.0642
Epoch: 015, test AUC: 0.9426, test AP: 0.9421, train AUC: 0.9436, train AP: 0.9385, loss:3.0782
Epoch: 016, test AUC: 0.9426, test AP: 0.9421, train AUC: 0.9436, train AP: 0.9385, loss:3.0172
Epoch: 017, test AUC: 0.9426, test AP: 0.9421, train AUC: 0.9436, train AP: 0.9385, loss:3.0432
Epoch: 018, test AUC: 0.9426, test AP: 0.9421, train AUC: 0.9436, train AP: 0.9385, loss:3.0900
Epoch: 019, test AUC: 0.9426, test AP: 0.9421, train AUC: 0.9436, train AP: 0.9385, loss:3.0705
Epoch: 020, test AUC: 0.9426, test AP: 0.9421, train AUC: 0.9436, train AP: 0.9385, loss:2.9925
Epoch: 021, test AUC: 0.9426, test AP: 0.9421, train AUC: 0.9436, train AP: 0.9385, loss:3.2228
Epoch: 022, test AUC: 0.9426, test AP: 0.9421, train AUC: 0.9436, train AP: 0.9385, loss:3.0448
Epoch: 023, test AUC: 0.9426, test AP: 0.9421, train AUC: 0.9436, train AP: 0.9385, loss:3.0485
Epoch: 024, test AUC: 0.9426, test AP: 0.9421, train AUC: 0.9436, train AP: 0.9385, loss:2.9706
Epoch: 025, test AUC: 0.9426, test AP: 0.9421, train AUC: 0.9436, train AP: 0.9385, loss:3.1790
Epoch: 026, test AUC: 0.9426, test AP: 0.9421, train AUC: 0.9436, train AP: 0.9385, loss:2.9618
Epoch: 027, test AUC: 0.9426, test AP: 0.9421, train AUC: 0.9436, train AP: 0.9385, loss:3.0103
Epoch: 028, test AUC: 0.9426, test AP: 0.9421, train AUC: 0.9436, train AP: 0.9385, loss:3.1602
Epoch: 029, test AUC: 0.9426, test AP: 0.9421, train AUC: 0.9436, train AP: 0.9385, loss:3.1063
Epoch: 030, test AUC: 0.9426, test AP: 0.9421, train AUC: 0.9436, train AP: 0.9385, loss:3.0232
Epoch: 031, test AUC: 0.9426, test AP: 0.9421, train AUC: 0.9436, train AP: 0.9385, loss:3.1293
Epoch: 032, test AUC: 0.9426, test AP: 0.9421, train AUC: 0.9436, train AP: 0.9385, loss:3.0388
Epoch: 033, test AUC: 0.9426, test AP: 0.9421, train AUC: 0.9436, train AP: 0.9385, loss:3.1522
Epoch: 034, test AUC: 0.9426, test AP: 0.9421, train AUC: 0.9436, train AP: 0.9385, loss:3.1759
Epoch: 035, test AUC: 0.9426, test AP: 0.9421, train AUC: 0.9436, train AP: 0.9385, loss:3.1749
Epoch: 036, test AUC: 0.9426, test AP: 0.9421, train AUC: 0.9436, train AP: 0.9385, loss:3.2155
Epoch: 037, test AUC: 0.9426, test AP: 0.9421, train AUC: 0.9436, train AP: 0.9385, loss:3.1248
Epoch: 038, test AUC: 0.9426, test AP: 0.9421, train AUC: 0.9436, train AP: 0.9385, loss:3.1518
Epoch: 039, test AUC: 0.9426, test AP: 0.9421, train AUC: 0.9436, train AP: 0.9385, loss:3.0606
Epoch: 040, test AUC: 0.9426, test AP: 0.9421, train AUC: 0.9436, train AP: 0.9385, loss:3.0665
Epoch: 041, test AUC: 0.9426, test AP: 0.9421, train AUC: 0.9436, train AP: 0.9385, loss:3.0464
Epoch: 042, test AUC: 0.9426, test AP: 0.9421, train AUC: 0.9436, train AP: 0.9385, loss:3.0648
Epoch: 043, test AUC: 0.9426, test AP: 0.9421, train AUC: 0.9436, train AP: 0.9385, loss:3.0416
Epoch: 044, test AUC: 0.9426, test AP: 0.9421, train AUC: 0.9436, train AP: 0.9385, loss:3.2085
Epoch: 045, test AUC: 0.9426, test AP: 0.9421, train AUC: 0.9436, train AP: 0.9385, loss:3.0527
Epoch: 046, test AUC: 0.9426, test AP: 0.9421, train AUC: 0.9436, train AP: 0.9385, loss:2.9981
Epoch: 047, test AUC: 0.9426, test AP: 0.9421, train AUC: 0.9436, train AP: 0.9385, loss:3.1345
Epoch: 048, test AUC: 0.9426, test AP: 0.9421, train AUC: 0.9436, train AP: 0.9385, loss:3.0027
Epoch: 049, test AUC: 0.9426, test AP: 0.9421, train AUC: 0.9436, train AP: 0.9385, loss:3.1076

我不知道为啥训练和测试没有变化,但是loss能看到变化。

4.总结

通过机器学习中的自编码器AE,迁移到图的自编码器GAE。还有一种VAE变分自编码器,VGAE是GAE的变体。先学一波。

以上来自https://snap.stanford.edu/class/cs224w-2020/

【实战】疾病-基因与图神经网络和图自动编码器的相互作用相关推荐

  1. 当图网络遇上计算机视觉!计算机视觉中基于图神经网络和图Transformer的方法和最新进展...

    点击下方卡片,关注"CVer"公众号 AI/CV重磅干货,第一时间送达 点击进入-> CV 微信技术交流群 可能是目前最全面的<当图网络遇上计算机视觉>综述!近四 ...

  2. 图神经网络与图注意力网络相关知识概述

    #图神经网络# #图注意力网络# 随着计算机行业和互联网时代的不断发展与进步,图神经网络已经成为人工智能和大数据的重要研究领域.图神经网络是对相邻节点间信息的传播和聚合的重要技术,可以有效地将深度学习 ...

  3. 图神经网络的图网络学习(上)

    图神经网络的图网络学习(上) 原文:Learning the Network of Graphs for Graph Neural Networks 摘要 图神经网络 (GNN) 在许多使用图结构数据 ...

  4. 图神经网络基础--图结构数据

    图神经网络基础–图结构数据 注:本节大部分内容(包括图片)来源于"Chapter 2 - Foundations of Graphs, Deep Learning on Graphs&quo ...

  5. 【图神经网络】图分类学习研究综述[2]:基于图神经网络的图分类

    基于GNN的图分类学习研究综述[2]:基于图神经网络的图分类 论文阅读:基于GNN的图分类学习研究综述 3. 基于图神经网络的图分类 3.1 卷积 3.2 池化 论文阅读:基于GNN的图分类学习研究综 ...

  6. 图神经网络的图网络学习(下)

    原文:Learning the Network of Graphs for Graph Neural Networks 1. 文章信息 作者 Yixiang Shan, Jielong Yang, X ...

  7. 「图神经网络复杂图挖掘」 的研究进展

    来源:专知 图神经网络对非欧式空间数据建立了深度学习框架,相比传统网络表示学习模型,它对图结构能够实施更加深层的信息聚合操作.近年来,图神经网络完成了向复杂图结构的迁移,诞生了一系列基于复杂图的图神经 ...

  8. 图神经网络之图卷积网络——GCN

    图卷积网络--GCN 一.前置基础知识回顾 图的基本概念 构造图神经网络的目的 训练方式 二.回顾卷积神经网络在图像及文本上的发展 图像上的卷积网络 文本上的卷积网络 图卷积网络的必要性 三.图卷积网 ...

  9. 图神经网络 | (6) 图分类(SAGPool)实战

    近期买了一本图神经网络的入门书,最近几篇博客对书中的一些实战案例进行整理,具体的理论和原理部分可以自行查阅该书,该书购买链接:<深入浅出的图神经网络>. 该书配套代码 本节我们通过代码来实 ...

  10. 论文浅尝 - IJCAI2020 | KGNN:基于知识图谱的图神经网络预测药物与药物相互作用...

    转载公众号 |  AI TIME 论道 药物间相互作用(DDI)预测是药理学和临床应用中一个具有挑战性的问题,在临床试验期间,有效识别潜在的DDI对患者和社会至关重要.现有的大多数方法采用基于AI的计 ...

最新文章

  1. ubuntu MNN编译安装
  2. 带你了解加速度传感器的几种应用
  3. 禁止vim生成 un~文件
  4. jQuery源码解读
  5. Python---编程检查并判断密码字符串的安全强度
  6. C#开发笔记之04-如何用C#优雅的计算个人所得税?
  7. java反编译源码_java反编译获取源码
  8. 4999元!iQOO 9 Pro赛道版今日预售:创新性采用芳纶纤维材质
  9. uva 11419 最大匹配(最小点覆盖)
  10. SetWindowPos详解
  11. c语言mfc步骤,C语言工程MFC
  12. 一键清理垃圾的bat文件
  13. Python百度文库爬虫终极版
  14. heaps入门---1
  15. 仿微信设置字体大小控件
  16. 【Visual C++】游戏开发笔记四十 浅墨DirectX教程之八 绘制真实质感的三维世界:光照与材质专场
  17. 基于BLE + LoRa人员定位技术下的室内定位-Lora人员定位-新导智能
  18. 搜狗语音云开发入门--移动端轻松添加高大上的语音识别
  19. 然后上传到linux主机上,Xshell实现Windows上传文件到Linux主机
  20. 河北省 河南省 安徽省 黑龙江省 辽宁省 吉林省 贵州省 陕西省 山东省 云南省 广西省二级建造师 一级建造师...

热门文章

  1. 程序员需要建立的对技术、业务、行业、管理、投资的认知
  2. linux内存占用率高怎么办,Linux下如何解决高内存使用率问题?
  3. 5秒内克隆你的声音,并生成任何内容,这个工具细思极恐...还特么的开源~
  4. 2020身高体重标准表儿童_婴儿身高体重对照表2020
  5. 错误Could not locate executable null\bin\winutils.exe in the Hadoop binaries的解决方案
  6. Limbo模拟器的三两事
  7. 使用pytest 出现collected 0 items解决
  8. 中国科技大学计算机系导师,中国科学技术大学
  9. 洋媳妇教育孩子的方法,令中国婆婆大开眼界 - 人人都是艺术
  10. java成员变量的访问权限_Java学习笔记10---访问权限修饰符如何控制成员变量、成员方法及类的访问范围...