在实际中,我们有时需要把训练好的模型部署到很多不同的设备。这种情况下我们需要将内存训练好的模型参数存储在硬盘上供后续读取使用。

一、读写Tensor

我们可以直接使用save函数和load函数分别存储和读取Tensor。save使用Python的pickle模块将对象进行序列化,然后将序列化的对象保存到disk,使用save可以保存各种对象。包括模型、张量和字典等。而load使用pickle unpickle工具将pickle的对象文件反序列化为内存。

下面的例子创建了Tensor变量x,并将其存在文件名同为x.pt的文件里。

x=torch.ones(3)
torch.save(x,'x.pt')

然后我们将数据从存储的文件读回内存:

x2=torch.load('x.pt')
print(x2)

输出:

tensor([1., 1., 1.])

我们还可以存储一个Tensor列表并读回内存。

y=torch.zeros(4)
torch.save([x,y],'xy.pt')
xy_list=torch.load('xy.pt')
print(xy_list)

输出:

[tensor([1., 1., 1.]), tensor([0., 0., 0., 0.])]

存储并读取一个从字符串映射到Tensor的字典:

torch.save({'x':x,'y':y},'xy_dict.pt')
xy=torch.load('xy_dict.pt')
print(xy)

输出:

{'x': tensor([1., 1., 1.]), 'y': tensor([0., 0., 0., 0.])}

二、读写模型

2.1 state_dict

Pytorch中,Module的可学习参数(即权重和偏差)以及模块模型包含在参数中(通过model.parameters()访问)。state_dict是一个从参数名称映射到参数Tensor的字典对象。

class MLP(nn.Module):def __init__(self):super(MLP,self).__init__()self.hidden=nn.Linear(3,2)self.act=nn.ReLU()self.output=nn.Linear(2,1)def forward(self,x):a=self.act(self.hidden(x))return self.output(a)net=MLP()
print(net.state_dict())

输出:

OrderedDict([('hidden.weight', tensor([[-0.3099, -0.3298,  0.2657],[-0.3176,  0.3805,  0.0322]])), ('hidden.bias', tensor([-0.0345, -0.0435])), ('output.weight', tensor([[-0.1456,  0.0254]])), ('output.bias', tensor([0.3656]))])

注意,只有具有可学习参数的层(卷积层、线性层等)才有state_dict中的条目,优化器(optim)中也有一个state_dict,其中包含关于优化器状态以及所使用超参数的信息。

optimizer=torch.optim.SGD(net.parameters(),lr=0.001,momentum=0.9)
optimizer.state_dict()

输出:

{'state': {}, 'param_groups': [{'lr': 0.001, 'momentum': 0.9, 'dampening': 0, 'weight_decay': 0, 'nesterov': False, 'params': [0, 1, 2, 3]}]}

2.2 保存和加载模型

Pytorch中保存和加载训练模型有两种常见的方法:

  1. 仅保存和加载模型参数(state_dict):
  2. 保存和加载整个模型。

2.2.1 保存和加载state_dict(推荐方式)

保存:

torch.save(model.state_dict(),PATH)
#注意推荐的文件后缀名是pt或pth

加载:

model=TheModelClass(*args,**kwargs)
model.load_state_dict(torch.load(PATH))

2.2.2 保存和加载整个模型

保存:

torch.save(model,PATH)

加载:

model=torch.load(PATH)

接下来实现一个完整的过程:

X=torch.randn(2,3)
Y=net(X)PATH='./net.pt'
torch.save(net.state_dict(),PATH)net2=MLP()
net2.load_state_dict(torch.load(PATH))
Y2=net2(X)
print(Y2)
print(Y)

输出:

tensor([[-0.4076],[-0.4076]], grad_fn=<AddmmBackward>)
tensor([[-0.4076],[-0.4076]], grad_fn=<AddmmBackward>)

因为net和net2具有同样的模型参数,所以对同一个输入X的计算结果是一样的,上面的输出也证明了这点。

Pytorch数据的读取与存储相关推荐

  1. 3-7 pandas数据的读取与存储

    数据分析工具pandas 7. 数据的读取与存储 7.1 读操作 7.2 写操作 7.3 JSON格式 7.4 分块读取大文件 Pandas是一个强大的分析结构化数据的工具集,基于NumPy构建,提供 ...

  2. ios 沙盒 plist 数据的读取和存储

    plist 只能存储基本的数据类型 和 array  字典 [objc] view plaincopy - (void)saveArray { // 1.获得沙盒根路径 NSString *home  ...

  3. chapter.外部数据读取和存储1.3

    web数据的读取和存储 互联网时代,网络上每天都会产生大量的数据,如何从这些非结构化的数据中提取有效的信息进行分析尤为重要. 1.读取HTML表格 对于HTML网页中的表格数据,使用pandas中的r ...

  4. [pytorch、学习] - 4.5 读取和存储

    参考 4.5 读取和存储 到目前为止,我们介绍了如何处理数据以及如何构建.训练和测试深度学习模型.然而在实际中,我们有时需要把训练好的模型部署到很多不同的设备.在这种情况下,我们可以把内存中训练好的模 ...

  5. 2021年大数据HBase(十三):HBase读取和存储数据的流程

    全网最详细的大数据HBase文章系列,强烈建议收藏加关注! 新文章都已经列出历史文章目录,帮助大家回顾前面的知识重点. 目录 系列历史文章 HBase读取和存储数据的流程 一.HBase读取数据的流程 ...

  6. PyTorch框架学习八——PyTorch数据读取机制(简述)

    PyTorch框架学习八--PyTorch数据读取机制(简述) 一.数据 二.DataLoader与Dataset 1.torch.utils.data.DataLoader 2.torch.util ...

  7. (pytorch-深度学习系列)读取和存储数据-学习笔记

    读取和存储数据 我们可以使用pt文件存储Tensor数据: import torch from torch import nnx = torch.ones(3) torch.save(x, 'x.pt ...

  8. 使用 PyTorch 数据读取,JAX 框架来训练一个简单的神经网络

    使用 PyTorch 数据读取,JAX 框架来训练一个简单的神经网络 本文例程部分主要参考官方文档. JAX简介 JAX 的前身是 Autograd ,也就是说 JAX 是 Autograd 升级版本 ...

  9. PyTorch系列 (二): pytorch数据读取自制数据集并

    PyTorch系列 (二): pytorch数据读取 PyTorch 1: How to use data in pytorch Posted by WangW on February 1, 2019 ...

  10. python处理表格数据教程_python利用Excel读取和存储测试数据完成接口自动化教程...

    http_request2.py用于发起http请求 #读取多条测试用例 #1.导入requests模块 import requests #从 class_12_19.do_excel1导入read_ ...

最新文章

  1. 10没有基于策略的qos_分布式QoS算法解析
  2. linux 命令 跳过yes,Linux命令之yes
  3. python语言入门编程猫-少儿编程语言Python入门课程,尽在厦门编程猫
  4. IDEA的UML图介绍(一)
  5. Windows访问Linux的Tomcat,显示无法连接
  6. c语言打印数组元素_C程序打印元素差为0或1的子集数
  7. linux ntp手动授时,关于我校NTP授时服务的使用说明
  8. python程序发布 ubuntu_将Windows项目发布到Ubuntu服务器详细教程(Windows编程,Ubuntu服务器做解释器)...
  9. paddleocr识别VIN码
  10. Inceptor导出建表语句、存储过程
  11. 局域网远程桌面无法连接到远程计算机,局域网无法远程连接桌面怎么解决
  12. 3.7V锂电池升压到5V1A,FS2114升压转换芯片设计布局
  13. 数据挖掘之航空公司客户价值分析
  14. qt 打印html 分页打印,QT 打印的简单实现
  15. ARM开发板编译----MYS-6ULX
  16. ORB_SLAM2运行TUM数据和实时数据
  17. 电子学会图形化scratch编程等级考试二级真题答案解析(选择题)2020-9A卷
  18. 翻译视频字幕的软件叫什么?安利这几个软件给你
  19. QG工作室——智能与嵌入式系统小组
  20. 详解ISO13400文档-2

热门文章

  1. python安装pytesser模块
  2. 算法案例之有效字母异位词
  3. 在java中什么是所有类的父类_java中object是所有类的父类吗
  4. 因为一个YYYY-MM-dd的Bug,我被老板骂的狗血淋头!
  5. datagrid 重载本地数据_DataGrid 的DataSource重新加载数据
  6. java设计模式模式组合_Java设计模式---组合模式
  7. golang在windows下编译Linux下的文件
  8. Django结合Bootstrap分页显示mysql中的值
  9. 使用oracle执行txt语句,oracle常用SQL语句.txt
  10. 核磁谱图分析步骤_微谱技术:想要涂料开发,少不了仪器分析……