Sklearn——交叉验证(Cross Validation)
文章目录
- 1.前言
- 2.非交叉验证实验
- 3.交叉验证实验
- 4.准确率与平方误差
- 4.1.准确率实验
- 4.2.均方误差实验
- 5.Learning curve 检查过拟合
- 5.1.加载必要模块
- 5.2.加载数据
- 5.3.调用learning_curve
- 5.4.learning_curve可视化
- 6.validation_curve 检查过拟合
1.前言
Sklearn 中的 Cross Validation (交叉验证)对于我们选择正确的 Model 和 Model 的参数是非常有帮助的, 有了它的帮助,我们能直观的看出不同 Model 或者参数对结构准确度的影响。
2.非交叉验证实验
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from sklearn.neighbors import KNeighborsClassifieriris = load_iris()
x = iris.data
y = iris.targetx_train, x_test, y_train, y_test = train_test_split(x,y,random_state = 12)
knn = KNeighborsClassifier()
knn.fit(x_train, y_train)
print(knn.score(x_test, y_test))#输出
0.9210526315789473
可以看到基础验证的准确率为0.9210526315789473
3.交叉验证实验
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from sklearn.neighbors import KNeighborsClassifier
from sklearn.model_selection import cross_val_score # K折交叉验证模块iris = load_iris()
x = iris.data
y = iris.targetx_train, x_test, y_train, y_test = train_test_split(x, y, random_state = 16)
knn = KNeighborsClassifier()
knn.fit(x_train, y_train)
scores = cross_val_score(knn, x, y,cv=5,scoring='accuracy') #使用K折交叉验证模块
print(scores) #将5次的预测准确率打印出
print(scores.mean()) #将5次的预测准确平均率打印出#输出
[0.96666667 1. 0.93333333 0.96666667 1. ]
0.9733333333333334
4.准确率与平方误差
4.1.准确率实验
from sklearn.datasets import load_iris
from sklearn.neighbors import KNeighborsClassifier
from sklearn.model_selection import cross_val_score
import matplotlib.pyplot as pltiris = load_iris()
x = iris.data
y = iris.targetk_scores = []
for k in range(1,41):knn = KNeighborsClassifier(n_neighbors = k)scores = cross_val_score(knn,x,y,cv=10,scoring = 'accuracy') #分类 十折交叉验证#loss = -cross_val_score(knn,x,y,cv=10,scoring = 'mean_squared_error') #回归k_scores.append(scores.mean())plt.plot(range(1,41),k_scores)
plt.xlabel("Value of K for KNN")
plt.ylabel("Cross Validated Accuracy")
plt.show()
从图中可以得知,选择12~20的k值最好。高过20之后,准确率开始下降则是因为过拟合(Over fitting)的问题。
4.2.均方误差实验
一般来说平均方差(Mean squared error)会用于判断回归(Regression)模型的好坏。
from sklearn.datasets import load_iris
from sklearn.neighbors import KNeighborsClassifier
from sklearn.model_selection import cross_val_score
import matplotlib.pyplot as pltiris = load_iris()
x = iris.data
y = iris.targetfor k in range(1,41):knn = KNeighborsClassifier(n_neighbors = k)#scores = cross_val_score(knn,x,y,cv=10,scoring = 'accuracy') #分类loss = -cross_val_score(knn,x,y,cv=10,scoring = 'neg_mean_squared_error') #回归k_scores.append(loss.mean())plt.plot(range(1,41),k_scores)
plt.xlabel("Value of K for KNN")
plt.ylabel("Cross Validated Accuracy")
plt.show()
由图可以得知,平均方差越低越好,因此选择13~20左右的K值会最好。
5.Learning curve 检查过拟合
sklearn.learning_curve 中的 learning curve 可以很直观的看出我们的 model 学习的进度, 对比发现有没有 overfitting 的问题. 然后我们可以对我们的 model 进行调整, 克服 overfitting 的问题.
5.1.加载必要模块
from sklearn.model_selection import learning_curve
from sklearn.datasets import load_digits
from sklearn.svm import SVC import matplotlib.pyplot as plt
import numpy as np
5.2.加载数据
加载digits数据集,其包含的是手写体的数字,从0到9。数据集总共有1797个样本,每个样本由64个特征组成, 分别为其手写体对应的8×8像素表示,每个特征取值0~16。
digits = load_digits()
X = digits.data
y = digits.target
5.3.调用learning_curve
观察样本由小到大的学习曲线变化, 采用K折交叉验证 cv=10, 选择平均方差检视模型效能 scoring=‘neg_mean_squared_error’, 样本由小到大分成5轮检视学习曲线(10%, 25%, 50%, 75%, 100%):
train_sizes, train_loss, test_loss = learning_curve(SVC(gamma=0.001), X, y, cv=10, scoring='neg_mean_squared_error',train_sizes=[0.1, 0.25, 0.5, 0.75, 1])#平均每一轮所得到的平均方差(共5轮,分别为样本10%、25%、50%、75%、100%)
train_loss_mean = -np.mean(train_loss, axis=1)
test_loss_mean = -np.mean(test_loss, axis=1)
5.4.learning_curve可视化
plt.plot(train_sizes, train_loss_mean, 'o-', color="r",label="Training")
plt.plot(train_sizes, test_loss_mean, 'o-', color="g",label="Cross-validation")plt.xlabel("Training examples")
plt.ylabel("Loss")
plt.legend(loc="best")
plt.show()
6.validation_curve 检查过拟合
继续上面的例子,并稍作小修改即可画出图形。这次我们来验证SVC中的一个参数 gamma 在什么范围内能使 model 产生好的结果. 以及过拟合和 gamma 取值的关系.
from sklearn.model_selection import validation_curve
from sklearn.datasets import load_digits
from sklearn.svm import SVC import matplotlib.pyplot as plt
import numpy as npdigits = load_digits()
X = digits.data
y = digits.targetparam_range = np.logspace(-5,-3,6)train_loss, test_loss = validation_curve(SVC(), X, y,param_name = 'gamma',param_range = param_range, cv=10, scoring='neg_mean_squared_error')#平均每一轮所得到的平均方差(共5轮,分别为样本10%、25%、50%、75%、100%)
train_loss_mean = -np.mean(train_loss, axis=1)
test_loss_mean = -np.mean(test_loss, axis=1)plt.plot(param_range, train_loss_mean, 'o-', color="r",label="Training")
plt.plot(param_range, test_loss_mean, 'o-', color="g",label="Cross-validation")plt.xlabel("Gamma")
plt.ylabel("Loss")
plt.legend(loc="best")
plt.show()
Sklearn——交叉验证(Cross Validation)相关推荐
- 训练数据集如何划分验证测试集?train/test(val/dev) set和交叉验证(cross validation)
普通train/test set 直接将训练数据划分为两部分,一部分用来做训练train set,一部分用来固定作为测试集test set.然后反复更换超参在训练集上进行训练,使用测试集依次测试,进行 ...
- 交叉验证(Cross Validation)方法思想简介
交叉验证(CrossValidation)方法思想 以下简称交叉验证(Cross Validation)为CV.CV是用来验证分类器的性能一种统计分析方法,基本思想是把在某种意义下将原始数据(data ...
- 【机器学习】<刘建平Pinard老师博客学习记录>交叉验证(Cross Validation)
交叉验证是在机器学习建立模型和验证模型参数时常用的办法.交叉验证,顾名思义,就是重复的使用数据,把得到的样本数据进行切分,组合为不同的训练集和测试集,用训练集来训练模型,用测试集来评估模型预测的好坏. ...
- 交叉验证(Cross Validation)最详解
1.OverFitting 在模型训练过程中,过拟合overfitting是非常常见的现象.所谓的overfitting,就是在训练集上表现很好,但是测试集上表现很差.为了减少过拟合,提高模型的泛化能 ...
- matlab 交叉验证 代码,交叉验证(Cross Validation)方法思想简介
本帖最后由 azure_sky 于 2014-1-17 00:30 编辑 2).K-fold Cross Validation(记为K-CV) 将原始数据分成K组(一般是均分),将每个子集数据分别做一 ...
- 交叉验证 cross validation 与 K-fold Cross Validation K折叠验证
交叉验证,cross validation是机器学习中非常常见的验证模型鲁棒性的方法.其最主要原理是将数据集的一部分分离出来作为验证集,剩余的用于模型的训练,称为训练集.模型通过训练集来最优化其内部参 ...
- 交叉验证(Cross Validation)原理小结
交叉验证是在机器学习建立模型和验证模型参数时常用的办法.交叉验证,顾名思义,就是重复的使用数据,把得到的样本数据进行切分,组合为不同的训练集和测试集,用训练集来训练模型,用测试集来评估模型预测的好坏. ...
- 机器学习笔记——sklearn 交叉验证(Cross-validation)
sklearn cross validation:https://scikit-learn.org/stable/modules/cross_validation.html 交叉验证(Cross Va ...
- 机器学习- Sklearn (交叉验证和Pipeline)
前面一节咱们已经介绍了决策树的原理已经在sklearn中的应用.那么这里还有两个数据处理和sklearn应用中的小知识点咱们还没有讲,但是在实践中却会经常要用到的,那就是交叉验证cross_valid ...
最新文章
- Android之上下文菜单创建
- oracle 11g 从rman全备中恢复控制文件,拥有RMAN全备(缺少后增文件),丢失全部数据文件,控制文件的恢复...
- 手机和邮箱的正则表达式
- ELK学习9_ELK数据流传输过程_问题总结2
- jquery 一些特效使用
- python爬取+BI分析5000条内衣数据,发现妹子最爱这款文胸
- retorfit converter使用说明
- DW Basic Knowledge1
- h264编解码器知识点
- 人大金仓数据库工程师培训实战教程(同步复制、读写分离、集群高可用)
- 电力线通信(Power Line Communication)简介
- 网站被黑提醒该站点可能受到黑客攻击,部分页面已被非法篡改...
- The C++ Frontend
- python语言程序代码保存在_《计算机二级Python语言程序设计考试》第5章:函数和代码复用...
- echarts 图表不能占满全屏
- 山寨 悟空遥控器的 方向键
- 贵州中小学教师计算机考试题目,2019贵州教师招聘考试习题及答案:小学数学...
- 计算机技术能力校本培训总结,计算机、网络技术校本培训总结.doc
- Ubuntu搭建LDAP服务器
- html 字号和像素的关系,一文搞懂CSS中的字体单位大小(px,em,rem...)