strict=False 但还是size mismatch for []: copying a param with shape [] from checkpoint,the shape in cur

问题

我们知道通过

model.load_state_dict(state_dict, strict=False)

可以暂且忽略掉模型和参数文件中不匹配的参数,先将正常匹配的参数从文件中载入模型。

笔者在使用时遇到了这样一个报错:

RuntimeError: Error(s) in loading state_dict for ViT_Aes:size mismatch for mlp_head.1.weight: copying a param with shape torch.Size([1000, 768]) from checkpoint, the shape in current model is torch.Size([10, 768]).size mismatch for mlp_head.1.bias: copying a param with shape torch.Size([1000]) from checkpoint, the shape in current model is torch.Size([10]).

一开始笔者很奇怪,我已经写明strict=False了,不匹配参数的不管就是了,为什么还要给我报错。

原因及解决方案

经过笔者仔细打印模型的键和文件中的键进行比对,发现是这样的:strict=False可以保证模型中的键与文件中的键不匹配时暂且跳过不管,但是一旦模型中的键和文件中的键匹配上了,PyTorch就会尝试帮我们加载参数,就必须要求参数的尺寸相同,所以会有上述报错。

比如在我们需要将某个预训练的模型的最后的全连接层的输出的类别数替换为我们自己的数据集的类别数,再进行微调,有时会遇到上述情况。这时,我们知道全连接层的参数形状会是不匹配,比如我们加载 ImageNet 1K 1000分类的预训练模型,它的最后一层全连接的输出维度是1000,但如果我们自己的数据集是10分类,我们需要将最后一层全链接的输出维度改为10。但是由于键名相同,所以PyTorch还是尝试给我们加载,这时1000和10维度不匹配,就会导致报错。

解决方案就是我们将 .pth 模型文件读入后,将其中我们不需要的层(通常是最后的全连接层)的参数pop掉即可。

以 ViT 为例子,假设我们有一个 ViT 模型,并有一个参数文件 vit-in1k.pth,它里面存储着 ViT 模型在 ImageNet-1K 1000分类数据集上训练的参数,而我们要在自己的10分类数据集上微调这个模型。

model = ViT(num_classes=10)
ckpt = torch.load('vit-in1k.pth', map_location='cpu')
msg = model.load_state_dict(ckpt, strict=False)
print(msg)

直接这样加载会出错,就是上面的错误:

 size mismatch for head.weight: copying a param with shape torch.Size([1000, 768]) from checkpoint, the shape in current model is torch.Size([10, 768]).size mismatch for head.bias: copying a param with shape torch.Size([1000]) from checkpoint, the shape in current model is torch.Size([10]).

我们将最后 pth 文件加载进来之后(即 ckpt) 中全连接层的参数直接pop掉,至于需要pop掉哪些键名,就是上面报错信息中提到了的,在这里就是 head.weighthead.bias

ckpt.pop('head.weight')
ckpt.pop('head.bias')

之后在运行,会发现我们打印的 msg 显示:

_IncompatibleKeys(missing_keys=['head.weight', 'head.bias'], unexpected_keys=[])

即缺失了head.weighthead.bias 这两个参数,这是正常的,因为在自己的数据集上微调时,我们本就不需要这两个参数,并且已经将它们从模型文件字典 ckpt 中pop掉了。现在,模型全连接之前的层(通常即所谓的特征提取层)的参数已经正常加载了,接下来可以在自己的数据集上进行微调。

因为反正我们也不用这些参数,就直接把这个键值对从字典中pop掉,以免 PyTorch 在帮我们加载时试图加载这些维度不匹配,我们也不需要的参数。

strict=False 但还是size mismatch for []: copying a param with shape [] from checkpoint,the shape in cur相关推荐

  1. size mismatch for roi_heads.box_predictor.cls_score.weight: copying a param with shape torch.Size([9

    1. 报错 RuntimeError: Error(s) in loading state_dict for FasterRCNN: size mismatch for roi_heads.box_p ...

  2. size mismatch for yolo_head2.1.bias: copying a param with shape torch.Size(【75】) from checkpoint...

    凯哥英语视频 今天一个朋友用YOLO4预测图片报错:size mismatch for yolo_head2.1.bias: copying a param with shape torch.Size ...

  3. 2021-05-31 size mismatch for transformers copying a param

    size mismatch for transformers copying a param with shape torch.Size from checkpoint, the shape in c ...

  4. size mismatch for fc.weight: copying a param with shape torch.Size([1000, 2048]) from checkpoint, th

    问题描述 我想在我自己的项目更换其他的模型,下载的预训练模型出现了FC层不匹配的问题,找了好多人都写了这个点,今天总结一下: 首先我们遇到的问题如下: 他的意思是resnet50的fc层是1000分类 ...

  5. 【Python】解决CNN中训练权重参数不匹配size mismatch for fc.weight,size mismatch for fc.bias

    目录 1.问题描述 2.问题原因 3.问题解决 3.1思路1--忽视最后一层权重 额外说明:假如载入权重不写strict=False, 直接是model.load_state_dict(pre_wei ...

  6. 解决size mismatch for embedding.embed_dict.userid.weight

    文章目录 一.问题描述 二.解决方法 三.其他问题 Reference 一.问题描述 导入之前训练好的模型权重后使用模型预测时如题报错size mismatch for embedding.embed ...

  7. size mismatch for xx.weight错误的解决方法

    问题重现: RuntimeError: Error(s) in loading state_dict for xxxNet:size mismatch for bn1.weight: copying ...

  8. oracle numa map size mismatch,Oracle启动时提示map size mismatch; abort

    Oracle启动时提示map size mismatch; abort 发布时间:2020-06-26 13:35:09 来源:51CTO 阅读:1370 作者:会说话的鱼 今天在DELL服务器的Re ...

  9. 解决使用ICsharpCode解压缩时候报错Size MisMatch的错误

    项目用到了这个组件,然后在解压文件时候报Size MisMatch错,解决方法:到https://github.com/icsharpcode/SharpZipLib/releases选择对应的源码下 ...

最新文章

  1. 解决 ERROR: Couldn't connect to Docker daemon at http+docker://localunixsocket - is it running?
  2. IDEA报错解决:Error:(33, 35) java: -source 7 中不支持 lambda 表达式 (请使用 -source 8 或更高版本以启用 lambda 表达式)
  3. STS中applicationContext.xml配置文件
  4. hive插入表的insert 执行计划_0651-6.2.0-启用Sentry后Impala执行SQL失败问题分析
  5. 一些培养程序员leadership的经验教训
  6. 自动化用户特定实体的访问控制
  7. sklearn 绘制roc曲线_如何用Tensorflow和scikit-learn绘制ROC曲线?
  8. python闭包怎么理解_Python:闭包的理解
  9. 怎么new一个指针_C++知识点 34:指针运算符重载 -- 智能指针
  10. Protobuf简单编写与使用
  11. 2019-05-22 SQL注入;啊D注入工具;
  12. 搭建web项目常见错误
  13. 指纹识别系统电路设计图集锦 —电路图天天读(200)
  14. BPM实例分享——金额规则大写
  15. 会唱歌的程序员为何如此受欢迎?
  16. Learning to Rank基于pairwise的算法(一)——Ranking SVM、MHR、IRSVM
  17. 关于网页中显示生僻字的方法
  18. Java SE-网络编程二
  19. 【2019全国职业技能大赛大数据技术】任务四:14-数据可视化(20分_题目+答案<图片+分值>)
  20. Python—循环程序

热门文章

  1. 企业实战(Jenkins+GitLab+SonarQube)_11_Jenkins权限的划分
  2. For循环(十分重要)
  3. java ajax查询_java-如何计时ajax查询(发送查询,处理,接收响应)
  4. floquet端口x极化入射波_请问CST 2012 floquet中的模式设置
  5. android重新编译res,使用 gradle 在编译时动态设置 Android resValue / BuildConfig / Manifes中lt;meta-datagt;变量的值...
  6. 努比亚手机浏览器 安全证书失效_浏览器提示“该站点安全证书的吊销信息不可用”的解决方法-...
  7. 面向对象三个特征总结
  8. android实现qq修改密码底部弹出框_易查分强大的“可修改列”功能:轻松实现填表、留言和信息核对...
  9. linux 镜像错误,VituralBox 使用已有镜像文件报错:E_INVALIDARG (0x80070057)
  10. oracle并行parallel update两张表_Oracle与并行性 parallel