机器学习 学习曲线 Python实现学习曲线及案例解析

学习曲线

如果数据集的大小为 mmm,则通过下面的流程即可画出学习曲线:

  • 把数据集分成训练数据集和交叉验证数据集。
  • 取训练数据集的 20%20\%20% 作为训练样本,训练出模型参数。
  • 使用交叉验证数据集来计算训练出来的模型的准确性。
  • 以训练数据集的准确性,交叉验证的准确性作为纵坐标,训练数据集个数作为横坐标,在坐标轴上画出上述步骤计算出来的模型准确性。
  • 训练数据集增加 10%10\%10%,跳到步骤3继续执行,直到训练数据集大小为 100%100\%100% 为止。

学习曲线要表达的内容是,当训练数据集增加时,模型对训练数据集你和的准确性以及交叉验证数据集预测的准确性的变化规律

实例:画出学习曲线

生成一个在y=xy=\sqrt{x}y=x​附件波动的点来作为训练样本。

import numpy as np
n_dots = 200X = np.linspace(0, 1, n_dots)
y = np.sqrt(X) + 0.2 * np.random.rand(n_dots) - 0.1# 因为 sklearn 的接口里,需要用到 n_sample x n_feature 的矩阵
# 所以需要转化为 200 x 1 的矩阵X = X.reshape(-1, 1)
y = y.reshape(-1, 1)

需要构造一个多项式模型

在scikit-learn里,需要用 Pipeline 来构造多项式模型,Pipeline 的意思是流水线,即这个流水线里可以包含多个数据处理模型,前一个模型处理完,转到下一个模型处理。

from sklearn.pipeline import Pipeline
from sklearn.preprocessing import PolynomialFeatures
from sklearn.linear_model import LinearRegressiondef polynomial_model(degree=1):polynomial_features = PolynomialFeatures(degree=degree, include_bias=False)linear_regression = LinearRegression()# 这是一个流水线,先增加多项式阶数,然后再用先行回归算法来拟合数据pipeline = Pipeline([('polynomial_features', polynomial_features), ("linear_regression", linear_regression)])return pipeline

polynomial_model()函数生成一个多项式模型,其中参数 degree 表示多项式的阶数
,比如polynomail_model(3)将生成一个三阶多项式的模型。

在scikit-learn里面,我们不用自己去实现学习曲线算法,直接使用 sklearn.model_selection.learning_curve()函数来画出学习曲线,它会自动把训练样本的数量按照预定的规则逐渐增加,然后画出不同的训练样本数量时的模型准确性。
其中 train_sizes 参数就是指定训练样本数量的变化规则,比如 train_sizes=np.linspace(.1, 1.0, 5)表示把训练样本数量从 0.1∼10.1\sim10.1∼1 分成五等分,生成 [0.1,0.352,0.55,0.775,1][ 0.1, 0.352, 0.55, 0.775, 1][0.1,0.352,0.55,0.775,1] 的序列,从序列中取出训练样本数量百分比,逐个计算在当前训练样本数量情况下训练出来的模型准确性。

from sklearn.model_selection import learning_curve
from sklearn.model_selection import ShuffleSplit
import matplotlib.pyplot as pltdef plot_learning_curve(estimator, title, X, y, ylim=None, cv=None, n_jobs=1, train_sizes=np.linspace(.1, 1.0, 5)):plt.title(title)if ylim is not None:plt.ylim(*ylim)plt.xlabel("Training examples")plt.ylabel("Score")train_sizes, train_scores, test_scores = learning_curve(estimator, X, y, cv=cv, n_jobs=n_jobs, train_sizes=train_sizes)train_scores_mean = np.mean(train_scores, axis=1)train_scores_std = np.std(train_scores, axis=1)test_scores_mean = np.mean(test_scores, axis=1)test_scores_std = np.std(test_scores, axis=1)plt.grid()plt.fill_between(train_sizes, train_scores_mean - train_scores_std, train_scores_mean + train_scores_std, alpha=0.1, color="r")plt.fill_between(train_sizes, test_scores_mean - test_scores_std, test_scores_mean + test_scores_std, alpha=0.1, color="g")plt.plot(train_sizes, train_scores_mean, 's--', color="r", label="Training score")plt.plot(train_sizes, test_scores_mean, 'o-', color="g", label="Cross-validation score")plt.legend(loc="best")return plt

这个函数实现的功能就是画出模型的学习曲线。

其中有个细节需要注意,当计算模型的准确性时,是随机从数据集中分配出训练样本和交叉验证样本,这样会导致数据分布不均匀。
即同样训练样本数量的模型,由于随机分配,导致每次计算出来的准确性都不一样。
为了解决这个问题,我们在计算模型的准确性时,多次计算,并求准确性的的平均值和方差。
上述代码中 plt.fill_between() 函数会把模型准确性的平均值的上下方差的空间里用颜色填充。
然后用plt.plot()函数画出模型准确性的平均值。上诉函数画出了训练样本的的准确性,也画出了交叉验证样本的准确性。

使用ploynomial_model()函数构造出3个模型,分别是一阶多项式、三阶多项式、十阶多项式,分别画出这3个模型的学习曲线。
# 为了让学习曲线更平滑,计算10次交叉验证数据集的分数
cv = ShuffleSplit(n_splits=10, test_size=0.2, random_state=0)
titles = ['Learning Curves (Under Fitting)', 'Learning Curves', 'Learning Curves (Over Fitting)']
degrees = [1, 3, 10]plt.figure(figsize=(18, 4), dpi=200)
for i in range(len(degrees)):plt.subplot(1, 3, i + 1)plot_learning_curve(polynomial_model(degrees[i]), titles[i], X, y, ylim=(0.75, 1.01), cv=cv)plt.show()


左图:一阶多项式,欠拟合;
中图:三阶多项式,较好地拟合了数据集;
右图:十阶多项式,过拟合。
虚线:针对训练数据集计算出来的分数,即针对训练数据集拟合的准确性,
实线:针对交叉验证数据集计算出来的分数,即针对交叉验证数据集预测的准确性。

从左图我们可以观察到,当模型欠拟合(High Bias, Under Fitting)时,随着训练数据集的增加,交叉验证数据集的准确性(实线)逐渐增大,逐渐和训练数据集的准确性(虚线)靠近,但其总体水平比较低,收敛在 0.880.880.88 左右。其训练数据集的准确性也比较低,收敛在 0.900.900.90 左右。这就是过拟合的表现。从这个关系可以看出来,当发生高偏差时,增加训练样本数量不会对算法准确性有较大的改善

从右图我们可以观察到,当模型过拟合(High Variance, Over Fitting)时,随着训练数据集的增加,交叉验证数据集的准确性(实线)也在增加,逐渐和训练数据集的准确性(虚线)靠近,但两者之间的间隙比较大。
训练数据集的准确性很高,收敛在 0.950.950.95 左右,是三者中最高的,但其交叉验证数据集的准确性值却较低,最终收敛在 0.910.910.91 左右。

中图,我们选择的三阶多阶式较好地拟合了数据,最终训练数据集的准确性(虚线)和交叉验证数据集的准确性(实线)靠得很近,最终交叉验证数据集收敛在 0.930.930.93 附近,训练数据集的准确性收敛在 0.940.940.94 附近。3个模型对比,这个模型的准确性最好。

当需要改进学习算法时,可以画出学习曲线,以便判断算法时处在高偏差还是高分差问题。
学习曲线是诊断模型算法准确性的一个非常重要的工具。

过拟合和欠拟合的特征

到此,我们可以总结出过拟合和欠拟合的特点如下。

  • 过拟合:模型对训练数据集的准确性比较高,其成本 Jtrain(θ)J_{train}(\theta)Jtrain​(θ)比较低,对交叉验证数据集的准确性比较低,其成本 Jcv(θ)J_{cv}(\theta)Jcv​(θ) 比较高。
  • 欠拟合:模型对训练数据集的准确性比较低,其成本 Jtrain(θ)J_{train}(\theta)Jtrain​(θ)比较高,对交叉验证数据集的准确性夜比较低,其成本 Jcv(θ)J_{cv}(\theta)Jcv​(θ) 也比较高。

一个好的机器学习算法应该是对训练数据集准确性高、成本低,即较准确地拟合数据,同时对交叉验证数据集准确性高、成本低、误差小,即对未知数据有良好的预测性。

机器学习 学习曲线 Python实现学习曲线及案例解析相关推荐

  1. Python使用pyexecjs代码案例解析

    针对现在大部分的网站都是使用js加密,js加载的,并不能直接抓取出来,这时候就不得不适用一些三方类库来执行js语句 execjs,一个比较好用且容易上手的类库(支持py2,与py3),支持 JS ru ...

  2. python爬虫正则表达式实例-Python 正则表达式爬虫使用案例解析

    现在拥有了正则表达式这把神兵利器,我们就可以进行对爬取到的全部网页源代码进行筛选了. 下面我们一起尝试一下爬取内涵段子网站: 打开之后,不难看出里面一个一个非常有内涵的段子,当你进行翻页的时候,注意u ...

  3. python查找字符串关键词_Python字符串查找基本操作案例解析

    本篇文章小编给大家分享一下Python字符串查找基本操作案例解析,文章介绍的很详细,小编觉得挺不错的,现在分享给大家供大家参考,有需要的小伙伴们可以来看看. 字符串查找基本操作主要分为三个关键词:fi ...

  4. python案例解析_python案例讲解

    <21天学通Python(第2版)>| 每日读本书 编辑推荐 基础知识→核心技术→典型实例→综合练习→项目案例,轻松上手与提高.全面掌握Python只需21天! √ 160个典型案例.2个 ...

  5. 【Python五篇慢慢弹(5)】类的继承案例解析,python相关知识延伸

    类的继承案例解析,python相关知识延伸 作者:白宁超 2016年10月10日22:36:57 摘要:继<快速上手学python>一文之后,笔者又将python官方文档认真学习下.官方给 ...

  6. python scrapy爬虫视频_python爬虫scrapy框架的梨视频案例解析

    之前我们使用lxml对梨视频网站中的视频进行了下载 下面我用scrapy框架对梨视频网站中的视频标题和视频页中对视频的描述进行爬取 分析:我们要爬取的内容并不在同一个页面,视频描述内容需要我们点开视频 ...

  7. 机器学习中的数学——学习曲线如何区别欠拟合与过拟合

    通过这篇博客,你将清晰的明白什么是如何区别欠拟合与过拟合.这个专栏名为白话机器学习中数学学习笔记,主要是用来分享一下我在 机器学习中的学习笔记及一些感悟,也希望对你的学习有帮助哦!感兴趣的小伙伴欢迎私 ...

  8. Stanford机器学习---第六周.学习曲线、机器学习系统的设计

    第六周.学习曲线.机器学习系统的设计 Learning Curve and Machine Learning System Design 关键词 学习曲线.偏差方差诊断法.误差分析.机器学习系统的数值 ...

  9. 类的继承python事例_【Python五篇慢慢弹(5)】类的继承案例解析,python相关知识延伸...

    作者:白宁超 2016年10月10日22:36:57 摘要:继一文之后,笔者又将python官方文档认真学习下.官方给出的pythondoc入门资料包含了基本要点.本文是对文档常用核心要点进行梳理,简 ...

最新文章

  1. 看懂SQL Server的查询计划(绝对好文!)
  2. Lucene.Net 2.3.1开发介绍 —— 三、索引(六)
  3. 0-1背包-分支限界
  4. python爬虫登陆网页版腾讯课堂
  5. Flink 1.13,面向流批一体的运行时与 DataStream API 优化
  6. python查找第二次输入字符串在第一次字符串中出现的次数
  7. TCP header
  8. wsimport指令
  9. C语言课设——电影院选票系统
  10. 视频水印怎么去除?超简单 千万不要错过
  11. 【汇智学堂】基于Socket实现的网络版梅花易数一撮金游戏
  12. 科比投篮选择——数据采集
  13. #clickid#CID#全新小程序链路CID/clickid解决方案,合规、完美防阿里封禁
  14. 浅谈CPU位数和操作系统位数
  15. 20189216 2018-2019-2 《密码与安全新技术专题》第二次作业
  16. HDR视频色调映射算法(之五:flicker reduction TMO)
  17. 云服务器建网站(安装Java与Tomcat)
  18. WIN7/WIN10 临时及永久 强制关闭驱动签名验证
  19. OneNote for Windows 10 笔记丢失踩雷
  20. 梦里Babel知多少(一)

热门文章

  1. python库之—psycopg2
  2. 【二分】Kevin喜欢零
  3. 3C低头族 小心飞蚊症找上你
  4. [Java]Java的静态构造函数 多线程下安全的单例模式
  5. 普鸥知识产权|为什么大家都要注册欧盟商标?有什么优势?
  6. Redis之介绍、下载安装
  7. 大厂嫡系文化,养肥了谁?
  8. 罗赛塔科技获数百万元人民币天使轮融资
  9. [附源码]JAVA+ssm基于远程协作的汽车故障诊断系统(程序+Lw)
  10. 破解马赛克有多「容易」?