文章目录

  • 网格搜索法在机器学习和深度学习中的使用
    • 1.项目简介
    • 2.机器学习案例
      • 2.1导入相关库
      • 2.2导入数据
      • 2.3拆分数据集
      • 2.4网格搜索法
      • 2.5使用最优参数重新训练模型
    • 3.深度学习案例
      • 3.1导入相关库
      • 3.2导入数据
      • 3.3拆分数据集
      • 3.4构造模型
      • 3.5网格搜索法
      • 3.6使用最优参数重新训练模型

网格搜索法在机器学习和深度学习中的使用

1.项目简介

  在机器学习和深度学习中,经常需要进行超参数优化。其中,网格搜索法在小批量的超参数优化时被经常使用,该项目是一个多分类(26个类别)问题。
  使用Jupyter Notebook完成,代码和数据文件。

2.机器学习案例

2.1导入相关库

from sklearn import svm  # 支持向量机库
import pandas as pd
from sklearn.model_selection import train_test_split  # 拆分数据集
from sklearn.model_selection import GridSearchCV  # 网格搜索法
from sklearn import metrics  # 用于模型评估

2.2导入数据

filepath = 'E:/Jupyter/Mojin/超参数优化/data/letterdata.csv'
letters = pd.read_csv(filepath)
letters.shape
letters.tail(10)  # 返回数据后10行

  运行结果:

  可以看到,总共有2万个数据,17个指标,其中,letter为目标变量,其余16个指标为特征变量。

2.3拆分数据集

features = letters.columns[1:]
trainX, valX, trainY, valY = train_test_split(letters[features], letters[['letter']], test_size=0.2, random_state=1234)

2.4网格搜索法

C = [0.05, 0.1, 0.15, 0.2]
kernel = ['rbf', 'sigmoid']
Hyperparameter = dict(C=C, kernel=kernel)  # 将超参数范围包成字典grid = GridSearchCV(estimator=svm.SVC(),  # 支持向量机中的SVC模型param_grid=Hyperparameter)
# 模型在训练数据集上的拟合
grid_result = grid.fit(trainX, trainY)# 返回最佳参数组合
print('Best:%f using %s' % (grid_result.best_score_, grid_result.best_params_))

  运行结果:

  可以看到,最佳参数组合:C:‘0.2’,kernel:‘rbf’,准确率:0.914125。

2.5使用最优参数重新训练模型

SVC = svm.SVC(kernel='rbf', C=0.2, gamma='scale').fit(trainX, trainY)
pred = SVC.predict(valX)  # 预测
print('模型预测的准确率:{0}'.format(metrics.accuracy_score(valY, pred)))

  运行结果:

3.深度学习案例

3.1导入相关库

from keras.models import Sequential
from keras.layers import Dense
from keras.layers import Dropout
import pandas as pd
import random  # 随机库
from sklearn.model_selection import GridSearchCV  # 网格搜索法
from keras.wrappers.scikit_learn import KerasClassifier  # Keras库分类模型封装器
from keras.utils.np_utils import to_categorical  # Keras库独热编码

  KerasClassifier是将深度模型包装传递给网格搜索法的接口,具体说明见官方文档。

3.2导入数据

filepath = 'E:/Jupyter/Mojin/超参数优化/data/letterdata.csv'
letters = pd.read_csv(filepath)
letters.shape
letters.tail(10)
# 将letter指标数据类型转换成category
letter = letters['letter'].astype('category')
# 使用标签的编码作为真正的数据
letters['letter'] = letter.cat.codes
letters.tail(10)
# 提取特征变量并转换成数组格式
X = letters[letters.columns[1:]].values
X.shape
Y = letters[['letter']].values
Y.shape

  运行结果:

# 独热编码
Y = to_categorical(Y)
Y.shape

  运行结果:

  可以看到,经过独热编码Y的维度从1变为26。

3.3拆分数据集

random_index = random.sample(list(range(X.shape[0])), X.shape[0])
train_size = int(X.shape[0] * 0.8)
train_index = random_index[:train_size]
test_index = random_index[train_size:]trainX = X[train_index]
trainX.shape
trainY = Y[train_index]testX = X[test_index]
testY = Y[test_index]

3.4构造模型

def create_model(dropout=0.2, depth=2):"""dropout: Dropout层丢弃的比例,一般在0.2~0.5之间depth: 模型隐藏层的层数"""model = Sequential()if depth < 2:raise Exception('至少两层结构')else:model.add(Dense(units=32, input_shape=(16,),  # 特征指标个数:16(trainX.shape[1])activation='relu'))model.add(Dropout(rate=dropout))  # 防止过拟合for i in range(depth - 2):model.add(Dense(units=32,activation='relu'))model.add(Dense(units=26,activation='softmax'))model.compile(loss='categorical_crossentropy',optimizer='rmsprop', metrics=['accuracy'])model.summary()return modelmodel = KerasClassifier(build_fn=create_model)

3.5网格搜索法

# 构建需要优化的超参数范围
depth = [3, 4, 5]
epochs = [30, 50]
batch_size = [100]
param_grid = dict(depth=depth,epochs=epochs,batch_size=batch_size)grid = GridSearchCV(estimator=model, param_grid=param_grid)  # 默认是cv=3,即3折交叉验证
grid.fit(trainX, trainY)
# 返回最佳参数组合
print('Best:%f using %s' % (grid.best_score_, grid.best_params_))

  运行结果:

3.6使用最优参数重新训练模型

model = create_model(depth=4)
model.fit(trainX, trainY, epochs=50, batch_size=100)
predict = model.predict(testX)
loss, acc = model.evaluate(testX, testY)
print('模型预测的准确率:{0}'.format(acc))

  运行结果:

超参数优化:网格搜索法相关推荐

  1. 机器学习之超参数优化 - 网格优化方法(随机网格搜索)

    机器学习之超参数优化 - 网格优化方法(随机网格搜索) 在讲解网格搜索时我们提到,伴随着数据和模型的复杂度提升,网格搜索所需要的时间急剧增加.以随机森林算法为例,如果使用过万的数据,搜索时间则会立刻上 ...

  2. 机器学习之超参数优化 - 网格优化方法(对半网格搜索HalvingSearchCV)

    机器学习之超参数优化 - 网格优化方法(对半网格搜索HalvingSearchCV) 在讲解随机网格搜索之前,我们梳理了决定枚举网格搜索运算速度的因子: 1 参数空间的大小:参数空间越大,需要建模的次 ...

  3. 机器学习、超参数、最优超参数、网格搜索、随机搜索、贝叶斯优化、Google Vizier、Adviser

    机器学习.超参数.最优超参数.网格搜索.随机搜索.贝叶斯优化.Google Vizier.Adviser 最优超参数 选择超参数的问题在于,没有放之四海而皆准的超参数. 因此,对于每个新数据集,我们必 ...

  4. DL之模型调参:深度学习算法模型优化参数之对深度学习模型的超参数采用网格搜索进行模型调优(建议收藏)

    DL之模型调参:深度学习算法模型优化参数之对深度学习模型的超参数采用网格搜索进行模型调优(建议收藏) 目录 神经网络的参数调优 1.神经网络的通病-各种参数随机性 2.评估模型学习能力

  5. 超参数优化(网格搜索和贝叶斯优化)

    超参数优化 1 超参数优化 1.1 网格搜索类 1.1.1 枚举网格搜索 1.1.2 随机网格搜索 1.1.3 对半网格搜索(Halving Grid Search) 1.2 贝叶斯超参数优化(推荐) ...

  6. 【视频】支持向量机SVM、支持向量回归SVR和R语言网格搜索超参数优化实例

    最近我们被客户要求撰写关于SVM的研究报告,包括一些图形和统计输出. 什么是支持向量机 (SVM)? 我们将从简单的理解 SVM 开始. [视频]支持向量机SVM.支持向量回归SVR和R语言网格搜索超 ...

  7. Python集成机器学习:用AdaBoost、决策树、逻辑回归集成模型分类和回归和网格搜索超参数优化

    最近我们被客户要求撰写关于集成机器的研究报告,包括一些图形和统计输出. Boosting 是一类集成机器学习算法,涉及结合许多弱学习器的预测. 视频:从决策树到随机森林:R语言信用卡违约分析信贷数据实 ...

  8. 【机器学习】算法模型自动超参数优化方法

    什么是超参数? 学习器模型中一般有两类参数,一类是可以从数据中学习估计得到,我们称为参数(Parameter).还有一类参数时无法从数据中估计,只能靠人的经验进行设计指定,我们称为超参数(Hyper ...

  9. 全网最全:机器学习算法模型自动超参数优化方法汇总

    什么是超参数? 学习器模型中一般有两类参数,一类是可以从数据中学习估计得到,我们称为参数(Parameter).还有一类参数时无法从数据中估计,只能靠人的经验进行设计指定,我们称为超参数(Hyper ...

最新文章

  1. 开源社区的危机:拒绝被“白嫖”?2大著名项目遭作者破坏
  2. python使用说明书-InfluxDB——python使用手册
  3. Win10系统下Visio安装失败问题
  4. Manjaro 软件源及软件管理相关操作【pacman、pacman-mirrors】整理
  5. matlab mex路径,使用matlab进行mex编译时的路径问题mexopts
  6. 如何使用TensorFlow构建简单的图像识别系统(第2部分)
  7. 多人姿态识别框架——AlphaPose
  8. 剑指_复杂链表的复制(Python)
  9. VBoxGuestAdditions.iso下载地址
  10. cad角度怎么画_初学入门CAD,就这样成精了!
  11. 怎样查找计算机的ip mac地址,如何通过mac地址查ip,教您Mac怎么查看ip地址
  12. 计算机开机后黑屏鼠标显示桌面图标,电脑开机后黑屏怎么解决只显示鼠标
  13. 苹果5G芯片研发失败:继续依赖高通,还要担心被起诉?
  14. pandas - 特别篇(关于读取DataFrame数据显示不完全的解决办法)
  15. 【JavaMap接口】特点实现类HashMap常用方法
  16. MATLAB/Simulink封装子模块图片显示和参数输出设置问题
  17. 多核计算机是指有多个cpu,多核和多个CPU有什么区别?
  18. 《当程序员的那些狗日日子》(四十三)绝缘空间
  19. 【Linux驱动】驱动设计硬件基础----串口、I2C、SPI、以太网接口、PCIE
  20. 使用STM32cubeMX写一个简单的LED闪烁

热门文章

  1. tit-al00 android 6,华为TIT-AL00入网 MTK6735四核全网通手机
  2. 控制系统分析常用命令
  3. redis保护模式的报错
  4. 有5个人坐在一起,问第五个人多少岁?他说比第4个人大2岁。问第4个人岁数,他说比第 3个人大2岁。问第三个人,又说比第2人大两岁。问第2个人,说比第一个人大两岁。最后 问第一个人,他说是10岁。请问第
  5. 铝模板18个标准化安装步骤,照此做法错不了
  6. destoon模板安装方法
  7. 【Requests】获取本地的请求IP和域名解析的IP
  8. php网站页面显示源码,用PHP显示网站的源代码
  9. 计算机基础之计算机的前沿技术
  10. ORA-00937:不是单组分组函数 ORA-22818:这里不允许出现子查询表达式