这个方法更好的解决了模型过拟合问题。

EarlyStopping的原理是提前结束训练轮次来达到“早停“的目的,故训练轮次需要设置的大一点以求更好的早停(比如可以设置100epoch)。

首先,我们需要一个一个标识,可以采用'val_acc’、’val_loss’等等,这些量在每一个轮次中都会不断更新自己的值,也和模型的参数息息相关,所以我们想通过他们间接操作模型参数。以val_loss来说,当模型训练时可能会出现当val_loss到一定值的时候会出现回弹的情况,所以我们希望在他回弹之前结束模型的训练。

早停法其实一共有3类停止标准,这里我们选用最简单的一种入门。话不多说,上代码!!!

import numpy as np
import torch

导入两个最基本的包就行,因为早停法是一种可以自己就写出来的算法!!!

参数有5个:

第一个patience:这个是当有连续的patience个轮次数值没有继续下降,反而上升的时候结束训练的条件(以val_loss为例)

第二个verbose:这个其实就是是否print一些值,可也不传参,因为他有默认值

第三个delta:这个就是控制对比是的”标准线“

第四个path:这个是权重保存路径,早停法会在每一轮次次产生最优解(就是val_loss继续减少)的时候保存当前的模型参数。注:只要保存路径不变,每一次保存在文件里面的参数都会覆盖上一次保存在文件里面的参数。

第五个trace_func:这个就是显示每一个轮次变化的数值的方式,默认print,也可以改成进度条显示(tqdm的对象)

class EarlyStopping:def __init__(self, patience=7, verbose=False, delta=0, path='weight7-stop.pth', trace_func=print):self.patience = patienceself.verbose = verboseself.counter = 0self.best_score = Noneself.early_stop = Falseself.val_loss_min = np.Infself.delta = deltaself.path = pathself.trace_func = trace_funcdef __call__(self, val_loss, model):score = -val_lossif self.best_score is None:self.best_score = scoreself.save_checkpoint(val_loss, model)elif score < self.best_score + self.delta:self.counter += 1self.trace_func(f'EarlyStopping counter: {self.counter} out of {self.patience}')if self.counter >= self.patience:self.early_stop = Trueelse:self.best_score = scoreself.save_checkpoint(val_loss, model)self.counter = 0def save_checkpoint(self, val_loss, model):'''Saves model when validation loss decrease.'''if self.verbose:self.trace_func(f'Validation loss decreased ({self.val_loss_min:.6f} --> {val_loss:.6f}).  Saving model ...')torch.save(model.state_dict(), self.path)self.val_loss_min = val_loss

重点就在中间那个__call__方法里面,比较的是这一轮的val_loss和之前最好的val_loss(可以加上一个数实现‘标准线’的‘上移’或者‘下移’)

实际应用与项目当中

这是我再积水检测项目中的代码的一部分。

我设置了patience为7.

epoch为200。(这个推荐小一点,因为太大没有意义,一定会过拟合的)

注:本文使用的早停法源代码不是原创,取自github。

【pytorch EarlyStopping】深度学习之早停法入门·相信我,一篇就够。相关推荐

  1. 深度学习之早停策略EarlyStopping以及保存测试集准确率最高的模型ModelCheckpoint

    在训练神经网络时,如果epochs设置的过多,导致最终结束时测试集上模型的准确率比较低,而我们却想保存准确率最高时候的模型参数,这就需要用到Early Stopping以及ModelCheckpoin ...

  2. 深度学习技巧之Early Stopping(早停法)

    深度学习技巧之Early Stopping(早停法) | 数据学习者官方网站(Datalearner) 当我们训练深度学习神经网络的时候通常希望能获得最好的泛化性能(generalization pe ...

  3. pytorch早停法

    作为深度学习训练数据的trick,结合交叉验证法,可以防止模型过早拟合. 早停法是一种被广泛使用的方法,在很多案例上都比正则化的方法要好.是在训练中计算模型在验证集上的表现,当模型在验证集上的表现开始 ...

  4. keras构建前馈神经网络(feedforward neural network)进行分类模型构建基于早停法(Early stopping)

    keras构建前馈神经网络(feedforward neural network)进行分类模型构建基于早停法(Early stopping) 当我们训练深度学习神经网络的时候通常希望能获得最好的泛化性 ...

  5. DL框架之PyTorch:深度学习框架PyTorch的简介、安装、使用方法之详细攻略

    DL框架之PyTorch:PyTorch的简介.安装.使用方法之详细攻略 DL框架之PyTorch:深度学习框架PyTorch的简介.安装.使用方法之详细攻略 目录 PyTorch的简介 1.pyto ...

  6. PyTorch核心贡献者开源书:《使用PyTorch进行深度学习》完整版现已发布!

    来源|新智元 [导读]<使用PyTorch进行深度学习>一书的完整版现已发布!教你如何使用PyTorch创建神经网络和深度学习系统,内含图解与代码,操作易上手. 由Luca Antiga. ...

  7. R语言构建xgboost模型使用早停法训练模型(early stopping):自定义损失函数(目标函数,loss function)、评估函数(evaluation function)

    R语言构建xgboost模型使用早停法训练模型(early stopping):自定义损失函数(目标函数.loss function.object function).评估函数(evaluation ...

  8. Pytorch:深度学习中pytorch/torchvision版本和CUDA版本最正确版本匹配、对应版本安装之详细攻略

    Pytorch:深度学习中pytorch/torchvision版本和CUDA版本最正确版本匹配.对应版本安装之详细攻略 目录 深度学习中pytorch/torchvision版本和CUDA版本最正确 ...

  9. 使用pytorch进行深度学习

    文章目录 使用pytorch进行深度学习 1. 深度学习构建模块:仿射变换, 非线性函数以及目标函数 1.1 仿射变换 1.2 非线性函数 1.3 Softmax和概率 1.4 目标函数 2. 优化和 ...

最新文章

  1. 成功解决No such file or directory: site-packages\\pyyaml-5.3-py3.6-win-amd64.egg\\EGG-INFO\\top_level.t
  2. Divide and conquer:K Best(POJ 3111)
  3. nyoj66分数拆分
  4. assertj_AssertJ的SoftAssertions –我们需要它们吗?
  5. 专访涯海:阿里云中间件是如何支撑双11的?
  6. 浅析ElasticSearch原理
  7. 挽救数据库性能的 30 条黄金法则 | 原力计划
  8. 公交车宜配备逃生绳索
  9. 如何培养项目管理的领导力?
  10. Ajax的readyState和status
  11. 计算机系女学霸男生追,杨紫李现解锁恋爱新姿势:吃最甜的糖,追最燃的梦
  12. 莱西姆大学计算机专业,菲律宾的大学排名是根据什么指标排的
  13. 关于学习软件逆向分析意义的阐述
  14. error: Microsoft Visual C++ 9.0 is required (Unabl
  15. 一位在微软公司的粉丝,写给我的信
  16. python 导入模型_scikit-learn系列之如何存储和导入机器学习模型
  17. 荣耀30s刷鸿蒙,荣耀30S“超过”苹果XS,靠华为鸿蒙框架优化能力
  18. 【Paper material】
  19. 微信小程序开发中常见问题及解决方法
  20. GPRS模块(SIM900A)在QT下的通信例程

热门文章

  1. Python读取,写入,保存txt文件
  2. JUC下的CountDownLatch,CyclicBarrier、Semaphore的使用方法
  3. c语言字符串子串问题,C语言计算字符串子串出现的次数
  4. 股东转让股权后是否还应承担出资义务
  5. RocketMQ(七) RocketMQ的两种消费模式
  6. Fiddler调试利器
  7. 新站如何使用好百度站长平台工具
  8. 大航海时代: 流行5掠夺篇
  9. Verilog中的!和~
  10. PROGRESSIVE GROWING OF GANS FOR IMPROVED QUALITY, STABILITY, AND VARIATION(PGGAN)