Keras Tuner官方教程
Keras Tuner官方教程
import tensorflow as tf
from tensorflow import keras
Install and import the Keras Tuner.
pip install -q -U keras-tuner
import kerastuner as kt
下载准备数据集
Fashion MNIST dataset.
(img_train, label_train), (img_test, label_test) = keras.datasets.fashion_mnist.load_data()
标准化 像素到0-1
# Normalize pixel values between 0 and 1
img_train = img_train.astype('float32') / 255.0
img_test = img_test.astype('float32') / 255.0
定义模型
定义hypermodel 的两种方式
- 模型构建器函数( model builder function )
- 您还可以使用两个预定义的HyperModel类—HyperXception和HyperResNet用于计算机视觉应用程序。
在本教程中,您将使用模型构建器函数( model builder function )来定义图像分类模型。模型构建器函数返回一个已编译的模型,并使用内联定义的超参数对模型进行超调。
# 构建模型,传入hp参数,使用其定义需要优化的参数范围,构成参数空间
def model_builder(hp):model = keras.Sequential()model.add(keras.layers.Flatten(input_shape=(28, 28)))# Tune the number of units in the first Dense layer# Choose an optimal value between 32-512hp_units = hp.Int('units', min_value=32, max_value=512, step=32)model.add(keras.layers.Dense(units=hp_units, activation='relu'))model.add(keras.layers.Dense(10))# Tune the learning rate for the optimizer# Choose an optimal value from 0.01, 0.001, or 0.0001hp_learning_rate = hp.Choice('learning_rate', values=[1e-2, 1e-3, 1e-4])model.compile(optimizer=keras.optimizers.Adam(learning_rate=hp_learning_rate),loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),metrics=['accuracy'])return model
实例化调优器并执行超参数调优
tuner是Hyperband
实例化调优器以执行超调优。Keras调优器有四个调优器可用-随机搜索,Hyperband,贝叶斯优化和Sklearn。在本教程中,您将使用Hyperband调谐器。
要实例化Hyperband调优器,您必须指定超参模型、要优化的目标objective 和要训练的最大epoch数(max epoch)。
tuner = kt.Hyperband(model_builder,objective='val_accuracy',max_epochs=10,factor=3,directory='my_dir',project_name='intro_to_kt')
Hyperband调优算法采用自适应资源分配 adaptive resource allocation和早停 early-stopping来快速收敛到高性能模型上。
该算法在几个epochs内训练大量模型,只将表现最好的一半模型带入下一轮。Hyperband通过计算1 + log factor(max_epochs)
确定要训练的模型数量四舍五入到最接近的整数。
创建一个回调,在达到验证损失的某个值后尽早停止训练。
stop_early = tf.keras.callbacks.EarlyStopping(monitor='val_loss', patience=5)
执行超级参数搜索。除了上面的回调外,搜索方法的参数与tf.keras.model.fit的参数相同。
tuner.search(img_train, label_train, epochs=50, validation_split=0.2, callbacks=[stop_early])# Get the optimal hyperparameters
best_hps=tuner.get_best_hyperparameters(num_trials=1)[0]print(f"""
The hyperparameter search is complete. The optimal number of units in the first densely-connected
layer is {best_hps.get('units')} and the optimal learning rate for the optimizer
is {best_hps.get('learning_rate')}.
""")
output:
Trial 30 Complete [00h 00m 24s]
val_accuracy: 0.8824166655540466Best val_accuracy So Far: 0.8901666402816772
Total elapsed time: 00h 05m 34s
INFO:tensorflow:Oracle triggered exitThe hyperparameter search is complete. The optimal number of units in the first densely-connected
layer is 448 and the optimal learning rate for the optimizer
is 0.001.
Train the model训练模型
利用搜索得到的超参数,找出最优epoch数来训练模型。找到最优的epoch,利用val_acc_per_epoch.index(max(val_acc_per_epoch)) + 1
# Build the model with the optimal hyperparameters and train it on the data for 50 epochs
model = tuner.hypermodel.build(best_hps)
history = model.fit(img_train, label_train, epochs=50, validation_split=0.2)val_acc_per_epoch = history.history['val_accuracy']
best_epoch = val_acc_per_epoch.index(max(val_acc_per_epoch)) + 1
print('Best epoch: %d' % (best_epoch,))
重新实例化超模型,从上面用最优的epochs数训练它。
hypermodel = tuner.hypermodel.build(best_hps)# Retrain the model
hypermodel.fit(img_train, label_train, epochs=best_epoch, validation_split=0.2)
根据测试数据评估hypermodel
eval_result = hypermodel.evaluate(img_test, label_test)
print("[test loss, test accuracy]:", eval_result)
output:
313/313 [==============================] - 1s 2ms/step - loss: 0.5915 - accuracy: 0.8867
[test loss, test accuracy]: [0.5915395617485046, 0.8866999745368958]
my_dir/intro_to_kt( 一开始在kt.Hyperband()中定义了)
目录包含在超参数搜索期间运行的每个试验(模型配置)的详细日志和检查点。如果重新运行超参数搜索,Keras Tuner将使用这些日志中的现有状态恢复搜索。要禁用此行为,请在实例化调优器时传递一个额外的overwrite=True参数。
总结(不想看细节的直接看)
import kerastuner as kt
model_builder
tuner = kt.Hyperband(model_builder )
stop_early 放在callback里
tuner.search(callbacks=[stop_early])寻找参数
找到参数为best_hps
best_hps=tuner.get_best_hyperparameters(num_trials=1)[0]
#构建 使用最优参数的网络
tuner.hypermodel.build(best_hps)
model.fit()
找到最好的epochs
然后,重新构建hypermodel,用最好的epochs再次训练
参考
翻译原文:教程
论文:Hyperband: A Novel Bandit-Based Approach toHyperparameter Optimization
总体实验:
需要设置模型构建函数,优化目标的名称(最大化还是最小化由内建度量得出),测试的总试验次数(max_trials),每次试验模型构建训练次数(executions_per_trial)。目前的优化器有RandomSearch和Hyperband。
Keras Tuner官方教程相关推荐
- Keras Tuner自动调参工具使用入门教程
主体是翻译的Keras Tuner的说明:https://keras-team.github.io/keras- tuner/documentation/tuners/ github地址:https: ...
- TensorFlow 2官方教程 . Keras机器学习基础知识 . 使用TF Hub进行文本分类
写在前面 此篇博客转载自tensorflow官方教程中文翻译版: https://www.tensorflow.org/tutorials/keras/text_classification_with ...
- tensorflow官方Blog-使用Keras Tuner超参数优化框架 进行超参数调整 ,具体实现版本
文章目录 进入正题,keras tuner超参数优化框架 模型构建def build_model(hp): 实例化tuner 加载数据集,进行超参数搜索tuner.search() 找到最佳的模型tu ...
- TensorFlow2.0 Guide官方教程 学习笔记20 -‘Effective TensorFlow 2‘
本笔记参照TensorFlow Guide官方教程,主要是对'Effictive TensorFlow 2'教程内容翻译和内容结构编排,原文链接:Effictive TensorFlow 2 高效的T ...
- 【tf.keras】官方教程一 Keras overview
目录 Sequential Model:(the simplest type of model) Getting started with the Keras Sequential model Spe ...
- TensorFlow2.0 Guide官方教程 学习笔记17 -‘Using the SavedModel format‘
本笔记参照TensorFlow官方教程,主要是对'Save a model-Training checkpoints'教程内容翻译和内容结构编排,原文链接:Using the SavedModel f ...
- Tensorflow2.0学习-Keras Tuner 妙用 (六)
文章目录 Keras Tuner调整超参数 引包 数据准备 模型准备 跑起来 Keras Tuner调整超参数 Keras Tuner 是一个库,可帮助您为 TensorFlow 程序选择最佳的超参数 ...
- TensorFlow2 -官方教程 :保存和恢复模型
文章目录 准备工作:安装,导入,获取数据集,定义model 在训练期间保存模型(以 checkpoints 形式保存) Checkpoint 回调用法 checkpoint 回调选项 这些文件是什么? ...
- tf.data官方教程 - - 基于TF-v2
这是本人关于tf.data的第二篇博文,第一篇基于TF-v1详细介绍了tf.data,但是v1和v2很多地方不兼容,所以替大家瞧瞧v2的tf.data模块有什么新奇之处. TensorFlow版本:2 ...
最新文章
- CONVT_NO_NUMBER
- Android SDK在线更新镜像服务器
- java赋值语句_深度分析:面试阿里,字节99%会被问到Java类加载机制和类加载器...
- php废物,PHP的垃圾回收机制以及大概实现
- Zoho:尽快修复已遭利用的 ManageEngine 严重漏洞
- 华为机试HJ88:扑克牌大小
- zookeeper观察者模式设计实例
- struts2从form取值的三种方式
- Python实用笔记 (27)面向对象高级编程——使用枚举类
- POJ 3581:Sequence(后缀数组)
- chrome浏览器版本简单介绍
- 标准软件开发过程 文档
- 数据、数据元素、数据项、数据对象
- 医学图像DCM格式文件处理
- Atitit 人工智能声音处理乐器总类以及midi规范的标示 目录 1. Atitit 乐器分类 打击乐器	1 1.1. 1.1. 打击乐器(各种鼓 三角铁等	1	2 1.2. 1.2. 管乐器
- 微信开放平台申请网站应用
- php1蛋白质带电情况,拿到一个蛋白以后,首先需要对蛋白进行全面的了解,所谓知彼知己方能百战不殆:...
- flixel 一个游戏开发的框架
- 我对“结构化思维”的理解 - 直播分享
- webfreer下载及设置