文章目录

  • 早停的目的与流程
  • 早停策略
  • pytorch使用示例
  • 参考网站

早停的目的与流程

目的:防止模型过拟合,由于深度学习模型可以无限迭代下去,因此希望在即将过拟合时、或训练效果微乎其微时停止训练。

流程如下:

  1. 将数据集切分为三部分:训练数据(数据量最多),验证数据(数据量最少,一般10%-20%左右即可),测试数据(数据量第二多)
  2. 模型通过训练集,得到训练集的LosstrainLoss_{train}Losstrain​
  3. 然后模型通过验证集,此时不是训练,不需要反向传播,得到验证集的LossvalidLoss_{valid}Lossvalid​
  4. 早停策略通过LosstrainLoss_{train}Losstrain​与LossvalidLoss_{valid}Lossvalid​来判断,是否需要中断训练

早停策略

早停策略,我们都是拿着验证集训练集来说事:

  1. 常用的策略:

    ♣ 如果训练集loss与验证集loss连续几次下降不明显,就早停
    ♣ 验证集loss连续n次不降反升则早停。(通常是3次)

  2. 根据泛化损失卡阈值的策略

    ♣ 将目前已有的验证集的最小loss记录下来,看当前的验证集loss与最小的loss之间的差距
    ♣ 通过公式:GL(t)=100⋅(Eva(t)Eopt(t)−1){GL(t)} = 100 \cdot \big( \frac{E_{va}(t)}{E_{opt}(t)} - 1)GL(t)=100⋅(Eopt​(t)Eva​(t)​−1)计算一个值,并称之为泛化损失
    ♣ 当这个泛化损失超过阈值的时候停止训练

  3. 根据度量进展卡阈值的策略:我们通常假设过拟合会出现在训练集loss很难下降的时候,此时模型继续强行下降loss会导致过拟合的风险,因此,

    ♣ 定一个迭代周期,为训练k次,判断本次迭代的时候平均训练loss比最小训练loss大多少
    (公式:Pk(t)=1000⋅(∑t′=t−k+1tEtr(t′)k⋅mint′=t−k+1tEtr(t′)−1)P_k(t) = 1000 \cdot \big( \frac{ \sum_{t' = t-k+1}^t E_{tr}(t') }{ k \cdot min_{t' = t-k+1}^t E_{tr}(t') } -1 \big)Pk​(t)=1000⋅(k⋅mint′=t−k+1t​Etr​(t′)∑t′=t−k+1t​Etr​(t′)​−1))
    ♣ 然后结合上面的泛化损失,计算GL(t)Pk(t)\frac{GL(t)}{P_k(t)}Pk​(t)GL(t)​
    ♣ 当这个值大于一个阈值时,停止训练

pytorch使用示例

我们参考https://github.com/Bjarten/early-stopping-pytorch这个项目的早停策略

EarlyStopping类在:https://github.com/Bjarten/early-stopping-pytorch/blob/master/pytorchtools.py

结合深度学习的示例如下:

import torch
import torch.nn as nn
import os
from sklearn.datasets import make_regression
from torch.utils.data import Dataset, DataLoader
import numpy as npclass EarlyStopping: # 这个是别人写的工具类,大家可以把它放到别的地方"""Early stops the training if validation loss doesn't improve after a given patience."""def __init__(self, patience=7, verbose=False, delta=0, path='checkpoint.pt', trace_func=print):"""Args:patience (int): How long to wait after last time validation loss improved.Default: 7verbose (bool): If True, prints a message for each validation loss improvement.Default: Falsedelta (float): Minimum change in the monitored quantity to qualify as an improvement.Default: 0path (str): Path for the checkpoint to be saved to.Default: 'checkpoint.pt'trace_func (function): trace print function.Default: print"""self.patience = patienceself.verbose = verboseself.counter = 0self.best_score = Noneself.early_stop = Falseself.val_loss_min = np.Infself.delta = deltaself.path = pathself.trace_func = trace_funcdef __call__(self, val_loss, model):score = -val_lossif self.best_score is None:self.best_score = scoreself.save_checkpoint(val_loss, model)elif score < self.best_score + self.delta:self.counter += 1self.trace_func(f'EarlyStopping counter: {self.counter} out of {self.patience}')if self.counter >= self.patience:self.early_stop = Trueelse:self.best_score = scoreself.save_checkpoint(val_loss, model)self.counter = 0def save_checkpoint(self, val_loss, model):'''Saves model when validation loss decrease.'''if self.verbose:self.trace_func(f'Validation loss decreased ({self.val_loss_min:.6f} --> {val_loss:.6f}).  Saving model ...')torch.save(model.state_dict(), self.path)self.val_loss_min = val_lossclass MyDataSet(Dataset):  # 定义数据格式def __init__(self, train_x, train_y, sample):self.train_x = train_xself.train_y = train_yself._len = sampledef __getitem__(self, item: int):return self.train_x[item], self.train_y[item]def __len__(self):return self._lendef get_data():"""构造数据"""sample = 20000data_x, data_y = make_regression(n_samples=sample, n_features=100)  # 生成数据集train_data_x = data_x[:int(sample * 0.8)]train_data_y = data_y[:int(sample * 0.8)]valid_data_x = data_x[int(sample * 0.8):]valid_data_y = data_y[int(sample * 0.8):]train_loader = DataLoader(MyDataSet(train_data_x, train_data_y, len(train_data_x)), batch_size=10)valid_loader = DataLoader(MyDataSet(valid_data_x, valid_data_y, len(valid_data_x)), batch_size=10)return train_loader, valid_loaderclass LinearRegressionModel(nn.Module):def __init__(self, input_dim, output_dim):super(LinearRegressionModel, self).__init__()self.linear = nn.Linear(input_dim, output_dim)  # 输入的个数,输出的个数def forward(self, x):out = self.linear(x)return outdef main():train_loader, valid_loader = get_data()model = LinearRegressionModel(input_dim=100, output_dim=1)optimizer = torch.optim.SGD(model.parameters(), lr=0.001)criterion = nn.MSELoss()early_stopping = EarlyStopping(patience=4, verbose=True)  # 早停# 开始训练模型for epoch in range(1000):# 正常的训练print("迭代第{}次".format(epoch))model.train()train_loss_list = []for train_x, train_y in train_loader:optimizer.zero_grad()outputs = model(train_x.float())loss = criterion(outputs.flatten(), train_y.float())loss.backward()train_loss_list.append(loss.item())optimizer.step()print("训练loss:{}".format(np.average(train_loss_list)))# 早停策略判断model.eval()with torch.no_grad():valid_loss_list = []for valid_x, valid_y in valid_loader:outputs = model(valid_x.float())loss = criterion(outputs.flatten(), valid_y.float())valid_loss_list.append(loss.item())avg_valid_loss = np.average(valid_loss_list)print("验证集loss:{}".format(avg_valid_loss))early_stopping(avg_valid_loss, model)if early_stopping.early_stop:print("此时早停!")breakif __name__ == '__main__':main()

参考网站

  • 深度学习技巧之Early Stopping(早停法):https://www.datalearner.com/blog/1051537860479157
  • early-stopping-pytorch:https://github.com/Bjarten/early-stopping-pytorch

pytorch使用早停策略相关推荐

  1. 深度学习之早停策略EarlyStopping以及保存测试集准确率最高的模型ModelCheckpoint

    在训练神经网络时,如果epochs设置的过多,导致最终结束时测试集上模型的准确率比较低,而我们却想保存准确率最高时候的模型参数,这就需要用到Early Stopping以及ModelCheckpoin ...

  2. pytorch早停法

    作为深度学习训练数据的trick,结合交叉验证法,可以防止模型过早拟合. 早停法是一种被广泛使用的方法,在很多案例上都比正则化的方法要好.是在训练中计算模型在验证集上的表现,当模型在验证集上的表现开始 ...

  3. 【pytorch EarlyStopping】深度学习之早停法入门·相信我,一篇就够。

    这个方法更好的解决了模型过拟合问题. EarlyStopping的原理是提前结束训练轮次来达到"早停"的目的,故训练轮次需要设置的大一点以求更好的早停(比如可以设置100epoch ...

  4. R语言构建xgboost模型使用早停法训练模型(early stopping):自定义损失函数(目标函数,loss function)、评估函数(evaluation function)

    R语言构建xgboost模型使用早停法训练模型(early stopping):自定义损失函数(目标函数.loss function.object function).评估函数(evaluation ...

  5. keras构建前馈神经网络(feedforward neural network)进行分类模型构建基于早停法(Early stopping)

    keras构建前馈神经网络(feedforward neural network)进行分类模型构建基于早停法(Early stopping) 当我们训练深度学习神经网络的时候通常希望能获得最好的泛化性 ...

  6. 早停!? earlystopping for keras

    为了获得性能良好的神经网络,网络定型过程中需要进行许多关于所用设置(超参数)的决策.超参数之一是定型周期(epoch)的数量:亦即应当完整遍历数据集多少次(一次为一个epoch)?如果epoch数量太 ...

  7. 精确拐点交易体系之追涨停策略

    [导言]通常做交易获利的主要途径有两种:一是研究创造价值,如择势.择股都需要大量的研究:二是利用波动本身创造价值,即利用择时体系低买高卖或高卖低买.此文介绍的精确拐点交易体系之追涨停策略,是璞一投资利 ...

  8. EarlyStopping早停法的实现原理

    keras中的EarlyStopping使用很方便,当但我测试torch的EarlyStopping时掉坑了!!! torch中pytorchtools工具有早停法,但我测试基本用不了,总是出错,fu ...

  9. 深度学习技巧之Early Stopping(早停法)

    深度学习技巧之Early Stopping(早停法) | 数据学习者官方网站(Datalearner) 当我们训练深度学习神经网络的时候通常希望能获得最好的泛化性能(generalization pe ...

最新文章

  1. 【noiOJ】P1996
  2. 七、Linux串口编程
  3. eclipse替换空格和注释
  4. java下拉文本框_java swing 下拉框与文本框
  5. python无条件跳转_python按按钮实现界面跳转_python实现界面跳转 - CSDN
  6. 利用高级筛选功能巧妙删除Excel的重复记录
  7. 自学python能干些什么副业-揭秘!女程序员为啥更赚钱?这4个大招,用Python做副业躺赚...
  8. mysql简单命令行操作以及环境变量的配置
  9. 深入浅出4G标准 LTE FDD和LTE TDD
  10. 图像识别技术原理和神经网络的图像识别技术
  11. URLConnection 传入参数
  12. UNIX环境高级编程——1.UNIX基础知识
  13. BP神经网络的详细推导
  14. SpringAop篇 (1) AOP 基础之动态代理的实现
  15. 揭秘 zCloud 3.0丨企业需要怎么样的DBA?
  16. 视频中的硬字幕该如何提取和翻译?
  17. 打造最强终端之一:Fish shell简明教程
  18. t3显示乱码_用友T3软件客户端不能输入汉字或者为乱码
  19. 网易云歌词居中滚动、点击/滑动进度条对应滚动、当前播放歌词高亮
  20. UltraISO制作启动盘及提取U盘为ISO镜像

热门文章

  1. python里什么叫子图_Python中的两个子图(matplotlib)
  2. tomcat加上了https后访问不了_西部数码使用指南:部署https后访问提示存在安全隐患的排查解决方法...
  3. TypeError at / 'AnonymousUser' object is not iterable
  4. mysql数据库年龄_sql获取时间、年龄
  5. 3D空间中射线与三角形的交叉检测算法
  6. HDU 4546 比赛难度 (优先队列 * * )
  7. 动手Lab|利用CSI和Kubernetes实现动态扩容
  8. PHP生成器--动态生成内容的数组
  9. scp上传服务器加特殊端口
  10. spring 启动完成后事件监听器处理