希望将训练好的模型加载到新的网络上。如上面题目所描述的,PyTorch在加载之前保存的模型参数的时候,遇到了问题。

Unexpected key(s) in state_dict: "module.features. ...".,Expected ".features....". 直接原因是key值名字不对应。

表明了加载过程中,期望获得的key值为feature...,而不是module.features....。这是由模型保存过程中导致的,模型应该是在DataParallel模式下面,也就是采用了多GPU训练模型,然后直接保存的。

You probably saved the model using nn.DataParallel, which stores the model in module, and now you are trying to load it without . You can either add a nn.DataParallel temporarily in your network for loading purposes, or you can load the weights file, create a new ordered dict without the module prefix, and load it back.

解决上面的问题有三个办法:

1. 对load的模型创建新的字典,去掉不需要的key值"module".

# original saved file with DataParallel
state_dict = torch.load('checkpoint.pt')  # 模型可以保存为pth文件,也可以为pt文件。
# create new OrderedDict that does not contain `module.`
from collections import OrderedDict
new_state_dict = OrderedDict()
for k, v in state_dict.items():name = k[7:] # remove `module.`,表面从第7个key值字符取到最后一个字符,正好去掉了module.new_state_dict[name] = v #新字典的key值对应的value为一一对应的值。
# load params
model.load_state_dict(new_state_dict) # 从新加载这个模型。

2. 直接用空白''代替'module.'

model.load_state_dict({k.replace('module.',''):v for k,v in torch.load('checkpoint.pt').items()})# 相当于用''代替'module.'。
#直接使得需要的键名等于期望的键名。

3. 最简单的方法,加载模型之后,接着将模型DataParallel,此时就可以load_state_dict。

如果有多个GPU,将模型并行化,用DataParallel来操作。这个过程会将key值加一个"module. ***"。

model = VGG()# 实例化自己的模型;
checkpoint = torch.load('checkpoint.pt', map_location='cpu')  # 加载模型文件,pt, pth 文件都可以;
if torch.cuda.device_count() > 1:# 如果有多个GPU,将模型并行化,用DataParallel来操作。这个过程会将key值加一个"module. ***"。model = nn.DataParallel(model)
model.load_state_dict(checkpoint) # 接着就可以将模型参数load进模型。

4. 总结

从出错显示的问题就可以看出,key值不匹配,因此可以选择多种方法,将模型参数加载进去。 这个方法通常会在load_state_dict过程中遇到。将训练好的一个网络参数,移植到另外一个网络上面,继续训练。或者将训练好的网络checkpoint加载进模型,再次进行训练。可以打印出model state_dict来看出两者的差别。

model = VGGNet()
params=model.state_dict() #获得模型的原始状态以及参数。
for k,v in params.items():print(k) #只打印key值,不打印具体参数。

features.0.0.weight   
features.0.1.weight
features.1.conv.3.weight
features.1.conv.4.num_batches_tracked

model = VGGNet()
checkpoint = torch.load('checkpoint.pt', map_location='cpu')
# Load weights to resume from checkpoint。
# print('**************************************')
# 这个方法能够直接打印出你保存的checkpoint的键和值。
for k,v in checkpoint.items():print(k)
print("*****************************************")

输出结果为:

module.features.0.0.weight",

"module.features.0.1.weight",

"module.features.0.1.bias

可以看出不匹配,模型的参数中,key值不同,多了module。

PS: 2020-12-25

在移植参数的过程中,对于出现 .total_ops和.total_params结尾的参数,可参考以下代码:

from collections import OrderedDict
checkpoint = torch.load(pretrained_model_file_path,map_location=(None if use_cuda and not remap_to_cpu else "cpu"))
new_state_dict = OrderedDict()
for k, v in checkpoint.items():if not k.endswith('total_ops') and not k.endswith('total_params'):name = k[7:]new_state_dict[name] = v

如果有用,记得点赞

PyTorch加载模型model.load_state_dict()问题,Unexpected key(s) in state_dict: “module.features..,Expected .相关推荐

  1. Pytorch加载模型并进行图像分类预测

    目录 1. 整体流程 1)实例化模型 2)加载模型 3)输入图像 4)输出分类结果 5)完整代码 2. 处理图像 1) How can i convert an RGB image into gray ...

  2. pytorch加载模型报错Unexpected key(s) in state_dict: module.conv1.weight, module.bn1

    文章目录 背景 报错 原因 解决 背景 Pytorch在加载模型参数的时候,有两种情况可能出现这种问题: 自己写的网络结构,例如: 代码 import models arch = 'resnet50' ...

  3. Pytorch加载模型只导入部分层权重,即跳过指定网络层的方法

    需求 Pytorch加载模型时,只导入部分层权重,跳过部分指定网络层.(权重文件存储为dict形式) 方法一 常见方法:加载权重时用if对网络层进行筛选 ''' # model为定义的网络结构: cl ...

  4. Unexpected key(s) in state_dict: “module.conv1.weight“, “module.bn1.weight“, “module.bn1.bias“,

    由于服务器老是断电 所以想加载已经训练好的上一个epoch的模型,但是在加载时遇到了这个问题 这是由于保存模型字典时每一个模块的key都自动加上了'module'.所以在加载模型参数继续训练时就会与模 ...

  5. Unexpected key(s) in state_dict: module.backbone.bn1.num_batches_tracked

    pytorch预测的时候报异常了: {RuntimeError}Error(s) in loading state_dict for DataParallel:     Unexpected key( ...

  6. OpenGL 加载模型Model

    OpenGL 模型Model 模型Model简介 导入3D模型到OpenGL 从Assimp到网格 索引 材质 重大优化 和箱子模型告别 模型Model简介 现在是时候接触Assimp并创建实际的加载 ...

  7. pytorch加载模型时出现.....ckpt_100.pth is a zip archive (did you mean to use torch.jit.load()?)

    在测试加载训练好的模型时出现上方问题,参考这篇文章,原因是训练和测试的torch版本不一致. 训练的时候是1.6,测试的时候是1.2,因此需要先在1.6版本下加载模型,重新保存,在保存的时候设置use ...

  8. pytorch 加载模型 模型大小测试速度

    直接加载整个模型 Pytorch保存和加载整个模型: save_net=model if hasattr(model, 'module'):save_net=model.module torch.sa ...

  9. 使用PyTorch加载模型部分参数方法

    前言 在深度学习领域,经常需要使用其他人已训练好的模型进行改进或微调,这个时候我们通常会希望加载预训练模型文件的参数,如果网络结构不变,只需要使用load_state_dict方法即可.而当我们改动网 ...

最新文章

  1. Linux课堂随笔---第四天
  2. U3D包大小优化之microlib
  3. 日期型转json格式(springboot)
  4. SQL Server 2000卸载后重新安装的问题
  5. Ubuntu 16.04安装Chrome浏览器
  6. Servlet API
  7. Spring MVC:表单处理卷。 4 –单选按钮
  8. 【kafka】在 Kafka Streams 中启用 Exactly-Once
  9. arduino esp8266_Arduino-httpupdate-OTA-esp8266升级探险记
  10. Spring 源码分析(四) ——MVC(六)M 与 C 的实现
  11. Python设置显示屏分辨率
  12. 《单片机原理与接口技术》期中测评
  13. iPhone IPv6上网
  14. 外卖行业现状分析_2020餐饮外卖行业市场前景及现状分析
  15. iOS 初学者功能代码大集合,个人笔记
  16. Sentinel流量卫兵
  17. 相亲角、地摊,暗访小县城的夜市
  18. 软件测试如何分类?又有哪些类别?
  19. nrf51822 --- 动态修改连接间隔
  20. 实验有效的js原生前端 全国三级联动

热门文章

  1. 2022-2028年中国特高压电网行业深度调研及投资前景预测报告
  2. 2022-2028年中国TPE手套行业市场全景调查及发展策略分析报告
  3. 【Sql Server】DateBase-连接查询
  4. LeetCode简单题之交替位二进制数
  5. LeetCode简单题之删列造序
  6. 在NVIDIA A100 GPU中使用DALI和新的硬件JPEG解码器快速加载数据
  7. 摄像头ISP系统原理(中)
  8. 2021年大数据Flink(三):​​​​​​​Flink安装部署 Local本地模式
  9. Python:CrawlSpiders
  10. ad 卡尔曼_卡尔曼滤波剪影__Kalman Filtering · Make Intuitive