如何为Keras中的深度学习模型建立Checkpoint
深度学习模式可能需要几个小时,几天甚至几周的时间来训练。
如果运行意外停止,你可能就白干了。
在这篇文章中,你将会发现在使用Keras库的Python训练过程中,如何检查你的深度学习模型
Checkpoint神经网络模型
应用程序Checkpoint是为长时间运行进程准备的容错技术。
这是一种在系统故障的情况下拍摄系统状态快照的方法。一旦出现问题不会让进度全部丢失。Checkpoint可以直接使用,也可以作为从它停止的地方重新运行的起点。
训练深度学习模型时,Checkpoint是模型的权重。他们可以用来作预测,或作持续训练的基础。
Keras库通过回调API提供Checkpoint功能。
ModelCheckpoint回调类允许你定义检查模型权重的位置在何处,文件应如何命名,以及在什么情况下创建模型的Checkpoint。
API允许你指定要监视的指标,例如训练或验证数据集的丢失或准确性。你可以指定是否寻求最大化或最小化分数的改进。最后,用于存储权重的文件名可以包括诸如训练次数的编号或标准的变量。
当模型上调用fit()函数时,可以将ModelCheckpoint传递给训练过程。
注意,你可能需要安装h5py库以HDF5格式输出网络权重。
Checkpoint神经网络模型改进
应用Checkpoint时,应在每次训练中观察到改进时输出模型权重。
下面的示例创建一个小型神经网络Pima印第安人发生糖尿病的二元分类问题。你可以在UCI机器学习库下载这个数据集。本示例使用33%的数据进行验证。
Checkpoint设置成当验证数据集的分类精度提高时保存网络权重(monitor=’val_acc’ and mode=’max’)。权重存储在一个包含评价的文件中(weights-improvement – { val_acc = .2f } .hdf5)。
# Checkpoint the weights when validation accuracy improves
from keras.modelsimport Sequential
from keras.layersimport Dense
from keras.callbacksimport ModelCheckpoint
import matplotlib.pyplot as plt
import numpy
# fix random seed for reproducibility
seed= 7
numpy.random.seed(seed)
# load pima indians dataset
dataset= numpy.loadtxt("pima-indians-diabetes.csv", delimiter=",")
# split into input (X) and output (Y) variables
X= dataset[:,0:8]
Y= dataset[:,8]
# create model
model= Sequential()
model.add(Dense(12, input_dim=8, kernel_initializer='uniform', activation='relu'))
model.add(Dense(8, kernel_initializer='uniform', activation='relu'))
model.add(Dense(1, kernel_initializer='uniform', activation='sigmoid'))
# Compile model
model.compile(loss='binary_crossentropy', optimizer='adam', metrics=['accuracy'])
# checkpoint
filepath="weights-improvement-{epoch:02d}-{val_acc:.2f}.hdf5"
checkpoint= ModelCheckpoint(filepath, monitor='val_acc', verbose=1, save_best_only=True, mode='max')
callbacks_list= [checkpoint]
# Fit the model
model.fit(X, Y, validation_split=0.33, epochs=150, batch_size=10, callbacks=callbacks_list, verbose=0)
运行示例会生成以下输出(有删节):
...
Epoch00134: val_acc didnot improve
Epoch00135: val_acc didnot improve
Epoch00136: val_acc didnot improve
Epoch00137: val_acc didnot improve
Epoch00138: val_acc didnot improve
Epoch00139: val_acc didnot improve
Epoch00140: val_acc improvedfrom 0.83465 to0.83858, saving model to weights-improvement-140-0.84.hdf5
Epoch00141: val_acc didnot improve
Epoch00142: val_acc didnot improve
Epoch00143: val_acc didnot improve
Epoch00144: val_acc didnot improve
Epoch00145: val_acc didnot improve
Epoch00146: val_acc improvedfrom 0.83858 to0.84252, saving model to weights-improvement-146-0.84.hdf5
Epoch00147: val_acc didnot improve
Epoch00148: val_acc improvedfrom 0.84252 to0.84252, saving model to weights-improvement-148-0.84.hdf5
Epoch00149: val_acc didnot improve
你将在工作目录中看到包含多个HDF5格式的网络权重文件。例如:
...
weights-improvement-53-0.76.hdf5
weights-improvement-71-0.76.hdf5
weights-improvement-77-0.78.hdf5
weights-improvement-99-0.78.hdf5
这是一个非常简单的Checkpoint策略。如果验证精度在训练周期上下波动 ,则可能会创建大量不必要的Checkpoint文件。然而,它将确保你具有在运行期间发现的最佳模型的快照。
Checkpoint最佳神经网络模型
如果验证精度提高的话,一个更简单的Checkpoint策略是将模型权重保存到相同的文件中。
这可以使用上述相同的代码轻松完成,并将输出文件名更改为固定(不包括评价或次数的信息)。
在这种情况下,只有当验证数据集上的模型的分类精度提高到到目前为止最好的时候,才会将模型权重写入文件“weights.best.hdf5”。
# Checkpoint the weights for best model on validation accuracy
from keras.modelsimport Sequential
from keras.layersimport Dense
from keras.callbacksimport ModelCheckpoint
import matplotlib.pyplot as plt
import numpy
# fix random seed for reproducibility
seed= 7
numpy.random.seed(seed)
# load pima indians dataset
dataset= numpy.loadtxt("pima-indians-diabetes.csv", delimiter=",")
# split into input (X) and output (Y) variables
X= dataset[:,0:8]
Y= dataset[:,8]
# create model
model= Sequential()
model.add(Dense(12, input_dim=8, kernel_initializer='uniform', activation='relu'))
model.add(Dense(8, kernel_initializer='uniform', activation='relu'))
model.add(Dense(1, kernel_initializer='uniform', activation='sigmoid'))
# Compile model
model.compile(loss='binary_crossentropy', optimizer='adam', metrics=['accuracy'])
# checkpoint
filepath="weights.best.hdf5"
checkpoint= ModelCheckpoint(filepath, monitor='val_acc', verbose=1, save_best_only=True, mode='max')
callbacks_list= [checkpoint]
# Fit the model
model.fit(X, Y, validation_split=0.33, epochs=150, batch_size=10, callbacks=callbacks_list, verbose=0)
运行示例会生成以下输出(有删节):
...
Epoch00139: val_acc improvedfrom 0.79134 to0.79134, saving model to weights.best.hdf5
Epoch00140: val_acc didnot improve
Epoch00141: val_acc didnot improve
Epoch00142: val_acc didnot improve
Epoch00143: val_acc didnot improve
Epoch00144: val_acc improvedfrom 0.79134 to0.79528, saving model to weights.best.hdf5
Epoch00145: val_acc improvedfrom 0.79528 to0.79528, saving model to weights.best.hdf5
Epoch00146: val_acc didnot improve
Epoch00147: val_acc didnot improve
Epoch00148: val_acc didnot improve
Epoch00149: val_acc didnot improve
你应该在本地目录中看到权重文件:
weights.best.hdf5
这是一个在你的实验中需要经常用到的方便的Checkpoint策略。它将确保你的最佳模型被保存,以便稍后使用。它避免了输入代码来手动跟踪,并在训练时序列化最佳模型。
加载Checkpoint神经网络模型
现在你已经了解了如何在训练期间检查深度学习模型,你需要回顾一下如何加载和使用一个Checkpoint模型。
Checkpoint只包括模型权重。它假定你了解网络结构。这也可以序列化成JSON或YAML格式。
在下面的示例中,模型结构是已知的,并且最好的权重从先前的实验中加载,然后存储在weights.best.hdf5文件的工作目录中。
那么将该模型用于对整个数据集进行预测。
# How to load and use weights from a checkpoint
from keras.modelsimport Sequential
from keras.layersimport Dense
from keras.callbacksimport ModelCheckpoint
import matplotlib.pyplot as plt
import numpy
# fix random seed for reproducibility
seed= 7
numpy.random.seed(seed)
# create model
model= Sequential()
model.add(Dense(12, input_dim=8, kernel_initializer='uniform', activation='relu'))
model.add(Dense(8, kernel_initializer='uniform', activation='relu'))
model.add(Dense(1, kernel_initializer='uniform', activation='sigmoid'))
# load weights
model.load_weights("weights.best.hdf5")
# Compile model (required to make predictions)
model.compile(loss='binary_crossentropy', optimizer='adam', metrics=['accuracy'])
print("Created model and loaded weights from file")
# load pima indians dataset
dataset= numpy.loadtxt("pima-indians-diabetes.csv", delimiter=",")
# split into input (X) and output (Y) variables
X= dataset[:,0:8]
Y= dataset[:,8]
# estimate accuracy on whole dataset using loaded weights
scores= model.evaluate(X, Y, verbose=0)
print("%s: %.2f%%" % (model.metrics_names[1], scores[1]*100))
运行示例生成以下输出:
Created modeland loaded weightsfrom file
acc:77.73%
总结
在这篇文章中,你已经发现Checkpoint对深度学习模型长期训练的重要性。
你学习了两种可用于你下一个深入Checkpoint学习项目的Checkpoint策略:
- Checkpoint模型改进。
- Checkpoint的最佳模型。
转载于https://cloud.tencent.com/developer/article/1049579,如有侵权,请联系 chris.zhang@wiz.ai 删除
如何为Keras中的深度学习模型建立Checkpoint相关推荐
- 如何在Keras中检查深度学习模型(翻译)
本文翻译自:How to Check-Point Deep Learning Models in Keras 深度学习模型可能需要数小时,数天甚至数周才能进行训练. 如果意外停止运行,则可能会丢失大量 ...
- 深度学习模型建立过程_所有深度学习都是统计模型的建立
深度学习模型建立过程 Deep learning is often used to make predictions for data driven analysis. But what are th ...
- 深度学习模型建立的整体流程和框架
深度学习模型建立的整体流程和框架 框架图如下,纵向是建立模型的主要流程,是一个简化且宏观的概念,横向是针对具体模块的延展. 数据处理 数据处理一般涉及到一下五个环节: 读入数据 划分数据集 生成批次数 ...
- 使用keras进行深度学习_如何在Keras中通过深度学习对蝴蝶进行分类
使用keras进行深度学习 A while ago I read an interesting blog post on the website of the Dutch organization V ...
- 【Django】项目中调用深度学习模型model.predict()(Django两种启动方式runserver和uwsgi的区别)
目录 问题 测试 解决方法 Django两种启动方式runserver和uwsgi的区别 问题 部署含有深度学习模型的Django项目的uWSGI.Nginx服务器的时候,所有模块都可以正常运行,也可 ...
- 在服务器上运行论文中的深度学习模型
前言 首先需要在服务器上搭建运行环境,参见上一篇博客:[服务器上搭建深度学习模型运行环境:ubuntu] 本文主要讲在搭建好运行环境的情况下如何跑开源模型,以Inf-Net: Automatic CO ...
- 医学图像处理中的深度学习模型
细胞病理学识别和疾病组织目标检测是目标人工智能技术在影像医学和病理方向的重要应用. 该技术主要是前期的预处理技术复杂,主要原因是因为医学的相关病理特征成因复杂,图像方面的随机误差很大(噪音),图像断 ...
- 资深算法专家解读CTR预估业务中的深度学习模型
内容来源:2018 年 01 月 05 日,资深算法专家张俊林在"2018 移动技术创新大会"进行<深度学习在CTR预估业务中的应用>演讲分享.IT 大咖说(微信id: ...
- 提升深度学习模型的表现,你需要这20个技巧
选自machielearningmastery 机器之心编译 作者:Jason Brownlee 参与:杜夏德.陈晨.吴攀.Terrence.李亚洲 本文原文的作者 Jason Brownlee 是一 ...
最新文章
- C指针7:指针作为函数返回值
- c# yield关键字原理
- python基础学习22----协程
- 【机器学习基础】--感知机完全解读
- vins中imu融合_双目版 VINS 项目发布,小觅双目摄像头作为双目惯导相机被推荐...
- python画圆形螺旋线_宝宝爱看小猪佩奇,很简单,让我们用python搞定它
- c++调用Java以及string互转
- 【linux】Ubuntu 18.04 设置桌面快捷启动方式
- Theano 中文文档 0.9 - 7.1.1 Python教程
- JS获取整个HTML网页代码 - Android 集美软件园 - 博客频道 - CSDN.NET
- hashtable允许null键和值吗_【29期】Java集合框架 10 连问,你有被问过吗?
- fabric批量操作远程操作主机的练习
- 大工19春计算机文化基础在线测试3,大工19春《计算机文化基础》在线测试3.doc...
- C# Winfrom MQTT 客户端与服务器【代码】
- web前端,多语言切换,data-localize,
- QQ聊天机器人--基于酷Q写的插件
- 分享106个PHP源码,总有一款适合您
- Splashtop 教育行业用户增加700%
- 工业互联网数据展现软件之组态工具
- 【数据结构】史上最好理解的红黑树讲解,让你彻底搞懂红黑树
热门文章
- Survey | 深度学习方法在生物网络中的应用
- SMILES | 简化分子线性输入规范
- 第十七课.有向图模型与条件独立性
- ubuntu下使用SVN
- 一个基于长数据转化为宽数据的小软件---data_tran.exe
- NAR:中科院微生物所发布全球模式微生物基因组测序计划进展
- Microbiome:首个地球微生物“社会关系”网络在浙大绘制!
- 华裔教授教你写论文2.引言的逻辑解析
- h5在线浏览word_怎样将PDF在线转换成Word?教你成为一个高手的方法
- R语言ggplot2可视化:使用长表数据(窄表数据)( Long Data Format)可视化多个时间序列数据、在同一个可视化图像中可视化多个时间序列数据(Multiple Time Series)