Pytorch(4)-模型保存-载入-eval()
模型保存与提取
- 1. 整个模型 保存-载入
- 2. 仅模型参数 保存-载入
- 3. GPU/CPU模型保存与导入
- 4. net.eval()--固定模型随机项
神经网络模型在线训练完之后需要保存下来,以便下次使用时可以直接导入已经训练好的模型。pytorch 提供两种方式保存模型:
方式1:保存整个网络,载入时直接载入整个网络,优点:代码简单,缺点需要的存储空间大
方式2:只保存网络参数,载入时需要先建立与原来网络一样结构的网络,然后将网络参数导入到该网络中,方式2的优缺点与方式1相反。
1. 整个模型 保存-载入
模型的结构参数都保存下来了
# 保存模型:设置 保存目录 和 保存文件名.扩展名,常用扩展名: .pkl .pth (扩展名只要好辨识就即可)
PATH="./model/mynet1.pkl"
# 导入官方提供的预训练模型
net1=torchvision.models.alexnet(pretrainend=True)
# 用数据集训练网络
.....
# 保存训练好的网络
torch.save(net1, PATH)
-----------------------------------------------------------
# 载入模型:设置载入路径,即模型保存的路径
PATH="./model/mynet1.pkl"
net1_1=torch.load(PATH)
2. 仅模型参数 保存-载入
保存时–只保存网络中的参数 (速度快, 占内存少), 载入时–需要提前创建好结构和net2是一样的
# 保存模型:设置 保存目录 和 保存文件名.扩展名,常用扩展名: .pkl .pth (扩展名只要好辨识就即可)
PATH="./model/mynet2.pkl"
# 导入官方提供的预训练模型
net2=torchvision.models.alexnet(pretrainend=True)
# 用数据集训练网络
.....
# 保存训练好的网络
torch.save(net1.state_dict(), PATH)
-----------------------------------------------------------
# 载入模型:设置载入路径,即模型保存的路径
PATH="./model/net2.pkl"
# 新建一个网络
net2_2=torchvision.models.alexnet(pretrained=True)
# 载入模型参数
net2_2.load_state_dict(torch.load(PATH))
迷糊的现象
在使用莫烦的文档做实验时,保存的两个文件:net.pkl,net_params.pkl大小差异比较大。保证在导入模型是比较快。
但是使用torchvision.models.模块中的一系列网络时,因为网络的参数很大,所以实验过程中用两种方法保存模型的文件大小是一致的。(猜测是内置模型使用torch.save(net1, ‘net.pkl’)时默认保存的是模型参数)
提供一个神经网络模型占用空间大小的计算方法:
参考文档:经典CNN模型计算量与内存需求分析
3. GPU/CPU模型保存与导入
在训练是模型是GPU/CPU,决定了模型载入时的模型原型。可以分为下面三种情况
(只展示导入整个网络模型的情况,具体实验还没做过):
1.CPU(原型)->CPU, GPU(原型)->GPU
torch.load( ‘net.pkl’)
2.GPU(原型)->CPU
torch.load(‘model_dict.pkl’, map_location=lambda storage, loc: storage)
3.CPU(模型文件)->GPU
torch.load(‘model_dic.pkl’, map_location=lambda storage, loc: storage.cuda)
参考文档:https://blog.csdn.net/u012135425/article/details/85217542
4. net.eval()–固定模型随机项
两种模型载入方式、.eval() 作用实验demo
step1: 载入模型
# 20191204 pytorch 模型载入测试
import torchvision as tvt
import torch
net1=tvt.models.alexnet(pretrained=True) # 1.自动从网上下载的预先训练模型
net2=torch.load("./model/mynet1.pkl") # 2.导入事先训练好的保存的整个网络net3=tvt.models.alexnet(pretrained=True) # 3.导入只保存模型参数的网络,需要新建一个网络
net3.load_state_dict(torch.load("../model/mynet2.pkl"))
net3.eval() # 固定dropout和归一化层,否则每次推理会生成不同的结果。
step2:输出三个网络同一层参数的和,net2 和net3 对应参数相等。可以看出来,两种模型保存和导入方式是等价的。
net1 tensor(-21257.7656, grad_fn=<SumBackward0>)
net2 tensor(-21253.9473, device='cuda:0', grad_fn=<SumBackward0>)
net3 tensor(-21253.9551, grad_fn=<SumBackward0>)
step3: 产生一个随机输入a,输入到网络1,2,3,打印输出结果。
a=torch.randn([1,3,224,224])
y1=net1(a)
y2=net2(a)
y3=net3(a)
# 第二次输入
y11=net1(a)
y22=net2(a)
y33=net3(a)
# 打印y1,y2,y3,y11,y22,y33(1000维的和)
y1: tensor(-5.2689, grad_fn=<SumBackward0>)
y2: tensor(-1.6695, device='cuda:0', grad_fn=<SumBackward0>)
y3: tensor(-4.4349, device='cuda:0', grad_fn=<SumBackward0>)y11: tensor(-4.4205, grad_fn=<SumBackward0>)
y22: tensor(-5.9475, device='cuda:0', grad_fn=<SumBackward0>)
y33: tensor(-4.4349, device='cuda:0', grad_fn=<SumBackward0>)
只有net3的输出是固定的,因为在模型导入的时候执行了net3.eval().
结论:无论采用 方式1 还是 方式2 导入的模型, 在模型测试时,都需要用.eval()方法固定一下网络在训练过程中的随机项目,如dropout 等,避免网络在同一个输入下产生不一样的结果。
Pytorch(4)-模型保存-载入-eval()相关推荐
- TensorFlow 模型保存/载入的两种方法
TensorFlow 模型保存/载入 我们在上线使用一个算法模型的时候,首先必须将已经训练好的模型保存下来.tensorflow保存模型的方式与sklearn不太一样,sklearn很直接,一个skl ...
- pytorch - swa_model模型保存的问题
AttributeError: Can't pickle local object 'AveragedModel.__init__.<locals>.avg_fn' 解决办法一: 解决办法 ...
- TensorFlow模型保存和加载方法
TensorFlow模型保存和加载方法 模型保存 import tensorflow as tfw1 = tf.Variable(tf.constant(2.0, shape=[1]), name=& ...
- tensor和模型 保存与加载 PyTorch
PyTorch教程-7:PyTorch中保存与加载tensor和模型详解 保存和读取Tensor PyTorch中的tensor可以保存成 .pt 或者 .pth 格式的文件,使用torch.save ...
- 【深度学习】——利用pytorch搭建一个完整的深度学习项目(构建模型、加载数据集、参数配置、训练、模型保存、预测)
目录 一.深度学习项目的基本构成 二.实战(猫狗分类) 1.数据集下载 2.dataset.py文件 3.model.py 4.config.py 5.predict.py 一.深度学习项目的基本构成 ...
- Pytorch —— 模型保存与加载
1.序列化与反序列化 模型的保存与加载就是序列化与反序列化,序列化与反序列化主要将内存与硬盘之间的数据转换关系,模型在内存中以对象的形式存储,在内存中对象不能长久地保存,所以需要将训练好的模型保存到硬 ...
- pytorch保存模型pth_浅谈pytorch 模型 .pt, .pth, .pkl的区别及模型保存方式
我们经常会看到后缀名为.pt, .pth, .pkl的pytorch模型文件,这几种模型文件在格式上有什么区别吗? 其实它们并不是在格式上有区别,只是后缀不同而已(仅此而已),在用torch.save ...
- paddlepaddle(六)模型保存与载入
目录 1.API分类 1.1基础API 1.2高级API 2.训练调优场景的模型&参数保存载入 2.1动态图参数保存载入 2.2静态图参数保存载入 3.训练部署场景的模型参数保存载入 3.1 ...
- PyTorch模型保存与加载
torch.save:保存序列化的对象到磁盘,使用了Python的pickle进行序列化,模型.张量.所有对象的字典. torch.load:使用了pickle的unpacking将pickled的对 ...
最新文章
- android test.apk,app-debug.apk和app-debug-androidTest.apk在安装macaca-android模块的时候build失败...
- WPF error: does not contain a static 'Main' method suitable for an entry point
- 数据挖掘实践(金融风控)
- 计算机学具制作,6计算机模板教案6算机模板教案.doc
- 64位linux并行计算大气模型效率优化研究,64位Linux并行计算大气模型效率优化研究...
- 实例57:python
- 什么是 IP 地址?
- Swift 语言概览 -自己在Xcode6 动手写1
- 眼前一亮的UI设计案例|插画世界里的网页首图
- 放下偏见,原来嵌入式程序员如此“妖娆”!
- L2-030 冰岛人 (25 分)-PAT 团体程序设计天梯赛 GPLT
- Spring 3.0: Unable to locate Spring NamespaceHandler for XML schema namespace
- 金蝶连服务器显示演示版,金蝶正版和金蝶演示版的区别
- 计算机描绘的基因结构图,推荐一款好用的基因结构图在线绘制工具!
- 基于FPGA实现的数字位同步锁相环设计
- HBase2.4.8详细教程(三)Java操作HBase
- ASP.NET MVC 分部页 PartialViewResult
- LP wizard无法生成PCB封装
- 微信名片加好友服务器繁忙,还傻乎乎的微信加好友?这些细节你要注意!
- python爬虫必备防检测工具
热门文章
- Asterisk配置文件说明
- string也可以很精彩
- java jni librtmp_librtmp 编译集成
- html 遍历div内check,vue+element中checkbox 实现遍历分组全选
- maven springboot 除去指定的jar包_SpringBoot的运行机制
- 高斯投影坐标系为什么是六七八位数
- 躺平也要看,2022年计算机相关考试汇总
- Idea Maven报错找不到程序包
- 【2050 Programming Competition - 2050 一万人码 】非官方部分题解(HDU)
- 【CodeForces - 706C】Hard problem(dp,字典序)