夏乙 栗子 编译自 Khanna.cc 
量子位 报道 | 公众号 QbitAI

想要训练个深度神经网络,也准备好了可以直接用的数据,要从哪里开始上手?

来自美国的Harry Khanna,精心编织了一套六步法。大家可以瞻仰一下,然后决定要不要行动。

整个过程中,过拟合的问题时时刻刻都要注意。

1. 选个损失函数

选择怎样的损失函数,取决于需要解决怎样的问题。

如果是回归问题,就可以用均方误差 (MSE) 损失函数。

如果是分类问题,就用交叉熵 (Cross-Entropy) 损失函数。

只有两类的话,要用二值交叉熵 (Binary Cross-Entropy) 。

如果遇到了不常见的问题,比如一次性学习 (One-Shot Learning) ,可能就要自行定制一个损失方程了。

2. 选个初始架构

说到结构化学习,比如预测销售情况,从全连接的隐藏层开始,是个不错的选择。

这一层的激活数 (Number of Activations) ,要在输入神经元与输出神经元的数量之间。

两者取个平均数,就可以。

像下面这样的取法,也可以。



Ni,是输入神经元数。
No,是输出神经元数。
Ns,训练集里的样本数。
a,尺度因子,在2到10之间选。

计算机视觉领域的小伙伴来说,像ResNet这样的架构,就很友好。

3. 拟合训练集

这一步,最重要的超参数,是学习率 (Learning Rate) (α) 。

不需要试错,fast.ai的库里面,有一个rate finder



只要写两行代码,就可以得到一个学习率的曲线。



在损失还在明显下降的区域,选取学习率——

比如,最陡部分的旁边一点点,损失仍在剧烈下降,没有到平坦的地方。

上图的话,10-4就可以。

如果,模型训练还是很慢,或者效果不好的话,可以用Adam优化,代替初始架构里的随机梯度下降 (SGD) 优化。

这时候,如果模型还是不能和训练集愉快玩耍,考虑一下学习率衰减 (Learning Rate Decay) ——

有指数衰减,离散阶梯衰减 (Discrete Staircase Decay) ,甚至还有一些手动方法,人类可以在损失不再下降的时候,自己把学习率 (α) 往下调。

其中,余弦型 (Cosine) 衰减,在一个回合 (Epoch) 开始的时候,衰减最慢,中间最快,结束时又最慢。



然后,可以加上一个“重启 (Restart) ”功能。这样,每 (几) 个回合开始时,学习率都会回到没有衰减的状态。

迁移学习的话,要把开始几层解冻 (Unfreeze) ,然后用差分学习率来训练神经网络。

如果,训练集还是不开心,还有另外几个超参数可以调整——

· 隐藏层的unit数
· 小批量 (Minibatch) 的大小:64,128,256,512……
· 隐藏层数

还不行的话,就要看目标能不能再细化一下。

输入的训练数据,可能并没有预测输出值所需的有效信息。

举个栗子,仅仅基于股票的历史价格,来预测未来走势,就很难。

4. 拟合验证集

这一步,是最难的,也最花时间。

怎样才能解决训练集上的过拟合问题?

丢弃 (Dropout)

把训练集中的一些神经元,随机清零。

那么,概率 (p) 要怎么设置?

虽然,没有万能之法,但还是有一些可以尝试的方法——

找到p=0.25的最后一个线性层,对这之前 (包含本层) 的那些层,执行随机抛弃。

然后,在把p往上调到0.5的过程中,实验几次。

如果还是不行,就给再往前的线性层,也执行随机丢弃操作,还是用p=0.25到0.5之间的范围。

并没有通天大法,所以有时候还是要试错,才能知道在哪些层里,取多大的p,更有效。

权重衰减/L2正则化 (Weight Decay/ L2 Regularization)

第二小步,加上权重衰减。就是在损失函数里面添加一项——



λ,是正则化超参数

wj,是权重矩阵w里面的特征j。

n,是权重矩阵w里的特征数。

过拟合的其中一个原因,就是权重大。而权重衰减,可以有效打击大权重。

输入归一化 (Normalize Inputs)

减少过拟合的第三小步,就是把输入特征的均值方差,各自除以训练集,归一化。

特征x1除以训练样本总数m,要等于0,方差要等于1。x2,x3…也一样。



μ向量,维数等于单个训练样本的输入特征数

x是一个n x m矩阵,n是输入特征数,m是训练样本数。

x-μ,就是x的每一列都要减掉μ。

标准差归一化的公式,看上去就比较熟悉了——



注意,要归一化的是,除以训练样本总数m,之后的均值和方差,不是除以每个样本的特征数n。

再注意,是用训练集的均值和方差,不是验证集。

批量归一化 (Batch Normalization)

上一小步,归一的是输入特征,而这里,要把隐藏层神经元的均值方差归一化。

和之前一样的是,用了训练集的均值方差,来调教验证集。

 不支持一次吃太多

不同的是,要一小批一小批地进行,并非整个训练集一步到位。

这种情况下,可以使用均值和方差的指数加权平均 (exponentially weighted average) ,因为有很多的均值,和很多的方差。

多点训练数据、数据扩增

根据经验,过拟合最好的解决办法,就是增加训练数据,继续训练。

如果没有太多数据,就只好做数据扩增。计算机视觉最爱这种方法,调光,旋转,翻转等等简单操作,都可以把一张图变成好几张。

不过,在结构化数据和自然语言处理中,数据扩增就没有什么栗子了。

梯度消失,梯度爆炸

梯度消失 (Vanishing Gradients) ,是指梯度变得很小很小,以至于梯度下降的学习效果不怎么好。

梯度爆炸 (Exploding Gradients) ,是指梯度变得很大很大,超出了设备的计算能力

解决这两个问题,第一条路,就是用一种特殊的方式把权重矩阵初始化——名曰“He Initialization”。

不是第三人称,是出自何恺明等2015年的论文:
Delving Deep into Rectifiers: Surpassing Human-Level Performance on ImageNet Classification。

把权重矩阵W[l],变成一个均值为零的高斯分布。标准差长这样:



效果意想不到的好,梯度消失和爆炸,都少有发生了。

如果是训练自然语言处理RNN,LSTM是首选,也是一种减少梯度消失或爆炸的方式。

出现NaN,很可能就是梯度爆炸了。

一个简单粗暴的处理方式是,梯度裁剪 (Gradient Clipping) ,给梯度设一个上限。超出了限制,梯度就会被切。

神经网络架构搜索

有时候调整超参数没什么用,不管怎么调,验证集的loss还是比训练集高好多。

这时候就该好好看看神经网络的架构了。

输入里有太多特征、隐藏层太多激活数,都可能会导致神经网络拟合了训练集里的噪音。

这时候就要调整隐藏层的神经元数量,或者增减隐藏层的数量。

这个过程需要试错,可能要试过很多架构才能找到一个好用的。

5. 在测试集上检验性能

当神经网络在训练集和验证集上都表现良好,要保持警惕:优化过程中,有可能一不小心在验证集上过拟合了。

上一步拟合验证集时,超参数都向着在验证集上优化的方向调整。于是,这些超参数有可能将验证集中的噪音考虑了进去,模型对新数据的泛化能力可能很差。

所以,到了这一步,就要在一个没见过的测试集上来运行神经网络,确认还能取得和验证集上一样的成绩。

如果在测试集上表现不好,就要通过增加新数据或者数据增强(data augmentation)的方式,扩大验证集规模。

然后重复第4、5步。

注意:不要根据测试集损失来调整超参数!这样只能得到一个对训练集、验证集和测试集都过拟合了的模型。

6. 在真实世界中检验性能

如果你训练了一个猫片识别器,就喂它一些你的猫片;

如果你训练了一个新闻稿情绪识别器,就喂它一些微软最近的新闻。

如果这个神经网络在训练、验证、测试集上表现都不错,到了现实世界却是个渣渣,那一定出了什么问题。比如说,有可能过拟合了测试集

这时候就需要换个验证集、测试集,看看模型表现怎么样。如果这个现实世界的渣渣依然表现良好,那么问题可能出在损失函数身上。

这种情况,还是挺少见的。

一般只要成功熬到第6步,模型在现实世界里都挺厉害的。

来,我们回顾一下

刚才讲的这么多,最后可以汇集成下面这个checklist:

  • 第1步:损失函数

    • 回归问题用MSE

    • 多类别分类问题用交叉熵

    • 二分类问题用二值交叉熵

  • 第2步:初始神经网络架构

    • 结构化学习:一个激活数在输入输出神经元数之间的全连接层

    • 计算机视觉:从ResNet开始

  • 第3步:训练集

    • 用learning rate finder来选学习率

    • Adam优化

    • 余弦学习率衰减

    • 学习率重启

    • 如果做迁移学习,尝试一下可微分学习率

    • 隐藏层的神经元数量

    • minibatch大小

    • 隐藏层数量

  • 第4步:验证集

    • Dropout

    • L2正则化

    • 输入特征归一化

    • 批量归一化

    • 数据扩增

    • 为训练集补充数据

    • 梯度消失或爆炸

      • He初始化

      • 用LSTM神经元

      • 梯度裁剪

    • 调整神经网络架构

  • 第5步:测试集

    • 如果有问题,扩大验证集,回到第4步

  • 第6步:真实世界

    • 如果有问题,换个验证集和测试集,回到第4步

加入社群

量子位AI社群18群开始招募啦,欢迎对AI感兴趣的同学,加小助手微信qbitbot8入群;

此外,量子位专业细分群(自动驾驶、CV、NLP、机器学习等)正在招募,面向正在从事相关领域的工程师及研究人员。

进群请加小助手微信号qbitbot8,并务必备注相应群的关键词~通过审核后我们将邀请进群。(专业群审核较严,敬请谅解)

诚挚招聘

量子位正在招募编辑/记者,工作地点在北京中关村。期待有才气、有热情的同学加入我们!相关细节,请在量子位公众号(QbitAI)对话界面,回复“招聘”两个字。

量子位 QbitAI · 头条号签约作者

վ'ᴗ' ի 追踪AI技术和产品新动态

怎样构建深度学习模型?六步走,时刻小心过拟合 | 入门指南相关推荐

  1. 5大关键步骤!如何构建深度学习模型?

    深度学习的关注度正持续上升,它是机器学习的一个子领域,基于人工神经网络的概念来执行特定任务.然而在理论上,人工神经网络与人类大脑的运作方式并不相同,甚至都不相似! 它们之所以被命名为人工神经网络,是因 ...

  2. R基于H2O包构建深度学习模型实战

    R基于H2O包构建深度学习模型实战 目录 R基于H2O包构建深度学习模型实战 #案例分析

  3. 通过 Keras 构建深度学习模型的步骤

  4. 用 Java 训练深度学习模型,原来这么简单

    作者 | DJL-Keerthan&Lanking 来源 | HelloGitHub 头图 | CSDN下载自东方IC 前言 很长时间以来,Java 都是一个很受企业欢迎的编程语言.得益于丰富 ...

  5. 【深度学习】深度学习模型训练全流程!

    Datawhale干货 作者:黄星源.奉现,Datawhale优秀学习者 本文从构建数据验证集.模型训练.模型加载和模型调参四个部分对深度学习中模型训练的全流程进行讲解. 一个成熟合格的深度学习训练流 ...

  6. 如何使用Keras和TensorFlow建立深度学习模型以预测员工留任率

    The author selected Girls Who Code to receive a donation as part of the Write for DOnations program. ...

  7. 深度学习模型训练全流程!

    ↑↑↑关注后"星标"Datawhale 每日干货 & 每月组队学习,不错过 Datawhale干货 作者:黄星源.奉现,Datawhale优秀学习者 本文从构建数据验证集. ...

  8. 了解如何在Google Colaboratory中构建深度学习系统

    原文来自特拉字节:https://telabytes.com/article/preview?id=119 这是练习深度学习的好时机.现有的主要深度学习框架如TensorFlow.Keras和PyTo ...

  9. tflearn教程_利用 TFLearn 快速搭建经典深度学习模型

    使用 TensorFlow 一个最大的好处是可以用各种运算符(Ops)灵活构建计算图,同时可以支持自定义运算符(见本公众号早期文章<TensorFlow 增加自定义运算符>).由于运算符的 ...

最新文章

  1. pytorch中Schedule与warmup_steps的用法
  2. java时钟面板clock
  3. 计算机辅助的开发方法,基于计算机辅助设计技术(TCAD)的工艺开发
  4. python对字符串和集合的内存垃圾回收机制
  5. Android SDK上手指南:应用程序数据
  6. pycharm 类型注释_学习Python第一步,变量与数据类型
  7. [转载]dbms_lob用法小结
  8. 经验分享:三步走教你升级企业NAS设备
  9. hadoop lambda_Delta架构:统一Lambda架构并利用Hadoop / REST中的Storm
  10. 浏览器解析html全过程详解
  11. mysql ddl分类_MySQL语言分类——DDL
  12. css 平行四边形 梯形 组合_微课|人教版五年级数学上册6.4组合图形的面积(P99)...
  13. Linux 命令(113)—— seq 命令
  14. 汉诺塔c++_C语言,递归解决汉诺塔问题
  15. Git客户端操作GitHub
  16. S71500 PLC 与第三方设备 ModbusTCP 项目调试记录
  17. html5给文字添加拼音,word怎么为文字添加拼音
  18. java 电子实时看板,物理看板还是电子看板?
  19. liunx系统下搭建domian
  20. 【Python学习记录】Numpy广播机制(broadcast)

热门文章

  1. 经过 180 年的训练,OpenAI 在 DOTA 2 上完虐人类!
  2. 如何仅凭 README 就名列 GitHub No.1 并收获上万 Star?
  3. 用c语言定义一个分式方程,计算机C语言实验报告(00001)
  4. day03【后台】管理员维护
  5. python实现食品推荐_Python分析亚马逊赞不绝口食物评论
  6. 宏碁笔记本linux,Acer宏碁(Acer宏碁)Acer 4752G-2332G50Mnkk Linux笔记本电脑整体评测-ZOL中关村在线...
  7. mysql linux 中文乱码怎么解决_如何解决mysql linux 中文乱码的问题
  8. win10怎么用计算机的搜索,win10 搜索用不了的问题
  9. 为什么二维码这么神奇,扫一下就能得到各种各样的信息?
  10. android小灯泡实验代码,typecho常用代码片段收集