文章目录

  • 背景
  • 报错
  • 原因
  • 解决

背景

Pytorch在加载模型参数的时候,有两种情况可能出现这种问题:

  1. 自己写的网络结构,例如:
  • 代码
import models
arch = 'resnet50'
model = models.__dict__[arch]() # 根据resnet,自己写的的多任务网络结构
checkpoint = torch.load(ckptFile)model.load_state_dict(checkpoint['state_dict'])
model = torch.nn.DataParallel(model).cuda()
  1. 按照官方示例代码
import torch
from torchvision import models
model = models.resnet50(pretrained=False) # 使用官方结构和预训练权重
ckptFile = '../resnet50-19c8e357.pth'
ckpt = torch.load(ckptFile)model = torch.nn.DataParallel(model)
model.load_state_dict(ckpt)

报错

相应的报错也分为两种情况:

  1. 第一种情况报错:
Traceback (most recent call last):File "/home/user1/project1/utils/eval.py", line 193, in <module>test_models(r'/home/user1/models')File "/home/user1/project1/utils/eval.py", line 157, in test_modelsmodel.load_state_dict(checkpoint['state_dict'])File "/home/user1/miniconda3/lib/python3.7/site-packages/torch/nn/modules/module.py", line 777, in load_state_dictself.__class__.__name__, "\n\t".join(error_msgs)))
RuntimeError: Error(s) in loading state_dict for ResNet:Missing key(s) in state_dict: "conv1.weight", "bn1.weight", "bn1.bias", ...". Unexpected key(s) in state_dict: "module.conv1.weight", "module.bn1.weight", "module.bn1.bias"...
  1. 第二种情况报错:
Traceback (most recent call last):File "/home/data/PJS/test_bed/torch_ddp.py", line 238, in <module>model.load_state_dict(ckpt)File "/home/data/pkgs/miniconda3/envs/thpj/lib/python3.6/site-packages/torch/nn/modules/module.py", line 777, in load_state_dictself.__class__.__name__, "\n\t".join(error_msgs)))
RuntimeError: Error(s) in loading state_dict for DataParallel:Missing key(s) in state_dict: "module.conv1.weight", "module.bn1.weight", "module.bn1.bias", "module.bn1.running_mean", "module.bn1.running_var",Unexpected key(s) in state_dict: "conv1.weight", "bn1.running_mean", "bn1.running_var", "bn1.weight", "bn1.bias", "layer1.0.conv1.weight",

原因

跟代码执行顺序有关,模型加载的顺序和模型放到DDP / DP 的顺序。
第一种情况需要先把模型放到DDP / DP ,再加载;
第二种情况,需要先加载模型,然后再放到DDP / DP

解决

正确的顺序

import modelsarch = 'resnet50'
model = models.__dict__[arch]()
# 交换了顺序,先放到DP
model = torch.nn.DataParallel(model).cuda()checkpoint = torch.load(ckptFile)
model.load_state_dict(checkpoint['state_dict'])
  1. 第二种情况正确的写法:
import torch
from torchvision import models
model = models.resnet50(pretrained=False)
ckptFile = '/home/k/Downloads/Persepolis/Others/resnet50-19c8e357.pth'
ckpt = torch.load(ckptFile)
# 交换了顺序,先加载
model.load_state_dict(ckpt)
model = torch.nn.DataParallel(model)

pytorch加载模型报错Unexpected key(s) in state_dict: module.conv1.weight, module.bn1相关推荐

  1. pytorch 加载模型报错:‘function‘ object has no attribute ‘copy‘

    太粗心了,保存模型的时候写错了,写成了如下: torch.save(model,model_file) 而实际上应该是: torch.save(model.state_dict(),model_fil ...

  2. Pytorch加载模型并进行图像分类预测

    目录 1. 整体流程 1)实例化模型 2)加载模型 3)输入图像 4)输出分类结果 5)完整代码 2. 处理图像 1) How can i convert an RGB image into gray ...

  3. Pytorch加载模型只导入部分层权重,即跳过指定网络层的方法

    需求 Pytorch加载模型时,只导入部分层权重,跳过部分指定网络层.(权重文件存储为dict形式) 方法一 常见方法:加载权重时用if对网络层进行筛选 ''' # model为定义的网络结构: cl ...

  4. 使用np.load()加载数据 报错 Object arrays cannot be loaded when allow_pickle=False

    使用np.load()加载数据 报错 Object arrays cannot be loaded when allow_pickle=False https://blog.csdn.net/weix ...

  5. webpack使用css-loader跟style-loader加载css报错

    webpack使用css-loader跟style-loader加载css报错 webpack使用css-loader跟style-loader加载css报错 webpack.config.js 配置 ...

  6. Office2016打开PPT出现加载项报错。

    首先来看错误截图吧: 抱歉,由于某种原因,PowerPoint无法加载... 错误原因:原先安装过MathType然后有卸载了,但在卸载Mathtype时没有卸载干净.造成的结果是不管是否重装Offi ...

  7. 加载lua报错cannot load incompatible bytecode

    问题描述 加载lua报错cannot load incompatible bytecode 原因分析: 显而易见就是字面原因:无法加载不兼容的字节码 1.查看文件修改日期,日前开发对其做过升级. 2. ...

  8. WKWebView 加载 http:// ** 报错WebPageProxy::didFailProvisionalLoadForFrame:

    WKWebView 加载 http:// ** 报错WebPageProxy::didFailProvisionalLoadForFrame: 模拟器:iOS14 iPhone11 Pro Max 猜 ...

  9. selenium加载cookie报错问题:selenium.common.exceptions.InvalidCookieDomainException: Message: invalid cooki

    selenium加载cookie报错问题:selenium.common.exceptions.InvalidCookieDomainException: Message: invalid cooki ...

最新文章

  1. 李小璐PGONE事件对推荐系统的考验
  2. Vue中使用a标签实现点击在新标签页中打开实现照片预览
  3. python如何入侵服务器的_通过redis入侵服务器的步骤
  4. 科普 | CPU 是如何工作的?
  5. Gitter - 高颜值GitHub小程序客户端诞生记 1
  6. 示波器采样速率单位Ms/s、Gs/s
  7. WindowsXP、Windows2003本地密码清除方法
  8. 计算机房等电位接地规范,电子计算机机房接地装置设计要求
  9. UBUNTU内核升级后,如何更新 kernel headers
  10. node 生成随机头像_给微信设置卡通头像,再不怕撞脸!
  11. 基于BPM(业务流程管理)的低代码开发平台有哪些优势?
  12. linux 批量删除任务,Linux-Shell脚本学习心得之批量创建、删除用户
  13. 【创文进行时】创建文明城市社区在行动
  14. 全球回报最好的 40 个 VC 投资案例,我们可以从中学到什么?
  15. Kaggle实战:泰坦尼克幸存者预测 -下
  16. SQL必知必会读书笔记
  17. 丑数求解以及丑数的优化
  18. WWDC20 CoreImage 专题
  19. 计算机中ar的作用,AR增强现实的作用
  20. 新买的阿里云服务器无法进行远程桌面

热门文章

  1. python3飞机大战源码及源码使用教程(让小白做出第一个小游戏)
  2. 分页和分段有什么区别?
  3. 分布式链路监控与追踪系统(SpringCloud Sleuth + Zipkin)
  4. koa工程化项目使用的基本包
  5. 浅谈车道偏离预警系统
  6. 用C++语言实现求长方形面积
  7. 《信号与系统》—MATLAB分析与实现(二)
  8. 基于STM32单片机的智能加湿器(Proteus仿真+程序)
  9. Chapter13 : Ultrahigh Throughput Protein-Ligand Docking with Deep Learning
  10. Mac 上的 Keynote 讲演文件转ppt格式