Early Stopping 早停法原理与实现
Early Stopping
训练深度学习神经网络的时候通常希望能获得最好的泛化性能,可以更好地拟合数据。但是所有的标准深度学习神经网络结构如全连接多层感知机都很容易过拟合。
当模型在训练集上表现很好,在验证集上表现很差的时候,我们认为模型出现了过拟合的情况,early stoppping 就是用来预防过拟合的一种方法,简单且有效。
原理
early stoppping 的原理是:当模型在验证集上的表现开始下降的时候,停止训练,这样就能避免继续训练导致过拟合的问题
缺点
如下图,模型在验证集上的表现可能咱短暂的变差之后有可能继续变好,并不是在验证集上的表现一旦变差就不会变好。early stoppping 主要是训练时间和泛化错误之间的权衡。
pytorch 实现
EarlyStopping 是用于提前停止训练的callbacks。具体地,可以达到当训练集上的loss不在减小(即减小的程度小于某个阈值)的时候停止继续训练。
初始化
- patience:自上次模型在验证集上损失降低之后等待的时间,此处设置为7
- verbose:当为False时,运行的时候将不显示详细信息
- counter:计数器,当其值超过patience时候,使用early stopping
- best_score:记录模型评估的最好分数
- early_step:决定模型要不要early stop,为True则停
- val_loss_min:模型评估损失函数的最小值,默认为正无穷(np.Inf)
- delta:表示模型损失函数改进的最小值,当超过这个值时候表示模型有所改进
class EarlyStopping:def __init__(self, patience=7, verbose=False, delta=0):self.patience = patienceself.verbose = verboseself.counter = 0self.best_score = Noneself.early_stop = Falseself.val_loss_min = np.Infself.delta = delta
保存模型
传入参数:val_loss、model 和 path
verbose 为True,则打印详细信息
函数作用:在 path 路径下,保存当前 model,并更新 val_loss_min 为当前 val_loss
def save_checkpoint(self, val_loss, model, path):if self.verbose:print(f'Validation loss decreased ({self.val_loss_min:.6f} --> {val_loss:.6f}). Saving model ...')torch.save(model.state_dict(), path + '/' + 'checkpoint.pth')self.val_loss_min = val_loss
调用
定义__call__()方法,该方法的功能类似于在类中重载 () 运算符,使得类实例对象可以像调用普通函数那样,以“对象名()”的形式使用
初始化时,设定了
self.best_score = None
if 语句第一行判断 self.best_score 是否为初始值,如果是初始值,则将 score 赋值给 self.best_score ,然后调用save_checkpoint()函数保存
当目前分数比最好分数加 self.delta 小时,就认为模型没有改进,将 counter 计数器加1,当计数器值超过 patience 的时候,就令early_stop为True,让模型停止训练。
当目前分数比最好分数加 self.delta 大时,我们认为模型有改进,将目前分数赋值给最好分数,并将模型保存,令计数器归零。
def __call__(self, val_loss, model, path):score = -val_lossif self.best_score is None:self.best_score = scoreself.save_checkpoint(val_loss, model, path)elif score < self.best_score + self.delta:self.counter += 1print(f'EarlyStopping counter: {self.counter} out of {self.patience}')if self.counter >= self.patience:self.early_stop = Trueelse:self.best_score = scoreself.save_checkpoint(val_loss, model, path)self.counter = 0
总体代码
总体代码如下:
class EarlyStopping:def __init__(self, patience=7, verbose=False, delta=0):self.patience = patienceself.verbose = verboseself.counter = 0self.best_score = Noneself.early_stop = Falseself.val_loss_min = np.Infself.delta = deltadef __call__(self, val_loss, model, path):score = -val_lossif self.best_score is None:self.best_score = scoreself.save_checkpoint(val_loss, model, path)elif score < self.best_score + self.delta:self.counter += 1print(f'EarlyStopping counter: {self.counter} out of {self.patience}')if self.counter >= self.patience:self.early_stop = Trueelse:self.best_score = scoreself.save_checkpoint(val_loss, model, path)self.counter = 0def save_checkpoint(self, val_loss, model, path):if self.verbose:print(f'Validation loss decreased ({self.val_loss_min:.6f} --> {val_loss:.6f}). Saving model ...')torch.save(model.state_dict(), path + '/' + 'checkpoint.pth')self.val_loss_min = val_loss
Early Stopping 早停法原理与实现相关推荐
- Early Stopping早停法
参考: https://www.jianshu.com/p/9ab695d91459
- R语言构建xgboost模型使用早停法训练模型(early stopping):自定义损失函数(目标函数,loss function)、评估函数(evaluation function)
R语言构建xgboost模型使用早停法训练模型(early stopping):自定义损失函数(目标函数.loss function.object function).评估函数(evaluation ...
- keras构建前馈神经网络(feedforward neural network)进行分类模型构建基于早停法(Early stopping)
keras构建前馈神经网络(feedforward neural network)进行分类模型构建基于早停法(Early stopping) 当我们训练深度学习神经网络的时候通常希望能获得最好的泛化性 ...
- 深度学习技巧之Early Stopping(早停法)
深度学习技巧之Early Stopping(早停法) | 数据学习者官方网站(Datalearner) 当我们训练深度学习神经网络的时候通常希望能获得最好的泛化性能(generalization pe ...
- pytorch早停法
作为深度学习训练数据的trick,结合交叉验证法,可以防止模型过早拟合. 早停法是一种被广泛使用的方法,在很多案例上都比正则化的方法要好.是在训练中计算模型在验证集上的表现,当模型在验证集上的表现开始 ...
- EarlyStopping早停法的实现原理
keras中的EarlyStopping使用很方便,当但我测试torch的EarlyStopping时掉坑了!!! torch中pytorchtools工具有早停法,但我测试基本用不了,总是出错,fu ...
- 【pytorch EarlyStopping】深度学习之早停法入门·相信我,一篇就够。
这个方法更好的解决了模型过拟合问题. EarlyStopping的原理是提前结束训练轮次来达到"早停"的目的,故训练轮次需要设置的大一点以求更好的早停(比如可以设置100epoch ...
- 深度学习——早停法(Early Stopping)
学习链接:https://www.jianshu.com/p/9ab695d91459 https://www.datalearner.com/blog/1051537860479157 目的: 为了 ...
- Earlystopping(早停法)
Earlystopping 简介 当我们训练深度学习神经网络的时候通常希望能获得最好的泛化性能(generalization performance,即可以很好地拟合数据). 但是所有的标准深度学习神 ...
最新文章
- 好玩!PyEcharts 绘制时间轮播图
- js原生设计模式——2面向对象编程之继承—new+call(this)组合式继承
- MySQL - 多版本控制 MVCC 机制初探
- Intellij IDEA 默认打开上次项目设置与取消设置
- 汇编原理实验--输出ASCII码10H到100H
- 辗转相除法(欧几里得算法)求 最大公约数与最小公倍数+推论与证明。
- 一个写得很不错的vuex详解(转)
- 如何实时主动监控你的网站接口是否挂掉并及时报警
- RedHat6.2 x86手动配置LNMP环境
- laravel简单的laragon环境搭建不需要composer一键集成
- SAP HR工资核算基础(转)
- 2015 2020 r4烧录卡 区别_2020版药典,药用辅料被重视了!
- 如何将视频转换为HEVC / H.265和AVC / H.264
- 如何提取mp4中的音频?
- 智慧营区解决方案-最新全套文件
- springboot yml文件不是绿叶子问题
- 【Linux】alias及设置
- 绿色版电脑站手机站仿站小工具
- IDEA 连接数据库报错
- 嘉益仕(Litins)再携手南方电网,升级智能仓储管理