保存方式一:

例如:对 vgg16 网络模型进行保存,模型如下

模型保存使用 torch.save() 方法

torch.save(模型, "文件名.h5")  # 保存 模型结构 + 模型参数

使用方式如下:

import torch
import torchvisionvgg16 = torchvision.models.vgg16(pretrained=False)  # 加载vgg16网络模型# 保存方式1
torch.save(vgg16, "vgg16_method.pth")  # 保存为 vgg16_method.pth 文件

加载保存后的模型文件

import torchmodel = torch.load("vgg16_method.pth")
print(model)

保存方式二:

import torch
import torchvisionvgg16 = torchvision.models.vgg16(pretrained=False)  # 加载vgg16网络模型# 保存模型的参数
torch.save(vgg16.state_dict(), "vgg16_method.pth")  # 把模型的参数 保存成字典

加载保存后的模型文件

import torchmodel = torch.load("vgg16_method.pth") # 使用 torch.load 加载 模型文件,得到模型的参数
print(model)


使用以下方法得到完整的网络模型

import torch
import torchvisionvgg16 = torchvision.models.vgg16(pretrained=False)
vgg16.load_state_dict(torch.load("vgg16_method.pth"))
print(vgg16)

保存自定义模型:

下面是自定义了一个 VVcatModel 模型进行保存。

import torch
from torch import nnclass VVcatModel(nn.Module):def __init__(self):super().__init__()self.conv1 = nn.Conv2d(3, 64, kernel_size=3)def forward(self, x):x = self.conv1(x)return xvvcat = VVcatModel() # 实例化模型
torch.save(vvcat, "vgg16_method.pth")

加载自定义模型

import torch
from torch import nnclass VVcatModel(nn.Module):def __init__(self):super().__init__()self.conv1 = nn.Conv2d(3, 64, kernel_size=3)def forward(self, x):x = self.conv1(x)return x# 此处不需要实例化模型
model = torch.load("vgg16_method.pth")
print(model)

注:自定义保存模型与加载自定义模型区别在于,加载模型时,不需要实例化模型。

注意:
如果使用GPU去训练生成的网络模型文件,在加载时可能出现以下错误

RuntimeError: Input type (torch.FloatTensor) and weight type (torch.cuda.FloatTensor) should be the same or input should be a MKLDNN tensor and weight is a dense tensor


当使用GPU去训练生成的网络模型文件,在加载时使用的是CPU去加载,所以,在加载该网络文件时,需要指定CPU去加载该文件, map_location=torch.device(‘cpu’),添加如下:

model = torch.load("./vgg16_method.pth", map_location=torch.device('cpu'))

Pytorch 网络模型的保存与读取相关推荐

  1. sklearn与pytorch模型的保存与读取

    当我们花了很长时间训练了一个模型,需要用该模型做其他事情(比如迁移学习),或者我们想把自己的机器学习模型分享出去的时候,我们这时候需要将我们的ML模型持久化到硬盘中去. 1.sklearn中模型的保存 ...

  2. 网络模型的保存和读取

    转自:http://blog.csdn.net/lwplwf/article/details/62419087 之前的笔记里实现了softmax回归分类.简单的含有一个隐层的神经网络.卷积神经网络等等 ...

  3. 模型数据的保存和读取

    1,基本内容 目的是将模型数据以文件的形式保存到本地. 使用神经网络模型进行大数据量和复杂模型训练时,训练时间可能会持续增加,此时为避免训练过程出现不可逆的影响,并验证训练效果,可以考虑分段进行,将训 ...

  4. Pytorch中参数和模型的保存与读取

    Tensor变量的存取(包括parameter) 对于普通Tensor变量的存取,如下代码所示: import torch import torch.nn as nn x = torch.ones(3 ...

  5. PyTorch模型的保存加载以及数据的可视化

    文章目录 PyTorch模型的保存和加载 模块和张量的序列化和反序列化 模块状态字典的保存和载入 PyTorch数据的可视化 TensorBoard的使用 总结 PyTorch模型的保存和加载 在深度 ...

  6. 实践教程 | Pytorch 模型的保存与迁移

    实践教程 | Pytorch 模型的保存与迁移 在本篇文章中,笔者首先介绍了模型复用的几种典型场景:然后介绍了如何查看Pytorch模型中的相关参数信息:接着介绍了如何载入模型.如何进行追加训练以及进 ...

  7. PyTorch | 模型的保存和加载

    PyTorch | 模型的保存和加载 一.模型参数的保存和加载 二.完整模型的保存和加载 一.模型参数的保存和加载 torch.save(module.state_dict(), path):使用mo ...

  8. pytorch模型的保存和加载、checkpoint

    pytorch模型的保存和加载.checkpoint 其实之前笔者写代码的时候用到模型的保存和加载,需要用的时候就去度娘搜一下大致代码,现在有时间就来整理下整个pytorch模型的保存和加载,开始学习 ...

  9. LabVIEW保存、读取配置文件

    目录 1.保存配置文件 2.读取配置文件 在软件项目开发过程中避免不了要将数据保存到本地,例如,登录信息.账户.密码等.保存数据到本地的方式有很多种,本篇博文主要分享LabVIEW内置的保存.读取配置 ...

最新文章

  1. 幅度调制后的频率混叠
  2. python socket编程:实现redirect函数、cookie和session
  3. 又一个可视化神器Highcharts,Python版也有哦!
  4. Cocos2d-x 3.2:通过ClippingNode实现一个功能完善的跑马灯公告(1)
  5. 【报告分享】2020情趣用品线上消费趋势报告.pdf(附下载链接)
  6. easyui validatebox设置默认值时 去掉校验
  7. [导入]Nutch 简介 [官方]
  8. GetLastError返回代码含义
  9. docker-redis配置文件修改
  10. 新机购入 戴尔成就5000
  11. oracle 数据库日志归档,ORACLE 数据库日志归档的清理
  12. 那些好玩的生成器网站(三)
  13. python小课笔记_小甲鱼Python第一讲笔记(个人笔记)
  14. 【论文解读IJCAI 2019】Extracting Entities and Events as a Single Task Using a Transition-Based NeuralModel
  15. 在64位ubuntu gcc 编译 -m32报错
  16. 虚幻四C++ 添加角色动画
  17. IPV6 官方文档 解决ipv6 的问题
  18. 如何解决Windows 无法完成格式化SD卡问题?
  19. Systemverilog中static、automatic区别
  20. 无法启动Outlook,无法打开Outlook…

热门文章

  1. CSS Sprite的应用
  2. 【转】一台台式机电脑 是集成显卡,我现在想搞两个显示器,一台显示器看监控,一台显示器自己...
  3. 第十二课:Sizzle引擎详解
  4. 网络生活点滴 网络管理实用8招技巧
  5. flex 3 使用手册
  6. 8,协议序列化组件NewLife.Serialization
  7. JavaScript之数据类型
  8. Android时钟的widget
  9. 后台开发经典书籍--unix网络编程
  10. Java-CAS初探