文章目录

  • 一、问题描述
  • 二、解决方法
  • 三、其他问题
  • Reference

一、问题描述

导入之前训练好的模型权重后使用模型预测时如题报错size mismatch for embedding.embed_dict.userid.weight

state_dict = torch.load(model_path)
model.load_state_dict(state_dict)

二、解决方法

是因为导入的模型权重(之前训练好、保存的)的维度和当前定义的model的权重维度不同,所以我选择修改下当前定义的model,即将自己返回如下beat_sparse_features等的dataloader,其读取的数据换成之前模型训练的数据,使得模型定义后的model的模型权重和导入的权重一致。

model = DeepFM(deep_features=beat_dense_features + beat_sparse_features,fm_features=beat_sparse_features,mlp_params={"dims": [256, 128], "dropout": 0.2, "activation": "relu"},
)

当然如果根据大家的实际情况改动,如很多时候实例化模型时改变实参即可。

三、其他问题

可能还有其他情况也会报这个错,如导入预训练模型进行微调,首先加载预训练模型权重:

model = models.resnet34(pretrained=False)
pretrained_dict = torch.load('./pretrain/resnet34-333f7ec4.pth')
pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model.state_dict()}
model.load_state_dict(pretrained_dict)
model.fc = torch.nn.Linear(512, 5) # 512为原始fc的数目,5是自己任务的分类数

由于分类类别不一致,报错size mismatch for fc.bias: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([x]).,这里可以选择不加载fc层:

model = models.resnet34(pretrained=False)
pretrained_dict = torch.load('./pretrain/resnet34-333f7ec4.pth')
model_dict = model.state_dict()
pretrained_dict = {k: v for k, v in pretrained_dict.items() if (k in model_dict and 'fc' not in k)} # 将'fc'这一层的权重选择不加载即可。
model_dict.update(pretrained_dict) # 更新权重
model.load_state_dict(model_dict)

可能还有其他情况,如NLP词表维度不一致等等,后面遇到再更新该帖。如有不对之处,恳请大佬们指正!

Reference

[1] 解决CNN中训练权重参数不匹配size mismatch for fc.weight,size mismatch for fc.bias
[2] torch 封装文本数据预处理、训练、评估、预测过程
[3] 关于Pytorch加载模型参数的避坑指南

解决size mismatch for embedding.embed_dict.userid.weight相关推荐

  1. size mismatch for roi_heads.box_predictor.cls_score.weight: copying a param with shape torch.Size([9

    1. 报错 RuntimeError: Error(s) in loading state_dict for FasterRCNN: size mismatch for roi_heads.box_p ...

  2. 【Python】解决CNN中训练权重参数不匹配size mismatch for fc.weight,size mismatch for fc.bias

    目录 1.问题描述 2.问题原因 3.问题解决 3.1思路1--忽视最后一层权重 额外说明:假如载入权重不写strict=False, 直接是model.load_state_dict(pre_wei ...

  3. size mismatch for xx.weight错误的解决方法

    问题重现: RuntimeError: Error(s) in loading state_dict for xxxNet:size mismatch for bn1.weight: copying ...

  4. 解决使用ICsharpCode解压缩时候报错Size MisMatch的错误

    项目用到了这个组件,然后在解压文件时候报Size MisMatch错,解决方法:到https://github.com/icsharpcode/SharpZipLib/releases选择对应的源码下 ...

  5. size mismatch for fc.weight: copying a param with shape torch.Size([1000, 2048]) from checkpoint, th

    问题描述 我想在我自己的项目更换其他的模型,下载的预训练模型出现了FC层不匹配的问题,找了好多人都写了这个点,今天总结一下: 首先我们遇到的问题如下: 他的意思是resnet50的fc层是1000分类 ...

  6. pytorch神经网络,解决输入图像大小不匹配问题 size mismatch

    问题如下:RuntimeError: size mismatch, m1: [4 x 512], m2: [64 x 128]-- RuntimeError: size mismatch, m1: [ ...

  7. pytorch搭建cnn报错:RuntimeError: size mismatch, m1: [10 x 43264], m2: [10816 x 2] at C...

    具体报错信息: Traceback (most recent call last):File "E:/Program Files/PyCharm 2019.2/machinelearning ...

  8. strict=False 但还是size mismatch for []: copying a param with shape [] from checkpoint,the shape in cur

    strict=False 但还是size mismatch for []: copying a param with shape [] from checkpoint,the shape in cur ...

  9. oracle numa map size mismatch,Oracle启动时提示map size mismatch; abort

    Oracle启动时提示map size mismatch; abort 发布时间:2020-06-26 13:35:09 来源:51CTO 阅读:1370 作者:会说话的鱼 今天在DELL服务器的Re ...

最新文章

  1. SMRT测序技术及其在微生物研究中的应用
  2. JS通过正则限制 input 输入框只能输入整数、小数(金额或者现金) 两位小数
  3. v-for 切换不同的class
  4. Yii2.0 数据库更新update
  5. CF223C【Partial Sums】(组合数学+乱搞)
  6. jQuery父级以及同级元素查找
  7. [Sharepoint2007对象模型]第一回:服务器场(SPFarm)
  8. 我与计算机作文450字,打电脑作文450字
  9. 《jQuery Mobile入门经典》—— 2.2 展现CSS样式
  10. Python基础—文件操作
  11. Labview程序优化
  12. 数仓建模—数仓建模实战(建模流程/建模工具)
  13. 并发编程-线程卡死问题实践
  14. win10电脑外接音响没声音怎么回事?win10电脑外接音响没声音的修复方法
  15. operator int()用法
  16. ios 设置导航栏背景色
  17. 天津城建大学计算机学院官网,天津城建大学计算机与信息工程学院研究生导师简介-杨振舰...
  18. 多个文本文档合并为一个文件的方法
  19. iOS 3DES加密解密(一行代码搞定)
  20. C练题笔记之:Leetcode-780. 到达终点

热门文章

  1. php 创建透明png,php生成透明背景图片实例
  2. html页面中漂浮物怎么实现,html的section标签是什么怎么用了
  3. 海盗王实现物品999个堆叠
  4. tl494c封装区别_TL494的特点与引脚功能
  5. ESP32 开发之旅③ Ticker——定时库
  6. 使用Arduino开发ESP32(07):系统时间和定时任务调度器Ticker
  7. 第十四届蓝桥杯. 接龙数列(线性DP)
  8. 取得文件夹下的所有文件的文件名和文件大小
  9. 解决.Net Core 使用 System.Drawing.Common 在CentOS下报错'Gdip'
  10. IDEA自动生成实体类