2021SC@SDUSC
DGraphDTA任务训练部分完整代码:

model = GNNNet()
model.to(device)
model_st = GNNNet.__name__
# device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
#设置lossfunction
loss_fn = nn.MSELoss()
#设置优化器
optimizer = torch.optim.Adam(model.parameters(), lr=LR)for dataset in datasets:train_data, valid_data = create_dataset_for_5folds(dataset, fold)train_loader = torch.utils.data.DataLoader(train_data, batch_size=TRAIN_BATCH_SIZE, shuffle=True,collate_fn=collate)valid_loader = torch.utils.data.DataLoader(valid_data, batch_size=TEST_BATCH_SIZE, shuffle=False,collate_fn=collate)best_mse = 1000best_test_mse = 1000best_epoch = -1model_file_name = 'models/model_' + model_st + '_' + dataset + '_' + str(fold) + '.model'for epoch in range(NUM_EPOCHS):train(model, device, train_loader, optimizer, epoch + 1)print('predicting for valid data')G, P = predicting(model, device, valid_loader)val = get_mse(G, P)print('valid result:', val, best_mse)if val < best_mse:best_mse = valbest_epoch = epoch + 1torch.save(model.state_dict(), model_file_name)print('rmse improved at epoch ', best_epoch, '; best_test_mse', best_mse, model_st, dataset, fold)else:print('No improvement since epoch ', best_epoch, '; best_test_mse', best_mse, model_st, dataset, fold)

初始化模型model,并将其送入device(此次为cuda)

model = GNNNet()
model.to(device)
model_st = GNNNet.__name__

设置损失函数为均方损失函数,设置优化器为Adam

loss_fn = nn.MSELoss()optimizer = torch.optim.Adam(model.parameters(), lr=LR)#参数为模型的参数以及learning rate

该部分将数据集拆分封装为训练集和测试集,此部分细节本文不作深入研究

for dataset in datasets:train_data, valid_data = create_dataset_for_5folds(dataset, fold)train_loader = torch.utils.data.DataLoader(train_data, batch_size=TRAIN_BATCH_SIZE, shuffle=True,collate_fn=collate)valid_loader = torch.utils.data.DataLoader(valid_data, batch_size=TEST_BATCH_SIZE, shuffle=False,collate_fn=collate)

此处对上面的循环做一个解释,这里的datasets为davis和kiba,所以是对两个数据集进行分割封装操作

#获取数据集
datasets = [['davis', 'kiba'][int(sys.argv[1])]]

数据加载部分:

def create_dataset_for_5folds(dataset, fold=0):# load datasetdataset_path = 'data/' + dataset + '/' #数据集路径train_fold_origin = json.load(open(dataset_path + 'folds/train_fold_setting1.txt')) #通过json.load加载txttrain_fold_origin = [e for e in train_fold_origin]  # for 5 folds 将其转换为列表ligands = json.load(open(dataset_path + 'ligands_can.txt'), object_pairs_hook=OrderedDict) #加载分子配体SMILES序列proteins = json.load(open(dataset_path + 'proteins.txt'), object_pairs_hook=OrderedDict) #加载蛋白质fasta序列# load contact and aln 加载接触图和alnmsa_path = 'data/' + dataset + '/aln'contac_path = 'data/' + dataset + '/pconsc4'msa_list = []contact_list = []#根据protein的dict中的每一个key到对应的aln和contact中寻找对应蛋白质的aln和contact,将它们append到mas_list和contact_list中与fasta数据一一对应for key in proteins:msa_list.append(os.path.join(msa_path, key + '.aln'))contact_list.append(os.path.join(contac_path, key + '.npy'))# load train,valid and test entriestrain_folds = []valid_fold = train_fold_origin[fold]  # one foldfor i in range(len(train_fold_origin)):  # other foldsif i != fold:train_folds += train_fold_origin[i]affinity = pickle.load(open(dataset_path + 'Y', 'rb'), encoding='latin1')drugs = []prots = []prot_keys = []drug_smiles = []# smilesfor d in ligands.keys():lg = Chem.MolToSmiles(Chem.MolFromSmiles(ligands[d]), isomericSmiles=True)drugs.append(lg)drug_smiles.append(ligands[d])# seqsfor t in proteins.keys():prots.append(proteins[t])prot_keys.append(t)if dataset == 'davis':affinity = [-np.log10(y / 1e9) for y in affinity]affinity = np.asarray(affinity)opts = ['train', 'valid']valid_train_count = 0valid_valid_count = 0for opt in opts:if opt == 'train':rows, cols = np.where(np.isnan(affinity) == False)rows, cols = rows[train_folds], cols[train_folds]train_fold_entries = []for pair_ind in range(len(rows)):if not valid_target(prot_keys[cols[pair_ind]], dataset):  # ensure the contact and aln files existscontinuels = []ls += [drugs[rows[pair_ind]]]ls += [prots[cols[pair_ind]]]ls += [prot_keys[cols[pair_ind]]]ls += [affinity[rows[pair_ind], cols[pair_ind]]]train_fold_entries.append(ls)valid_train_count += 1csv_file = 'data/' + dataset + '_' + 'fold_' + str(fold) + '_' + opt + '.csv'data_to_csv(csv_file, train_fold_entries)elif opt == 'valid':rows, cols = np.where(np.isnan(affinity) == False)rows, cols = rows[valid_fold], cols[valid_fold]valid_fold_entries = []for pair_ind in range(len(rows)):if not valid_target(prot_keys[cols[pair_ind]], dataset):continuels = []ls += [drugs[rows[pair_ind]]]ls += [prots[cols[pair_ind]]]ls += [prot_keys[cols[pair_ind]]]ls += [affinity[rows[pair_ind], cols[pair_ind]]]valid_fold_entries.append(ls)valid_valid_count += 1csv_file = 'data/' + dataset + '_' + 'fold_' + str(fold) + '_' + opt + '.csv'data_to_csv(csv_file, valid_fold_entries)print('dataset:', dataset)# print('len(set(drugs)),len(set(prots)):', len(set(drugs)), len(set(prots)))# entries with protein contact and aln files are marked as effiectiveprint('fold:', fold)print('train entries:', len(train_folds), 'effective train entries', valid_train_count)print('valid entries:', len(valid_fold), 'effective valid entries', valid_valid_count)compound_iso_smiles = drugstarget_key = prot_keys# create smile graphsmile_graph = {}for smile in compound_iso_smiles:g = smile_to_graph(smile)smile_graph[smile] = g# print(smile_graph['CN1CCN(C(=O)c2cc3cc(Cl)ccc3[nH]2)CC1']) #for test# create target graph# print('target_key', len(target_key), len(set(target_key)))target_graph = {}for key in target_key:if not valid_target(key, dataset):  # ensure the contact and aln files existscontinueg = target_to_graph(key, proteins[key], contac_path, msa_path)target_graph[key] = g# count the number of  proteins with aln and contact filesprint('effective drugs,effective prot:', len(smile_graph), len(target_graph))if len(smile_graph) == 0 or len(target_graph) == 0:raise Exception('no protein or drug, run the script for datasets preparation.')# 'data/davis_fold_0_train.csv' or data/kiba_fold_0__train.csv'train_csv = 'data/' + dataset + '_' + 'fold_' + str(fold) + '_' + 'train' + '.csv'df_train_fold = pd.read_csv(train_csv)train_drugs, train_prot_keys, train_Y = list(df_train_fold['compound_iso_smiles']), list(df_train_fold['target_key']), list(df_train_fold['affinity'])train_drugs, train_prot_keys, train_Y = np.asarray(train_drugs), np.asarray(train_prot_keys), np.asarray(train_Y)train_dataset = DTADataset(root='data', dataset=dataset + '_' + 'train', xd=train_drugs, target_key=train_prot_keys,y=train_Y, smile_graph=smile_graph, target_graph=target_graph)df_valid_fold = pd.read_csv('data/' + dataset + '_' + 'fold_' + str(fold) + '_' + 'valid' + '.csv')valid_drugs, valid_prots_keys, valid_Y = list(df_valid_fold['compound_iso_smiles']), list(df_valid_fold['target_key']), list(df_valid_fold['affinity'])valid_drugs, valid_prots_keys, valid_Y = np.asarray(valid_drugs), np.asarray(valid_prots_keys), np.asarray(valid_Y)valid_dataset = DTADataset(root='data', dataset=dataset + '_' + 'train', xd=valid_drugs,target_key=valid_prots_keys, y=valid_Y, smile_graph=smile_graph,target_graph=target_graph)return train_dataset, valid_dataset

使用ctrl+F查找“]”,发现该文件形式为一个有5个数组元素的数组(5*n)的二维数组,而其中每个的数字代表一对蛋白质和药物结合的key,通过这个数字可以找到对应的药靶结合。

于是我们可以知道以下代码的含义,valid_fold是获取上面txt中第一个数组作为验证集,剩下的四个通过循环全部添加作为训练集。

train_folds = []valid_fold = train_fold_origin[fold]  # one fold for i in range(len(train_fold_origin)):  # other foldsif i != fold:train_folds += train_fold_origin[i]

DGraphDTA训练部分源码解读分析(一)2021SC@SDUSC相关推荐

  1. django middleware 中间件原理概念,源码解读分析

    用到的知识点 wsgi 搜索应用的入口 闭包,高阶函数递归调用 中间件实现的关键技术 asyncio 了解异步与同步函数类型转换 原理概念逻辑 引用官方文档 你可以把它想象成一个洋葱:每个中间件类都是 ...

  2. Nett源码剖析注册通道2021SC@SDUSC

    2021SC@SDUSC 在绑定端口过程中,类initAndRegister里有注册通道方法ChannelFuture regFuture = config().group().register(ch ...

  3. 分布式训练PyTorch 源码解读

    点上方计算机视觉联盟获取更多干货 仅作学术分享,不代表本公众号立场,侵权联系删除 转载于:商汤 AI博士笔记系列推荐 周志华<机器学习>手推笔记正式开源!可打印版本附pdf下载链接 0 前 ...

  4. 利用yolov7训练自己的数据集; yolov7的安装与使用 ; yolov7源码解读

    *免责声明: 1\此方法仅提供参考 2\搬了其他博主的操作方法,以贴上路径. 3* 场景一:Anconda环境基本操作 场景二:yolov7的使用 场景三:yolov7训练自己的数据集 场景四:实用工 ...

  5. 实战:Spring Boot源码解读与原理分析

    承载着作者的厚望,掘金爆火小册同名读物<Spring Boot源码解读与原理剖析>正式出书! 本书前身是掘金社区销量TOP的小册--<Spring Boot源码解读与原理剖析> ...

  6. Spark源码解读之Shuffle原理剖析与源码分析

    在前面几篇文章中,介绍了Spark的启动流程Spark内核架构流程深度剖析,Spark源码分析之DAGScheduler详解,Spark源码解读之Executor以及Task工作原理剖析,Spark源 ...

  7. 利用yolov5训练自己的数据集; yolov5的安装与使用 ; yolov5源码解读

    *免责声明: 1\此方法仅提供参考 2\搬了其他博主的操作方法,以贴上路径. 3* 场景一:Anconda环境基本操作 场景二:yolov5的使用 场景三:yolo v5训练自己的数据集 场景四:yo ...

  8. PTMs:QLoRA技巧之源码解读(qlora.py文件)—解析命令与加载参数→数据预处理→模型训练+评估+推理

    PTMs:QLoRA技巧之源码解读(qlora.py文件)-解析命令与加载参数→数据预处理→模型训练+评估+推理 目录 QLoRA技巧之源码解读(qlora.py文件)-解析命令与加载参数→数据预处理 ...

  9. python库源码分析_python第三方库Faker源码解读

    源码背景 Faker是一个Python第三方库,GITHUB开源项目,主要用于创建伪数据创建的数据包含地理信息类.基础信息类.个人账户信息类.网络基础信息类.浏览器信息类.文件信息类.数字类 文本加密 ...

  10. Bert系列(三)——源码解读之Pre-train

    https://www.jianshu.com/p/22e462f01d8c pre-train是迁移学习的基础,虽然Google已经发布了各种预训练好的模型,而且因为资源消耗巨大,自己再预训练也不现 ...

最新文章

  1. 互联网协议 — 使用 Wireshark 调试 HTTPS 及 HTTP/2 流量
  2. mysql年份_【数据库_Mysql】查询当前年份的sql
  3. linux下 安装tengine
  4. linux 阻止 复位命令,linux防误删操作(使用safe-rm;使用mv命令删除文件)
  5. Pentium M处理器架构/微架构/流水线(1) - 流水线概述
  6. 腾讯云blog:孪生网络入门(上) Siamese Net及其损失函数
  7. A Simple RESTful API Service With Node.js And Koa2
  8. python学来干什么-学 Python 都用来干嘛的?
  9. scala数据类型_Scala数据类型示例教程
  10. 5.4 Components -- Wrapping Content in A Component(在组件中包裹内容)
  11. python day 105
  12. sql和mysql 语法区别吗_sql和mysql语法有什么不同
  13. [系统安全] 十五.Chrome密码保存功能渗透解析、Chrome蓝屏漏洞及音乐软件漏洞复现
  14. 运动目标跟踪(十七)--一些跟踪算法简述及跟踪牛人资料整理
  15. 秀米编辑器内容复制到wangEditor中图片的处理
  16. Shiny server: application failed to start
  17. WIN10插上耳机拔掉后再插没声音的问题【已解决】
  18. CSS3知识点复习与总结
  19. 详解Mysql分布式事务XA
  20. Synplify 综合Gtech 网表

热门文章

  1. 小白快速学习 Kotlin 语法基础
  2. C语言中getch()、getche()、getc()、getchar()、gets()、fgetc()、fgets()的区别与使用
  3. 关于trycatchfinal返回值问题
  4. 股票、基金、场内ETF基金下载数据代码例子
  5. kdc服务器密码修改,KDC服务安装及配置
  6. 小文一篇,说说:where、:has和:is的特殊性吧
  7. Java岗大厂面试百日冲刺 - 日积月累,每日三题【Day30】—— 设计模式1
  8. 月入万元的SEO优化师:分享我几年接私单的经历
  9. 【三色N95pro显示器驱动板烧录-维修小记】
  10. FTP客户端(利用sun.net.ftp.FtpClient实现)