模型选择、欠拟合和过拟合

【代码】

引入:当模型在训练数据集上更准确时,它在测试数据集上却不一定更准确。

1. 训练误差和泛化误差

机器学习模型应关注降低泛化误差

要区分训练误差(training error)和泛化误差(generalization error)。通俗来讲,前者指模型在训练数据集上表现出的误差,后者指模型在任意⼀个测试数据样本上表现出的误差的期望,并常常通过测试数据集上的误差来近似。计算训练误差和泛化误差可以使用损失函数。

训练误差可以认为是做往年高考试题(训练题)时的错误率,泛化误差则可以通过真正参加高考(测试题)时的答题错误率来近似。假设训练题和测试题都随机采样于⼀个未知的依照相同考纲的巨大试题库。如果让⼀名未学习中学知识的小学生去答题,那么测试题和训练题的答题错误率可能很相近。但如果换成⼀名反复练习训练题的高三备考⽣答题,即使在训练题上做到了错误率为0,也不代表真实的⾼考成绩会如此。

假设训练数据集(训练题)和测试数据集(测试题)⾥的每⼀个样本都是从同⼀个概率分布中相互独立地生成的。基于该独立同分布假设,给定任意⼀个机器学习模型(含参数),它的训练误差的期望和泛化误差都是⼀样的。例如,如果我们将模型参数设成随机值(小学生),那么训练误差和泛化误差会非常相近。模型的参数是通过在训练数据集上训练模型而学习出的,参数的选择依据了最小化训练误差(⾼三备考生)。所以,训练误差的期望小于或等于泛化误差。也就是说,⼀般情况下,由训练数据集学到的模型参数会使模型在训练数据集上的表现优于或等于在测试数据集上的表现。由于无法从训练误差估计泛化误差,⼀味地降低训练误差并不意味着泛化误差⼀定会降低。

2. 模型选择

在机器学习中,通常需要评估若干候选模型的表现并从中选择模型。可供选择的候选模型可以是有着不同超参数的同类模型。以多层感知机为例,我们可以选择隐藏层的个数,以及每个隐藏层中隐藏单元个数和激活函数。

验证数据集

从严格意义上讲,测试集只能在所有超参数和模型参数选定后使用⼀次。不可以使⽤测试数据选择模型,如调参。由于无法从训练误差估计泛化误差,因此也不应只依赖训练数据选择模型。鉴于此,我们可以预留⼀部分在训练数据集和测试数据集以外的数据来进⾏模型选择。这部分数据被称为验证数据集,简称验证集(validation set)。例如,我们可以从给定的训练集中随机选取⼀小部分作为验证集,而将剩余部分作为真正的训练集。

K折交叉验证

由于验证数据集不参与模型训练,当训练数据不够⽤时,预留大量的验证数据显得太奢侈。⼀种改善的方法是K折交叉验证(K-fold cross-validation)。在K折交叉验证中,我们把原始训练数据分割成K个不重合的子数据集,然后我们做K次模型和验证。每一次,我们使用一个子数据集验证模型,并使用其他K-1个子数据集来训练模型。在这K次训练和验证中,每次用来验证模型的子数据集都不同。最后,我们对这K次训练误差和验证误差分别求平均。

3. 欠拟合和过拟合

模型训练中经常出现的两类典型问题:⼀类是模型无法得到较低的训练误差,我们将这⼀现象称作欠拟合(underfitting);另⼀类是模型的训练误差远小于它在测试数据集上的误差,我们称该现象为过拟合(overfitting)。

模型复杂度

为了解释模型复杂度,我们以多项式函数拟合为例。给定⼀个由标量数据特征xxx和对应的标量标签yyy组成的训练数据集,多项式函数拟合的⽬标是找⼀个KKK阶多项式函数:
y^=b+∑k=1Kxkwk(1)\hat{y}=b+\sum_{k=1}^{K} x^k w_k \tag 1 y^​=b+k=1∑K​xkwk​(1)
来近似yyy。在上式中,wkw_kwk​是模型的权重参数,bbb是偏差参数。与线性回归相同,多项式函数拟合也是用平方损失函数。

因为高阶多项式函数模型参数更多,模型函数的选择空间更⼤,所以高阶多项式函数比低阶多项式函数的复杂度更⾼。因此,高阶多项式函数⽐低阶多项式函数更容易在相同的训练数据集上得到更低的训练误差。给定训练数据集,模型复杂度和误差之间的关系通常如图所⽰。给定训练数据集,如果模型的复杂度过低,很容易出现欠拟合;如果模型复杂度过高,很容易出现过拟合。应对⽋拟合和过拟合的⼀个办法是针对数据集选择合适复杂度的模型。

训练数据集大小

影响⽋拟合和过拟合的另⼀个重要因素是训练数据集的⼤小。⼀般来说,如果训练数据集中样本数过少,特别是⽐模型参数数量(按元素计)更少时,过拟合更容易发⽣。此外,泛化误差不会随训练数据集⾥样本数量增加而增⼤。因此,在计算资源允许的范围之内,我们通常希望训练数据集大⼀些,特别是在模型复杂度较⾼时,例如层数较多的深度学习模型。

4. 多项式函数拟合实验

# 导包
%matplotlib inline
from mxnet import autograd, gluon, nd
from mxnet.gluon import data as gdata, loss as gloss, nn

生成数据集

我们将⽣成⼀个人工数据集。在训练数据集和测试数据集中,给定样本特征xxx,我们使⽤如下的三阶多项式函数来⽣成该样本的标签:
y=1.2x−3.4x2+5.6x3+5+ε(2)y=1.2x-3.4x^2+5.6x^3+5+\varepsilon \tag 2 y=1.2x−3.4x2+5.6x3+5+ε(2)
其中噪声项ε\varepsilonε服从均值为0、标准差为0.1的正态分布。训练数据集和测试数据集的样本数都设为100。

n_train, n_test, true_w, true_b = 100, 100, [1.2, -3.4, 5.6], 5
features = nd.random.normal(shape=(n_train + n_test, 1))
poly_features = nd.concat(features, nd.power(features, 2), nd.power(features, 3))
labels = (true_w[0] * poly_features[:,0] + true_w[1] * poly_features[:,1] + true_w[2] * poly_features[:,2] + true_b)
labels += nd.random.normal(scale=0.1, shape=labels.shape)

生成的数据集的前两个样本。

features[:2], poly_features[:2], labels[:2]
([[1.1630785][0.4838046]]<NDArray 2x1 @cpu(0)>,[[1.1630785  1.3527517  1.5733565 ][0.4838046  0.2340669  0.11324265]]<NDArray 2x3 @cpu(0)>,[10.534649  5.530093]<NDArray 2 @cpu(0)>)

定义、训练和测试模型

先定义作图函数semilogy,其中yyy轴使用了对数尺度。

import matplotlib.pyplot as plt
from utils import set_figsize
def semilogy(x_vals, y_vals, x_label, y_label, x2_vals=None, y2_vals=None, legend=None, figsize=(3.5, 2.5)):set_figsize(figsize)plt.xlabel(x_label)plt.ylabel(y_label)plt.semilogy(x_vals, y_vals)if x2_vals and y2_vals:plt.semilogy(x2_vals, y2_vals, linestyle=':')plt.legend(legend)
num_epochs, loss = 100, gloss.L2Loss()
def fit_and_plot(train_features, test_features, train_labels, test_labels):net = nn.Sequential()net.add(nn.Dense(1))net.initialize()batch_size = min(10, train_labels.shape[0])train_iter = gdata.DataLoader(gdata.ArrayDataset(train_features, train_labels), batch_size, shuffle=True)trainer = gluon.Trainer(net.collect_params(), 'sgd', {'learning_rate': 0.01})train_ls, test_ls = [], []for _ in range(num_epochs):for X, y in train_iter:with autograd.record():l = loss(net(X), y)l.backward()trainer.step(batch_size)train_ls.append(loss(net(train_features), train_labels).mean().asscalar())test_ls.append(loss(net(test_features), test_labels).mean().asscalar())print('final epoch: train loss', train_ls[-1], 'test loss', test_ls[-1])semilogy(range(1, num_epochs + 1), train_ls, 'epochs', 'loss', range(1, num_epochs + 1), test_ls, ['train', 'test'])print('weight:', net[0].weight.data().asnumpy(), '\nbias:', net[0].bias.data().asnumpy())

三阶多项式函数拟合(正常)

我们先使用与数据生成函数同阶的三阶多项式函数拟合。实验表明,这个模型的训练误差和在测试数据集的误差都较低。

fit_and_plot(poly_features[:n_train, :], poly_features[n_train:, :], labels[:n_train], labels[n_train:])
final epoch: train loss 0.00698169 test loss 0.0063500497
weight: [[ 1.1729248 -3.3906946  5.604663 ]]
bias: [4.985479]

线性函数拟合(欠拟合)

线性函数拟合。很明显,该模型的训练误差在迭代早期下降后便很难继续降低。在完
成最后⼀次迭代周期后,训练误差依旧很⾼。线性模型在⾮线性模型(如三阶多项式函数)⽣成的数据集上容易⽋拟合。

fit_and_plot(features[:n_train, :], features[n_train:, :], labels[:n_train], labels[n_train:])
final epoch: train loss 159.33257 test loss 102.91761
weight: [[22.651974]]
bias: [-0.65602565]

训练样本不足(过拟合)

即便使⽤与数据⽣成模型同阶的三阶多项式函数模型,如果训练样本不⾜,该模型依然
容易过拟合。让我们只使⽤两个样本来训练模型。显然,训练样本过少了,甚⾄少于模型参数的数量。这使模型显得过于复杂,以⾄于容易被训练数据中的噪声影响。在迭代过程中,尽管训练误差较低,但是测试数据集上的误差却很⾼。这是典型的过拟合现象。

fit_and_plot(poly_features[0:2,:], poly_features[n_train:,:], labels[0:2], labels[n_train:])
final epoch: train loss 0.47576833 test loss 133.27455
weight: [[2.0588458 1.9273669 2.0477402]]
bias: [2.482129]

模型选择、欠拟合和过拟合相关推荐

  1. [pytorch、学习] - 3.11 模型选择、欠拟合和过拟合

    参考 3.11 模型选择.欠拟合和过拟合 3.11.1 训练误差和泛化误差 在解释上述现象之前,我们需要区分训练误差(training error)和泛化误差(generalization error ...

  2. 欠拟合和过拟合以及如何选择模型

    模型选择.欠拟合和过拟合 在前几节基于Fashion-MNIST数据集的实验中,我们评价了机器学习模型在训练数据集和测试数据集上的表现.如果你改变过实验中的模型结构或者超参数,你也许发现了:当模型在训 ...

  3. 从多项式函数拟合实验出发浅谈“模型选择、欠拟合和过拟合”问题

    在本笔记中,我们将从简单易懂的多项式函数拟合实验出发,谈一谈如今做机器学习绕不开的三个重要概念:模型选择.欠拟合和过拟合,并且进一步挖掘如何选择模型.如何避免欠拟合和过拟合问题.本笔记主要从下面五个方 ...

  4. 从零开始学Pytorch(五)之欠拟合和过拟合

    本文首发于微信公众号"计算机视觉cv" 模型选择.过拟合和欠拟合 训练误差和泛化误差 训练误差(training error)指模型在训练数据集上表现出的误差,泛化误差(gener ...

  5. 机器学习:什么是欠拟合和过拟合

    https://blog.csdn.net/u011630575/article/details/71158656 1. 什么是欠拟合和过拟合 先看三张图片,这三张图片是线性回归模型 拟合的函数和训练 ...

  6. 1.5 欠拟合和过拟合

    1.5 欠拟合和过拟合 欠拟合(Underfitting):选择的模型过于简单,以致于模型对训练集和未知数据的预测都很差的现象. 过拟合(Overfitting):选择的模型过于复杂(所包含的参数过多 ...

  7. 欠拟合和过拟合学习笔记

    欠拟合和过拟合学习笔记 https://www.cnblogs.com/DicksonJYL/p/9620464.html 在建模的过程中会经常出现1.模型的效果,但是泛化能力弱,2.模型的结果很差的 ...

  8. 线性回归;欠拟合和过拟合

    线性回归 定义:线性回归通过一个或者多个自变量与因变量之间之间进行建模的回归分析.其中特点为一个或多个称为回归系数的模型参数的线性组合. 线性回归的误差大小通过损失函数来计算–最小二乘法,目的是去寻找 ...

  9. 解决欠拟合和过拟合的几种方法

    深度学习 欠拟合和过拟合的问题 ... 如何解决欠拟合和过拟合的问题? 深度学习 前言 一.介绍 二.如何解决欠拟合问题 三.如何解决过拟合问题 总结 前言   我们可以将搭建的模型是否发生欠拟合或者 ...

最新文章

  1. 富文本框让最大四百像素_TinyMCE 富文本编辑器 ━━ 基本配置
  2. 树莓派AI视觉云台——6、Linux常用命令及vim编辑器的使用
  3. SAP S/4HANA OData Mock Service 介绍
  4. Educational Codeforces Round 37 (Rated for Div. 2) 1
  5. 电商系统的商品规格设计方案
  6. 大工计算机基础在线作业答案,大工1209《计算机应用基础》在线作业123.doc
  7. iOS 播放音频的几种方法
  8. Java Web架构演变
  9. MSCRM与MS人立方关系的集成
  10. python : autopep8
  11. Script.NET 1.0版本的Tcl+Html界面编程原理
  12. 利用mybatis插件开发动态更改sql
  13. php k线公式源码,黄金K线主图源码
  14. 计算机桌面图标的使用,电脑桌面图标不见了怎么恢复 如何规范使用电脑
  15. zblog火车头采集经验
  16. IllegalStateException: Couldn‘t read row 0, col 10 from CursorWindow. Make sure the Cursor is initi
  17. 变异系数在线计算机,数理统计在线计算器
  18. 什么是正向代理,开放的代理软件使用
  19. 服务器配置记录(五)GNS3虚拟机SSH配置
  20. 【C语言习题】统计君君提水的桶数(不使用ceil函数与floor函数)

热门文章

  1. Mac 配置Nginx域名转发
  2. 图解IFRS9 金融工具(4)金融资产分类框架
  3. 杭州程序员从互联网跳央企,不用996工作摸鱼真的爽!
  4. CentOS 安装 java 环境安装及配置
  5. 黄帝内经.素问.热论篇(31)
  6. 【Java数据结构[链表--单向链表]】
  7. 【附源码】Java计算机毕业设计安卓英语学习app(程序+LW+部署)
  8. 大数据运维方向面试题
  9. python的算法是指_python中的算法
  10. 不可思议的《魔兽世界》