Pytorch-早停法(early stopping)原理及其代码
作为深度学习训练数据的trick,这个方法必须知道啊,结合交叉验证法,可以防止模型过早拟合。
早停法是一种被广泛使用的方法,在很多案例上都比正则化的方法要好。是在训练中计算模型在验证集上的表现,当模型在验证集上的表现开始下降的时候,停止训练,这样就能避免继续训练导致过拟合的问题。其主要步骤如下:
1. 将原始的训练数据集划分成训练集和验证集
2. 只在训练集上进行训练,并每隔一个周期计算模型在验证集上的误差
3. 当模型在验证集上(权重的更新低于某个阈值;预测的错误率低于某个阈值;达到一定的迭代次数),则停止训练
4. 使用上一次迭代结果中的参数作为模型的最终参数
如下图之后的某个epoch,模型的验证误差逐渐上升,模型出现过拟合,所以需要提前停止训练,早停法主要是训练时间和泛化错误之间的权衡。不同的停止标准也是给我们带来不同的效果。
下面在pytorch上面运用早停法(early stopping)
#Train the Model using Early Stopping
def train_model(model, batch_size, patience, n_epochs):# to track the training loss as the model trainstrain_losses = []# to track the validation loss as the model trainsvalid_losses = []# to track the average training loss per epoch as the model trainsavg_train_losses = []# to track the average validation loss per epoch as the model trainsavg_valid_losses = [] # initialize the early_stopping objectearly_stopping = EarlyStopping(patience=patience, verbose=True)for epoch in range(1, n_epochs + 1):#################### train the model ####################model.train() # prep model for trainingfor batch, (data, target) in enumerate(train_loader, 1):# clear the gradients of all optimized variablesoptimizer.zero_grad()# forward pass: compute predicted outputs by passing inputs to the modeloutput = model(data)# calculate the lossloss = criterion(output, target)# backward pass: compute gradient of the loss with respect to model parametersloss.backward()# perform a single optimization step (parameter update)optimizer.step()# record training losstrain_losses.append(loss.item())###################### # validate the model #######################model.eval() # prep model for evaluationfor data, target in valid_loader:# forward pass: compute predicted outputs by passing inputs to the modeloutput = model(data)# calculate the lossloss = criterion(output, target)# record validation lossvalid_losses.append(loss.item())# print training/validation statistics # calculate average loss over an epochtrain_loss = np.average(train_losses)valid_loss = np.average(valid_losses)avg_train_losses.append(train_loss)avg_valid_losses.append(valid_loss)epoch_len = len(str(n_epochs))print_msg = (f'[{epoch:>{epoch_len}}/{n_epochs:>{epoch_len}}] ' +f'train_loss: {train_loss:.5f} ' +f'valid_loss: {valid_loss:.5f}')print(print_msg)# clear lists to track next epochtrain_losses = []valid_losses = []# early_stopping needs the validation loss to check if it has decresed, # and if it has, it will make a checkpoint of the current modelearly_stopping(valid_loss, model)if early_stopping.early_stop:print("Early stopping")break# load the last checkpoint with the best modelmodel.load_state_dict(torch.load('checkpoint.pt'))return model, avg_train_losses, avg_valid_losses
具体的完整代码为:https://github.com/Bjarten/early-stopping-pytorch/blob/master/MNIST_Early_Stopping_example.ipynb
Pytorch-早停法(early stopping)原理及其代码相关推荐
- pytorch早停法
作为深度学习训练数据的trick,结合交叉验证法,可以防止模型过早拟合. 早停法是一种被广泛使用的方法,在很多案例上都比正则化的方法要好.是在训练中计算模型在验证集上的表现,当模型在验证集上的表现开始 ...
- R语言构建xgboost模型使用早停法训练模型(early stopping):自定义损失函数(目标函数,loss function)、评估函数(evaluation function)
R语言构建xgboost模型使用早停法训练模型(early stopping):自定义损失函数(目标函数.loss function.object function).评估函数(evaluation ...
- keras构建前馈神经网络(feedforward neural network)进行分类模型构建基于早停法(Early stopping)
keras构建前馈神经网络(feedforward neural network)进行分类模型构建基于早停法(Early stopping) 当我们训练深度学习神经网络的时候通常希望能获得最好的泛化性 ...
- 深度学习技巧之Early Stopping(早停法)
深度学习技巧之Early Stopping(早停法) | 数据学习者官方网站(Datalearner) 当我们训练深度学习神经网络的时候通常希望能获得最好的泛化性能(generalization pe ...
- 【pytorch EarlyStopping】深度学习之早停法入门·相信我,一篇就够。
这个方法更好的解决了模型过拟合问题. EarlyStopping的原理是提前结束训练轮次来达到"早停"的目的,故训练轮次需要设置的大一点以求更好的早停(比如可以设置100epoch ...
- Early Stopping 早停法原理与实现
Early Stopping 训练深度学习神经网络的时候通常希望能获得最好的泛化性能,可以更好地拟合数据.但是所有的标准深度学习神经网络结构如全连接多层感知机都很容易过拟合. 当模型在训练集上表现很好 ...
- 深度学习——早停法(Early Stopping)
学习链接:https://www.jianshu.com/p/9ab695d91459 https://www.datalearner.com/blog/1051537860479157 目的: 为了 ...
- 在PyTorch中进行双线性采样:原理和代码详解
↑ 点击蓝字 关注视学算法 作者丨土豆@知乎 来源丨https://zhuanlan.zhihu.com/p/257958558 编辑丨极市平台 在pytorch中的双线性采样(Bilinear Sa ...
- Early Stopping早停法
参考: https://www.jianshu.com/p/9ab695d91459
- Earlystopping(早停法)
Earlystopping 简介 当我们训练深度学习神经网络的时候通常希望能获得最好的泛化性能(generalization performance,即可以很好地拟合数据). 但是所有的标准深度学习神 ...
最新文章
- MySQL基础入门学习【1】基本操作
- Nginx的启动、停止
- iview在vue-cli3如何按需加载
- linux jdk安装_linux运维 - 用脚本快速安装jdk
- Bootstrap模态框(modal)显示、隐藏与禁用ESC代码实现
- C语言quaternion(四元数)(附完整源码)
- Java 线程的生命周期
- tomcat 优化_浅谈Tomcat服务器优化方法
- AMD宣布350亿美元收购赛灵思,CPU、GPU、FPGA全凑齐
- MessAPI V1.1.1 QQ音乐、网易云音乐、酷狗音乐、咪咕音乐、酷我音乐、百度音乐API接口
- word 左侧显示目录
- 2022危险化学品经营单位安全管理人员特种作业证考试题库及在线模拟考试
- 龙芯CPU开发系统固件与内核接口手册资料
- SSL 3.0曝出Poodle漏洞的解决方案
- 2021-07-04 【5】
- 并发编程系列之AQS实现原理
- C语言-概念-fscanf函数和fprintf函数
- java sql timestamp_Java SQL Timestamp before()用法及代码示例
- 在线测试c语言程序代码,C语言在线测评系统的使用
- APP安全性测试总结-移动APP安全测试
热门文章
- win10计算机被网络设备发现,图文解决win10系统网络发现已关闭计算机和设备不见的方法...
- 3dsMax建模,卡线学习笔记
- 万能五笔输入法弹窗_万能五笔输入法广告如何彻底关闭
- k8s安装calico网络插件
- 华为设备MAC地址配置命令
- 浙大pat | 浙大pat乙级 1001~1004
- 有什么软件可以测试汽车的噪音,汽车噪声测试,汽车通过噪声测试
- uni-app知识点整理(1)- uni-app简介、环境搭建、项目创建、项目目录文件
- python在统计专业的应用_Python统计学statistics实战
- 基于E4A的手机蓝牙串口助手app制作