深度学习中模型训练效果不好的原因

  • 1. 是否选择合适的损失函数
  • 2. 是否选择了合适的Mini-batch size
  • 3. 是否选择了合适的激活函数
  • 4. 是否选择了合适的学习率
  • 5. 优化算法是否使用了动量(Momentum)
  • 6. 其他原因

当我们用自定义的模型去训练某个数据集时,
经常会出现效果不佳的情况:精度太低、损失降不下去、泛性太差等情况。可能的原因有:

  • 数据集样本太少,多样性不够;
  • 网络模型是否添加了BN层,损失函数和激活函数的选取;
  • 优化器的选取,学习率的设置等;

这里暂时不考虑数据集的原因,我们首先来看一下网络模型和优化算法中可能存在的问题:

1. 是否选择合适的损失函数

神经网络的损失函数是非凸的,有多个局部最低点,目标是找到一个可用的最低点。
非凸函数是凹凸不平的,但是不同的损失函数凹凸起伏的程度不同,例如下述的平方损失和交叉熵损失,后者起伏更大,且后者更容易找到一个可用的最低点,从而达到优化的目的。
- Square Error(平方损失)
- Cross Entropy(交叉熵损失)

2. 是否选择了合适的Mini-batch size

使用合适的batch size进行学习,一方面可以减少计算量,一方面有助于跳出局部最优点。
batch取太大会陷入局部最小值,batch取太小会抖动厉害,因此要选择一个合适的batch size。

batch size选取时可以采用以下策略:

  • 当有足够算力时,选取batch size为32或更小一些。
  • 算力不够时,在效率和泛化性之间做trade-off,尽量选择更小的batch size。
  • 当模型训练到尾声,想更精细化地提高成绩(比如论文实验/比赛到最后),有一个有用的trick,就是设置batch size为1,即做纯SGD,慢慢把error磨低。

3. 是否选择了合适的激活函数

使用激活函数把卷积层输出结果做非线性映射,但是要选择合适的激活函数。

  • Sigmoid函数是一个平滑函数,且具有连续性和可微性,它的最大优点就是非线性。但该函数的两端很缓,易发生学不动的情况,产生梯度弥散。
  • ReLU函数是如今设计神经网络时使用最广泛的激活函数,该函数为非线性映射,且简单,可缓解梯度弥散。

4. 是否选择了合适的学习率

  1. 学习率过大,会抖动厉害,导致没有优化提升,容易错过最优解
  2. 学习率太小,下降太慢,训练会很慢

学习率可以采用以下策略:

  • 如果模型是非常稀疏的,那么优先考虑自适应学习率的算法。
  • 在模型设计实验过程中,要快速验证新模型的效果,可以先用Adam进行快速实验优化;在模型上线或者结果发布前,可以用精调的SGD进行模型的极致优化。并且制定一个合适的学习率衰减策略。 可以使用定期衰减策略,比如每过多少个epoch就衰减一次。

5. 优化算法是否使用了动量(Momentum)

在SGD的基础上使用动量,有助于冲出局部最低点。

当我们将一个小球从山上滚下来时,没有阻力的话,它的动量会越来越大,但是如果遇到了阻力,速度就会变小。

在SGD的基础上使用动量,可以使得梯度方向不变的维度上速度变快,梯度方向有所改变的维度上的更新速度变慢,这样就可以加快收敛并减小震荡。

6. 其他原因

如果以上五部分都选对了,效果还不好,那就是产生过拟合了,可使如下方法来防止过拟合,分别是

  1. 早停法(earyly stoping):早停法将数据分成训练集和验证集,训练集用来计算梯度、更新权重和阈值,验证集用来估计误差,若训练集误差降低但验证集误差升高,则停止训练,同时返回具有最小验证集误差的连接权和阈值。
  2. 权重衰减(Weight Decay):到训练的后期,通过衰减因子使权重的梯度下降地越来越缓,可以采用L1或L2正则化。
  3. Dropout:Dropout是正则化的一种处理,以一定的概率关闭神经元的通路,阻止信息的传递。由于每次关闭的神经元不同,从而得到不同的网路模型,最终对这些模型进行融合。
  4. 调整网络结构(Network Structure)。

参考文章:
https://www.julyedu.com/question/big/kp_id/26/ques_id/2589
本文是对其他参考文章的总结,若内容和图片有涉及侵权,请联系作者删除。

深度学习中模型训练效果不好的原因以及防止过拟合的方法相关推荐

  1. 【深度学习】模型训练过程可视化思路(可视化工具TensorBoard)

    [深度学习]模型训练过程可视化思路(可视化工具TensorBoard) 文章目录 1 TensorBoard的工作原理 2 TensorFlow中生成log文件 3 启动TensorBoard,读取l ...

  2. 【深度学习中模型评价指标汇总(混淆矩阵、recall、precision、F1、AUC面积、ROC曲线、ErrorRate)】

    深度学习中模型好坏的所有评价指标汇总(混淆矩阵.recall.precision.F1score.AUC面积.ROC曲线.ErrorRate) 导航 0.混淆矩阵 1.AUC面积 2.ROC曲线 3. ...

  3. 【深度学习】深度学习中模型计算量(FLOPs)和参数量(Params)等的理解以及四种在python应用的计算方法总结

    接下来要分别概述以下内容: 1 首先什么是参数量,什么是计算量 2 如何计算 参数量,如何统计 计算量 3 换算参数量,把他换算成我们常用的单位,比如:mb 4 对于各个经典网络,论述他们是计算量大还 ...

  4. 【机器学习算法】神经网络与深度学习-7 DNN深度学习算法模型出现学习效果不好的情况,如何补救,对策如下,建议收藏。

    目录 深度学习效果不好的对策 训练数据效果不好的情况,采用什么对策 方法1:proper loss 方法2:Mini-Batch 方法3:采用new activation function 方法4:A ...

  5. 深度学习中模型攻击与防御(Attack DL Models and Defense)的原理与应用(李宏毅视频课笔记)

    文章目录 0 前言 1 Attack(模型攻击) 1.1 Attack on Image Recognition Network Model(对图像识别网络模型的攻击) 1.1.1 Loss Func ...

  6. 深度学习大模型训练--分布式 deepspeed PipeLine Parallelism 源码解析

    deepspeed PipeLine Parallelism 源码解析 basic concept PipeDream abstract 1F1B 4 steps Code comprehension ...

  7. 【深度学习】模型训练教程之Focal Loss调参和Dice实现

    文章目录 1 Focal Loss调参概述 2 实验 3 FocalLoss 对样本不平衡的权重调节和减低损失值 4 多分类 focal loss 以及 dice loss 的pytorch以及ker ...

  8. 「AI不惑境」深度学习中的多尺度模型设计

    https://www.toutiao.com/a6716408071637172748/ 大家好,这是专栏<AI不惑境>的第七篇文章,讲述计算机视觉中的多尺度问题. 进入到不惑境界,就是 ...

  9. 【AI不惑境】深度学习中的多尺度模型设计

    大家好,这是专栏<AI不惑境>的第七篇文章,讲述计算机视觉中的多尺度问题. 进入到不惑境界,就是向高手迈进的开始了,在这个境界需要自己独立思考.如果说学习是一个从模仿,到追随,到创造的过程 ...

最新文章

  1. NLP中文面试学习资料:面向算法面试,理论代码俱全,登上GitHub趋势榜
  2. 宅男程序员给老婆的计算机课程之9:数据模型
  3. [Python图像处理] 四十一.Python图像平滑万字详解(均值滤波、方框滤波、高斯滤波、中值滤波、双边滤波)
  4. 任务31:课时介绍 任务32:Cookie-based认证介绍 任务33:34课 :AccountController复制过来没有移除[Authorize]标签...
  5. arduino i2c 如何写16位寄存器_Arduino之我见
  6. 同比暴增3700%!百度取代谷歌成世界第二
  7. windows访问mysql57_windows下 Mysql5.5升级5.7(其实就是安装了两个版本的mysql)
  8. Java获取接口所有实现类的方式
  9. html平面图绘制,如何利用Word绘制校园平面图?
  10. 2022年全国图书参考联盟读秀5.0/4.0/3.0/2.0/1.0书库网盘数据索引在线搜索查询系统搭建教程,可以实现ISBN/SS号/书封面链接/书名/作者/出版社…等信息一键搜索查询
  11. 利用企业微信/飞书/钉钉扫码认证连接办公WiFi无线网络解决方案
  12. P4234(最小差值生成树 lct维护生成树)
  13. 游戏HTML翻翻乐,大班益智游戏翻翻乐教案
  14. 如何让游戏讲一个好故事?
  15. VLC 21年,重新审视低延迟直播
  16. HTML中引入CSS文件的几种方法
  17. Efficient Use Of Tmux
  18. EXPDP预估导出空间estimate,estimate_only
  19. 灼热丝试验箱操作规程 洛克仪器 Labverse
  20. 韩顺平java--Collection 集合专题

热门文章

  1. 淘宝技术分享:手淘亿级移动端接入层网关的技术演进之路
  2. 数据可视化:8款小众但好用的可视化工具
  3. 苹果ipv6审核解决方案
  4. Linux 压缩包乱码
  5. PPT结束语有哪些?
  6. 【第二届】无锡太湖学院ICPC校队对抗赛原创 IOI D题题解
  7. 软件测试英文项目,一个成功软件测试项目的经验(国外英文资料).doc
  8. mysql查询第二个字母为a_MSSQL_关于SQL Server查询语句的使用,一.查询第二个字母是t或者a的 - phpStudy...
  9. 2011 Esri中国开发者大会
  10. 对话黄骁俭:SAP的工程师文化