众所周知,python的对象都可以通过torch.save和torch.load函数进行保存和加载(不知道?那你现在知道了(*^_^*)),比如:

x1 = {"d":"ddf","dd":'fdsf'}
torch.save(x1, 'a1.pt')x2 = ["ddf",'fdsf']
torch.save(x2, 'a2.pt')x3 = 1
torch.save(x3, 'a3.pt')x4 = torch.ones(3)
torch.save(x4, 'a4.pt')

读取的时候也是一样:

x5 = torch.load('a1.pt')x6 = torch.load('a2.pt')x7 = torch.load('a3.pt')x8 = torch.load('a4.pt')

这种非常简单粗暴,直接把整个对象扔进磁盘文件里保存,所以对于我们训练好的模型来说,因为训练好的模型也是一个对象,所以我们也可以使用这个方法把训练好的模型对象直接扔进去。但是这样有一个问题,就是模型对象开销比较大,比如最近包含1350亿个参数的那个有名的神经网络模型,如果把它保存到磁盘里面没有百八十T是保存不下的。所以我们是不是可以仅仅保存模型里面的关键数据呢?

答案是,可以!

因为决定一个模型是什么样有两方面的因素,一个是模型的结构是什么,另一个是模型的参数是什么,这两个定了,这个模型也就确定了。模型的结构在我们初始化模型对象的时候就定了,比如对于任意一个模型类,我们初始化它的两个对象,这两个对象代表的模型的结构肯定是一样的,区别就在于它们的参数不一样。所以我们保存模型的关键就是保存模型的参数,而模型的结构每次用的时候新建一个对象就好了,然后从磁盘里把模型的参数读取出来赋给这个对象。是不是超级简单?

那我们怎么拿到模型的参数呢?巧了!

模型的state_dict()函数就是返回模型的所有参数的(这个函数是nn.Module的,所以所有继承了nn.Module的模型类都有这个函数),比如:

class MLP(nn.Module):def __init__(self):super(MLP, self).__init__()self.hidden = nn.Linear(3, 2)self.act = nn.ReLU()self.output = nn.Linear(2, 1)def forward(self, x):a = self.act(self.hidden(x))return self.output(a)net = MLP()
net.state_dict()

输出:

OrderedDict([('hidden.weight',tensor([[-0.4195,  0.2609,  0.4325],[-0.4031,  0.2078,  0.2077]])),('hidden.bias', tensor([ 0.0755, -0.1408])),('output.weight', tensor([[0.2473, 0.6614]])),('output.bias', tensor([0.6191]))])

有的同学可能注意到了,self.act层的参数没有包含进来!

大哥,self.act层没有参数好吗(捂脸)

还有的同学可能想问,那有的层有参数、有的层没有参数,那万一加载的时候把某个参数给错了层怎么办?

完全不会!注意看,state_dict()返回的是一个字典,每一个张量都对应的有层的名字,清清楚楚,绝对没有问题。

那这样就简单了,举个例子看一下:

X = torch.randn(2, 3)
Y = net(X)  # 这个net就是上面创建的那个对象,我们把它的参数保存起来,然后新建一个net2,然后把保存的这些参数加载进net2,这样我们把X输入net2得到的Y2应该与Y是相等的PATH = "./net.pt"
torch.save(net.state_dict(), PATH)net2 = MLP()
net2.load_state_dict(torch.load(PATH))
Y2 = net2(X)
Y2 == Y

输出:

tensor([[1],[1]], dtype=torch.uint8)

输出的张量,代表Y2 == Y比较结果为true,也就是说是一样的,验证了我们的猜想(上面代码注释中的那个猜想)。

好了,以上就是pytorch保存和加载模型的两种方法,是不是非常简单?

阳阳:保存和加载pytorch模型的两种方法,选哪个好?​zhuanlan.zhihu.com

加载dict_PyTorch 7.保存和加载pytorch模型的两种方法相关推荐

  1. pytorch保存模型的两种方法

    文章目录 前言 一.保存整个模型 二.只保存参数 模型不同后缀名的区别 总结 前言 模型的本质是一堆用某种结构存储起来的参数 用数据对模型进行训练后得到了比较理想的模型,就需要将其存储起来,然后在需要 ...

  2. 【PyTorch】保存和载入模型的两种方法

    import torch import argparseparser = argparse.ArgumentParser("-") parser.add_argument(&quo ...

  3. 将图片保存到系统相冊的两种方法

    第一种:採用系统的api直接使用: ContentResolver cr = getContentResolver();String url = MediaStore.Images.Media.ins ...

  4. 保存DC到bmp图片的两种方法

    这里主要记录一下平时经常用到的控件贴图方法,在必要的时候将DC保存成bmp文件方便检查程序中贴图有时背景不正确的情况. 方法1: 纯Win32 GDI的方法,保存HBITMAP用的是CImage类 v ...

  5. php如果单数前面加0,php左边用0填充补齐的两种方法

    如果要自动生成学号,自动生成某某编号,就像这样的形式"d0000009"."d0000027"时,那么就会面临一个问题,怎么把左边用0补齐成这样8位数的编码呢? ...

  6. pytorch load state dict_Pytorch学习记录-使用Pytorch进行深度学习,保存和加载模型

    新建 Microsoft PowerPoint 演示文稿 (2).jpg 保存和加载模型 在完成60分钟入门之后,接下来有六节tutorials和五节关于文本处理的tutorials.争取一天一节.不 ...

  7. PyTorch 深度剖析:如何保存和加载PyTorch模型?

    点击上方"视学算法",选择加"星标"或"置顶" 重磅干货,第一时间送达 作者丨科技猛兽 编辑丨极市平台 导读 本文详解了PyTorch 模型 ...

  8. 保存和加载pytorch模型

    当保存和加载模型时,需要熟悉三个核心功能: torch.save:将序列化对象保存到磁盘.此函数使用Python的pickle模块进行序列化.使用此函数可以保存如模型.tensor.字典等各种对象. ...

  9. python torch exp_Python:PyTorch 保存和加载训练过的网络 (八十)

    保存和加载模型 在这个 notebook 中,我将为你展示如何使用 Pytorch 来保存和加载模型.这个步骤十分重要,因为你一定希望能够加载预先训练好的模型来进行预测,或是根据新数据继续训练. %m ...

最新文章

  1. 科普:String hashCode 方法为什么选择数字31作为乘子
  2. 部署 H3C CAS E0306
  3. Java编程思想读书笔记--第21章并发
  4. 通过SQL Server 2008 访问MySQL
  5. DDD领域驱动设计---战略设计(包括四色原型建模)
  6. SQL Server数据库中使用sql脚本删除指定表的列
  7. android 自定义flowlayout,Android 自定义ViewGroup之实现FlowLayout-标签流容器
  8. 使用matplotlib画图时不能同时打开太多张图
  9. jqGrid 使用案例及笔记
  10. 你都有哪些丢人的经历?
  11. 关于安装VS2008后SQL server 2005安装的问题
  12. 创建ipadWEB应用程序到主屏幕
  13. SEO网站优化是什么
  14. kaggle员工离职预测案例(3)
  15. java 对象逃逸 解决_Java中的逃逸问题心得
  16. wlh机器人_机器人小组活动实施方案
  17. 电子封装行业市场专项调查分析
  18. 洛谷-P1830 轰炸III
  19. 【网络】应用层-HTTP协议
  20. 学习java核心技术第3章的读书笔记

热门文章

  1. linux script 命令
  2. MoeCTF 2021Re部分------RedC4Bomb
  3. Windows消息机制学习笔记(一)—— 消息队列
  4. 160个Crackme032用ProcessMonitor拆解KeyFile保护
  5. Windbg无源码调试驱动
  6. 通过SEH 非inline hook
  7. 关于Uncaught SyntaxError: Unexpected identifier
  8. 9、MySQL定义条件和处理程序
  9. Python之多进程
  10. 【PAT乙级】1077 互评成绩计算 (20 分)