原标题:Python中训练集/测试集的分割和交叉验证

嗨,大家好!在上一篇关于Python线性回归的文章之后,我认为撰写关于切分训练集/测试集和交叉验证的文章是很自然的,和往常一样,我将对该主题进行简短概述,然后给出在Python中实现该示例的示例。这是数据科学和数据分析中两个相当重要的概念,并用作防止(或最小化)过度拟合的工具。我将解释这是什么—当我们使用统计模型(例如,线性回归)时,我们通常将模型拟合到训练集上,以便对未经训练的数据(常规数据)进行预测 )。过度拟合意味着我们对模型的拟合程度过高。我保证,这一切很快就会变得有意义。

1.什么是模型的过拟合/欠拟合?

如前所述,在统计和机器学习中,我们通常将数据分为两个子集:训练数据和测试数据(有时分为三个子集:训练,验证和测试),并在训练数据上拟合我们的模型,以便对测试数据做预测。当我们这样做时,可能会发生两件事之一:我们过度拟合我们的模型或我们欠拟合我们的模型。我们不希望发生这些事情,因为它们会影响我们模型的可预测性-我们可能使用的是准确性较低和/或未概括的模型(这意味着您无法对其他数据进行概括)。让我们看看欠拟合和过度拟合的实际含义:

1.1 过拟合

过度拟合意味着我们训练的模型训练得“太好了”,也就是说太适合训练数据集了。这通常在模型过于复杂时发生(即与观察数量相比,特征/变量太多)。该模型在训练数据上将非常准确,但是在未训练或新数据上可能会非常不准确。这是因为此模型未通用化,当发生这种情况时,模型将学习或描述训练数据中的“噪声”,而不是数据中变量之间的实际关系。显然,这种噪声不是任何新数据集的一部分,也无法对其应用。

1.2 欠拟合

与过度拟合相反,当模型拟合不足时,这意味着该模型不适合训练数据,因此会错过数据中的趋势。这也意味着该模型无法推广到新数据。您可能已经猜到,这通常是一个非常简单的模型(没有足够的预测变量/独立变量)导致的。例如,当我们将线性模型(如线性回归)拟合到非线性数据时,也会发生这种情况。不用说,该模型的预测能力很差(在训练数据上,不能推广到其他数据)。

欠拟合,“恰到好处”,过拟合的实例

值得注意的是,拟合不足不如拟合过度普遍。 但是,我们希望避免在数据分析中同时遇到这两个问题。 您可能会说我们正在尝试寻找模型不足和过度拟合之间的中间点。 正如您将看到的,训练/测试拆分和交叉验证有助于避免过度拟合而不是过度拟合。 让我们一起来研究它们!

2.训练/测试拆分

正如我之前所说,我们使用的数据通常分为训练数据和测试数据。 训练集包含一个已知的输出,并且模型在此数据上学习,以便稍后将其推广到其他数据。 我们具有测试数据集(或子集),以测试我们对该子集的模型预测。

让我们看看如何在Python中执行此操作。

我们将使用Scikit-Learn库(尤其是train_test_split方法)进行此操作。 我们将从导入必要的库开始:

让我们快速浏览一下我导入的库:

Pandas- 将数据文件加载为pandas数据框并分析数据。

From Sklearn- 导入了数据集模块,因此可以加载示例数据集和linear_model,从而可以运行线性回归

From Sklearn- 从子库model_selection中导入了train_test_split,因此可以将其拆分为训练集和测试集

Matplotlib-导入了pyplot,以便绘制数据图

好,准备好了! 让我们加载糖尿病数据集(diabetes dataset),将其转换为数据框并定义列的名称:

# Load the Diabetes dataset

columns = “age sex bmi map tc ldl hdl tch ltg glu”.split # Declare the columns names

diabetes = datasets.load_diabetes # Call the diabetes dataset from sklearn

df = pd.DataFrame(diabetes.data, columns=columns) # load the dataset as a pandas data frame

y = diabetes.target # define the target variable (dependent variable) as y

现在,我们可以使用train_test_split函数进行拆分。 函数内部的test_size = 0.2表示应保留进行测试的数据百分比。 通常是80/20或70/30。

# create training and testing vars

X_train, X_test, y_train, y_test = train_test_split(df, y, test_size=0.2)

print X_train.shape, y_train.shape

print X_test.shape, y_test.shape

(353, 10) (353,)

(89, 10) (89,)

现在,我们将模型拟合到训练数据上:

# fit a model

lm = linear_model.LinearRegressionmodel = lm.fit(X_train,y_train)

predictions = lm.predict(X_test)

如您所见,我们正在训练模型上拟合模型,并试图预测测试数据。 让我们看看预测是什么(部分):

注意:由于我在预测后使用了[0:5],因此只显示了前五个预测值。 删除[0:5]将使其打印我们的模型创建的所有预测值。

让我们绘制模型:

## The line / model

plt.scatter(y_test, predictions)

plt.xlabel(“True Values”)

plt.ylabel(“Predictions”)

并输出准确性分数:

print “Score:”, model.score(X_test, y_test)

做得好! 以下是我所做工作的摘要:我已经加载了数据,将其分为训练和测试集,对训练数据拟合了回归模型,基于该数据对测试数据进行了预测。 看起来不错吧? 但是训练/测试拆分确实有其危险-如果我们进行的拆分不是随机的,该怎么办? 如果我们数据的一个子集仅包含来自某个州的人员,具有一定收入水平的员工,而没有其他收入水平,只有女性或只有特定年龄的人员,该怎么办? (想象一个由这些命令之一排序的文件)。 即使我们试图避免过度拟合,也会导致过度拟合! 这是交叉验证的来源。

3.交叉验证

在上一段中,我提到了训练/测试拆分方法中的注意事项。 为了避免这种情况,我们可以执行称为交叉验证的操作。 它与训练/测试拆分非常相似,但适用于更多子集。 意思是,我们将数据分为k个子集,并在k-1个子集中训练。 我们要做的是保留最后一个子集进行测试, 并且需要为每个子集重复这个步骤。

训练/测试拆分和交叉验证的可视化表示

有很多交叉验证方法,我将介绍其中的两种:第一种是K-folds交叉验证,第二种是“ Leave One Out”交叉验证(LOOCV)

3.1 K折交叉验证

在K折交叉验证中,我们将数据分为k个不同的子集(或折)。 我们使用k-1个子集来训练我们的数据,并保留最后一个子集(或最后一个折叠)作为测试数据。 然后,我们针对每个子集对模型进行平均,然后最终确定模型。 之后,我们针对测试集对其进行测试。

可视化的K折

这是Sklearn文档中有关K折的一个非常简单的示例:

from sklearn.model_selection import KFold # import KFold

X = np.array([[1, 2], [3, 4], [1, 2], [3, 4]]) # create an array

y = np.array([1, 2, 3, 4]) # Create another array

kf = KFold(n_splits=2) # Define the split - into 2 folds

kf.get_n_splits(X) # returns the number of splitting iterations in the cross-validatorprint(kf) KFold(n_splits=2, random_state=None, shuffle=False)

让我们看看结果-折叠:

for train_index, test_index in kf.split(X):

print(“TRAIN:”, train_index, “TEST:”, test_index)

X_train, X_test = X[train_index], X[test_index]

y_train, y_test = y[train_index], y[test_index]

('TRAIN:', array([2, 3]), 'TEST:', array([0, 1]))

('TRAIN:', array([0, 1]), 'TEST:', array([2, 3]))

如您所见,该函数将原始数据拆分为数据的不同子集。 同样,这是一个非常简单的示例,但我认为它很好地解释了这个概念。

3.2 留一交叉验证(LOOCV)

这是另一种进行交叉验证的方法: 留一交叉验证(顺便说一句,这些方法不是仅有的两种,还有很多其他交叉验证方法。请在Sklearn网站上进行检查)。 在这种类型的交叉验证中,折叠(子集)的数量等于我们在数据集中观察到的数量。 然后,我们对所有这些折叠进行平均,并用平均数建立模型。 然后,我们针对最后的折叠测试模型。 由于我们将获得大量训练集(等于样本数量),因此该方法的计算量非常大,应在小型数据集上使用。 如果数据集很大,则最好使用其他方法,例如k-fold。

让我们看看Sklearn的另一个示例:

from sklearn.model_selection import LeaveOneOut

X = np.array([[1, 2], [3, 4]])

y = np.array([1, 2])

loo = LeaveOneOut

loo.get_n_splits(X)

for train_index, test_index in loo.split(X):

print("TRAIN:", train_index, "TEST:", test_index)

X_train, X_test = X[train_index], X[test_index]

y_train, y_test = y[train_index], y[test_index]

print(X_train, X_test, y_train, y_test)

这是输出:

('TRAIN:', array([1]), 'TEST:', array([0]))

(array([[3, 4]]), array([[1, 2]]), array([2]), array([1]))

('TRAIN:', array([0]), 'TEST:', array([1]))

(array([[1, 2]]), array([[3, 4]]), array([1]), array([2]))

同样,简单的示例,但我确实确实认为这有助于理解此方法的基本概念。

那么,我们应该使用什么方法呢? 应该使用多少折? 当我们使用更多折,我们将减少由于偏差引起的误差,但由于方差导致的误差增加; 很明显,计算成本也会上涨-折叠的次数越多,计算时间就越长,并且您将需要更多的内存。K折数量较少时,我们会减少由于方差引起的误差,但由于偏斜引起的误差会更大。 在计算上也更便宜。 因此,在大数据集中,通常建议k = 3。 如前所述,在较小的数据集中,最好使用LOOCV。

让我们看看我以前使用的示例,这次使用交叉验证。 我将使用cross_val_predict函数返回每个数据点在测试切片中的预测值。

# Necessary imports:

from sklearn.cross_validation import cross_val_score, cross_val_predict

from sklearn import metrics

您还记得,我之前为糖尿病数据集创建了训练/测试拆分并拟合了模型。 让我们看看交叉验证后的得分是多少:

如您所见,最后的折叠将原始模型的得分从0.485提高到0.569。 这不是一个惊人的结果,但是,我们将尽我们所能:)

现在,在执行交叉验证后,让我们绘制新的预测:

# Make cross validated predictions

predictions = cross_val_predict(model, df, y, cv=6)

plt.scatter(y, predictions)

您会发现它与之前的原始图有很大不同。 因为我使用cv = 6,所以它的点数是原始图的六倍。

最后,让我们检查一下模型的R²得分(R²是一个“数字,该数字表明可以从自变量中预测出的因变量中的方差比例。”基本上,我们的模型有多准确):

这次就是这样! 希望您喜欢这篇文章。 与往常一样,我欢迎您提出关于您要阅读的主题的问题,注释,评论和帖子请求。 下次见!

作者:Adi Bronshtein

翻译:Fan Wang返回搜狐,查看更多

责任编辑:

python基于训练集预测_Python中训练集/测试集的分割和交叉验证相关推荐

  1. 数据标准化常见问题:对整个数据集数据标准化后再划分训练集、测试集和先对训练级标准化再将规则用于测试集有什么区别(Python实现)

    在数据分析与挖掘.算法建模的都会用到数据标准化.数据的标准化(normalization)是将数据按比例缩放,使之落入一个小的特定区间.在某些比较和评价的指标处理中经常会用到,去除数据的单位限制,将其 ...

  2. python基于模型的预测概率和标签信息可视化ROC曲线、编写自定义函数计算约登值、寻找最佳阈值(threshold、cutoff)、可视化ROC曲线并在曲线中标记最佳阈值及其数值标签

    python基于模型的预测概率和标签信息可视化ROC曲线.编写自定义函数计算约登值.寻找最佳阈值(threshold.cutoff).可视化ROC曲线并在曲线中标记最佳阈值及其数值标签 目录

  3. python第30讲数据挖掘_Python 中的实用数据挖掘

    本文是 2014 年 12 月我在布拉格经济大学做的名为' Python 数据科学'讲座的笔记.欢迎通过 @RadimRehurek 进行提问和评论. 本次讲座的目的是展示一些关于机器学习的高级概念. ...

  4. Python使用tpot获取最优模型、将最优模型应用于交叉验证数据集(5折)获取数据集下的最优表现,并将每一折(fold)的预测结果、概率、属于哪一折与测试集标签、结果、概率一并整合输出为结果文件

    Python使用tpot获取最优模型.将最优模型应用于交叉验证数据集(5折)获取数据集下的最优表现,并将每一折(fold)的预测结果.概率.属于哪一折与测试集标签.结果.概率一并整合输出为结果文件 目 ...

  5. SVM 训练--在训练集上acc为94% 在测试集上为70%

    用SVM 训练的时候: 出现的问题是: Error: specified nu is infeasible 带有下标的赋值维度不匹配. 百度到的答案:赋值维度不匹配...说明等号两边的大小不一致,导致 ...

  6. R语言使用caret包中的createFolds函数对机器学习数据集进行交叉验证抽样、返回的样本列表长度为k个

    R语言使用caret包中的createFolds函数对机器学习数据集进行交叉验证抽样.返回的样本列表长度为k个 目录

  7. R语言使用caret包中的createMultiFolds函数对机器学习数据集进行交叉验证抽样、返回的样本列表长度为k×times个、times为组内抽样次数

    R语言使用caret包中的createMultiFolds函数对机器学习数据集进行交叉验证抽样.返回的样本列表长度为k×times个.times为组内抽样次数 目录

  8. 深度学习模型在训练集上很好而在测试集表现得不好而拟合次数并不多_机器学习中的过拟合,欠拟合和偏倚方差折衷...

    过度拟合在机器学习中很重要. 很直观的解释过拟合:假设我们现在让机器学习考试做题,想象一种情况,机器逐字记住每个问题的答案(拟合非常好-完美).然后,我们可以在练习题上得分很高:我们这样做是基于希望实 ...

  9. python方差分析模型的预测结果中endog表示_python时间序列分析

    题记:毕业一年多天天coding,好久没写paper了.在这动荡的日子里,也希望写点东西让自己静一静.恰好前段时间用python做了一点时间序列方面的东西,有一丁点心得体会想和大家分享下.在此也要特别 ...

最新文章

  1. save_path is not a valid checkpoint
  2. win32com python_python模块:win32com用法详解
  3. CALL TRANSACTION
  4. 芋道 spring security oauth2 入门_Spring官方宣布:新的Spring OAuth2.0授权服务器已经来了
  5. 实体与电商,有啥区别?
  6. SolarWinds 攻击者开发的新后门 FoggyWeb
  7. netword localhost与 127.0.0.1 与 ::1 与 0.0.0.0 区别
  8. java 观察者模式_设计模式-Java-观察者模式-RxJava
  9. 阿里巴巴优酷视频增强和超分辨率挑战赛-持续更新
  10. 设置国内maven镜像仓库
  11. 【zheng】学习搭建github的高星项目:zheng
  12. win10网页找不到服务器dns,win10系统浏览网页提示“找不到服务器或dns错误”的解决方法...
  13. python如何爬取煎蛋图片(js)
  14. 永磁同步电机力矩控制(九):定子磁场中的若干相关概念
  15. 【马克思主义基本原理概论】
  16. Android差分升级原理和实现方式
  17. 前端面试系列-输入url后全过程页面渲染机制DOM生成过程
  18. actuator的端口暴露
  19. JAVA 语言程序设计与数据结构 教材课本源码 和 课后习题答案
  20. 合伙人股权设计的9点常识

热门文章

  1. 安装windows与Ubuntu双系统,并使用GRUB启动引导器
  2. 都说在阿里年薪百万不难,面试入职阿里需要准备什么?
  3. VMware虚拟机下Centos7 桥接方式网络配置完整步骤
  4. ThreeJS FBXLoader 加载3D文件,材质消失,已解决
  5. java jtextarea 超出_java – 如何保持JTextArea的大小不变?
  6. java swing 多行文本,Java Swing JTextArea
  7. 诺丁汉为满足当地需求新建一个数据中心
  8. 2022hit计算机系统大作业
  9. 一份不太简短的LaTeX2e介绍最新版地址2019 The Not So Short In­tro­duc­tion To LATEX (Chi­nese Edi­tion)
  10. (一)Siamese目标跟踪——SiamFC训练和跟踪过程:从论文细节角度出发