Early Stopping

训练深度学习神经网络的时候通常希望能获得最好的泛化性能,可以更好地拟合数据。但是所有的标准深度学习神经网络结构如全连接多层感知机都很容易过拟合

当模型在训练集上表现很好,在验证集上表现很差的时候,我们认为模型出现了过拟合的情况,early stoppping 就是用来预防过拟合的一种方法,简单且有效

原理

early stoppping 的原理是:当模型在验证集上的表现开始下降的时候,停止训练,这样就能避免继续训练导致过拟合的问题

缺点

如下图,模型在验证集上的表现可能咱短暂的变差之后有可能继续变好,并不是在验证集上的表现一旦变差就不会变好。early stoppping 主要是训练时间和泛化错误之间的权衡。

pytorch 实现

EarlyStopping 是用于提前停止训练的callbacks。具体地,可以达到当训练集上的loss不在减小(即减小的程度小于某个阈值)的时候停止继续训练。

初始化

  • patience:自上次模型在验证集上损失降低之后等待的时间,此处设置为7
  • verbose:当为False时,运行的时候将不显示详细信息
  • counter:计数器,当其值超过patience时候,使用early stopping
  • best_score:记录模型评估的最好分数
  • early_step:决定模型要不要early stop,为True则停
  • val_loss_min:模型评估损失函数的最小值,默认为正无穷(np.Inf)
  • delta:表示模型损失函数改进的最小值,当超过这个值时候表示模型有所改进
class EarlyStopping:def __init__(self, patience=7, verbose=False, delta=0):self.patience = patienceself.verbose = verboseself.counter = 0self.best_score = Noneself.early_stop = Falseself.val_loss_min = np.Infself.delta = delta

保存模型

传入参数:val_loss、model 和 path

verbose 为True,则打印详细信息
函数作用:在 path 路径下,保存当前 model,并更新 val_loss_min 为当前 val_loss

    def save_checkpoint(self, val_loss, model, path):if self.verbose:print(f'Validation loss decreased ({self.val_loss_min:.6f} --> {val_loss:.6f}).  Saving model ...')torch.save(model.state_dict(), path + '/' + 'checkpoint.pth')self.val_loss_min = val_loss

调用

定义__call__()方法,该方法的功能类似于在类中重载 () 运算符,使得类实例对象可以像调用普通函数那样,以“对象名()”的形式使用

  1. 初始化时,设定了self.best_score = None

  2. if 语句第一行判断 self.best_score 是否为初始值,如果是初始值,则将 score 赋值给 self.best_score ,然后调用save_checkpoint()函数保存

  3. 当目前分数比最好分数加 self.delta 小时,就认为模型没有改进,将 counter 计数器加1,当计数器值超过 patience 的时候,就令early_stop为True,让模型停止训练

  4. 当目前分数比最好分数加 self.delta 大时,我们认为模型有改进,将目前分数赋值给最好分数,并将模型保存,令计数器归零。

    def __call__(self, val_loss, model, path):score = -val_lossif self.best_score is None:self.best_score = scoreself.save_checkpoint(val_loss, model, path)elif score < self.best_score + self.delta:self.counter += 1print(f'EarlyStopping counter: {self.counter} out of {self.patience}')if self.counter >= self.patience:self.early_stop = Trueelse:self.best_score = scoreself.save_checkpoint(val_loss, model, path)self.counter = 0

总体代码

总体代码如下:

class EarlyStopping:def __init__(self, patience=7, verbose=False, delta=0):self.patience = patienceself.verbose = verboseself.counter = 0self.best_score = Noneself.early_stop = Falseself.val_loss_min = np.Infself.delta = deltadef __call__(self, val_loss, model, path):score = -val_lossif self.best_score is None:self.best_score = scoreself.save_checkpoint(val_loss, model, path)elif score < self.best_score + self.delta:self.counter += 1print(f'EarlyStopping counter: {self.counter} out of {self.patience}')if self.counter >= self.patience:self.early_stop = Trueelse:self.best_score = scoreself.save_checkpoint(val_loss, model, path)self.counter = 0def save_checkpoint(self, val_loss, model, path):if self.verbose:print(f'Validation loss decreased ({self.val_loss_min:.6f} --> {val_loss:.6f}).  Saving model ...')torch.save(model.state_dict(), path + '/' + 'checkpoint.pth')self.val_loss_min = val_loss

Early Stopping 早停法原理与实现相关推荐

  1. Early Stopping早停法

    参考: https://www.jianshu.com/p/9ab695d91459

  2. R语言构建xgboost模型使用早停法训练模型(early stopping):自定义损失函数(目标函数,loss function)、评估函数(evaluation function)

    R语言构建xgboost模型使用早停法训练模型(early stopping):自定义损失函数(目标函数.loss function.object function).评估函数(evaluation ...

  3. keras构建前馈神经网络(feedforward neural network)进行分类模型构建基于早停法(Early stopping)

    keras构建前馈神经网络(feedforward neural network)进行分类模型构建基于早停法(Early stopping) 当我们训练深度学习神经网络的时候通常希望能获得最好的泛化性 ...

  4. 深度学习技巧之Early Stopping(早停法)

    深度学习技巧之Early Stopping(早停法) | 数据学习者官方网站(Datalearner) 当我们训练深度学习神经网络的时候通常希望能获得最好的泛化性能(generalization pe ...

  5. pytorch早停法

    作为深度学习训练数据的trick,结合交叉验证法,可以防止模型过早拟合. 早停法是一种被广泛使用的方法,在很多案例上都比正则化的方法要好.是在训练中计算模型在验证集上的表现,当模型在验证集上的表现开始 ...

  6. EarlyStopping早停法的实现原理

    keras中的EarlyStopping使用很方便,当但我测试torch的EarlyStopping时掉坑了!!! torch中pytorchtools工具有早停法,但我测试基本用不了,总是出错,fu ...

  7. 【pytorch EarlyStopping】深度学习之早停法入门·相信我,一篇就够。

    这个方法更好的解决了模型过拟合问题. EarlyStopping的原理是提前结束训练轮次来达到"早停"的目的,故训练轮次需要设置的大一点以求更好的早停(比如可以设置100epoch ...

  8. 深度学习——早停法(Early Stopping)

    学习链接:https://www.jianshu.com/p/9ab695d91459 https://www.datalearner.com/blog/1051537860479157 目的: 为了 ...

  9. Earlystopping(早停法)

    Earlystopping 简介 当我们训练深度学习神经网络的时候通常希望能获得最好的泛化性能(generalization performance,即可以很好地拟合数据). 但是所有的标准深度学习神 ...

最新文章

  1. 好玩!PyEcharts 绘制时间轮播图
  2. js原生设计模式——2面向对象编程之继承—new+call(this)组合式继承
  3. MySQL - 多版本控制 MVCC 机制初探
  4. Intellij IDEA 默认打开上次项目设置与取消设置
  5. 汇编原理实验--输出ASCII码10H到100H
  6. 辗转相除法(欧几里得算法)求 最大公约数与最小公倍数+推论与证明。
  7. 一个写得很不错的vuex详解(转)
  8. 如何实时主动监控你的网站接口是否挂掉并及时报警
  9. RedHat6.2 x86手动配置LNMP环境
  10. laravel简单的laragon环境搭建不需要composer一键集成
  11. SAP HR工资核算基础(转)
  12. 2015 2020 r4烧录卡 区别_2020版药典,药用辅料被重视了!
  13. 如何将视频转换为HEVC / H.265和AVC / H.264
  14. 如何提取mp4中的音频?
  15. 智慧营区解决方案-最新全套文件
  16. springboot yml文件不是绿叶子问题
  17. 【Linux】alias及设置
  18. 绿色版电脑站手机站仿站小工具
  19. IDEA 连接数据库报错
  20. 嘉益仕(Litins)再携手南方电网,升级智能仓储管理

热门文章

  1. 软件测试书清华大学出版社,清华大学出版社-图书详情-《软件测试技术与实践》...
  2. 哈尔滨工程大学 自动控制原理 真题
  3. 职称计算机ppt教程,职称计算机考试WPS教程:幻灯片格式的设置
  4. 【技术规划】描绘未来第 4 部分:技术路线图
  5. 带你走进准确率高于 99.7% 的智能鉴黄功能
  6. (上部)你要的 wechaty 微信机器人教程
  7. 面具Zygigk插件开发入门教程
  8. 什么是pmp证书,pmp证书有什么用,考试时间?
  9. OSPF的七种状态机
  10. 小米网抢购系统开发实践阅读心得