文章目录

  • 0.General Guide
    • 训练集上的loss太大怎么办?
    • 测试集上的loss太大怎么办?
  • 1.局部最小值与鞍点
  • 2.批次(batch)与动量(momentum)
  • 3.自动调整学习率(Adaptive Learning Rate)
  • 4.loss也可能有影响
  • 5.批次标准化(Batch Normalization)

主要是一些训练的tips,从训练集和测试集出发。

0.General Guide

模型训练的一般准则如下图:

训练集上的loss太大怎么办?

  1. model bais:模型太简单,需要重新设计model。
  2. optimation:优化没做好,没有找到全局最小点,可能是局部最小点。
    如何判断训练数据loss太大是由什么原因造成?——考察不同的模型,看训练的loss,如果更深的模型在训练集上没有获得更小的loss,那么这就是个优化问题。

如何解决optimzation的问题?——见下文momentum

测试集上的loss太大怎么办?

1.overfitting:训练数据loss小,但测试数据loss大,说明过拟合
解决:增加数据量;数据增强;给模型一些限制(更少的参数,共享参数)……

随着模型变得复杂,训练loss会一直减小,但测试loss可能会在某处反而增大,所以该如何选取比较好的模型?

将training set分为training set 和validation set,通过验证集来检验训练集的loss

怎么划分训练集和验证集——用N-fold Cross Validation
以不同的方式划分训练集和验证集,计算loss,再选择最好的方式

2.mismatch:
训练资料和测试资料分布不同,因此增加数据集之类的做法对测试集loss没有多大帮助

1.局部最小值与鞍点

优化失败的原因有:局部最小点或鞍点(此时梯度皆为0)
假设loss函数L(θ)L(\theta)L(θ)在θ=θ′\theta=\theta'θ=θ′ 附近的泰勒估计是
L(θ)≈L(θ′)+(θ−θ′)Tg+12(θ−θ′)TH(θ−θ′)L(\theta)\approx L(\theta')+(\theta-\theta')^Tg+\frac{1}{2}(\theta-\theta')^TH(\theta-\theta')L(θ)≈L(θ′)+(θ−θ′)Tg+21​(θ−θ′)TH(θ−θ′)

在critical point(驻点)即极值点或鞍点时,g=0,此时泰勒估计可以写成
L(θ)≈L(θ′)+12(θ−θ′)TH(θ−θ′)L(\theta)\approx L(\theta')+\frac{1}{2}(\theta-\theta')^TH(\theta-\theta')L(θ)≈L(θ′)+21​(θ−θ′)TH(θ−θ′)
令 θ−θ′=v\theta-\theta'=vθ−θ′=v,则:L(θ)≈L(θ′)+12vTHvL(\theta)\approx L(\theta')+\frac{1}{2}v^THvL(θ)≈L(θ′)+21​vTHv
For all v:vTHv>0v^THv>0vTHv>0,H是正定矩阵,在 θ′\theta'θ′ 附近,L(θ)>L(θ′)L(\theta)>L(\theta')L(θ)>L(θ′)——局部最小点
For all v:vTHv<0v^THv<0vTHv<0,H是负定矩阵,在 θ′\theta'θ′ 附近,L(θ)<L(θ′)L(\theta)<L(\theta')L(θ)<L(θ′)——局部最大点
For v : vTHvv^THvvTHv 有时大于0有时小于0——鞍点

如果是鞍点,可以通过H的特征向量的方向让L下降

2.批次(batch)与动量(momentum)

  • batch就是说将数据分成很多个批次,主要探讨small batch和large batch对训练的影响(比如batch size=1和 batch size=N)
    batch size也是个超参,具体区别如下图:

  • momentum——类比真实世界中的动量,在梯度很小或者等于0的时候还可以继续优化,解决鞍点和局部最小点
    之前的的优化过程是这样的,沿着负梯度方向前进:

现在加上动量:

Movement: movement of last step minus gradient at present

此时的优化方向是上次的运动方向减去当前的梯度,具体算法流程:

optimation(优化)+momentum(动量)
1.start at θ0\theta^0θ0,令动量movement m0=0m^0=0m0=0,计算θ0\theta^0θ0处的梯度g^0
2.计算动量 m1=λm0−ηg0m^1=\lambda m^0-\eta g^0m1=λm0−ηg0
3.计算此时优化方向 θ1=θ0+m1g1\theta^1=\theta^0+m^1 g^1θ1=θ0+m1g1,计算θ1\theta^1θ1处梯度g1g^1g1
4.计算动量m2=λm1−ηg1计算动量 m^2=\lambda m^1-\eta g^1计算动量m2=λm1−ηg1
5.计算此时优化方向 θ2=θ1+m2\theta^2=\theta^1+m^2θ2=θ1+m2

也就是说,在梯度很小或者dengyu0的时候,加上动量可以让优化继续

3.自动调整学习率(Adaptive Learning Rate)

训练停止不一定是梯度很小(Training stuck ≠ Small Gradient),也有可能是学习率的问题,有时学习率太大,可能错过了最小点;有时学习率太小,可能需要训练很久也不一定能找到最小点。另外不同的参数需要不同的学习率
还是从之前的优化出发:θit+1=θit−ηgit\theta_i^{t+1}=\theta_i^{t}-\eta g_i^tθit+1​=θit​−ηgit​,对学习率进行改动:
θit+1=θit−ησitgit\theta_i^{t+1}=\theta_i^{t}-\frac{\eta}{\sigma_i^t} g_i^tθit+1​=θit​−σit​η​git​
可以看出此时的学习率既是迭代的又是参数独立的,σit\sigma_i^tσit​ 怎么去求呢,常见的有这两种办法
1.Root Mean Square:导数的均方根

2.RMSProp

优化中著名的Adam算法就是 RMSProp+Monentum

  • Learning Rate Scheduling
    学习率的调整,对上述的优化表达式再进行变形:
    θit+1=θit−ηtσitgit\theta_i^{t+1}=\theta_i^{t}-\frac{\eta_t}{\sigma_i^t} g_i^tθit+1​=θit​−σit​ηt​​git​
    即让 η\etaη 与时间有关,比如说可以让学习率随着时间下降,或者先增大再减小(RAdam论文)。更多学习率调整技巧可参加 click here

所以就优化表达式总结:一般的优化:θit+1=θit−ηgit\theta_i^{t+1}=\theta_i^{t}-\eta g_i^tθit+1​=θit​−ηgit​
改进的优化:

4.loss也可能有影响

这个问题主要是针对回归和分类的loss函数

回归和分类差不多,稍有不同的就是分类需要在输出时再通过一层softmax层,可以理解成将输出的任意 y 值变换到0-1之间,softmax公式是:
yi′=exp⁡(yi)∑iexp(yi)y_i'=\frac{\exp(y_i)}{\sum_i exp(y_i)}yi′​=∑i​exp(yi​)exp(yi​)​
在二分类时,常用sigmoid函数,其实经过计算比较发现此时两者结果一样
Q:softmax和sigmoid有什么区别和联系?

我们知道回归的loss函数一般是 MSE:e=∑(y^i−yi)2e=\sum(\hat{y}_i-y_i)^2e=∑(y^​i​−yi​)2
分类的loss函数一般是 cross-entropy:e=−∑iyi^lnyi′e=-\sum_i \hat{y_i}ln y_i'e=−∑i​yi​^​lnyi′​
在实际训练中,如果把分类问题当作回归问题,用MSE代替cross-entropy,效果不会很好。
所以:

Changing the loss function can change the difficulty of optimization

5.批次标准化(Batch Normalization)

  • Question:为什么会产生不好train的时候?
    不同的参数下降速度不同,并且梯度下降斜率可能相差很大,这个时候使用固定的学习率优化效果可能就不太好,所以就用到上面讲过的自适应学习率,Adam等
  • 换个角度的思路:
    当稍微改变w值的时候,希望loss变化也比较小,那一个可能方法就是让输入的x比较小,这样改变了w,对整体的loss影响就不会很大,即
    较小的数据x1x_1x1​对loss影响较小,会让error surface(误差曲面)更smooth
    较大的数据x2x_2x2​对loss影响较大,会让error surface(误差曲面)更steep
  • solution:
    所以希望不同dimension的特征有相同的数值范围(same range),即有接近的数值——进行归一化

进行归一化,计算所有数据同一个dimension的特征均值及标准差,通过公式
x~ir=(xir−mi)/σi\tilde{x}_i^r=(x_i^r-m_i)/\sigma_ix~ir​=(xir​−mi​)/σi​ 计算归一化后的值,此时所有维度的特征都在0的附近。
一般来说,特征归一化可以让收敛更快

进一步的可以延申至每一层网络,即不仅在数据输入层进行归一化,在经过一层神经网络后再进行归一化。
另外一般数据量很大时可以将其分成很多个batch(比如batch=64即每次输入64笔数据),减小GPU内存压力,所以称之为batch normalization.

参考链接:
https://speech.ee.ntu.edu.tw/~hylee/ml/2022-spring.php

类神经网络训练不起来怎么办——机器学习模型训练指南相关推荐

  1. 机器学习模型训练之GPU使用

    机器学习模型训练之GPU使用 1.电脑自带GPU 2.kaggle之免费GPU 3.amazon SageMaker Studio Lab 免费GPU使用推荐 深度学习框架由大量神经元组成,它们的计算 ...

  2. AI:神经网络IMDB电影评论二分类模型训练和评估

    AI:Keras神经网络IMDB电影评论二分类模型训练和评估,python import keras from keras.layers import Dense from keras import ...

  3. 机器学习模型训练_您打算什么时候重新训练机器学习模型

    机器学习模型训练 You may find a lot of tutorials which would help you build end to end Machine Learning pipe ...

  4. 机器学习模型训练全流程

    一.机器学习模型训练全流程 1.获得原始数据集 同时包含X和Y--可以用于监督学习(回归或分类):只包含X--无监督学习. 若Y包含定量值,那么数据集(由X和Y组成)用于回归:若Y包含定性值,那么数据 ...

  5. 9 张手绘图:阐明机器学习模型训练全流程

    Datawhale干货 译者:张峰,安徽工业大学,Datawhale成员 周末在家无聊闲逛github,发现一个很有趣的开源项目,作者用手绘图的方式讲解了机器学习模型构建的全流程,逻辑清晰.生动形象. ...

  6. 干货:机器学习模型训练全流程!

    [提醒:公众号推送规则变了,如果您想及时收到推送,麻烦右下角点个在看,或者把本号置顶] 正文开始 周末在家无聊闲逛github,发现一个很有趣的开源项目,作者用手绘图的方式讲解了机器学习模型构建的全流 ...

  7. python训练模型、如何得到模型训练总时长_【绝对干货】机器学习模型训练全流程!...

    周末在家无聊闲逛github,发现一个很有趣的开源项目,作者用手绘图的方式讲解了机器学习模型构建的全流程,逻辑清晰.生动形象.同时,作者也对几张图进行了详细的讲解,学习之后,收获很多,于是将其翻译下来 ...

  8. 机器学习模型训练问答

    内容主要来自Aurelien Geron<Hands-on Machine Learning withi Scikit-Learn&TensorFlow> 线性回归 1. 如果训练 ...

  9. 最全的机器学习模型训练全流程

    简言 发现一个很有趣的开源项目,作者用手绘图的方式讲解了机器学习模型构建的全流程,逻辑清晰.生动形象.想给大家分享一下. 项目地址:https://github.com/dataprofessor/i ...

  10. python训练模型、如何得到模型训练总时长_模型训练时间的估算

    模型训练时间的估算 昨天群里一个朋友训练一个BERT句子对模型,使用的是CPU来进行训练,由于代码是BERT官方代码,并没有显示训练需要的总时间,所以训练的时候只能等待.他截图发了基本的信息,想知道训 ...

最新文章

  1. 使用变量对象引出作用域链
  2. 拿下计网协议后,我就是公园里最靓的仔
  3. 【转】android是32-bit系统还是64-bit系统
  4. LVS(3)——针对于真实主机的增删改操作
  5. js 让浏览器全屏模式的方法launchFullscreen
  6. linux监控电脑配置,Zabbix基本配置及监控主机
  7. MYSQL8 度分秒(DMS)转度(DDD)函数编写实战
  8. 数字猜谜游戏python_Python Tkinter教程系列02:数字猜谜游戏
  9. iQOO 5G版8月上市,价格更加亲民!
  10. java api练习_Java接口练习
  11. 史上最详细的Studio教程二来啦
  12. 在线查服务器地址,工具|查询域名所在服务器的其他网站和IP
  13. iptv错误代码2003什么意思_IPTV部分错误代码及原因解释
  14. 高频头极化角调整+用什么本振的高频头
  15. 金山办公推出协同办公全家桶 WPS升级为超级工作入口
  16. 时钟晶振电路EMC设计标准电路详解
  17. R语言:rJava包的安装
  18. vb.net 编写的简易串口调试程序
  19. 国内股票KDJ指标计算,Python实现KDJ指标计算,Talib实现KDJ指标计算
  20. 在iview中render函数使用Switch功能

热门文章

  1. phpunit光速入门
  2. 新建的web项目为什么默认访问index.jsp
  3. My 2007 Fash game: Elite Shooter
  4. 手把手教你基于PaddlePaddle的情绪识别
  5. python 输出结果图文混排_Django图文混排
  6. Go语言攻略:“面向对象”
  7. php的chunk_split,php函数chunk_split详解
  8. 比犀利哥更经典的话语
  9. UGUI_03_补充之_Image的属性(image type这个属性simple、Sliced、tiled、filled样式详解)
  10. 计算机专业动漫设计毕业论文,计算机动漫设计与制作专业毕业论文.doc