Pytorch 网络模型的保存与读取
保存方式一:
例如:对 vgg16 网络模型进行保存,模型如下
模型保存使用 torch.save() 方法
torch.save(模型, "文件名.h5") # 保存 模型结构 + 模型参数
使用方式如下:
import torch
import torchvisionvgg16 = torchvision.models.vgg16(pretrained=False) # 加载vgg16网络模型# 保存方式1
torch.save(vgg16, "vgg16_method.pth") # 保存为 vgg16_method.pth 文件
加载保存后的模型文件
import torchmodel = torch.load("vgg16_method.pth")
print(model)
保存方式二:
import torch
import torchvisionvgg16 = torchvision.models.vgg16(pretrained=False) # 加载vgg16网络模型# 保存模型的参数
torch.save(vgg16.state_dict(), "vgg16_method.pth") # 把模型的参数 保存成字典
加载保存后的模型文件
import torchmodel = torch.load("vgg16_method.pth") # 使用 torch.load 加载 模型文件,得到模型的参数
print(model)
使用以下方法得到完整的网络模型
import torch
import torchvisionvgg16 = torchvision.models.vgg16(pretrained=False)
vgg16.load_state_dict(torch.load("vgg16_method.pth"))
print(vgg16)
保存自定义模型:
下面是自定义了一个 VVcatModel 模型进行保存。
import torch
from torch import nnclass VVcatModel(nn.Module):def __init__(self):super().__init__()self.conv1 = nn.Conv2d(3, 64, kernel_size=3)def forward(self, x):x = self.conv1(x)return xvvcat = VVcatModel() # 实例化模型
torch.save(vvcat, "vgg16_method.pth")
加载自定义模型
import torch
from torch import nnclass VVcatModel(nn.Module):def __init__(self):super().__init__()self.conv1 = nn.Conv2d(3, 64, kernel_size=3)def forward(self, x):x = self.conv1(x)return x# 此处不需要实例化模型
model = torch.load("vgg16_method.pth")
print(model)
注:自定义保存模型与加载自定义模型区别在于,加载模型时,不需要实例化模型。
注意:
如果使用GPU去训练生成的网络模型文件,在加载时可能出现以下错误
RuntimeError: Input type (torch.FloatTensor) and weight type (torch.cuda.FloatTensor) should be the same or input should be a MKLDNN tensor and weight is a dense tensor
当使用GPU去训练生成的网络模型文件,在加载时使用的是CPU去加载,所以,在加载该网络文件时,需要指定CPU去加载该文件, map_location=torch.device(‘cpu’),添加如下:
model = torch.load("./vgg16_method.pth", map_location=torch.device('cpu'))
Pytorch 网络模型的保存与读取相关推荐
- sklearn与pytorch模型的保存与读取
当我们花了很长时间训练了一个模型,需要用该模型做其他事情(比如迁移学习),或者我们想把自己的机器学习模型分享出去的时候,我们这时候需要将我们的ML模型持久化到硬盘中去. 1.sklearn中模型的保存 ...
- 网络模型的保存和读取
转自:http://blog.csdn.net/lwplwf/article/details/62419087 之前的笔记里实现了softmax回归分类.简单的含有一个隐层的神经网络.卷积神经网络等等 ...
- 模型数据的保存和读取
1,基本内容 目的是将模型数据以文件的形式保存到本地. 使用神经网络模型进行大数据量和复杂模型训练时,训练时间可能会持续增加,此时为避免训练过程出现不可逆的影响,并验证训练效果,可以考虑分段进行,将训 ...
- Pytorch中参数和模型的保存与读取
Tensor变量的存取(包括parameter) 对于普通Tensor变量的存取,如下代码所示: import torch import torch.nn as nn x = torch.ones(3 ...
- PyTorch模型的保存加载以及数据的可视化
文章目录 PyTorch模型的保存和加载 模块和张量的序列化和反序列化 模块状态字典的保存和载入 PyTorch数据的可视化 TensorBoard的使用 总结 PyTorch模型的保存和加载 在深度 ...
- 实践教程 | Pytorch 模型的保存与迁移
实践教程 | Pytorch 模型的保存与迁移 在本篇文章中,笔者首先介绍了模型复用的几种典型场景:然后介绍了如何查看Pytorch模型中的相关参数信息:接着介绍了如何载入模型.如何进行追加训练以及进 ...
- PyTorch | 模型的保存和加载
PyTorch | 模型的保存和加载 一.模型参数的保存和加载 二.完整模型的保存和加载 一.模型参数的保存和加载 torch.save(module.state_dict(), path):使用mo ...
- pytorch模型的保存和加载、checkpoint
pytorch模型的保存和加载.checkpoint 其实之前笔者写代码的时候用到模型的保存和加载,需要用的时候就去度娘搜一下大致代码,现在有时间就来整理下整个pytorch模型的保存和加载,开始学习 ...
- LabVIEW保存、读取配置文件
目录 1.保存配置文件 2.读取配置文件 在软件项目开发过程中避免不了要将数据保存到本地,例如,登录信息.账户.密码等.保存数据到本地的方式有很多种,本篇博文主要分享LabVIEW内置的保存.读取配置 ...
最新文章
- 幅度调制后的频率混叠
- python socket编程:实现redirect函数、cookie和session
- 又一个可视化神器Highcharts,Python版也有哦!
- Cocos2d-x 3.2:通过ClippingNode实现一个功能完善的跑马灯公告(1)
- 【报告分享】2020情趣用品线上消费趋势报告.pdf(附下载链接)
- easyui validatebox设置默认值时 去掉校验
- [导入]Nutch 简介 [官方]
- GetLastError返回代码含义
- docker-redis配置文件修改
- 新机购入 戴尔成就5000
- oracle 数据库日志归档,ORACLE 数据库日志归档的清理
- 那些好玩的生成器网站(三)
- python小课笔记_小甲鱼Python第一讲笔记(个人笔记)
- 【论文解读IJCAI 2019】Extracting Entities and Events as a Single Task Using a Transition-Based NeuralModel
- 在64位ubuntu gcc 编译 -m32报错
- 虚幻四C++ 添加角色动画
- IPV6 官方文档 解决ipv6 的问题
- 如何解决Windows 无法完成格式化SD卡问题?
- Systemverilog中static、automatic区别
- 无法启动Outlook,无法打开Outlook…