解决size mismatch for embedding.embed_dict.userid.weight
文章目录
- 一、问题描述
- 二、解决方法
- 三、其他问题
- Reference
一、问题描述
导入之前训练好的模型权重后使用模型预测时如题报错size mismatch for embedding.embed_dict.userid.weight
。
state_dict = torch.load(model_path)
model.load_state_dict(state_dict)
二、解决方法
是因为导入的模型权重(之前训练好、保存的)的维度和当前定义的model
的权重维度不同,所以我选择修改下当前定义的model
,即将自己返回如下beat_sparse_features
等的dataloader,其读取的数据换成之前模型训练的数据,使得模型定义后的model
的模型权重和导入的权重一致。
model = DeepFM(deep_features=beat_dense_features + beat_sparse_features,fm_features=beat_sparse_features,mlp_params={"dims": [256, 128], "dropout": 0.2, "activation": "relu"},
)
当然如果根据大家的实际情况改动,如很多时候实例化模型时改变实参即可。
三、其他问题
可能还有其他情况也会报这个错,如导入预训练模型进行微调,首先加载预训练模型权重:
model = models.resnet34(pretrained=False)
pretrained_dict = torch.load('./pretrain/resnet34-333f7ec4.pth')
pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model.state_dict()}
model.load_state_dict(pretrained_dict)
model.fc = torch.nn.Linear(512, 5) # 512为原始fc的数目,5是自己任务的分类数
由于分类类别不一致,报错size mismatch for fc.bias: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([x]).
,这里可以选择不加载fc层:
model = models.resnet34(pretrained=False)
pretrained_dict = torch.load('./pretrain/resnet34-333f7ec4.pth')
model_dict = model.state_dict()
pretrained_dict = {k: v for k, v in pretrained_dict.items() if (k in model_dict and 'fc' not in k)} # 将'fc'这一层的权重选择不加载即可。
model_dict.update(pretrained_dict) # 更新权重
model.load_state_dict(model_dict)
可能还有其他情况,如NLP词表维度不一致等等,后面遇到再更新该帖。如有不对之处,恳请大佬们指正!
Reference
[1] 解决CNN中训练权重参数不匹配size mismatch for fc.weight,size mismatch for fc.bias
[2] torch 封装文本数据预处理、训练、评估、预测过程
[3] 关于Pytorch加载模型参数的避坑指南
解决size mismatch for embedding.embed_dict.userid.weight相关推荐
- size mismatch for roi_heads.box_predictor.cls_score.weight: copying a param with shape torch.Size([9
1. 报错 RuntimeError: Error(s) in loading state_dict for FasterRCNN: size mismatch for roi_heads.box_p ...
- 【Python】解决CNN中训练权重参数不匹配size mismatch for fc.weight,size mismatch for fc.bias
目录 1.问题描述 2.问题原因 3.问题解决 3.1思路1--忽视最后一层权重 额外说明:假如载入权重不写strict=False, 直接是model.load_state_dict(pre_wei ...
- size mismatch for xx.weight错误的解决方法
问题重现: RuntimeError: Error(s) in loading state_dict for xxxNet:size mismatch for bn1.weight: copying ...
- 解决使用ICsharpCode解压缩时候报错Size MisMatch的错误
项目用到了这个组件,然后在解压文件时候报Size MisMatch错,解决方法:到https://github.com/icsharpcode/SharpZipLib/releases选择对应的源码下 ...
- size mismatch for fc.weight: copying a param with shape torch.Size([1000, 2048]) from checkpoint, th
问题描述 我想在我自己的项目更换其他的模型,下载的预训练模型出现了FC层不匹配的问题,找了好多人都写了这个点,今天总结一下: 首先我们遇到的问题如下: 他的意思是resnet50的fc层是1000分类 ...
- pytorch神经网络,解决输入图像大小不匹配问题 size mismatch
问题如下:RuntimeError: size mismatch, m1: [4 x 512], m2: [64 x 128]-- RuntimeError: size mismatch, m1: [ ...
- pytorch搭建cnn报错:RuntimeError: size mismatch, m1: [10 x 43264], m2: [10816 x 2] at C...
具体报错信息: Traceback (most recent call last):File "E:/Program Files/PyCharm 2019.2/machinelearning ...
- strict=False 但还是size mismatch for []: copying a param with shape [] from checkpoint,the shape in cur
strict=False 但还是size mismatch for []: copying a param with shape [] from checkpoint,the shape in cur ...
- oracle numa map size mismatch,Oracle启动时提示map size mismatch; abort
Oracle启动时提示map size mismatch; abort 发布时间:2020-06-26 13:35:09 来源:51CTO 阅读:1370 作者:会说话的鱼 今天在DELL服务器的Re ...
最新文章
- SMRT测序技术及其在微生物研究中的应用
- JS通过正则限制 input 输入框只能输入整数、小数(金额或者现金) 两位小数
- v-for 切换不同的class
- Yii2.0 数据库更新update
- CF223C【Partial Sums】(组合数学+乱搞)
- jQuery父级以及同级元素查找
- [Sharepoint2007对象模型]第一回:服务器场(SPFarm)
- 我与计算机作文450字,打电脑作文450字
- 《jQuery Mobile入门经典》—— 2.2 展现CSS样式
- Python基础—文件操作
- Labview程序优化
- 数仓建模—数仓建模实战(建模流程/建模工具)
- 并发编程-线程卡死问题实践
- win10电脑外接音响没声音怎么回事?win10电脑外接音响没声音的修复方法
- operator int()用法
- ios 设置导航栏背景色
- 天津城建大学计算机学院官网,天津城建大学计算机与信息工程学院研究生导师简介-杨振舰...
- 多个文本文档合并为一个文件的方法
- iOS 3DES加密解密(一行代码搞定)
- C练题笔记之:Leetcode-780. 到达终点
热门文章
- php 创建透明png,php生成透明背景图片实例
- html页面中漂浮物怎么实现,html的section标签是什么怎么用了
- 海盗王实现物品999个堆叠
- tl494c封装区别_TL494的特点与引脚功能
- ESP32 开发之旅③ Ticker——定时库
- 使用Arduino开发ESP32(07):系统时间和定时任务调度器Ticker
- 第十四届蓝桥杯. 接龙数列(线性DP)
- 取得文件夹下的所有文件的文件名和文件大小
- 解决.Net Core 使用 System.Drawing.Common 在CentOS下报错'Gdip'
- IDEA自动生成实体类