最近在鼓捣使用pytorch的distributeddataparallel这个API搭一个数据并行的训练测试任务,过程中遇到了一个问题,做一下记录。

1、问题

  使用DDP打包了一个模型训练了一段时间,loss不断下降metric不断上升,一切都是很正常的现象。当因为意外暂停或者手动暂停更改学习率而停止了程序,再开启程序加载之前的checkpoint继续训练,却发现loss突然比之前上升或者metric比之前下降了很多。仔细看了一下loss的值,发现直接回到刚开始第一次训练模型时的水平,仿佛checkpoint根本没加载进去,是从初始化开始训练的一样。

2、原因分析

根据我之前的框架使用经验,认为可能的原因有以下两点:

2.1 模型的train和eval模式问题

  由于很多算子在训练模式和测试模式下的前向传播原理不同,例如batchnorm和dropout等,导致几乎所有的框架都会对模型设置一个train或eval的flag。Pytorch可以通过调用model.train()或model.eval()将模型的状态进行切换。在训练模式下如果模型是eval状态或者在推理模式下模型是train状态都会使得结果计算不正确,可能是导致上述问题的一个原因。
  但这个猜想很快就被我给否掉了。第一,按照我的经验,如果一个模型已经训练到一个比较好的状态,即便是搞混了train和eval的状态flag,结果虽然不对但是一般也不会差的特别多。我之前用0.1学习率训练了七八个epoch,损失已经到了0.4~0.5左右。再次训练加载checkpoint损失直接飙到了差不多快到8了,这个跳跃太大了。第二,我去check了一下我的代码,发现并没有出现train和eval搞混的问题(手动狗头)。

2.2 模型没有正确的加载进去

  出现上述问题的另外一个可能的原因是:Pytorch没有正确的将模型加载进去。经常使用pytorch的同学可能都遇到过这样一种情况:自己设计了一个网络用来做某项任务,选择了某个经典分类模型(如resnet等)的特征提取部分作为backbone。训练时在github上下载了已经在ImageNet数据集上pretrain的分类模型,并把这个模型的特征提取部分的权重直接加载到自己的模型中实现backbone预训练。但是效果却并不好,可能的原因之一就是backbone并没有成功的加载进去。
  Pytorch中模型参数的保存底层使用的是字典的结构,因此参数加载需要保证参数名必须是一一对应的。常用的一个加载模型参数的API是load_state_dict,其中有一个参数是strict=True,这个参数用来控制加载模型是否是“严格”的。严格指的是代码模型定义里的所有parameter和buffer必须和要加载的checkpoint里的parameter和buffer的参数名、参数维度、参数类型等能够一一对应上,一个都不能多也不能少,否则就会报错。strict=False则可以允许代码模型定义里的部分parameter或buffer和checkpoint中的对应不上,如果有能对应上的就加载,否则就忽略。比如下面的情况,当strict=False时,parameter2、3、4和5可以被正确加载,parameter1和6不会被加载而采用用户定义的方式初始化;当strict=True时,加载会报错。

  在我遇到的问题中,经过确认我排除了这个可能性。checkpoint是用我自定义的模型训练得到的而不是从网上下载的,模型定义我没有更改过因此和之前的是一样的,而且我设置了strict=True,也没有报错说明模型是被正确加载进去的。

2.3 DistributedDataParallel问题

  以上两种思考没有解决我的问题,此时我痛定思痛仔细回想一下整个过程。同样的代码同样的逻辑之前不做数据并行的时候是没有问题的,但是一做DistributedDataParallel训练就出现了问题,说明bug出在DistributedDataParallel这里。看了一下这个API的源码,找到了问题所在。在这个类的__init__函数里有这么一段:

class DistributedDataParallel(Module):def __init__(self, ...):...# Sync params and buffersself._sync_params_and_buffers(authoritative_rank=0)...

也就是说在调用这个API把一个普通的model打包成一个ddp的model后,即实例化一个DistributedDataParallel对象的时候,就已经完成了模型的parameter和buffer在主进程模型和其他进程上replica的同步。而我的代码里,是先实例化了一个ddp对象,然后才去加载checkpoint

...
model = MyModel()
model.to(device=rank)
model = nn.parallel.DistributedDataParallel(model, devices=[rank])
if rank == 0:ret = model.load_state_dict(torch.load(xxx), strict=True)
...

此时代码的执行过程是:1、实例化一个MyModel对象并随机初始化;2、实例化一个ddp对象并用之前随机初始化的model去同步其他进程上replica的parameter和buffer;3、将checkpoint的parameter和buffer加载到主进程上的model中。此时其他几个进程上的model的parameter和buffer还都是随机初始化的,在前向和反向传播时虽然主进程上的model给出了类似之前checkpoint比较准确的结果。可是其他几个子进程上的模型由于参数是随机初始化的所以结果差的很远,各个进程上的梯度经过reduce_mean后就错的很离谱了。因此应该调整一下代码的顺序为:

...
model = MyModel()
model.to(device=rank)
if rank == 0:ret = model.load_state_dict(torch.load(xxx), strict=True)
model = nn.parallel.DistributedDataParallel(model, devices=[rank])
...

  此时仍然有一个小小的bug,就是通过DistributedDataParallel这个API去打包模型后,模型的所有参数的名字都会多一个module的前缀,还是看一下API的源码:

class DistributedDataParallel(Module):def __init__(self, module, ...):...self.module = module...

熟悉Pytorch.nn.Module这个类的变量命名规则的同学应该知道,加了这个成员变量赋值的语句后,所有模型变量的名字前缀都会多一个module。比如MyModel()实例化的对象中有一个名为conv1.weight的参数,经过DDP打包后得到的新模型中,对应的参数变量名会变为module.conv1.weight,一种解决办法是可以通过保存模型时指定保存DDP对象的module模块来消除这个前缀。

  水平有限,欢迎讨论。

Pytorch采坑记录:DDP加载之前的checkpoint后loss上升(metric下降)相关推荐

  1. Pytorch采坑记录:每隔num_workers个iteration数据加载速度很慢

      最近在做某个视觉任务的模型训练,由于数据量比较少为了效果好一点,决定现在imagenet上pretrain一下骨干网络.但是在训练的时候遇到了一个问题:每隔num_workers个iteratio ...

  2. Centos7.9上利用cephadm安装Ceph Octopus 15.2的采坑记录,附带K8S挂载方法

    Centos7.9上利用cephadm安装Ceph Octopus 15.2的采坑记录,附带K8S挂载方法 0.亮点 1 准备 1.1 修改历史记录 1.2 升级系统内核 1.3 配置免密登录 问题1 ...

  3. Pytorch踩坑记录:关于用net.eval()和with no grad装饰器计算结果不一样的问题

    Pytorch踩坑记录 相同点 net.eval()和with toch.no_grad()的相同点:都停止反向传播 不同点: 1.net.eval() 用net.eval(),此时BN层会用训练时的 ...

  4. iOS 微信SDK1.8.6后需要UniversalLink解决方案及采坑记录

    项目最初因审核原因,一直使用iOS原生分享, 最近因项目需求要求, 接入微信分享, 以为和原来的没有区别, 但是接入时才发现改动的地方还是挺多的, 主要是需要配置UniversalLink和提包时的一 ...

  5. php给微信公众号接入聊天机器人程序+采坑记录

    php给微信公众号接入聊天机器人程序 今天逛了下我的公众号,突然心血来潮,想添加个自动聊天功能,于是-动手-!! 主要用到的api: 图灵机器人api 青云客智能聊天机器人API 茉莉机器人API 至 ...

  6. MATLAB 非线性隐函数拟合采坑记录(使用 fsolve solve nlinfit lsqcurvefit函数)

    MATLAB 非线性隐函数拟合采坑记录(使用 fsolve solve nlinfit lsqcurvefit函数) 问题描述 解决思路 错误示范1 代码思路 原因解释 模型更正 更正模型1 更正模型 ...

  7. Mac EOS 采坑记录

    Mac EOS 采坑记录 eos版本 dawn v4.0.0 Mac OS 版本 10.13.4 错误信息: Could not find a package configuration file p ...

  8. H5拍照、预览、压缩、上传采坑记录

    H5拍照.预览.压缩.上传采坑记录 公司项目前段时间需要实现手机拍照上传的功能,本来以为用createObjectURL和canvas可以很轻松的实现,结果发现问题多多,特此记录下来. DEMO预览( ...

  9. PyTorch 保存模型结构参数及加载模型

    PyTorch 保存模型结构参数及加载模型 保存模型与加载 保存模型分为两种方式: 保存整个网络结构和参数 保存整个网络的参数 # 1.保存并加载整个网络结构和参数 # 保存模型 torch.save ...

最新文章

  1. ArcGIS案例学习1_2
  2. centos离线安装mysql8_CentOS7离线安装Mysql8.0
  3. 关于如何正确地在android项目中添加第三方jar包
  4. VIsual Studio编译OpenCV:无法打开python27_d.lib(python36_d.lib)的问题
  5. 前端学习(159):meta
  6. epoll边缘触发_4.2.3、epoll:水平触发与边缘触发
  7. 《Imperfect C++中文版》——1.3 运行期契约:前置条件、后置条件和不变式
  8. css background-image显示全部_CSS 与网络性能,看了都说好
  9. dummy node
  10. sql数据库 ‘xxxxxx‘ 已存在,请选择其他数据库名称
  11. Ubuntu 16.04安装Synaptic Package Manager图形化APT管理工具
  12. 《“通用语”与“兽人语”互译手册》(全集)
  13. 大一女生废话编程爆火!懂不懂编程的看完都拴Q了
  14. 用MATLAB拟合实验报告,MATLAB插值与拟合实验报告材料
  15. Powershell 脚本创建 iso 映像文件
  16. 发生系统错误 1275.此驱动程序被阻止加载 解决方案
  17. 嵌入式linux学习-驱动(2) hello world 模块实现记录 基于RK3568
  18. Quest固件下载链接,最全版本升级包,降级,Quest2,Firmware,rom,system.img,boot.img, 附录下载地址大全
  19. 能被某些数整除的数的特征
  20. 基于Python的一个疫情传播可视化模拟实验

热门文章

  1. 大数据场景中语言虚拟机的应用和挑战
  2. 2016第三届科学数据大会诚邀商务合作
  3. 数据库系统实训——实验八——数据库维护
  4. 【Python】处理 from sklearn.externals import joblib 报错问题
  5. (四)Go 语言编译流程简述
  6. Firefly推出了小型高性能嵌入式主机
  7. Java程序员的日常—— 垃圾回收中引用类型的作用
  8. shell:判断一个进程是否存在
  9. 技术女性的是是非非(2)
  10. 当自动化遇见数字化——德资企业儒拉玛特的数字化实践