当保存和加载模型时,需要熟悉三个核心功能:

torch.save:将序列化对象保存到磁盘。此函数使用Python的pickle模块进行序列化。使用此函数可以保存如模型、tensor、字典等各种对象。
torch.load:使用pickle的unpickling功能将pickle对象文件反序列化到内存。此功能还可以有助于设备加载数据。
torch.nn.Module.load_state_dict:使用反序列化函数 state_dict 来加载模型的参数字典。

Python中对于模型数据的保存和加载操作都是引用Python内置的pickle包,使用pickle.dump()和pickle.load()方法。在Pytorch中也有同样功能的方法提供。

>>>torch.save(model,'model.pkl') #保存整个模型
>>>model = torch.load('model.pkl') #加载整个模型
>>>torch.save(alexnet.state_dict(),'params.pkl') #保存网格中的参数
>>>alexnet.load_state_dict(torch.load('params.pkl')) #加载网格中的参数

在torchvision.models模块里,PyTorch提供了一些常用的模型:

常用模型
AlexNet
VGG
ResNet
SqueezeNet
DenseNet
Inception v3

可以使用torch.util.model_zoo来预加载它们,具体设置通过参数pretrained=True来实现。

>>>import torchvision.models as models
>>>ResNet18 = models.ResNet18(pretrained=True)
>>>alexnet = models.alexnet(pretrained=True)
>>>squeezenet = models.squeezenet1_0(pretrained=True)
>>>vgg16 = models.vgg16(pretrained=True)
>>>densenet = models.densenet161(pretrained=True)
>>>inception = models.inception_v3(pretrained=True)

加载这类预训练模型的过程中,还可以进行微处理。

>>>pretrained_dict = model_zoo.load_url(model_urls['resnet134'])
>>>model_dict = model.state_dict()
>>>pretrained_dict = {k:v for k,v in pretrained_dict.items()if k in model_dict}#将pretrained_dict里不属于model_dict的键剔除掉
>>>model_dict.update(pretrained_dict)  #更新现在有的model_dict
>>>model.load_state_dict(model_dict)

参考
《PyTorch机器学习从入门到实战》

保存和加载pytorch模型相关推荐

  1. 加载dict_PyTorch 7.保存和加载pytorch模型的两种方法

    众所周知,python的对象都可以通过torch.save和torch.load函数进行保存和加载(不知道?那你现在知道了(*^_^*)),比如: x1 = {"d":" ...

  2. PyTorch 深度剖析:如何保存和加载PyTorch模型?

    点击上方"视学算法",选择加"星标"或"置顶" 重磅干货,第一时间送达 作者丨科技猛兽 编辑丨极市平台 导读 本文详解了PyTorch 模型 ...

  3. pytorch模型的保存和加载、checkpoint

    pytorch模型的保存和加载.checkpoint 其实之前笔者写代码的时候用到模型的保存和加载,需要用的时候就去度娘搜一下大致代码,现在有时间就来整理下整个pytorch模型的保存和加载,开始学习 ...

  4. TensorFlow模型保存和加载方法

    TensorFlow模型保存和加载方法 模型保存 import tensorflow as tfw1 = tf.Variable(tf.constant(2.0, shape=[1]), name=& ...

  5. keras模型保存和加载

    (一)保存和加载整个模型 ​ 包含模型的结构.权重.训练配置项(损失函数.优化器).优化器状态,允许准确地从上次结束的地方开始训练. 1.训练完模型后 path='.../.../xxx.h5' mo ...

  6. PyTorch | 保存和加载模型教程

    点击上方"算法猿的成长",选择"加为星标" 第一时间关注 AI 和 Python 知识 图片来自 Unsplash,作者: Jenny Caywood 2019 ...

  7. python保存模型与参数_基于pytorch的保存和加载模型参数的方法

    当我们花费大量的精力训练完网络,下次预测数据时不想再(有时也不必再)训练一次时,这时候torch.save(),torch.load()就要登场了. 保存和加载模型参数有两种方式: 方式一: torc ...

  8. PyTorch学习笔记2:nn.Module、优化器、模型的保存和加载、TensorBoard

    文章目录 一.nn.Module 1.1 nn.Module的调用 1.2 线性回归的实现 二.损失函数 三.优化器 3.1.1 SGD优化器 3.1.2 Adagrad优化器 3.2 分层学习率 3 ...

  9. pytorch load state dict_Pytorch学习记录-使用Pytorch进行深度学习,保存和加载模型

    新建 Microsoft PowerPoint 演示文稿 (2).jpg 保存和加载模型 在完成60分钟入门之后,接下来有六节tutorials和五节关于文本处理的tutorials.争取一天一节.不 ...

最新文章

  1. 劳动节特别活动,钉钉、支付宝合种,2-4天领证,限量9个名额
  2. 东芝发布15nm SG5固态硬盘 容量高达1TB
  3. hadoop命令帮助
  4. 盘点区块链的2018:技术与工具演进篇
  5. 产品功能上线前,如何高效的埋点?
  6. 运行时常量池_从String的intern()到常量池
  7. 现在电脑的主流配置_玩手游是因为电脑配置差?现在来告诉你这些网游需要啥配置...
  8. 2018年最受大家欢迎的五大机器学习工具和五大数据学习工具
  9. jq控制div是否展示_jQuery控制多个DIV的显示和隐藏
  10. 清远机器人编程_致敬逆行者:棒棒贝贝为清远援鄂人员子女免费提供一年乐高编程课...
  11. 1.3亿豪宅被拆 抱头痛哭的房主,真实身份竟是保健品大佬,曾涉嫌传销...
  12. CentOS7 安装xen(在虚拟机上成功,实体机测试死机!)
  13. GNU ARM汇编--(五)中断汇编之嵌套中断处理
  14. 【一天一个C++小知识】010.malloc/free和new/delete
  15. 2019湖南多校第四场
  16. BootStrap运行流程解析
  17. 《大数据之路:阿里巴巴大数据实践》-第1篇 数据技术篇 -第3章数据同步
  18. 7月23日云栖精选夜读丨前端leader找我谈心:我是如何从刚毕业的前端菜鸟一步步成长为前端工程师的?...
  19. 计算机网络应用基础总结,(完整版)计算机网络应用基础高教版对口高考复习资料总结...
  20. 敏捷团队的病与药——阿里健康医药B2B团队敏捷转型手记

热门文章

  1. 难道前途真的比钱重要吗
  2. Canvas3 汉化QA和BUG反馈
  3. Cisco Catalyst 2960系列交换机资料
  4. 输出400以内的smith数java,史密斯(A.O.Smith) 空气净化器 KJ400F-B11
  5. python实现cc攻击_运维纪录:遭遇CC攻击,防御与查水表
  6. mysql缓解oom发生的方法_MySQL Slave 触发 oom-killer解决方法_MySQL
  7. c 中ajax不起作用,Jquery AJAX調用:$(this)在成功后不起作用
  8. Ajax Session失效跳转登录页面的方法
  9. java怎么将前端的数据存到关联的表中_MySQL数据库性能优化
  10. 2018帮助_字节跳动扶贫获“北京市扶贫协作奖”,一年帮助8万贫困人口增收