本文翻译自:How to Check-Point Deep Learning Models in Keras


深度学习模型可能需要数小时,数天甚至数周才能进行训练。
如果意外停止运行,则可能会丢失大量工作。
在这篇文章中,您将了解如何使用Keras库在Python培训期间检查您的深度学习模型。
让我们开始吧。

  • 2017年3月更新:Keras 2.0.2,TensorFlow 1.0.1和Theano 0.9.0的更新示例。
  • 更新March / 2018:添加了备用链接以下载数据集,因为原始图像已被删除。

检验点神经网络模型
应用程序检查点是一种适用于长时间运行过程的容错技术。

这是一种在系统出现故障时采用系统状态快照的方法。如果出现问题,并非全部丢失。检查点可以直接使用,或者用作新运行的起点,从中断处开始。

在训练深度学习模型时,检查点是模型的权重。这些权重可用于按原样进行预测,或用作持续培训的基础。

Keras库通过回调API提供检查点功能。

ModelCheckpoint回调类允许您定义检查模型权重的位置,文件应如何命名以及在何种情况下创建模型的检查点。

API允许您指定要监控的度量标准,例如培训或验证数据集的丢失或准确性。您可以指定是否在最大化或最小化分数时寻求改进。最后,用于存储权重的文件名可以包含诸如纪元号或度量的变量。

然后,在模型上调用fit()函数时,可以将ModelCheckpoint传递给训练过程。

注意,您可能需要安装h5py库以输出HDF5格式的网络权重。


检查点神经网络模型改进

检查点的良好用途是每次在训练期间观察到改进时输出模型权重。

下面的例子为皮马印第安人糖尿病二元分类问题创建了一个小型神经网络。该示例假设pima-indians-diabetes.csv文件位于您的工作目录中。

您可以从此处下载数据集:

皮马印第安人糖尿病数据集
该示例使用33%的数据进行验证。

只有在验证数据集(monitor ='val_acc’和mode =‘max’)的分类准确性有所提高时,才会设置检验点以保存网络权重。权重存储在一个文件中,该文件包含文件名中的分数(权重改进 - {val_acc = .2f} .hdf5)。

# Checkpoint the weights when validation accuracy improves
from keras.models import Sequential
from keras.layers import Dense
from keras.callbacks import ModelCheckpoint
import matplotlib.pyplot as plt
import numpy
# fix random seed for reproducibility
seed = 7
numpy.random.seed(seed)
# load pima indians dataset
dataset = numpy.loadtxt("pima-indians-diabetes.data.csv", delimiter=",")
# split into input (X) and output (Y) variables
X = dataset[:,0:8]
Y = dataset[:,8]
# create model
model = Sequential()
model.add(Dense(12, input_dim=8, kernel_initializer='uniform', activation='relu'))
model.add(Dense(8, kernel_initializer='uniform', activation='relu'))
model.add(Dense(1, kernel_initializer='uniform', activation='sigmoid'))
# Compile model
model.compile(loss='binary_crossentropy', optimizer='adam', metrics=['accuracy'])
# checkpoint
filepath="weights-improvement-{epoch:02d}-{val_acc:.2f}.hdf5"
checkpoint = ModelCheckpoint(filepath, monitor='val_acc', verbose=1, save_best_only=True, mode='max')
callbacks_list = [checkpoint]
# Fit the model
model.fit(X, Y, validation_split=0.33, epochs=150, batch_size=10, callbacks=callbacks_list, verbose=0)

运行该示例将生成以下输出(为简洁起见,将其截断):

Epoch 00134: val_acc did not improve
Epoch 00135: val_acc did not improve
Epoch 00136: val_acc did not improve
Epoch 00137: val_acc did not improve
Epoch 00138: val_acc did not improve
Epoch 00139: val_acc did not improve
Epoch 00140: val_acc improved from 0.83465 to 0.83858, saving model to weights-improvement-140-0.84.hdf5
Epoch 00141: val_acc did not improve
Epoch 00142: val_acc did not improve
Epoch 00143: val_acc did not improve
Epoch 00144: val_acc did not improve
Epoch 00145: val_acc did not improve
Epoch 00146: val_acc improved from 0.83858 to 0.84252, saving model to weights-improvement-146-0.84.hdf5
Epoch 00147: val_acc did not improve
Epoch 00148: val_acc improved from 0.84252 to 0.84252, saving model to weights-improvement-148-0.84.hdf5
Epoch 00149: val_acc did not improve

您将在工作目录中看到许多文件,其中包含HDF5格式的网络权重。例如:

weights-improvement-53-0.76.hdf5
weights-improvement-71-0.76.hdf5
weights-improvement-77-0.78.hdf5
weights-improvement-99-0.78.hdf5

这是一个非常简单的检查点策略。如果验证准确度在训练时期上下移动,则可能会创建大量不必要的检查点文件。然而,它将确保您拥有在运行期间发现的最佳模型的快照。

仅限检查点最佳神经网络模型

更简单的检查点策略是将模型权重保存到同一文件中,当且仅当验证准确度提高时。

这可以使用上面相同的代码轻松完成,并将输出文件名更改为固定(不包括分数或纪元信息)。

在这种情况下,只有当验证数据集上模型的分类精度提高到目前为止最佳时,模型权重才会写入文件“weights.best.hdf5”。

# Checkpoint the weights for best model on validation accuracy
from keras.models import Sequential
from keras.layers import Dense
from keras.callbacks import ModelCheckpoint
import matplotlib.pyplot as plt
import numpy
# fix random seed for reproducibility
seed = 7
numpy.random.seed(seed)
# load pima indians dataset
dataset = numpy.loadtxt("pima-indians-diabetes.data.csv", delimiter=",")
# split into input (X) and output (Y) variables
X = dataset[:,0:8]
Y = dataset[:,8]
# create model
model = Sequential()
model.add(Dense(12, input_dim=8, kernel_initializer='uniform', activation='relu'))
model.add(Dense(8, kernel_initializer='uniform', activation='relu'))
model.add(Dense(1, kernel_initializer='uniform', activation='sigmoid'))
# Compile model
model.compile(loss='binary_crossentropy', optimizer='adam', metrics=['accuracy'])
# checkpoint
filepath="weights.best.hdf5"
checkpoint = ModelCheckpoint(filepath, monitor='val_acc', verbose=1, save_best_only=True, mode='max')
callbacks_list = [checkpoint]
# Fit the model
model.fit(X, Y, validation_split=0.33, epochs=150, batch_size=10, callbacks=callbacks_list, verbose=0)

运行此示例提供以下输出(为简洁起见,将其截断):

Epoch 00139: val_acc improved from 0.79134 to 0.79134, saving model to weights.best.hdf5
Epoch 00140: val_acc did not improve
Epoch 00141: val_acc did not improve
Epoch 00142: val_acc did not improve
Epoch 00143: val_acc did not improve
Epoch 00144: val_acc improved from 0.79134 to 0.79528, saving model to weights.best.hdf5
Epoch 00145: val_acc improved from 0.79528 to 0.79528, saving model to weights.best.hdf5
Epoch 00146: val_acc did not improve
Epoch 00147: val_acc did not improve
Epoch 00148: val_acc did not improve
Epoch 00149: val_acc did not improve

您应该在本地目录中看到权重文件。

weights.best.hdf5

这是一个方便的检查点策略,在您的实验中始终使用。它将确保为运行保存最佳模型,以便您以后使用。它避免了您需要在训练时包含代码以手动跟踪和序列化最佳模型。

加载检查指向神经网络模型

现在您已经了解了如何在培训期间检查您的深度学习模型,您需要查看如何加载和使用检查点模型。

检查点仅包括模型权重。它假设您了解网络结构。这也可以序列化为JSON或YAML格式的文件。

在下面的示例中,模型结构是已知的,最佳权重从上一个实验加载,存储在weights.best.hdf5文件的工作目录中。

然后使用该模型对整个数据集进行预测。

# How to load and use weights from a checkpoint
from keras.models import Sequential
from keras.layers import Dense
from keras.callbacks import ModelCheckpoint
import matplotlib.pyplot as plt
import numpy
# fix random seed for reproducibility
seed = 7
numpy.random.seed(seed)
# create model
model = Sequential()
model.add(Dense(12, input_dim=8, kernel_initializer='uniform', activation='relu'))
model.add(Dense(8, kernel_initializer='uniform', activation='relu'))
model.add(Dense(1, kernel_initializer='uniform', activation='sigmoid'))
# load weights
model.load_weights("weights.best.hdf5")
# Compile model (required to make predictions)
model.compile(loss='binary_crossentropy', optimizer='adam', metrics=['accuracy'])
print("Created model and loaded weights from file")
# load pima indians dataset
dataset = numpy.loadtxt("pima-indians-diabetes.data.csv", delimiter=",")
# split into input (X) and output (Y) variables
X = dataset[:,0:8]
Y = dataset[:,8]
# estimate accuracy on whole dataset using loaded weights
scores = model.evaluate(X, Y, verbose=0)
print("%s: %.2f%%" % (model.metrics_names[1], scores[1]*100))

运行该示例将生成以下输出:

Created model and loaded weights from file
acc: 77.73%

摘要

在这篇文章中,您已经发现了深度学习模型在长时间训练中的重要性。

您学习了两个检查点策略,您可以在下一个深度学习项目中使用它们:

  • 检查点模型改进。
  • Checkpoint最佳型号。
    您还学习了如何加载检查点模型并进行预测。

如何在Keras中检查深度学习模型(翻译)相关推荐

  1. 如何为Keras中的深度学习模型建立Checkpoint

    深度学习模式可能需要几个小时,几天甚至几周的时间来训练. 如果运行意外停止,你可能就白干了. 在这篇文章中,你将会发现在使用Keras库的Python训练过程中,如何检查你的深度学习模型 Checkp ...

  2. 使用keras进行深度学习_如何在Keras中通过深度学习对蝴蝶进行分类

    使用keras进行深度学习 A while ago I read an interesting blog post on the website of the Dutch organization V ...

  3. 如何在TensorFlow中通过深度学习构建年龄和性别的多任务预测器

    by Cole Murray 通过科尔·默里(Cole Murray) In my last tutorial, you learned about how to combine a convolut ...

  4. 【Django】项目中调用深度学习模型model.predict()(Django两种启动方式runserver和uwsgi的区别)

    目录 问题 测试 解决方法 Django两种启动方式runserver和uwsgi的区别 问题 部署含有深度学习模型的Django项目的uWSGI.Nginx服务器的时候,所有模块都可以正常运行,也可 ...

  5. 在服务器上运行论文中的深度学习模型

    前言 首先需要在服务器上搭建运行环境,参见上一篇博客:[服务器上搭建深度学习模型运行环境:ubuntu] 本文主要讲在搭建好运行环境的情况下如何跑开源模型,以Inf-Net: Automatic CO ...

  6. 医学图像处理中的深度学习模型

    细胞病理学识别和疾病组织目标检测是目标人工智能技术在影像医学和病理方向的重要应用.  该技术主要是前期的预处理技术复杂,主要原因是因为医学的相关病理特征成因复杂,图像方面的随机误差很大(噪音),图像断 ...

  7. 资深算法专家解读CTR预估业务中的深度学习模型

    内容来源:2018 年 01 月 05 日,资深算法专家张俊林在"2018 移动技术创新大会"进行<深度学习在CTR预估业务中的应用>演讲分享.IT 大咖说(微信id: ...

  8. 如何在Chatbot中应用深度学习

    人类其实从很早以前就开始追求人类和机器之间的对话,早先科学家研发的机器在和人对话时都是采用规则性的回复,比如人提问后,计算机从数据库中找出相关的答案来回复.这种规则性的一对一匹配有很多限制.机器只知道 ...

  9. 如何在Chatbot中应用深度学习?

    编者按:本书节选自图书<深度学习算法实践>,本书以一位软件工程师在工作中遇到的问题为主线,阐述了如何从软件工程思维向算法思维转变.如何将任务分解成算法问题,并结合程序员在工作中经常面临的产 ...

最新文章

  1. python可以实现哪些功能_Python中实现机器学习功能的四种方法介绍
  2. Spring Boot Admin Reference Guide
  3. YLMF OS 发布
  4. 矩阵转置行列式的运算规律
  5. 计算机无法访问,您可能没有权限使用网络资源.请与这台服务器的管理员联系的解决办
  6. matlab命令行窗口显示长度设置_设置命令行窗口输出显示格式 | MATLAB format| MathWork...
  7. Modern love 年度最暖心美剧
  8. 胆结石的发病原因有哪些?
  9. 怎么快速做一个excel手机报表?
  10. python画哆啦a梦 代码_python之:tkinter画哆啦A梦
  11. fbx sdk android,FBX SDK环境配置
  12. 路由器计算机无法上网,电脑可以上网路由器不能上网怎么回事?
  13. Go语言学习、结构体
  14. 【转租】【淞虹路独立厨卫一室户2700/月】【与房东直接签合同】
  15. 美团面试 一面+二面
  16. HTML+CSS从入门到入土
  17. oracle11g ocm考试总结
  18. CT图像重建算法------迭代投影模型之距离驱动算法(Distance-Driven Model,DDM)
  19. 快速对帝国竞争算法ICA的了解
  20. Android端与服务端基于TCP/IP协议的Socket通讯

热门文章

  1. python 乱序数组,list等有序结构的方法
  2. C库函数-perror()
  3. 市场有变,中小型基因测序机构机会来了
  4. Microbiome:NGLess语言实现快速可重复分析宏基因组的流程NG-meta-profiler
  5. 宏基因组实战8. 分箱宏基因组binning, MqaxBin, MetaBin, VizBin
  6. Python使用numpy包编写自定义函数计算平均绝对误差(MAE、Mean Absolute Error)、评估回归模型和时间序列模型、解读MAE
  7. numpy使用[]语法索引二维numpy数组中指定行列位置的数值内容(access value at certain row and column in numpy array)
  8. R语言ggplot2可视化:指定标题的坐标轴位置(X轴坐标和Y轴坐标),将图像的标题(title)放置在图像内部的指定位置(customize title positon in plot)
  9. R语言with函数和within函数:with函数基于表达式在dataframe上计算、within函数基于表达式在dataframe上计算并修改原始数据
  10. pandas使用argmax函数返回给定series对象中最大值(max、maximum)的行索引实战