文章目录

  • 1.前言
  • 2.非交叉验证实验
  • 3.交叉验证实验
  • 4.准确率与平方误差
    • 4.1.准确率实验
    • 4.2.均方误差实验
  • 5.Learning curve 检查过拟合
    • 5.1.加载必要模块
    • 5.2.加载数据
    • 5.3.调用learning_curve
    • 5.4.learning_curve可视化
  • 6.validation_curve 检查过拟合

1.前言

Sklearn 中的 Cross Validation (交叉验证)对于我们选择正确的 Model 和 Model 的参数是非常有帮助的, 有了它的帮助,我们能直观的看出不同 Model 或者参数对结构准确度的影响。

2.非交叉验证实验

from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from sklearn.neighbors import KNeighborsClassifieriris = load_iris()
x = iris.data
y = iris.targetx_train, x_test, y_train, y_test = train_test_split(x,y,random_state = 12)
knn = KNeighborsClassifier()
knn.fit(x_train, y_train)
print(knn.score(x_test, y_test))#输出
0.9210526315789473

可以看到基础验证的准确率为0.9210526315789473

3.交叉验证实验

from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from sklearn.neighbors import KNeighborsClassifier
from sklearn.model_selection import cross_val_score # K折交叉验证模块iris = load_iris()
x = iris.data
y = iris.targetx_train, x_test, y_train, y_test = train_test_split(x, y, random_state = 16)
knn = KNeighborsClassifier()
knn.fit(x_train, y_train)
scores = cross_val_score(knn, x, y,cv=5,scoring='accuracy')   #使用K折交叉验证模块
print(scores)    #将5次的预测准确率打印出
print(scores.mean())   #将5次的预测准确平均率打印出#输出
[0.96666667 1.         0.93333333 0.96666667 1.        ]
0.9733333333333334

4.准确率与平方误差

4.1.准确率实验

from sklearn.datasets import load_iris
from sklearn.neighbors import KNeighborsClassifier
from sklearn.model_selection import cross_val_score
import matplotlib.pyplot as pltiris = load_iris()
x = iris.data
y = iris.targetk_scores = []
for k in range(1,41):knn = KNeighborsClassifier(n_neighbors = k)scores = cross_val_score(knn,x,y,cv=10,scoring = 'accuracy')   #分类 十折交叉验证#loss = -cross_val_score(knn,x,y,cv=10,scoring = 'mean_squared_error')   #回归k_scores.append(scores.mean())plt.plot(range(1,41),k_scores)
plt.xlabel("Value of K for KNN")
plt.ylabel("Cross Validated Accuracy")
plt.show()


从图中可以得知,选择12~20的k值最好。高过20之后,准确率开始下降则是因为过拟合(Over fitting)的问题。

4.2.均方误差实验

一般来说平均方差(Mean squared error)会用于判断回归(Regression)模型的好坏。

from sklearn.datasets import load_iris
from sklearn.neighbors import KNeighborsClassifier
from sklearn.model_selection import cross_val_score
import matplotlib.pyplot as pltiris = load_iris()
x = iris.data
y = iris.targetfor k in range(1,41):knn = KNeighborsClassifier(n_neighbors = k)#scores = cross_val_score(knn,x,y,cv=10,scoring = 'accuracy')   #分类loss = -cross_val_score(knn,x,y,cv=10,scoring = 'neg_mean_squared_error')   #回归k_scores.append(loss.mean())plt.plot(range(1,41),k_scores)
plt.xlabel("Value of K for KNN")
plt.ylabel("Cross Validated Accuracy")
plt.show()


由图可以得知,平均方差越低越好,因此选择13~20左右的K值会最好。

5.Learning curve 检查过拟合

sklearn.learning_curve 中的 learning curve 可以很直观的看出我们的 model 学习的进度, 对比发现有没有 overfitting 的问题. 然后我们可以对我们的 model 进行调整, 克服 overfitting 的问题.

5.1.加载必要模块

from sklearn.model_selection import learning_curve
from sklearn.datasets import load_digits
from sklearn.svm import SVC import matplotlib.pyplot as plt
import numpy as np

5.2.加载数据

加载digits数据集,其包含的是手写体的数字,从0到9。数据集总共有1797个样本,每个样本由64个特征组成, 分别为其手写体对应的8×8像素表示,每个特征取值0~16。

digits = load_digits()
X = digits.data
y = digits.target

5.3.调用learning_curve

观察样本由小到大的学习曲线变化, 采用K折交叉验证 cv=10, 选择平均方差检视模型效能 scoring=‘neg_mean_squared_error’, 样本由小到大分成5轮检视学习曲线(10%, 25%, 50%, 75%, 100%):

train_sizes, train_loss, test_loss = learning_curve(SVC(gamma=0.001), X, y, cv=10, scoring='neg_mean_squared_error',train_sizes=[0.1, 0.25, 0.5, 0.75, 1])#平均每一轮所得到的平均方差(共5轮,分别为样本10%、25%、50%、75%、100%)
train_loss_mean = -np.mean(train_loss, axis=1)
test_loss_mean = -np.mean(test_loss, axis=1)

5.4.learning_curve可视化

plt.plot(train_sizes, train_loss_mean, 'o-', color="r",label="Training")
plt.plot(train_sizes, test_loss_mean, 'o-', color="g",label="Cross-validation")plt.xlabel("Training examples")
plt.ylabel("Loss")
plt.legend(loc="best")
plt.show()

6.validation_curve 检查过拟合

继续上面的例子,并稍作小修改即可画出图形。这次我们来验证SVC中的一个参数 gamma 在什么范围内能使 model 产生好的结果. 以及过拟合和 gamma 取值的关系.

from sklearn.model_selection import validation_curve
from sklearn.datasets import load_digits
from sklearn.svm import SVC import matplotlib.pyplot as plt
import numpy as npdigits = load_digits()
X = digits.data
y = digits.targetparam_range = np.logspace(-5,-3,6)train_loss, test_loss = validation_curve(SVC(), X, y,param_name = 'gamma',param_range = param_range, cv=10, scoring='neg_mean_squared_error')#平均每一轮所得到的平均方差(共5轮,分别为样本10%、25%、50%、75%、100%)
train_loss_mean = -np.mean(train_loss, axis=1)
test_loss_mean = -np.mean(test_loss, axis=1)plt.plot(param_range, train_loss_mean, 'o-', color="r",label="Training")
plt.plot(param_range, test_loss_mean, 'o-', color="g",label="Cross-validation")plt.xlabel("Gamma")
plt.ylabel("Loss")
plt.legend(loc="best")
plt.show()

Sklearn——交叉验证(Cross Validation)相关推荐

  1. 训练数据集如何划分验证测试集?train/test(val/dev) set和交叉验证(cross validation)

    普通train/test set 直接将训练数据划分为两部分,一部分用来做训练train set,一部分用来固定作为测试集test set.然后反复更换超参在训练集上进行训练,使用测试集依次测试,进行 ...

  2. 交叉验证(Cross Validation)方法思想简介

    交叉验证(CrossValidation)方法思想 以下简称交叉验证(Cross Validation)为CV.CV是用来验证分类器的性能一种统计分析方法,基本思想是把在某种意义下将原始数据(data ...

  3. 【机器学习】<刘建平Pinard老师博客学习记录>交叉验证(Cross Validation)

    交叉验证是在机器学习建立模型和验证模型参数时常用的办法.交叉验证,顾名思义,就是重复的使用数据,把得到的样本数据进行切分,组合为不同的训练集和测试集,用训练集来训练模型,用测试集来评估模型预测的好坏. ...

  4. 交叉验证(Cross Validation)最详解

    1.OverFitting 在模型训练过程中,过拟合overfitting是非常常见的现象.所谓的overfitting,就是在训练集上表现很好,但是测试集上表现很差.为了减少过拟合,提高模型的泛化能 ...

  5. matlab 交叉验证 代码,交叉验证(Cross Validation)方法思想简介

    本帖最后由 azure_sky 于 2014-1-17 00:30 编辑 2).K-fold Cross Validation(记为K-CV) 将原始数据分成K组(一般是均分),将每个子集数据分别做一 ...

  6. 交叉验证 cross validation 与 K-fold Cross Validation K折叠验证

    交叉验证,cross validation是机器学习中非常常见的验证模型鲁棒性的方法.其最主要原理是将数据集的一部分分离出来作为验证集,剩余的用于模型的训练,称为训练集.模型通过训练集来最优化其内部参 ...

  7. 交叉验证(Cross Validation)原理小结

    交叉验证是在机器学习建立模型和验证模型参数时常用的办法.交叉验证,顾名思义,就是重复的使用数据,把得到的样本数据进行切分,组合为不同的训练集和测试集,用训练集来训练模型,用测试集来评估模型预测的好坏. ...

  8. 机器学习笔记——sklearn 交叉验证(Cross-validation)

    sklearn cross validation:https://scikit-learn.org/stable/modules/cross_validation.html 交叉验证(Cross Va ...

  9. 机器学习- Sklearn (交叉验证和Pipeline)

    前面一节咱们已经介绍了决策树的原理已经在sklearn中的应用.那么这里还有两个数据处理和sklearn应用中的小知识点咱们还没有讲,但是在实践中却会经常要用到的,那就是交叉验证cross_valid ...

最新文章

  1. Android之上下文菜单创建
  2. oracle 11g 从rman全备中恢复控制文件,拥有RMAN全备(缺少后增文件),丢失全部数据文件,控制文件的恢复...
  3. 手机和邮箱的正则表达式
  4. ELK学习9_ELK数据流传输过程_问题总结2
  5. jquery 一些特效使用
  6. python爬取+BI分析5000条内衣数据,发现妹子最爱这款文胸
  7. retorfit converter使用说明
  8. DW Basic Knowledge1
  9. h264编解码器知识点
  10. 人大金仓数据库工程师培训实战教程(同步复制、读写分离、集群高可用)
  11. 电力线通信(Power Line Communication)简介
  12. 网站被黑提醒该站点可能受到黑客攻击,部分页面已被非法篡改...
  13. The C++ Frontend
  14. python语言程序代码保存在_《计算机二级Python语言程序设计考试》第5章:函数和代码复用...
  15. echarts 图表不能占满全屏
  16. 山寨 悟空遥控器的 方向键
  17. 贵州中小学教师计算机考试题目,2019贵州教师招聘考试习题及答案:小学数学...
  18. 计算机技术能力校本培训总结,计算机、网络技术校本培训总结.doc
  19. Ubuntu搭建LDAP服务器
  20. html 字号和像素的关系,一文搞懂CSS中的字体单位大小(px,em,rem...)

热门文章

  1. mysql 减去_MySql进阶面试题
  2. 『力荐汇总』这些 VS Code 快捷键太好用,忍不住录了这34张gif动图
  3. 为eclipse安装python、shell开发环境和SVN插件
  4. FPGA--------随笔总结(持续更新)
  5. smokeping的启动脚本
  6. java面试 谈谈jvm内存结构
  7. 一道很简单却也很容易入坑的java面试题
  8. python更新织梦网站_怎么让dedecms织梦网站首页自动更新
  9. 转载:OpenStack从入门到放弃
  10. GDAL1.11版本对SHP文件索引加速测试