文章目录

  • 5 模型选择及调优
    • 5.1 数据增强
    • 5.2 过拟合
    • 5.3 交叉验证
    • 5.4 超参数搜索——网格搜索

5 模型选择及调优

5.1 数据增强

有时候,你和你的老板说你数据不够,它是不会理你的。老板会发问:为什么你是做机器学习的要那么多数据干嘛,让机器去做不就行了。

对于这种问题有时候即使无语但你也不能正面拆穿,否则你的工作就不用干了。而为了解决数据集不足的问题,我们通常会采用数据增强

这个名词看似高大上,实际上就是把数据集经过某些变换,从而产生新的数据集。

这种方法多用于图片识别上,将图片结果左右对称变换,或反转,或偏移角度来达到拥有更多数据集的目的。

5.2 过拟合

统计学习的目的是使学到的模型不仅对已知数据而且对未知数据都能有很好的预测能力。不同的学习方法会给出不同的模型。当损失函数给定时,基于损失函数的模型的训练误差和模型的测试误差就自然成为学习方法评估的标准。注意,统计学习方法具体采用的损失函数未必是评估时使用的损失函数。当然,让两者一致是比较理想的。

训练误差的大小,对判断给定的问题是不是一个容易学习的问题是有意义的,但本质上不重要。测试误差反映了学习方法对未知的测试数据集的预测能力,是学习中的重要概念。显然,给定两种学习方法,测试误差小的方法具有更好的预测能力,是更有效的方法。

通常将学习方法对未知数据的预测能力称为泛化能力

当假设空间含有不同复杂度的模型时,就要面临模型选择的问题。我们希望选择或学习一个合适的模型。如果在假设空间中存在真模型,那么所选择的模型应该逼近真模型。具体地,所选择的模型要与真模型的参数个数相同,所选择的模型的参数向量与真模型的参数向量相近。

如果一味追求提高对训练数据的预测能力,所选模型的复杂度则往往会比真模型更高。这种现象称为过拟合。过拟合是指学习时选择的模型所包含的参数过多,以至于出现这一模型对已知数据预测的很好,但对未知数据预测得很差的现象。可以说模型选择旨在避免过拟合并提高模型的预测能力。

也就是说,上述的话翻译成人话就是,我们不要那种能够完全贴合训练集的函数,那种函数训练出来训练集在上面跑挺牛逼,一到测试集就不行了。我们需要的是那种在训练集跑的差不多,对于测试集跑出来效果也很好的那种函数。

现在我们有以上的数据集,我们要选择一个模型去拟合真模型,也就是M=0时图中画的曲线,那条曲线即为真模型。当然了,根据我们上面所说,我们要的是做到“差不多”即可,我们不要精度完全一样或者超过真模型。

当我们M=1,选择的是一条直线,这种模型其实是罔顾事实的做法,我们完全不考虑拟合的效果,一上来就乱套模型,这样会导致拟合数据的效果贼差。这种在古老的文献中称为“欠拟合”现象。

当M=3时,我们选择的模型已经接近数据所对应的真模型了,已经几乎拟合了,这时候的模型符合测试误差最小的学习目的了。

当M=9时,这时候就是所谓的过拟合现象了,由于参数设置过多,导致这条曲线几乎穿过了我们已知的所有的数据点。的确,他对已知数据预测很好(穿过了嘛),但是他对未知数据却预测很差(说不定下一个点不在这条线上,这就导致前面预测很准,后面误差越来越大)。

简单来说,想解决过拟合,实际上无非就是选择复杂度适当的模型,以达到使测试误差最小的学习目的。我们常用的模型选择方法:正则化交叉验证。关于正则化的学习我们在后面的学习中会接触到,我们这里要提到的是关于交叉验证

5.3 交叉验证

交叉验证(cross validation)简单来说就是将拿到的训练数据,再次分为训练集和验证集。其中验证集和测试集的功能一样,都是对训练集训练出来的模型进行评估。而交叉验证方法就是将训练集划分为训练集+测试集,测试集通常占1份,而训练集占k-1份,通过四次测试,每次更换不同的验证集来达到在有限的数据集中得出不同精度,得出4组模型结果;得出结果后取平均值作为最终结果。我们把上述的做法称为K折交叉验证。

K折交叉验证示意图如下:

虽然K折交叉验证能够在k均值算法中起到优化K值的效果,那么如何来选取K值呢?什么时候才是最好呢?这就要进入我们的下一小节了。

5.4 超参数搜索——网格搜索

通常情况下,有很多参数时需要手动指定的(如K-近邻算法中的K值),这种叫做超参数。但是手动指定不准且计算复杂,所以我们要对模型预设几种超参数组合,每组超参数都采用交叉验证来进行评估,最后选出最优参数组合建立模型。

知道原理了就是动手写代码的时刻,我们看一下sklearn中有哪些库供我们调用。

sklearn.model_selection.GridSearchCV(estimator,param_grid = None,cv = None)

该API可以对估计器的指定参数值进行详尽搜索

  • estimator:估计器对象
  • param——grid:估计器参数,对应到knn中可以传入多个k值建立多个模型,以此评估哪个模型最好,传入时要用字典形式,如{“n_neighbors”:[1,3,5]}
  • cv:指定几折交叉验证,常用10折交叉验证

fit():输入训练数据

score():准确率

通过调用以下属性可以查看结果:

  • 最佳参数:best_params_
  • 最佳结果:best_score_
  • 最佳估计器:best_estimator_
  • 交叉验证结果:cv_results_

知道上面的原理,让我们对前一讲的KNN分类鸢尾花代码优化一下吧!

from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from sklearn.neighbors import KNeighborsClassifier
from sklearn.model_selection import GridSearchCVdef knn_iris():"""用KNN算法对鸢尾花进行分类"""# 1 导入数据集iris = load_iris()# 2 划分数据集x_train, x_test, y_train, y_test = train_test_split(iris.data, iris.target, random_state=6)# 3 特征工程:标准化transfer = StandardScaler()x_train = transfer.fit_transform(x_train)x_test = transfer.transform(x_test)# 4 实例化KNN算法预估器estimator = KNeighborsClassifier()# 选用合适的K值来选择多个模型param_dict = {"n_neighbors": [1, 3, 5, 7, 9, 11]}# 加入超参数网格搜索和交叉验证estimator = GridSearchCV(estimator, param_grid=param_dict, cv=10)estimator.fit(x_train, y_train)# 5 模型评估# 方法1 直接比对真实值和预测值y_predict = estimator.predict(x_test)print("y_predict:\n", y_predict)print("直接对比真实值和预测值:\n", y_test == y_predict)# 方法2 计算准确率score = estimator.score(x_test, y_test)print("准确率为:\n", score)# 查看最佳参数print("KNN模型最佳参数:\n", estimator.best_params_)print("最佳结果:\n", estimator.best_score_)print("最佳估计器:\n", estimator.best_estimator_)print("交叉验证结果:\n", estimator.cv_results_)# 调用方法
knn_iris()

机器学习的练功方式(五)——模型选择及调优相关推荐

  1. 【机器学习】K-近邻算法-模型选择与调优

    前言 在KNN算法中,k值的选择对我们最终的预测结果有着很大的影响 那么有没有好的方法能够帮助我们选择好的k值呢? 模型选择与调优 目标 说明交叉验证过程 说明参数搜索过程 应用GirdSearchC ...

  2. 机器学习算法------1.10 交叉验证,网格搜索(交叉验证,网格搜索(模型选择与调优)API、鸢尾花案例增加K值调优)

    文章目录 1.10 交叉验证,网格搜索 学习目标 1 什么是交叉验证(cross validation) 1.1 分析 1.2 为什么需要交叉验证 2 什么是网格搜索(Grid Search) 3 交 ...

  3. 机器学习-分类算法-模型选择与调优09

    模型选择与调优 交叉验证:为了让被评估的模型更加准确可信 网格搜索 from sklearn.neighbors import KNeighborsClassifier from sklearn.mo ...

  4. 简单粗暴理解与实现机器学习之K-近邻算法(十):交叉验证,网格搜索(模型选择与调优)API、鸢尾花案例增加K值调优

    K-近邻算法 文章目录 K-近邻算法 学习目标 1.10 交叉验证,网格搜索 1 什么是交叉验证(cross validation) 1.1 分析 1.2 为什么需要交叉验证 **问题:那么这个只是对 ...

  5. 机器学习——分类算法之K近邻+朴素贝叶斯,模型选择与调优

    目录 K-近邻算法 定义 如何求距离? 数据预处理--标准化 sklearn k-近邻算法API 案例--预测入住位置 分类问题 数据处理 k近邻算法相关问题 k值取多大?有什么影响? 优缺点 应用场 ...

  6. python人工智能——机器学习——模型选择与调优

    1.交叉验证 交叉验证:为了让被评估的模型更加准确可信 交叉验证过程 交叉验证:将拿到的数据,分为训练和验证集. 以下图为例:将数据分成5份,其中一份作为验证集.然后经过5次(组)的测试,每次都更换不 ...

  7. XGBoost学习(五):参数调优

    XGBoost学习(一):原理 XGBoost学习(二):安装及介绍 XGBoost学习(三):模型详解 XGBoost学习(四):实战 XGBoost学习(五):参数调优 XGBoost学习(六): ...

  8. JVM内存模型和性能调优:系列文章 - 导读

    0.JVM课程总体介绍 学习 Java 虚拟机能深入地理解 Java 这门语言,想要深入学习java的各种细节,很多时候你要深入到字节码层次去分析,你才能得到准确的结论,通过学习JVM你了解JVM历史 ...

  9. R语言使用caret包对GBM模型进行参数调优实战:Model Training and Parameter Tuning

    R语言使用caret包对GBM模型进行参数调优实战:Model Training and Parameter Tuning 目录 R语言使用caret包对GBM模型进行参数调优实战:Model Tra ...

最新文章

  1. 计算机累加器有加法器功能吗,累加器-累加器ACC的作用
  2. SubBuilder使用
  3. Jupyter notebook入门教程(下)
  4. slice 定义和用法
  5. Spring_day1
  6. html5类似ios下拉选择器,iosselect:一个js picker项目,在H5中实现IOS的select下拉框效果 - mufc-go...
  7. 「基因组组装」用AMOS/minimus2合并两个contig
  8. 复盘:C语言中int a[][3]={1,2,3,4,5,6,7,8}什么意思,int a[3][]又是什么意思,结果为10的是
  9. 基于数字电路典型分频电路设计
  10. 获取所有打印机,设置默认打印机,获取默认打印机
  11. USB Type-C数据线美国新标准UL9990报告检测项目
  12. 低度酒的诸神之战,能分出胜负吗?
  13. 展示csdn的云服务
  14. 爱上开源之一款查询docker容器启动命令的工具
  15. 如何使用谷歌浏览器远程调试安卓/ios真机H5应用?
  16. Future.get()抛出ExecutionException或InterruptedException?
  17. 问:adb连接逍遥模拟器时,报offline。
  18. Linux进程KILL--Quit,INT,HUP,QUIT,和TERM、PIPE的解释
  19. 如何删除百度快照?百度快照是什么?百度快照优化是什么意思?
  20. layui table动态选中_mac动态图片编辑工具-Motion

热门文章

  1. 对于个人(注册表)与团队(团队表)(两张表没有关联)的展示与可空判断
  2. Udi Dahan对于业务逻辑重用以及微服务方面的观点
  3. mysql导入导出数据
  4. 谷歌笔试题(Google十二岁生日晚)
  5. [书籍推荐]《软件设计精要与模式(第2版)》-张逸——提高设计模式及软件设计的方法...
  6. python初学者_面向初学者的20种重要的Python技巧
  7. 医疗大数据处理流程_我们需要数据来大规模改善医疗流程
  8. leetcode 1319. 连通网络的操作次数(并查集)
  9. ionic4 打包ios_学习Ionic 4并开始创建iOS / Android应用
  10. react和react2_为什么React16是React开发人员的福气