torch.save()模型的保存于加载
保存模型主要分为两类:保存整个模型和只保存模型参数
1.保存加载整个模型(不推荐):
保存整个网络模型,网络结构+权重参数
torch.save(model,'net.pth')
加载整个网络模型(可能比较耗时)
model=torch.load('net.pth')
2.只保存加载模型参数(推荐)
保存模型的权重参数(速度快,占内存少)
torch.save(model.state_dict(),'net_params.pth')
load 模型参数
因为我们只保存了 模型的参数,所以需要先定义一个网络对象,然后再加载模型参数。
model=ClassNet()
#将模型参数加载到新模型中,torch.load返回的是一个OrderedDict,说明.state_dict()只是把所有模型的参数都已OrderedDict的形式存下来。
state_dict=torch.load('net_params.pth')
model.load_state_dict(state_dict)
Note:保存模型进行推理测试时,只需保存训练好的模型的权重参数,即推荐第二种方法。
load_state_dict的参数strict=False
new_model.load_state_dict(state_dict,strict=False)
如果哪一天我们需要重新写这个网络的,比如使用new_model,如果直接load会出现unexpected key. 但是加上strict=False可以很容易地加载预训练的参数(注意检查key是否匹配),直接忽略不匹配的key,对于匹配的key则进行正常的赋值。
3.保存加载自定义模型
上面“保存加载整个模型”加载的net.pt其实是一个字典,通常包含以下内容:
网络结构:输入尺寸,输出尺寸以及隐藏层的信息,以便能够在加载时重建模型。
模型的权重参数,包含各网络层训练后的可学习参数,可以在模型实例上调用state_dict()方法来获取,比如只保存模型权重参数时用到的model.state_dict().
优化器参数:有时保存模型的参数需要稍后接着训练,那么就必须保存优化器的状态和其所使用的超参数,也是在优化器实例上调用state_dict()方法来获取这些参数。
其他信息:有时我们需要保存一些其他的信息,比如epoch, batch_size等超参数
我们可以自定义需要save的内容
#saving a checkpoint assuming the network class named ClassNet
checkpoint={'modle':ClassNet(),'model_state_dict':model.state_dict(),'optimize_state_dict':optimizer.state_dict(),'epoch':epoch}
torch.save(checkpoint,'checkpoint.pkl')
上面的checkpoint是个字典,里面有4各键值对,分别表示网络模型的不同信息。
然后我们要loda上面保存的自定义模型
def load_checkpoint(filepath):checkpoint=torch.load(filepath)def load_checkpoint(filepath):checkpoint = torch.load(filepath)model=checkpoint['model']#提前网络结构model.load_state_dict(checkpoint['model_state_dict'])#加载网络权重参数optimizer=TheOptimizerClass()optimizer.load_state_dict(checkpoint['optimizer_state_dict'])#加载优化器参数for parameter in model.parameters():parameter.requires_grad=Falsemodel.eval()return model
modle=load_checkpoint('checkpoint.pkl')
后续使用
如果加载模型只是为了进行推理测试,则将每一层的requires_grad置为False,即固定这些权重参数,还需要调用model.eval()将模型置为测试模式,主要是将dropout和batch normalization层进行固定,否则模型的预测结果每次都会不同。
如果需要继续训练,则调用model.train(),以确保网络模型处于训练模式。
跨设备保存加载模型
在CPU上加载在GPU上训练并保存的模型(save on GPU, load on CPU
device=torch.device('cpu')
model=TheModelClass()
#load all tensors onto the CPU device
model.load_state_dict(torch.load('net_params.pkl',map_location=device))
令torch.load()函数的map_location参数等于torch.device('cpu')即可,这里令map_location参数等于‘cpu'也同样可以。
torch.save()模型的保存于加载相关推荐
- pytorch模型的保存与加载
我们先创建一个模型,使用的是pytorch笔记--简易回归问题_刘文巾的博客-CSDN博客 的主体框架,唯一不同的是,我这里用的是torch.nn.Sequential来定义模型框架,而不是那篇博客里 ...
- TensorFlow2.0:模型的保存与加载
** 一.权重参数的保存与加载 ** network.save_weights('weights.ckpt') network.load_weights('weights.ckpt') 权重参数的保存 ...
- [TensorFlow深度学习入门]实战七·简便方法实现TensorFlow模型参数保存与加载(ckpt方式)
[TensorFlow深度学习入门]实战七·简便方法实现TensorFlow模型参数保存与加载(ckpt方式) 个人网站–> http://www.yansongsong.cn TensorFl ...
- Pytorch:模型的保存与加载 torch.save()、torch.load()、torch.nn.Module.load_state_dict()
Pytorch 保存和加载模型后缀:.pt 和.pth 1 torch.save() [source] 保存一个序列化(serialized)的目标到磁盘.函数使用了Python的pickle程序用于 ...
- pytorch保存模型pth_Day159:模型的保存与加载
网络结构和参数可以分开的保存和加载,因此,pytorch保存模型有两种方法: 保存 整个模型 (结构+参数) 只保存模型参数(官方推荐) # 保存整个网络torch.save(model, check ...
- MXNET学习笔记(二):模型的保存与加载
当序列化 NDArray 的时候,我们序列化的是NDArray 中保存的 tensor 值.当序列化 Symbol 的时候,我们序列化的是 Graph. Symbol序列化 当序列化 Symbol 的 ...
- 【xgboost】xgboost模型的保存与加载
xgboost模型的保存方法 有多种方法可以保存xgboost模型,包括pickle,joblib,以及原生的save_model,load_model函数 其中Pickle是Python中序列化对象 ...
- tensorflow1.0模型的保存、加载、在训练
1.checkpoint文件总览 tensorflow保存的模型文件如下所示: .meta文件保存的是图结构,meta文件是pb(protocol buffer)格式文件,包含变量.op.集合等. c ...
- pb 保存变量文件名_【Tensorflow 2.0 正式版教程】模型的保存、加载与迁移
模型的保存和加载可以直接通过Model类的save_weights和load_weights实现.默认的保存格式为tensorflow的checkpoint格式,也可以手动设置保存为h5文件. mo ...
最新文章
- [Markdown] 数学公式
- 【数据挖掘】分类任务简介 ( 分类概念 | 分类和预测 | 分类过程 | 训练集 | 测试集 | 数据预处理 | 有监督学习 )
- 经典论文复现 | 基于深度卷积网络的图像超分辨率算法
- NoSQL架构实践(三)——以NoSQL为缓存
- 草稿 断开式绑定combobox 1128
- c语言常用库函数使用方法,c语言常用库函数使用方法及用途
- JAVA并发包内容_java并发包
- 谷歌承诺未来三年将支付10亿美元新闻费用
- Android Fragment (一)
- dl360 g7安装linux,HPDL360G7服务器安装说明.ppt
- word保存不了磁盘已满_磁盘到底该不该分区?容量不够怎么办?
- php 清除浮动,清除浮动的几种方法
- ncverilog脚本_NC-Verilog控制命令
- 计算机cf编程,警察牧马人宏自定义编程计算机游戏鼠标有线大声笑/ cf英雄联盟光速质量保证....
- Qt自定义实现的日历控件
- php音频怎么打开,音频管理器怎么设置
- redis管理_Redis 桌面管理工具Redis Desktop Manager
- 零售EDI:家乐福Carrefour EDI需求分析
- python-docx 标题字体设置失败如何解决?
- 小猫爪:i.MX RT1050学习笔记26-RT1xxx系列的FlexCAN详解
热门文章
- idea加密解密C++实现
- 上半年要写的博客文章21
- 基于js实现的简易记账小本
- 如何利用Slack客户端漏洞窃取Slack用户下载的所有文件
- java 微信定位到市_java 微信公众号地理位置获取
- 《寓言中的经济学》简明纪要 - Part 1
- layui 横向表单_fwr-layui-formdesigner
- 苹果笔记本显卡性能测试软件,测试结果来了!新款Macbook Pro显卡性能怎样?
- JDK8 | 字符串收集器 Collectors.joining()
- 计算机查看流量记录,教你用路由器查看电脑数据流量使用情况的方法