【源码解析】如何从零实现一个回归模型?
说明:本文源代码来源于MACHINE LEARNING 2022 SPRING课程,我只是针对源代码进行了一些加工处理。感谢互联网,让我们能免费接触到这些优秀的课程。
前置知识
- 什么是回归模型?简单说就是模型的输出是连续的,如概率大小等
目标
- 借助DNN(Deep Neural Networks)网络解决一个回归问题
- 理解基本的DNN训练技巧,如超参数的微调、特征选取、正则化
- 根据美国某州过去五天中前四天的调查结果,预测第五天新冠测试阳性的病例数
任务描述
- COVID-19情况预测
- 数据来源:Delphi group@CMU 自2020年4月以来,通过FaceBook进行的每日调查
- 根据美国特定州最近5天的调查结果,预测第5天的新确诊患者比率
数据组成
- 州代码(37个州,已编码成独热向量)
- 独热向量:仅有一个元素置为1,而其它元素均置为0的向量。在深度学习中常用于编码离散值
- COVID相似症状(4组)
- 行为指标(8组)
- 心理健康指标(3组)
- 阳性病例(我们想预测的数据)
性能指标
- Mean Squared Error(MSE)
- MSE=1N∑i=1N(yi−y~i)2MSE=\frac{1}{N}\sum_{i=1}^{N}(y_i-\tilde{y}_i)^2MSE=N1∑i=1N(yi−y~i)2
- yiy_iyi代表Ground truth,y~i\tilde{y}_iy~i代表模型输出的预测值
实现思路
源码解析
基础部分
导包
# Numerical Operations
import math
import numpy as np# Reading/Writing Data
import pandas as pd
import os
import csv# For Progress Bar
from tqdm import tqdm# Pytorch
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader, random_split# For plotting learning curve
from torch.utils.tensorboard import SummaryWriter
功能函数
def same_seed(seed): '''Fixes random number generator seeds for reproducibility.'''# A bool that, if True, causes cuDNN to only use deterministic convolution algorithms.# cudnn: 是经GPU加速的深度神经网络基元库。cuDNN可大幅优化标准例程(例如用于前向传播和反向传播的卷积层、池化层、归一化层和激活层)的实施。torch.backends.cudnn.deterministic = True# A bool that, if True, causes cuDNN to benchmark multiple convolution algorithms and select the fastest.torch.backends.cudnn.benchmark = False# 用于生成指定的随机数np.random.seed(seed)# Sets the seed for generating random numbers. torch.manual_seed(seed)if torch.cuda.is_available():# Sets the seed for generating random numbers for the current GPU. # It’s safe to call this function if CUDA is not available; in that case, it is silently ignored.torch.cuda.manual_seed_all(seed)def train_valid_split(data_set, valid_ratio, seed):'''Split provided training data into training set and validation set'''valid_set_size = int(valid_ratio * len(data_set)) train_set_size = len(data_set) - valid_set_size# Randomly split a dataset into non-overlapping new datasets of given lengths. # Optionally fix the generator for reproducible resultstrain_set, valid_set = random_split(data_set, [train_set_size, valid_set_size], generator=torch.Generator().manual_seed(seed))return np.array(train_set), np.array(valid_set)def predict(test_loader, model, device):model.eval() # Set your model to evaluation mode.preds = []for x in tqdm(test_loader):x = x.to(device) with torch.no_grad(): pred = model(x) preds.append(pred.detach().cpu()) preds = torch.cat(preds, dim=0).numpy() return preds
数据
数据的下载
# 下面这些是从谷歌云盘上下载数据到当前目录下
!gdown --id '1kLSW_-cW2Huj7bh84YTdimGBOJaODiOS' --output covid.train.csv
!gdown --id '1iiI5qROrAhZn-o4FPqsE97bMzDEFvIdg' --output covid.test.csv
数据的预处理(特征选取、数据划分)
def select_feat(train_data, valid_data, test_data, select_all=True):'''Selects useful features to perform regression'''y_train, y_valid = train_data[:,-1], valid_data[:,-1]raw_x_train, raw_x_valid, raw_x_test = train_data[:,:-1], valid_data[:,:-1], test_dataif select_all:feat_idx = list(range(raw_x_train.shape[1]))else:feat_idx = [0,1,2,3,4] # TODO: Select suitable feature columns.return raw_x_train[:,feat_idx], raw_x_valid[:,feat_idx], raw_x_test[:,feat_idx], y_train, y_valid# Set seed for reproducibility
same_seed(config['seed'])
# train_data size: 2699 x 118 (id + 37 states + 16 features x 5 days)
# test_data size: 1078 x 117 (without last day's positive rate)
train_data, test_data = pd.read_csv('./covid.train.csv').values, pd.read_csv('./covid.test.csv').values
train_data, valid_data = train_valid_split(train_data, config['valid_ratio'], config['seed'])# Print out the data size.
print(f"""train_data size: {train_data.shape}
valid_data size: {valid_data.shape}
test_data size: {test_data.shape}""")# Select features
x_train, x_valid, x_test, y_train, y_valid = select_feat(train_data, valid_data, test_data, config['select_all'])# Print out the number of features.
print(f'number of features: {x_train.shape[1]}')
数据加载器的构造(DataSet、DataLoader)
class COVID19Dataset(Dataset):'''x: Features.y: Targets, if none, do prediction.'''def __init__(self, x, y=None):if y is None:self.y = yelse:self.y = torch.FloatTensor(y)self.x = torch.FloatTensor(x)def __getitem__(self, idx):if self.y is None:return self.x[idx]else:return self.x[idx], self.y[idx]def __len__(self):return len(self.x)
train_dataset, valid_dataset, test_dataset = COVID19Dataset(x_train, y_train), \COVID19Dataset(x_valid, y_valid), \COVID19Dataset(x_test)# Pytorch data loader loads pytorch dataset into batches.
train_loader = DataLoader(train_dataset, batch_size=config['batch_size'], shuffle=True, pin_memory=True)
valid_loader = DataLoader(valid_dataset, batch_size=config['batch_size'], shuffle=True, pin_memory=True)
test_loader = DataLoader(test_dataset, batch_size=config['batch_size'], shuffle=False, pin_memory=True)
网络结构
结构的实现
class My_Model(nn.Module):def __init__(self, input_dim):super(My_Model, self).__init__()# TODO: modify model's structure, be aware of dimensions. self.layers = nn.Sequential(nn.Linear(input_dim, 16),nn.ReLU(),nn.Linear(16, 8),nn.ReLU(),nn.Linear(8, 1))def forward(self, x):x = self.layers(x)x = x.squeeze(1) # (B, 1) -> (B)return x
训练与预测
训练函数
def trainer(train_loader, valid_loader, model, config, device):criterion = nn.MSELoss(reduction='mean') # Define your loss function, do not modify this.# Define your optimization algorithm. # TODO: Please check https://pytorch.org/docs/stable/optim.html to get more available algorithms.# TODO: L2 regularization (optimizer(weight decay...) or implement by your self).optimizer = torch.optim.SGD(model.parameters(), lr=config['learning_rate'], momentum=0.9) # ?writer = SummaryWriter() # Writer of tensoboard.if not os.path.isdir('./models'):os.mkdir('./models') # Create directory of saving models.n_epochs, best_loss, step, early_stop_count = config['n_epochs'], math.inf, 0, 0for epoch in range(n_epochs):model.train() # Set your model to train mode.loss_record = []# tqdm is a package to visualize your training progress.train_pbar = tqdm(train_loader, position=0, leave=True)for x, y in train_pbar:optimizer.zero_grad() # Set gradient to zero.x, y = x.to(device), y.to(device) # Move your data to device. pred = model(x) loss = criterion(pred, y)loss.backward() # Compute gradient(backpropagation).optimizer.step() # Update parameters.step += 1loss_record.append(loss.detach().item())# Display current epoch number and loss on tqdm progress bar.train_pbar.set_description(f'Epoch [{epoch+1}/{n_epochs}]')train_pbar.set_postfix({'loss': loss.detach().item()})mean_train_loss = sum(loss_record)/len(loss_record)writer.add_scalar('Loss/train', mean_train_loss, step)model.eval() # Set your model to evaluation mode.loss_record = []for x, y in valid_loader:x, y = x.to(device), y.to(device)with torch.no_grad():pred = model(x)loss = criterion(pred, y)loss_record.append(loss.item())mean_valid_loss = sum(loss_record)/len(loss_record)print(f'Epoch [{epoch+1}/{n_epochs}]: Train loss: {mean_train_loss:.4f}, Valid loss: {mean_valid_loss:.4f}')writer.add_scalar('Loss/valid', mean_valid_loss, step)if mean_valid_loss < best_loss:best_loss = mean_valid_losstorch.save(model.state_dict(), config['save_path']) # Save your best modelprint('Saving model with loss {:.3f}...'.format(best_loss))early_stop_count = 0else: early_stop_count += 1if early_stop_count >= config['early_stop']:print('\nModel is not improving, so we halt the training session.')return
训练参数的设置
device = 'cuda' if torch.cuda.is_available() else 'cpu'
config = {'seed': 5201314, # Your seed number, you can pick your lucky number. :)'select_all': True, # Whether to use all features.'valid_ratio': 0.2, # validation_size = train_size * valid_ratio'n_epochs': 3000, # Number of epochs. 'batch_size': 256, 'learning_rate': 1e-5, 'early_stop': 400, # If model has not improved for this many consecutive epochs, stop training. 'save_path': './models/model.ckpt' # Your model will be saved here.
}
开始训练
model = My_Model(input_dim=x_train.shape[1]).to(device) # put your model and data on the same computation device.
trainer(train_loader, valid_loader, model, config, device)
测试函数(保存测试结果)
def save_pred(preds, file):''' Save predictions to specified file '''with open(file, 'w') as fp:writer = csv.writer(fp)writer.writerow(['id', 'tested_positive'])for i, p in enumerate(preds):writer.writerow([i, p])
开始测试
model = My_Model(input_dim=x_train.shape[1]).to(device)
model.load_state_dict(torch.load(config['save_path']))
preds = predict(test_loader, model, device)
save_pred(preds, 'pred.csv')
【源码解析】如何从零实现一个回归模型?相关推荐
- spring boot 源码解析15-spring mvc零配置
前言 spring boot 是基于spring 4 的基础上的一个框架,spring 4 有一个新特效–>基于java config 实现零配置.而在企业的实际工作中,spring 都是和sp ...
- object_detection源码解析-box_list
models.research.object_detection源码解析-core.box_list box_list是一个ObjectDetection项目中,一个综合管理bounding box的 ...
- Laravel5.2之Filesystem源码解析(下)
2019独角兽企业重金招聘Python工程师标准>>> 说明:本文主要学习下\League\Flysystem这个Filesystem Abstract Layer,学习下这个pac ...
- JDK源码解析之java.util.AbstractCollection
AbstractCollection类提供了collection的实现类应该具有的基本方法,具有一定的普适性,可以从大局上了解collection实现类的主要功能. java.util.Abstrac ...
- Laravel核心解读--Session源码解析
Session 模块源码解析 由于HTTP最初是一个匿名.无状态的请求/响应协议,服务器处理来自客户端的请求然后向客户端回送一条响应.现代Web应用程序为了给用户提供个性化的服务往往需要在请求中识别出 ...
- HDFS源码解析:教你用HDFS客户端写数据
摘要:终于开始了这个很感兴趣但是一直觉得困难重重的源码解析工作,也算是一个好的开端. 本文分享自华为云社区<hdfs源码解析之客户端写数据>,作者: dayu_dls. 在我们客户端写数据 ...
- Zookeeper源码解析 -- 本地事务日志持久化之FileTxnLog
序言 在各个分布式组件中,持久化数据到本地的思想并不少见,为的是能保存内存中的数据,以及重启后能够重载上次内存状态的值.那么如何行之有效的进行,内存数据持久化到磁盘,怎么样的落盘策略合适,怎么设计持久 ...
- 【详解】Ribbon 负载均衡服务调用原理及默认轮询负载均衡算法源码解析、手写
Ribbon 负载均衡服务调用 一.什么是 Ribbon 二.LB负载均衡(Load Balancer)是什么 1.Ribbon 本地负载均衡客户端 VS Nginx 服务端负载均衡的区别 2.LB负 ...
- 深入探究JDK中Timer的使用方式与源码解析
导言 定时器Timer的使用 构造方法 实例方法 使用方式 1. 执行时间晚于当前时间 2. 执行时间早于当前时间 3. 向Timer中添加多个任务 4. 周期性执行任务 5. 停止任务 源码解析 T ...
最新文章
- matlab怎么没有编辑器,在不打开编辑器窗口的情况下开始一个新的MATLAB会话
- python是一种语言吗-Python是一种什么样的编程语言?解释?编译?汇编?机械?...
- JavaScript对TreeView的操作全解
- Codeforces 474C Captain Marmot 给定4个点和各自旋转中心 问旋转成正方形的次数
- java excel中删除两列_Java 插入、隐藏/显示、删除Excel行或列
- 数据结构之图:有向图的拓扑排序,Python代码实现——26
- LeetCode 513. 找树左下角的值(递归)
- php fmod小数位数_PHP取余函数介绍MOD(x,y)与x%y
- c语言贪吃蛇游戏 vc6,把tc下的贪吃蛇游戏改到vc6下运行
- Oracle P6培训系列:15定义资源库
- matlab数学建模程序代码大全,matlab程序代码
- linux如何生成tgz文件,linux – 压缩文件夹以创建tgz文件
- 安卓 USB 无权限请求权限崩溃 UsbManager.requestPermission()空指针异常
- 【设计模式】简单工厂模式+工厂方法模式+抽象工厂模式
- 微信小程序动态绑定unit-id
- 侧边栏如何展开与收起
- .net core with 微服务 - Polly 熔断降级
- 如何运用政务智慧引导系统提升群众办事效率
- bugku:游戏过关
- 什么是IP地址定位,优缺点有哪些?
热门文章
- java记事本字体_记事本中的字体+字形+大小设置
- 顶级(top-level window)窗口,被拥有窗口(owned window),子窗口(child window) 与WS_POPUP,WS_CHILD深入浅出
- 【树莓派笔记】USB口供电能力
- Solr(1):Solr概述
- 超声波穿刺焊工艺知多少
- RabbitMQ系列笔记主题订阅模式
- 瑞云服务云携手慧而特,引领餐饮设备服务创新升级
- excel表格内容拆分_Excel表格中制作动态下拉菜单的方法,学会了菜单内容想加就加...
- 计算机教室消防说明,计算机教室消防制度.doc
- 第十三章 使用动态SQL(三)