PyTorch载入预训练权重方法和冻结权重方法
载入预训练权重
1. 直接载入预训练权重
简单粗暴法:
pretrain_weights_path = "./resnet50.pth"
net.load_state_dict(torch.load(pretrain_weights_path))
如果这里的pretrain_weights
与我们训练的网络不同,一般指的是包含大于模型参数时,可以修改为
net.load_state_dict(torch.load(pretrain_weights_path), strict=False)
2. 修改网络结构
常用方法1:
model_weight_path = "resnet34pre.pth"
net.load_state_dict(torch.load(model_weight_path))
# 这里假设最后一层为FC层,使用迁移学习,将分类结果修改
# net是实例化的resnet网络,in_features是网络输入结构参数,最后的5是修改的输出参数
inchannel = net.fc.in_features
net.fc = nn.Linear(inchannel, 5)
# 注意,最后去转换我们的模型设备,否则可能会报错,怀疑是修改的模型部分和原模型部分使用的设备不同
net.to(device)
常用方法2:
net = MobileNetV2(num_class=5)
net_weights = net.state_dict()
model_weights_path = "./mobilenet_v2pre.pth"
pre_weights = torch.load(model_weights_path)
# delete classifier weights
# 这种方法主要是遍历字典,.pth文件(权重文件)的本质就是字典的存储
# 通过改变我们载入的权重的键值对,可以和当前的网络进行配对的
# 这里举到的例子是对"classifier"结构层的键值对剔除,或者说是不载入该模块的训练权重
pre_dict = {k: v for k, v in pre_weights.items() if "classifier" not in k}
# 另一种方法会直接两种权重对比,直接两种方法对比,减少问题的存在
pre_dict = {k: v for k, v in pre_weight.items() if net_weights[k].numel() == v.numel()}
# 如果修改了载入权重或载入权重的结构和当前模型的结构不完全相同,需要加strict=False,保证能够权重载入
net.load_state_dict(pre_dict, strict=False)
net.to(device)
灵活提升:
net = resnet34(num_classes=5)
pre_weights = torch.load(model_weight_path, map_location=device)
del_key = []
for key, _ in pre_weights.items():if "fc" in key: # 这里可以多加一下字段,比如 or "layer4" in key:del_key.append(key)
# missing_keys表示net中的部分权重未出现在pre_weights中
# unexpected_keys表示pre_weights当中有一部权重不在net中
missing_keys, unexpected_keys = net.load_state_dict(del_key, strict=False)
# 执行结果
# [missing_keys]:
# fc.weight
# fc.bias
# [unexpected_keys]:
载入权重常见问题
key-val
不匹配问题
解决方式:模型结构修改了,没有正确修改预训练权重,导致载入权重与模型不同,使用上文中的方法适当修改载入权重即可- 载入预训练权重
param
名称和模型中的param
名称不同,导致载入失败
解决方法:修改模型中的para
名称,或者修改网络中模块的名称。
冻结训练
冻结训练方法很简单,只要对requires_grad = False
即可
for name, para in model.named_parameters():# 除最后的全连接层外,其他权重全部冻结if "fc" not in name:para.requires_grad_(False)# 或者 para.requires_grad = False
还有一个小建议,将model
中需要反向梯度传播的param
单独list
出来
pg = [p for p in model.parameters() if p.requires_grad]
optimizer = optim.SGD(pg, lr=args.lr, momentum=0.9, weight_decay=0.005)
PyTorch载入预训练权重方法和冻结权重方法相关推荐
- 迁移学习、载入预训练权重和冻结权重
迁移学习就是载入别人预训练好的权重,拿别人的训练好的参数作为我们自己模型的初始化参数,再在这个基础上继续优化.比起从头开始一点一点随机初始化,让模型胡乱地找梯度最优的方向,肯定是迁移学习快啦. 目录 ...
- PyTorch 的预训练,是时候学习一下了
前言 最近使用 PyTorch 感觉妙不可言,有种当初使用 Keras 的快感,而且速度还不慢.各种设计直接简洁,方便研究,比 tensorflow 的臃肿好多了.今天让我们来谈谈 PyTorch 的 ...
- 【论文写作分析】之三《基于预训练语言模型的案件要素识别方法》
[1] 参考论文信息 论文名称:<基于预训练语言模型的案件要素识别方法> 发布期刊:<中文信息学报> 期刊信息:CSCD 论文写作分析摘要:本文非常典型.首先网 ...
- pytorch载入预训练模型后,训练指定层
1.有了已经训练好的模型参数,对这个模型的某些层做了改变,如何利用这些训练好的模型参数继续训练: pretrained_params = torch.load('Pretrained_Model') ...
- 微软最新论文解读 | 基于预训练自然语言生成的文本摘要方法
作者丨张浩宇 学校丨国防科技大学计算机学院 研究方向丨自然语言生成.知识图谱问答 本文解读的是一篇由国防科技大学与微软亚洲研究院共同完成的工作,文中提出一种基于预训练模型的自然语言生成方法. 摘要 在 ...
- 无需在数据集上学习和预训练,这种图像修复新方法效果惊人 | 论文
林鳞 编译自 Github 量子位 出品 | 公众号 QbitAI Reddit上又炸了,原因是一个无需在数据集上学习和预训练就可以超分辨率.修补和去噪的方法:Deep image prior. 帖子 ...
- PyTorch搭建预训练AlexNet、DenseNet、ResNet、VGG实现猫狗图片分类
目录 前言 AlexNet DensNet ResNet VGG 前言 在之前的文章中,利用一个简单的三层CNN猫狗图片分类,正确率不高,详见: CNN简单实战:PyTorch搭建CNN对猫狗图片进行 ...
- Kaggle数据集猫狗分类(Pytorch+ResNet34预训练)99%以上正确率
关于训练集的介绍和数据划分可以参照上一个博客: https://blog.csdn.net/qq_41685265/article/details/104895273 数据加载 class DogCa ...
- CVPR 2022 | GEN-VLKT:基于预训练知识迁移的HOI检测方法
近日,阿里巴巴大淘宝技术多媒体算法团队与计算机视觉青年学者刘偲教授团队合作论文:<GEN-VLKT: Simplify Association and Enhance Interaction U ...
最新文章
- 命令别名的设置alias,unalias
- 一道有意思的找规律题目 --- CodeForces - 964A
- C语言中static详细分析
- ESIM (Enhanced LSTM for Natural Language Inference)
- 网络模型和TCP协议族
- ES6中的新特性:Iterables和iterators
- 找最大重复次数的数和重复次数(C++ Pair)
- api怎么写_使用Node.js原生API写一个web服务器
- Hadoop的改进实验(中文分词词频统计及英文词频统计)(4/4)
- 5个python标准库及作用_零基础编程——Python标准库使用
- 浅入浅出数据结构(18)——希尔排序
- 解决Mac osx AirPort: Link Down on en1. Reason 8 (Disassociated because station leaving)
- java课程设计---彩票销售管理系统
- top 监控系统内存、进程的资源占用情况
- 我的世界1.8.9无需正版的服务器,我的世界1.8-1.8.9勇者世界生存服务器
- 电影《无双》中的管理知识
- strtolower() 函数
- 【和UI斗智斗勇的日子】如何实现一个类似哈罗单车APP主页打车模块的卡片式切的View
- Power Apps平台利用CDS(Common Data Service)制作问卷调查
- 关于No enclosing instance of type MyProject is accessible的报错