Pytorch模型迁移和迁移学习

目录

Pytorch模型迁移和迁移学习

1. 利用resnet18做迁移学习

2. 修改网络名称并迁移学习

3.去除原模型的某些模块


1. 利用resnet18做迁移学习

import torch
from torchvision import modelsif __name__ == "__main__":# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")device = 'cpu'print("-----device:{}".format(device))print("-----Pytorch version:{}".format(torch.__version__))input_tensor = torch.zeros(1, 3, 100, 100)print('input_tensor:', input_tensor.shape)pretrained_file = "model/resnet18-5c106cde.pth"model = models.resnet18()model.load_state_dict(torch.load(pretrained_file))model.eval()out = model(input_tensor)print("out:", out.shape, out[0, 0:10])

结果输出:

input_tensor: torch.Size([1, 3, 100, 100])
out: torch.Size([1, 1000]) tensor([ 0.4010,  0.8436,  0.3072,  0.0627,  0.4446,  0.8470,  0.1882,  0.7012,0.2988, -0.7574], grad_fn=<SliceBackward>)

如果,我们修改了resnet18的网络结构,如何将原来预训练模型参数(resnet18-5c106cde.pth)迁移到新的resnet18网络中呢?

比如,这里将官方的resnet18的self.layer4 = self._make_layer(block, 512, layers[3], stride=2)改为:self.layer44 = self._make_layer(block, 512, layers[3], stride=2)

class ResNet(nn.Module):def __init__(self, block, layers, num_classes=1000, zero_init_residual=False):super(ResNet, self).__init__()self.inplanes = 64self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3,bias=False)self.bn1 = nn.BatchNorm2d(64)self.relu = nn.ReLU(inplace=True)self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)self.layer1 = self._make_layer(block, 64, layers[0])self.layer2 = self._make_layer(block, 128, layers[1], stride=2)self.layer3 = self._make_layer(block, 256, layers[2], stride=2)self.layer44 = self._make_layer(block, 512, layers[3], stride=2)self.avgpool = nn.AdaptiveAvgPool2d((1, 1))self.fc = nn.Linear(512 * block.expansion, num_classes)for m in self.modules():if isinstance(m, nn.Conv2d):nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')elif isinstance(m, nn.BatchNorm2d):nn.init.constant_(m.weight, 1)nn.init.constant_(m.bias, 0)# Zero-initialize the last BN in each residual branch,# so that the residual branch starts with zeros, and each residual block behaves like an identity.# This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677if zero_init_residual:for m in self.modules():if isinstance(m, Bottleneck):nn.init.constant_(m.bn3.weight, 0)elif isinstance(m, BasicBlock):nn.init.constant_(m.bn2.weight, 0)def _make_layer(self, block, planes, blocks, stride=1):downsample = Noneif stride != 1 or self.inplanes != planes * block.expansion:downsample = nn.Sequential(conv1x1(self.inplanes, planes * block.expansion, stride),nn.BatchNorm2d(planes * block.expansion),)layers = []layers.append(block(self.inplanes, planes, stride, downsample))self.inplanes = planes * block.expansionfor _ in range(1, blocks):layers.append(block(self.inplanes, planes))return nn.Sequential(*layers)def forward(self, x):x = self.conv1(x)x = self.bn1(x)x = self.relu(x)x = self.maxpool(x)x = self.layer1(x)x = self.layer2(x)x = self.layer3(x)x = self.layer44(x)x = self.avgpool(x)x = x.view(x.size(0), -1)x = self.fc(x)return x

这时,直接加载模型:

model = models.resnet18()
    model.load_state_dict(torch.load(pretrained_file))

这时,肯定会报错,类似:Missing key(s) in state_dict或者Unexpected key(s) in state_dict的错误:

RuntimeError: Error(s) in loading state_dict for ResNet:
    Missing key(s) in state_dict: "layer44.0.conv1.weight", "layer44.0.bn1.weight", "layer44.0.bn1.bias", "layer44.0.bn1.running_mean", "layer44.0.bn1.running_var", "layer44.0.conv2.weight", "layer44.0.bn2.weight", "layer44.0.bn2.bias", "layer44.0.bn2.running_mean", "layer44.0.bn2.running_var", "layer44.0.downsample.0.weight", "layer44.0.downsample.1.weight", "layer44.0.downsample.1.bias", "layer44.0.downsample.1.running_mean", "layer44.0.downsample.1.running_var", "layer44.1.conv1.weight", "layer44.1.bn1.weight", "layer44.1.bn1.bias", "layer44.1.bn1.running_mean", "layer44.1.bn1.running_var", "layer44.1.conv2.weight", "layer44.1.bn2.weight", "layer44.1.bn2.bias", "layer44.1.bn2.running_mean", "layer44.1.bn2.running_var". 
    Unexpected key(s) in state_dict: "layer4.0.conv1.weight", "layer4.0.bn1.running_mean", "layer4.0.bn1.running_var", "layer4.0.bn1.weight", "layer4.0.bn1.bias", "layer4.0.conv2.weight", "layer4.0.bn2.running_mean", "layer4.0.bn2.running_var", "layer4.0.bn2.weight", "layer4.0.bn2.bias", "layer4.0.downsample.0.weight", "layer4.0.downsample.1.running_mean", "layer4.0.downsample.1.running_var", "layer4.0.downsample.1.weight", "layer4.0.downsample.1.bias", "layer4.1.conv1.weight", "layer4.1.bn1.running_mean", "layer4.1.bn1.running_var", "layer4.1.bn1.weight", "layer4.1.bn1.bias", "layer4.1.conv2.weight", "layer4.1.bn2.running_mean", "layer4.1.bn2.running_var", "layer4.1.bn2.weight", "layer4.1.bn2.bias".

Process finished with

RuntimeError: Error(s) in loading state_dict for ResNet:
    Unexpected key(s) in state_dict: "layer4.0.conv1.weight", "layer4.0.bn1.running_mean", "layer4.0.bn1.running_var", "layer4.0.bn1.weight", "layer4.0.bn1.bias", "layer4.0.conv2.weight", "layer4.0.bn2.running_mean", "layer4.0.bn2.running_var", "layer4.0.bn2.weight", "layer4.0.bn2.bias", "layer4.0.downsample.0.weight", "layer4.0.downsample.1.running_mean", "layer4.0.downsample.1.running_var", "layer4.0.downsample.1.weight", "layer4.0.downsample.1.bias", "layer4.1.conv1.weight", "layer4.1.bn1.running_mean", "layer4.1.bn1.running_var", "layer4.1.bn1.weight", "layer4.1.bn1.bias", "layer4.1.conv2.weight", "layer4.1.bn2.running_mean", "layer4.1.bn2.running_var", "layer4.1.bn2.weight", "layer4.1.bn2.bias".

我们希望将原来预训练模型参数(resnet18-5c106cde.pth)迁移到新的resnet18网络,当然只能迁移二者相同的模型参数,不同的参数还是随机初始化的.

def transfer_model(pretrained_file, model):'''只导入pretrained_file部分模型参数tensor([-0.7119,  0.0688, -1.7247, -1.7182, -1.2161, -0.7323, -2.1065, -0.5433,-1.5893, -0.5562]update:D.update([E, ]**F) -> None.  Update D from dict/iterable E and F.If E is present and has a .keys() method, then does:  for k in E: D[k] = E[k]If E is present and lacks a .keys() method, then does:  for k, v in E: D[k] = vIn either case, this is followed by: for k in F:  D[k] = F[k]:param pretrained_file::param model::return:'''pretrained_dict = torch.load(pretrained_file)  # get pretrained dictmodel_dict = model.state_dict()  # get model dict# 在合并前(update),需要去除pretrained_dict一些不需要的参数pretrained_dict = transfer_state_dict(pretrained_dict, model_dict)model_dict.update(pretrained_dict)  # 更新(合并)模型的参数model.load_state_dict(model_dict)return modeldef transfer_state_dict(pretrained_dict, model_dict):'''根据model_dict,去除pretrained_dict一些不需要的参数,以便迁移到新的网络url: https://blog.csdn.net/qq_34914551/article/details/87871134:param pretrained_dict::param model_dict::return:'''# state_dict2 = {k: v for k, v in save_model.items() if k in model_dict.keys()}state_dict = {}for k, v in pretrained_dict.items():if k in model_dict.keys():# state_dict.setdefault(k, v)state_dict[k] = velse:print("Missing key(s) in state_dict :{}".format(k))return state_dictif __name__ == "__main__":input_tensor = torch.zeros(1, 3, 100, 100)print('input_tensor:', input_tensor.shape)pretrained_file = "model/resnet18-5c106cde.pth"# model = resnet18()# model.load_state_dict(torch.load(pretrained_file))# model.eval()# out = model(input_tensor)# print("out:", out.shape, out[0, 0:10])model1 = resnet18()model1 = transfer_model(pretrained_file, model1)out1 = model1(input_tensor)print("out1:", out1.shape, out1[0, 0:10])

2. 修改网络名称并迁移学习

上面的例子,只是将官方的resnet18的self.layer4 = self._make_layer(block, 512, layers[3], stride=2)改为了:self.layer44 = self._make_layer(block, 512, layers[3], stride=2),我们仅仅是修改了一个网络名称而已,就导致 model.load_state_dict(torch.load(pretrained_file))出错,

那么,我们如何将预训练模型"model/resnet18-5c106cde.pth"转换成符合新的网络的模型参数呢?

方法很简单,只需要将resnet18-5c106cde.pth的模型参数中所有前缀为layer4的名称,改为layer44即可

本人已经定义好了方法:

modify_state_dict(pretrained_dict, model_dict, old_prefix, new_prefix)

def string_rename(old_string, new_string, start, end):new_string = old_string[:start] + new_string + old_string[end:]return new_stringdef modify_model(pretrained_file, model, old_prefix, new_prefix):''':param pretrained_file::param model::param old_prefix::param new_prefix::return:'''pretrained_dict = torch.load(pretrained_file)model_dict = model.state_dict()state_dict = modify_state_dict(pretrained_dict, model_dict, old_prefix, new_prefix)model.load_state_dict(state_dict)return modeldef modify_state_dict(pretrained_dict, model_dict, old_prefix, new_prefix):'''修改model dict:param pretrained_dict::param model_dict::param old_prefix::param new_prefix::return:'''state_dict = {}for k, v in pretrained_dict.items():if k in model_dict.keys():# state_dict.setdefault(k, v)state_dict[k] = velse:for o, n in zip(old_prefix, new_prefix):prefix = k[:len(o)]if prefix == o:kk = string_rename(old_string=k, new_string=n, start=0, end=len(o))print("rename layer modules:{}-->{}".format(k, kk))state_dict[kk] = vreturn state_dict
if __name__ == "__main__":input_tensor = torch.zeros(1, 3, 100, 100)print('input_tensor:', input_tensor.shape)pretrained_file = "model/resnet18-5c106cde.pth"# model = models.resnet18()# model.load_state_dict(torch.load(pretrained_file))# model.eval()# out = model(input_tensor)# print("out:", out.shape, out[0, 0:10])## model1 = resnet18()# model1 = transfer_model(pretrained_file, model1)# out1 = model1(input_tensor)# print("out1:", out1.shape, out1[0, 0:10])#new_file = "new_model.pth"model = resnet18()new_model = modify_model(pretrained_file, model, old_prefix=["layer4"], new_prefix=["layer44"])torch.save(new_model.state_dict(), new_file)model2 = resnet18()model2.load_state_dict(torch.load(new_file))model2.eval()out2 = model2(input_tensor)print("out2:", out2.shape, out2[0, 0:10])

这时,输出,跟之前一模一样了

out: torch.Size([1, 1000]) tensor([ 0.4010,  0.8436,  0.3072,  0.0627,  0.4446,  0.8470,  0.1882,  0.7012,0.2988, -0.7574], grad_fn=<SliceBackward>)

3.去除原模型的某些模块

下面是在不修改原模型代码的情况下,通过"resnet18.named_children()"和"resnet18.children()"的方法去除子模块"fc"和"avgpool"

import torch
import torchvision.models as models
from collections import OrderedDictif __name__=="__main__":resnet18 = models.resnet18(False)print("resnet18",resnet18)# use named_children()resnet18_v1 = OrderedDict(resnet18.named_children())# remove avgpool,fcresnet18_v1.pop("avgpool")resnet18_v1.pop("fc")resnet18_v1 = torch.nn.Sequential(resnet18_v1)print("resnet18_v1",resnet18_v1)# use childrenresnet18_v2 = torch.nn.Sequential(*list(resnet18.children())[:-2])print(resnet18_v2,resnet18_v2)

Pytorch模型迁移和迁移学习,导入部分模型参数相关推荐

  1. Python工程能力进阶、数学基础、经典机器学习模型实战、深度学习理论基础和模型调优技巧……胜任机器学习工程师岗位需要学习什么?...

    咱不敢谈人工智能时代咋样咋样之类的空话,就我自己来看,只要是个营收超过 5 亿的互联网公司,基本都需要具备机器学习的能力.因为大部分公司盈利模式基本都会围绕搜索.推荐和广告而去. 就比如极客时间,他的 ...

  2. 树莓派4B安装系统,pytorch,opencv搭建深度学习目标检测模型

    树莓派4B跑深度学习模型 树莓派目标检测 树莓派4B跑深度学习模型 树莓派4B简介 树莓派系统烧录 烧录步骤 配置树莓派开机wifi连接 远程连接前的配置 远程连接 获取树莓派的IP地址 进入到树莓派 ...

  3. python 3d图形库_PythonOCC 3D图形库学习—导入STEP模型

    原博文 2016-03-04 13:44 − PythonOCC comes with importers/exporters for the most commonly used standard ...

  4. TINA导入SPICE模型

    虽然大多数电路仿真软件(如Multisim.TINA)都提供数千个现成的组件和SPICE仿真模型,但设计人员还是会经常会用到一些数据库没有提供的组件.出现这些种情况时,软件工具通常提供相应的功能来让设 ...

  5. pytorch与keras_Keras vs PyTorch:如何通过迁移学习区分外星人与掠食者

    pytorch与keras by Patryk Miziuła 通过PatrykMiziuła Keras vs PyTorch:如何通过迁移学习区分外星人与掠食者 (Keras vs PyTorch ...

  6. PyTorch系列 | 快速入门迁移学习

    点击上方"算法猿的成长",选择"加为星标" 第一时间关注 AI 和 Python 知识 图片来源:Pexels,作者:Arthur Ogleznev 2019 ...

  7. 净迁移人口预测程序python_高质量深度学习模型, 一键模型预测,迁移学习很简单...

    飞桨(PaddlePaddle)核心框架Paddle Fluid v1.5已经发布,而作为其关键工具,用来迁移学习的PaddleHub也进行了全面更新,正式发布了1.0版本. 全新的PaddleHub ...

  8. PyTorch实战使用Resnet迁移学习

    PyTorch实战使用Resnet迁移学习 项目结构 项目任务 项目代码 网络模型测试 项目结构 数据集存放在flower_data文件夹 cat_to_name.json是makejson文件运行生 ...

  9. 量化交易有因子动物园 深度学习里有模型动物园(ModelZoo)又叫模型市场基于深度学习的增量学习,迁移学习等技术发展而来【调研】

    前言 随着迁移模型的概念流行起来,就像快乐会传染样,自然语言处理,计算机视觉,生成模型,强化学习,非监监督学习,语音识别 这几个领域内部产生了大量的可复用可迁移学习的基础模型,领域之间的方法也在互相学 ...

最新文章

  1. 连载MariaDB Crash Cource中文笔记(第二)
  2. Arcgis Server发布服务
  3. EM 不能启动,重新设置无效
  4. IDEA 集成Lombok 插件-使用插件
  5. 用C#编写ActiveX控件
  6. [JAR包] android引入JAR包,打包成JAR包,打包成Library项目,导入Library项目
  7. linux外部命令帮助,Linux的命令帮助
  8. 每日一笑 | 程序员千万不能轻易去网吧!
  9. Web应用系统中数据传递的方式汇总
  10. Kettle报错:Entry to update with following key could not be found
  11. 数据分析中的可视化-常见图形
  12. 华为交换机做qos案例_景区视频监控交换机如何选?信锐安视交换机给您答案
  13. FTP、WEB虚拟目录作用
  14. Htmlunit 使用记录
  15. 一只刚学竞价两周的菜鸟
  16. URLencode 转义符
  17. 拓端tecdat|python安娜卡列妮娜词云图制作
  18. 服务器cpu占用过高一般是什么原因,常见云服务器CPU占用100%问题原因及解决办法...
  19. 鹏业安装算量复制工程量
  20. 狂神说笔记之ElasticSearch

热门文章

  1. jdbc_servlet基础增删改分页2(userinfo表的)
  2. SQL2008R2 不支持用该后端版本设计数据库关系图或表
  3. PHP程序员的技术成长规划(转载)
  4. Java中的24种设计模式与7大原则
  5. C#方法重载(overload)方法重写(override)隐藏(new)
  6. step4 . day1标准IO和文件IO
  7. 洛谷P2016战略游戏
  8. [BZOJ 2839]集合计数
  9. mock以及特殊场景下对mock数据的处理
  10. Spring Boot 静态资源映射与上传文件路由配置