读取和存储数据

我们可以使用pt文件存储Tensor数据:

import torch
from torch import nnx = torch.ones(3)
torch.save(x, 'x.pt')

这样我们就将数据存储在名为x.pt的文件中了
我们可以从文件中将该数据读入内存:

x2 = torch.load('x.pt')
print(x2)

还可以存储Tensor列表到文件中,并读取:

y = torch.zeros(4)
torch.save([x, y], "xy.pt")
xy_list = torch.load("xy.pt")
print(xy_list)

不仅如此,还可以存储一个键值为Tensor变量的字典:

torch.save({'x':x, 'y':y}, "xy_dict")
xy_dict = torch.load("xy_dict")
print(xy_dict)

对模型参数进行读写:

对于Module类的对象,我们可以使用model.parameters()函数来访问模型的参数。而state_dict函数将会返回一个模型的参数名称到参数Tensor对象的一个字典对象。

class my_module(mm.Module):def __init__(self):super(my_module, self)self.hidden = nn.Linear(3, 2)self.action = nn.ReLU()self.output = nn.Linear(2, 1)def forward(self, x):middle = self.action(self.hidden(x))return self.output(middle) net = my_module()
net.state_dict()

输出:

OrderedDict([('hidden.weight', tensor([[ 0.2448,  0.1856, -0.5678],[ 0.2030, -0.2073, -0.0104]])),('hidden.bias', tensor([-0.3117, -0.4232])),('output.weight', tensor([[-0.4556,  0.4084]])),('output.bias', tensor([-0.3573]))])

但是,只有具有可变参数(可学习参数)的网络层才会在state_dict中,

同样的,优化器(optim)也有一个state_dict,这个函数返回一个字典,该字典包含优化器的状态以及其超参数信息:

optimizer = torch.optim.SGD(net.parameters(), lr=0.001, momentum=0.9)
optimizer.state_dict()

输出:

{'param_groups': [{'dampening': 0,'lr': 0.001,'momentum': 0.9,'nesterov': False,'params': [4736167728, 4736166648, 4736167368, 4736165352],'weight_decay': 0}],'state': {}}

那么就可以通过保存模型的state_dict来保存模型

torch.save(net.state_dict(), PATH)model = my_module(*args, **kwargs)
model.load_state_dict(torch.load(PATH))

还可以直接保存整个模型:

torch.save(model, PATH)
model = torch.load(PATH)

(pytorch-深度学习系列)读取和存储数据-学习笔记相关推荐

  1. 大数据学习系列:Hadoop3.0苦命学习(一)

    传送门: 大数据学习系列:Hadoop3.0苦命学习(一) 大数据学习系列:Hadoop3.0苦命学习(二) 大数据学习系列:Hadoop3.0苦命学习(三) 大数据学习系列:Hadoop3.0苦命学 ...

  2. 大数据学习系列:Hadoop3.0苦命学习(五)

    传送门: 大数据学习系列:Hadoop3.0苦命学习(一) 大数据学习系列:Hadoop3.0苦命学习(二) 大数据学习系列:Hadoop3.0苦命学习(三) 大数据学习系列:Hadoop3.0苦命学 ...

  3. 2021年大数据HBase(十三):HBase读取和存储数据的流程

    全网最详细的大数据HBase文章系列,强烈建议收藏加关注! 新文章都已经列出历史文章目录,帮助大家回顾前面的知识重点. 目录 系列历史文章 HBase读取和存储数据的流程 一.HBase读取数据的流程 ...

  4. 大数据学习系列:Hadoop3.0苦命学习(七)

    传送门: 大数据学习系列:Hadoop3.0苦命学习(一) 大数据学习系列:Hadoop3.0苦命学习(二) 大数据学习系列:Hadoop3.0苦命学习(三) 大数据学习系列:Hadoop3.0苦命学 ...

  5. Hadoop学习系列之Hadoop、Spark学习路线(很值得推荐)

    Hadoop学习系列之Hadoop.Spark学习路线(很值得推荐) 文章出自:http://www.cnblogs.com/zlslch/p/5448857.html 1 Java基础: 视频方面: ...

  6. 大数据Hadoop学习系列之Hadoop、Spark学习路线

    1 Java基础: 视频方面:推荐毕老师<毕向东JAVA基础视频教程>. 学习hadoop不需要过度的深入,java学习到javase,在多线程和并行化多多理解实践即可. 书籍方面:推荐李 ...

  7. 元强化学习系列(1)之:元学习入门基础

    元强化学习三境界 统计学是人工智能开始发展的一个基础,古老的人们从大量的数据中发现七所存在的规律,在以统计学为基础的 机器学习(machine learning)时代,复杂一点的分类问题效果就不好了, ...

  8. Hadoop学习系列之Hadoop、Spark学习路线

    1 Java基础: 视频方面:推荐毕老师<毕向东JAVA基础视频教程>. 学习hadoop不需要过度的深入,java学习到javase,在多线程和并行化多多理解实践即可. 书籍方面:推荐李 ...

  9. 大数据学习路线图(附上大数据学习资料)

    不知道你是计算机专业应届生还是已经从业者.总之,有java基础的学生学习大数据会轻松很多,零基础的小白都需要从java和linux学起.如果你是一个学习能力特别强,而且自律性也很强的人的话可以通过自学 ...

最新文章

  1. php中的抽象类(abstract class)和接口(interface)
  2. Ubuntu 20.04上安装Git方法
  3. hdu4539 郑厂长系列故事——排兵布阵 + POJ1158 炮兵阵地
  4. 【LeetCode从零单排】No21.MergeTwoSortedLists
  5. spring注解@service(service)括号中的service有什么用?
  6. 图像处理(十)基于特征线的图像变形-Siggraph 1992
  7. 第六届蓝桥杯省赛javaB组真题及答案
  8. 分布式系统保障—混沌工程—初识
  9. Win10 Explorer v1.3 有趣创意WordPress主题
  10. 双螺杆制冷压缩机行业调研报告 - 市场现状分析与发展前景预测
  11. 百度地图java批量获得经纬度_从百度地图API接口批量获取地点的经纬度
  12. kdtree java_KdTree理解与实现(Java)
  13. pcl之将QVTKWidget添加到QtCreator
  14. SharePoint2010企业开发最佳实践(八)---- SPWeb 对象
  15. Cuda: Handle Conflicting Installation Methods
  16. 把黄鸟hcy请求转换为autojs请求
  17. 首次公开!阿里搜索中台开发运维一体化实践
  18. vb6 sp6中文企业版
  19. RichEdit读取rtf格式
  20. P1074 靶形数独

热门文章

  1. 外设驱动库开发笔记37:S1336-5BQ光敏二极管作为光度计驱动
  2. 现代软件工程系列 学生读后感 梦断代码 DTSlob (1)
  3. java线程的优点_Java使用多线程的优势
  4. JAVA入门级教学之(猜数字测试)
  5. Nginx的配置实例(反向代理实例 )
  6. JAVA入门级教学之(super关键字)
  7. mysql 控制id复原_清空mysql表后,自增id复原
  8. list redis 怎样做排行_redis实现商品销量排行榜
  9. python矩阵相乘例题_百道Python入门级练习题(新手友好)第一回合——矩阵乘法...
  10. 自学Java必看的知识点,猿们怎么看?