首先很多网上的博客,讲的都不对,自己跟着他们踩了很多坑

1.单卡训练,单卡加载

这里我为了把三个模块save到同一个文件里,我选择对所有的模型先封装成一个checkpoint字典,然后保存到同一个文件里,这样就可以在加载时只需要加载一个参数文件。
保存:

states = {'state_dict_encoder': encoder.state_dict(),'state_dict_decoder': decoder.state_dict(),}
torch.save(states, fname)

加载:

#先初始化模型,因为保存时只保存了模型参数,没有保存模型整个结构
encoder = Encoder()
decoder = Decoder()
#然后加载参数
checkpoint = torch.load(model_path) #model_path是你保存的模型文件的位置
encoder_state_dict=checkpoint['state_dict_encoder']
decoder_state_dict=checkpoint['state_dict_decoder']
encoder.load_state_dict(encoder_state_dict)
decoder.load_state_dict(decoder_state_dict)

2.单卡训练,多卡加载

保存:
保存过程一样,不做任何改变

states = {'state_dict_encoder': encoder.state_dict(),'state_dict_decoder': decoder.state_dict(),}
torch.save(states, fname)

加载:
加载过程也没有任何改变,但是要注意,先加载模型参数,再对模型做并行化处理

#先初始化模型,因为保存时只保存了模型参数,没有保存模型整个结构
encoder = Encoder()
decoder = Decoder()
#然后加载参数
checkpoint = torch.load(model_path) #model_path是你保存的模型文件的位置
encoder_state_dict=checkpoint['state_dict_encoder']
decoder_state_dict=checkpoint['state_dict_decoder']
encoder.load_state_dict(encoder_state_dict)
decoder.load_state_dict(decoder_state_dict)
#并行处理模型
encoder = nn.DataParallel(encoder)
decoder = nn.DataParallel(decoder)

3.多卡训练,单卡加载

注意,如果你考虑到以后可能需要单卡加载你多卡训练的模型,建议在保存模型时,去除模型参数字典里面的module,如何去除呢,使用model.module.state_dict()代替model.state_dict()
保存:

states = {'state_dict_encoder': encoder.module.state_dict(), #不是encoder.state_dict()'state_dict_decoder': decoder.module.state_dict(),}
torch.save(states, fname)

加载:
要注意由于我们保存的方式是以单卡的方式保存的,所以还是要先加载模型参数,再对模型做并行化处理

#先初始化模型,因为保存时只保存了模型参数,没有保存模型整个结构
encoder = Encoder()
decoder = Decoder()
#然后加载参数
checkpoint = torch.load(model_path) #model_path是你保存的模型文件的位置
encoder_state_dict=checkpoint['state_dict_encoder']
decoder_state_dict=checkpoint['state_dict_decoder']
encoder.load_state_dict(encoder_state_dict)
decoder.load_state_dict(decoder_state_dict)
#并行处理模型
encoder = nn.DataParallel(encoder)
decoder = nn.DataParallel(decoder)

同时,你也可以用第二种方式去保存和加载:

3.多卡训练,单卡加载,方法二

使用model.state_dict()保存,但是单卡加载的时候,要把模型做并行化(在单卡上并行)
保存:

states = {'state_dict_encoder': encoder.state_dict(), 'state_dict_decoder': decoder.state_dict(),}
torch.save(states, fname)

加载:
要注意由于我们保存的方式是以多卡的方式保存的,所以无论你加载之后的模型是在单卡运行还是在多卡运行,都先把模型并行化再去加载

#先初始化模型,因为保存时只保存了模型参数,没有保存模型整个结构
encoder = Encoder()
decoder = Decoder()
#并行处理模型
encoder = nn.DataParallel(encoder)
decoder = nn.DataParallel(decoder)
#然后加载参数
checkpoint = torch.load(model_path) #model_path是你保存的模型文件的位置
encoder_state_dict=checkpoint['state_dict_encoder']
decoder_state_dict=checkpoint['state_dict_decoder']
encoder.load_state_dict(encoder_state_dict)
decoder.load_state_dict(decoder_state_dict)

4.多卡保存,多卡加载

这就和多卡保存,单卡加载第二中方式一样了
**使用model.state_dict()**保存,加载的时候,要先把模型做并行化(在多卡上并行)
保存:

states = {'state_dict_encoder': encoder.state_dict(), 'state_dict_decoder': decoder.state_dict(),}
torch.save(states, fname)

加载:
要注意由于我们保存的方式是以多卡的方式保存的,所以无论你加载之后的模型是在单卡运行还是在多卡运行,都先把模型并行化再去加载

#先初始化模型,因为保存时只保存了模型参数,没有保存模型整个结构
encoder = Encoder()
decoder = Decoder()
#并行处理模型
encoder = nn.DataParallel(encoder)
decoder = nn.DataParallel(decoder)
#然后加载参数
checkpoint = torch.load(model_path) #model_path是你保存的模型文件的位置
encoder_state_dict=checkpoint['state_dict_encoder']
decoder_state_dict=checkpoint['state_dict_decoder']
encoder.load_state_dict(encoder_state_dict)
decoder.load_state_dict(decoder_state_dict)

pytorch 使用DataParallel 单机多卡和单卡保存和加载模型时遇到的问题相关推荐

  1. Pytorch 保存和加载模型

    当保存和加载模型时,需要熟悉三个核心功能: 1. torch.save :将序列化对象保存到磁盘.此函数使用Python的 pickle 模块进行序列化.使 用此函数可以保存如模型.tensor.字典 ...

  2. PyTorch | 保存和加载模型教程

    点击上方"算法猿的成长",选择"加为星标" 第一时间关注 AI 和 Python 知识 图片来自 Unsplash,作者: Jenny Caywood 2019 ...

  3. python保存模型与参数_基于pytorch的保存和加载模型参数的方法

    当我们花费大量的精力训练完网络,下次预测数据时不想再(有时也不必再)训练一次时,这时候torch.save(),torch.load()就要登场了. 保存和加载模型参数有两种方式: 方式一: torc ...

  4. pytorch load state dict_Pytorch学习记录-使用Pytorch进行深度学习,保存和加载模型

    新建 Microsoft PowerPoint 演示文稿 (2).jpg 保存和加载模型 在完成60分钟入门之后,接下来有六节tutorials和五节关于文本处理的tutorials.争取一天一节.不 ...

  5. 【pytorch】(六)保存和加载模型

    文章目录 保存和加载模型 保存加载模型参数 保存加载模型和参数 保存和加载模型 import torch from torch import nn from torch.utils.data impo ...

  6. pytorch 不同设备下保存和加载模型,需要指定设备

  7. pytorch保存和加载模型state_dict

    保存模型: torch.save({'epoch': epoch + 1,'state_dict': model.state_dict(),'optimizer': optimizer.state_d ...

  8. 保存和加载pytorch模型

    当保存和加载模型时,需要熟悉三个核心功能: torch.save:将序列化对象保存到磁盘.此函数使用Python的pickle模块进行序列化.使用此函数可以保存如模型.tensor.字典等各种对象. ...

  9. pytorch模型的保存和加载、checkpoint

    pytorch模型的保存和加载.checkpoint 其实之前笔者写代码的时候用到模型的保存和加载,需要用的时候就去度娘搜一下大致代码,现在有时间就来整理下整个pytorch模型的保存和加载,开始学习 ...

最新文章

  1. 完整项目基础架构精简版-实现权限管理
  2. SQL查询库、表,列等的一些操作
  3. 对PostgreSQL中 index only scan 的初步理解
  4. redis hash
  5. 百度地图拖动标注后获取坐标
  6. unix高级环境编程-基础知识
  7. 嵌入式robocode实训-任务三
  8. 银河麒麟操作系统下载地址收集
  9. Matlab如何下载安装科研绘图工具Gramm并绘图
  10. python中backward是什么意思_python-PyTorch函数中的下划线后缀是什么意思...
  11. 88e1111的1000base-x to copper(GBIC)配置及使用
  12. Android 启动页白屏 快速解决
  13. Windows 系统中添加防火墙规则
  14. 对比学习中的4种经典训练模式
  15. 【python绘图】——删除多余的图例【图示说明】
  16. kaggle比赛前2%摸奖银牌总结
  17. 计算机基础课程改革问卷调查,大學计算机基础课程教学改革的调查与设想.doc...
  18. 软件培训学习中自律很重要
  19. 【电脑一点通】如何开启Windows夜间模式
  20. Android 手机配office365邮箱

热门文章

  1. 2021年程序员1月薪资统计,你在哪一档?
  2. Xcode中StaticLibrary和Framework的共同点和区别
  3. linux7端口聚合,centos7配置链路聚合
  4. 考研数学(180°为什么等于π)
  5. php登录信息首页显示,首页登录后怎么在首页显示用户名以及隐藏登录框?
  6. python变量声明语句_python – 在条件语句中声明变量有问题吗?
  7. 看Java大牛是如何高效学习的?你掌握好这些了吗?
  8. 补丁程序正在运行_针对微软4月14日更新补丁会导致蓝屏问题的检测及解决方法...
  9. matlab zigzag算法,ZIGZAG扫描的MATLAB实现
  10. python 类方法 函数_Python OOP类中的几种函数或方法总结