[pytorch、学习] - 4.5 读取和存储
参考
4.5 读取和存储
到目前为止,我们介绍了如何处理数据以及如何构建、训练和测试深度学习模型。然而在实际中,我们有时需要把训练好的模型部署到很多不同的设备。在这种情况下,我们可以把内存中训练好的模型参数存储在硬盘上供后续读取使用。
4.5.1 读写tensor
我们可以直接使用save
函数和load
函数分别存储和读取Tensor
。
下面的例子创建了Tensor
变量x
,并将其存储在文件名为x.pt
的文件里.
import torch
import torch.nn as nnx = torch.ones(3)
torch.save(x, 'x.pt')
然后我们将数据从存储的文件读回内存
x2 = torch.load('x.pt')
x2
存储一个Tensor列表并返回
y =torch.zeros(4)
torch.save([x, y], 'xy.pt')
xy_list = torch.load('xy.pt')
xy_list
存储并读取一个从字符串映射到Tensor
的字典
torch.save({'x': x,'y': y
}, 'xy_dict.pt')
xy = torch.load('xy_dict.pt')
xy
4.5.2 读写模型
4.5.2.1 state_dict
static_dict
是一个从参数名称映射到参数Tensor
的字典对象
class MLP(nn.Module):def __init__(self):super(MLP, self).__init__()self.hidden = nn.Linear(3, 2)self.act = nn.ReLU()self.output = nn.Linear(2, 1)def forward(self, x):a = self.act(self.hidden(x))return self.output(a)net = MLP()
net.state_dict()
注意,只有具有可学习参数的层(卷积层、线性层)才有 state_dict中的条目
optimizer = torch.optim.SGD(net.parameters(), lr= 0.001, momentum=0.9)
optimizer.state_dict()
4.5.2.2 保存和加载模型
PyTorch中保存和加载训练模型有两种常见的方法:
- 仅保存和加载模型参数(state_dict)
- 保存和加载整个模型。
1. 保存加载static_dict
(推荐方式)
torch.save(model.state_dict(), PATH)
# 保存
torch.save(model.state_dict(), PATH)# 加载
model = TheModelClass(*args, **kwargs)
model.load_state_dict(torch.laod(PATH))
2. 保存和加载整个模型
# 保存
torch.save(model, PATH)# 加载
model = torch.load(PATH)
采用第一种方法来试验一下:
X = torch.randn(2, 3)
Y = net(X)PATH = "./net.pt"
torch.save(net.state_dict(), PATH)net2 = MLP()
net2.load_state_dict(torch.load(PATH))
Y2 = net2(X)Y2 ==Y
[pytorch、学习] - 4.5 读取和存储相关推荐
- (pytorch-深度学习系列)读取和存储数据-学习笔记
读取和存储数据 我们可以使用pt文件存储Tensor数据: import torch from torch import nnx = torch.ones(3) torch.save(x, 'x.pt ...
- 在pytorch中自定义dataset读取数据2021-1-8学习笔记
在pytorch中自定义dataset读取数据 utils import os import json import pickle import randomimport matplotlib.pyp ...
- 学习笔记Spark(四)—— Spark编程基础(创建RDD、RDD算子、文件读取与存储)
文章目录 一.创建RDD 1.1.启动Spark shell 1.2.创建RDD 1.2.1.从集合中创建RDD 1.2.2.从外部存储中创建RDD 任务1: 二.RDD算子 2.1.map与flat ...
- Pytorch学习笔记总结
往期Pytorch学习笔记总结: 1.Pytorch实战笔记_GoAI的博客-CSDN博客 2.Pytorch入门教程_GoAI的博客-CSDN博客 Pytorch系列目录: PyTorch学习笔记( ...
- PyTorch学习笔记(五):模型定义、修改、保存
往期学习资料推荐: 1.Pytorch实战笔记_GoAI的博客-CSDN博客 2.Pytorch入门教程_GoAI的博客-CSDN博客 本系列目录: PyTorch学习笔记(一):PyTorch环境安 ...
- PyTorch学习笔记(四):PyTorch基础实战
PyTorch实战:以FashionMNIST时装分类为例: 往期学习资料推荐: 1.Pytorch实战笔记_GoAI的博客-CSDN博客 2.Pytorch入门教程_GoAI的博客-CSDN博客 本 ...
- pytorch学习五、深度学习计算
来自于 https://tangshusen.me/Dive-into-DL-PyTorch/#/ 官方文档 https://pytorch.org/docs/stable/tensors.html ...
- PyTorch学习笔记2:nn.Module、优化器、模型的保存和加载、TensorBoard
文章目录 一.nn.Module 1.1 nn.Module的调用 1.2 线性回归的实现 二.损失函数 三.优化器 3.1.1 SGD优化器 3.1.2 Adagrad优化器 3.2 分层学习率 3 ...
- 【Pytorch学习笔记2】Pytorch的主要组成模块
个人笔记,仅用于个人学习与总结 感谢DataWhale开源组织提供的优秀的开源Pytorch学习文档:原文档链接 本文目录 1. Pytorch的主要组成模块 1.1 完成深度学习的必要部分 1.2 ...
最新文章
- python调用动态链接库windows_用win从python ctypes调用标准windows.dll的Segfault
- 1.13 Predicate操作Collection集合
- antd table排序 vue_ant-design-vue中的table取消默认不排序的状态
- Python生成器(send,close,throw)方法详解
- 35岁中年博士失业,决定给找高校教职的后辈一些建议
- android 4.4 屏幕方向,Android4.4屏幕旋转功能
- 谷爱凌惊“险”一跳,最少价值10个亿!
- 金融统计分析与挖掘实战3.3.1-3.3.3
- knx智能照明控制系统电路图_智能照明控制系统KNX
- 我的家乡html网页设计,创作一个以“我的家乡”为主题的网站
- 中国草鱼养殖产业发展现状分析,生态养殖是未来发展趋势「图」
- 电力系统分析—潮流计算代码Python编程练习(基于极坐标形式的常规牛拉法)
- 谈谈信息化、数字化、智能化和数智化的区别
- 【STM32】CubeMX+HAL库之 硬件IIC+DMA控制OLED(兼容SSD1306SH1106驱动)
- 内卷老员工之三级缓存和伪共享
- centos挂载u盘只读_解决CentOS自动挂载U盘/SD Card被识别为只读文件系统
- 一文读懂物联网 MQTT 协议之实战篇
- 芝加哥大学计算机科学在哪个学院,芝加哥大学计算机专业怎么样?
- QQ魔法表情实现原理
- 漫天桃花只为你飘落(代码实现)
热门文章
- vue 改变domclass_手机上的大片制作软件——如何使用VUE
- java精准查询mysql时间_在mysql查询中查找与指定日期时间最接近的日期时间
- python波峰波谷算法_波动均分算法
- 根据oracle入库数据进行告警,Oracle 启动故障案例之--ORA-600 [4193]错误
- 把Sublime Text 2打造成一个轻量级Python的IDE
- Query意图分析:记一次完整的机器学习过程(scikit learn library学习笔记)
- 【LeetCode】200. 岛屿的个数
- QT 子窗体 最大化 界面显示不对
- JZOJ 4421. aplusb
- js 实现 复制 功能 (zeroclipboard)