pytorch保存模型有两种方法:

  1. 保存整个模型 (结构+参数)
  2. 只保存参数(官方推荐)

两者都是用torch.save(obj, dir)实现,这个函数的作用是将对象保存到磁盘中,它的内部是使用Python的pickle实现。
两种方法的区别其实就是obj参数的不同:前者的obj是整个model对象,后者的obj是从model里获取的存储了model参数的词典,推荐用第二种,虽然麻烦了一丁点,但是比较灵活,有利于实现预训练、参数迁移等操作。

保存整个模型

这种方法很简单,保存和加载就两行代码,和Python pickle包的用法是一样的,把model当作一个对象直接保存加载就行。

# 保存
model = Mymodel()
torch.save(model, path)
# 加载
model = torch.load(path)

Note:PyTorch约定使用.pt或.pth后缀命名保存文件。

保存参数

重点介绍一下这种方法,一般训完一个模型之后我们不会单独只保存一个模型的参数,为了方便后续操作,比如恢复训、参数迁移等,我们会保存当前转态的一个快照,具体信息可以根据自己的需要,下面列出几个方面:

  • 模型参数
  • 优化器参数
  • loss
  • epoch
  • args

把这些信息用字典包装起来,然后保存即可。

这种方式保存的模型只是它的参数,所以我们在加载时需要先创建好模型,然后再把参数加载进去,如下:

# 获得保存信息
save_data = {'model_state_dict': model.state_dict(),'optimizer_state_dict': optimizer.state_dict(),'loss': loss,'epoch': epoch,'args': args...
}
# 保存
torch.save(save_data , path)
load_data = torch.load(path)
model = Mymodel()
optimizer = Myoptimizer()
# 加载参数
model.load_state_dict(load_data ['model_state_dict'])
optimizer.load_state_dict(load_data ['optimizer_state_dict'])
...

Note:PyTorch约定使用.pt或.pth后缀命名保存文件。

Pytorch之模型加载/保存相关推荐

  1. PyTorch框架学习十九——模型加载与保存

    PyTorch框架学习十九--模型加载与保存 一.序列化与反序列化 二.PyTorch中的序列化与反序列化 1.torch.save 2.torch.load 三.模型的保存 1.方法一:保存整个Mo ...

  2. Python时间序列模型推理预测实战:时序推理数据预处理(特征生成、lstm输入结构组织)、模型加载、模型预测结果保存、条件判断模型循环运行

    Python时间序列模型推理预测实战:时序推理数据预处理(特征生成.lstm输入结构组织).模型加载.模型预测结果保存.条件判断模型循环运行 目录

  3. 【TensorFlow】TensorFlow从浅入深系列之十三 -- 教你深入理解模型持久化(模型保存、模型加载)

    本文是<TensorFlow从浅入深>系列之第13篇 TensorFlow从浅入深系列之一 -- 教你如何设置学习率(指数衰减法) TensorFlow从浅入深系列之二 -- 教你通过思维 ...

  4. 时间序列预测:SVR用于时间序列预测代码+模型保存+模型加载+网格搜索+交叉验证

    本文关于SVR时间序列的预测,详细步骤如下: 1.数据读取 2.数据集的划分(采用滑动窗口重叠切片) 3.训练数据集掷乱 4.SVR参数设置(网格搜索+交叉验证) 5.SVR模型训练+模型保存 6.S ...

  5. pytorch无法将模型加载到gpu上

    通常是model = model.to(cuda)就好了 但由于搭建模型的时候,forward函数的代码直接调用这个类外部的函数,如图所示: 在这里直接调用了外部的函数, 这个函数里面有torch.n ...

  6. PyTorch:数据加载,数学原理,猫鱼分类,CNN,预训练,迁移学习

    1,数据加载 PyTorch开发了与数据交互的标准约定,所以能一致地处理数据,而不论处理图像.文本还是音频.与数据交互的两个主要约定是数据集(dataset)和数据加载器(dataloader).数据 ...

  7. OpenGL OBJ模型加载.

    在我们前面绘制一个屋,我们可以看到,需要每个立方体一个一个的自己来推并且还要处理位置信息.代码量大并且要时间.现在我们通过加载模型文件的方法来生成模型文件,比较流行的3D模型文件有OBJ,FBX,da ...

  8. mmsegmentation的demo测试-模型加载

    想学下语义分割框架,首先是跑demo,但是模型加载不熟悉,半天不知道从哪里去加载模型,并且查不到相关资料,现记一下以便后用. model_urls = {'fcn_resnet50_coco': 'h ...

  9. Tensorflow 获取model中的变量列表,用于模型加载等

    目录 前言 1. 用tensorflow自带的工具 2. 用tensorflow.contrib.slim. 3. 从保存的model中提取var_list 4. 其他 前言 在加载预训练的网络模型时 ...

  10. 草图大师sketchup的模型加载到cesium里显示 带贴图

    草图大师sketchup的模型加载到cesium里显示 带贴图 前言 一.sketchUp导出obj格式 二.在Blender中转换数据 1.导入数据 2.调整尺寸和视图 三.cesium中加载 总结 ...

最新文章

  1. JAVA线程池的简单实现及优先级设置
  2. 大数据之Linux常用命令
  3. ECS 支持 IPv6 啦,快来尝鲜吧~
  4. 超越IEtab、网银支付助手,无需再次登陆的Firefox的IE插件
  5. Flask 知识总结
  6. android黑科技系列——爆破一款应用的签名验证问题
  7. 安卓文本编辑器php cpp,用安卓原生控件封装一个简易的富文本编辑器
  8. 测试markdown的发布
  9. 中国剩余定理 —— 入门
  10. php打出等边三角形,CSS 如何进行单一div的正多边形变换
  11. 厉害,Spring Boot 2.3.0 刚刚发布了!
  12. php url地址栏传中文乱码解决方法集合
  13. 利用gretna计算小世界网络属性等图论指标笔记
  14. Launcher folder、foldericon
  15. Windows系统封装(二)导入封装工具安装软件,安装系统。
  16. 我国的频段划分,请参考~
  17. 系统运维工程师30岁学python_一名Linux系统运维工程师的自述
  18. 巴比特 | 元宇宙每日必读:微博动漫将招募全球各类虚拟偶像并为其提供扶持...
  19. 加尔布雷思:人类永恒的愚蠢,就是把莫名其妙的担忧当成智力超群。
  20. 物联网板开发入门指南

热门文章

  1. mysql如何从两个表取出内容_如何从mysql中的两个表中获取数据?
  2. crossentropyloss 输入_Pytorch常用的交叉熵损失函数CrossEntropyLoss()详解
  3. c++ map 析构函数_面向偷懒的编程 - C/C++项目中使用Go的分布式系统库
  4. Spring AOP 讲解(Pointcut、Before、Around、AfterReturning、After)
  5. abb外部轴零位校准_【ABB】ABB机器人外部轴参数(KpKvTi)调试
  6. c语言乘法除法结合律,有关C语言运算符优先级和结合律的思考
  7. linux 打开cgm软件,cgm文件扩展名,cgm文件怎么打开?
  8. php元素排序算法,php 4大基础排序算法
  9. 字符数组的ss.toString()和new String(ss)的问题
  10. 异步读写之利用完成历程