本文整理了绝大多数keras里的Callback回调)函数,并且收集了代码调用示例。

大多数内容整理自网络,参考资料已在文章最后给出。

回调函数Callbacks

回调函数是一组在训练的特定阶段被调用的函数集,你可以使用回调函数来观察训练过程中网络内部的状态和统计信息。通过传递回调函数列表到模型的.fit()中,即可在给定的训练阶段调用该函数集中的函数。

【Tips】虽然我们称之为回调“函数”,但事实上Keras的回调函数是一个类,回调函数只是习惯性称呼

Callback

keras.callbacks.Callback()

这是回调函数的抽象类,定义新的回调函数必须继承自该类

类属性

  • params:字典,训练参数集(如信息显示方法verbosity,batch大小,epoch数)

  • model:keras.models.Model对象,为正在训练的模型的引用

回调函数以字典logs为参数,该字典包含了一系列与当前batch或epoch相关的信息。

目前,模型的.fit()中有下列参数会被记录到logs中:

  • 在每个epoch的结尾处(on_epoch_end),logs将包含训练的正确率和误差,accloss,如果指定了验证集,还会包含验证集正确率和误差val_acc)val_lossval_acc还额外需要在.compile中启用metrics=['accuracy']

  • 在每个batch的开始处(on_batch_begin):logs包含size,即当前batch的样本数

  • 在每个batch的结尾处(on_batch_end):logs包含loss,若启用accuracy则还包含acc

BaseLogger

keras.callbacks.BaseLogger()

该回调函数用来对每个epoch累加metrics指定的监视指标的epoch平均值

该回调函数在每个Keras模型中都会被自动调用


ProgbarLogger

keras.callbacks.ProgbarLogger()

该回调函数用来将metrics指定的监视指标输出到标准输出上

History

keras.callbacks.History()

该回调函数在Keras模型上会被自动调用,History对象即为fit方法的返回值


ModelCheckpoint【重点】

Checkpoint神经网络模型简介

应用程序Checkpoint是为长时间运行进程准备的容错技术。

这是一种在系统故障的情况下拍摄系统状态快照的方法。一旦出现问题不会让进度全部丢失。Checkpoint可以直接使用,也可以作为从它停止的地方重新运行的起点。

训练深度学习模型时,Checkpoint是模型的权重。他们可以用来作预测,或作持续训练的基础。

Keras库通过回调API提供Checkpoint功能。

ModelCheckpoint回调类允许你定义检查模型权重的位置在何处,文件应如何命名,以及在什么情况下创建模型的Checkpoint。

API允许你指定要监视的指标,例如训练或验证数据集的丢失或准确性。你可以指定是否寻求最大化或最小化分数的改进。最后,用于存储权重的文件名可以包括诸如训练次数的编号或标准的变量。

当模型上调用fit()函数时,可以将ModelCheckpoint传递给训练过程。

注意,你可能需要安装h5py库以HDF5格式输出网络权重。

用法示例:

keras.callbacks.ModelCheckpoint(filepath, monitor='val_loss', verbose=0, save_best_only=False, save_weights_only=False, mode='auto', period=1)

该回调函数将在每个epoch后保存模型到filepath

filepath可以是格式化的字符串,里面的占位符将会被epoch值和传入on_epoch_endlogs关键字所填入

例如,filepath若为weights.{epoch:02d-{val_loss:.2f}}.hdf5,则会生成对应epoch和验证集loss的多个文件。

参数

  • filename:字符串,保存模型的路径

  • monitor:需要监视的值

  • verbose:信息展示模式,0或1

  • save_best_only:当设置为True时,将只保存在验证集上性能最好的模型

  • mode:‘auto’,‘min’,‘max’之一,在save_best_only=True时决定性能最佳模型的评判准则,例如,当监测值为val_acc时,模式应为max,当检测值为val_loss时,模式应为min。在auto模式下,评价准则由被监测值的名字自动推断。

  • save_weights_only:若设置为True,则只保存模型权重,否则将保存整个模型(包括模型结构,配置信息等)

  • period:CheckPoint之间的间隔的epoch数

例子No.0:官网示例

from keras.callbacks import ModelCheckpointmodel = Sequential()
model.add(Dense(10, input_dim=784, kernel_initializer='uniform'))
model.add(Activation('softmax'))
model.compile(loss='categorical_crossentropy', optimizer='rmsprop')'''
saves the model weights after each epoch if the validation loss decreased
'''
checkpointer = ModelCheckpoint(filepath="/tmp/weights.hdf5", verbose=1, save_best_only=True)
model.fit(X_train, Y_train, batch_size=128, epochs=20, verbose=0, validation_data=(X_test, Y_test), callbacks=[checkpointer])

例子No.1:保存改进的每一个改进的CheckPoint模型

应用Checkpoint时,应在每次训练中观察到改进时输出模型权重。

下面的示例创建一个小型神经网络Pima印第安人发生糖尿病的二元分类问题。你可以在UCI机器学习库下载这个数据集。本示例使用33%的数据进行验证。

Checkpoint设置成当验证数据集的分类精度提高时保存网络权重(monitor=’val_acc’ and mode=’max’)。权重存储在一个包含评价的文件中(weights-improvement – { val_acc = .2f } .hdf5)。

# Checkpoint the weights when validation accuracy improves
from keras.modelsimport Sequential
from keras.layersimport Dense
from keras.callbacksimport ModelCheckpoint
import matplotlib.pyplot as plt
import numpy
# fix random seed for reproducibility
seed= 7
numpy.random.seed(seed)
# load pima indians dataset
dataset= numpy.loadtxt("pima-indians-diabetes.csv", delimiter=",")
# split into input (X) and output (Y) variables
X= dataset[:,0:8]
Y= dataset[:,8]
# create model
model= Sequential()
model.add(Dense(12, input_dim=8, kernel_initializer='uniform', activation='relu'))
model.add(Dense(8, kernel_initializer='uniform', activation='relu'))
model.add(Dense(1, kernel_initializer='uniform', activation='sigmoid'))
# Compile model
model.compile(loss='binary_crossentropy', optimizer='adam', metrics=['accuracy'])
# checkpoint
filepath="weights-improvement-{epoch:02d}-{val_acc:.2f}.hdf5"
checkpoint= ModelCheckpoint(filepath, monitor='val_acc', verbose=1, save_best_only=True, mode='max')
callbacks_list= [checkpoint]
# Fit the model
model.fit(X, Y, validation_split=0.33, epochs=150, batch_size=10, callbacks=callbacks_list, verbose=0)

你将在工作目录中看到包含多个HDF5格式的网络权重文件。例如:

...
weights-improvement-53-0.76.hdf5
weights-improvement-71-0.76.hdf5
weights-improvement-77-0.78.hdf5
weights-improvement-99-0.78.hdf5

这是一个非常简单的Checkpoint策略。如果验证精度在训练周期上下波动 ,则可能会创建大量不必要的Checkpoint文件。然而,它将确保你具有在运行期间发现的最佳模型的快照。

例子No.2:保存最佳的Checkpoint模型

如果验证精度提高的话,一个更简单的Checkpoint策略是将模型权重保存到相同的文件中。

这可以使用上述相同的代码轻松完成,并将输出文件名更改为固定(不包括评价或次数的信息)。

在这种情况下,只有当验证数据集上的模型的分类精度提高到到目前为止最好的时候,才会将模型权重写入文件“weights.best.hdf5”。

# Checkpoint the weights for best model on validation accuracy
from keras.modelsimport Sequential
from keras.layersimport Dense
from keras.callbacksimport ModelCheckpoint
import matplotlib.pyplot as plt
import numpy
# fix random seed for reproducibility
seed= 7
numpy.random.seed(seed)
# load pima indians dataset
dataset= numpy.loadtxt("pima-indians-diabetes.csv", delimiter=",")
# split into input (X) and output (Y) variables
X= dataset[:,0:8]
Y= dataset[:,8]
# create model
model= Sequential()
model.add(Dense(12, input_dim=8, kernel_initializer='uniform', activation='relu'))
model.add(Dense(8, kernel_initializer='uniform', activation='relu'))
model.add(Dense(1, kernel_initializer='uniform', activation='sigmoid'))
# Compile model
model.compile(loss='binary_crossentropy', optimizer='adam', metrics=['accuracy'])
# checkpoint
filepath="weights.best.hdf5"
checkpoint= ModelCheckpoint(filepath, monitor='val_acc', verbose=1, save_best_only=True, mode='max')
callbacks_list= [checkpoint]
# Fit the model
model.fit(X, Y, validation_split=0.33, epochs=150, batch_size=10, callbacks=callbacks_list, verbose=0)

你应该在本地目录中看到权重文件:

weights.best.hdf5

这是一个在你的实验中需要经常用到的方便的Checkpoint策略。它将确保你的最佳模型被保存,以便稍后使用。它避免了输入代码来手动跟踪,并在训练时序列化最佳模型。

例子No.3:加载之前保存过的Checkpoint模型

Checkpoint只包括模型权重。它假定你了解网络结构。这也可以序列化成JSON或YAML格式。

在下面的示例中,模型结构是已知的,并且最好的权重从先前的实验中加载,然后存储在weights.best.hdf5文件的工作目录中。

那么将该模型用于对整个数据集进行预测。

# How to load and use weights from a checkpoint
from keras.modelsimport Sequential
from keras.layersimport Dense
from keras.callbacksimport ModelCheckpoint
import matplotlib.pyplot as plt
import numpy
# fix random seed for reproducibility
seed= 7
numpy.random.seed(seed)
# create model
model= Sequential()
model.add(Dense(12, input_dim=8, kernel_initializer='uniform', activation='relu'))
model.add(Dense(8, kernel_initializer='uniform', activation='relu'))
model.add(Dense(1, kernel_initializer='uniform', activation='sigmoid'))
# load weights
model.load_weights("weights.best.hdf5")
# Compile model (required to make predictions)
model.compile(loss='binary_crossentropy', optimizer='adam', metrics=['accuracy'])
print("Created model and loaded weights from file")
# load pima indians dataset
dataset= numpy.loadtxt("pima-indians-diabetes.csv", delimiter=",")
# split into input (X) and output (Y) variables
X= dataset[:,0:8]
Y= dataset[:,8]
# estimate accuracy on whole dataset using loaded weights
scores= model.evaluate(X, Y, verbose=0)
print("%s: %.2f%%" % (model.metrics_names[1], scores[1]*100))
运行示例生成以下输出:
Created modeland loaded weightsfrom file
acc:77.73%

EarlyStopping

keras.callbacks.EarlyStopping(monitor='val_loss', patience=0, verbose=0, mode='auto')

当监测值不再改善时,该回调函数将中止训练

参数

  • monitor:需要监视的量

  • patience:当early stop被激活(如发现loss相比上一个epoch训练没有下降),则经过patience个epoch后停止训练。

  • verbose:信息展示模式

  • mode:‘auto’,‘min’,‘max’之一,在min模式下,如果检测值停止下降则中止训练。在max模式下,当检测值不再上升则停止训练。

LearningRateScheduler

keras.callbacks.LearningRateScheduler(schedule)

该回调函数是学习率调度器

参数

  • schedule:函数,该函数以epoch号为参数(从0算起的整数),返回一l新学习率(浮点数)

例子:LearningRateScheduler调整学习率

from keras.callbacks import LearningRateScheduler
def scheduler(epoch):if epoch % 100 == 0 and epoch != 0:lr = K.get_value(model.optimizer.lr)K.set_value(model.optimizer.lr, lr * 0.1)print("lr changed to {}".format(lr * 0.1))return K.get_value(model.optimizer.lr)reduce_lr = LearningRateScheduler(scheduler)
model.fit(train_x, train_y, batch_size=32, epochs=5, callbacks=[reduce_lr])

ReduceLROnPlateau

keras.callbacks.ReduceLROnPlateau(monitor='val_loss', factor=0.1, patience=10, verbose=0, mode='auto', epsilon=0.0001, cooldown=0, min_lr=0)

当评价指标不在提升时,减少学习率

当学习停滞时,减少2倍或10倍的学习率常常能获得较好的效果。该回调函数检测指标的情况,如果在patience个epoch中看不到模型性能提升,则减少学习率

参数

  • monitor:被监测的量
  • factor:每次减少学习率的因子,学习率将以lr = lr*factor的形式被减少
  • patience:当patience个epoch过去而模型性能不提升时,学习率减少的动作会被触发
  • mode:‘auto’,‘min’,‘max’之一,在min模式下,如果检测值触发学习率减少。在max模式下,当检测值不再上升则触发学习率减少。
  • epsilon:阈值,用来确定是否进入检测值的“平原区”
  • cooldown:学习率减少后,会经过cooldown个epoch才重新进行正常操作
  • min_lr:学习率的下限

例子:ReduceLROnPlateau自动减少学习率

from keras.callbacks import ReduceLROnPlateau
reduce_lr = ReduceLROnPlateau(monitor='val_loss', factor=0.2,patience=5, min_lr=0.001)
model.fit(X_train, Y_train, callbacks=[reduce_lr])

CSVLogger

keras.callbacks.CSVLogger(filename, separator=',', append=False)

将epoch的训练结果保存在csv文件中,支持所有可被转换为string的值,包括1D的可迭代数值如np.ndarray.

参数

  • fiename:保存的csv文件名,如run/log.csv
  • separator:字符串,csv分隔符
  • append:默认为False,为True时csv文件如果存在则继续写入,为False时总是覆盖csv文件

例子:CSVLoggerb保存训练结果

csv_logger = CSVLogger('training.log')
model.fit(X_train, Y_train, callbacks=[csv_logger])

LambdaCallback

keras.callbacks.LambdaCallback(on_epoch_begin=None, on_epoch_end=None, on_batch_begin=None, on_batch_end=None, on_train_begin=None, on_train_end=None)

用于创建简单的callback的callback类

该callback的匿名函数将会在适当的时候调用,注意,该回调函数假定了一些位置参数on_eopoch_beginon_epoch_end假定输入的参数是epoch, logs. on_batch_beginon_batch_end假定输入的参数是batch, logson_train_beginon_train_end假定输入的参数是logs

参数

  • on_epoch_begin: 在每个epoch开始时调用
  • on_epoch_end: 在每个epoch结束时调用
  • on_batch_begin: 在每个batch开始时调用
  • on_batch_end: 在每个batch结束时调用
  • on_train_begin: 在训练开始时调用
  • on_train_end: 在训练结束时调用

例子:轻量级自定义回调函数LambdaCallback

# Print the batch number at the beginning of every batch.
batch_print_callback = LambdaCallback(on_batch_begin=lambda batch,logs: print(batch))# Plot the loss after every epoch.
import numpy as np
import matplotlib.pyplot as plt
plot_loss_callback = LambdaCallback(on_epoch_end=lambda epoch, logs: plt.plot(np.arange(epoch),logs['loss']))# Terminate some processes after having finished model training.
processes = ...
cleanup_callback = LambdaCallback(on_train_end=lambda logs: [p.terminate() for p in processes if p.is_alive()])model.fit(...,callbacks=[batch_print_callback,plot_loss_callback,cleanup_callback])

编写自己的回调函数

我们可以通过继承keras.callbacks.Callback编写自己的回调函数,回调函数通过类成员self.model访问访问,该成员是模型的一个引用。

这里是一个简单的保存每个batch的loss的回调函数:

class LossHistory(keras.callbacks.Callback):def on_train_begin(self, logs={}):self.losses = []def on_batch_end(self, batch, logs={}):self.losses.append(logs.get('loss'))

例子:记录损失函数的历史数据

class LossHistory(keras.callbacks.Callback):def on_train_begin(self, logs={}):self.losses = []def on_batch_end(self, batch, logs={}):self.losses.append(logs.get('loss'))model = Sequential()
model.add(Dense(10, input_dim=784, kernel_initializer='uniform'))
model.add(Activation('softmax'))
model.compile(loss='categorical_crossentropy', optimizer='rmsprop')history = LossHistory()
model.fit(X_train, Y_train, batch_size=128, epochs=20, verbose=0, callbacks=[history])print history.losses
# outputs
'''
[0.66047596406559383, 0.3547245744908703, ..., 0.25953155204159617, 0.25901699725311789]

主要参考:

  • Keras中文文档
  • 如何为Keras中的深度学习模型建立Checkpoint
  • 【深度学习】Keras学习率调整
  • Keras英文文档

Keras中的各种Callback函数示例(含Checkpoint模型的保存、读取示例)-----记录相关推荐

  1. java addcallback函数_java中怎么使用callback函数?

    UYOU 在很多场景,作为开发都会想到,在执行完毕一个任务的时候,能执行一个callback函数是多么好的事情.现在模拟一下这个情景:定义三个类.分别是主函数类.callback函数的接口类.业务处理 ...

  2. C语言编程题—结构体—设计程序,已知学生的记录由学号和学习成绩构成,N名学生的数据已存入a结构体数组中。请编写函数 fun:找出成绩最低的学生记录,通过形参返回主函数(规定只有一个最低分

    4 C语言编程题--结构体 **设计程序,已知学生的记录由学号和学习成绩构成,N名学生的数据已存入a结构体数组中.请编写函数 fun,函数的功能是:找出成绩最低的学生记录,通过形参返回主函数(规定只有 ...

  3. Keras中几个重要函数用法

    官方keras例子:http://keras-cn.readthedocs.io/en/latest/getting_started/sequential_model/ 模块需导入包: [python ...

  4. go 变量在其中一个函数中赋值 另一个函数_go 学习笔记之仅仅需要一个示例就能讲清楚什么闭包...

    本篇文章是 Go 语言学习笔记之函数式编程系列文章的第二篇,上一篇介绍了函数基础,这一篇文章重点介绍函数的重要应用之一: 闭包 空谈误国,实干兴邦,以具体代码示例为基础讲解什么是闭包以及为什么需要闭包 ...

  5. keras中model的evaluate函数的返回值究竟是什么?

    不少帖子都解释evaluate()函数的返回值是什么,但是通常都让人看得云里雾里,其实该函数第一个返回值是损失(loss),第二个返回值是准确率(acc).可以通过打印model.metrics_na ...

  6. java 消息传递示例_java actor模型和消息传递简单示例

    接上面java actor模型框架ujavaactorhttp://zhwj184.iteye.com/admin/blogs/1613351,上面的示例比较复杂,写一个简单点的示例: import ...

  7. Callback 函数

    说明:此文章出自<深入浅出mfc>第6章的"Callback  函数" Callback 函数 Hello的OnPaint在程序收到 WM_PAINT之后开始运作.为了 ...

  8. Ray----Tune(5):Tune包中的类和函数参考

    本篇主要介绍一下tune中常用的一些函数用处,可以作为一个简单的API使用. ray.tune ray.tune.grid_search(values) 用于指定值上的网格搜索的快捷方法. 参数:va ...

  9. Keras中Callback函数的使用

    回调函数是一组在训练的特定阶段被调用的函数集,你可以使用回调函数来观察训练过程中网络内部的状态和统计信息.通过传递回调函数列表到模型的.fit()中,即可在给定的训练阶段调用该函数集中的函数. [Ti ...

最新文章

  1. 利用反射实现类的动态加载
  2. 画活动图教程_绘画教程116—传统的山水现代的刀画,看了就会的步骤图
  3. [詹兴致矩阵论习题参考解答]习题6.6
  4. 派森编程软件python-零基础学习Python需要用什么开发工具?
  5. LightOJ1032 Fast Bit Calculations(数位DP)
  6. 微服务网关Gateway中StripPrefix讲解
  7. Linux RPM包校验和数字证书
  8. [bzoj3036]绿豆蛙的归宿
  9. 6大新品重磅发布,华为云全栈云原生技术能力持续创新升级
  10. cpu矿工cpuminer-multi编译与使用
  11. (05)FPGA内部资源
  12. 使用firefox44版本,弃用chrome
  13. 【Oracle】表级别分区操作对索引(本地分区索引,全局分区索引,非分区索引)的影响
  14. maven库的查询和配置
  15. python实现艾宾浩斯背单词功能,实现自动提取单词、邮件发送,部署在阿里云服务器,再也不用担心背单词啦!!
  16. 计算机配色在纺织中的应用,计算机配色在印染行业的应用
  17. 访问本机php文件无法解析_浏览器访问.php文件不解析直接下载
  18. 记录安装Ubuntu16.04后必须要做的事,杂篇
  19. 吾生也有涯,而学也无涯
  20. 常用笔记啊(持续更新)

热门文章

  1. macbook搭建java环境_MacBook从零开始搭建java环境
  2. java 分析 线程堵塞 原因_java如何避免线程阻塞?相关方法解析
  3. 最大连接数和最大线程数相关知识点的总结
  4. 机器视觉系统设计的难点解析
  5. LINUX系统不识别NTFS格式的硬盘
  6. 公路村村通(含注释)
  7. 英文Essay写作标准结构及风格分析
  8. 矩阵的行列式(递归C++)
  9. javascript里的转义字符
  10. 【Spring Boot 源码研究 】- 自动化装配机制核心注解剖析