在训练神经网络时,如果epochs设置的过多,导致最终结束时测试集上模型的准确率比较低,而我们却想保存准确率最高时候的模型参数,这就需要用到Early Stopping以及ModelCheckpoint

一.早停策略之EarlyStopping

EarlyStopping是用于提前停止训练callbacks,callbacks用于指定在每个epoch开始和结束的时候进行哪种特定操作。简而言之,就是可以达到当测试集上的loss不再减小(即减小的程度小于某个阈值)的时候停止继续训练

1.EarlyStopping的原理

1.将数据分为训练集和测试集
2.每个epoch结束后(或每N个epoch后): 在测试集上获取测试结果,随着epoch的增加,如果在测试集上发现测试误差上升,则停止训练;
3.将停止之后的权重作为网络的最终参数。

这儿就有一个疑惑,在平常模型训练时,会发现模型的loss值有时会出现降低再上升再下降的情况,难道只要再上升的时候就要停止嘛?上升之后再下降有可能会得到更低的loss值,那么如果只要上升就停止的话,就会得不偿失。现实肯定不是这样的不能根据一两次的连续降低就判断不再提高。一般的做法是,在训练的过程中,记录到目前为止最好的测试集精度,当连续10次epoch(或者更多次)没达到最佳精度时,则可以认为精度不再提高了。

看图直观感受一下:

2.EarlyStopping的优缺点

优点:只运行一次梯度下降,我们就可以找出w的较小值,中间值和较大值。而无需尝试L2正则化超级参数lambda的很多值。
缺点:不能独立地处理以上两个问题,使得要考虑的东西变得复杂

3.参数解释

tf.keras.callbacks.EarlyStopping(monitor="acc",min_delta=0,patience=0,verbose=0,mode="max",baseline=None,restore_best_weights=False,
)

1.monitor: 监控的数据接口,有’acc’,’val_acc’,’loss’,’val_loss’等等。正常情况下如果有验证集,就用’val_acc’或者’val_loss’。

2.mode: 就’auto’, ‘min’, ‘,max’三个可能。如果知道是要上升还是下降,建议设置一下。例如监控的是’acc’,那么就设置为’max’。

3.min_delta:增大或减小的阈值,只有大于这个部分才算作改善(监控的数据不同,变大变小就不确定)。这个值的大小取决于monitor,也反映了你的容忍程度。

4.patience:能够容忍多少个epoch内都没有改善。patience的大小和learning rate直接相关。在learning rate设定的情况下,前期先训练几次观察抖动的epoch number,patience设置的值应当稍大于epoch number。在learning rate变化的情况下,建议要略小于最大的抖动epoch number

5.baseline:监控数据的基线值,如果在训练过程中,模型训练结果相比于基线值没有什么改善的话,就停止训练。

二.ModelCheckpoint

函数原型:

tf.keras.callbacks.ModelCheckpoint(filepath, monitor='val_loss', verbose=0, save_best_only=False, save_weights_only=False, mode='auto', period=1)

参数解释

1.filename:字符串,保存模型的路径,filepath可以是格式化的字符串,里面的占位符将会被epoch值和传入on_epoch_endlogs关键字所填入。
例如:filepath = “weights_{epoch:03d}-{val_loss:.4f}.h5”,则会生成对应epoch和测试集loss的多个文件。

2.monitor:需要监视的值,通常为:val_acc 、 val_loss 、 acc 、 loss四种。

3.verbose:信息展示模式,0或1。为1表示输出epoch模型保存信息,默认为0表示不输出该信息。

4.save_best_only:当设置为True时,将只保存在测试集上性能最好的模型

5.mode:‘auto’,‘min’,‘max’之一,在save_best_only=True时决定性能最佳模型的评判准则,例如,当监测值为val_acc时,模式应为max,当检测值为val_loss时,模式应为min。在auto模式下,评价准则由被监测值的名字自动推断

6.save_weights_only:若设置为True,则只保存模型权重,否则将保存整个模型(包括模型结构,配置信息等)。

7.period:CheckPoint之间的间隔的epoch数。

三.样例示范

from tensorflow.keras.callbacks import ModelCheckpoint, Callback, EarlyStoppingearlystopper = EarlyStopping(monitor='loss', patience=1, verbose=1,mode = 'min')checkpointer = ModelCheckpoint('best_model.h5',monitor='val_accuracy',verbose=0,save_best_only=True,save_weights_only=True,mode = 'max')
train_model  = model.fit(train_ds,epochs=epochs,validation_data=test_ds,callbacks=[earlystopper, checkpointer]#<-看这儿)

努力加油a啊

参考链接:

https://blog.csdn.net/zwqjoy/article/details/86677030
https://blog.csdn.net/zengNLP/article/details/94589469

深度学习之早停策略EarlyStopping以及保存测试集准确率最高的模型ModelCheckpoint相关推荐

  1. 【pytorch EarlyStopping】深度学习之早停法入门·相信我,一篇就够。

    这个方法更好的解决了模型过拟合问题. EarlyStopping的原理是提前结束训练轮次来达到"早停"的目的,故训练轮次需要设置的大一点以求更好的早停(比如可以设置100epoch ...

  2. 深度学习的实用层面 —— 1.1 训练/开发/测试集

    在配置训练.验证和测试数据集的过程中做出正确决策会在很大程度上帮助大家创建高效的神经网络. 在训练神经网络时,我们需要做出很多决策,例如神经网络分多少层,每层含有多少个隐藏单元,学习速率是多少,各层采 ...

  3. pytorch使用早停策略

    文章目录 早停的目的与流程 早停策略 pytorch使用示例 参考网站 早停的目的与流程 目的:防止模型过拟合,由于深度学习模型可以无限迭代下去,因此希望在即将过拟合时.或训练效果微乎其微时停止训练. ...

  4. 【Python】深度学习中将数据按比例随机分成随机 训练集 和 测试集的python脚本

    深度学习中经常将数据分成 训练集 和 测试集,参考博客,修改python脚本 randPickAITrainTestData.py . 功能:从 输入目录 中随机检出一定比例的文件或目录,移动到保存 ...

  5. 浅谈深度学习:LSTM对股票的收益进行预测(Sequential 序贯模型,Keras实现)

    浅谈深度学习:LSTM对股票的收益进行预测(Sequential 序贯模型,Keras实现) 总包含文章: 一个完整的机器学习模型的流程 浅谈深度学习:了解RNN和构建并预测 浅谈深度学习:基于对LS ...

  6. 深度学习技术系列(1):Mosaic Model — 不良图片检测开源模型

    最近整理了在图像深度学习方面的一部分工作,开源了一个不良图片检测的CNN模型(Mosaic Model),放在了github上.目前github上开源了最新的模型,以及demo的python文件,可以 ...

  7. 【深度学习】Pytorch实现CIFAR10图像分类任务测试集准确率达95%

    文章目录 前言 CIFAR10简介 Backbone选择 训练+测试 训练环境及超参设置 完整代码 部分测试结果 完整工程文件 Reference 前言 分享一下本人去年入门深度学习时,在CIFAR1 ...

  8. 深度学习theano/tensorflow多显卡多人使用问题集

    深度学习theano/tensorflow多显卡多人使用问题集 转载自:https://zhuanlan.zhihu.com/p/23250782 其实一直想写这篇东西,今天还是抽空系统整理一下吧. ...

  9. 深度学习网络模型训练过程中的Loss问题合集

    把数据集随机分为训练集,验证集和测试集,然后用训练集训练模型,用验证集验证模型,根据情况不断调整模型,选择出其中最好的模型,再用训练集和验证集数据训练出一个最终的模型,最后用测试集评估最终的模型 训练 ...

最新文章

  1. 广东安网2016:重拳挥出 打造安宁互联网环境
  2. Mysql数据库编码转换问题
  3. python np fft_Python的武器库05:numpy模块(下)
  4. ios apple pay 证书配置
  5. King of Range
  6. 【计算机网络】Quiz集合
  7. python重定向_Python接口自动化(十)重定向(Location)
  8. mysql 存在更新不存在写入_梅姨这个人,到底存在不存在?
  9. (2)把BlackBerry作为插件安装到已有的Eclipse中
  10. socket连接时间太长受什么原因影响?_晶振不起振的原因和应对措施
  11. groovy-实现接口
  12. 香港股票交易成本计算器 android,股票交易手续费计算器
  13. HTML5游戏开发进阶指南.pdf
  14. 电池电压测试技术总结
  15. 深信服科技2019年校园招聘 移动应用开发 一面
  16. 阿里云服务器安装mysql
  17. 20190713 关于session串号问题的记录
  18. RPC框架的意义和用法,什么是RPC
  19. HTML5+CSS3笔记
  20. centos安装aria2c_CentOS下搭建Aria2远程下载环境

热门文章

  1. 汇编中addr和offset
  2. C++中extern关键字的作用
  3. MFC中绘制动态曲线
  4. swift5主线程延迟操作的几种写法
  5. git add后取消_git 必看,各种撤销操作
  6. 剑指offer(12)旋转数组的最小数字
  7. php url地址 怎么写,php url地址重写
  8. python删除字符串中的字母_在Python中删除字符串中的大写字母
  9. Android开发之WebView加载HTML源码包含转义字符实现富文本显示的方法
  10. flask 路由 php文件,Flask 请求处理流程(一):WSGI 和 路由