Pytorch载入部分参数并冻结
参考资料
- pytorch 模型部分参数的加载
- Pytorch中,只导入部分模型参数的做法
- Correct way to freeze layers
- Pytorch自由载入部分模型参数并冻结
- pytorch冻结部分参数训练另一部分
- PyTorch更新部分网络,其他不更新
- Pytorch固定部分参数(只训练部分层)
加载部分参数
如果加载现有模型的所有参数,我们常使用的是代码如下:
torch.load(model.state_dict())
在训练过程中,我们常常会使用预训练模型,有时我们是在自己的模型中加入别人的某些模块,或者对别人的模型进行局部修改,这个时候再使用torch.load(model.state_dict())
,就会出现类似这些的错误:RuntimeError: Error(s) in loading state_dict for Net:Missing key(s) in state_dict:xxx
。出现这个错误就是某些参数缺失或者不匹配。
保持原来网络层的名称和结构不变
现有模型中引入的那部分网络结构的网络层的名称和结构保持不变,这时候加载参数的代码很简单。
# 加载引入的网络模型
model_path = "xxx"
checkpoint = torch.load(os.path.join(model_path, map_location=torch.device('cpu'))
pretrained_dict = checkpoint['net']
# 获取现有模型的参数字典
model_dict = model.state_dict()
# 获取两个模型相同网络层的参数字典
state_dict = {k:v for k,v in pretrained_dict.items() if k in model_dict.keys()}
# update必不可少,实现相同key的value同步
model_dict.update(state_dict)
# 加载模型部分参数
model.load_state_dict(model_dict)
引入的网络层名称发生修改
这个时候再直接使用上面的加载方法,会导致部分key的value无法实现更新。
我就曾在这个位置犯过很严重的错误。首先我定义了AttentionResNet
,这是一个UNet来实现图像分割,然后在另一个模型中我使用了这个模型self.attention_map = AttentionResNet(XXX)
。因为我在引用的过程中并没有对AttentionResNet
那部分代码进行修改,所以本能的觉得这部分网络层的名称是相同的,所以加载这部分参数时,我直接使用了上面的方法。这个错误隐藏了差不多一个星期。直到我开始冻结这部分参数进行训练时,发现情况不对。因为我在输出attention_map
的特征图时,我发现它是一张全黑图(像素全为0),这表示加载的参数不对,然后我尝试输出pretrained_dict
时,它是一个空字典。然后继续输出pretrained_dict.keys()
(未修改之前的pretrained_dict
)和model_dict.keys()
发现预期相同的那部分key中都多了一部分attention_map.
。问题主要出在self.attention_map = AttentionResNet(XXX)
这一句,它使原有的网络层名称都加了个前缀attention_map.
,知道了错误,修改起来很简单。
# 加载引入的网络模型
model_path = "xxx"
checkpoint = torch.load(os.path.join(model_path, map_location=torch.device('cpu'))
pretrained_dict = checkpoint['net']
# 获取现有模型的参数字典
model_dict = model.state_dict()
# 获取两个模型相同网络层的参数字典
state_dict = {'attention_map.' + k:v for k,v in pretrained_dict.items() if 'attention_map.' + k in model_dict.keys()}
# update必不可少,实现相同key的value同步
model_dict.update(state_dict)
# 加载模型部分参数
model.load_state_dict(model_dict)
其实我这个位置的修改有点投机,更加常规的方法是:
引用自Pytorch自由载入部分模型参数并冻结
我们看出只要构建一个字典,使得字典的keys和我们自己创建的网络相同,我们在从各种预训练网络把想要的参数对着新的keys填进去就可以有一个新的state_dict了,这样我们就可以load这个新的state_dict,这是最普适的方法适用于所有的网络变化。
先输出两个模型的参数字典,观察需要加载的那部分参数所处的位置,然后利用for循环构建新的字典。
冻结参数
- 将需要固定的那部分参数的
requires_grad
置为False. - 在优化器中加入filter根据
requires_grad
进行过滤.
ps: 解决AttributeError: ‘NoneType’ object has no attribute ‘data’
问题的一种思路就是冻结参数,参考博客
代码如下:
# requires_grad置为False
for p in net.XXX.parameters():p.requires_grad = False# filter
optimizer.SGD(filter(lambda p: p.requires_grad, model.parameters()), lr=1e-3)
当需要冻结的那部分参数的网络层名称不太明确时,可以采用pytorch冻结部分参数训练另一部分的思路,打印出所有网络层,通过参数名称进行冻结。
Pytorch载入部分参数并冻结相关推荐
- Pytorch 查看模型参数
Pytorch 查看模型参数 查看利用Pytorch搭建模型的参数,直接看程序 import torch # 引入torch.nn并指定别名 import torch.nn as nn import ...
- pytorch 获取模型参数_Pytorch获取模型参数情况的方法
分享人工智能技术干货,专注深度学习与计算机视觉领域! 相较于Tensorflow,Pytorch一开始就是以动态图构建神经网络图的,其获取模型参数的方法也比较容易,既可以根据其内建接口自己写代码获取模 ...
- Pytorch统计网络参数计算工具、模型 FLOPs, MACs, MAdds 关系
Pytorch统计网络参数 #网络参数数量 def get_parameter_number(net):total_num = sum(p.numel() for p in net.parameter ...
- PyTorch载入预训练权重方法和冻结权重方法
载入预训练权重 1. 直接载入预训练权重 简单粗暴法: pretrain_weights_path = "./resnet50.pth" net.load_state_dict(t ...
- pytorch加载之前训练模型中的部分参数以及冻结部分参数(实测,自己实际项目代码中的)
我的需求是,由于我在不停的尝试各种模型,导致模型木块一直会变.如果每次重复重新开始训练要花费大把时间. 我之前运行的模型 ResNet -> ...
- pytorch载入预训练模型后,训练指定层
1.有了已经训练好的模型参数,对这个模型的某些层做了改变,如何利用这些训练好的模型参数继续训练: pretrained_params = torch.load('Pretrained_Model') ...
- Pytorch导入模型参数
1.Pytorch中,只导入部分模型参数 https://blog.csdn.net/qq_34914551/article/details/87871134 2. pytorch如何在某些层上冻结网 ...
- pytorch 获取模型参数_剑指TensorFlow,PyTorch Hub官方模型库一行代码复现主流模型...
选自PyTorch 机器之心编译 参与:思源.一鸣 经典预训练模型.新型前沿研究模型是不是比较难调用?PyTorch 团队今天发布了模型调用神器 PyTorch Hub,只需一行代码,BERT.GPT ...
- pytorch打印模型参数_Pytorch网络压缩系列教程一:Prune你的模型
Pytorch网络压缩系列教程一:Prune你的模型 本文由林大佬原创,转载请注明出处,来自腾讯.阿里等一线AI算法工程师组成的QQ交流群欢迎你的加入: 1037662480 深度学习模型取得了前所未 ...
最新文章
- android根据ip获取域名_android常用工具类 通过域名获取ip
- linux root权限_深入了解 Linux 权限
- IBM AIX JFS 数据恢复记录(暂)
- 从 datetime2 数据类型到 datetime 数据类型的转换产生一个超出范围的值
- 常用前端框架Angular和React的一些认识
- 【学习笔记】传输层:TCP协议(报文段、连接管理{握手}、可靠传输、流量控制、拥塞控制)
- python读取csv画图datetime_python – CSV数据(Timestamp和事件)的时间表绘图:x-label常量...
- 减肥日程表(WPS文档反馈群253147947)
- GBDT+LR算法解析及Python实现
- JAVA Excel下载学习
- 微信支付一面(C++后台)
- 电路设计基础知识(一)[转]
- Pandas修改列名
- 工业数字相机的应用及基础知识
- 穷人和富人的距离0.05厘米
- 天龙八部服务器维护怎么进去,天龙八部怎么进不去?维护了吗?到什么时候?...
- nbu新增media server过程简介
- 【物联网】物联网关键技术与应用分析
- Centos7安装加速下载工具aria2
- 变压器绕制工艺之分布电容