作者丨Fatescript

来源丨https://zhuanlan.zhihu.com/p/450779978

编辑丨GiantPandaCV

本文算是我工作一年多以来的一些想法和经验,最早发布在旷视研究院内部的论坛中,本着开放和分享的精神发布在我的知乎专栏中,如果想看干货的话可以直接跳过动机部分。另外,后续在这个专栏中,我会做一些关于原理和设计方面的一些分享,希望能给领域从业人员提供一些看待问题的不一样的视角。

动机

前段时间走在路上,一直在思考一个问题:我的时间开销很多都被拿去给别人解释一些在我看起来显而易见的问题了,比如( https://link.zhihu.com/?target=https%3A//github.com/Megvii- BaseDetection/cvpods )里面的一些code写法问题(虽然这在某些方面说明了文档建设的不完善),而这变相导致了我实际工作时间的减少,如何让别人少问一些我觉得答案显而易见的问题?如何让别人提前规避一些不必要的坑?只有解决掉这样的一些问题,我才能从一件件繁琐的小事中解放出来,把精力放在我真正关心的事情上去。

其实之前同事有跟我说过类似的话,每次带一个新人,都要告诉他:你的实现需要注意这里blabla,还要注意那里blabla。说实话,我很佩服那些带intern时候非常细致和知无不言的人,但我本性上并不喜欢每次花费时间去解释一些我觉得显而易见的问题,所以我写下了这个帖子,把我踩过的坑和留下来的经验分享出去。希望能够方便别人,同时也节约我的时间。

加入旷视以来,个人一直在做一些关于框架相关的内容,所以内容主要偏向于模型训练之类的工作。因为 一个拥有知识的人是无法想象知识在别人脑海中的样子的(the curse of knowledge),所以我只能选取被问的最多的,和我认为最应该知道的

准备好了的话,我们就启航出发(另,这篇专栏文章会长期进行更新)。

坑/经验

Data模块

  1. python图像处理用的最多的两个库是opencv和Pillow(PIL),但是两者读取出来的图像并不一样, opencv读取的图像格式的三个通道是BGR形式的,但是PIL是RGB格式的 。这个问题看起来很小,但是衍生出来的坑可以有很多,最常见的场景就是数据增强和预训练模型中。比如有些数据增强的方法是基于channel维度的,比如megengine里面的HueTransform,这一行代码 (https://github.com/MegEngine/MegEngine/blob/4d72e7071d6b8f8240edc56c6853384850b7407f/imperative/python/megengine/data/transform/vision/transform.py#L958 ) 显然是需要确保图像是BGR的,但是经常会有人只看有Transform就无脑用了,从来没有考虑过这些问题。

  2. 接上条,RGB和BGR的另一个问题就是导致预训练模型载入后训练的方式不对,最常见的场景就是预训练模型的input channel是RGB的(例如torch官方来的预训练模型),然后你用cv2做数据处理,最后还忘了convert成RGB的格式,那么就是会有问题。这个问题应该很多炼丹的同学没有注意过,我之前写CenterNet-better(https://github.com/FateScript/CenterNet-better)就发现CenterNet(https://github.com/xingyizhou/CenterNet)存在这么一个问题,要知道当时这可是一个有着3k多star的仓库,但是从来没有人意识到有这个问题。当然,依照我的经验,如果你训练的iter足够多,即使你的channel有问题,对于结果的影响也会非常小。不过,既然能做对,为啥不注意这些问题一次性做对呢?

  3. torchvision中提供的模型,都是输入图像经过了ToTensor操作train出来的。也就是说最后在进入网络之前会统一除以255从而将网络的输入变到0到1之间。torchvision的文档(https://pytorch.org/vision/stable/models.html)给出了他们使用的mean和std,也是0-1的mean和std。如果你使用torch预训练的模型,但是输入还是0-255的,那么恭喜你,在载入模型上你又会踩一个大坑(要么你的图像先除以255,要么你的code中mean和std的数值都要乘以255)。

  4. ToTensor之后接数据处理的坑。上一条说了ToTensor之后图像变成了0到1的,但是一些数据增强对数值做处理的时候,是针对标准图像,很多人ToTensor之后接了这样一个数据增强,最后就是练出来的丹是废的(心疼电费QaQ)。

  5. 数据集里面有一个图特别诡异,只要train到那一张图就会炸显存(CUDA OOM),别的图训练起来都没有问题,应该怎么处理?通常出现这个问题,首先判断数据本身是不是有问题。如果数据本身有问题,在一开始生成Dataset对象的时候去掉就行了。如果数据本身没有问题,只不过因为一些特殊原因导致显存炸了(比如检测中图像的GT boxes过多的问题),可以catch一个CUDA OOM的error之后将一些逻辑放在CPU上,最后retry一下,这样只是会慢一个iter,但是训练过程还是可以完整走完的,在我们开源的YOLOX里有类似的参考code(https://github.com/Megvii-BaseDetection/YOLOX/blob/0.1.0/yolox/models/yolo_head.py#L330-L334)。

  6. pytorch中dataloader的坑。有时候会遇到pytorch num_workers=0(也就是单进程)没有问题,但是多进程就会报一些看不懂的错的现象,这种情况通常是因为torch到了ulimit的上限,更核心的原因是 torch的dataloader不会释放文件描述符 (参考issue: https://github.com/pytorch/pytorch/issues/973)。可以ulimit -n 看一下机器的设置。跑程序之前修改一下对应的数值。

  7. opencv和dataloader的神奇联动。很多人经常来问为啥要写cv2.setNumThreads(0),其实是因为cv2在做resize等op的时候会用多线程,当torch的dataloader是多进程的时候,多进程套多线程,很容易就卡死了(具体哪里死锁了我没探究很深)。除了setNumThreads之外,通常还要加一句cv2.ocl.setUseOpenCL(False),原因是cv2使用opencl和cuda一起用的时候通常会拖慢速度,加了万事大吉,说不定还能加速。感谢评论区 @Yuxin Wu(https://www.zhihu.com/people/ppwwyyxx) 大大的指正

  8. dataloader会在epoch结束之后进行类似重新加载的操作,复现这个问题的code稍微有些长,放在后面了。这个问题算是可以说是一个高级bug/feature了,可能导致的问题之一就是炼丹师在本地的code上进行了一些修改,然后训练过程直接加载进去了。解决方法也很简单,让你的sampler源源不断地产生数据就好,这样即使本地code有修改也不会加载进去。

Module模块

  1. BatchNorm在训练和推断的时候的行为是不一致的。这也是新人最常见的错误(类似的算子还有dropout,这里提一嘴, pytorch的dropout在eval的时候行为是Identity ,之前有遇到过实习生说dropout加了没效果,直到我看了他的code:x = F.dropout(x, p=0.5)

  2. BatchNorm叠加分布式训练的坑。在使用DDP(DistributedDataParallel)进行训练的时候,每张卡上的BN统计量是可能不一样的,仔细检查broadcast_buffer这个参数 。DDP的默认行为是在forward之前将rank0 的 buffer做一次broadcast(broadcast_buffer=True),但是一些常用的开源检测仓库是将broadcast_buffer设置成False的(参考:mmdet(https://github.com/facebookresearch/detectron2/blob/f50ec07cf220982e2c4861c5a9a17c4864ab5bfd/tools/plain_train_net.py#L206) 和 detectron2(https://github.com/facebookresearch/detectron2/blob/f50ec07cf220982e2c4861c5a9a17c4864ab5bfd/tools/plain_train_net.py#L206),我猜是在检测任务中因为batchsize过小,统一用卡0的统计量会掉点) 这个问题在一边训练一边测试的code中更常见 ,比如说你train了5个epoch,然后要分布式测试一下。一般的逻辑是将数据集分到每块卡上,每块卡进行inference,最后gather到卡0上进行测点。但是 因为每张卡统计量是不一样的,所以和那种把卡0的模型broadcast到不同卡上测试出来的结果是不一样的。这也是为啥通常训练完测的点和单独起了一个测试脚本跑出来的点不一样的原因 (当然你用SyncBN就不会有这个问题)。

  3. Pytorch的SyncBN在1.5之前一直实现的有bug,所以有一些老仓库是存在使用SyncBN之后掉点的问题的。

  4. 用了多卡开多尺度训练,明明尺度更小了,但是速度好像不是很理想?这个问题涉及到多卡的原理,因为分布式训练的时候,在得到新的参数之后往往需要进行一次同步。假设有两张卡,卡0的尺度非常小,卡1的尺度非常大,那么就会出现卡0始终在等卡1,于是就出现了虽然有的尺度变小了,但是整体的训练速度并没有变快的现象(木桶效应)。解决这个问题的思路就是 尽量把负载拉均衡一些

  5. 多卡的小batch模拟大batch(梯度累积)的坑。假设我们在单卡下只能塞下batchsize = 2,那么为了模拟一个batchsize = 8的效果,通常的做法是forward / backward 4次,不清理梯度,step一次(当然考虑BN的统计量问题这种做法和单纯的batchsize=8肯定还是有一些差别的)。在多卡下,因为调用loss.backward的时候会做grad的同步,所以说前三次调用backward的时候需要加ddp.no_sync(https://pytorch.org/docs/stable/generated/torch.nn.parallel.DistributedDataParallel.html?highlight=no_sync#torch.nn.parallel.DistributedDataParallel.no_sync)的context manager(不加的话,第一次bp之后,各个卡上的grad此时会进行同步),最后一次则不需要加。当然,我看很多仓库并没有这么做,我只能理解他们就是单纯想做梯度累积(BTW,加了ddp.no_sync会使得程序快一些,毕竟加了之后bp过程是无通讯的)。

  6. 浮点数的加法其实不遵守交换律的 ,这个通常能衍生出来GPU上的运算结果不能严格复现的现象。可能一些非计算机软件专业的同学并不理解这一件事情,直接自己开一个python终端体验可能会更好:

print(1e100 + 1e-4 + -1e100)  # ouptut: 0
print(1e100 + -1e100 + 1e-4)  # output: 0.0001

训练模块

  1. FP16训练/混合精度训练。使用Apex训练混合精度模型,在保存checkpoint用于继续训练的时候,除了model和optimizer本身的state_dict之外,还需要保存一下amp的state_dict,这个在amp的文档(https://link.zhihu.com/?target=https%3A//nvidia.github.io/apex/amp.html%23checkpointing)中也有提过。(当然,经验上来说忘了保存影响不大,会多花几个iter search一个loss scalar出来)

  2. 多机分布式训练卡死的问题。好友 @NoahSYZhang(https://www.zhihu.com/people/syzhangbuaa) 遇到的一个坑。场景是先申请了两个8卡机,然后机器1和机器2用前4块卡做通讯(local rank最大都是4,总共是两机8卡)。可以初始化process group,但是在使用DDP的时候会卡死。原因在于pytorch在做DDP的时候会猜测一个rank,参考code(https://github.com/pytorch/pytorch/blob/0d437fe6d0ef17648072eb586484a4a5a080b094/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp#L1622-L1630)。对于上面的场景,第二个机器上因为存在卡5到卡8,而对应的rank也是5到8,所以DDP就会认为自己需要同步的是卡5到卡8,于是就卡死了。

复现Code

Data部分

from torch.utils.data import DataLoader
from torch.utils.data import Dataset
import tqdm
import time
class SimpleDataset(Dataset):def __init__(self, length=400):self.length = lengthself.data_list = list(range(length))def __getitem__(self, index):data = self.data_list[index]time.sleep(0.1)return datadef __len__(self):return self.length
def train(local_rank):dataset = SimpleDataset()dataloader = DataLoader(dataset, batch_size=1, num_workers=2)iter_loader = iter(dataloader)max_iter = 100000for _ in tqdm.tqdm(range(max_iter)):try:_ = next(iter_loader)except StopIteration:print("Refresh here !!!!!!!!")iter_loader = iter(dataloader)_ = next(iter_loader)
if __name__ == "__main__":import torch.multiprocessing as mpmp.spawn(train, args=(), nprocs=2, daemon=False)

当程序运行起来的时候,可以在Dataset里面的__getitem__方法里面加一个print输出一些内容,在refresh之后,就会print对应的内容哦(看到现象是不是觉得自己以前炼的丹可能有问题了呢hhh)

一些碎碎念

一口气写了这么多条也有点累了,后续有踩到新坑的话我也会继续更新这篇文章的。毕竟写这篇文章是希望工作中不再会有人踩类似的坑 & 炼丹的人能够对深度学习框架有意识(虽然某种程度上来讲这算是个心智负担)。

如果说今年来什么事情是最大的收获的话,那就是理解了一个开放的生态是可以迸发出极强的活力的,也希望能看到更多的人来分享自己遇到的问题和解决的思路。毕竟探索的答案只是一个副产品,过程本身才是最大的财宝。

本文仅做学术分享,如有侵权,请联系删文。

重磅!计算机视觉工坊-学习交流群已成立

扫码添加小助手微信,可申请加入3D视觉工坊-学术论文写作与投稿 微信交流群,旨在交流顶会、顶刊、SCI、EI等写作与投稿事宜。

同时也可申请加入我们的细分方向交流群,目前主要有ORB-SLAM系列源码学习、3D视觉CV&深度学习SLAM三维重建点云后处理自动驾驶、CV入门、三维测量、VR/AR、3D人脸识别、医疗影像、缺陷检测、行人重识别、目标跟踪、视觉产品落地、视觉竞赛、车牌识别、硬件选型、深度估计、学术交流、求职交流等微信群,请扫描下面微信号加群,备注:”研究方向+学校/公司+昵称“,例如:”3D视觉 + 上海交大 + 静静“。请按照格式备注,否则不予通过。添加成功后会根据研究方向邀请进去相关微信群。原创投稿也请联系。

▲长按加微信群或投稿

▲长按关注公众号

3D视觉从入门到精通知识星球:针对3D视觉领域的视频课程(三维重建系列三维点云系列结构光系列手眼标定相机标定、激光/视觉SLAM、自动驾驶等)、知识点汇总、入门进阶学习路线、最新paper分享、疑问解答五个方面进行深耕,更有各类大厂的算法工程人员进行技术指导。与此同时,星球将联合知名企业发布3D视觉相关算法开发岗位以及项目对接信息,打造成集技术与就业为一体的铁杆粉丝聚集区,近4000星球成员为创造更好的AI世界共同进步,知识星球入口:

学习3D视觉核心技术,扫描查看介绍,3天内无条件退款

圈里有高质量教程资料、可答疑解惑、助你高效解决问题

觉得有用,麻烦给个赞和在看

关于炼丹,你是否知道这些细节?相关推荐

  1. 数据竞赛Tricks集锦

    点击上方"Datawhale",选择"星标"公众号 第一时间获取价值内容 本文将对数据竞赛的『技巧』进行全面的总结,同时还会分享下个人对比赛方法论的思考.前者比 ...

  2. PyTorch消除训练瓶颈 提速技巧

    [GiantPandaCV导读]训练大型的数据集的速度受很多因素影响,由于数据集比较大,每个优化带来的时间提升就不可小觑.硬件方面,CPU.内存大小.GPU.机械硬盘orSSD存储等都会有一定的影响. ...

  3. 关于炼丹,那些不为人知的细节

    作者丨Fatescript 来源丨https://zhuanlan.zhihu.com/p/450779978:仅做学术分享: 序 本文算是我工作一年多以来的一些想法和经验,最早发布在旷视研究院内部的 ...

  4. 【干货】新手炼丹经验总结

    点击上方"视学算法",选择加"星标"或"置顶" 重磅干货,第一时间送达 作者丨hzwer@知乎 来源丨https://zhuanlan.zh ...

  5. 2021年炼丹笔记最受欢迎的10篇技术文章

    阶段性整理知识笔记是炼丹笔记的习惯,在这里我们温故而知新,根据文章在全网的阅读情况整理了2021年度,最受欢迎的10篇文章,错过的朋友可以补一下哦. 推荐系统内容实在太丰富了,以至于刚开始学的人都无从 ...

  6. 推荐系统炼丹笔记:EdgeRec阿里边缘计算推荐系统

    作者:一元 公众号:炼丹笔记 背景 推荐系统(RS)已经成为大多数web应用程序的关键模块.最近,大多数RSs都是基于云到边缘框架的瀑布式的,其中推荐的结果通过在云服务器中预先计算传送到edge(例如 ...

  7. 推荐系统炼丹笔记:推荐系统Bias/Debias大全

    作者:一元 公众号:炼丹笔记 背景 在实践中,做推荐系统的很多朋友思考的问题是如何对数据进行挖掘,大多数论文致力于开发机器学习模型来更好地拟合用户行为数据.然而,用户行为数据是观察性的,而不是实验性的 ...

  8. 推荐算法炼丹笔记:做向量召回 All You Need is 双塔

    作者:十方,公众号:炼丹笔记 对于基于向量召回,那就不得不提到双塔.为什么双塔在工业界这么常用?双塔上线有多方便,真的是谁用谁知道,user塔做在线serving,item塔离线计算embeding建 ...

  9. 推荐系统炼丹笔记:阿里边缘计算+奉送20个推荐系统强特

    作者:一元 公众号:炼丹笔记 背景 推荐系统(RS)已经成为大多数web应用程序的关键模块.最近,大多数RSs都是基于云到边缘框架的瀑布式的,其中推荐的结果通过在云服务器中预先计算传送到edge(例如 ...

  10. 炼丹面试官的面试笔记

    作者:无名,某小公司算法专家 排版:一元,四品炼丹师 公众号:炼丹笔记 关于Attention和Transformer的灵魂拷问 背景 现在几乎很多搞深度学习的朋友都把attention和Transf ...

最新文章

  1. jQuery Tools:Web开发必备的 jQuery UI 库
  2. 英语是缺乏AOP的语言,汉语是具备AOP的语言。
  3. python for循环求和_python用for循环求和的方法总结
  4. JAVA 正则表达式 分组
  5. 替换a链接的href和title
  6. VTK:检查VTK的版本用法实战
  7. [NOI2018]冒泡排序
  8. 常用代码生成工具介绍
  9. 2020年墨天轮数据报告发布!
  10. 【解决方案】PDF文字复制后乱码
  11. Microsoft Visio 2016 专业版
  12. Redis开发与运维学习笔记
  13. Spring Guide:Securing a Web Application(中文大概意思)
  14. 掌中革命--手机富媒体
  15. SURF(Speeded Up Robust Features)算法原理
  16. 欢迎清风艾艾在ITPUB博客安家!
  17. linux的s权限和t权限
  18. js网页点击播放背景音乐,再次点击暂停播放背景音乐
  19. 千里之行,始于足下——有感于平安林伟丹的分享
  20. 高级开发工程师如何快速晋升为架构师?高级开发工程师与架构师到底有啥区别?

热门文章

  1. 服务器里怎么找到K3账套文件,金蝶K3账套自动备份步骤详解
  2. UG二次开发GRIP过滤
  3. OpenCV中feature2D学习——Shi-Tomasi角点检测
  4. 基于机器学习方法对销售预测的研究
  5. Dubbo Monitor 分析
  6. Panabit流控软件使用相关说明及配置文件说明
  7. 神仙打架!清华公布2020特奖候选人名单,有人三篇顶会一作!还有人...
  8. Ruby语言的优点和缺点
  9. 2018,丁磊的野心静悄悄
  10. T410i笔记本DP线转接HDMI链接外设无法传输声音问题解决