网格搜索算法是一种通过遍历给定的参数组合来优化模型表现的方法。

以决策树为例,当我们确定了要使用决策树算法的时候,为了能够更好地拟合和预测,我们需要调整它的参数。在决策树算法中,我们通常选择的参数是决策树的最大深度。

于是我们会给出一系列的最大深度的值,比如 {‘max_depth’: [1,2,3,4,5]},我们会尽可能包含最优最大深度。

不过,我们如何知道哪一个最大深度的模型是最好的呢?我们需要一种可靠的评分方法,对每个最大深度的决策树模型都进行评分,这其中非常经典的一种方法就是交叉验证,下面我们就以K折交叉验证为例,详细介绍它的算法过程。

首先我们先看一下数据集是如何分割的。我们拿到的原始数据集首先会按照一定的比例划分成训练集和测试集。比如下图,以8:2分割的数据集:

训练集用来训练我们的模型,它的作用就像我们平时做的练习题;测试集用来评估我们训练好的模型表现如何,它的作用像我们做的高考题,这是要绝对保密不能提前被模型看到的。

因此,在K折交叉验证中,我们用到的数据是训练集中的所有数据。我们将训练集的所有数据平均划分成K份(通常选择K=10),取第K份作为验证集,它的作用就像我们用来估计高考分数的模拟题,余下的K-1份作为交叉验证的训练集。

对于我们最开始选择的决策树的5个最大深度 ,以 max_depth=1 为例,我们先用第2-10份数据作为训练集训练模型,用第1份数据作为验证集对这次训练的模型进行评分,得到第一个分数;然后重新构建一个 max_depth=1 的决策树,用第1和3-10份数据作为训练集训练模型,用第2份数据作为验证集对这次训练的模型进行评分,得到第二个分数……以此类推,最后构建一个 max_depth=1 的决策树用第1-9份数据作为训练集训练模型,用第10份数据作为验证集对这次训练的模型进行评分,得到第十个分数。于是对于 max_depth=1 的决策树模型,我们训练了10次,验证了10次,得到了10个验证分数,然后计算这10个验证分数的平均分数,就是 max_depth=1 的决策树模型的最终验证分数。

对于 max_depth = 2,3,4,5 时,分别进行和 max_depth=1 相同的交叉验证过程,得到它们的最终验证分数。然后我们就可以对这5个最大深度的决策树的最终验证分数进行比较,分数最高的那一个就是最优最大深度,我们利用最优参数在全部训练集上训练一个新的模型,整个模型就是最优模型。

下面提供一个简单的利用决策树预测乳腺癌的例子:

from sklearn.model_selection import GridSearchCV, KFold, train_test_split

from sklearn.metrics import make_scorer, accuracy_score

from sklearn.tree import DecisionTreeClassifier

from sklearn.datasets import load_breast_cancer

data = load_breast_cancer()

X_train, X_test, y_train, y_test = train_test_split(

data['data'], data['target'], train_size=0.8, random_state=0)

regressor = DecisionTreeClassifier(random_state=0)

parameters = {'max_depth': range(1, 6)}

scoring_fnc = make_scorer(accuracy_score)

kfold = KFold(n_splits=10)

grid = GridSearchCV(regressor, parameters, scoring_fnc, cv=kfold)

grid = grid.fit(X_train, y_train)

reg = grid.best_estimator_

print('best score: %f'%grid.best_score_)

print('best parameters:')

for key in parameters.keys():

print('%s: %d'%(key, reg.get_params()[key]))

print('test score: %f'%reg.score(X_test, y_test))

import pandas as pd

pd.DataFrame(grid.cv_results_).T

直接用决策树得到的分数大约是92%,经过网格搜索优化以后,我们可以在测试集得到95.6%的准确率:

best score: 0.938462

best parameters:

max_depth: 4

test score: 0.956140

转载自https://zhuanlan.zhihu.com/p/25637642

kfold_机器学习gridsearchcv(网格搜索)和kfold validation(k折验证)相关推荐

  1. 机器学习之网格搜索技术,如何在Auto-sklearn中应用网格搜索技术

    文章目录 一,机器学习中的网格搜索技术是怎么回事 二,通俗解释 三,在一般情况下使用网格搜索技术 四,GridSearchCV网格搜索技术的原理 五,如何在Auto-sklearn中使用网格搜索技术 ...

  2. 2022-1-17第三章机器学习基础--网格搜索超参数优化、决策树、随机森林

    交叉验证与网格搜索 ①交叉验证(训练集划分-训练集.验证集)–将所有数据分成n等分-并不具备调参能力 4等分就是4折交叉验证:一般采用10折交叉验证 ②网格搜索-调参数(与交叉验证一同使用) 如果有多 ...

  3. sklearn GridSearchCV网格搜索案例与代码

    文章目录 准备数据 网格搜索参数 评估结果 全部代码 需要的包如下: import pandas as pd import numpy as np from sklearn.ensemble impo ...

  4. 机器学习之网格搜索调参sklearn

    网格搜索 网格搜索 GridSearchCV我们在选择超参数有两个途径:1凭经验:2选择不同大小的参数,带入到模型中,挑选表现最好的参数.通过途径2选择超参数时,人力手动调节注意力成本太高,非常不值得 ...

  5. 交叉验证 cross validation 与 K-fold Cross Validation K折叠验证

    交叉验证,cross validation是机器学习中非常常见的验证模型鲁棒性的方法.其最主要原理是将数据集的一部分分离出来作为验证集,剩余的用于模型的训练,称为训练集.模型通过训练集来最优化其内部参 ...

  6. StratifiedKFold和KFold(5折验证)交叉验证的联系和区别Python实例

    Kfold: 将全部训练集分成k个不相交的子集,假设训练集的训练样例个数为m,那么每一个子集有m/k个训练样例,比如[1,2,3,4,5,6]分成两份,则第一份可能为[1,3,5],第二份[2,4,6 ...

  7. python网格搜索核函数_机器学习笔记——模型调参利器 GridSearchCV(网格搜索)参数的说明...

    算法 数据结构 机器学习笔记--模型调参利器 GridSearchCV(网格搜索)参数的说明 GridSearchCV,它存在的意义就是自动调参,只要把参数输进去,就能给出最优化的结果和参数.但是这个 ...

  8. python实现留一法_数据分割:留出法train_test_split、留一法LeaveOneOut、GridSearchCV(交叉验证法+网格搜索)、自助法...

    1.10 交叉验证,网格搜索 学习目标 目标 知道交叉验证.网格搜索的概念 会使用交叉验证.网格搜索优化训练模型 1 什么是交叉验证(cross validation) 交叉验证:将拿到的训练数据,分 ...

  9. 加载svr模型_机器学习XGBoost实战,网格搜索自动调参,对比随机森林,线性回归,SVR【完整代码(含注释)+数据集见原文链接】...

    建议:阅读2020.8.7的文章,完全了解GDBT和XGBT的原理. 机器学习- XGBoost,GDBT[过程:决策树,集成学习,随机森林,GDBT,XGBT,LightGBM] 本次实践内容: 数 ...

最新文章

  1. springboot(三):Spring boot中Redis的使用
  2. golang 标准库间依赖的可视化展示
  3. 中兴智能视觉大数据:人脸识别技术目前处于“用的不够,用的不好”
  4. react组件之间传递信息/react组件之间值的传递
  5. 跨站脚本攻击之反射型XSS漏洞【转载】
  6. swagger整合springMVC
  7. 计算机房精密空调术语,机房空调常用单位及计算公式
  8. python列表内存分配_python 列表, 元组内存分配优化
  9. css样式的补充:鼠标悬停字体变大和改变颜色
  10. EfficientNet 简介
  11. 小爱同学指令大全_小爱同学指令
  12. combo box使用
  13. Wireshark之流量包分析+日志分析 (护网:蓝队)web安全 取证 分析黑客攻击流程(上篇)
  14. 【Json转换为实体类】
  15. 数据结构 期末复习主观题练习题(答案版)
  16. 云服务器和一般服务器之间有什么区别?
  17. Docker:overlay2浅析
  18. pd快充线无法连接计算机,一种PD快充高清连接线的制作方法
  19. ThinkPHP6.0使用twig作为模板引擎及自定义过滤器
  20. MySQL--在批处理中执行SQL

热门文章

  1. python将文本转化成语音并播放
  2. 陕西活性炭需求分析_20212027年中国粉末活性炭行业市场发展现状调研与投资趋势前景分析报告...
  3. socket通信流程图
  4. linux大小写敏感和windows大小写不敏感(忽略大小写)导致的直接拷贝文件文件名冲突问题(需要打tar包再分享)
  5. visio 程序设计流程图合符号含义
  6. WEB服务器和HTTP服务器和应用服务器的区别?(web服务器就是HTTP服务器)为什么要把Web服务器独立配置,和应用程序服务器一前一后?
  7. python PyQt5 QSlider类(滑块)
  8. python codecs模块(用于执行编码转换之类的)
  9. Intel Realsense D435 将深度图的灰度图映射为彩色图,打印输出灰度图或彩色图
  10. python opencv cv2.resize()函数