-柚子皮-

神经网络训练后我们需要将模型进行保存,要用的时候将保存的模型进行加载。

PyTorch 中保存模型主要分为两类:保存整个模型和只保存模型参数。

A common PyTorch convention is to save models using either a .pt or .pth file extension.

保存加载整个模型(不推荐)

保存整个网络模型

(网络结构+权重参数)

torch.save(model, 'net.pth')

出错:

AttributeError: Can't pickle local object 'AtomicModel.get_metrics.<locals>.<lambda>'
AttributeError: Can't pickle local object 'AtomicModel._get_metrics.<locals>._accuracy_score'
原因:pickle不能序列化lambda函数,或者是闭包。[python模块 - pickle模块]
加载整个网络模型

(可能比较耗时)

model = torch.load('net.pth')

只保存加载模型参数(推荐)

保存模型的权重参数

(速度快,占内存少)

torch.save(model.state_dict(), 'net_params.pth')

load模型参数
因为我们只保存了模型的参数,所以需要先定义一个网络对象,然后再加载模型参数。

# 构建一个网络结构
model = ClassNet()
# 将模型参数加载到新模型中,torch.load 返回的是一个 OrderedDict,说明.state_dict()只是把所有模型的参数都以OrderedDict的形式存下来。
state_dict = torch.load('net_params.pth')
model.load_state_dict(state_dict)

Note: 保存模型进行推理测试时,只需保存训练好的模型的权重参数,即推荐第二种方法。

load_state_dict的参数Strict=False

new_model.load_state_dict(state_dict, strict=False)
      如果哪一天我们需要重新写这个网络的,比如使用new_model,如果直接load会出现unexpected key。但是加上strict=False可以很容易地加载预训练的参数(注意检查key是否匹配),直接忽略不匹配的key,对于匹配的key则进行正常的赋值。

[Pytorch学习(十七)--- 模型load各种问题解决]

保存加载自定义模型

上面“保存加载整个模型”加载的 net.pt 其实一个字典,通常包含如下内容:

网络结构:输入尺寸、输出尺寸以及隐藏层信息,以便能够在加载时重建模型。
模型的权重参数:包含各网络层训练后的可学习参数,可以在模型实例上调用 state_dict() 方法来获取,比如只保存模型权重参数时用到的 model.state_dict()。
优化器参数:有时保存模型的参数需要稍后接着训练,那么就必须保存优化器的状态和所其使用的超参数,也是在优化器实例上调用 state_dict() 方法来获取这些参数。
其他信息:有时我们需要保存一些其他的信息,比如 epoch,batch_size 等超参数。

我们可以自定义需要save的内容

# saving a checkpoint assuming the network class named ClassNet
checkpoint = {'model': ClassNet(),
              'model_state_dict': model.state_dict(),
              'optimizer_state_dict': optimizer.state_dict(),
              'epoch': epoch}
torch.save(checkpoint, 'checkpoint.pkl')
上面的 checkpoint 是个字典,里面有4个键值对,分别表示网络模型的不同信息。

然后我们要load上面保存的自定义的模型

def load_checkpoint(filepath):
    checkpoint = torch.load(filepath)
    model = checkpoint['model']  # 提取网络结构
    model.load_state_dict(checkpoint['model_state_dict'])  # 加载网络权重参数
    optimizer = TheOptimizerClass()
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])  # 加载优化器参数
    
    for parameter in model.parameters():
        parameter.requires_grad = False
    model.eval()
    
    return model
    
model = load_checkpoint('checkpoint.pkl')

后续使用

如果加载模型只是为了进行推理测试,则将每一层的 requires_grad 置为 False,即固定这些权重参数;还需要调用 model.eval() 将模型置为测试模式,主要是将 dropout 和 batch normalization 层进行固定,否则模型的预测结果每次都会不同。

如果希望继续训练,则调用 model.train(),以确保网络模型处于训练模式。

跨设备保存加载模型

在 CPU 上加载在 GPU 上训练并保存的模型(Save on GPU, Load on CPU):

device = torch.device('cpu')
model = TheModelClass()
# Load all tensors onto the CPU device
model.load_state_dict(torch.load('net_params.pkl', map_location=device))
map_location:a function, torch.device, string or a dict specifying how to remap storage locations

令 torch.load() 函数的 map_location 参数等于 torch.device('cpu') 即可。 这里令 map_location 参数等于 'cpu' 也同样可以。

[PyTorch 中模型的使用]

from: -柚子皮-

ref: [SAVING AND LOADING MODELS]

PyTorch:模型save和load相关推荐

  1. keras/tensorflow 模型保存后重新加载准确率为0 model.save and load giving different result

    我在用别人的代码跑程序的时候遇到了这个问题: keras 模型保存后重新加载准确率为0 GitHub上有个issue:model.save and load giving different resu ...

  2. pytorch教程:save and load

    torch有一个"save and load"机制,即训练意外终止了,可以保存训练的中间文件(ckpt.mdl),然后恢复训练的最开始重新加载进来.

  3. PyTorch 深度剖析:如何保存和加载PyTorch模型?

    点击上方"视学算法",选择加"星标"或"置顶" 重磅干货,第一时间送达 作者丨科技猛兽 编辑丨极市平台 导读 本文详解了PyTorch 模型 ...

  4. TensorRT和PyTorch模型的故事

    点击上方"视学算法",选择加"星标"或"置顶" 重磅干货,第一时间送达 作者丨伯恩legacy 来源丨https://zhuanlan.zh ...

  5. 基于C++的PyTorch模型部署

    点击上方"小白学视觉",选择加"星标"或"置顶" 重磅干货,第一时间送达 引言 PyTorch作为一款端到端的深度学习框架,在1.0版本之后 ...

  6. 在C++平台上部署PyTorch模型流程+踩坑实录

    点击上方"小白学视觉",选择加"星标"或"置顶" 重磅干货,第一时间送达 导读 本文主要讲解如何将pytorch的模型部署到c++平台上的模 ...

  7. 如何使用TensorRT对训练好的PyTorch模型进行加速?

    点击上方"3D视觉工坊",选择"星标" 干货第一时间送达 作者丨伯恩legacy@知乎 来源丨https://zhuanlan.zhihu.com/p/8831 ...

  8. sklearn与pytorch模型的保存与读取

    当我们花了很长时间训练了一个模型,需要用该模型做其他事情(比如迁移学习),或者我们想把自己的机器学习模型分享出去的时候,我们这时候需要将我们的ML模型持久化到硬盘中去. 1.sklearn中模型的保存 ...

  9. 保存和加载pytorch模型

    当保存和加载模型时,需要熟悉三个核心功能: torch.save:将序列化对象保存到磁盘.此函数使用Python的pickle模块进行序列化.使用此函数可以保存如模型.tensor.字典等各种对象. ...

最新文章

  1. zabbix简介(第一章第一节)
  2. java中使用base64加密解密16进制方法
  3. sqoop数据倾斜_北京卓越讯通大数据岗位面试题分享
  4. 使用composer_在Google Cloud Composer(Airflow)上使用Selenium搜寻网页
  5. SCCM 2012系列16 操作系统播发⑤
  6. python中parse是什么_Python中optparse模块使用浅析
  7. win10计算机修改底色,win10电脑如何修改登陆背景
  8. 飞鸽传书(http://www.freeeim.com)软件下载
  9. 20155117王震宇 2006-2007-2 《Java程序设计》第5周学习总结
  10. laravel 扩展包
  11. 数字图像识别学习笔记(第二章-数字图像基础(1))
  12. 分享几个Python小技巧函数里的4个小花招 1
  13. 2021东北四省赛J. Transform(空间几何)
  14. OULU-NPU数据说明
  15. html中文字不自动换行 white-space style
  16. 心跳与超时:高并发高性能的时间轮超时器
  17. 【Niagara Vykon N4 】物联网学习 03照明控制及照明时间表
  18. win10下 oracle安装(11g)
  19. 养生:拔火罐有什么好处?
  20. 分享毕业后在北京租房的经验

热门文章

  1. 高通charge杂记
  2. java中log日志的使用(完全版)
  3. TensorFlow 2.3.0在Adding visible gpu devices: 0要卡很久,很慢,已解决
  4. 基于PLC控制的导热油温控系统如何实现远程监控
  5. 100种思维模型之机会成本思维模型-001
  6. 微会动资讯:2018年中国展览业发展回顾与2019年展望
  7. 【力扣-LeetCode】LCP 07. 传递信息 C++题解
  8. 删除maya阿诺德渲染器所有AOVS层
  9. Linux使用解压命令unzip报错:unzip: cannot find zipfile directory in one of xxx.zip
  10. 【JAVA进阶系列】JAVA 设计模式 -- 抽象工厂模式(Abstract Factory)