文章目录

  • 1.10 交叉验证,网格搜索
    • 学习目标
    • 1 什么是交叉验证(cross validation)
      • 1.1 分析
      • 1.2 为什么需要交叉验证
    • 2 什么是网格搜索(Grid Search)
    • 3 交叉验证,网格搜索(模型选择与调优)API:
    • 4 鸢尾花案例增加K值调优
    • 5 总结

1.10 交叉验证,网格搜索

学习目标

  • 目标

    • 知道交叉验证、网格搜索的概念
    • 会使用交叉验证、网格搜索优化训练模型

1 什么是交叉验证(cross validation)

交叉验证:将拿到的训练数据,分为训练和验证集。以下图为例:将数据分成4份,其中一份作为验证集。然后经过4次(组)的测试,每次都更换不同的验证集。即得到4组模型的结果,取平均值作为最终结果。又称4折交叉验证。

1.1 分析

我们之前知道数据分为训练集和测试集,但是**为了让从训练得到模型结果更加准确。**做以下处理

  • 训练集:训练集+验证集
  • 测试集:测试集

1.2 为什么需要交叉验证

交叉验证目的:为了让被评估的模型更加准确可信

问题:这个只是让被评估的模型更加准确可信,那么怎么选择或者调优参数呢?

2 什么是网格搜索(Grid Search)

通常情况下,有很多参数是需要手动指定的(如k-近邻算法中的K值),这种叫超参数。但是手动过程繁杂,所以需要对模型预设几种超参数组合。每组超参数都采用交叉验证来进行评估。最后选出最优参数组合建立模型。

3 交叉验证,网格搜索(模型选择与调优)API:

  • sklearn.model_selection.GridSearchCV(estimator, param_grid=None,cv=None)

    • 对估计器的指定参数值进行详尽搜索

    • estimator:估计器对象

    • param_grid:估计器参数(dict){“n_neighbors”:[1,3,5]}

    • cv:指定几折交叉验证

    • fit:输入训练数据

    • score:准确率

    • 结果分析:

      • best_score__:在交叉验证中验证的最好结果_
      • best_estimator_:最好的参数模型
      • cv_results_:每次交叉验证后的验证集准确率结果和训练集准确率结果

4 鸢尾花案例增加K值调优

  • 使用GridSearchCV构建估计器
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split,GridSearchCV
from sklearn.preprocessing import StandardScaler
from sklearn.neighbors import KNeighborsClassifier# 1、获取数据集iris = load_iris()# 2、数据基本处理 -- 划分数据集x_train, x_test, y_train, y_test = train_test_split(iris.data, iris.target, random_state=22)# 3、特征工程:标准化# 实例化一个转换器类transfer = StandardScaler()# 调用fit_transformx_train = transfer.fit_transform(x_train)x_test = transfer.transform(x_test)# 4、KNN预估器流程#  4.1 实例化预估器类estimator = KNeighborsClassifier()# 4.2 模型选择与调优——网格搜索和交叉验证# 准备要调的超参数param_dict = {"n_neighbors": [1, 3, 5]}estimator = GridSearchCV(estimator, param_grid=param_dict, cv=3)# 4.3 fit数据进行训练estimator.fit(x_train, y_train)# 5、评估模型效果# 方法a:比对预测结果和真实值y_predict = estimator.predict(x_test)print("比对预测结果和真实值:\n", y_predict == y_test)# 方法b:直接计算准确率score = estimator.score(x_test, y_test)print("直接计算准确率:\n", score)
  • 然后进行评估查看最终选择的结果和交叉验证的结果
print("在交叉验证中验证的最好结果:\n", estimator.best_score_)
print("最好的参数模型:\n", estimator.best_estimator_)
print("每次交叉验证后的准确率结果:\n", estimator.cv_results_)
  • 最终结果
 比对预测结果和真实值:[ True  True  True  True  True  True  True False  True  True  True  TrueTrue  True  True  True  True  True False  True  True  True  True  TrueTrue  True  True  True  True  True  True  True  True  True  True  TrueTrue  True]直接计算准确率:0.947368421053在交叉验证中验证的最好结果:0.973214285714最好的参数模型:KNeighborsClassifier(algorithm='auto', leaf_size=30, metric='minkowski',metric_params=None, n_jobs=1, n_neighbors=5, p=2,weights='uniform')每次交叉验证后的准确率结果:{'mean_fit_time': array([ 0.00114751,  0.00027037,  0.00024462]), 'std_fit_time': array([  1.13901511e-03,   1.25300249e-05,   1.11011951e-05]), 'mean_score_time': array([ 0.00085751,  0.00048693,  0.00045625]), 'std_score_time': array([  3.52785082e-04,   2.87650037e-05,   5.29673344e-06]), 'param_n_neighbors': masked_array(data = [1 3 5],mask = [False False False],fill_value = ?), 'params': [{'n_neighbors': 1}, {'n_neighbors': 3}, {'n_neighbors': 5}], 'split0_test_score': array([ 0.97368421,  0.97368421,  0.97368421]), 'split1_test_score': array([ 0.97297297,  0.97297297,  0.97297297]), 'split2_test_score': array([ 0.94594595,  0.89189189,  0.97297297]), 'mean_test_score': array([ 0.96428571,  0.94642857,  0.97321429]), 'std_test_score': array([ 0.01288472,  0.03830641,  0.00033675]), 'rank_test_score': array([2, 3, 1], dtype=int32), 'split0_train_score': array([ 1.        ,  0.95945946,  0.97297297]), 'split1_train_score': array([ 1.        ,  0.96      ,  0.97333333]), 'split2_train_score': array([ 1.  ,  0.96,  0.96]), 'mean_train_score': array([ 1.        ,  0.95981982,  0.96876877]), 'std_train_score': array([ 0.        ,  0.00025481,  0.0062022 ])}

5 总结

  • 交叉验证【知道】

    • 定义:

      • 将拿到的训练数据,分为训练和验证集
      • *折交叉验证
    • 分割方式:
      • 训练集:训练集+验证集
      • 测试集:测试集
    • 为什么需要交叉验证
      • 为了让被评估的模型更加准确可信
      • 注意:交叉验证不能提高模型的准确率
  • 网格搜索【知道】
    • 超参数:

      • sklearn中,需要手动指定的参数,叫做超参数
    • 网格搜索就是把这些超参数的值,通过字典的形式传递进去,然后进行选择最优值
  • api【知道】
    • sklearn.model_selection.GridSearchCV(estimator, param_grid=None,cv=None)

      • estimator – 选择了哪个训练模型
      • param_grid – 需要传递的超参数
      • cv – 几折交叉验证

机器学习算法------1.10 交叉验证,网格搜索(交叉验证,网格搜索(模型选择与调优)API、鸢尾花案例增加K值调优)相关推荐

  1. 简单粗暴理解与实现机器学习之K-近邻算法(十):交叉验证,网格搜索(模型选择与调优)API、鸢尾花案例增加K值调优

    K-近邻算法 文章目录 K-近邻算法 学习目标 1.10 交叉验证,网格搜索 1 什么是交叉验证(cross validation) 1.1 分析 1.2 为什么需要交叉验证 **问题:那么这个只是对 ...

  2. 机器学习算法——线性回归的详细介绍 及 利用sklearn包实现线性回归模型

    目录 1.线性回归简介 1.1 线性回归应用场景 1.2 什么是线性回归 1.2.1 定义与公式 1.2.2 线性回归的特征与目标的关系分析 2.线性回归api初步使用 2.1 线性回归API 2.2 ...

  3. 机器学习算法拾遗:(七)隐马尔科夫模型(前向后向算法、鲍姆-韦尔奇算法、维特比算法)

    1.隐马尔科夫模型HMM 隐马尔科夫模型的图结构如下 从上图中主要有两个信息:一是观测变量xi 仅仅与与之对应的状态变量yi 有关:二是当前的状态变量yi 仅仅与它的前一个状态变量yi-1 有关. 隐 ...

  4. 机器学习算法——系统性的学会使用 K近邻算法(KNN)

    目录 1.K-近邻算法简介 1.1 什么是K-近邻算法 1.2 K-近邻算法(KNN)概念 (1)定义: (2)距离公式: 1.3 电影类型分析 1.4 KNN算法流程总结 2.k近邻算法api初步使 ...

  5. 【机器学习】K-近邻算法-模型选择与调优

    前言 在KNN算法中,k值的选择对我们最终的预测结果有着很大的影响 那么有没有好的方法能够帮助我们选择好的k值呢? 模型选择与调优 目标 说明交叉验证过程 说明参数搜索过程 应用GirdSearchC ...

  6. 【06】机器学习算法——评估方法总结

    分类的评估方法 1.精确率与召回率 (1)混淆矩阵 在分类任务下,预测结果(Predicted Condition)与正确标记(True Condition)之间存在四种不同的组合,构成混淆矩阵(适用 ...

  7. 机器学习算法(1)—— K-近邻算法

    K-近邻算法 1 KNN介绍 2 KNN的初步使用 3 距离度量 3.1 基本性质 3.2 常见距离公式 3.3 距离属性 4 k值选择 5 KNN优化-kd树 5.1 kd树简介 5.2 构造方法 ...

  8. 通过交叉验证寻找K近邻算法的最优K值

    问题引出 之前我们使用K近邻算法尝试寻找用户年龄与预估薪资之间的某种相关性,以及他们是否有购买SUV的决定.主要代码如下: from sklearn.neighbors import KNeighbo ...

  9. 机器学习算法小结与收割offer遇到的问题

    机器学习是做NLP和计算机视觉这类应用算法的基础,虽然现在深度学习模型大行其道,但是懂一些传统算法的原理和它们之间的区别还是很有必要的.可以帮助我们做一些模型选择.本篇博文就总结一下各种机器学习算法的 ...

最新文章

  1. 每天一个linux命令(8):cp 命令
  2. centos7重启命令_centos7单用户模式更改root一种方法
  3. Gradle修改缓存路径 和 Gradle修改Maven仓库地址
  4. Java EE 7:新增功能???
  5. LeetCode || Copy List with Random Pointer
  6. 我的第一款 Drone 插件
  7. android应用开发(26)---Parcelables 和 Bundles
  8. 【李宏毅2020 ML/DL】P16 PyTorch Tutorial | 最后提及了 apex.amp
  9. 学生如何免费使用Jetbrains旗下包含Pycharm等开发工具(中文详细教程)
  10. LeetCode.495 Teemo Attacking
  11. 【C语言】指针的理解(乱七芭蕉)
  12. 基于TI Sitara系列AM437x ARM Cortex-A9核心板 处理器
  13. 休假管理系统的问题描述与词汇表
  14. Docker学习笔记八:删除镜像构建私有Registry
  15. 全球疫情形势动态地图展示(超帅超好玩的python动图)
  16. [转载]【职场新人必看】领导谆谆寄语
  17. 软考中级怎么选?如何备考?
  18. CSDN编程挑战赛第六期—参赛心得+题解
  19. 微信考勤签到 php,【投稿】微信签到打卡领积分源码,每日积分签到
  20. mysql随机生成数据并插入_python生成随机数据插入mysql

热门文章

  1. 视频小程序风口,行业的新机遇
  2. js隐藏html页面元素高度,如何使用jQuery获取隐藏元素的高度?
  3. vue的生命周期 (11个钩子函数)看了都能懂的
  4. Http会话保持机制:Cookie、Session和Token
  5. rtmp服务器信息,搭建RTMP服务器
  6. Kubernetes 使用 PVC 持久卷后,持久卷内数据丢失问题
  7. JavaScript中监听程序运行时间
  8. 群聊私聊天建群社交即时通讯H5系统开发
  9. 图像算法---磨皮算法研究汇总
  10. 复试c语言笔试题,2014年暨南大学C语言考研复试试题(回忆版)