1. 保存模型:torch.save(model.state_dict(), PATH)

加载模型:model.load_state_dict(torch.load(PATH))

model.eval()

2. 什么是状态字典:state_dict?

在PyTorch中, torch.nn.Module 模型的可学习参数(即权重和偏差)包含在模型的参数中,(使用model.parameters() 可以进行访问)。 state_dict 是Python字典对象,它将每一层映射到其参数张量。注意,只有具有可学习参数的层(如卷积层,线性层等)的模型才具有state_dict 这一项。目标优化torch.optim 也有state_dict 属性,它包含有关优化器的状态信息,以及使用的超参数。

因为state_dict的对象是Python字典,所以它们可以很容易的保存、更新、修改和恢复,为PyTorch模型和优化器添加了大量模块。

3. torch.nn.Module.state_dict(destination=None, prefix='', keep_vars=False)
返回一个包含模型状态信息的字典。包含参数(weighs and biases)和持续的缓冲值(如:观测值的平均值)。只有具有可更新参数的层才会被保存在模型的 state_dict 数据结构中。

当保存好模型用来推断的时候,只需要保存模型学习到的参数,使用torch.save() 函数来保存模型state_dict ,它会给模型恢复提供最大的灵活性

4. pytorch中自带几种常用的深度学习网络预训练模型,torchvision.models包中包含alexnet、densenet、inception、resnet、squeezenet、vgg等常用网络结构,并且提供了预训练模型,可通过调用来读取网络结构和预训练模型(模型参数)。往往为了加快学习进度,训练的初期直接加载pretrain模型中预先训练好的参数。加载model如下所示:

import torchvision.models as models

1.加载网络结构和预训练参数:resnet34 = models.resnet34(pretrained=True)

2.#只加载网络结构,不加载预训练参数,即不需要用预训练模型的参数来初始化:

resnet18 = models.resnet18(pretrained=False) #pretrained参数默认是False,为了代码清晰,最好还是加上参数赋值.

resnet18.load_state_dict(torch.load(path_params.pkl))#其中,path_params.pkl为预训练模型参数的保存路径。加载预先下载好的预训练参数到resnet18,用预训练模型的参数初始化resnet18的层,此时resnet18发生了改变。调用model的load_state_dict方法用预训练的模型参数来初始化自己定义的新网络结构,这个方法就是PyTorch中通用的用一个模型的参数初始化另一个模型的层的操作。load_state_dict方法还有一个重要的参数是strict,该参数默认是True,表示预训练模型的层和自己定义的网络结构层严格对应相等(比如层名和维度)。故,当新定义的网络(model_dict)和预训练网络(pretrained_dict)的层名不严格相等时,需要先将pretrained_dict里不属于model_dict的键剔除掉 :
pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict} ,再用预训练模型参数更新model_dict,最后用load_state_dict方法初始化自己定义的新网络结构。

5.例子:

model   = DeepLab(num_classes=num_classes, backbone=backbone, downsample_factor=downsample_factor, pretrained=False)  # 加载模型,不加载预训练参数

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model_dict  = model.state_dict()  # 读取

pretrained_dict = torch.load(model_path, map_location = device)

pretrained_dict = {k: v for k, v in pretrained_dict.items() if np.shape(model_dict[k]) == np.shape(v)} # 将pretrained_dict和model_dict中命名一致的层加入pretrained_dict(包括参数)

model_dict.update(pretrained_dict)  # 更新现有的model_dict
model.load_state_dict(model_dict)   #加载真正需要的state_dict  # 加载预先下载好的预训练参数到model

无论是从缺少某些键的 state_dict 加载还是从键的数目多于加载模型的 state_dict , 都可以通过在
load_state_dict() 函数中将strict 参数设置为 False来忽略非匹配键的函数。
如果要将参数从一个层加载到另一个层,但是某些键不匹配,主要修改正在加载的 state_dict 中的
参数键的名称以匹配要在加载到模型中的键即可。

model.load_state_dict(torch.load(PATH), strict=False)

参考:pytorch预训练 - 五妹 - 博客园Pytorch预训练模型以及修改 pytorch中自带几种常用的深度学习网络预训练模型,torchvision.models包中包含alexnet、densenet、inception、resnet、https://www.cnblogs.com/wmlj/p/9917827.html

Pytorch预训练模型加载相关推荐

  1. Whole Word Masking (wwm) BERT PaddlePaddle常用预训练模型加载

    Whole Word Masking (wwm),暂翻译为全词Mask或整词Mask,是谷歌在2019年5月31日发布的一项BERT的升级版本,主要更改了原预训练阶段的训练样本生成策略. 简单来说,原 ...

  2. PyTorch:数据加载,数学原理,猫鱼分类,CNN,预训练,迁移学习

    1,数据加载 PyTorch开发了与数据交互的标准约定,所以能一致地处理数据,而不论处理图像.文本还是音频.与数据交互的两个主要约定是数据集(dataset)和数据加载器(dataloader).数据 ...

  3. php 图片预览原理,JavaScript_纯JS实现的批量图片预览加载功能,1.实现原理直接见代码,需要一 - phpStudy...

    纯JS实现的批量图片预览加载功能 1.实现原理直接见代码,需要一张转圈的小图片,需要预览的所有图片默认的位置全是这张小图片,滚轮滚到原图需要出现的位置时候,预览加载替换小图片.实现效果 复制代码 代码 ...

  4. pytorch 入门学习加载数据集-8

    pytorch 入门学习加载数据集 import torch import numpy as np import torchvision import numpy as np from torch.u ...

  5. 小程序预览加载不出图片

    小程序预览加载不出图片 比如在开发者工具做了一个swiper轮播图 <view> <swiper indicator-dots="true" indicator- ...

  6. pytorch使用Dataloader加载自己的数据集train_X和train_Y

    Pytorch使用Dataloader加载自己的数据集train_X和train_Y 1.重构一个新的dataloader函数 2.调用 1.重构一个新的dataloader函数 在使用torch进行 ...

  7. Pytorch预训练模型下载并加载(以VGG为例)自定义路径

    简述 一般来讲,Pytorch用torchvision调用vgg之类的模型话,如果电脑在cache(Pytorch硬编码的一个地址)(如果在环境变量中添加了TORCH_HOME 和TORCH_MODE ...

  8. Pytorch迁移学习加载部分预训练权重

    迁移学习在图像分类领域非常常见,利用在超大数据集上训练得到的网络权重,迁移到自己的数据上进行训练可以节约大量的训练时间,降低欠拟合/过拟合的风险. 如果用原生网络进行迁移学习非常简单,其核心是 mod ...

  9. pytorch加载预训练 加载部分参数

    最简单的: state_dict = torch.load(weight_path)    self.load_state_dict(state_dict,strict=False) 加载cpu: m ...

最新文章

  1. hihoCoder#1384 : Genius ACM
  2. PythonNET网络编程3
  3. java package private,Java中的public,protected,package-private和private有什么区别?
  4. 盘点全球最美的15座数据中心
  5. RocketMQ多Master多Slave模式部署
  6. Linux下编译、链接、加载运行C++ OpenCV的两种方式及常见问题的解决
  7. 关于前段与后端数据库的连接
  8. Django template 过滤器
  9. C语言的隐式类型转换
  10. 网站丨平淡的生活里增添一点幸福感
  11. 文本框里面加删除按钮
  12. java getbytes 乱码_深入解析java String中getBytes()的编码问题
  13. 速成KeePass全局自动填表登录QQ与迅雷(包括中文输入法状态时用中文用户名一键登录)...
  14. 数独-比回溯法更优的人类思维逻辑的数独解法
  15. 查看各大网站服务器操作系统
  16. 思科WLC与AP无法正常Join
  17. 日撸 Java 三百行: DAY1 AND DAY2
  18. 安卓sdk自带模拟器的使用
  19. 解决C/C++报错error: cannot pass objects of non-trivially-copyable type ‘std::string’问题
  20. 升级php7后的报错处理

热门文章

  1. DQL:数据库查询语句
  2. Zotero文献管理软件使用指南——进阶篇
  3. 译文:A Robust and Modular Multi-Sensor Fusion ApproachApplied to MAV Navigation
  4. ARM汇编中的:比较指令--CMN / CMP / TEQ / TST
  5. c# + halcon编程(读图、显示图、处理图、鼠标和图像交互)
  6. 18天精读掌握《费曼物理学讲义卷一》 第3天 2019.6.14
  7. 电脑CPU使用率过高怎么办
  8. gcc降版本方法 - [学习]
  9. Android产品研发(十)--尽量不使用静态变量保存数据
  10. 教师资格证考试科目汇总