参考资料

  1. pytorch 模型部分参数的加载
  2. Pytorch中,只导入部分模型参数的做法
  3. Correct way to freeze layers
  4. Pytorch自由载入部分模型参数并冻结
  5. pytorch冻结部分参数训练另一部分
  6. PyTorch更新部分网络,其他不更新
  7. Pytorch固定部分参数(只训练部分层)

加载部分参数

如果加载现有模型的所有参数,我们常使用的是代码如下:

torch.load(model.state_dict())

在训练过程中,我们常常会使用预训练模型,有时我们是在自己的模型中加入别人的某些模块,或者对别人的模型进行局部修改,这个时候再使用torch.load(model.state_dict()),就会出现类似这些的错误:RuntimeError: Error(s) in loading state_dict for Net:Missing key(s) in state_dict:xxx。出现这个错误就是某些参数缺失或者不匹配。

保持原来网络层的名称和结构不变

现有模型中引入的那部分网络结构的网络层的名称和结构保持不变,这时候加载参数的代码很简单。

# 加载引入的网络模型
model_path = "xxx"
checkpoint = torch.load(os.path.join(model_path, map_location=torch.device('cpu'))
pretrained_dict = checkpoint['net']
# 获取现有模型的参数字典
model_dict =  model.state_dict()
# 获取两个模型相同网络层的参数字典
state_dict = {k:v for k,v in pretrained_dict.items() if k in model_dict.keys()}
# update必不可少,实现相同key的value同步
model_dict.update(state_dict)
# 加载模型部分参数
model.load_state_dict(model_dict)

引入的网络层名称发生修改

这个时候再直接使用上面的加载方法,会导致部分key的value无法实现更新。

我就曾在这个位置犯过很严重的错误。首先我定义了AttentionResNet,这是一个UNet来实现图像分割,然后在另一个模型中我使用了这个模型self.attention_map = AttentionResNet(XXX)。因为我在引用的过程中并没有对AttentionResNet那部分代码进行修改,所以本能的觉得这部分网络层的名称是相同的,所以加载这部分参数时,我直接使用了上面的方法。这个错误隐藏了差不多一个星期。直到我开始冻结这部分参数进行训练时,发现情况不对。因为我在输出attention_map的特征图时,我发现它是一张全黑图(像素全为0),这表示加载的参数不对,然后我尝试输出pretrained_dict时,它是一个空字典。然后继续输出pretrained_dict.keys()(未修改之前的pretrained_dict)和model_dict.keys()发现预期相同的那部分key中都多了一部分attention_map.。问题主要出在self.attention_map = AttentionResNet(XXX)这一句,它使原有的网络层名称都加了个前缀attention_map.,知道了错误,修改起来很简单。

# 加载引入的网络模型
model_path = "xxx"
checkpoint = torch.load(os.path.join(model_path, map_location=torch.device('cpu'))
pretrained_dict = checkpoint['net']
# 获取现有模型的参数字典
model_dict =  model.state_dict()
# 获取两个模型相同网络层的参数字典
state_dict = {'attention_map.' + k:v for k,v in pretrained_dict.items() if 'attention_map.' + k in model_dict.keys()}
# update必不可少,实现相同key的value同步
model_dict.update(state_dict)
# 加载模型部分参数
model.load_state_dict(model_dict)

其实我这个位置的修改有点投机,更加常规的方法是:

引用自Pytorch自由载入部分模型参数并冻结

我们看出只要构建一个字典,使得字典的keys和我们自己创建的网络相同,我们在从各种预训练网络把想要的参数对着新的keys填进去就可以有一个新的state_dict了,这样我们就可以load这个新的state_dict,这是最普适的方法适用于所有的网络变化。

先输出两个模型的参数字典,观察需要加载的那部分参数所处的位置,然后利用for循环构建新的字典。

冻结参数

  1. 将需要固定的那部分参数的requires_grad置为False.
  2. 在优化器中加入filter根据requires_grad进行过滤.

ps: 解决AttributeError: ‘NoneType’ object has no attribute ‘data’问题的一种思路就是冻结参数,参考博客

代码如下:

# requires_grad置为False
for p in net.XXX.parameters():p.requires_grad = False# filter
optimizer.SGD(filter(lambda p: p.requires_grad, model.parameters()), lr=1e-3)

当需要冻结的那部分参数的网络层名称不太明确时,可以采用pytorch冻结部分参数训练另一部分的思路,打印出所有网络层,通过参数名称进行冻结。

Pytorch载入部分参数并冻结相关推荐

  1. Pytorch 查看模型参数

    Pytorch 查看模型参数 查看利用Pytorch搭建模型的参数,直接看程序 import torch # 引入torch.nn并指定别名 import torch.nn as nn import ...

  2. pytorch 获取模型参数_Pytorch获取模型参数情况的方法

    分享人工智能技术干货,专注深度学习与计算机视觉领域! 相较于Tensorflow,Pytorch一开始就是以动态图构建神经网络图的,其获取模型参数的方法也比较容易,既可以根据其内建接口自己写代码获取模 ...

  3. Pytorch统计网络参数计算工具、模型 FLOPs, MACs, MAdds 关系

    Pytorch统计网络参数 #网络参数数量 def get_parameter_number(net):total_num = sum(p.numel() for p in net.parameter ...

  4. PyTorch载入预训练权重方法和冻结权重方法

    载入预训练权重 1. 直接载入预训练权重 简单粗暴法: pretrain_weights_path = "./resnet50.pth" net.load_state_dict(t ...

  5. pytorch加载之前训练模型中的部分参数以及冻结部分参数(实测,自己实际项目代码中的)

    我的需求是,由于我在不停的尝试各种模型,导致模型木块一直会变.如果每次重复重新开始训练要花费大把时间. 我之前运行的模型 ResNet ->                            ...

  6. pytorch载入预训练模型后,训练指定层

    1.有了已经训练好的模型参数,对这个模型的某些层做了改变,如何利用这些训练好的模型参数继续训练: pretrained_params = torch.load('Pretrained_Model') ...

  7. Pytorch导入模型参数

    1.Pytorch中,只导入部分模型参数 https://blog.csdn.net/qq_34914551/article/details/87871134 2. pytorch如何在某些层上冻结网 ...

  8. pytorch 获取模型参数_剑指TensorFlow,PyTorch Hub官方模型库一行代码复现主流模型...

    选自PyTorch 机器之心编译 参与:思源.一鸣 经典预训练模型.新型前沿研究模型是不是比较难调用?PyTorch 团队今天发布了模型调用神器 PyTorch Hub,只需一行代码,BERT.GPT ...

  9. pytorch打印模型参数_Pytorch网络压缩系列教程一:Prune你的模型

    Pytorch网络压缩系列教程一:Prune你的模型 本文由林大佬原创,转载请注明出处,来自腾讯.阿里等一线AI算法工程师组成的QQ交流群欢迎你的加入: 1037662480 深度学习模型取得了前所未 ...

最新文章

  1. android根据ip获取域名_android常用工具类 通过域名获取ip
  2. linux root权限_深入了解 Linux 权限
  3. IBM AIX JFS 数据恢复记录(暂)
  4. 从 datetime2 数据类型到 datetime 数据类型的转换产生一个超出范围的值
  5. 常用前端框架Angular和React的一些认识
  6. 【学习笔记】传输层:TCP协议(报文段、连接管理{握手}、可靠传输、流量控制、拥塞控制)
  7. python读取csv画图datetime_python – CSV数据(Timestamp和事件)的时间表绘图:x-label常量...
  8. 减肥日程表(WPS文档反馈群253147947)
  9. GBDT+LR算法解析及Python实现
  10. JAVA Excel下载学习
  11. 微信支付一面(C++后台)
  12. 电路设计基础知识(一)[转]
  13. Pandas修改列名
  14. 工业数字相机的应用及基础知识
  15. 穷人和富人的距离0.05厘米
  16. 天龙八部服务器维护怎么进去,天龙八部怎么进不去?维护了吗?到什么时候?...
  17. nbu新增media server过程简介
  18. 【物联网】物联网关键技术与应用分析
  19. Centos7安装加速下载工具aria2
  20. 变压器绕制工艺之分布电容

热门文章

  1. 知乎上一些JAVA精选问答
  2. 2021年最新React状态管理解决方案
  3. SQL Saturday活动再起
  4. HTML仿QQ音乐页面附源码(无框架)
  5. Boost.Spirit x3学习笔记
  6. 数据类型与数据结构 文件读写及绘图
  7. expect命令简介及使用案例
  8. Git实用教程 4.0:回到过去
  9. 简单爬取Library genesis 免费文献下载网
  10. 高中计算机竞赛考试题,2019年高中信息技术基本功竞赛试卷试题