文章目录

  • 参考依据
  • 两个现象
    • 1.神经网络的训练没有想象中简单
    • 2. 神经网络训练的失败往往是悄无声息的
  • 正确的训练方式
    • 1. 数据第一!
    • 2. 制作端到端的训练/验证框架 + 得到baselines
    • 3. 过拟合
    • 4. 正则化
    • 5. 调参
    • 6. 精益求精

参考依据

参考自Andrej Karpathy大佬(特斯拉AI总监,李飞飞学生)的博客:http://karpathy.github.io/2019/04/25/recipe/

在此文章的基础上,结合了自己的理解和想法,写下了这篇博文。

两个现象

1.神经网络的训练没有想象中简单

很多框架,包括Pytorch、Tensorflow在内,都会提供一些可以直接调用神经网络模型接口,这往往会给人一种错觉,就是神经网络模型是即插即用的。甚至有时候,人们对于模型的训练会简化到以下这种地步:

>>> your_data = # plug your awesome dataset here
>>> model = SuperCrossValidator(SuperDuper.fit, your_data, ResNet50, SGDOptimizer)
# conquer world here

说明:
上面这个是Andrej在描述这种现象写的一个示例代码,注释真的是太搞笑了。

虽然我们十分熟悉这种API调用的模式,并且希望达到这种效果,但是神经网络没有我们想的这么容易。神经网络是数据驱动的,不同的数据分布在同一神经网络模型展现的效果不总是同样好的,神经网络有一个著名的理论叫No Free Lunch Theorem:对于基于迭代而产生的最优算法,不存在某种算法对于所有问题都有效。 并且简单地使用反向传播和随机梯度下降不会使得你的模型总是work,Batch Norm也不总是能加速收敛,在一些数据集上,丢掉BatchNorm反而能更好地拟合数据。所以如果你不了解其中的原理,你可能就会失败(我觉得大概率,至少不是最优解)。

2. 神经网络训练的失败往往是悄无声息的

这些错误往往不是显性的语法错误,而是一些内容上的、逻辑上的错误。你有时候会发现你的模型能够work,但实际上并不是这样。比如,我有时候会因为标签标错,然后在训练进行测试的时候才发现这个问题,只通过训练时反馈的信息无法看出任何异常。还有一次我训练分割模型,对数据集的标签进行增广时,由于缩放和旋转采用的方式是双线性插值,导致模型在训练过程中loss越来越大。 所以,如果你的模型在训练时报错,那么你是幸运的,因为在很多时候,它往往是悄无声息的。

总的来说,快速而暴力的训练方式在神经网络的训练中是不起作用的,必须要有耐心,并且循序渐进。

正确的训练方式

1. 数据第一!

训练一个神经网络,第一步绝对不是敲代码,而是检查数据(特别是自己的数据集)!!这个我真的深有体会,当你在魔改网络,加各种骚操作之前,一定要确保你的数据是正确的!! 所以与其在之后花大工夫来检查,然后发现其实是数据问题,还不如一开始就确保数据的正确性。所以先浏览一遍数据是必要的,观察数据的分布特点以及是否存在类别不平衡情况,这将取决我们应该去探索哪种网络模型。例如,局部特征就足够了吗?是否需要全局信息特征?应该采取什么形式的数据增广?空间信息重要吗?还是直接可以平均池化?图像的细节重要吗?我们应该下采样到什么程度?标签是否有噪声?
除了定性的观察一遍数据,也可以编写程序(搜索、过滤、排序等)对数据进行一些定量分析(例如标签类型、标注数量、标注大小等),然后可视化数据分布,并且找出异常值(outliers)。

2. 制作端到端的训练/验证框架 + 得到baselines

下一步就需要建立一个完整的训练/验证框架,并通过一系列的实验来确保正确性。最好先选择一个简单的模型把流程跑通 我们在这个模型上完成训练、可视化loss和一些其他的指标、进行模型预测和进行一系列消融实验。

Tips & Tricks:

  1. 设置随机数种子 设置好随机数种子,确保模型的可复现性。
  2. 简化 不要去抱有一些不必要的想法。例如数据增广,因为它是用来提高模型的泛化能力的操作,在目前这个阶段,是不必要引入的,徒增训练的负荷。
  3. 有效的评估 当绘制loss值时,要以整个test/val数据集的loss为单位,而不是以batch为一个单位。
  4. 在初始阶段验证损失函数 从初始化阶段就要确保损失函数计算的正确性。
  5. 一个好的初始化 正确初始化最后一层的权重。例如你要对一个均值为50的数据集做回归预测,那么最后的logits的bias就可以初始化为50;对于分类任务,如果你的数据不平衡,假设正负样本比例为1:10,就可以让最后的logits的bias初始化为0.1。一个好的初始化能够加速收敛,避免你的网络在前几个迭代过程都是在学习偏差。
  6. human baseline 监控所以的对于人类来说是可解释的或者可检查的指标(比如accuracy、Mean IOU等),将这些指标和人类的指标相对比(比如,如果让你人为地去分类这批数据,你觉得你的准确率会是多少)。
  7. 与输入无关的baseline 训练一个与输入无关的baseline(最简单的方法是将所有的输入变为0),这个可以检测你的模型是否能够获取输入信息。说实话,这个我没太懂,有大佬懂了的话,麻烦指点一二。
  8. 过拟合一个单batch模型 令batch_size = 1(或者2)来训练部分样本,从而得到一个过拟合的模型。这样做有两个好处:首先是可以增加模型的容量(比如增加一些层和卷积核),其次是可以观察现在这个模型能达到的最小loss值。这个可以检验模型的能力,要确保现在选的模型具有足够的拟合能力。
  9. 确保训练loss减少 在所有数据上训练,如果训练集的loss不再减少,需要再次确定你的模型具有足够的拟合能力。
  10. 数据送入网络前先可视化 在将数据送入模型前,先可视化一下数据,确保数据是正确。在y = model(x)之前,进行可视化。
  11. 动态可视化预测 在训练过程中,对固定批次上的样本可视化预测结果。
  12. 使用反向传播来绘制关系依赖图

3. 过拟合

在这一阶段,我们应该对数据集有一个很好的了解,并且有完整的训练/验证流程。对于任何的模型,我们都能计算得出我们需要的指标。现在可以开始迭代一个好模型了。一般为两个阶段:首先是使得模型足够强,能够在训练集上过拟合;然后在使用归一化策略,放弃一些训练loss,从而降低验证loss,达到一个平衡。

Tips & Tricks:

  1. 选择模型 先要为数据选择一个好的、合适的模型。其中重要的一点就是:不要逞英雄!不然一开始尝试复杂的、花里胡哨的模型,然后疯狂地做一些骚操作,先选择一个最简明、最普遍应用的模型。比如,如果是分类任务,直接上ResNet-50。
  2. Adam是保险的 Adam对学习率的设置更宽容,但是SGD的表现性能要更好(学习率的调整范围更窄:需要更精确的学习率)
  3. 一次只复杂化一个 这意味着,当我们有很多可以增加模型复杂度的方法,不要一股脑地全部用上去,一次只使用一个。
  4. 注意学习率衰减 要主要学习率的衰减策略,最开始的可以不用学习率衰减策略,而是使用恒定的学习率。这可以避免你的学习率过低导致模型不够拟合。

4. 正则化

理想情况下,进行到这一步的时候,我们可以得到一个能够拟合训练集的模型了(有可能存在过拟合现象)。现在我们需要加一些正则化操作,是的模型具有更强的泛化能力。
Tips & Tricks:

  1. 更多的数据 提高模型的泛化能力最重要的一条就是尽可能地收集多的真实样本数据。花尽功夫想使得小样本获得一个好的拟合性能是不现实的。
  2. 数据增强 这个不用多说,使用现有的数据来模拟一些数据。
  3. 有创意的增强 如果2.还不能满足要求,可以使用例如GANs的方法用来进行数据增强。
  4. 预训练模型 即使你拥有足够多的数据量,使用预训练模型也是没有坏处的。
  5. 坚持监督学习 目前还没有任何版本的无监督训练模型在计算机视觉领域取得显著成果。
  6. 更小的输入维度 如果图像细节不重要,可以将图像缩放小一些。
  7. 更小的模型 许多情况下,可以给网络加上领域知识限制(Domain Knowledge Constraints),使得模型变小。比如,我们之前都是给分类网络加上全连接层,后来,逐渐被简单的平均池化所代替,大大减少了参数
  8. 减少Batch Size 小的Batch size对于Batch Norm来说,在某种程度上能够增强模型的泛化能力。
  9. Drop 增加Dropout层,但是对于内置Batch Norm的网络来说,Dropout层似乎效果不太好(所以慎用)
  10. 增加weight decay
  11. 提前停止 出现loss不再下降的情况,应该提前停止训练,避免过拟合。
  12. 尝试更大的模型 最后提到这一点,并且在提前停止后才提到。虽然大的模型会过拟合得更厉害,但使用“提前停止”会使得它们比小模型表现得更好。

最后,提到一点,可以对于网络的第一层参数做一个可视化效果,看网络是否能够捕捉到一些有用的边缘信息。如果看起来像一团噪音,那么就要注意模型是否出现问题了。

5. 调参

Tips & Tricks:

  1. 随机网格搜索 这种方法有点费时,但是可以对于某些重要的超参做随机网格搜索。
  2. 超参数优化

6. 精益求精

Tips & Tricks:

  1. ensembles 把几个模型融合在一起,至少可以提高2%。如果算力顶不住,可以尝试使用网络蒸馏(https://arxiv.org/abs/1503.02531)
  2. 让模型’飞’一会 有时候你需要做的只是什么都不管,让模型继续训练下去,保持耐心可能会获得额外的守护喔,静静地让‘子弹飞一会’。

如何训练好一个神经网络?相关推荐

  1. 训练softmax分类器实例_知识蒸馏:如何用一个神经网络训练另一个神经网络

    作者:Tivadar Danka 编译:ronghuaiyang 原文链接 知识蒸馏:如何用一个神经网络训练另一个神经网络​mp.weixin.qq.com 导读 知识蒸馏的简单介绍,让大家了解知识蒸 ...

  2. 知识蒸馏:如何用一个神经网络训练另一个神经网络

    点击上方"小白学视觉",选择加"星标"或"置顶" 重磅干货,第一时间送达 如果你曾经用神经网络来解决一个复杂的问题,你就会知道它们的尺寸可能 ...

  3. python神经网络训练数据_用Python从头开始实现一个神经网络

    注:本篇文章非原创,翻译自Implementing a Neural Network from Scratch in Python – An Introduction​www.wildml.com ...

  4. 04_面向初学者的快速入门、建立图像分类的一个神经网络、训练这个神经网络、评估模型的精确度

    翻译自:https://tensorflow.google.cn/tutorials/quickstart/beginner 这是一个的使用Keras做如下事情的简短介绍: 建立图像分类的一个神经网络 ...

  5. 训练神经网络的详细步骤,如何训练一个神经网络

    如何训练神经网络 1.先别着急写代码训练神经网络前,别管代码,先从预处理数据集开始.我们先花几个小时的时间,了解数据的分布并找出其中的规律. Andrej有一次在整理数据时发现了重复的样本,还有一次发 ...

  6. 神经网络的三种训练方法,如何训练一个神经网络

    1.神经网络有哪些主要分类规则并如何分类? 神经网络模型的分类 人工神经网络的模型很多,可以按照不同的方法进行分类.其中,常见的两种分类方法是,按照网络连接的拓朴结构分类和按照网络内部的信息流向分类. ...

  7. 从原理上“训练”一个神经网络(下)

    点击关注我哦 一篇文章带你了解函数声明时的优雅操作 四.训练 当我们从神经网络开始时,我们会随机初始化权重.显然,它不会给很好的结果.在训练过程中,我们希望从性能不佳的神经网络入手,并以高准确度结束网 ...

  8. 从原理上“训练”一个神经网络(上)

    点击关注我哦 一篇文章带你了解函数声明时的优雅操作 一.引言 这是我计划的系列优化算法的第1部分,该算法特别用于机器学习和神经网络中的"训练".在这篇文章中,将介绍Gradient ...

  9. 用Python从头实现一个神经网络

    用Python从头实现神经网络 实在是觉得LaTeX编译出来的公式太好看了,所以翻译了一下,原文地址: Machine Learning for Beginners: An Introduction ...

最新文章

  1. 二十三、死锁的处理策略---避免死锁(银行家算法)
  2. Mybatis插入MySQL数据库中文乱码
  3. Redis Sentinel集群部署
  4. HDU多校10 - 6880 Permutation Counting(dp+思维)
  5. java 计算两个时间戳_Java时间戳计算重叠持续时间与间隔
  6. 通过 getResources 找不到jar包中的资源和目录的解决方法
  7. Windows环境下使用Linux命令
  8. 圆形取景框 相机_据说这款设备可以使老旧单反相机解决无线联机拍摄方案
  9. ITK简介与ITK Pipeline
  10. 思科路由器Ez***测试
  11. 【原创】nbsp;Window7nbsp;vs201…
  12. MarkDown桌面版使用下载+安装+使用教程(包括全套的使用语法,欢迎大家查看)
  13. dell笔记本插上耳机没有声音_笔记本扬声器没声音,但耳机有声音怎么办
  14. 功能齐全的 ESP32 智能手表,具有多个表盘、心率传感器硬件设计
  15. DySAT: Deep Neural Representation Learning on Dynamic Graph via Self-Attention Networks
  16. sublime软件中如何给很多行数据加逗号及双引号并转换成一行?
  17. 网店美工之你不知道的图片设计技巧
  18. 原来写在BlogBus(博客大巴)上的文章搬运到这里!~
  19. CSAPP 拆炸弹 中科大实验
  20. 中文分词预处理之N最短路径法小结

热门文章

  1. 在网页中点击链接就可以和在线好友QQ聊天
  2. Docker——使用docker工具管理软件/组件的运行,镜像、容器、数据卷的基本概念,常用指令,使用docker搭建Java微服务运行环境
  3. 当程序员遇到有远见的丈母娘,找对象那不是事
  4. VUE的父传子 子传父
  5. 启明欣欣STM32开发板闪烁LED实验
  6. 同一个网址电脑手机访问显示不同内容思路
  7. 判断方阵是否沿主对角线对称
  8. js 26个字母排序
  9. 【机器学习】基于mnist数据集的手写数字识别
  10. 如何在短时间内提高推广?做好展现量