说明:本文源代码来源于MACHINE LEARNING 2022 SPRING课程,我只是针对源代码进行了一些加工处理。感谢互联网,让我们能免费接触到这些优秀的课程。

前置知识

  • 什么是回归模型?简单说就是模型的输出是连续的,如概率大小等

目标

  • 借助DNN(Deep Neural Networks)网络解决一个回归问题
  • 理解基本的DNN训练技巧,如超参数的微调、特征选取、正则化
  • 根据美国某州过去五天中前四天的调查结果,预测第五天新冠测试阳性的病例数

任务描述

  • COVID-19情况预测
  • 数据来源:Delphi group@CMU 自2020年4月以来,通过FaceBook进行的每日调查
  • 根据美国特定州最近5天的调查结果,预测第5天的新确诊患者比率

数据组成

  • 州代码(37个州,已编码成独热向量)

    • 独热向量:仅有一个元素置为1,而其它元素均置为0的向量。在深度学习中常用于编码离散值
  • COVID相似症状(4组)
  • 行为指标(8组)
  • 心理健康指标(3组)
  • 阳性病例(我们想预测的数据)

性能指标

  • Mean Squared Error(MSE)

    • MSE=1N∑i=1N(yi−y~i)2MSE=\frac{1}{N}\sum_{i=1}^{N}(y_i-\tilde{y}_i)^2MSE=N1​∑i=1N​(yi​−y~​i​)2
    • yiy_iyi​代表Ground truth,y~i\tilde{y}_iy~​i​代表模型输出的预测值

实现思路

源码解析

基础部分

导包
# Numerical Operations
import math
import numpy as np# Reading/Writing Data
import pandas as pd
import os
import csv# For Progress Bar
from tqdm import tqdm# Pytorch
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader, random_split# For plotting learning curve
from torch.utils.tensorboard import SummaryWriter
功能函数
def same_seed(seed): '''Fixes random number generator seeds for reproducibility.'''# A bool that, if True, causes cuDNN to only use deterministic convolution algorithms.# cudnn: 是经GPU加速的深度神经网络基元库。cuDNN可大幅优化标准例程(例如用于前向传播和反向传播的卷积层、池化层、归一化层和激活层)的实施。torch.backends.cudnn.deterministic = True# A bool that, if True, causes cuDNN to benchmark multiple convolution algorithms and select the fastest.torch.backends.cudnn.benchmark = False# 用于生成指定的随机数np.random.seed(seed)# Sets the seed for generating random numbers. torch.manual_seed(seed)if torch.cuda.is_available():# Sets the seed for generating random numbers for the current GPU. # It’s safe to call this function if CUDA is not available; in that case, it is silently ignored.torch.cuda.manual_seed_all(seed)def train_valid_split(data_set, valid_ratio, seed):'''Split provided training data into training set and validation set'''valid_set_size = int(valid_ratio * len(data_set)) train_set_size = len(data_set) - valid_set_size# Randomly split a dataset into non-overlapping new datasets of given lengths. # Optionally fix the generator for reproducible resultstrain_set, valid_set = random_split(data_set, [train_set_size, valid_set_size], generator=torch.Generator().manual_seed(seed))return np.array(train_set), np.array(valid_set)def predict(test_loader, model, device):model.eval() # Set your model to evaluation mode.preds = []for x in tqdm(test_loader):x = x.to(device)                        with torch.no_grad():                   pred = model(x)                     preds.append(pred.detach().cpu())   preds = torch.cat(preds, dim=0).numpy()  return preds

数据

数据的下载
# 下面这些是从谷歌云盘上下载数据到当前目录下
!gdown --id '1kLSW_-cW2Huj7bh84YTdimGBOJaODiOS' --output covid.train.csv
!gdown --id '1iiI5qROrAhZn-o4FPqsE97bMzDEFvIdg' --output covid.test.csv
数据的预处理(特征选取、数据划分)
def select_feat(train_data, valid_data, test_data, select_all=True):'''Selects useful features to perform regression'''y_train, y_valid = train_data[:,-1], valid_data[:,-1]raw_x_train, raw_x_valid, raw_x_test = train_data[:,:-1], valid_data[:,:-1], test_dataif select_all:feat_idx = list(range(raw_x_train.shape[1]))else:feat_idx = [0,1,2,3,4] # TODO: Select suitable feature columns.return raw_x_train[:,feat_idx], raw_x_valid[:,feat_idx], raw_x_test[:,feat_idx], y_train, y_valid# Set seed for reproducibility
same_seed(config['seed'])
# train_data size: 2699 x 118 (id + 37 states + 16 features x 5 days)
# test_data size: 1078 x 117 (without last day's positive rate)
train_data, test_data = pd.read_csv('./covid.train.csv').values, pd.read_csv('./covid.test.csv').values
train_data, valid_data = train_valid_split(train_data, config['valid_ratio'], config['seed'])# Print out the data size.
print(f"""train_data size: {train_data.shape}
valid_data size: {valid_data.shape}
test_data size: {test_data.shape}""")# Select features
x_train, x_valid, x_test, y_train, y_valid = select_feat(train_data, valid_data, test_data, config['select_all'])# Print out the number of features.
print(f'number of features: {x_train.shape[1]}')
数据加载器的构造(DataSet、DataLoader)
class COVID19Dataset(Dataset):'''x: Features.y: Targets, if none, do prediction.'''def __init__(self, x, y=None):if y is None:self.y = yelse:self.y = torch.FloatTensor(y)self.x = torch.FloatTensor(x)def __getitem__(self, idx):if self.y is None:return self.x[idx]else:return self.x[idx], self.y[idx]def __len__(self):return len(self.x)
train_dataset, valid_dataset, test_dataset = COVID19Dataset(x_train, y_train), \COVID19Dataset(x_valid, y_valid), \COVID19Dataset(x_test)# Pytorch data loader loads pytorch dataset into batches.
train_loader = DataLoader(train_dataset, batch_size=config['batch_size'], shuffle=True, pin_memory=True)
valid_loader = DataLoader(valid_dataset, batch_size=config['batch_size'], shuffle=True, pin_memory=True)
test_loader = DataLoader(test_dataset, batch_size=config['batch_size'], shuffle=False, pin_memory=True)

网络结构

结构的实现
class My_Model(nn.Module):def __init__(self, input_dim):super(My_Model, self).__init__()# TODO: modify model's structure, be aware of dimensions. self.layers = nn.Sequential(nn.Linear(input_dim, 16),nn.ReLU(),nn.Linear(16, 8),nn.ReLU(),nn.Linear(8, 1))def forward(self, x):x = self.layers(x)x = x.squeeze(1) # (B, 1) -> (B)return x

训练与预测

训练函数
def trainer(train_loader, valid_loader, model, config, device):criterion = nn.MSELoss(reduction='mean') # Define your loss function, do not modify this.# Define your optimization algorithm. # TODO: Please check https://pytorch.org/docs/stable/optim.html to get more available algorithms.# TODO: L2 regularization (optimizer(weight decay...) or implement by your self).optimizer = torch.optim.SGD(model.parameters(), lr=config['learning_rate'], momentum=0.9) # ?writer = SummaryWriter() # Writer of tensoboard.if not os.path.isdir('./models'):os.mkdir('./models') # Create directory of saving models.n_epochs, best_loss, step, early_stop_count = config['n_epochs'], math.inf, 0, 0for epoch in range(n_epochs):model.train() # Set your model to train mode.loss_record = []# tqdm is a package to visualize your training progress.train_pbar = tqdm(train_loader, position=0, leave=True)for x, y in train_pbar:optimizer.zero_grad()               # Set gradient to zero.x, y = x.to(device), y.to(device)       # Move your data to device. pred = model(x)             loss = criterion(pred, y)loss.backward()                     # Compute gradient(backpropagation).optimizer.step()                    # Update parameters.step += 1loss_record.append(loss.detach().item())# Display current epoch number and loss on tqdm progress bar.train_pbar.set_description(f'Epoch [{epoch+1}/{n_epochs}]')train_pbar.set_postfix({'loss': loss.detach().item()})mean_train_loss = sum(loss_record)/len(loss_record)writer.add_scalar('Loss/train', mean_train_loss, step)model.eval() # Set your model to evaluation mode.loss_record = []for x, y in valid_loader:x, y = x.to(device), y.to(device)with torch.no_grad():pred = model(x)loss = criterion(pred, y)loss_record.append(loss.item())mean_valid_loss = sum(loss_record)/len(loss_record)print(f'Epoch [{epoch+1}/{n_epochs}]: Train loss: {mean_train_loss:.4f}, Valid loss: {mean_valid_loss:.4f}')writer.add_scalar('Loss/valid', mean_valid_loss, step)if mean_valid_loss < best_loss:best_loss = mean_valid_losstorch.save(model.state_dict(), config['save_path']) # Save your best modelprint('Saving model with loss {:.3f}...'.format(best_loss))early_stop_count = 0else: early_stop_count += 1if early_stop_count >= config['early_stop']:print('\nModel is not improving, so we halt the training session.')return
训练参数的设置
device = 'cuda' if torch.cuda.is_available() else 'cpu'
config = {'seed': 5201314,      # Your seed number, you can pick your lucky number. :)'select_all': True,   # Whether to use all features.'valid_ratio': 0.2,   # validation_size = train_size * valid_ratio'n_epochs': 3000,     # Number of epochs.            'batch_size': 256, 'learning_rate': 1e-5,              'early_stop': 400,    # If model has not improved for this many consecutive epochs, stop training.     'save_path': './models/model.ckpt'  # Your model will be saved here.
}
开始训练
model = My_Model(input_dim=x_train.shape[1]).to(device) # put your model and data on the same computation device.
trainer(train_loader, valid_loader, model, config, device)
测试函数(保存测试结果)
def save_pred(preds, file):''' Save predictions to specified file '''with open(file, 'w') as fp:writer = csv.writer(fp)writer.writerow(['id', 'tested_positive'])for i, p in enumerate(preds):writer.writerow([i, p])
开始测试
model = My_Model(input_dim=x_train.shape[1]).to(device)
model.load_state_dict(torch.load(config['save_path']))
preds = predict(test_loader, model, device)
save_pred(preds, 'pred.csv')

【源码解析】如何从零实现一个回归模型?相关推荐

  1. spring boot 源码解析15-spring mvc零配置

    前言 spring boot 是基于spring 4 的基础上的一个框架,spring 4 有一个新特效–>基于java config 实现零配置.而在企业的实际工作中,spring 都是和sp ...

  2. object_detection源码解析-box_list

    models.research.object_detection源码解析-core.box_list box_list是一个ObjectDetection项目中,一个综合管理bounding box的 ...

  3. Laravel5.2之Filesystem源码解析(下)

    2019独角兽企业重金招聘Python工程师标准>>> 说明:本文主要学习下\League\Flysystem这个Filesystem Abstract Layer,学习下这个pac ...

  4. JDK源码解析之java.util.AbstractCollection

    AbstractCollection类提供了collection的实现类应该具有的基本方法,具有一定的普适性,可以从大局上了解collection实现类的主要功能. java.util.Abstrac ...

  5. Laravel核心解读--Session源码解析

    Session 模块源码解析 由于HTTP最初是一个匿名.无状态的请求/响应协议,服务器处理来自客户端的请求然后向客户端回送一条响应.现代Web应用程序为了给用户提供个性化的服务往往需要在请求中识别出 ...

  6. HDFS源码解析:教你用HDFS客户端写数据

    摘要:终于开始了这个很感兴趣但是一直觉得困难重重的源码解析工作,也算是一个好的开端. 本文分享自华为云社区<hdfs源码解析之客户端写数据>,作者: dayu_dls. 在我们客户端写数据 ...

  7. Zookeeper源码解析 -- 本地事务日志持久化之FileTxnLog

    序言 在各个分布式组件中,持久化数据到本地的思想并不少见,为的是能保存内存中的数据,以及重启后能够重载上次内存状态的值.那么如何行之有效的进行,内存数据持久化到磁盘,怎么样的落盘策略合适,怎么设计持久 ...

  8. 【详解】Ribbon 负载均衡服务调用原理及默认轮询负载均衡算法源码解析、手写

    Ribbon 负载均衡服务调用 一.什么是 Ribbon 二.LB负载均衡(Load Balancer)是什么 1.Ribbon 本地负载均衡客户端 VS Nginx 服务端负载均衡的区别 2.LB负 ...

  9. 深入探究JDK中Timer的使用方式与源码解析

    导言 定时器Timer的使用 构造方法 实例方法 使用方式 1. 执行时间晚于当前时间 2. 执行时间早于当前时间 3. 向Timer中添加多个任务 4. 周期性执行任务 5. 停止任务 源码解析 T ...

最新文章

  1. matlab怎么没有编辑器,在不打开编辑器窗口的情况下开始一个新的MATLAB会话
  2. python是一种语言吗-Python是一种什么样的编程语言?解释?编译?汇编?机械?...
  3. JavaScript对TreeView的操作全解
  4. Codeforces 474C Captain Marmot 给定4个点和各自旋转中心 问旋转成正方形的次数
  5. java excel中删除两列_Java 插入、隐藏/显示、删除Excel行或列
  6. 数据结构之图:有向图的拓扑排序,Python代码实现——26
  7. LeetCode 513. 找树左下角的值(递归)
  8. php fmod小数位数_PHP取余函数介绍MOD(x,y)与x%y
  9. c语言贪吃蛇游戏 vc6,把tc下的贪吃蛇游戏改到vc6下运行
  10. Oracle P6培训系列:15定义资源库
  11. matlab数学建模程序代码大全,matlab程序代码
  12. linux如何生成tgz文件,linux – 压缩文件夹以创建tgz文件
  13. 安卓 USB 无权限请求权限崩溃 UsbManager.requestPermission()空指针异常
  14. 【设计模式】简单工厂模式+工厂方法模式+抽象工厂模式
  15. 微信小程序动态绑定unit-id
  16. 侧边栏如何展开与收起
  17. .net core with 微服务 - Polly 熔断降级
  18. 如何运用政务智慧引导系统提升群众办事效率
  19. bugku:游戏过关
  20. 什么是IP地址定位,优缺点有哪些?

热门文章

  1. java记事本字体_记事本中的字体+字形+大小设置
  2. 顶级(top-level window)窗口,被拥有窗口(owned window),子窗口(child window) 与WS_POPUP,WS_CHILD深入浅出
  3. 【树莓派笔记】USB口供电能力
  4. Solr(1):Solr概述
  5. 超声波穿刺焊工艺知多少
  6. RabbitMQ系列笔记主题订阅模式
  7. 瑞云服务云携手慧而特,引领餐饮设备服务创新升级
  8. excel表格内容拆分_Excel表格中制作动态下拉菜单的方法,学会了菜单内容想加就加...
  9. 计算机教室消防说明,计算机教室消防制度.doc
  10. 第十三章 使用动态SQL(三)