该代码是 python-torch 写的!
请看序列(一、二、三)

一、模型概述

文件分布

首先看文件的内容:STSGCN中包含两个文件夹:model,PeMSD7(M) 。model文件中包含:main.py,stsgcn.py ,utils.py三个文件。PeMSD7(M)中包含:矩阵文件 adj_mat.npy和特征数据node_values.npy两个文件。

PeMSD7(M)中的数据

特征数据为:(34722,207,2)
矩阵数据为:(207,207)
表明共有207个节点,每个节点2个特征

二、导入库和参数设置–main.py

import os
import argparse
import pickle as pk
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
# 导入自定义函数
from stgcn import STGCN
from utils import generate_dataset, load_metr_la_data, get_normalized_adj
# 参数设置
use_gpu = False
num_timesteps_input = 12 # 输入时间步长
num_timesteps_output = 3 # 输出时间步长epochs = 1000   # 训练轮次
batch_size = 50 # 批量大小parser = argparse.ArgumentParser(description='STGCN')
parser.add_argument('--enable-cuda', action='store_true',help='Enable CUDA')
# 添加 可选参数 --enable-cuda
# action=‘store_true’,只要运行时该变量有传参就将该变量设为True
args = parser.parse_args()
args.device = None
if args.enable_cuda and torch.cuda.is_available():args.device = torch.device('cuda')
else:args.device = torch.device('cpu')
# 获得当前设备的类型

三、查看主函数–main.py

if __name__ == '__main__':torch.manual_seed(7) #为CPU中设置种子,生成随机数A, X, means, stds = load_metr_la_data()print("数据加载...")split_line1 = int(X.shape[2] * 0.6) # train分割线split_line2 = int(X.shape[2] * 0.8) # test和verify分割线train_original_data = X[:, :, :split_line1]val_original_data = X[:, :, split_line1:split_line2]test_original_data = X[:, :, split_line2:]training_input, training_target = generate_dataset(train_original_data,num_timesteps_input=num_timesteps_input,num_timesteps_output=num_timesteps_output)val_input, val_target = generate_dataset(val_original_data,num_timesteps_input=num_timesteps_input,num_timesteps_output=num_timesteps_output)test_input, test_target = generate_dataset(test_original_data,num_timesteps_input=num_timesteps_input,num_timesteps_output=num_timesteps_output)print("数据生成器")A_wave = get_normalized_adj(A)        # 矩阵归一化A_wave = torch.from_numpy(A_wave)     # np.ndarray-->torch.TensorA_wave = A_wave.to(device=args.device)# 数据放入设备print("矩阵归一化后并存在设备上")net = STGCN(A_wave.shape[0],training_input.shape[3],num_timesteps_input,num_timesteps_output).to(device=args.device)print("模型实例化并且存放在设备上")optimizer = torch.optim.Adam(net.parameters(), lr=1e-3)# 优化net的所有参数,并设置学习率loss_criterion = nn.MSELoss() # 损失函数为MSEprint("建立损失函数和优化器")training_losses = []validation_losses = []validation_maes = []for epoch in range(epochs):loss = train_epoch(training_input, training_target,batch_size=batch_size)# 调用训练函数training_losses.append(loss) # 训练损失 保存# Run validationwith torch.no_grad():# 上下文管理器,with的内容将不会 track 梯度net.eval()val_input = val_input.to(device=args.device)    # 数据放入设备val_target = val_target.to(device=args.device)  # 数据放入设备out = net(A_wave, val_input) # 模型输入并启动val_loss = loss_criterion(out, val_target).to(device="cpu")  #调用损失函数:求结果和目标值的MSE,并放入设备validation_losses.append(np.asscalar(val_loss.detach().numpy()))# 将val_loss从模型中剥离(detach)后转为numpy数组再存入validation_lossesout_unnormalized = out.detach().cpu().numpy()*stds[0]+means[0] #out从model剥离放入cpu设备转为numpy数组,再逆归一化target_unnormalized = val_target.detach().cpu().numpy()*stds[0]+means[0]# target 同上操作mae = np.mean(np.absolute(out_unnormalized - target_unnormalized)) # 对out和target求 maevalidation_maes.append(mae) # mae存入out = Noneval_input = val_input.to(device="cpu")val_target = val_target.to(device="cpu")# 对每一epo打印当前 训练loss,val_loss,验证maeprint("Training loss: {}".format(training_losses[-1]))print("Validation loss: {}".format(validation_losses[-1]))print("Validation MAE: {}".format(validation_maes[-1]))plt.plot(training_losses, label="training loss")      # 画train——loss曲线图plt.plot(validation_losses, label="validation loss")  # 画val_-loss曲线图plt.legend()plt.show()# 设置检测路径,保存当前epo的训练结果,以防发生意外checkpoint_path = "checkpoints/"if not os.path.exists(checkpoint_path):os.makedirs(checkpoint_path)# 如果checkpoints文件夹不存在,则创建当前文件目录with open("checkpoints/losses.pk", "wb") as fd:pk.dump((training_losses, validation_losses, validation_maes), fd)# 将结果用pickle保存到路径fd

注释:

  1. with torch.no_grad():见链接no_grad()
  2. net.eval():用于测试或者评估之前。原因见参考博客园链接1和简书链接2。
  3. np.absolute=np.abs:对数组内每个元素求绝对值。在其上即求 mae

四、训练函数–main.py

def train_epoch(training_input, training_target, batch_size):"""Trains one epoch with the given data.:param training_input: Training inputs of shape (num_samples, num_nodes,num_timesteps_train, num_features).:param training_target: Training targets of shape (num_samples, num_nodes,num_timesteps_predict).:param batch_size: Batch size to use during training.:return: Average loss for this epoch."""permutation = torch.randperm(training_input.shape[0])#打乱索引,配合下文,用来随机抽取一个batch数据epoch_training_losses = []for i in range(0, training_input.shape[0], batch_size):net.train() # 见net.eval()注释optimizer.zero_grad()indices = permutation[i:i + batch_size]X_batch, y_batch = training_input[indices], training_target[indices]X_batch = X_batch.to(device=args.device)y_batch = y_batch.to(device=args.device)out = net(A_wave, X_batch)loss = loss_criterion(out, y_batch)loss.backward()optimizer.step()epoch_training_losses.append(loss.detach().cpu().numpy())return sum(epoch_training_losses)/len(epoch_training_losses)

注释:

  1. torch.randperm:返回一个0到n-1的数组。

torch.randperm(n, out=None, dtype=torch.int64, layout=torch.strided, device=None, requires_grad=False)
torch.randperm(4)
tensor([ 2, 1, 0, 3])

  1. optimizer.zero_grad()的原因:
  • 梯度下降的原理和实现步骤
  • 参数更新和反向传播
  • 梯度清零

下一篇:【stgcn】代码pytorch解读(二)

【stgcn】代码解读之主函数(一)相关推荐

  1. TEB局部轨迹规划代码解读

    teb局部规划代码地址:https://github.com/rst-tu-dortmund/teb_local_planner teb局部规划论文:C. Rösmann, F. Hoffmann a ...

  2. A-LOAM代码解读

    Comment 由于A-LOAM代码看的时间前前后后跨度时间较长,因此前半部分(scanRegistration.laserOdometry部分)的注释以及理解写的比较详细,而最后一个部分laserM ...

  3. matlab怎么调用主函数,Matlab中一个函数调用另外一个函数的操作步骤

    原创Matlab中一个函数调用另外一个函数的操作步骤 编辑:小安 来源:PC下载网时间:2019-11-18 13:27:35 最近很多伙伴才刚刚安装入手Matlab这款软件,而本节就重点介绍了关于M ...

  4. 【代码解读】超详细,YOLOV5之build_targets函数解读。

    文章目录 build_targets作用 注意 可视化结果 过程 详细代码解读 准备 第一遍筛选 扩增正样本 Reference build_targets作用 build_targets函数用于网络 ...

  5. ACMNO.27 Python的两行代码解决 C语言-字符逆序 写一函数。使输入的一个字符串按反序存放,在主函数中输入输出反序后的字符串。 输入 一行字符 输出 逆序后的字符串

    题目描述 写一函数,使输入的一个字符串按反序存放,在主函数中输入输出反序后的字符串. 输入 一行字符 输出 逆序后的字符串 样例输入 123456abcdef 样例输出 fedcba654321 来源 ...

  6. 【Groovy】Groovy 代码创建 ( 使用 Java 语法实现 Groovy 类和主函数并运行 | 按照 Groovy 语法改造上述 Java 语法规则代码 )

    文章目录 一.创建 Groovy 代码文件 二.使用 Java 语法实现 Groovy 类和主函数并运行 三.按照 Groovy 语法改造上述 Java 语法规则代码 一.创建 Groovy 代码文件 ...

  7. 【C 语言】字符串拷贝 ( 字符串拷贝业务逻辑代码 | 分离 主函数 与 字符串拷贝 业务模型 )

    文章目录 一.字符串拷贝业务逻辑代码 二.分离 主函数 与 字符串拷贝 业务模型 一.字符串拷贝业务逻辑代码 下面的代码 , 是 字符串 拷贝 最简单的代码 , 仅 使用 指针 遍历内存 , 实现了字 ...

  8. 【caffe解读】 caffe从数学公式到代码实现2-基础函数类

    文章首发于微信公众号<与有三学AI> [caffe解读] caffe从数学公式到代码实现2-基础函数类 接着上一篇,本篇就开始读layers下面的cpp,先看一下layers下面都有哪些c ...

  9. C语言main()主函数执行完毕后是否会再执行一段代码

    C语言main()主函数执行完毕后是否会再执行一段代码 分享到: QQ空间 新浪微博 腾讯微博 豆瓣 人人网 main() 主函数执行完毕后,是否可能会再执行一段代码?给出说明. main主函数是所有 ...

最新文章

  1. 联想笔记本电脑,重新安装系统之U盘启动方法
  2. python守护进程windows_如何把 python predict程序 做成 windows 守护进程
  3. 物联网中常见的传感器类型
  4. C/C++:Windows编程—创建进程、终止进程、枚举进程、枚举线程、枚举DLL
  5. 安卓listview点击空白事件_要权限才给用?安卓毒瘤APP滚蛋吧!
  6. Faster R-CNN——学习笔记~
  7. linux命令存放 bash: xxx command not found
  8. 湖南省对口升学c语言试题,湖南省对口升学计算机专业综合试卷试题.doc
  9. 描述计算机专业导论课程的内容结构,计算机专业导论课程学习内容.doc
  10. 新MacBook Pro软件安装记录
  11. 安卓脚本用什么写_什么是抖音脚本?脚本有什么用?
  12. border缩写属性
  13. 大数据定价方法的国内外研究综述及对比分析
  14. 给自己定个一年后的终极目标!
  15. 如何取消PPT的密码保护?
  16. 2022AcWing寒假算法每日一题之1934. 贝茜放慢脚步
  17. python实现去除图片水印
  18. 视觉世界中的“众里寻她”--开放环境下的人物特征表示
  19. Word中字号和磅值的对应关系
  20. 自从加入酒水捡漏群,京东自营酒水2折捡漏,我一下子屯了6个酒柜...

热门文章

  1. 权限管理中的RBAC与ABAC
  2. oracle怎么条件强制走索引,如何让oracle的select强制走索引
  3. 案例研究:设计令人震撼的名片!
  4. Linux IV ,IVM编辑 退出方法
  5. \Qt5\\bin\\d3dcompiler_47.dll
  6. 高新技术企业认定申请通过后补贴
  7. 感冒药盒上请看清这6个字,一定要注意! “美”:支气管炎患者慎用
  8. 陕西省ti杯竞赛题目_2017年全国大学生电子设计竞赛和陕西省(TI杯)校际联赛暨西安电子科技大学校内选拔赛...
  9. Centos7 安装DB2
  10. 这样构建的用户画像!想不懂你的用户都难