参考

4.5 读取和存储

到目前为止,我们介绍了如何处理数据以及如何构建、训练和测试深度学习模型。然而在实际中,我们有时需要把训练好的模型部署到很多不同的设备。在这种情况下,我们可以把内存中训练好的模型参数存储在硬盘上供后续读取使用。

4.5.1 读写tensor

我们可以直接使用save函数和load函数分别存储和读取Tensor

下面的例子创建了Tensor变量x,并将其存储在文件名为x.pt的文件里.

import torch
import torch.nn as nnx = torch.ones(3)
torch.save(x, 'x.pt')

然后我们将数据从存储的文件读回内存

x2 = torch.load('x.pt')
x2


存储一个Tensor列表并返回

y =torch.zeros(4)
torch.save([x, y], 'xy.pt')
xy_list = torch.load('xy.pt')
xy_list



存储并读取一个从字符串映射到Tensor的字典

torch.save({'x': x,'y': y
}, 'xy_dict.pt')
xy = torch.load('xy_dict.pt')
xy


4.5.2 读写模型

4.5.2.1 state_dict

static_dict是一个从参数名称映射到参数Tensor的字典对象

class MLP(nn.Module):def __init__(self):super(MLP, self).__init__()self.hidden = nn.Linear(3, 2)self.act = nn.ReLU()self.output = nn.Linear(2, 1)def forward(self, x):a = self.act(self.hidden(x))return self.output(a)net = MLP()
net.state_dict()

注意,只有具有可学习参数的层(卷积层、线性层)才有 state_dict中的条目

optimizer = torch.optim.SGD(net.parameters(), lr= 0.001, momentum=0.9)
optimizer.state_dict()

4.5.2.2 保存和加载模型

PyTorch中保存和加载训练模型有两种常见的方法:

  1. 仅保存和加载模型参数(state_dict)
  2. 保存和加载整个模型。

1. 保存加载static_dict(推荐方式)
torch.save(model.state_dict(), PATH)

# 保存
torch.save(model.state_dict(), PATH)# 加载
model = TheModelClass(*args, **kwargs)
model.load_state_dict(torch.laod(PATH))

2. 保存和加载整个模型

# 保存
torch.save(model, PATH)# 加载
model = torch.load(PATH)

采用第一种方法来试验一下:

X = torch.randn(2, 3)
Y = net(X)PATH = "./net.pt"
torch.save(net.state_dict(), PATH)net2 = MLP()
net2.load_state_dict(torch.load(PATH))
Y2 = net2(X)Y2 ==Y

[pytorch、学习] - 4.5 读取和存储相关推荐

  1. (pytorch-深度学习系列)读取和存储数据-学习笔记

    读取和存储数据 我们可以使用pt文件存储Tensor数据: import torch from torch import nnx = torch.ones(3) torch.save(x, 'x.pt ...

  2. 在pytorch中自定义dataset读取数据2021-1-8学习笔记

    在pytorch中自定义dataset读取数据 utils import os import json import pickle import randomimport matplotlib.pyp ...

  3. 学习笔记Spark(四)—— Spark编程基础(创建RDD、RDD算子、文件读取与存储)

    文章目录 一.创建RDD 1.1.启动Spark shell 1.2.创建RDD 1.2.1.从集合中创建RDD 1.2.2.从外部存储中创建RDD 任务1: 二.RDD算子 2.1.map与flat ...

  4. Pytorch学习笔记总结

    往期Pytorch学习笔记总结: 1.Pytorch实战笔记_GoAI的博客-CSDN博客 2.Pytorch入门教程_GoAI的博客-CSDN博客 Pytorch系列目录: PyTorch学习笔记( ...

  5. PyTorch学习笔记(五):模型定义、修改、保存

    往期学习资料推荐: 1.Pytorch实战笔记_GoAI的博客-CSDN博客 2.Pytorch入门教程_GoAI的博客-CSDN博客 本系列目录: PyTorch学习笔记(一):PyTorch环境安 ...

  6. PyTorch学习笔记(四):PyTorch基础实战

    PyTorch实战:以FashionMNIST时装分类为例: 往期学习资料推荐: 1.Pytorch实战笔记_GoAI的博客-CSDN博客 2.Pytorch入门教程_GoAI的博客-CSDN博客 本 ...

  7. pytorch学习五、深度学习计算

    来自于 https://tangshusen.me/Dive-into-DL-PyTorch/#/ 官方文档 https://pytorch.org/docs/stable/tensors.html ...

  8. PyTorch学习笔记2:nn.Module、优化器、模型的保存和加载、TensorBoard

    文章目录 一.nn.Module 1.1 nn.Module的调用 1.2 线性回归的实现 二.损失函数 三.优化器 3.1.1 SGD优化器 3.1.2 Adagrad优化器 3.2 分层学习率 3 ...

  9. 【Pytorch学习笔记2】Pytorch的主要组成模块

    个人笔记,仅用于个人学习与总结 感谢DataWhale开源组织提供的优秀的开源Pytorch学习文档:原文档链接 本文目录 1. Pytorch的主要组成模块 1.1 完成深度学习的必要部分 1.2 ...

最新文章

  1. python调用动态链接库windows_用win从python ctypes调用标准windows.dll的Segfault
  2. 1.13 Predicate操作Collection集合
  3. antd table排序 vue_ant-design-vue中的table取消默认不排序的状态
  4. Python生成器(send,close,throw)方法详解
  5. 35岁中年博士失业,决定给找高校教职的后辈一些建议
  6. android 4.4 屏幕方向,Android4.4屏幕旋转功能
  7. 谷爱凌惊“险”一跳,最少价值10个亿!
  8. 金融统计分析与挖掘实战3.3.1-3.3.3
  9. knx智能照明控制系统电路图_智能照明控制系统KNX
  10. 我的家乡html网页设计,创作一个以“我的家乡”为主题的网站
  11. 中国草鱼养殖产业发展现状分析,生态养殖是未来发展趋势「图」
  12. 电力系统分析—潮流计算代码Python编程练习(基于极坐标形式的常规牛拉法)
  13. 谈谈信息化、数字化、智能化和数智化的区别
  14. 【STM32】CubeMX+HAL库之 硬件IIC+DMA控制OLED(兼容SSD1306SH1106驱动)
  15. 内卷老员工之三级缓存和伪共享
  16. centos挂载u盘只读_解决CentOS自动挂载U盘/SD Card被识别为只读文件系统
  17. 一文读懂物联网 MQTT 协议之实战篇
  18. 芝加哥大学计算机科学在哪个学院,芝加哥大学计算机专业怎么样?
  19. QQ魔法表情实现原理
  20. 漫天桃花只为你飘落(代码实现)

热门文章

  1. vue 改变domclass_手机上的大片制作软件——如何使用VUE
  2. java精准查询mysql时间_在mysql查询中查找与指定日期时间最接近的日期时间
  3. python波峰波谷算法_波动均分算法
  4. 根据oracle入库数据进行告警,Oracle 启动故障案例之--ORA-600 [4193]错误
  5. 把Sublime Text 2打造成一个轻量级Python的IDE
  6. Query意图分析:记一次完整的机器学习过程(scikit learn library学习笔记)
  7. 【LeetCode】200. 岛屿的个数
  8. QT 子窗体 最大化 界面显示不对
  9. JZOJ 4421. aplusb
  10. js 实现 复制 功能 (zeroclipboard)