定义网络时,在网络类的构造函数网络结构定义中添加如下语句:

for p in self.parameters():p.requires_grad = False

该语句的功能是固定定义在该语句之前的网络权重,在训练网络时,反向传播只会改变定义在该语句之后的网络权重,如以下代码所示:

import torch.nn as nn
import torchvision.models as modelsclass PredTrianedVGG16(nn.Module):def __init__(self):super(PredTrianedVGG16, self).__init__()model = models.vgg16()      # 加载pytorch预训练的vgg16self.vgg_layer = nn.Sequential(*list(model.children())[:-1])     # 去掉最后一层for p in self.parameters():      # 固定权重p.requires_grad = Falseself.classifier = nn.Sequential(nn.Linear(25088, 4096, bias=True),nn.ReLU(),nn.Dropout(0.5),nn.Linear(4096, 4096, bias=True),nn.ReLU(),nn.Dropout(0.5),nn.Linear(4096, 2, bias=True))

在这段代码中,训练网络时只会调整classifier中的三个全连接层的权重,而不会变动vgg_layer的权重,这一部分是直接调用的预训练好的vgg网络。为了实现这一功能,还需要在定义优化器时添加含lambda表达式的过滤器,如:

opt = torch.optim.Adam(filter(lambda p: p.requires_grad, vgg16.parameters()), lr=LR)       # vgg16为实例化的网络变量名

通过这种方法实例化一个网络并训练,将训练前后的网络的所有权重保存到两个.txt作对比,得到如下结果:
这两个文件中的前面所有权重完全相同,只有最后几个的全连接层的权重不同,结果与预期一致。

Pytorch加载预训练网络,替换分类层并重新训练相关推荐

  1. PyTorch 加载预训练权重

    前言  使用PyTorch官方提供的权重或者其他第三方提供的权重对相同模型的参数进行初始化,在数据量较少的前提下,可以帮助模型更快地收敛到最优点,达到更好的效果,即迁移学习.  在大部分的迁移学习场景 ...

  2. pytorch加载预训练模型_Pytorch-Transformers 1.0发布,支持六个预训练框架,含27个预训练模型...

    AI 科技评论按:刚刚在Github上发布了开源 Pytorch-Transformers 1.0,该项目支持BERT, GPT, GPT-2, Transfo-XL, XLNet, XLM等,并包含 ...

  3. pytorch加载预训练 加载部分参数

    最简单的: state_dict = torch.load(weight_path)    self.load_state_dict(state_dict,strict=False) 加载cpu: m ...

  4. pytorch加载预训练模型遇到的问题:KeyError: ‘bn1.num_batches_tracked‘

    问题 最近在使用pytorch1.0加载resnet预训练模型时,遇到的一个问题,在此记录一下.     KeyError: 'layer1.0.bn1.num_batches_tracked' 其实 ...

  5. Pytorch 加载预训练模型参数时出现size mismatch错误

    目录 1 不妨先研究一下' resnet18-5c106cde.pth'里面存了什么东西以及它的数据类型 (1_1)' resnet18-5c106cde.pth'的数据类型

  6. pytorch:加载预训练模型(多卡加载单卡预训练模型,多GPU,单GPU)

    在pytorch加载预训练模型时,可能遇到以下几种情况. 分为以下几种 在pytorch加载预训练模型时,可能遇到以下几种情况. 1.多卡训练模型加载单卡预训练模型 2. 多卡训练模型加载多卡预训练模 ...

  7. pytorch 加载不对齐预训练

    以前改变网络通道数,需要重新从头训练,无法加载预训练,今天研究了一下如何改变网络通道后,还有预训练模型可用,这样可以减少980%的训练时间,提供训练效率. 废话不说,直接上代码: 这个代码加载预训练模 ...

  8. Pytorch网络模型权重初始化、保存与加载模型、加载预训练模型、按需设置学习率

    前言 在我们对神经网络模型进行训练时,往往需要对模型进行初始化或者加载预训练模型.本文将对模型的权重初始化与加载预训练模型做一个学习记录,以便后续查询使用. 权重初始化 常见的初始化方法 PyTorc ...

  9. keras冻结_Keras 实现加载预训练模型并冻结网络的层

    在解决一个任务时,我会选择加载预训练模型并逐步fine-tune.比如,分类任务中,优异的深度学习网络有很多. ResNet, VGG, Xception等等... 并且这些模型参数已经在imagen ...

最新文章

  1. OEP30W频率响应
  2. 编程式事务与声明式事务
  3. 如何自己动手建立最简单的动态网站
  4. 鸟哥的Linux私房菜(基础篇)- Red Hat 6.x旧文件
  5. 第三次学JAVA再学不好就吃翔(part5)--基础语法之数据类型转换
  6. python异常处理优点_python自测100题(下)
  7. 换手机的再等等!iPhone SE2还有戏:苹果官网悄然更新AppleCare+服务计划
  8. Docker Java程序镜像制作
  9. C#中DllImport用法
  10. linux是否有安装java_Linux 安装 Java
  11. 只想着一直调用一直爽, 那API凭证泄漏风险如何破?
  12. 光子能变成正负电子,能不能变成其他正反物质?
  13. Python Thrift 简单示例
  14. 织梦响应式酒店民宿住宿类网站织梦模板(自适应手机端)
  15. 80psi等于多少kpa_关于胎压的换算psi、bar,kpa
  16. go mysql报错Error 1406: Data truncation: Data too long for column ‘content‘ at row 1
  17. 户外直播、慢直播、赛事直播等直播行业的未来发展趋势
  18. 计算机应用基础实验报告心得体会,计算机应用基础实训总结报告
  19. C#chart绘折线图动态添加数据
  20. android手机绘图软件,手机绘画软件(MediBang Paint Tablet)

热门文章

  1. 测试小故事75:角色
  2. paypal 的支付流程
  3. VMware清理磁盘空间
  4. java开发srm系统_SRM系统需要实现的核心功能有哪些?
  5. 湖北恩施:土家老人展示“九老十八匠”民间手艺
  6. 移动信息终端基带芯片的开发与产业化
  7. MetroGAN: Simulating Urban Morphology with Generative Adversarial Network
  8. HTML+CSS基础总结(下)
  9. 一键 linux桌面安装vnc,Linux OpenVZ Debian 7 32/64bit环境一键安装VNC桌面环境教程
  10. JASS代码加翻译更新(第六篇)