【ASTGCN之一个特征】代码解读(torch)之train_ASTGCN_r(二)
文章目录
- 一、文件主要结构
- 二、lib.utlis 函数 `load_graphdata_channel1`
- 三、lib.utlis函数`get_adjacency_matrix`
- 四、函数`train_main`
- 五、函数`predict_main`
- 六、lib.metrics中的函数`masked_mae`
- 七、lib.utlis函数`compute_val_loss_mstgcn`
- 八、lib.utils函数`predict_and_save_results_mstgcn`
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import os
from time import time
import shutil
import argparse
import configparser
from model.ASTGCN_r import make_model
from lib.utils import load_graphdata_channel1, get_adjacency_matrix, compute_val_loss_mstgcn, predict_and_save_results_mstgcn
from tensorboardX import SummaryWriter
在这里主要讲解训练的主函数以及在主函数中引用的函数。
一、文件主要结构
########训练设备######
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
######argparse参数####
# 1、创建一个解析器(ArgumentParser 对象包含将命令行解析成 Python 数据类型所需的全部信息。)
parser = argparse.ArgumentParser()
# 2、添加参数 (default - 不指定参数时的默认值。 type - 命令行参数应该被转换成的类型。
# help - 参数的帮助信息,当指定为 argparse.SUPPRESS 时表示不显示该参数的帮助信息.
# required - 可选参数是否可以省略 (仅针对可选参数)。)
# add_argument:读入命令行参数
parser.add_argument("--config", default='configurations/PEMS04_astgcn.conf', type=str,help="configuration file path")#'configurations/METR_LA_astgcn.conf'
# 3、解析参数
args = parser.parse_args()
# (4.结果测试) --> print(args.config)
######configparser参数#########
config = configparser.ConfigParser() # 创建对象
print('Read configuration file: %s' % (args.config))# 读取配置文件,如果配置文件不存在则创建
config.read(args.config)
data_config = config['Data'] # 读取configurations/PEMS08_astgcn.conf中的Data信息
training_config = config['Training'] # 读取args.config中的Training信息####从配置文件中获得相关参数#######
adj_filename = data_config['adj_filename']
graph_signal_matrix_filename = data_config['graph_signal_matrix_filename']
if config.has_option('Data', 'id_filename'):id_filename = data_config['id_filename']
else:id_filename = Nonenum_of_vertices = int(data_config['num_of_vertices'])
points_per_hour = int(data_config['points_per_hour'])
num_for_predict = int(data_config['num_for_predict'])
len_input = int(data_config['len_input'])
dataset_name = data_config['dataset_name']model_name = training_config['model_name']''' ctx = mx.gpu(int(ctx[ctx.index('-') + 0:])) '''
ctx = training_config['ctx'] # 将 'ctx=3' 赋值于 ctxos.environ["CUDA_VISIBLE_DEVICES"] = ctx # 指定要使用的GPU
USE_CUDA = torch.cuda.is_available() # GPU是否可用
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("CUDA:", USE_CUDA, DEVICE)learning_rate = float(training_config['learning_rate'])
epochs = int(training_config['epochs'])
start_epoch = int(training_config['start_epoch'])
batch_size = int(training_config['batch_size'])
num_of_weeks = int(training_config['num_of_weeks'])
num_of_days = int(training_config['num_of_days'])
num_of_hours = int(training_config['num_of_hours'])
time_strides = num_of_hours
nb_chev_filter = int(training_config['nb_chev_filter'])
nb_time_filter = int(training_config['nb_time_filter'])
in_channels = int(training_config['in_channels'])
nb_block = int(training_config['nb_block'])
K = int(training_config['K'])
loss_function = training_config['loss_function']
metric_method = training_config['metric_method']
missing_value = float(training_config['missing_value'])folder_dir = '%s_h%dd%dw%d_channel%d_%e' % (model_name, num_of_hours, num_of_days, num_of_weeks, in_channels, learning_rate)
print('folder_dir:', folder_dir)
params_path = os.path.join('experiments', dataset_name, folder_dir)
print('params_path:', params_path)
######数据加载#######
train_loader, train_target_tensor, val_loader, val_target_tensor, test_loader, test_target_tensor, _mean, _std \= load_graphdata_channel1(graph_signal_matrix_filename, num_of_hours,num_of_days, num_of_weeks, device, batch_size)# 见section2
adj_mx, distance_mx = get_adjacency_matrix(adj_filename, num_of_vertices, id_filename)#见section 3
# adj_mx = A (307, 307)-当PEMS04时 无权的
# distance_mx = distaneA (307, 307) 加权的#####模型实例化####
###def train_main # 见section 4
###def predict_main # 见section 5
#####训练或着预测###########
if __name__ == "__main__":train_main()# predict_main()
二、lib.utlis 函数 load_graphdata_channel1
- 输入变量
变量名称 | 类型 | 举例 |
---|---|---|
graph_signal_matrix_filename | str | ./data/PEMS04/PEMS04.npz |
num_of_weeks | int | 0 |
num_of_days | int | 0 |
num_of_hours | int | 1 |
DEVICE | CPU | |
batch_size | int | 32 |
- 函数流程
1. 文件路径的拼接结果: ./data/PEMS04\PEMS04_r1_d0_w0_astcgn.npz
2. 数据读取为file_data
3. 获得'特征数据'及'标签数据',特征中为F=3,现选择第一个特征[0:1]。
4. shape如下(seq,N,F,T)
* train: [10181, 307, 1, 12] train_target: [10181, 307, 12]
* val: [3394, 307, 1, 12] val_target: [3394, 307, 12]
* test: [3394, 307, 1, 12] test_target: [3394, 307, 12]
5. 获得关于F这一轴的均值和标准差
* PEMS04Mean: (1, 1, 1, 1)
* PEMS04Std: (1, 1, 1, 1)
6. train_loader的数据生成:生成Tensor:train_x_tensor和train_target_tensor函数TensorDataset生成相同格式的迭代器:train_dataset,该迭代器介绍:for tx,ty in train_dataset:共循环B=10181次tx.shape=(N,F,T)=(307,1,12)ty.shape=(N,F,T) =(307,1,12)
7. test_loader的数据生成:生成Tensor:test_x_tensor和test_target_tensor函数TensorDataset生成相同格式的迭代器:test_dataset,该迭代器介绍:for tx,ty in train_dataset:共循环B=3394次tx.shape=(N,F,T)=(307,1,12)ty.shape=(N,F,T) =(307,1,12) 8. val_loader的数据生成生成Tensor:val_x_tensor和val_target_tensor函数TensorDataset生成相同格式的迭代器:val_dataset,该迭代器介绍:for tx,ty in train_dataset:共循环B=3394次tx.shape=(N,F,T)=(307,1,12)ty.shape=(N,F,T) =(307,1,12)
- 输出数据
变量名称 | 类型 | shape | 用途 |
---|---|---|---|
train_loader | 迭代器 | len=319 | 每个迭代含dx,dy;dx和dy的shape都为(32,307,1,12) |
train_target_tensor | Tensor | [10181, 307, 12] | 训练部分的标签 |
val_loader | 迭代器 | len=107 | 每个迭代含dx,dy;dx和dy的shape都为(32,307,1,12) |
val_target_tensor | Tensor | [3394, 307, 12] | 验证部分的标签 |
test_loader | 迭代器 | len=107 | 每个迭代含dx,dy;dx和dy的shape都为(32,307,1,12) |
test_target_tensor | Tensor | [3394, 307, 12] | 测试部分的标签 |
mean | Tensor | [1,1,1,] | 整个数据集关于F[0:1]的均值 |
std | Tensor | [1,1,1,1] | 整个数据集关于F[0:1]的标准差 |
def load_graphdata_channel1(graph_signal_matrix_filename, num_of_hours, num_of_days, num_of_weeks, DEVICE, batch_size, shuffle=True):'''这个是为PEMS的数据准备的函数将x,y都处理成归一化到[-1,1]之前的数据;每个样本同时包含所有监测点的数据,所以本函数构造的数据输入时空序列预测模型;该函数会把hour, day, week的时间串起来;注: 从文件读入的数据,x是最大最小归一化的,但是y是真实值这个函数转为mstgcn,astgcn设计,返回的数据x都是通过减均值除方差进行归一化的,y都是真实值:return:three DataLoaders, each dataloader contains:test_x_tensor: (B, N_nodes, in_feature, T_input)test_decoder_input_tensor: (B, N_nodes, T_output)test_target_tensor: (B, N_nodes, T_output)'''file = os.path.basename(graph_signal_matrix_filename).split('.')[0]# 返回文件路径的最后的“文件名”,并以'.'分割;只保留名称不要后缀dirpath = os.path.dirname(graph_signal_matrix_filename)# 返回去掉最后“文件名”的路径filename = os.path.join(dirpath,file + '_r' + str(num_of_hours) + '_d' + str(num_of_days) + '_w' + str(num_of_weeks)) +'_astcgn'# 路径拼接,如果组件没有“/"或“\”,则自动添加“\”。print('load file:', filename)file_data = np.load(filename + '.npz')train_x = file_data['train_x'] # (10181, 307, 3, 12)train_x = train_x[:, :, 0:1, :]train_target = file_data['train_target'] # (10181, 307, 12)val_x = file_data['val_x']val_x = val_x[:, :, 0:1, :]val_target = file_data['val_target']test_x = file_data['test_x']test_x = test_x[:, :, 0:1, :]test_target = file_data['test_target']mean = file_data['mean'][:, :, 0:1, :] # (1, 1, 3, 1)std = file_data['std'][:, :, 0:1, :] # (1, 1, 3, 1)# ------- train_loader -------train_x_tensor = torch.from_numpy(train_x).type(torch.FloatTensor).to(DEVICE) # (B, N, F, T)train_target_tensor = torch.from_numpy(train_target).type(torch.FloatTensor).to(DEVICE) # (B, N, T)train_dataset = torch.utils.data.TensorDataset(train_x_tensor, train_target_tensor)# TensorDataset: 封装成tensor的数据集,每⼀个样本都通过索引张量来获得.# for t_x,t_target in train_dataset:# t_x.shape =(N,F,T)=(307,1,12)# t_target.shape=(N,F,T) =(307,1,12)train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=shuffle)#Dataloader(dataset,batch_size=32,shuffle=True)生成迭代数据# len(train_loader)=10181/32=318.15625,返回319个迭代# for tx,ty in train_loader:# tx.shape=(32,307,1,12)# ty.shape=(32,307,1,12)# ------- val_loader -------val_x_tensor = torch.from_numpy(val_x).type(torch.FloatTensor).to(DEVICE) # (B, N, F, T)val_target_tensor = torch.from_numpy(val_target).type(torch.FloatTensor).to(DEVICE) # (B, N, T)val_dataset = torch.utils.data.TensorDataset(val_x_tensor, val_target_tensor)val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=batch_size, shuffle=False)#Dataloader(dataset,batch_size=32,shuffle=True)生成迭代数据# len(val_loader)=3394/32=106.0625,返回107个迭代# ------- test_loader -------test_x_tensor = torch.from_numpy(test_x).type(torch.FloatTensor).to(DEVICE) # (B, N, F, T)test_target_tensor = torch.from_numpy(test_target).type(torch.FloatTensor).to(DEVICE) # (B, N, T)test_dataset = torch.utils.data.TensorDataset(test_x_tensor, test_target_tensor)test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size, shuffle=False)# Dataloader(dataset,batch_size=32,shuffle=True)生成迭代数据# len(test_loader)=3394/32=106.0625,返回107个迭代# printprint('train:', train_x_tensor.size(), train_target_tensor.size())print('val:', val_x_tensor.size(), val_target_tensor.size())print('test:', test_x_tensor.size(), test_target_tensor.size())return train_loader, train_target_tensor, val_loader, val_target_tensor, test_loader, test_target_tensor, mean, std
注释:
- PyTorch之torch.utils.data.DataLoader详解
DataLoader(dataset,batch_size=32,shuttle=True,drop_last=True):批次为32,在每次迭代训练时将数据洗牌,并丢弃最后一批数据,因为最后一批数据可能要<batch_size
疑问❓,这里的最后一批的数据集的batch_size≠32\neq 32=32,为什么不通过drop_last=True
丢掉呢??
三、lib.utlis函数get_adjacency_matrix
- 输入变量
变量 | 类型 | 举例 |
---|---|---|
distance_df_filename | str | ./data/PEMS04/distance.csv |
num_of_vertices | int | 307 |
id_filename | None |
- 函数流程
1. 如果是.npy格式邻接矩阵直接np.load读取并返回A,None
2. 否则读取.csv格式的数据
3. 根据顶点个数创建全0的array数组A, distanceA
4. 如果id_filename有内容的话,。。。。,然会A,distanceA
5. 否则的话,用csv打开distance_df_filename,数据填充进A和distanceA并返回
* A 是无权的邻接矩阵(0-1)
* distanceA是加权的邻接矩阵
def get_adjacency_matrix(distance_df_filename, num_of_vertices, id_filename=None):'''Returns----------A: np.ndarray, adjacency matrix'''if 'npy' in distance_df_filename: # falseadj_mx = np.load(distance_df_filename)return adj_mx, None# --------------------------------------------- read from hereelse:import csv
###### 创建全0的array数组A和distanceA####A = np.zeros((int(num_of_vertices), int(num_of_vertices)),dtype=np.float32)distaneA = np.zeros((int(num_of_vertices), int(num_of_vertices)),dtype=np.float32)# ------------ Ignoreif id_filename: # falsewith open(id_filename, 'r') as f:# 把节点id(idx)映射成从0开始的索引。 enumerate:枚举,对可遍历对象,例如列表、元组、字符串进行遍历id_dict = {int(i): idx for idx, i in enumerate(f.read().strip().split('\n'))}with open(distance_df_filename, 'r') as f:f.readline()reader = csv.reader(f)for row in reader:if len(row) != 3:continuei, j, distance = int(row[0]), int(row[1]), float(row[2])A[id_dict[i], id_dict[j]] = 1distaneA[id_dict[i], id_dict[j]] = distancereturn A, distaneA# ------------- Continue readingelse:with open(distance_df_filename, 'r') as f:f.readline()reader = csv.reader(f)for row in reader:if len(row) != 3:continuei, j, distance = int(row[0]), int(row[1]), float(row[2])A[i, j] = 1distaneA[i, j] = distancereturn A, distaneA
四、函数train_main
- 相关变量
变量 | 类型 | 举例 | 用途 |
---|---|---|---|
start_epoch | int | 0 | 模型训练开始的轮次 |
params_path | str | experiments\PEMS04\astgcn_r_h1d0w0_channel1_1.000000e-03 | 参数保存的路径 |
masked_flag | int | 0或1 | 用来选择控制损失函数 |
global_step | int | init 0 | 全局变量,记录当前训练轮次 |
best_epoch | int | init 0 | 记录验证模型最好的轮次 |
best_val_loss | float | init 无穷大 | 记录验证最好的结果 |
- 函数流程
#######准备工作#######
1. 通过start_epoch和params_path组合来保证参数保存的路径的正确性和可靠性
2. 打印相关参数
3. masked_flag和loss_function控制2个损失函数criterion和criterion_masked
4. 使用Adam方法初始化优化器optimizer
5. 初始化SummaryWriter以将信息写入日志目录
6. 打印net模型、net's state_dict、Optimizer's state_dict
7. 如果不是从0开始训练,则获得参数文件,net加载参数
8. 设置初始全局变量:global_step = 0 best_epoch = 0,best_val_loss为无穷大
#####模型训练#######
for epoch in range(start_epoch, epochs)a. 参数文件路径的获取部分b. 模型验证部分c. 模型训练部分net.train()for ...train_loader:获取数据inputs,label梯度设置为0outputs=net(inputs)根据masked_flag获得损失值lossloss.backward()向后函数optimizer.step()优化training_loss = loss.item()损失值剥离global_step += 1训练次数+1sw.add_scalar 写入日志
#####获取最优net####
打印best_val_loss和best_epoch
并根据获得最优的net
使用 predict_main函数预测最优结果
def train_main():###########param_path-参数保存的路径的正确设置########if (start_epoch == 0) and (not os.path.exists(params_path)):# start_epoch=0,且params_path不存在,则创建paras_path目录,注意如果目录存在使用os.makedirs则会报错os.makedirs(params_path)print('create params directory %s' % (params_path))elif (start_epoch == 0) and (os.path.exists(params_path)):# start_epoch=0且params_path存在则删除文件夹及内容,再重新创建shutil.rmtree(params_path)os.makedirs(params_path)print('delete the old one and create params directory %s' % (params_path))elif (start_epoch > 0) and (os.path.exists(params_path)):print('train from params directory %s' % (params_path))else:raise SystemExit('Wrong type of model!')
############打印相关参数######################print('param list:')print('CUDA\t', device)print('in_channels\t', in_channels)print('nb_block\t', nb_block)print('nb_chev_filter\t', nb_chev_filter)print('nb_time_filter\t', nb_time_filter)print('time_strides\t', time_strides)print('batch_size\t', batch_size)print('graph_signal_matrix_filename\t', graph_signal_matrix_filename)print('start_epoch\t', start_epoch)print('epochs\t', epochs)###masked_flag和loss_function控制损失函数criterion和criterion_masked###masked_flag=0criterion = nn.L1Loss().to(device)# L1Loss:取预测值和真实值的绝对误差的平均数# mae是绝对值误差L1损失,mse是均方误差L2损失criterion_masked = masked_mae# 见section 6if loss_function=='masked_mse':criterion_masked = masked_mse #nn.MSELoss().to(DEVICE)masked_flag=1elif loss_function=='masked_mae':criterion_masked = masked_maemasked_flag = 1elif loss_function == 'mae':criterion = nn.L1Loss().to(device) # 一阶masked_flag = 0elif loss_function == 'rmse':criterion = nn.MSELoss().to(device) # 二阶masked_flag= 0####### initialize a trainer to train modeloptimizer = optim.Adam(net.parameters(), lr=learning_rate)######## initialize a SummaryWriter to write information into logs dirsw = SummaryWriter(logdir=params_path, flush_secs=5)#######打印net,net's state dict, optimizer's state dict####print(net)print('Net\'s state_dict:')total_param = 0for param_tensor in net.state_dict():print(param_tensor, '\t', net.state_dict()[param_tensor].size())total_param += np.prod(net.state_dict()[param_tensor].size())print('Net\'s total params:', total_param)print('Optimizer\'s state_dict:')for var_name in optimizer.state_dict():print(var_name, '\t', optimizer.state_dict()[var_name])
# 》
#######设置初始的全局变量#######global_step = 0 best_epoch = 0best_val_loss = np.inf # 最优val_loss为无穷大start_time = time() # 训练起始时间
#####如果不是从0开始训练,则获得参数文件,net加载参数#######if start_epoch > 0:params_filename = os.path.join(params_path, 'epoch_%s.params' % start_epoch)net.load_state_dict(torch.load(params_filename))print('start epoch:', start_epoch)print('load weight from: ', params_filename)####### train model###########for epoch in range(start_epoch, epochs):### 参数文件路径的拼接获取params_filename = os.path.join(params_path, 'epoch_%s.params' % epoch)################函数验证部分确定##############根据masked_flag选择不同的损失函数并获得val_lossif masked_flag:val_loss = compute_val_loss_mstgcn(net, val_loader, criterion_masked, masked_flag,missing_value,sw, epoch)# 见section 7else:val_loss = compute_val_loss_mstgcn(net, val_loader, criterion, masked_flag, missing_value, sw, epoch)# 见section 7###如果val_loss<最优值,则最优值更新并记录当前轮次if val_loss < best_val_loss:best_val_loss = val_lossbest_epoch = epochtorch.save(net.state_dict(), params_filename)print('save parameters to file: %s' % params_filename)#################函数训练部分#############net.train() # ensure dropout layers are in train modefor batch_index, batch_data in enumerate(train_loader):encoder_inputs, labels = batch_data # encoder_inputs torch.Size([32, 307, 1, 12]) label torch.Size([32, 307,1, 12])optimizer.zero_grad()## 设置梯度为0outputs = net(encoder_inputs)if masked_flag:loss = criterion_masked(outputs, labels,missing_value)else :loss = criterion(outputs, labels)loss.backward()optimizer.step()training_loss = loss.item()global_step += 1sw.add_scalar('training_loss', training_loss, global_step)if global_step % 1000 == 0:print('global step: %s, training loss: %.2f, time: %.2fs' % (global_step, training_loss, time() - start_time))print('best epoch:', best_epoch)####### # apply the best model on the test setpredict_main(best_epoch, test_loader, test_target_tensor,metric_method ,_mean, _std, 'test')# 见section 5
五、函数predict_main
- 输入变量
变量 | 类型 | 举例 | 用途 |
---|---|---|---|
global_step | int | >=0 | 训练轮次 |
data_loader | 迭代器 | test_loader | 用于预测的测试数据 |
data_target_tensor | Tensor | test_target_tensor | 测试的标签数据 |
metric_method | str | unmask | |
_mean | (1,1,1,1) | load_graphdata_channel1中而来 | 均值 |
_std | (1,1,1,1) | load_graphdata_channel1中而来 | 标准差 |
type | str | test | 表明哪个数据样本 |
- 函数流程
1. 参数路径的拼接获取
2. net加载参数路径
3. predict_and_save_results_mstgcn函数获得预测结果
def predict_main(global_step, data_loader, data_target_tensor,metric_method, _mean, _std, type):params_filename = os.path.join(params_path, 'epoch_%s.params' % global_step)print('load weight from:', params_filename)net.load_state_dict(torch.load(params_filename))predict_and_save_results_mstgcn(net, data_loader, data_target_tensor, global_step, metric_method,_mean, _std, params_path, type)
六、lib.metrics中的函数masked_mae
def masked_mae(preds, labels, null_val=np.nan):if np.isnan(null_val):mask = ~torch.isnan(labels)else:mask = (labels != null_val)mask = mask.float()mask /= torch.mean((mask))mask = torch.where(torch.isnan(mask), torch.zeros_like(mask), mask)loss = torch.abs(preds - labels)loss = loss * maskloss = torch.where(torch.isnan(loss), torch.zeros_like(loss), loss)return torch.mean(loss)
七、lib.utlis函数compute_val_loss_mstgcn
- 输入变量
变量 | 类型 | 举例 | 用途 |
---|---|---|---|
net | nn.Module | 模型 | |
val_loader | 迭代器 | 验证集数据 | |
criterion | fun | nn.MSELoss | 损失函数 |
masked_flag | int | 0或1 | 用于配合损失函数 |
sw | tensorboardX.SummaryWriter | 训练日志 | |
epoch | int | >=0 | 第几轮训练 |
- 函数流程
1. net.train(False)去点dropout机制
2. 模型静态0梯度:tmp=[ ] 存储loss结果for enumerate(val_loader):outputs=net(输入)用masked_flag选择损失函数返回损失lossvalidation_loss是多轮次的平均损失值将epoch和validation_loss写入日志swreturn validation_loss
《Dropout理解-原理,实现,优缺点》
def compute_val_loss_mstgcn(net, val_loader, criterion, masked_flag,missing_value,sw, epoch, limit=None):
# :return: val_loss# 《 定义损失函数net.train(False) # ensure dropout layers are in evaluation modewith torch.no_grad():val_loader_length = len(val_loader) # nb of batchtmp = [] # 记录了所有batch的lossfor batch_index, batch_data in enumerate(val_loader):encoder_inputs, labels = batch_dataoutputs = net(encoder_inputs)if masked_flag:loss = criterion(outputs, labels, missing_value)else:loss = criterion(outputs, labels)tmp.append(loss.item())if batch_index % 100 == 0:print('validation batch %s / %s, loss: %.2f' % (batch_index + 1, val_loader_length, loss.item()))if (limit is not None) and batch_index >= limit:breakvalidation_loss = sum(tmp) / len(tmp)sw.add_scalar('validation_loss', validation_loss, epoch)return validation_loss
八、lib.utils函数predict_and_save_results_mstgcn
注意到:该函数多次用到type
,input
等原本具有函数意义的变量。不建议如此使用。
- 输入变量
变量 | 类型 | 举例 | 用途 |
---|---|---|---|
net | nn.Module | ||
data_loader | 迭代器 | test_loader | 测试数据用于预测 |
data_target_tensor | Tensor | test_target_tensor | 测试数据的label |
global_step | int | >=0 | 用于选择第几轮次模型参数 |
metric_method | str | unmask | |
_mean | Tensor | (1,1,1,1) | 均值 |
_std | Tensor | (1,1,1,1) | 标准差 |
params_path | str | 文件路径 | |
type | str | “test” | 说明数据类型的 |
- 函数流程
A. net.train(False) # 去掉dropout机制
B. with torch.no_grad() 以下均在此环境中进行
1. prediction=[]用于存储net结果;input=[]用于存储输入数据
2. 注意到所有与net相关的数据转换:detach().cpu().numpy()
3. for ... enumerate(data_loader):outputs=net(encoder_inputs)输入和输出数据分别放入input和prediction
4. 对input数据按照axis=0进行合并,再利用_mean和_std进行反归一化
5. 对prediction数据按照axis=0进行合并
6. 打印input,prediction,data_target_tensor的shape
7. excel_list=[]用于存储误差
8. for i in range(prediction.shape[2]):确保 data_target_tensor和prediction的shape一致对数据data_target_tensor[:, :, i]和prediction[:, :, i]求误差if metric_method == 'mask':计算mae,rmse,mape并打印结果if metric_method == 'unmask'计算mae,rmse,mape并打印结果[mae,rmse,mape]放入excel_list
9. if metric_method == 'mask':对数据data_target_tensor.reshape(-1, 1),和prediction.reshape(-1, 1)求误差计算mae,rmse,mape
10. if metric_method == 'ummask':对数据data_target_tensor.reshape(-1, 1),和prediction.reshape(-1, 1)求误差计算mae,rmse,mape
11. 打印mae,rmse,mape这个总结果,并放入excel_list中
def predict_and_save_results_mstgcn(net, data_loader, data_target_tensor, global_step, metric_method,_mean, _std, params_path, type):net.train(False) # ensure dropout layers are in test modewith torch.no_grad():data_target_tensor = data_target_tensor.cpu().numpy()loader_length = len(data_loader) # nb of batchprediction = [] # 存储所有batch的outputinput = [] # 存储所有batch的inputfor batch_index, batch_data in enumerate(data_loader):encoder_inputs, labels = batch_datainput.append(encoder_inputs[:, :, 0:1].cpu().numpy()) # (batch, T', 1)outputs = net(encoder_inputs)prediction.append(outputs.detach().cpu().numpy())if batch_index % 100 == 0:print('predicting data set batch %s / %s' % (batch_index + 1, loader_length))input = np.concatenate(input, 0)input = re_normalization(input, _mean, _std)prediction = np.concatenate(prediction, 0) # (batch, T', 1)print('input:', input.shape)print('prediction:', prediction.shape)print('data_target_tensor:', data_target_tensor.shape)output_filename = os.path.join(params_path, 'output_epoch_%s_%s' % (global_step, type))np.savez(output_filename, input=input, prediction=prediction, data_target_tensor=data_target_tensor)# 计算误差excel_list = []prediction_length = prediction.shape[2]for i in range(prediction_length):assert data_target_tensor.shape[0] == prediction.shape[0]print('current epoch: %s, predict %s points' % (global_step, i))if metric_method == 'mask':mae = masked_mae_test(data_target_tensor[:, :, i], prediction[:, :, i],0.0)rmse = masked_rmse_test(data_target_tensor[:, :, i], prediction[:, :, i],0.0)mape = masked_mape_np(data_target_tensor[:, :, i], prediction[:, :, i], 0)else :mae = mean_absolute_error(data_target_tensor[:, :, i], prediction[:, :, i])rmse = mean_squared_error(data_target_tensor[:, :, i], prediction[:, :, i]) ** 0.5mape = masked_mape_np(data_target_tensor[:, :, i], prediction[:, :, i], 0)print('MAE: %.2f' % (mae))print('RMSE: %.2f' % (rmse))print('MAPE: %.2f' % (mape))excel_list.extend([mae, rmse, mape])# print overall resultsif metric_method == 'mask':mae = masked_mae_test(data_target_tensor.reshape(-1, 1), prediction.reshape(-1, 1), 0.0)rmse = masked_rmse_test(data_target_tensor.reshape(-1, 1), prediction.reshape(-1, 1), 0.0)mape = masked_mape_np(data_target_tensor.reshape(-1, 1), prediction.reshape(-1, 1), 0)else :mae = mean_absolute_error(data_target_tensor.reshape(-1, 1), prediction.reshape(-1, 1))rmse = mean_squared_error(data_target_tensor.reshape(-1, 1), prediction.reshape(-1, 1)) ** 0.5mape = masked_mape_np(data_target_tensor.reshape(-1, 1), prediction.reshape(-1, 1), 0)print('all MAE: %.2f' % (mae))print('all RMSE: %.2f' % (rmse))print('all MAPE: %.2f' % (mape))excel_list.extend([mae, rmse, mape])print(excel_list)
【ASTGCN之一个特征】代码解读(torch)之train_ASTGCN_r(二)相关推荐
- 【AM-GCN】代码解读之主程序(二)
[前篇]: [AM-GCN]代码解读之初了解(一) [AM-GCN]论文解读 一.导入库 import torch.nn.functional as F #常用函数 import torch.opti ...
- STM32学习心得二十一:实时时钟RTC和备份寄存器BKP特征、原理及相关实验代码解读
记录一下,方便以后翻阅~ 主要内容 1) RTC特征与原理: 2) BKP备份寄存器特征与原理: 3) RTC常用寄存器+库函数介绍: 4) 相关实验代码解读. 实验内容: 因为没有买LCD屏,所以计 ...
- shfflenetv2代码解读
shufflenetv2代码解读 目录 shufflenetv2代码解读 概述 shufflenetv2网络结构图 shufflenetv2架构参数 shufflenetv2代码细节分析 概述 shu ...
- Inception代码解读
Inception代码解读 目录 Inception代码解读 概述 Inception网络结构图 inception网络结构框架 inception代码细节分析 概述 inception相比起最开始兴 ...
- AlexNet代码解读
AlexNet代码解读 目录 AlexNet代码解读 概述 网络结构图 AlexNet代码细节分析 概述 AlexNet的网络结构很简单,是最初级版本的CNN,没有使用什么技巧. 网络分成两个部分,分 ...
- VGAE(Variational graph auto-encoders)论文及代码解读
一,论文来源 论文pdf Variational graph auto-encoders 论文代码 github代码 二,论文解读 理论部分参考: Variational Graph Auto-Enc ...
- Memory-Associated Differential Learning论文及代码解读
Memory-Associated Differential Learning论文及代码解读 论文来源: 论文PDF: Memory-Associated Differential Learning论 ...
- 编译原理语义分析代码_Pix2Pix原理分析与代码解读
原理分析: 图像.视觉中很多问题都涉及到将一副图像转换为另一幅图像(Image-to-Image Translation Problem),这些问题通常都使用特定的方法来解决,不存在一个通用的方法.但 ...
- 基于实例分割方法的端到端车道线检测 论文+代码解读
Towards End-to-End Lane Detection: an Instance Segmentation Approach 论文原文 https://arxiv.org/pdf/1802 ...
最新文章
- dateformat java 并发_java.text.DateFormat 多线程并发问题
- [Node.js] 2、利用node-git-server快速搭建git服务器
- HTTPS加密那点事--轻松秒懂HTTPS非对称加密
- python 处理excel
- NOIP2014联合权值
- Flutter 动画组件
- 项目打包部署到Tomcat
- js 中转换成list集合_程序员:java集合介绍-List,具说很详细,你不来看看?
- 【Http】PKIX(PKIX path building failed) 问题 unable to find valid certification path to requested target
- 解决 fatal: Not a git repository (or any of the parent directories): .git 问题
- PLSQL 升级到最新版本 以及 破解注册细则
- webtrends之ODBC源数据获取(一)——环境准备篇
- 用IDEA进行git版本控制
- echarts绘制分时图(1)-- 接收及处理股票接口数据
- 计算机在平面设计中的作用,计算机在平面设计中应用2篇.docx
- CentOS7和CentOS8 FreeSWITCH 1.10.7 简单图形化界面20--FreeSWITCH通过域名注册
- 分享怎么才能防止域名被劫持
- R语言条件逻辑回归-因变量是配对资料
- python爬取bilibili数据_如何使用Python爬取bilibili视频(详细教程)
- 用python画小仓鼠教程_彩铅画教程:教你画小仓鼠
热门文章
- mysql的事件开启,停止和启动
- GO语言基础进阶教程:sync包——互斥锁
- Java图书馆管理系统毕业设计、JAVA图书借阅系统设计与实现 毕设作品参考
- Intellij Idea整合JProfiler插件
- 我的世界mod整合包java_我的世界1.7.10食物、匠魂与超能力向整合包(附JAVA)
- SD卡检测插入卡槽方案
- 算法导论(三)字符串匹配
- 【算法千题案例】每日LeetCode打卡——93.宝石与石头
- 报错分析:Precision and F-score are ill-defined for being 0.0 in labels with no predicted samples
- 我的Ubuntu装机配置