训练误差和泛化误差

在解释上述现象之前,我们需要区分训练误差(training error)和泛化误差(generalization error)。通俗来讲,前者指模型在训练数据集上表现出的误差,后者指模型在任意一个测试数据样本上表现出的误差的期望,并常常通过测试数据集上的误差来近似。计算训练误差和泛化误差可以使用之前介绍过的损失函数,例如线性回归用到的平方损失函数和softmax回归用到的交叉熵损失函数。

让我们以高考为例来直观地解释训练误差和泛化误差这两个概念。训练误差可以认为是做往年高考试题(训练题)时的错误率,泛化误差则可以通过真正参加高考(测试题)时的答题错误率来近似。假设训练题和测试题都随机采样于一个未知的依照相同考纲的巨大试题库。如果让一名未学习中学知识的小学生去答题,那么测试题和训练题的答题错误率可能很相近。但如果换成一名反复练习训练题的高三备考生答题,即使在训练题上做到了错误率为0,也不代表真实的高考成绩会如此。

在机器学习里,我们通常假设训练数据集(训练题)和测试数据集(测试题)里的每一个样本都是从同一个概率分布中相互独立地生成的。基于该独立同分布假设,给定任意一个机器学习模型(含参数),它的训练误差的期望和泛化误差都是一样的。例如,如果我们将模型参数设成随机值(小学生),那么训练误差和泛化误差会非常相近。但我们从前面几节中已经了解到,模型的参数是通过在训练数据集上训练模型而学习出的,参数的选择依据了最小化训练误差(高三备考生)。所以,训练误差的期望小于或等于泛化误差。也就是说,一般情况下,由训练数据集学到的模型参数会使模型在训练数据集上的表现优于或等于在测试数据集上的表现。由于无法从训练误差估计泛化误差,一味地降低训练误差并不意味着泛化误差一定会降低。

机器学习模型应关注降低泛化误差。

模型选择

在机器学习中,通常需要评估若干候选模型的表现并从中选择模型。这一过程称为模型选择(model selection)。可供选择的候选模型可以是有着不同超参数的同类模型。以多层感知机为例,我们可以选择隐藏层的个数,以及每个隐藏层中隐藏单元个数和激活函数。为了得到有效的模型,我们通常要在模型选择上下一番功夫。下面,我们来描述模型选择中经常使用的验证数据集(validation data set)。

验证数据集

从严格意义上讲,测试集只能在所有超参数和模型参数选定后使用一次。不可以使用测试数据选择模型,如调参。由于无法从训练误差估计泛化误差,因此也不应只依赖训练数据选择模型。鉴于此,我们可以预留一部分在训练数据集和测试数据集以外的数据来进行模型选择。这部分数据被称为验证数据集,简称验证集(validation set)。例如,我们可以从给定的训练集中随机选取一小部分作为验证集,而将剩余部分作为真正的训练集。

然而在实际应用中,由于数据不容易获取,测试数据极少只使用一次就丢弃。因此,实践中验证数据集和测试数据集的界限可能比较模糊。从严格意义上讲,除非明确说明,否则本书中实验所使用的测试集应为验证集,实验报告的测试结果(如测试准确率)应为验证结果(如验证准确率)。

KKK折交叉验证

由于验证数据集不参与模型训练,当训练数据不够用时,预留大量的验证数据显得太奢侈。一种改善的方法是KKK折交叉验证(KKK-fold cross-validation)。在KKK折交叉验证中,我们把原始训练数据集分割成KKK个不重合的子数据集,然后我们做KKK次模型训练和验证。每一次,我们使用一个子数据集验证模型,并使用其他K−1K-1K−1个子数据集来训练模型。在这KKK次训练和验证中,每次用来验证模型的子数据集都不同。最后,我们对这KKK次训练误差和验证误差分别求平均。

欠拟合和过拟合

接下来,我们将探究模型训练中经常出现的两类典型问题:一类是模型无法得到较低的训练误差,我们将这一现象称作欠拟合(underfitting);另一类是模型的训练误差远小于它在测试数据集上的误差,我们称该现象为过拟合(overfitting)。在实践中,我们要尽可能同时应对欠拟合和过拟合。虽然有很多因素可能导致这两种拟合问题,在这里我们重点讨论两个因素:模型复杂度和训练数据集大小。

关于模型复杂度和训练集大小对学习的影响的详细理论分析可参见这篇博客。

模型复杂度

为了解释模型复杂度,我们以多项式函数拟合为例。给定一个由标量数据特征xxx和对应的标量标签yyy组成的训练数据集,多项式函数拟合的目标是找一个KKK阶多项式函数

y^=b+∑k=1Kxkwk\hat{y} = b + \sum_{k=1}^K x^k w_k y^​=b+k=1∑K​xkwk​

来近似 yyy。在上式中,wkw_kwk​是模型的权重参数,bbb是偏差参数。与线性回归相同,多项式函数拟合也使用平方损失函数。特别地,一阶多项式函数拟合又叫线性函数拟合。

因为高阶多项式函数模型参数更多,模型函数的选择空间更大,所以高阶多项式函数比低阶多项式函数的复杂度更高。因此,高阶多项式函数比低阶多项式函数更容易在相同的训练数据集上得到更低的训练误差。给定训练数据集,模型复杂度和误差之间的关系通常如图3.4所示。给定训练数据集,如果模型的复杂度过低,很容易出现欠拟合;如果模型复杂度过高,很容易出现过拟合。应对欠拟合和过拟合的一个办法是针对数据集选择合适复杂度的模型。

训练数据集大小

影响欠拟合和过拟合的另一个重要因素是训练数据集的大小。一般来说,如果训练数据集中样本数过少,特别是比模型参数数量(按元素计)更少时,过拟合更容易发生。此外,泛化误差不会随训练数据集里样本数量增加而增大。因此,在计算资源允许的范围之内,我们通常希望训练数据集大一些,特别是在模型复杂度较高时,例如层数较多的深度学习模型。

多项式函数拟合实验

为了理解模型复杂度和训练数据集大小对欠拟合和过拟合的影响,下面我们以多项式函数拟合为例来实验。首先导入实验需要的包或模块。

%matplotlib inline
import torch
import numpy as np
import sys
sys.path.append("..")
import d2lzh_pytorch as d2l

生成数据集

我们将生成一个人工数据集。在训练数据集和测试数据集中,给定样本特征xxx,我们使用如下的三阶多项式函数来生成该样本的标签:

y=1.2x−3.4x2+5.6x3+5+ϵ,y = 1.2x - 3.4x^2 + 5.6x^3 + 5 + \epsilon,y=1.2x−3.4x2+5.6x3+5+ϵ,

其中噪声项ϵ\epsilonϵ服从均值为0、标准差为0.01的正态分布。训练数据集和测试数据集的样本数都设为100。

n_train, n_test, true_w, true_b = 100, 100, [1.2, -3.4, 5.6], 5
features = torch.randn((n_train + n_test, 1))
poly_features = torch.cat((features, torch.pow(features, 2), torch.pow(features, 3)), 1)
labels = (true_w[0] * poly_features[:, 0] + true_w[1] * poly_features[:, 1]+ true_w[2] * poly_features[:, 2] + true_b)
labels += torch.tensor(np.random.normal(0, 0.01, size=labels.size()), dtype=torch.float)

看一看生成的数据集的前两个样本。

features[:2], poly_features[:2], labels[:2]

输出:

(tensor([[-1.0613],[-0.8386]]), tensor([[-1.0613,  1.1264, -1.1954],[-0.8386,  0.7032, -0.5897]]), tensor([-6.8037, -1.7054]))

定义、训练和测试模型

我们先定义作图函数semilogy,其中 yyy 轴使用了对数尺度。

# 本函数已保存在d2lzh_pytorch包中方便以后使用
def semilogy(x_vals, y_vals, x_label, y_label, x2_vals=None, y2_vals=None,legend=None, figsize=(3.5, 2.5)):d2l.set_figsize(figsize)d2l.plt.xlabel(x_label)d2l.plt.ylabel(y_label)d2l.plt.semilogy(x_vals, y_vals)if x2_vals and y2_vals:d2l.plt.semilogy(x2_vals, y2_vals, linestyle=':')d2l.plt.legend(legend)

和线性回归一样,多项式函数拟合也使用平方损失函数。因为我们将尝试使用不同复杂度的模型来拟合生成的数据集,所以我们把模型定义部分放在fit_and_plot函数中。多项式函数拟合的训练和测试步骤与3.6节(softmax回归的从零开始实现)介绍的softmax回归中的相关步骤类似。

num_epochs, loss = 100, torch.nn.MSELoss()def fit_and_plot(train_features, test_features, train_labels, test_labels):net = torch.nn.Linear(train_features.shape[-1], 1)# 通过Linear文档可知,pytorch已经将参数初始化了,所以我们这里就不手动初始化了batch_size = min(10, train_labels.shape[0])    dataset = torch.utils.data.TensorDataset(train_features, train_labels)train_iter = torch.utils.data.DataLoader(dataset, batch_size, shuffle=True)optimizer = torch.optim.SGD(net.parameters(), lr=0.01)train_ls, test_ls = [], []for _ in range(num_epochs):for X, y in train_iter:l = loss(net(X), y.view(-1, 1))optimizer.zero_grad()l.backward()optimizer.step()train_labels = train_labels.view(-1, 1)test_labels = test_labels.view(-1, 1)train_ls.append(loss(net(train_features), train_labels).item())test_ls.append(loss(net(test_features), test_labels).item())print('final epoch: train loss', train_ls[-1], 'test loss', test_ls[-1])semilogy(range(1, num_epochs + 1), train_ls, 'epochs', 'loss',range(1, num_epochs + 1), test_ls, ['train', 'test'])print('weight:', net.weight.data,'\nbias:', net.bias.data)

三阶多项式函数拟合(正常)

我们先使用与数据生成函数同阶的三阶多项式函数拟合。实验表明,这个模型的训练误差和在测试数据集的误差都较低。训练出的模型参数也接近真实值:w1=1.2,w2=−3.4,w3=5.6,b=5w_1 = 1.2, w_2=-3.4, w_3=5.6, b = 5w1​=1.2,w2​=−3.4,w3​=5.6,b=5。

fit_and_plot(poly_features[:n_train, :], poly_features[n_train:, :], labels[:n_train], labels[n_train:])

输出:

final epoch: train loss 0.00010175639908993617 test loss 9.790256444830447e-05
weight: tensor([[ 1.1982, -3.3992,  5.6002]])
bias: tensor([5.0014])

线性函数拟合(欠拟合)

我们再试试线性函数拟合。很明显,该模型的训练误差在迭代早期下降后便很难继续降低。在完成最后一次迭代周期后,训练误差依旧很高。线性模型在非线性模型(如三阶多项式函数)生成的数据集上容易欠拟合。

fit_and_plot(features[:n_train, :], features[n_train:, :], labels[:n_train],labels[n_train:])

输出:

final epoch: train loss 249.35157775878906 test loss 168.37705993652344
weight: tensor([[19.4123]])
bias: tensor([0.5805])

训练样本不足(过拟合)

事实上,即便使用与数据生成模型同阶的三阶多项式函数模型,如果训练样本不足,该模型依然容易过拟合。让我们只使用两个样本来训练模型。显然,训练样本过少了,甚至少于模型参数的数量。这使模型显得过于复杂,以至于容易被训练数据中的噪声影响。在迭代过程中,尽管训练误差较低,但是测试数据集上的误差却很高。这是典型的过拟合现象。

fit_and_plot(poly_features[0:2, :], poly_features[n_train:, :], labels[0:2],labels[n_train:])

输出:

final epoch: train loss 1.198514699935913 test loss 166.037109375
weight: tensor([[1.4741, 2.1198, 2.5674]])
bias: tensor([3.1207])

小结

  • 由于无法从训练误差估计泛化误差,一味地降低训练误差并不意味着泛化误差一定会降低。机器学习模型应关注降低泛化误差。
  • 可以使用验证数据集来进行模型选择。
  • 欠拟合指模型无法得到较低的训练误差,过拟合指模型的训练误差远小于它在测试数据集上的误差。
  • 应选择复杂度合适的模型并避免使用过少的训练样本。

动手深度学习PyTorch(三)模型选择、欠拟合和过拟合相关推荐

  1. 动手深度学习——Pytorch 入门语法一

    持续更新ing 数据操作 import torchx = torch.arange(12) # 创建⼀个⾏向量x.这个⾏向量包含从0开始的前12个整数 print(x.shape) # 通过张量的sh ...

  2. 动手深度学习PyTorch(九)GRU、LSTM、Bi-RNN

    GRU 上一篇介绍了循环神经网络中的梯度计算方法.我们发现,当时间步数较大或者时间步较小时,循环神经网络的梯度较容易出现衰减或爆炸.虽然裁剪梯度可以应对梯度爆炸,但无法解决梯度衰减的问题.通常由于这个 ...

  3. 动手深度学习PyTorch(十二)word2vec

    独热编码 独热编码即 One-Hot 编码,又称一位有效编码,其方法是使用N位状态寄存器来对N个状态进行编码,每个状态都有它独立的寄存器位,并且在任意时候,其中只有一位有效.举个例子,假设我们有四个样 ...

  4. 动手学深度学习(PyTorch实现)(三)--过拟合与欠拟合

    过拟合与欠拟合 1. 过拟合与欠拟合 1.1 训练误差和泛化误差 1.2 模型选择 1.2.1 验证数据集 1.2.2 K折交叉验证 1.3 过拟合与欠拟合 1.3.1 模型复杂度 1.3.2 训练数 ...

  5. ElitesAI·动手学深度学习PyTorch版-第三次打卡

    1.过拟合欠拟合及其解决方案 1.1 模型选择.过拟合和欠拟合 在解释上述现象之前,我们需要区分训练误差(training error)和泛化误差(generalization error). 训练误 ...

  6. 过拟合欠拟合模拟 || 深度学习 || Pytorch || 动手学深度学习11 || 跟李沐学AI

    昔我往矣,杨柳依依.今我来思,雨雪霏霏. ---<采薇> 本文是对于跟李沐学AI--动手学深度学习第11节:模型选择 + 过拟合和欠拟合的代码实现.主要是通过使用线性回归模型在自己生成的数 ...

  7. [pytorch、学习] - 3.11 模型选择、欠拟合和过拟合

    参考 3.11 模型选择.欠拟合和过拟合 3.11.1 训练误差和泛化误差 在解释上述现象之前,我们需要区分训练误差(training error)和泛化误差(generalization error ...

  8. 【动手学深度学习PyTorch版】6 权重衰退

    上一篇移步[动手学深度学习PyTorch版]5 模型选择 + 过拟合和欠拟合_水w的博客-CSDN博客 目录 一.权重衰退 1.1 权重衰退 weight decay:处理过拟合的最常见方法(L2_p ...

  9. 伯禹公益AI《动手学深度学习PyTorch版》Task 04 学习笔记

    伯禹公益AI<动手学深度学习PyTorch版>Task 04 学习笔记 Task 04:机器翻译及相关技术:注意力机制与Seq2seq模型:Transformer 微信昵称:WarmIce ...

  10. 深度学习+pytorch自学笔记(三)——线性回归

    参考书籍<动手学深度学习(pytorch版),参考网址为:https://tangshusen.me/Dive-into-DL-PyTorch/#/ 请大家也多多支持这一个很好用的平台~ 大部分 ...

最新文章

  1. JAVA环境变量配置与配置后CMD的使用
  2. 涌现:21世纪科学的统一主题
  3. jQuery实现拖动布局并将排序结果保存到数据库
  4. 双脑协同RSVP目标检测
  5. Windows Server 2003 R2中的“分布式文件系统”案例应用
  6. 太难了!产品经理想拿高薪
  7. tensorflow.GraphDef was modified concurrently during serialization
  8. spring mvc 中对静态资源的访问配置
  9. Exchange Server 2013日记功能
  10. SVN文件上感叹号、加号、问号等图标的原因
  11. 随想录(一种powerpc编译学习的方法)
  12. 辗转相除法 两个数的最大公约数
  13. kubernetes-dashboard部署
  14. 串行通信接口:RS-232、RS-485和RS-422简述
  15. c#语言絢止函数是,取汉子拼音首字母的C#和VB.Net方法
  16. 常见视频封装格式(3) — MP4
  17. 安装linux ubuntu11系统时,应该如何选择键盘布局,在ubuntu上创建新键盘布局需要哪些步骤?...
  18. Python中的爬虫
  19. win10的系统mysql服务器地址,win10的系统mysql服务器地址
  20. ps中矫正镜头的一种方法

热门文章

  1. [分析力学]解题思路 - 最小作用量原理
  2. 论文中参考文献规范格式
  3. YOLOv5 完美实现中文标签显示
  4. native8081端口 react_教你轻松修改React Native端口(如何同时运行多个React Native、8081端口占用问题)...
  5. TP-Link路由器端口映射8081端口的Nexus服务外网无法访问的解决办法
  6. 异数OS 开放式闭源继承人协议
  7. NexT主题添加音乐
  8. mysql导致的502_ab压测过程中出现502及操作数据库失败
  9. Java求抛物线输入角度速度_知道初速度和抛物线的角度,怎么计算落点
  10. [ActiveForm] -- ActiveForm::begin表单用法