TensorFlow2.0(六)--超参数搜索
超参数搜索
- 1. 超参数搜索简介
- 1.1 超参数
- 1.2 超参数搜索
- 2. 手动实现超参数搜索
- 2.1 导入相应的库
- 2.2 数据载入与处理
- 2.3 手动实现超参数搜索
- 3. sklearn实现超参数搜索
- 3.1 sklearn实现超参数搜索
- 3.2 最优训练结果
- 3.3 对得到超参数进行验证
1. 超参数搜索简介
1.1 超参数
超参数就是在神经网络的训练过程中,不变的参数。比如:
- 网络结构参数:层数,每层宽度,每层激活函数等
- 训练参数: batch_size, 学习率, 学习率衰减算法
1.2 超参数搜索
如果我们在训练模型的过程中手动的一个个的更改上述的超参数组合,那么工作量是巨大的,所以我们需要采取超参数搜索策略。超参数搜索有一下几个策略:
网格搜索
网课搜索是一种最简答和最容易理解的超参数搜索策略。以dropout rate和learning rate两个超参数为例,我们可以将两个超参数组成一个二维网格,比如dropout rate取值[0.1, 0.3, 0.6, 0.8]四个值,learning rate取[0.001, 0.005, 0.01, 0.005],我们用二维网格将二个超参数两两结合,然后在多台机器上进行并行训练,就可以快速得到相对优的超参数组合。
随机搜索
随机搜索和网格搜索比较接近,二者的区别是网格搜索的参数分布是固定和相对均匀的,随机搜索的参数是随机生成的。对于网格搜索,最优参数很可能分布在网格中间而非网格的节点上,所以我们往往很难找到最优的超参数组合。随机搜索因为参数分布是随机的,所以找到的超参数组合往往要优于网格搜索,但是随机搜索生成的超参数组合也是要多于网格搜索。
遗传算法搜索
遗传算法是对自然界的模拟。- 首先我们先初始化候选参数集合,然后进行训练,以得到的模型指标作为该模型参数的生存概率,指标越好,生存概率越大。
- 其次,我们对参数进行选择–>交叉–>变异–>产生下一代
- 然后再次进行训练,重复以上步骤,最后得到的最优的参数集合就是我们搜索的最优结果
启发式搜索
启发式搜索是AutoML中的研究热点,启发式搜索使用循环神经网络来生成参数,然后使用强化学习来进行反馈,使用模型来训练生成参数。
2. 手动实现超参数搜索
2.1 导入相应的库
# matplotlib 用于绘图
import matplotlib as mpl
import matplotlib.pyplot as plt
%matplotlib inline
# 处理数据的库
import numpy as np
import sklearn
import pandas as pd
# 系统库
import os
import sys
import time
# TensorFlow的库
import tensorflow as tf
from tensorflow import keras
2.2 数据载入与处理
本篇博客选择使用房价预测的回归问题来完成超参数搜索,因为这个问题的维度比较小,实现起来比较容易。
数据集加载:
from sklearn.datasets import fetch_california_housing
housing = fetch_california_housing()
数据集分割为训练集、测试集与验证集:
from sklearn.model_selection import train_test_split
"""
# test_size 指的是划分的训练集和测试集的比例
# test_size 默认值为0.25 表示数据分四份,测试集占一份
"""
x_train_all, x_test, y_train_all, y_test = train_test_split(housing.data, housing.target, random_state = 7, test_size = 0.25)
x_train, x_valid, y_train, y_valid = train_test_split(x_train_all, y_train_all, random_state = 11, test_size = 0.25)
数据归一化处理:
# 数据归一化
from sklearn.preprocessing import StandardScaler
scaler = StandardScaler()
"""
# 训练集数据使用的是 fit_transform,和验证集与测试集中使用的 transform 是不一样的
# fit_transform 可以计算数据的均值和方差并记录下来
# 验证集和测试集用到的均值和方差都是训练集数据的,所以二者的归一化使用 transform 即可
# 归一化只针对输入数据, 标签不变
"""
x_train_scaled = scaler.fit_transform(x_train)
x_valid_scaled = scaler.transform(x_valid)
x_test_scaled = scaler.transform(x_test)
2.3 手动实现超参数搜索
手动实现超参数搜索:
# 搜索learning rate: [1e-4, 3e-4, 1e-3, 3e-3, 1e-2, 3e-2]
learning_rate = [1e-4, 3e-4, 1e-3, 3e-3, 1e-2, 3e-2]
histories = []
"""
# 我们只以学习率为例
# 学习率的取值有6个
# 以for循环为表达进行学习率的遍历
"""
for lr in learning_rate:# 模型的构建model = keras.models.Sequential([keras.layers.Dense(10, activation='relu', input_shape=x_train.shape[1:]),keras.layers.Dense(1),])# 模型的编译model.compile(loss="mean_squared_error", optimizer = keras.optimizers.SGD(lr))# 回调函数callbacks = [keras.callbacks.EarlyStopping(patience=5, min_delta=1e-2)]# 开始训练history = model.fit(x_train_scaled, y_train,validation_data=(x_valid_scaled, y_valid),epochs = 100,callbacks= callbacks)# histories存放每个学习率下的训练结果histories.append(history)
3. sklearn实现超参数搜索
手动实现超参数搜索是一件很麻烦的事情,我们只举了一个学习率的例子,但是实际情况中,我们要搜索的超参数有很多,这就造成了我们实际上要做的for循环可能不止上述的6个,可能有成百上千次循环。所以我们利用现有的sklearn库进行超参数搜索。
3.1 sklearn实现超参数搜索
"""
我们利用RandomizedSearchCV包实现超参数搜索
1. 转化为sklearn的model
2. 定义参数集合
3. 搜索参数
"""
# 转化为sklearn的model
def build_model(hidden_layers = 1,layer_size = 30,learning_rate = 3e-3):model = keras.models.Sequential()model.add(keras.layers.Dense(layer_size, activation='relu', input_shape=x_train.shape[1:]))for _ in range(hidden_layers - 1):model.add(keras.layers.Dense(layer_size, activation='relu'))model.add(keras.layers.Dense(1))model.compile(loss="mean_squared_error", optimizer = keras.optimizers.SGD(learning_rate))return modelsklearn_model = keras.wrappers.scikit_learn.KerasRegressor(build_model)
callbacks = [keras.callbacks.EarlyStopping(patience=5, min_delta=1e-2)]# 定义参数集合
param_distribution = {"hidden_layers":[1,2,3,4],"layer_size":np.arange(1,100),"learning_rate":[1e-4, 3e-4, 1e-3, 3e-3, 1e-2, 3e-2]
}# 搜索参数
from sklearn.model_selection import RandomizedSearchCV
random_search_CV =RandomizedSearchCV(sklearn_model,param_distribution,n_iter =10, # 生成的超参数组合数n_jobs = 1) # 并行处理的数量
# 开始训练
history = random_search_CV.fit(x_train_scaled, y_train, epochs = 100,validation_data = (x_valid_scaled, y_valid),callbacks = callbacks)
3.2 最优训练结果
我们可以打印出来最优的参数,得分以及模型:
print(random_search_CV.best_params_)
print(random_search_CV.best_score_)
print(random_search_CV.best_estimator_)
3.3 对得到超参数进行验证
我们对得到的最优的模型进行验证:
model = random_search_CV.best_estimator_.model
model.evaluate(x_test_scaled, y_test)
输出为:
TensorFlow2.0(六)--超参数搜索相关推荐
- 11_模型的选择与调优,交叉验证,超参数搜索-网格搜索sklearn.model_selection.GridSearchCV
1.交叉验证 交叉验证:为了让被评估的模型更加准确可信. 交叉验证:将拿到的数据,分为训练和验证集.以下图为例:将数据分成5份,其中一份作为验证集.然后经过5次(组)的测试,每次都更换不同的验证集.即 ...
- python 超参数_OpenCV python sklearn随机超参数搜索的实现
""" 房价预测数据集 使用sklearn执行超参数搜索 """ import matplotlib as mpl import matpl ...
- 干货 | 深度学习模型超参数搜索实用指南
乾明 整理编译自 FloydHub Blog 量子位 报道 | 公众号 QbitAI 在文章开始之前,我想问你一个问题:你已经厌倦了小心翼翼地照看你的深度学习模型吗? 如果是的话,那你就来对地方了. ...
- 机器学习基础|K折交叉验证与超参数搜索
文章目录 交叉验证 交叉验证的概念 K的取值 为什么要用K折交叉验证 Sklearn交叉验证API 超参数搜索 超参数的概念 超参数搜索的概念 超参数搜索的原理 Sklearn超参数搜索API 实例 ...
- NNI(自动超参数搜索)工具环境配置及使用
本文主要介绍如何搭建Microsoft的NNI工具环境以及使用NNI进行Mnist分类任务的超参数搜索. NNI简介 NNI (Neural Network Intelligence) 是一个轻量但强 ...
- 超参数搜索——初始学习率搜索的学习笔记
1 概述 由于南溪只有一块2080Ti,所以暂时不会考虑用强化学习的方法来做~ 南溪目前想要学习的超参数搜索算法有: 网格搜索 随机搜索 贝叶斯搜索 粒子群算法 进化算法 遗传算法 (7. 强化学习) ...
- 超参数搜索——网格搜索和随机搜索
https://cloud.tencent.com/developer/article/1187140
- datawhale课程《transformers入门》笔记6:文本分类、超参搜索
Transformers解决文本分类任务.超参搜索 本文主要内容转自天国之影笔记Task06,之后具体的API进行了一些查询,写了一些说明. 文章目录 Transformers解决文本分类任务.超参搜 ...
- Ray.tune可视化调整超参数Tensorflow 2.0
Ray.tune官方文档 调整超参数通常是机器学习工作流程中最昂贵的部分. Tune专为解决此问题而设计,展示了针对此痛点的有效且可扩展的解决方案. 请注意,此示例取决于Tensorflow 2.0. ...
最新文章
- JUnit 3.8 通过反射测试私有方法
- 初等数学O 集合论基础 第三节 序关系
- SOA实现方式与模式
- hihoCoder 1116 计算 (线段树)
- Java String 类的方法
- 虚拟化四路服务器,IDC:4路及8路服务器现状未来趋势分析
- bzoj1293: [SCOI2009]生日礼物
- WPS简历模板的图标怎么修改_新媒体运营-简历模板范文,【工作经历+项目经验+自我评价】怎么写?...
- 如何入门 Python 爬虫?50集免费全套教程视频让你轻松掌握
- c盘扩容提示簇被标记_电脑C盘爆满飘红?系统卡?试试这两种解决办法
- vid在c语言中的作用,——PVID的作用及和VID的区别
- 2018.9.10 工作日志 猎宝行动
- 中标麒麟 NeoKylin-SDK 里都有哪些库文件
- FOC控制中Clark/iClark和Park/iPark变换及matpoltlib仿真
- iSlide(PPT插件)
- linux下查看服务器型号
- R语言ggridges包绘制漂亮的峰峦图(山脊图)-下篇
- C语言strchr()函数以及strstr()函数的实现
- HTML语言中 blur()方法,HTML DOM blur() 方法
- 小米平板刷机shell怎么退_小米平板2如何一键解锁?刷机教程图解