简介

数据质量的高低是决定使用机器学习算法获得预测结果质量高低的重要因素,在很多常见任务中,数据质量的作用远大于模型的作用,本文讨论数据预处理时会遇到的一个常见问题:训练集与测试集数据分布不一致。

什么是训练集与测试集数据分布不一致?

一个具体的例子,比如我现在要预测泰坦尼克号乘客存活率(Kaggle 上的经典题,已经被各路选手将准确率刷爆了),如果训练集的输入特征中,“性别” 这一特征多数是男性,而在测试集里,“性别” 这一特征多数是女性,这便是训练集与测试集上,某特征其数据分布不均。

训练集和测试集分布不一致也被称作数据集偏移 (Dataset Shift),导致这种问题有两个常见原因:

  • 样本选择偏差 (Sample Selection Bias): 训练集是通过有偏方法得到的,例如非均匀选择 (Non-uniform Selection),导致训练集无法很好表征的真实样本空间。

  • 环境不平稳 (Non-stationary Environments): 当训练集数据的采集环境跟测试集不一致时会出现该问题,一般是由于时间或空间的改变引起的。

先讨论样本选择偏差,在有监督学习里,样本会分为特征数据 (feature) 与目标变量 (label),样本选择偏差也会分分为两种情况:

  • 没有考虑数据中不同特征的分布问题,如前面举例的预测泰坦尼克号乘客存活率问题,训练集的性别特征中,男性比例大,而测试集的性别特征中,女性比例大。

  • 没有考虑数据中目标变量分布问题,从而会出现:训练集类别 A 数据量远多于类别 B,而测试集相反的情况。

样本选择偏差会导致训练好的模型在测试集上鲁棒性很差,因为训练集没有很好覆盖整个样本空间。

接着讨论环境不平稳带来的数据偏移,最典型的就是在时序数据中,用历史时序数据预测未来时序,未来突发事件很可能带来时序的不稳定表现,这便带来了分布差异。

环境因素不仅限于时间和空间,还有数据采集设备、标注人员等。

校验数据分布

如何判断训练集与测试集数据分布是否不一致呢?

通常使用核密度估计 (kernel density estimation, KDE) 分布图和 KS 校验这两种方法来判断。

KDE 分布图

在讨论 KDE 分布图之前,先考虑一下使用概率密度直方图来判断数据分布的问题。

概率密度直方图是用数据集中不同数据出现的次数来表示其概率,需注意这种假设不一定成立。

要对比训练集和测试集数据的分布,我们可以通过绘制相应的概率密度直方图,然后直观的判断直方图的差异,但通过直方图判断数据分布的会有两个缺陷:

  • 1. 受 bin 宽度影响大

  • 2. 不平滑

而 KDE 分布图相比于直方图,它受 bin 影响更小,绘图呈现更平滑,易于对比数据分布,下图便是直方图和核密度估计的一个对比:

在进一步讨论 KDE 前,先讨论一下核函数,核函数定义一个用于生成 PDF (概率分布函数,Probability Distribution Function) 的曲线,不同于将值放入离散 bins 中,核函数对每个样本值都创建一个独立的概率密度曲线,然后加和这些平滑曲线,最终得到一个平滑连续的概率分布曲线。

“核” 在不同的语境下的含义是不同的,在 “非参数估计”(即不知道数据分布情况) 的语境下,“核” 是一个函数,用来提供权重。例如高斯函数 (Gaussian) 就是一个常用的核函数。

KDE 在数学上还有挺多细节,但在实现上,通过 seaborn 库便可以轻松实现,代码如下:

import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt# 创建样例特征
train_mean, train_cov = [0, 2], [(1, .5), (.5, 1)]
test_mean, test_cov = [0, .5], [(1, 1), (.6, 1)]
#  np.random.multivariate_normal 从多变量正态分布中随机抽取样本
# 多正态分布是一维正态分布向高维的推广。这样的分布是由它的平均值和协方差矩阵来确定的。
# 这些参数类似于一维正态分布的均值(平均或“中心”)和方差(标准差或“宽度”的平方)。
train_feat, _ = np.random.multivariate_normal(train_mean, train_cov, size=50).T
test_feat, _ = np.random.multivariate_normal(test_mean, test_cov, size=50).T# 绘KDE对比分布
sns.kdeplot(train_feat, shade = True, color='r', label = 'train')
sns.kdeplot(test_feat, shade = True, color='b', label = 'test')
plt.xlabel('Feature')
plt.legend()
plt.show()

注意,上述代码中,train_feat 参数是一维的(即单独某个特征的分布,多个多特征,需要绘制多个KDE分布图)。效果如图所示:

从上图可知,训练集与测试集分布差异不大,可以继续模型训练等操作,如果分布差异较大,比如下图这种,就需要对原始数据进行处理了。

KS 检验

KDE 是使用 PDF 来对比,而 KS 检验是基于 CDF (累计分布函数 Cumulative Distribution Function) 来检验两个数据分布是否一致,它也是非参数检验方法。

KS 检验是基于累计分布函数,用于检验一个分布是否符合某种理论分布或比较两个经验分布是否有显著差异。

KS 检验一般返回两个值:

  • 第一个值表示两个分布之间的最大距离,值越小即这两个分布的差距越小,分布也就越一致。

  • 第二个值是 p 值,用来判定假设检验结果的一个参数,p 值越大,越不能拒绝原假设(待检验的两个分布是同分布),即两个分布越是同分布。

通过 scipy 库可以快速实现 KS 检验,代码如下:

from traceback import print_tb
import numpy as np
from scipy import statstrain_mean, train_cov = [0, 2], [(1, .5), (.5, 1)]
test_mean, test_cov = [0, .5], [(1, 1), (.6, 1)]train_feat, _ = np.random.multivariate_normal(train_mean, train_cov, size=50).T
test_feat, _ = np.random.multivariate_normal(test_mean, test_cov, size=50).Tresult = stats.ks_2samp(train_feat, test_feat)
print(result)# 打印结果:
# KstestResult(statistic=0.18, pvalue=0.3959398631708505)

若 KS 统计值小且 p 值大,则可以接受 KS 检验的原假设,即两个数据分布一致。

上面样例数据的统计值较低,p 值大于 10% 但不是很高,因此反映分布略微不一致。如果p 值 < 0.01,建议拒绝原假设,p 值越大,越倾向于原假设成立。

分类器对抗验证

所谓对抗验证,就是构建一个分类模型去分类训练集和测试集,如果分类模型可以清楚的分类,则说明训练集和测试集的分布有明显差异,反之分布差异不大。

分类模型可以直接使用 sklearn 中提供了几种常见分类器来实现,比如 SVM。

具体步骤如下:

  • 训练集和测试集合并,同时新增标签Is_Test去标记训练集样本为 0,测试集样本为 1。

  • 构建分类器 (例如 SVM、LGB、XGB 等) 去训练混合后的数据集 (可采用交叉验证的方式),拟合目标标签Is_Test。

  • 输出交叉验证中最优的 AUC 分数。AUC 越大 (越接近 1),越说明训练集和测试集分布不一致。

结尾

本文的方法虽然基于训练数据与测试数据进行讨论,但同样可以用于训练数据与预测数据的分布检测上,在模型训练测试阶段,我们会将已有的数据划分为训练数据与测试数据,当模型通过测试后,通常会合并训练数据与测试数据,用所有数据进行训练,获得最终的模型,然后上线使用。

如果上线后效果不好,数据分布问题依旧要考虑,通常,我们会收集线上的待预测数据,将待预测数据的特征分与训练数据的特征分布进行比较,依旧使用本文提及的方法,如果分布差异大,则说明,训练数据无法代表待预测数据,当前模型是没有实用价值的。

数据预处理是多数机器学习任务的核心,反倒是模型,因为很多成熟的实现,反而不是啥大问题。最近在整理自己过去的机器学习笔记,后续会将有价值的部分输出到公众号中。

本文相关参考:

训练 / 测试集分布不一致解法总结

Python 可视化神器 Seaborn 入门系列 (一)——kdeplot 和 distplot

密度估计(kernel density estimation)

训练集与测试集数据分布不一致相关推荐

  1. label y 训练集测试集x_Adversarial validation-对抗验证| 一种解决训练集与测试集分布不一致的方法...

    导语: 马上就要五一了,祝全世界人民五一快乐!在这之前,想过好几个准备这些天可以完成的专题,比如Boosting系列在搞点最近几年的新玩意,或者开一个新专题,如心心念念的GNN/GCN(主要是又可以去 ...

  2. ML基础 : 训练集,验证集,测试集关系及划分 Relation and Devision among training set, validation set and testing set...

    首先三个概念存在于 有监督学习的范畴 Training set: A set of examples used for learning, which is to fit the parameters ...

  3. 1. 训练集、开发集、测试集(Train/Dev/Test sets)

    1.在以往的机器学习中 如上图所示,以往机器学习中,对训练集.开发集.测试集的划分比例为60/20/20,如此划分通常可以获得较好的效果. 训练集(training set):训练算法. 开发集(de ...

  4. 机器学习数据集(训练集、测试集)划分方法

    数据集划分方法 留出(Hold-out)法 交叉验证(cross validation)法 自助法(bootstrap)   一个模型的好坏终归还是需要一个客观的评价标准,但是现有标准都比较难以适用于 ...

  5. 训练集、验证集、测试集划分

    一,搞清楚验证集 此段文字摘自<机器学习>周志华,第二章第二节评估方法 . 从文中可以get到几个点: (1)验证集和测试集不同. (2)验证集来自训练集的再划分. (3)验证集的划分是为 ...

  6. 【Python】深度学习中将数据按比例随机分成随机 训练集 和 测试集的python脚本

    深度学习中经常将数据分成 训练集 和 测试集,参考博客,修改python脚本 randPickAITrainTestData.py . 功能:从 输入目录 中随机检出一定比例的文件或目录,移动到保存 ...

  7. R语言图形用户界面数据挖掘包Rattle:加载UCI糖尿病数据集、并启动Rattle图形用户界面、数据集变量重命名,为数据集结果变量添加标签、数据划分(训练集、测试集、验证集)、随机数设置

    R语言图形用户界面数据挖掘包Rattle:加载UCI糖尿病数据集.并启动Rattle图形用户界面.数据集变量重命名,为数据集结果变量添加标签.数据划分(训练集.测试集.验证集).随机数设置 目录

  8. Python计算训练数据集(测试集)中某个分类变量阴性(阳性)标签样本的不同水平(level)或者分类值的统计个数以及比例

    Python计算训练数据集(测试集)中某个分类变量阴性(阳性)标签样本的不同水平(level)或者分类值的统计个数以及比例 目录

  9. python尝试不同的随机数进行数据划分、使用卡方检验依次计算不同随机数划分下训练接和测试集所有分类特征的卡方检验的p值,如果所有p值都大于0.05则训练集和测试集都具有统计显著性、数据划分合理

    python尝试不同的随机数进行数据划分.使用卡方检验依次计算不同随机数划分下训练接和测试集所有分类特征(categorical)的卡方检验的p值,如果所有p值都大于0.05则退出循环.则训练集和测试 ...

  10. Python计算医疗数据训练集、测试集的对应的临床特征:训练集(测试集)的阴性和阳性的样本个数、连续变量的均值(标准差)以及训练测试集阳性阴性的p值、离散变量的分类统计、比率、训练测试集阳性阴性的p值

    Python使用pandas和scipy计算医疗数据训练集.测试集的对应的临床特征:训练集(测试集)的阴性和阳性的样本个数.连续变量的均值(标准差࿰

最新文章

  1. 【硅谷牛仔】优步CEO,最倒霉的成功创业者 -- 特拉维斯·卡兰尼克
  2. 智慧校园“手环考勤”已成为学校常态
  3. Python全栈工程师(文件操作、编码)
  4. 加法器的verilog实现(串行进位、并联、超前进位、流水线)
  5. 示范NTFS 卷上的硬链接
  6. Protel中的快捷键使用(网上资源)
  7. AtCoder Regular Contest 059
  8. python xlwings追加数据_大数据分析Python库xlwings提升Excel工作效率教程
  9. ubuntu下搭建tftp服务器并且验证功能
  10. dubbo接口demo开发
  11. 详尽Ubuntu18安装搜狗输入法教程
  12. 一个很有趣的游戏(看谁的名字打架厉害)
  13. c#SQL参数化查询自动生成SqlParameter列表
  14. SAP中销售价格导致的无法发货的实例分析
  15. IDEA的TODO的使用
  16. HTTP Digest Authentication 使用心得
  17. navicat连接LinuxMySQL10038错误、mysql通过命令行进行导入导出sql文件
  18. 转 兵无常势 水无常形 贴
  19. 阿里巴巴内测全网社交产品来往
  20. CSUOJ 1644 超能陆战队

热门文章

  1. 格雷码的FPGA实现
  2. 谷歌强烈推荐!浏览器助手,让你的浏览器至少提升10个档次!
  3. 通信感知一体化概述(IMT-2030 6G)
  4. 华罗庚的《统筹方法》
  5. HoloView -- Tabular Datasets
  6. cfe刷机教程 斐讯k3_斐讯K3刷机教程官改V2.1D或者其它版本教程
  7. 【EasyPR】Linux安装使用EasyPR开源车牌识别系统
  8. 为什么会存在乱码?什么是编解码?为什么会有这么多字符集?
  9. 计算机网络(第七版)谢希仁知识点总结
  10. Mysql常用命令笔记