模型保存与提取

  • 1. 整个模型 保存-载入
  • 2. 仅模型参数 保存-载入
  • 3. GPU/CPU模型保存与导入
  • 4. net.eval()--固定模型随机项

神经网络模型在线训练完之后需要保存下来,以便下次使用时可以直接导入已经训练好的模型。pytorch 提供两种方式保存模型:

方式1:保存整个网络,载入时直接载入整个网络,优点:代码简单,缺点需要的存储空间大

方式2:只保存网络参数,载入时需要先建立与原来网络一样结构的网络,然后将网络参数导入到该网络中,方式2的优缺点与方式1相反。

1. 整个模型 保存-载入

模型的结构参数都保存下来了

# 保存模型:设置 保存目录 和 保存文件名.扩展名,常用扩展名: .pkl .pth (扩展名只要好辨识就即可)
PATH="./model/mynet1.pkl"
# 导入官方提供的预训练模型
net1=torchvision.models.alexnet(pretrainend=True)
# 用数据集训练网络
.....
# 保存训练好的网络
torch.save(net1, PATH)
-----------------------------------------------------------
# 载入模型:设置载入路径,即模型保存的路径
PATH="./model/mynet1.pkl"
net1_1=torch.load(PATH)

2. 仅模型参数 保存-载入

保存时–只保存网络中的参数 (速度快, 占内存少), 载入时–需要提前创建好结构和net2是一样的

# 保存模型:设置 保存目录 和 保存文件名.扩展名,常用扩展名: .pkl .pth (扩展名只要好辨识就即可)
PATH="./model/mynet2.pkl"
# 导入官方提供的预训练模型
net2=torchvision.models.alexnet(pretrainend=True)
# 用数据集训练网络
.....
# 保存训练好的网络
torch.save(net1.state_dict(), PATH)
-----------------------------------------------------------
# 载入模型:设置载入路径,即模型保存的路径
PATH="./model/net2.pkl"
# 新建一个网络
net2_2=torchvision.models.alexnet(pretrained=True)
# 载入模型参数
net2_2.load_state_dict(torch.load(PATH))

迷糊的现象

在使用莫烦的文档做实验时,保存的两个文件:net.pkl,net_params.pkl大小差异比较大。保证在导入模型是比较快。

但是使用torchvision.models.模块中的一系列网络时,因为网络的参数很大,所以实验过程中用两种方法保存模型的文件大小是一致的。(猜测是内置模型使用torch.save(net1, ‘net.pkl’)时默认保存的是模型参数)

提供一个神经网络模型占用空间大小的计算方法:

参考文档:经典CNN模型计算量与内存需求分析

3. GPU/CPU模型保存与导入

在训练是模型是GPU/CPU,决定了模型载入时的模型原型。可以分为下面三种情况
(只展示导入整个网络模型的情况,具体实验还没做过):

1.CPU(原型)->CPU, GPU(原型)->GPU

torch.load( ‘net.pkl’)

2.GPU(原型)->CPU

torch.load(‘model_dict.pkl’, map_location=lambda storage, loc: storage)

3.CPU(模型文件)->GPU

torch.load(‘model_dic.pkl’, map_location=lambda storage, loc: storage.cuda)

参考文档:https://blog.csdn.net/u012135425/article/details/85217542

4. net.eval()–固定模型随机项

两种模型载入方式、.eval() 作用实验demo

step1: 载入模型

# 20191204 pytorch 模型载入测试
import torchvision as tvt
import torch
net1=tvt.models.alexnet(pretrained=True)  # 1.自动从网上下载的预先训练模型
net2=torch.load("./model/mynet1.pkl")     # 2.导入事先训练好的保存的整个网络net3=tvt.models.alexnet(pretrained=True)  # 3.导入只保存模型参数的网络,需要新建一个网络
net3.load_state_dict(torch.load("../model/mynet2.pkl"))
net3.eval()                              #   固定dropout和归一化层,否则每次推理会生成不同的结果。

step2:输出三个网络同一层参数的和,net2 和net3 对应参数相等。可以看出来,两种模型保存和导入方式是等价的。

net1 tensor(-21257.7656, grad_fn=<SumBackward0>)
net2 tensor(-21253.9473, device='cuda:0', grad_fn=<SumBackward0>)
net3 tensor(-21253.9551, grad_fn=<SumBackward0>)

step3: 产生一个随机输入a,输入到网络1,2,3,打印输出结果。

a=torch.randn([1,3,224,224])
y1=net1(a)
y2=net2(a)
y3=net3(a)
# 第二次输入
y11=net1(a)
y22=net2(a)
y33=net3(a)
# 打印y1,y2,y3,y11,y22,y33(1000维的和)
y1: tensor(-5.2689, grad_fn=<SumBackward0>)
y2: tensor(-1.6695, device='cuda:0', grad_fn=<SumBackward0>)
y3: tensor(-4.4349, device='cuda:0', grad_fn=<SumBackward0>)y11: tensor(-4.4205, grad_fn=<SumBackward0>)
y22: tensor(-5.9475, device='cuda:0', grad_fn=<SumBackward0>)
y33: tensor(-4.4349, device='cuda:0', grad_fn=<SumBackward0>)

只有net3的输出是固定的,因为在模型导入的时候执行了net3.eval().
结论:无论采用 方式1 还是 方式2 导入的模型, 在模型测试时,都需要用.eval()方法固定一下网络在训练过程中的随机项目,如dropout 等,避免网络在同一个输入下产生不一样的结果。

Pytorch(4)-模型保存-载入-eval()相关推荐

  1. TensorFlow 模型保存/载入的两种方法

    TensorFlow 模型保存/载入 我们在上线使用一个算法模型的时候,首先必须将已经训练好的模型保存下来.tensorflow保存模型的方式与sklearn不太一样,sklearn很直接,一个skl ...

  2. pytorch - swa_model模型保存的问题

    AttributeError: Can't pickle local object 'AveragedModel.__init__.<locals>.avg_fn' 解决办法一: 解决办法 ...

  3. TensorFlow模型保存和加载方法

    TensorFlow模型保存和加载方法 模型保存 import tensorflow as tfw1 = tf.Variable(tf.constant(2.0, shape=[1]), name=& ...

  4. tensor和模型 保存与加载 PyTorch

    PyTorch教程-7:PyTorch中保存与加载tensor和模型详解 保存和读取Tensor PyTorch中的tensor可以保存成 .pt 或者 .pth 格式的文件,使用torch.save ...

  5. 【深度学习】——利用pytorch搭建一个完整的深度学习项目(构建模型、加载数据集、参数配置、训练、模型保存、预测)

    目录 一.深度学习项目的基本构成 二.实战(猫狗分类) 1.数据集下载 2.dataset.py文件 3.model.py 4.config.py 5.predict.py 一.深度学习项目的基本构成 ...

  6. Pytorch —— 模型保存与加载

    1.序列化与反序列化 模型的保存与加载就是序列化与反序列化,序列化与反序列化主要将内存与硬盘之间的数据转换关系,模型在内存中以对象的形式存储,在内存中对象不能长久地保存,所以需要将训练好的模型保存到硬 ...

  7. pytorch保存模型pth_浅谈pytorch 模型 .pt, .pth, .pkl的区别及模型保存方式

    我们经常会看到后缀名为.pt, .pth, .pkl的pytorch模型文件,这几种模型文件在格式上有什么区别吗? 其实它们并不是在格式上有区别,只是后缀不同而已(仅此而已),在用torch.save ...

  8. paddlepaddle(六)模型保存与载入

    目录 1.API分类 1.1基础API 1.2高级API 2.训练调优场景的模型&参数保存载入 2.1动态图参数保存载入 2.2静态图参数保存载入 3.训练部署场景的模型参数保存载入 3.1 ...

  9. PyTorch模型保存与加载

    torch.save:保存序列化的对象到磁盘,使用了Python的pickle进行序列化,模型.张量.所有对象的字典. torch.load:使用了pickle的unpacking将pickled的对 ...

最新文章

  1. android test.apk,app-debug.apk和app-debug-androidTest.apk在安装macaca-android模块的时候build失败...
  2. WPF error: does not contain a static 'Main' method suitable for an entry point
  3. 数据挖掘实践(金融风控)
  4. 计算机学具制作,6计算机模板教案6算机模板教案.doc
  5. 64位linux并行计算大气模型效率优化研究,64位Linux并行计算大气模型效率优化研究...
  6. 实例57:python
  7. 什么是 IP 地址?
  8. Swift 语言概览 -自己在Xcode6 动手写1
  9. 眼前一亮的UI设计案例|插画世界里的网页首图
  10. 放下偏见,原来嵌入式程序员如此“妖娆”!
  11. L2-030 冰岛人 (25 分)-PAT 团体程序设计天梯赛 GPLT
  12. Spring 3.0: Unable to locate Spring NamespaceHandler for XML schema namespace
  13. 金蝶连服务器显示演示版,金蝶正版和金蝶演示版的区别
  14. 计算机描绘的基因结构图,推荐一款好用的基因结构图在线绘制工具!
  15. 基于FPGA实现的数字位同步锁相环设计
  16. HBase2.4.8详细教程(三)Java操作HBase
  17. ASP.NET MVC 分部页 PartialViewResult
  18. LP wizard无法生成PCB封装
  19. 微信名片加好友服务器繁忙,还傻乎乎的微信加好友?这些细节你要注意!
  20. python爬虫必备防检测工具

热门文章

  1. Asterisk配置文件说明
  2. string也可以很精彩
  3. java jni librtmp_librtmp 编译集成
  4. html 遍历div内check,vue+element中checkbox 实现遍历分组全选
  5. maven springboot 除去指定的jar包_SpringBoot的运行机制
  6. 高斯投影坐标系为什么是六七八位数
  7. 躺平也要看,2022年计算机相关考试汇总
  8. Idea Maven报错找不到程序包
  9. 【2050 Programming Competition - 2050 一万人码 】非官方部分题解(HDU)
  10. 【CodeForces - 706C】Hard problem(dp,字典序)