PyTorch教程-7:PyTorch中保存与加载tensor和模型详解
保存和读取Tensor
PyTorch中的tensor可以保存成 .pt 或者 .pth 格式的文件,使用torch.save()方法保存张量,使用torch.load()来读取张量:

x = torch.rand(4,5)
torch.save(x, "./myTensor.pt")y = torch.load("./myTensor.pt")
print(y)
tensor([[0.9363, 0.2292, 0.1612, 0.9558, 0.9414],[0.3649, 0.9622, 0.3547, 0.5772, 0.7575],[0.7005, 0.8115, 0.6132, 0.6640, 0.1173],[0.6999, 0.1023, 0.8544, 0.7708, 0.1254]])

当然,save和load方法也适用于其他数据类型,比如list、tuple、dict等:

a = {'a':torch.rand(2,2), 'b':torch.rand(3,4)}
torch.save(a, "./myDict.pth")b = torch.load("./myDict.pth")
print(b)
{'a': tensor([[0.9356, 0.0240],[0.6004, 0.3923]]), 'b': tensor([[0.0222, 0.1799, 0.9172, 0.8159],[0.3749, 0.6689, 0.4796, 0.5772],[0.5016, 0.5279, 0.5109, 0.0592]])}

保存Tensor的纯数据
PyTorch中,使用 torch.save 保存的不仅有其中的数据,还包括一些它的信息,包括它与其它数据(可能存在)的关系,这一点是很有趣的。
详细的原文可以参考:https://pytorch.org/docs/stable/notes/serialization.html#saving-and-loading-tensors-preserves-views

这里结合例子给出一个简单的解释。

x = torch.arange(20)
y = x[:5]torch.save([x,y], "./myTensor.pth")
x_, y_ = torch.load("././myTensor.pth")y_ *= 100print(x_)
tensor([  0, 100, 200, 300, 400,   5,   6,   7,   8,   9,  10,  11,  12,  13, 14,  15,  16,  17,  18,  19])

比如在上边的例子中,我们看到y是x的一个前五位的切片,当我们同时保存x和y后,它们的切片关系也被保存了下来,再将他们加载出来,它们之间依然保留着这个关系,因此可以看到,我们将加载出来的 y_ 乘以100后,x_ 也跟着变化了。

如果不想保留他们的关系,其实也很简单,再保存y之前使用 clone 方法保存一个只有数据的“克隆体”,这样就能只保存数据而不保留关系:

x = torch.arange(20)
y = x[:5]torch.save([x,y.clone()], "./myTensor.pth")
x_, y_ = torch.load("././myTensor.pth")y_ *= 100print(x_)
tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19])

当我们只保存y而不同时保存x会怎样呢?这样的话确实可以避免如上的情况,即不会再在读取数据后保留他们的关系,但是实际上有一个不容易被看到的影响存在,那就是保存的数据所占用的空间会和其“父亲”级别的数据一样大:

x = torch.arange(1000)
y = x[:5]torch.save(y, "./myTensor1.pth")
torch.save(y.clone(), "./myTensor2.pth")y1_ = torch.load("./myTensor1.pth")
y2_ = torch.load("./myTensor2.pth")print(y1_.storage().size())
print(y2_.storage().size())
1000
5

如果你去观察他们保存的文件,会发现占用的空间确实存在很大的差距:

myTensor1.pth      9KB
myTensor2.pth      1KB

综上所述,对于一些“被关系”的数据来说,如果不想保留他们的关系,最好使用 clone 来保存其“纯数据”

保存与加载模型
保存与加载state_dict
这是一种较为推荐的保存方法,即只保存模型的参数,保存的模型文件会较小,而且比较灵活。但是当加载时,需要先实例化一个模型,然后通过加载将参数赋给这个模型的实例,也就是说加载的时候也需要直到模型的结构。

保存:

torch.save(model.state_dict(), PATH)

加载:

model = TheModelClass(*args, **kwargs)
model.load_state_dict(torch.load(PATH))
model.eval()

比较重要的点是:

保存模型时调用 state_dict() 获取模型的参数,而不保存结构
加载模型时需要预先实例化一个对应的结构
加载模型使用 load_state_dict 方法,其参数不是文件路径,而是 torch.load(PATH)
如果加载出来的模型用于验证,不要忘了使用 model.eval() 方法,它会丢弃 dropout、normalization 等层,因为这些层不能在inference的时候使用,否则得到的推断结果不一致。
一个例子:

import torch
import torch.nn as nn
import torch.nn.functional as Fclass Net(nn.Module):def __init__(self):super(Net,self).__init__()# convolution layersself.conv1 = nn.Conv2d(1,6,3)self.conv2 = nn.Conv2d(6,16,3)# fully-connection layersself.fc1 = nn.Linear(16*6*6,120)self.fc2 = nn.Linear(120,84)self.fc3 = nn.Linear(84,10)def forward(self,x):# max pooling over convolution layersx = F.max_pool2d(F.relu(self.conv1(x)),2)x = F.max_pool2d(F.relu(self.conv2(x)),2)# fully-connected layers followed by activation functionsx = x.view(-1,16*6*6)x = F.relu(self.fc1(x))x = F.relu(self.fc2(x))# final fully-connected without activation functonx = self.fc3(x)return xnet = Net()torch.save(net.state_dict(), "./myModel.pth")loaded_net = Net()
loaded_net.load_state_dict(torch.load("./myModel.pth"))
loaded_net.eval()
Net((conv1): Conv2d(1, 6, kernel_size=(3, 3), stride=(1, 1))(conv2): Conv2d(6, 16, kernel_size=(3, 3), stride=(1, 1))(fc1): Linear(in_features=576, out_features=120, bias=True)(fc2): Linear(in_features=120, out_features=84, bias=True)(fc3): Linear(in_features=84, out_features=10, bias=True)
)

保存与加载整个模型
这种方式不仅保存、加载模型的数据,也包括模型的结构一并存储,存储的文件会较大,好处是加载时不需要提前知道模型的结构,解来即用。实际上这与上文提到的保存Tensor是一致的。

保存:

torch.save(model, PATH)

加载:

model = torch.load(PATH)
model.eval()

同样的,如果加载的模型用于inference,则需要使用 model.eval()

保存与加载模型与其他信息
有时我们不仅要保存模型,还要连带保存一些其他的信息。比如在训练过程中保存一些 checkpoint,往往除了模型,还要保存它的epoch、loss、optimizer等信息,以便于加载后对这些 checkpoint 继续训练等操作;或者再比如,有时候需要将多个模型一起打包保存等。这些其实也很简单,正如我们上文提到的,torch.save 可以保存dict、list、tuple等多种数据结构,所以一个字典可以很完美的解决这个问题,比如一个简单的例子:

# saving
torch.save({'epoch': epoch,'model_state_dict': model.state_dict(),'optimizer_state_dict': optimizer.state_dict(),'loss': loss,...}, PATH)# loading
model = TheModelClass(*args, **kwargs)
optimizer = TheOptimizerClass(*args, **kwargs)checkpoint = torch.load(PATH)
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
epoch = checkpoint['epoch']
loss = checkpoint['loss']model.eval()
# - or -
model.train()

跨设备存储与加载
跨设备的情况指对于一些数据的保存、加载在不同的设备上,比如一个在CPU上,一个在GPU上的情况,大致可以分为如下几种情况:

从CPU保存,加载到CPU
实际上,这就是默认的情况,我们上文提到的所有内容都没有关心设备的问题,因此也就适应于这种情况。

从CPU保存,加载到GPU
保存:依旧使用默认的方法
加载:有两种可选的方式
使用 torch.load() 函数的 map_location 参数指定加载后的数据保存的设备
对于加载后的模型使用 to() 函数发送到设备

torch.save(net.state_dict(), PATH)device = torch.device("cuda")loaded_net = Net()
loaded_net.load_state_dict(torch.load(PATH, map_location=device))
# or
loaded_net.to(device)

从GPU保存,加载到CPU
保存:依旧使用默认的方法
加载:只能使用 torch.load() 函数的 map_location 参数指定加载后的数据保存的设备

torch.save(net.state_dict(), PATH)device = torch.device("cuda")loaded_net = Net()
loaded_net.load_state_dict(torch.load(PATH, map_location=device))

从GPU保存,加载到GPU
保存:依旧使用默认的方法
加载:只能使用 对于加载后的模型进行 to() 函数发送到设备

torch.save(net.state_dict(), PATH)device = torch.device("cuda")loaded_net = Net()
loaded_net.to(device)

tensor和模型 保存与加载 PyTorch相关推荐

  1. pytorch模型保存与加载总结

    pytorch模型保存与加载总结 模型保存与加载方式 模型保存 方式一 只存储模型中的参数,该方法速度快,占用空间少(官方推荐使用) model = VGGNet() torch.save(model ...

  2. Pytorch —— 模型保存与加载

    1.序列化与反序列化 模型的保存与加载就是序列化与反序列化,序列化与反序列化主要将内存与硬盘之间的数据转换关系,模型在内存中以对象的形式存储,在内存中对象不能长久地保存,所以需要将训练好的模型保存到硬 ...

  3. PyTorch系列入门到精通——模型保存与加载

    PyTorch系列入门到精通--模型保存与加载

  4. 飞桨框架2.0RC新增模型保存、加载方案,与用户场景完美匹配,更全面、更易用

    通过一段时间系统的课程学习,算法攻城狮张同学对于飞桨框架的使用越来越顺手,于是他打算在企业内尝试使用飞桨进行AI产业落地. 但是AI产业落地并不是分秒钟的事情,除了专业技能过硬,熟悉飞桨的使用外,在落 ...

  5. [tensorflow] 模型保存、加载与转换详解

    TensorFlow模型加载与转换详解 本次讲解主要涉及到TensorFlow框架训练时候模型文件的管理以及转换. 首先我们需要明确TensorFlow模型文件的存储格式以及文件个数: model_f ...

  6. 机器学习之模型——保存与加载

    机器学习之模型--保存与加载 知识点 fit() transform() fit_transform() 目的 API 流程 获取数据 划分数据集 标准化 预估器 保存模型 加载模型 得出模型 模型评 ...

  7. TensorFlow2.0 —— 模型保存与加载

    目录 1.Keras版本模型保存与加载 2.自定义版本模型保存与加载 3.总结 1.Keras版本模型保存与加载 保存模型权重(model.save_weights) 保存HDF5文件(model.s ...

  8. gensim bm25模型保存与加载

    gensim bm25模型保存与加载 1. 模型保存 2. 模型加载 20210719修改: python version:3.6.12 gensim version:3.8.3 使用bm25模型计算 ...

  9. tf第七讲:模型保存与加载(tf.train.Saver()tf.saved_model)及fine_tune(梯度冻结)

      大家好,我是爱编程的喵喵.双985硕士毕业,现担任全栈工程师一职,热衷于将数据思维应用到工作与生活中.从事机器学习以及相关的前后端开发工作.曾在阿里云.科大讯飞.CCF等比赛获得多次Top名次.现 ...

最新文章

  1. Python 计算Mesh顶点法向量
  2. nagios自定义监控脚本
  3. Linux 查看 占用内存最多 占用cpu最多 程序(类似top,监视)
  4. 中考 计算机录取 步骤,中考录取时间及录取流程详解
  5. 程序员被公司辞退12天,前领导要求回公司讲清楚代码,结果懵了
  6. Spring3开发实战 之 第四章:对JDBC和ORM的支持
  7. 机器学习(时间序列):线性回归之虚拟变量 dummy variables
  8. migration java_如何重置migration
  9. ie11不兼容java_IE11浏览器网页不兼容的四种解决方法
  10. typora上传图片出现Can‘t find smms config错误
  11. MySQL必知必会总结
  12. 如何使用IP摄像头进行电脑直播
  13. 路由控制——ACL、IP-Prefix List
  14. 内网渗透-横向渗透2
  15. 【75】颜色分类--荷兰国旗问题
  16. L2-001 紧急救援
  17. 简易方法提高手机3G上网速度(2G转3G)
  18. 为Apple Watch设计:产品策略
  19. 计算机类专业本科生毕业论文+答辩那点事
  20. 【EoSL】Introduction

热门文章

  1. 2022-2028年中国测绘设备行业研究及前瞻分析报告
  2. pip install faiss-gpu失败unable to execute ‘swig‘: No such file or directory
  3. python重难点之装饰器详解
  4. 【C#】数组的最大最小值
  5. 机房收费系统总结【5】——无用功
  6. AI 芯片的分类及技术
  7. Linux 2 的 Windows 子系统上发布 CUDA
  8. 2021年大数据ELK(二十七):数据可视化(Visualize)
  9. 商城数据库表设计介绍
  10. 谷歌不更新android studio,彻底迈向64位:谷歌宣布 Android Studio 将停止 32 位版本更新...