关于交叉验证,我在之前的文章中已经进行了简单的介绍,而现在我们则通过几个更加详尽的例子.详细的介绍

CV

%matplotlib inline
import numpy as np
from sklearn.model_selection import train_test_split
from sklearn import datasets
from sklearn import svmiris = datasets.load_iris()
iris.data.shape,iris.target.shape
((150, 4), (150,))

一般的分割方式,训练集-测试集.然而这种方式并不是很好

X_train, X_test, y_train, y_test = train_test_split(iris.data, iris.target, test_size=0.4, random_state=0) clf_svc = svm.SVC(kernel='linear').fit(X_train,y_train)
clf_svc.score(X_test,y_test)
0.9666666666666667
  • 缺点一:浪费数据
  • 缺点二:容易过拟合,且矫正方式不方便

这时,我们需要使用另外一种分割方式-交叉验证

from sklearn.model_selection import cross_val_score
clf_svc_cv = svm.SVC(kernel='linear',C=1)
scores_clf_svc_cv = cross_val_score(clf_svc_cv,iris.data,iris.target,cv=5)
print(scores_clf_svc_cv)
print("Accuracy: %0.2f (+/- %0.2f)" % (scores_clf_svc_cv.mean(), scores_clf_svc_cv.std() * 2))
[ 0.96666667  1.          0.96666667  0.96666667  1.        ]
Accuracy: 0.98 (+/- 0.03)

同时我们也可以为cross_val_score选择不同的性能度量函数

from sklearn import metrics
scores_clf_svc_cv_f1 = cross_val_score(clf_svc_cv,iris.data,iris.target,cv=5,scoring='f1_macro')
print("F1: %0.2f (+/- %0.2f)" % (scores_clf_svc_cv_f1.mean(), scores_clf_svc_cv_f1.std() * 2))
F1: 0.98 (+/- 0.03)

同时也正是这些特性使得,cv与数据转化以及pipline(sklearn中的管道机制)变得更加契合

from sklearn import preprocessing
from sklearn.pipeline import make_pipeline
clf_pipline = make_pipeline(preprocessing.StandardScaler(),svm.SVC(C=1))
scores_pipline_cv = cross_val_score(clf_pipline,iris.data,iris.target,cv=5)
print("Accuracy: %0.2f (+/- %0.2f)" % (scores_clf_svc_cv_f1.mean(), scores_clf_svc_cv_f1.std() * 2))
Accuracy: 0.98 (+/- 0.03)

同时我们还可以在交叉验证使用多个度量函数

from sklearn.model_selection import cross_validate
from sklearn import metricsscoring = ['precision_macro', 'recall_macro']
clf_cvs = svm.SVC(kernel='linear', C=1, random_state=0)
scores_cvs = cross_validate(clf_cvs,iris.data,iris.target,cv=5,scoring=scoring,return_train_score = False)
sorted(scores_cvs.keys())
['fit_time', 'score_time', 'test_precision_macro', 'test_recall_macro']
print(scores_cvs['test_recall_macro'])
print("test_recall_macro: %0.2f (+/- %0.2f)" % (scores_cvs['test_recall_macro'].mean(), scores_cvs['test_recall_macro'].std() * 2))
[ 0.96666667  1.          0.96666667  0.96666667  1.        ]
test_recall_macro: 0.98 (+/- 0.03)

同时cross_validate也可以使用make_scorer自定义度量功能
或者使用单一独量

from sklearn.metrics.scorer import make_scorer
scoring_new = {'prec_macro': 'precision_macro','recall_micro': make_scorer(metrics.recall_score, average='macro')}
# 注意此处的make_scorer
scores_cvs_new = cross_validate(clf_cvs,iris.data,iris.target,cv=5,scoring=scoring_new,return_train_score = False)
sorted(scores_cvs_new.keys())
['fit_time', 'score_time', 'test_prec_macro', 'test_recall_micro']
print(scores_cvs_new['test_recall_micro'])
print("test_recall_micro: %0.2f (+/- %0.2f)" % (scores_cvs_new['test_recall_micro'].mean(), scores_cvs_new['test_recall_micro'].std() * 2))
[ 0.96666667  1.          0.96666667  0.96666667  1.        ]
test_recall_micro: 0.98 (+/- 0.03)

关于Sklearn中的CV还有cross_val_predict可用于预测,下面则是Sklearn中一个关于使用该方法进行可视化预测错误的案例

from sklearn import datasets
from sklearn.model_selection import cross_val_predict
from sklearn import linear_model
import matplotlib.pyplot as pltlr = linear_model.LinearRegression()
boston = datasets.load_boston()
y = boston.target# cross_val_predict returns an array of the same size as `y` where each entry
# is a prediction obtained by cross validation:
predicted = cross_val_predict(lr, boston.data, y, cv=10)fig, ax = plt.subplots()
fig.set_size_inches(18.5,10.5)
ax.scatter(y, predicted, edgecolors=(0, 0, 0))
ax.plot([y.min(), y.max()], [y.min(), y.max()], 'k--', lw=4)
ax.set_xlabel('Measured')
ax.set_ylabel('Predicted')
plt.show()

KFlod的例子

Stratified k-fold:实现了分层交叉切分

from sklearn.model_selection import StratifiedKFold
X = np.array([[1, 2, 3, 4],[11, 12, 13, 14],[21, 22, 23, 24],[31, 32, 33, 34],[41, 42, 43, 44],[51, 52, 53, 54],[61, 62, 63, 64],[71, 72, 73, 74]])y = np.array([1, 1, 0, 0, 1, 1, 0, 0])stratified_folder = StratifiedKFold(n_splits=4, random_state=0, shuffle=False)
for train_index, test_index in stratified_folder.split(X, y):print("Stratified Train Index:", train_index)print("Stratified Test Index:", test_index)print("Stratified y_train:", y[train_index])print("Stratified y_test:", y[test_index],'\n')
Stratified Train Index: [1 3 4 5 6 7]
Stratified Test Index: [0 2]
Stratified y_train: [1 0 1 1 0 0]
Stratified y_test: [1 0] Stratified Train Index: [0 2 4 5 6 7]
Stratified Test Index: [1 3]
Stratified y_train: [1 0 1 1 0 0]
Stratified y_test: [1 0] Stratified Train Index: [0 1 2 3 5 7]
Stratified Test Index: [4 6]
Stratified y_train: [1 1 0 0 1 0]
Stratified y_test: [1 0] Stratified Train Index: [0 1 2 3 4 6]
Stratified Test Index: [5 7]
Stratified y_train: [1 1 0 0 1 0]
Stratified y_test: [1 0]
from sklearn.model_selection import StratifiedKFold
X = np.array([[1, 2, 3, 4],[11, 12, 13, 14],[21, 22, 23, 24],[31, 32, 33, 34],[41, 42, 43, 44],[51, 52, 53, 54],[61, 62, 63, 64],[71, 72, 73, 74]])y = np.array([1, 1, 0, 0, 1, 1, 0, 0])stratified_folder = StratifiedKFold(n_splits=4, random_state=0, shuffle=False)
for train_index, test_index in stratified_folder.split(X, y):print("Stratified Train Index:", train_index)print("Stratified Test Index:", test_index)print("Stratified y_train:", y[train_index])print("Stratified y_test:", y[test_index],'\n')
Stratified Train Index: [1 3 4 5 6 7]
Stratified Test Index: [0 2]
Stratified y_train: [1 0 1 1 0 0]
Stratified y_test: [1 0] Stratified Train Index: [0 2 4 5 6 7]
Stratified Test Index: [1 3]
Stratified y_train: [1 0 1 1 0 0]
Stratified y_test: [1 0] Stratified Train Index: [0 1 2 3 5 7]
Stratified Test Index: [4 6]
Stratified y_train: [1 1 0 0 1 0]
Stratified y_test: [1 0] Stratified Train Index: [0 1 2 3 4 6]
Stratified Test Index: [5 7]
Stratified y_train: [1 1 0 0 1 0]
Stratified y_test: [1 0]

除了这几种交叉切分KFlod外,还有很多其他的分割方式,比如StratifiedShuffleSplit重复分层KFold,实现了每个K中各类别的比例与原数据集大致一致,而RepeatedStratifiedKFold 可用于在每次重复中用不同的随机化重复分层 K-Fold n 次。至此基本的KFlod在Sklearn中都实现了

注意

i.i.d 数据是机器学习理论中的一个常见假设,在实践中很少成立。如果知道样本是使用时间相关的过程生成的,则使用 time-series aware cross-validation scheme 更安全。 同样,如果我们知道生成过程具有 group structure (群体结构)(从不同 subjects(主体) , experiments(实验), measurement devices (测量设备)收集的样本),则使用 group-wise cross-validation 更安全。

下面就是一个分组KFold的例子,

from sklearn.model_selection import GroupKFoldX = [0.1, 0.2, 2.2, 2.4, 2.3, 4.55, 5.8, 8.8, 9, 10]
y = ["a", "b", "b", "b", "c", "c", "c", "d", "d", "d"]
groups = [1, 1, 1, 2, 2, 2, 3, 3, 3, 3]gkf = GroupKFold(n_splits=3)
for train, test in gkf.split(X, y, groups=groups):print("%s %s" % (train, test))
[0 1 2 3 4 5] [6 7 8 9]
[0 1 2 6 7 8 9] [3 4 5]
[3 4 5 6 7 8 9] [0 1 2]

更多内容请参考:sklearn相应手册

Sklearn中的CV与KFold详解相关推荐

  1. sklearn中warm_start参数的作用详解

    warm_start参数用在模型训练过程中,默认是False,从字面上理解就是"从温暖的地方开始"训练模型. sklearn官网定义: 使用方法: 1.如果warm_start=T ...

  2. PyTorch中的torch.nn.Parameter() 详解

    PyTorch中的torch.nn.Parameter() 详解 今天来聊一下PyTorch中的torch.nn.Parameter()这个函数,笔者第一次见的时候也是大概能理解函数的用途,但是具体实 ...

  3. python labelencoder参数_对python 数据处理中的LabelEncoder 和 OneHotEncoder详解

    python 怎么读取 Label 里面的值 #简单来说 LabelEncoder 是对不连续的数字或者文本进行编号 from sklearn.preprocessing import LabelEn ...

  4. 手机摄影中多摄融合理论详解与代码实战

    转载AI Studio项目链接https://aistudio.baidu.com/aistudio/projectdetail/3465839 手机摄影中多摄融合理论详解与代码实战 前言   从20 ...

  5. linux命令打包文件,Linux系统中打包文件的命令详解

    linux系统中遇到要打包文件的时候我们该使用什么命令呢?下面由秋天网 Qiutian.ZqNF.Com小编为大家整理了linux系统中打包文件的命令详解的相关知识,希望对大家有帮助! linux系统 ...

  6. Linux中/proc目录下文件详解

    Linux中/proc目录下文件详解(一) 声明:可以自由转载本文,但请务必保留本文的完整性. 作者:张子坚 email:zhangzijian@163.com 说明:本文所涉及示例均在fedora ...

  7. python创建列向量_关于Numpy中的行向量和列向量详解

    关于Numpy中的行向量和列向量详解 行向量 方式1 import numpy as np b=np.array([1,2,3]).reshape((1,-1)) print(b,b.shape) 结 ...

  8. jQuery中getJSON跨域原理详解

    详见:http://blog.yemou.net/article/query/info/tytfjhfascvhzxcytp28 jQuery中getJSON跨域原理详解 前几天我再开发一个叫 河蟹工 ...

  9. java mod %区别_Java中 % 与Math.floorMod() 区别详解

    %为取余(rem),Math.floorMod()为取模(mod) 取余取模有什么区别呢? 对于整型数a,b来说,取模运算或者取余运算的方法都是: 1.求 整数商: c = a/b; 2.计算模或者余 ...

最新文章

  1. web-inf下文件访问
  2. Java 多线程常见问题
  3. C++primer 第 3 章 字符串、向量和数组 3 . 4 迭代器介绍
  4. 数据库函数依赖及范式
  5. dj电商-数据表的设计-商品表的设计
  6. html 图片防盗链,配置.htaccess文件实现图片防盗链
  7. 转:在Nginx上配置多个站点
  8. 数据结构 2-2 线性表的顺序表实现
  9. vs2012 使用mysql_vs2012连接mysql
  10. 网管学习笔记-hybrid口配置
  11. 使用Net::DNS CPAN模块模拟一个DNS CNAME应答记录
  12. Fortran77基础
  13. tplink无线路由dns服务器,TPLINK无线路由器如何设置DNS服务器的DNS服务器地址,TPLINK无线路由器的地址 | 192.168.1.1手机登陆...
  14. 移动GIS开发:手机基站定位+离线切片地图(矢量vtpk+栅格tpk)导航安卓APP
  15. 自发光物体能被烘焙之后是否会影响周围的物体
  16. 柳传志:古稀人生最后一个大弯
  17. error converting to execution character set illegal byte sequence报错解决办法
  18. Whale帷幄 - 智慧化门店 智慧化运营
  19. 用牛顿迭代法求方程的根
  20. strlen、strcpy、strcmp、strcat函数的实现

热门文章

  1. 在Linux Debian 8下部署基于PHP的Web项目。
  2. 汉字内码UNICODE转换表
  3. Mybatis报错:无效的列类型
  4. 银行祖传系统重构实例:创立12年,只支持Python 2,跑着500多个应用程序
  5. 一篇来自前端同学对后端接口的吐槽:痛!
  6. Spring Boot ES 实战,值得参考!
  7. 磁盘满了,为啥du却显示还有很大空间?
  8. 框架:Spring的自动装配
  9. STL泛型算法:reverse
  10. ppp lcp协商报文有哪些_PPP协议