Keras Tuner官方教程

import tensorflow as tf
from tensorflow import keras

Install and import the Keras Tuner.

pip install -q -U keras-tuner
import kerastuner as kt

下载准备数据集

Fashion MNIST dataset.

(img_train, label_train), (img_test, label_test) = keras.datasets.fashion_mnist.load_data()

标准化 像素到0-1

# Normalize pixel values between 0 and 1
img_train = img_train.astype('float32') / 255.0
img_test = img_test.astype('float32') / 255.0

定义模型

定义hypermodel 的两种方式

  • 模型构建器函数( model builder function )
  • 您还可以使用两个预定义的HyperModel类—HyperXception和HyperResNet用于计算机视觉应用程序。

在本教程中,您将使用模型构建器函数( model builder function )来定义图像分类模型。模型构建器函数返回一个已编译的模型,并使用内联定义的超参数对模型进行超调。

# 构建模型,传入hp参数,使用其定义需要优化的参数范围,构成参数空间
def model_builder(hp):model = keras.Sequential()model.add(keras.layers.Flatten(input_shape=(28, 28)))# Tune the number of units in the first Dense layer# Choose an optimal value between 32-512hp_units = hp.Int('units', min_value=32, max_value=512, step=32)model.add(keras.layers.Dense(units=hp_units, activation='relu'))model.add(keras.layers.Dense(10))# Tune the learning rate for the optimizer# Choose an optimal value from 0.01, 0.001, or 0.0001hp_learning_rate = hp.Choice('learning_rate', values=[1e-2, 1e-3, 1e-4])model.compile(optimizer=keras.optimizers.Adam(learning_rate=hp_learning_rate),loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),metrics=['accuracy'])return model

实例化调优器并执行超参数调优

tuner是Hyperband

实例化调优器以执行超调优。Keras调优器有四个调优器可用-随机搜索,Hyperband,贝叶斯优化和Sklearn。在本教程中,您将使用Hyperband调谐器。
要实例化Hyperband调优器,您必须指定超参模型、要优化的目标objective 和要训练的最大epoch数(max epoch)。

tuner = kt.Hyperband(model_builder,objective='val_accuracy',max_epochs=10,factor=3,directory='my_dir',project_name='intro_to_kt')

Hyperband调优算法采用自适应资源分配 adaptive resource allocation和早停 early-stopping来快速收敛到高性能模型上。

该算法在几个epochs内训练大量模型,只将表现最好的一半模型带入下一轮。Hyperband通过计算1 + log factor(max_epochs)
确定要训练的模型数量四舍五入到最接近的整数。

创建一个回调,在达到验证损失的某个值后尽早停止训练。

stop_early = tf.keras.callbacks.EarlyStopping(monitor='val_loss', patience=5)

执行超级参数搜索。除了上面的回调外,搜索方法的参数与tf.keras.model.fit的参数相同。

tuner.search(img_train, label_train, epochs=50, validation_split=0.2, callbacks=[stop_early])# Get the optimal hyperparameters
best_hps=tuner.get_best_hyperparameters(num_trials=1)[0]print(f"""
The hyperparameter search is complete. The optimal number of units in the first densely-connected
layer is {best_hps.get('units')} and the optimal learning rate for the optimizer
is {best_hps.get('learning_rate')}.
""")

output:

Trial 30 Complete [00h 00m 24s]
val_accuracy: 0.8824166655540466Best val_accuracy So Far: 0.8901666402816772
Total elapsed time: 00h 05m 34s
INFO:tensorflow:Oracle triggered exitThe hyperparameter search is complete. The optimal number of units in the first densely-connected
layer is 448 and the optimal learning rate for the optimizer
is 0.001.

Train the model训练模型

利用搜索得到的超参数,找出最优epoch数来训练模型。找到最优的epoch,利用val_acc_per_epoch.index(max(val_acc_per_epoch)) + 1

# Build the model with the optimal hyperparameters and train it on the data for 50 epochs
model = tuner.hypermodel.build(best_hps)
history = model.fit(img_train, label_train, epochs=50, validation_split=0.2)val_acc_per_epoch = history.history['val_accuracy']
best_epoch = val_acc_per_epoch.index(max(val_acc_per_epoch)) + 1
print('Best epoch: %d' % (best_epoch,))

重新实例化超模型,从上面用最优的epochs数训练它。

hypermodel = tuner.hypermodel.build(best_hps)# Retrain the model
hypermodel.fit(img_train, label_train, epochs=best_epoch, validation_split=0.2)

根据测试数据评估hypermodel

eval_result = hypermodel.evaluate(img_test, label_test)
print("[test loss, test accuracy]:", eval_result)

output:

313/313 [==============================] - 1s 2ms/step - loss: 0.5915 - accuracy: 0.8867
[test loss, test accuracy]: [0.5915395617485046, 0.8866999745368958]

my_dir/intro_to_kt( 一开始在kt.Hyperband()中定义了)
目录包含在超参数搜索期间运行的每个试验(模型配置)的详细日志和检查点。如果重新运行超参数搜索,Keras Tuner将使用这些日志中的现有状态恢复搜索。要禁用此行为,请在实例化调优器时传递一个额外的overwrite=True参数。

总结(不想看细节的直接看)

import kerastuner as kt
model_builder
tuner = kt.Hyperband(model_builder )

stop_early 放在callback里
tuner.search(callbacks=[stop_early])寻找参数
找到参数为best_hps
best_hps=tuner.get_best_hyperparameters(num_trials=1)[0]
#构建 使用最优参数的网络
tuner.hypermodel.build(best_hps)
model.fit()

找到最好的epochs
然后,重新构建hypermodel,用最好的epochs再次训练

参考

翻译原文:教程

论文:Hyperband: A Novel Bandit-Based Approach toHyperparameter Optimization


总体实验:
需要设置模型构建函数,优化目标的名称(最大化还是最小化由内建度量得出),测试的总试验次数(max_trials),每次试验模型构建训练次数(executions_per_trial)。目前的优化器有RandomSearch和Hyperband。


Keras Tuner官方教程相关推荐

  1. Keras Tuner自动调参工具使用入门教程

    主体是翻译的Keras Tuner的说明:https://keras-team.github.io/keras- tuner/documentation/tuners/ github地址:https: ...

  2. TensorFlow 2官方教程 . Keras机器学习基础知识 . 使用TF Hub进行文本分类

    写在前面 此篇博客转载自tensorflow官方教程中文翻译版: https://www.tensorflow.org/tutorials/keras/text_classification_with ...

  3. tensorflow官方Blog-使用Keras Tuner超参数优化框架 进行超参数调整 ,具体实现版本

    文章目录 进入正题,keras tuner超参数优化框架 模型构建def build_model(hp): 实例化tuner 加载数据集,进行超参数搜索tuner.search() 找到最佳的模型tu ...

  4. TensorFlow2.0 Guide官方教程 学习笔记20 -‘Effective TensorFlow 2‘

    本笔记参照TensorFlow Guide官方教程,主要是对'Effictive TensorFlow 2'教程内容翻译和内容结构编排,原文链接:Effictive TensorFlow 2 高效的T ...

  5. 【tf.keras】官方教程一 Keras overview

    目录 Sequential Model:(the simplest type of model) Getting started with the Keras Sequential model Spe ...

  6. TensorFlow2.0 Guide官方教程 学习笔记17 -‘Using the SavedModel format‘

    本笔记参照TensorFlow官方教程,主要是对'Save a model-Training checkpoints'教程内容翻译和内容结构编排,原文链接:Using the SavedModel f ...

  7. Tensorflow2.0学习-Keras Tuner 妙用 (六)

    文章目录 Keras Tuner调整超参数 引包 数据准备 模型准备 跑起来 Keras Tuner调整超参数 Keras Tuner 是一个库,可帮助您为 TensorFlow 程序选择最佳的超参数 ...

  8. TensorFlow2 -官方教程 :保存和恢复模型

    文章目录 准备工作:安装,导入,获取数据集,定义model 在训练期间保存模型(以 checkpoints 形式保存) Checkpoint 回调用法 checkpoint 回调选项 这些文件是什么? ...

  9. tf.data官方教程 - - 基于TF-v2

    这是本人关于tf.data的第二篇博文,第一篇基于TF-v1详细介绍了tf.data,但是v1和v2很多地方不兼容,所以替大家瞧瞧v2的tf.data模块有什么新奇之处. TensorFlow版本:2 ...

最新文章

  1. CONVT_NO_NUMBER
  2. Android SDK在线更新镜像服务器
  3. java赋值语句_深度分析:面试阿里,字节99%会被问到Java类加载机制和类加载器...
  4. php废物,PHP的垃圾回收机制以及大概实现
  5. Zoho:尽快修复已遭利用的 ManageEngine 严重漏洞
  6. 华为机试HJ88:扑克牌大小
  7. zookeeper观察者模式设计实例
  8. struts2从form取值的三种方式
  9. Python实用笔记 (27)面向对象高级编程——使用枚举类
  10. POJ 3581:Sequence(后缀数组)
  11. chrome浏览器版本简单介绍
  12. 标准软件开发过程 文档
  13. 数据、数据元素、数据项、数据对象
  14. 医学图像DCM格式文件处理
  15. Atitit 人工智能声音处理乐器总类以及midi规范的标示 目录 1. Atitit 乐器分类 打击乐器 1 1.1. 1.1. 打击乐器(各种鼓 三角铁等 1 2 1.2. 1.2. 管乐器
  16. 微信开放平台申请网站应用
  17. php1蛋白质带电情况,拿到一个蛋白以后,首先需要对蛋白进行全面的了解,所谓知彼知己方能百战不殆:...
  18. flixel 一个游戏开发的框架
  19. 我对“结构化思维”的理解 - 直播分享
  20. webfreer下载及设置

热门文章

  1. js实现登录页面的背景图片的随机展示
  2. 电脑软件连接ABB机器人控制柜
  3. Python数据可视化第 9 讲:matplotlib极坐标图绘制函数polar
  4. 【数据库】什么是 PostgreSQL?开源数据库系统
  5. Linux下Nginx的启动、停止等命令
  6. 在线文档协作进行项目管理
  7. sprintf 用法(sprintf_s)
  8. Android数据存储(二)----PreferenceFragment详解
  9. 网络安全笔记第四天day4(kali基本操作)
  10. easyrecovery2023免费绿色版电脑数据恢复软件