前言

作为一名每天与神经网络训练/测试打交道的同学,是否经常会遇到以下这几个问题,时常怀疑人生:

  1. 怎么肥事,训练正常着呢,咋效果这么差呢?

  2. 嗯。。再等等是不是loss就更低了。啊?明明loss更低了呀,为啥效果更差了?

  3. 又是怎么肥事?我改了哪里,效果提升了这么多?阿哈哈哈哈收工下班。

总而言之,当模型效果不如预期的时候去调试深度学习网络是一件头疼且繁琐的事情,为了让这件麻烦事情更加仅仅有条,笔者结合实际经验简单整理了一些checklist,方便广大炼丹师傅掌握火候。

1. 从最简单的数据/模型开始

现在开源社区做的很好,同学们用模型也十分方便,但也有相应的问题。以句子情感识别为例,新入手的同学可能一上来就调出HuggingFace/transformers代码库,然后一股脑BERT/Roberta啥的跑个结果,当然文档做好的开源代码一般都能照着跑个好结果,但改到自己数据集上往往就懵逼了,啊这?51%的二分类准确率(可能夸张了点,但如果任务要比二分类稍微复杂点,基本结果不会如预期),也太差了吧,HuggingFace/transofmrers这些模型不行啊。算了,咱换一个库吧,再次求助github和谷歌搜索。其实可能都还不清楚数据输入格式对不对?数据量够不够?评测指标含义是否清楚?Roberta的tokenizer是咋做的?模型结构是什么样子?

所以第1个checklist是:请尽量简单!

  • 模型简单

  • 数据简单

模型简单:解决一个深度学习任务,最好是先自己搭建一个最简单的神经网络,就几层全连接的那种。

数据简单:一般来说少于10个样本做调试足够了,一定要做过拟合测试(特别是工作的同学,拿过来前人的代码直接改个小结构就跑全量数据训练7-8天是可能踩坑的哦,比如某tensorflow版本 GPU embedding查表,输入超出了vocab size维度甚至可能都不报错哦,但cpu又会报错)!如果你的模型无法在7、8个样本上过拟合,要么模型参数实在太少,要么有模型有bug,要么数据有bug。为什么不建议1个样本呢?多选几个有代表性的输入数据有助于直接测试出非法数据格式。但数据太多模型就很难轻松过拟合了,所以建议在10个以下,1个以上,基本ok了。

2. loss设计是否合理?

loss决定了模型参数如何更新,所以记得确定一下你的loss是否合理?

  • 初始loss期望值和实际值误差是否过大,多分类例子。

    橘个????:CIFAR-10用Softmax Classifier进行10分类,那么一开始每个类别预测对的概率是0.1(随机预测),用Softmax loss使用的是negative log probability,所以正确的loss大概是:-ln(0.1)= 2.303左右。

  • 初始loss测试,二分类例子。假设数据中有20%是标签是0,80%的标签是1,那么一开始的loss大概应该是-0.2ln(0.5)-0.8ln(0.5)=0.69左右,如果一开始loss比1还大,那么可能是模型初始化不均匀或者数据输入没有归一化。

  • 比如多任务学习的时候,多个loss相加,那这些loss的数值是否在同一个范围呢?

  • 数据不均衡的时候是不是可以尝试一下focal loss呢?

3. 网络中间输出检查、网络连接检查

Pytorch已经可以让我们像写python一样单步debug了,所以输入输出shape对齐这步目前还挺好做的,基本上单步debug走一遍forward就能将网络中间输出shape对齐,连接也能对上,但有时候还是可能眼花看漏几个子网络的连接。

所以最好再外部测试一下每个参数的梯度是否更新了,训练前后参数是否都改变了。

那么具体的模型中间输出检查、网络连接检查是:

  • 确认所有子网络的输入输出shape对齐,并确认全部都连接上了,可能有时候定一个一个子网络,但放一边忘记连入主网络啦。

  • 梯度更新是否正确?如果某个参数没有梯度,那么是不是没有连上?

  • 如果参数的梯度大部分是0,那么是不是学习率太小了?

  • 时刻监测一下梯度对不对/时刻进行修正。经典问题:梯度消失,梯度爆炸。

  • 参数的梯度是否真的被更新了?有时候我们会通过参数名字来设置哪些梯度更新,哪些不更新,而这个时候有木有误操作呢?

读者可以参考stanford cs231n中的Gradient checking:

https://cs231n.github.io/neural-networks-3/#gradcheck

https://cs231n.github.io/optimization-1/#gradcompute

另外用tensorboard来检查一下网络连接/输入输出shape和连接关系也是不错的。

4. 时刻关注着模型参数

所谓模型参数也就是一堆矩阵/或者说大量的数值。如果这些数值中有些数值异常大/小,那么模型效果一般也会出现异常。一般来说,让模型参数保持正常有这么几个方法:

  • 调整batch size(或者说mini-batch)。

  • 统计梯度下降中,我们需要的batch size要求是:1、batch size足够大到能让我们在loss反向传播时候正确估算出梯度;2、batch size足够小到统计梯度下降(SGD)能够一定程度上regularize我们的网络结构。batch size太小优化困难,太大又会导致:Generalization Gap和Sharp Minima(具体参考:论文https://arxiv.org/abs/1609.04836,On Large-Batch Training for Deep Learning: Generalization Gap and Sharp Minima)。

  • 调整learning rate学习率。

    学习率太小可能会导致局部最优,而太大又会导致模型无法收敛。

    具体读者可以学习斯坦佛cs231n这个部分:

    https://cs231n.github.io/neural-networks-3/#anneal,另外关于学习率几个常用的网站:

    Pytorch:https://pytorch.org/docs/stable/_modules/torch/optim/lr_scheduler.html

  • Tensorflow:

    https://www.tensorflow.org/api_docs/python/tf/train/exponential_decay

    Keras:

    https://keras.io/callbacks/#learningratescheduler。

  • 梯度裁剪。

    在反向传播的时候,将参数的梯度限制在一个范围之类:[-min, max]。对于梯度消失和梯度爆炸很有帮助。

  • Batch normalization。将每层的输入进行归一化,有助于解决internal covariate shift问题。当然最近transformer流行的是Layer normalization。不同normalization的区别可以学习张俊林老师关于normalization的博客。

  • Dropout。

    网络参数随机失活,有效防止过拟合/正则的常用手段。但是如果batch normalization和Dropout一起使用的话,建议先学习这个文章:Pitfalls of Batch Norm in TensorFlow and Sanity Checks for Training Networks 还有这个文章:Understanding the Disharmony between Dropout and Batch Normalization by Variance Shift。

  • Regularization。

    Regularization对于模型的泛化能力很重要,他对于模型的参数量和复杂度做了一个惩罚,能在不显著增加模型bias的情况下降低模型的variance。

    但需要注意的是:

    通常情况下,我们的loss是data loss加上regularization loss,那么如果regularization loss比data loss更大更重要了,那么魔性的data loss可能就学不好了。

  • 使用什么优化器?

    一般来说SGD作为baseline就可以了,但如果想要更好的效果比如使用Adam,还有很多其他SGD的改进可以使用。

    这个部分建议阅读这篇文章进行学习:An overview of gradient descent optimization algorithms。

5. 详细记录调试/调参数过程

可能这会儿换了一个learning rate,过会儿增大了dropout,过会儿又加了一个batch normalization,最后也不知道自己改了啥。

一个好的办法是是使用excel(虽然有些古老,其实还是很有效的,可以记录各种自己想要记录的变量)将重点改进,改进结果进行存放,另外合理使用tensorboard也是不错。

由于要实验或者改的地方太多,通常就时不时忘记/不方便使用git了,而是copy一大堆名字相似的文件,这个时候,请千万注意你的代码结构/命名规则,当然使用好bash脚本将使用的参数,训练过程一一存放起来也是不错的选择。

总之,无论是古老的工具,先进的工具,能将实验过程进行记录/复现的就是好工具。具体使用什么工具因人而异,可能有些人就是不喜欢用git。。。当然也有一些别人开发好的工具可以使用啦,比如:Comet.ml

模型对数据/超参数,甚至是随机种子、GPU版本,tensorflow/pytorch版本,所以请尽可能记录好每个部分,并且最好时刻可以复现。最后小时候学的控制变量法也很重要哦。

总结

将以上内容做一个总结:

  1. 简单模型,简单数据,全流程走通。

  2. 调整/选择合理的loss函数/评价指标,最好检查一下初始loss是否符合预期。

  3. 查看网络中间输出、子网络是否都连接上了。

  4. 时刻关注模型参数。无论是优化器的改变、学习率的改变、增加正则方法或者梯度裁剪,主要作用都是在修正/更新模型参数。

  5. 详细记录实验过程。保持良好的训练/测试流程和习惯,SOTA近在眼前~。

参考文献:

https://towardsdatascience.com/checklist-for-debugging-neural-networks-d8b2a9434f21

往期精彩回顾适合初学者入门人工智能的路线及资料下载机器学习及深度学习笔记等资料打印机器学习在线手册深度学习笔记专辑《统计学习方法》的代码复现专辑
AI基础下载机器学习的数学基础专辑温州大学《机器学习课程》视频
本站qq群851320808,加入微信群请扫码:

【深度学习】收藏|神经网络调试Checklist相关推荐

  1. 一文掌握深度学习、神经网络和学习过程的历史

    来源:算法与数学之美 本质上,深度学习是一个新兴的时髦名称,衍生于一个已经存在了相当长一段时间的主题--神经网络. 从20世纪40年代开始,深度学习发展迅速,直到现在.该领域取得了巨大的成功,深度学习 ...

  2. 深度学习与神经网络概述

    本文将简单介绍:人工智能(Artificial Intelligence).机器学习(Machine Learning).深度学习(Deep Learning),并介绍神经网络的发展,以及三个在线演示 ...

  3. 机器学习、深度学习、神经网络学习资料集合(开发必备)

    最近整理了下AI方面的学习资料,包含了学习社区.入门教程.汲取学习.深度学习.自然语言处理.计算机视觉.数据分析.面试和书籍等方面的知识.在这里分享给大家,欢迎大家点赞收藏. 学习社区 神力AI(MA ...

  4. 深度学习和神经网络的介绍(一)

    1.深度学习和神经网络 1.1 深度学习的介绍 目标: 知道什么是深度学习 知道深度学习和机器学习的区别 能够说出深度学习的主要应用场景 知道深度学习的常见框架 1.1.1 深度学习的概念 深度学习是 ...

  5. 这套人工智能算法书已经出版了3卷,其中卷3深度学习和神经网络最受程序员喜欢

    人工智能算法系列图书以一种数学上易于理解的方式讲授人工智能相关概念,这也是本系列图书英文书名中"for Human"的含义. 本系列图书的每一卷均可独立阅读,也可作为系列图书整体阅 ...

  6. 深度学习(1)基础1 -- 深度学习与神经网络基础

    目录 一.深度学习与神经网络 1.深度学习定义 2.神经网络 3.深度学习过程 4.深度学习功能 二.深度学习应用 三.分类数据集推荐 一.深度学习与神经网络 1.深度学习定义 深度学习(deep l ...

  7. 从神经元到神经网络、从神经网络到深度学习:神经网络、深度学习、神经元、神经元模型、感知机、感知机困境、深度网络

    从神经元到神经网络.从神经网络到深度学习:神经网络.深度学习.神经元.神经元模型.感知机.感知机困境.深度网络 目录 从神经元到神经网络.从神经网络到深度学习 神经网络:

  8. 针对深度学习(神经网络)的AI框架调研

    针对深度学习(神经网络)的AI框架调研 在我们的AI安全引擎中未来会使用深度学习(神经网络),后续将引入AI芯片,因此重点看了下业界AI芯片厂商和对应芯片的AI框架,包括Intel(MKL CPU). ...

  9. DL:听着歌曲《成都》三分钟看遍主流的深度学习的神经网络的发展框架(1950~2018)

    DL:听着歌曲<成都>三分钟看遍主流的深度学习的神经网络的发展框架(1950~2018) 视频链接:听着歌曲<成都>三分钟看遍主流的深度神经网络的发展框架(1950~2018) ...

  10. 深度学习(神经网络) —— BP神经网络原理推导及python实现

    深度学习(神经网络) -- BP神经网络原理推导及python实现 摘要 (一)BP神经网络简介 1.神经网络权值调整的一般形式为: 2.BP神经网络中关于学习信号的求取方法: (二)BP神经网络原理 ...

最新文章

  1. 如何使用create-react-app在本地设置HTTPS
  2. Python 速度慢,试试这个方法提高 1000 倍
  3. 《Sibelius 脚本程序设计》连载(二十六) - 2.13 utils库中的函数
  4. 使用 ExtJs Extender Controls 遇到的第一个错误
  5. Java黑皮书课后题第2章:*2.21(金融应用:计算未来投资回报)编写程序,读取投资总额、年利率和年龄,显示未来投资回报金额
  6. MHA环境搭建【4】manager相关依赖的解决
  7. POJ_2112 Optimal Milking(网络流)
  8. 密文恢复出明文的过程称为_整流二极管的反向恢复过程图解
  9. 在wamp里面配置feehicms
  10. L2-016 愿天下有情人都是失散多年的兄妹(DFS)
  11. Virtualbox安装Ubuntu
  12. CAD文字快速添加框
  13. 时间管理:良好的状态是解决重要不紧急的事,而不是陷入重要且紧急的事情中出不来
  14. ecshop后台getshell
  15. 排列组合Cnm的求法
  16. 编程新手表示很想知道JAVA中Bean是什么?
  17. “黑客”必备书籍 你值得拥有!
  18. 阿里云服务器ECS如何临时升级带宽?
  19. failed to register layer: Error processing tar file(exit status 1): archive/tar: invalid tar header
  20. java 骰子游戏_java 骰子游戏

热门文章

  1. 依赖注入容器Autofac的详解[转]
  2. 表设计避免使用保留字
  3. IL语言之.ctor
  4. 微信小程序1. Forgot to add page route in app.json. 2. Invoking Page() in async task.
  5. 序列化反序列化api(入门级)
  6. MS-SQL分页not in 方法改进之使用row_number
  7. 新版上线时发现的数据库优化问题
  8. .NET打包工具怎么注册 .dll文件??
  9. #转载:十大排序方法,动图展示
  10. JavaSE(三)——数组及继承