加载dict_PyTorch 7.保存和加载pytorch模型的两种方法
众所周知,python的对象都可以通过torch.save和torch.load函数进行保存和加载(不知道?那你现在知道了(*^_^*)),比如:
x1 = {"d":"ddf","dd":'fdsf'}
torch.save(x1, 'a1.pt')x2 = ["ddf",'fdsf']
torch.save(x2, 'a2.pt')x3 = 1
torch.save(x3, 'a3.pt')x4 = torch.ones(3)
torch.save(x4, 'a4.pt')
读取的时候也是一样:
x5 = torch.load('a1.pt')x6 = torch.load('a2.pt')x7 = torch.load('a3.pt')x8 = torch.load('a4.pt')
这种非常简单粗暴,直接把整个对象扔进磁盘文件里保存,所以对于我们训练好的模型来说,因为训练好的模型也是一个对象,所以我们也可以使用这个方法把训练好的模型对象直接扔进去。但是这样有一个问题,就是模型对象开销比较大,比如最近包含1350亿个参数的那个有名的神经网络模型,如果把它保存到磁盘里面没有百八十T是保存不下的。所以我们是不是可以仅仅保存模型里面的关键数据呢?
答案是,可以!
因为决定一个模型是什么样有两方面的因素,一个是模型的结构是什么,另一个是模型的参数是什么,这两个定了,这个模型也就确定了。模型的结构在我们初始化模型对象的时候就定了,比如对于任意一个模型类,我们初始化它的两个对象,这两个对象代表的模型的结构肯定是一样的,区别就在于它们的参数不一样。所以我们保存模型的关键就是保存模型的参数,而模型的结构每次用的时候新建一个对象就好了,然后从磁盘里把模型的参数读取出来赋给这个对象。是不是超级简单?
那我们怎么拿到模型的参数呢?巧了!
模型的state_dict()函数就是返回模型的所有参数的(这个函数是nn.Module的,所以所有继承了nn.Module的模型类都有这个函数),比如:
class MLP(nn.Module):def __init__(self):super(MLP, self).__init__()self.hidden = nn.Linear(3, 2)self.act = nn.ReLU()self.output = nn.Linear(2, 1)def forward(self, x):a = self.act(self.hidden(x))return self.output(a)net = MLP()
net.state_dict()
输出:
OrderedDict([('hidden.weight',tensor([[-0.4195, 0.2609, 0.4325],[-0.4031, 0.2078, 0.2077]])),('hidden.bias', tensor([ 0.0755, -0.1408])),('output.weight', tensor([[0.2473, 0.6614]])),('output.bias', tensor([0.6191]))])
有的同学可能注意到了,self.act层的参数没有包含进来!
大哥,self.act层没有参数好吗(捂脸)
还有的同学可能想问,那有的层有参数、有的层没有参数,那万一加载的时候把某个参数给错了层怎么办?
完全不会!注意看,state_dict()返回的是一个字典,每一个张量都对应的有层的名字,清清楚楚,绝对没有问题。
那这样就简单了,举个例子看一下:
X = torch.randn(2, 3)
Y = net(X) # 这个net就是上面创建的那个对象,我们把它的参数保存起来,然后新建一个net2,然后把保存的这些参数加载进net2,这样我们把X输入net2得到的Y2应该与Y是相等的PATH = "./net.pt"
torch.save(net.state_dict(), PATH)net2 = MLP()
net2.load_state_dict(torch.load(PATH))
Y2 = net2(X)
Y2 == Y
输出:
tensor([[1],[1]], dtype=torch.uint8)
输出的张量,代表Y2 == Y比较结果为true,也就是说是一样的,验证了我们的猜想(上面代码注释中的那个猜想)。
好了,以上就是pytorch保存和加载模型的两种方法,是不是非常简单?
阳阳:保存和加载pytorch模型的两种方法,选哪个好?zhuanlan.zhihu.com
加载dict_PyTorch 7.保存和加载pytorch模型的两种方法相关推荐
- pytorch保存模型的两种方法
文章目录 前言 一.保存整个模型 二.只保存参数 模型不同后缀名的区别 总结 前言 模型的本质是一堆用某种结构存储起来的参数 用数据对模型进行训练后得到了比较理想的模型,就需要将其存储起来,然后在需要 ...
- 【PyTorch】保存和载入模型的两种方法
import torch import argparseparser = argparse.ArgumentParser("-") parser.add_argument(&quo ...
- 将图片保存到系统相冊的两种方法
第一种:採用系统的api直接使用: ContentResolver cr = getContentResolver();String url = MediaStore.Images.Media.ins ...
- 保存DC到bmp图片的两种方法
这里主要记录一下平时经常用到的控件贴图方法,在必要的时候将DC保存成bmp文件方便检查程序中贴图有时背景不正确的情况. 方法1: 纯Win32 GDI的方法,保存HBITMAP用的是CImage类 v ...
- php如果单数前面加0,php左边用0填充补齐的两种方法
如果要自动生成学号,自动生成某某编号,就像这样的形式"d0000009"."d0000027"时,那么就会面临一个问题,怎么把左边用0补齐成这样8位数的编码呢? ...
- pytorch load state dict_Pytorch学习记录-使用Pytorch进行深度学习,保存和加载模型
新建 Microsoft PowerPoint 演示文稿 (2).jpg 保存和加载模型 在完成60分钟入门之后,接下来有六节tutorials和五节关于文本处理的tutorials.争取一天一节.不 ...
- PyTorch 深度剖析:如何保存和加载PyTorch模型?
点击上方"视学算法",选择加"星标"或"置顶" 重磅干货,第一时间送达 作者丨科技猛兽 编辑丨极市平台 导读 本文详解了PyTorch 模型 ...
- 保存和加载pytorch模型
当保存和加载模型时,需要熟悉三个核心功能: torch.save:将序列化对象保存到磁盘.此函数使用Python的pickle模块进行序列化.使用此函数可以保存如模型.tensor.字典等各种对象. ...
- python torch exp_Python:PyTorch 保存和加载训练过的网络 (八十)
保存和加载模型 在这个 notebook 中,我将为你展示如何使用 Pytorch 来保存和加载模型.这个步骤十分重要,因为你一定希望能够加载预先训练好的模型来进行预测,或是根据新数据继续训练. %m ...
最新文章
- 科普:String hashCode 方法为什么选择数字31作为乘子
- 部署 H3C CAS E0306
- Java编程思想读书笔记--第21章并发
- 通过SQL Server 2008 访问MySQL
- DDD领域驱动设计---战略设计(包括四色原型建模)
- SQL Server数据库中使用sql脚本删除指定表的列
- android 自定义flowlayout,Android 自定义ViewGroup之实现FlowLayout-标签流容器
- 使用matplotlib画图时不能同时打开太多张图
- jqGrid 使用案例及笔记
- 你都有哪些丢人的经历?
- 关于安装VS2008后SQL server 2005安装的问题
- 创建ipadWEB应用程序到主屏幕
- SEO网站优化是什么
- kaggle员工离职预测案例(3)
- java 对象逃逸 解决_Java中的逃逸问题心得
- wlh机器人_机器人小组活动实施方案
- 电子封装行业市场专项调查分析
- 洛谷-P1830 轰炸III
- 【网络】应用层-HTTP协议
- 学习java核心技术第3章的读书笔记