在pytorch中对model进行调整有多种方法。但是总有些莫名奇妙会报错的。
下面有三种,详情见博客

pytorch中的pre-train函数模型引用及修改(增减网络层,修改某层参数等)
(继)pytorch中的pretrain模型网络结构修改

一是加载完模型后直接修改,(对于resnet比较适用,对于vgg就不能用了)比如:
model.fc = nn.Linear(fc_features, 9)
这种情况,适用于修改的层,可以由self.层的名字获取到。
如果层在sequential中。因为sequential类型没有定义setitem,只有getitem 所以不能直接获取某一层并进行修改。就是sequential[0]=nn.Linear(fc_features, 9)是会报错的。(不知道有没有别的方法。)
二是用参数覆盖的方法,即自己先定义一个类似的网络,再将预训练中的参数提取到自己的网络中来。这里以resnet预训练模型举例。
这个方法不太理解。。我还是不知道怎么用到sequential里面。。感觉改动会比较大。
通过state_dict() 去获取每一层的名字并给予权重。就是新定义的网络要注意不能和pretrained的网络有同样名字的层。
三是使用nn.module的model.children()的函数,重新定义自己model的层。这个比较灵活。
self.layer= nn.Sequential(*list(model.children())[:-2])
例如对于vgg11 我想修改成1channel输入 ,输出是100个类别的实现如下:修改和添加的代码比较少。

import torch.nn as nn
import torch.utils.model_zoo as model_zoo
import mathclass VGG(nn.Module):def __init__(self, features, num_classes=1000, init_weights=True):super(VGG, self).__init__()self.features = featuresself.classifier = nn.Sequential(nn.Linear(512 * 7 * 7, 4096),nn.ReLU(True),nn.Dropout(),nn.Linear(4096, 4096),nn.ReLU(True),nn.Dropout(),nn.Linear(4096, num_classes),)if init_weights:self._initialize_weights()def forward(self, x):x = self.features(x)x = x.view(x.size(0), -1)x = self.classifier(x)return xdef _initialize_weights(self):for m in self.modules():if isinstance(m, nn.Conv2d):nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')if m.bias is not None:nn.init.constant_(m.bias, 0)elif isinstance(m, nn.BatchNorm2d):nn.init.constant_(m.weight, 1)nn.init.constant_(m.bias, 0)elif isinstance(m, nn.Linear):nn.init.normal_(m.weight, 0, 0.01)nn.init.constant_(m.bias, 0)def make_layers(cfg, batch_norm=False):layers = []in_channels = 3for v in cfg:if v == 'M':layers += [nn.MaxPool2d(kernel_size=2, stride=2)]else:conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1)if batch_norm:layers += [conv2d, nn.BatchNorm2d(v), nn.ReLU(inplace=True)]else:layers += [conv2d, nn.ReLU(inplace=True)]in_channels = vreturn nn.Sequential(*layers)#生成一个1channel输入的model
def make_one_channel_layers(cfg, batch_norm=False):layers = []in_channels = 1for v in cfg:if v == 'M':layers += [nn.MaxPool2d(kernel_size=2, stride=2)]else:conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1)if batch_norm:layers += [conv2d, nn.BatchNorm2d(v), nn.ReLU(inplace=True)]else:layers += [conv2d, nn.ReLU(inplace=True)]in_channels = vreturn nn.Sequential(*layers)cfg = {'A': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],'B': [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],'D': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'],'E': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, 'M'],
}def vgg11(pretrained=False, **kwargs):"""VGG 11-layer model (configuration "A")Args:pretrained (bool): If True, returns a model pre-trained on ImageNet"""if pretrained:kwargs['init_weights'] = Falsemodel = VGG(make_layers(cfg['A']), **kwargs)#输出为100个类别mymodel=VGG(make_one_channel_layers(cfg['A']),num_classes=100, **kwargs)if pretrained:model.load_state_dict(model_zoo.load_url(model_urls['vgg11']))#在预训练好的model中选择要的部分,拼接自己定义的mymodel类型部分model.features=nn.Sequential(list(mymodel.features.children())[0],*list(model.features.children())[1:])mymodel.classifier=nn.Sequential(*list(model.classifier.children())[:-1],list(mymodel.classifier.children())[-1])  return mymodel

呃。。。。。。。。。。。

其实不用上面那么麻烦 直接修改需要修改的层就可以了像

model.features=nn.Sequential(nn.Conv2d(1, 96, kernel_size=7, stride=2),*list(model.features.children())[1:])

另外直接用
model_conv.classifier[6].out_features = Output_features
这样直接修改参数,输出模型是修改之后的,但是运行时还是会报错Given groups=1, weight[64, 3, 3, 3], so expected input[32, 1, 224, 224] to have 3 channels, but got 1 channels 这样的错。。。所以。。不知道怎么改,如果可以这样的话,就会很方便呀!!!可是报错。。。

附上部分更新模型参数的方法(新模型增加了一些层)

pretrained_dict = ...
model_dict = model.state_dict()
# 1. filter out unnecessary keys
pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
# 2. overwrite entries in the existing state dict
model_dict.update(pretrained_dict)
# 3. load the new state dict
model.load_state_dict(model_dict)

【深度学习】关于pytorch中使用pretrained的模型,对模型进行调整相关推荐

  1. 【深度学习】PyTorch 中的线性回归和梯度下降

    作者 | JNK789   编译 | Flin  来源 | analyticsvidhya 我们正在使用 Jupyter notebook 来运行我们的代码.我们建议在Google Colaborat ...

  2. [PyTorch] 深度学习框架PyTorch中的概念和函数

    Pytorch的概念 Pytorch最重要的概念是tensor,意为"张量". Variable是能够构建计算图的 tensor(对 tensor 的封装).借用Variable才 ...

  3. 深度学习之Pytorch中的flatten()、transpose()和permute()

    1.flatten():压缩维度 input_tensor.flatten(start_dim, end_dim),其中 input_tensor 就是输入的你想压缩的 tensor,start_di ...

  4. DL框架之PyTorch:深度学习框架PyTorch的简介、安装、使用方法之详细攻略

    DL框架之PyTorch:PyTorch的简介.安装.使用方法之详细攻略 DL框架之PyTorch:深度学习框架PyTorch的简介.安装.使用方法之详细攻略 目录 PyTorch的简介 1.pyto ...

  5. 深度学习二(Pytorch物体检测实战)

    深度学习二(Pytorch物体检测实战) 文章目录 深度学习二(Pytorch物体检测实战) 1.PyTorch基础 1.1.基本数据结构:Tensor 1.1.1.Tensor数据类型 1.1.2. ...

  6. 《深度学习之PyTorch物体检测实战》—读书笔记

    随书代码 物体检测与PyTorch 深度学习 为了赋予计算机以人类的理解能力与逻辑思维,诞生了人工智能(Artificial Intelligence, AI)这一学科.在实现人工智能的众多算法中,机 ...

  7. 《深度学习之pytorch实战计算机视觉》第8章 图像风格迁移实战(代码可跑通)

    上一章<深度学习之pytorch实战计算机视觉>第7章 迁移学习(代码可跑通)介绍了迁移学习.本章将完成一个有趣的应用,基于卷积神经网络实现图像风格迁移(Style Transfer).和 ...

  8. 《动手学深度学习》PyTorch版GitHub资源

    之前,偶然间看到过这个PyTorch版<动手学深度学习>,当时留意了一下,后来,着手学习pytorch,发现找不到这个资源了.今天又看到了,赶紧保存下来. <动手学深度学习>P ...

  9. Lesson 12.1 深度学习建模实验中数据集生成函数的创建与使用

    Lesson 12.1 深度学习建模实验中数据集生成函数的创建与使用   为了方便后续练习的展开,我们尝试自己创建一个数据生成器,用于自主生成一些符合某些条件.具备某些特性的数据集.相比于传统的机器学 ...

  10. DL:深度学习框架Pytorch、 Tensorflow各种角度对比

    DL:深度学习框架Pytorch. Tensorflow各种角度对比 目录 先看两个框架实现同样功能的代码 1.Pytorch.Tensorflow代码比较 2.Tensorflow(数据即是代码,代 ...

最新文章

  1. lvs 负载均衡原理及其配置之 nat 模式
  2. Java IO (二),常见的输入/输出流
  3. hdu 1251 统计难题(trie树入门)
  4. 到底什么是IT服务管理
  5. axios.post提交的三种请求方式
  6. java单个变量的表达式_java中使用Lambda表达式的5种语法
  7. CCF CSP201912-1 报数
  8. 活动目录中组的类型和可用范围
  9. matlab数据分类与识别,Matlab图像识别/检索系列(3)—10行代码完成caltech图象集分类和识别...
  10. linux 文件上传扫描_SecureCRT实现windows与linux文件上传下载
  11. linux socket tcp程序,Linux下Socket TCP的简单例子
  12. Git的简介及使用技巧 PPT
  13. git 小乌龟代码回退
  14. 使用Mac系统来进行Java编程
  15. 【java学习之旅】——JSP入门
  16. 【电脑删不掉文件或文件夹】总结7种方法永久删除!
  17. maya表情blendshape_【Maya】角色表情绑定-BlendShape的使用技巧
  18. C语言经典面试题 与 C语言面试宝典
  19. 网易云音乐唱片机效果
  20. 2019微博热点,盘一盘那些记忆中的大瓜

热门文章

  1. Linux软件磁盘列阵RAID
  2. C# dataGridView控件单元格底色 dataGridView背景色 背景色调整 Header背景色前景色
  3. 利用抓包思想实现“优雅”请假
  4. Keras实现小数量集图片分类——6类别Birds数据集分类
  5. 笔记本安装Ubuntu9.04.图文并茂
  6. 系统hosts文件进行域名解析
  7. 【三】多线程 —— 设计模式
  8. 重重事故下,区块链安全的难题与出路 |链捕手
  9. 警告: A docBase inside the host appBase has been specified, and will be ignore
  10. c语言考场排座系统,具才考场座次编排系统