引言

在机器学习中,我们将模型在训练集上的误差称之为训练误差,又称之为经验误差,在新的数据集(比如测试集)上的误差称之为泛化误差,泛化误差也可以说是模型在总体样本上的误差。对于一个好的模型应该是经验误差约等于泛化误差,也就是经验误差要收敛于泛化误差,根据霍夫丁不等式可知经验误差在一定条件下是可以收敛于泛化误差的。
当机器学习模型对训练集学习的太好的时候(再学习数据集的通性的时候,也学习了数据集上的特性,这些特性是会影响模型在新的数据集上的表达能力的,也就是泛化能力),此时表现为经验误差很小,但往往此时的泛化误差会很大,这种情况我们称之为过拟合,而当模型在数据集上学习的不够好的时候,此时经验误差较大,这种情况我们称之为欠拟合。具体表现如下图所示,第一幅图就是欠拟合,第三幅图就是过拟合。

过拟合

原因

1、训练集的数量级和模型的复杂度不匹配。训练集的数量级要小于模型的复杂度

2、训练集和测试集特征分布不一致

3、样本里的噪音数据干扰过大,大到模型过分的记住了噪音特征,而忽略了真实的输入输出间的关系

4、权值学习迭代次数足够多(Overtraining),拟合了训练数据中的噪声和训练样例中没有代表性的特征

方法

1、simpler model structure
调小模型复杂度,使其适合自己训练集的数量级(缩小宽度和减小深度)

2、data augmentation
训练集越多,过拟合的概率越小。在计算机视觉领域中,增广的方式是对图像旋转,缩放,剪切,添加噪声等

3、regularization
参数太多,会导致我们的模型复杂度上升,容易过拟合,也就是我们的训练误差会很小。正则化是指通过引入额外新信息来解决过拟合问题的一种。这种额外信息通常的形式是模型复杂性带来的惩罚度。正则化可以保持模型简单,另外,规则项的使用还可以约束我们模型的特性。

4、dropout
在训练的时候让神经元以一定概率不工作。
dropout会导致网络的训练速度慢2、3倍,而且数据小的时候,dropout的效果并不会太好,因此只会在大型网络上使用。

左图没有dropout的标准神经网络,右图有dropout的神经网络,即在训练时候以一定的概率p来跳过一定的神经元

5、Early stopping
对模型进行训练的过程即使对模型的参数进行学习更新的过程,这个参数学习的过程往往会用到一些迭代方法,如梯度下降(Gradient descent)学习算法。Early stopping便是一种迭代次数截断的方法来防止过拟合的,即在模型对训练数据集迭代收敛之前停止迭代来防止过拟合。
其具体做法是,在每一个Epoch结束时(一个Epoch集为对所有的训练数据的一轮遍历)计算validation data 的 accuracy,当accuracy不再提高时,就停止训练。这种做法很符合直观感受,因为accuracy都不再提高了,在继续训练也是无益的。那么该做法的重点就是“怎样才认为accuracy不再提高了?” “并不是accuracy一降下来便认为不再提高了,因为可能经过这个Epoch后,accuracy降低了,但是随后的Epoch又让accuracy上去了,因此不能根据一次两次连续降低就判断不再提高了”。
一般做法是:在训练过程中,记录到目前为止最好的alidation accuracy,当连续10次Epoch(或者更多次)没达到最佳accuracy时,则可以认为accuracy不再提高了。此时便可以停止迭代了(Early Stopping),这种策略也称为“No-improvement-in-n” , n即Epoch的次数,可以根据实际情况选取,如10、20、30…

开源代码

import numpy as np
import torchclass EarlyStopping:"""Early stops the training if validation loss doesn't improve after a given patience."""def __init__(self, patience=7, verbose=False, delta=0, path='checkpoint.pt', trace_func=print):"""Args:patience (int): How long to wait after last time validation loss improved.上次验证集损失值改善后等待几个epochDefault: 7verbose (bool): If True, prints a message for each validation loss improvement. 如果是True,为每个验证集损失值改善打印一条信息Default: Falsedelta (float): Minimum change in the monitored quantity to qualify as an improvement.检测数量的最小变化,以符合改进的要求Default: 0path (str): Path for the checkpoint to be saved to.Default: 'checkpoint.pt'trace_func (function): trace print function.Default: print            """self.patience = patienceself.verbose = verboseself.counter = 0self.best_score = Noneself.early_stop = Falseself.val_loss_min = np.Infself.delta = deltaself.path = pathself.trace_func = trace_funcdef __call__(self, val_loss, model):score = -val_lossif self.best_score is None:self.best_score = scoreself.save_checkpoint(val_loss, model)elif score < self.best_score + self.delta:self.counter += 1self.trace_func(f'EarlyStopping counter: {self.counter} out of {self.patience}')if self.counter >= self.patience:self.early_stop = Trueelse:self.best_score = scoreself.save_checkpoint(val_loss, model)self.counter = 0def save_checkpoint(self, val_loss, model):'''Saves model when validation loss decrease.验证损失减少时保存模型'''if self.verbose:self.trace_func(f'Validation loss decreased ({self.val_loss_min:.6f} --> {val_loss:.6f}).  Saving model ...')# 这里会存储迄今最优的模型torch.save(model.state_dict(), self.path)self.val_loss_min = val_loss

伪代码——pytorch

# 打包数据集mydata_loader = Dataloader(dataset= , batch_size= 批量大小, shuffle = True or False 是否打乱数据顺序)
model = MyModel()
# 指定损失函数, 这里是交叉熵损失函数
criterion = nn.CrossEntropyLoss()
# 指定优化器
optimizer = torch.nn.Adam(model.parameters())
# 初始化early stopping 对象
# 这里是当验证集损失在连续20次训练周期中都没有得到降低时,停止模型训练,防止过拟合
patience = 20
early_stopping = EarlyStopping(patience = patience, verbose = True)batch_size = 64
n_epochs = 100 #可以设置大一点,比较希望通过早停法来结束模型训练# 训练模型,直到epoch == n_epochs 或者触发early_stopping来结束训练for epoch in range(1, n_epochs+1):# 模型设置为训练模式model.train()# 按小批量训练for batch, (data, label) in enumerate(mydata_loader):# 清空所有参数的梯度optimizer.zero_grad()# 输出模型预测值output = model(data)# 计算损失loss = criterion(output, label)# 计算损失对于各个参数的梯度loss.backward()# 执行单步优化操作: 更新参数optimizer.step()# 模型设置为评估/测试模式,关闭dropout,并将模型参数锁定model.eval()# 一般如果验证集不是很大,模型验证时就不需要按批量进行,但要注意输入参数的维度不能错# 预测数据pre_data 对应目标/标签 pre_labelvalid_output = model(pre_data)valid_loss = criterion(valid_output, pre_label)early_stopping(valid_loss, model)# 若满足early stopping 要求if early_stopping.early_stop:print("Early Stopping!")# 结束模型训练break# 获得 early stopping 时的模型参数
model.load_state_dict(torch.load('checkpoint.pt'))

参考文档:pytorch-实现早停法

参考文档:深度学习技巧之早停法
6、ensemble
集成学习算法也可以有效的减轻过拟合。
Bagging通过评价多个模型的结果,来降低模型的方差。Boosting不仅能过减小偏差,还能减小方差。

欠拟合

原因及方法

1、欠拟合是由于学习不足,可以考虑添加特征,从数据中挖掘更多的特征,有时候还需要对特征进行变换,使用组合特征和高次特征

2、模型简单也会导致欠拟合,如线性模型只能拟合一次函数的数据。尝试使用更高级的模型有助于解决欠拟合,如svm等

3、正则化参数是用来防止过拟合的,出现欠拟合的情况就要考虑减少正则化参数了。

常见过拟合、欠拟合原因及解决办法相关推荐

  1. ECONNABORTED,Socket 常见连接错误之一,原因分析 + 解决办法

    原创博文,欢迎转载,转载时请务必附上博文链接,感谢您的尊重. 前言 通过本篇,你将简单了解到 Socket 连接错误-- ECONNABORTED [产生的原因],还有[避免的方法],结论直接见最后的 ...

  2. 欠拟合的原因以及解决办法(深度学习)

    之前这篇文章,我分析了一下深度学习中,模型过拟合的主要原因以及解决办法: 过拟合的原因以及解决办法(深度学习)_大黄的博客-CSDN博客 这篇文章中写一下深度学习中,模型欠拟合的原因以及一些常见的解决 ...

  3. 过拟合原因及解决办法

    过拟合原因及解决办法 知乎 过拟合出现的原因以及解决方案 过拟合 欠拟合过拟合出现的原因及解决办法

  4. 【深度学习】模型过拟合的原因以及解决办法

    [深度学习]模型过拟合的原因以及解决办法 1.背景 2.模型拟合 3.简述原因 4.欠拟合解决办法 5.过拟合解决办法 1.背景 所谓模型过拟合现象: 在训练网络模型的时候,会发现模型在训练集上表现很 ...

  5. 过拟合的原因以及解决办法

    1.什么是过拟合  欠拟合是指模型没有能够很好的表现数据的结构,而出现的拟合度不高的情况.  过拟合是指模型过分的拟合训练样本,但对测试样本预测准确率不高的情况,也就是说模型泛化能力很差.如下图所示: ...

  6. 模型过拟合原因及解决办法

    模型过拟合原因及解决办法 过拟合现象 导致过拟合原因 解决办法 过拟合现象 对于样本量有限.但需要使用强大模型的复杂任务,模型很容易出现过拟合的表现,即在训练集上的损失小,在验证集或测试集上的损失较大 ...

  7. 06 回归算法 - 损失函数、过拟合欠拟合

    == 损失函数 == 损失函数是衡量一个模型好坏的指标,一般来说损失函数的值越小越好. 0~1损失函数: J(θ)=$begin{cases} 1,Y≠f(X)\ 0,Y=f(X)\ end{case ...

  8. 31,32,33_过拟合、欠拟合的概念、L2正则化,Pytorch过拟合欠拟合,交叉验证-Train-Val-Test划分,划分训练集和测试集,K-fold,Regularization

    1.26.过拟合.欠拟合及其解决方案 1.26.1.过拟合.欠拟合的概念 1.26.1.1.训练误差和泛化误差 1.26.1.2.验证数据集与K-fold验证 1.26.1.3.过拟合和欠拟合 1.2 ...

  9. 机器学习--过度拟合 欠拟合

    过度拟合(overfitting)是指数据模型在训练集里表现非常满意,但是一旦应用到真实业务实践时,效果大打折扣:换成学术化语言描述,就是模型对样本数据拟合非常好,但是对于样本数据外的应用数据,拟合效 ...

  10. 网站服务器挂了导致排名下降,常见关键词排名消失的原因及解决对策

    常见关键词排名消失的原因及解决对策 (2013-06-07 14:34:29) 标签: 关键词排名消失的原因 我们做seo,一定会遇到关键词排名消失的现象.其实关键词排名消失不可怕,怕的是我们找不到关 ...

最新文章

  1. 30天提升技术人的写作力-第二天
  2. 高通平台msm8909 LK 实现LCD 兼容
  3. 鸭子在Java中打字? 好吧,不完全是
  4. Django基本概念、安装、配置到实现框架,Xmind学习笔记
  5. websocket学习总结记录
  6. PIC18F中断定时器
  7. 使用WxPython进行Win32下Python编程
  8. CentOS 7 启动与切换图形界面
  9. mysql互为主从 keep_KeepAlived+MySQL互为主从
  10. Shell脚本学习-阶段十六-备份和恢复系统权限
  11. 贺利坚老师汇编课程57笔记:CMP和JXXX配合实现条件转移指令if
  12. css中如何使图标的旋转
  13. 根据c语言标识符的命名规则 标识符只能由,二级C语言教程同步习题集答案解析1-2章.doc...
  14. 把故事收回到一杯茶里,不知不觉,茶已经凉了
  15. 如何让自己发了疯、拼了命、石乐志的学习?
  16. 递推算法与递推套路(手撕算法篇)
  17. 晶品特装科创板上市:市值68亿 主打地面无人装备研发与产销
  18. 即时通讯视频聊天代码和技术架构
  19. HttpHelper类
  20. animFilters -3dsMax动画曲线优化,减帧

热门文章

  1. 常见的DOM操作方式有哪些
  2. git clone 提速方法
  3. 求助! fdisk创建新分区失败
  4. java的类型_Java的基本类型
  5. 拼音输入法(MPinyinIME)
  6. kafka-应用场景
  7. 自定义flink es source
  8. Android-短信弹窗提示
  9. 寓言故事《小松鼠的冒险与反思》
  10. JavaScript算法题整理