PyTorch | 保存和加载模型教程
点击上方“算法猿的成长”,选择“加为星标”
第一时间关注 AI 和 Python 知识
图片来自 Unsplash,作者: Jenny Caywood
2019 年第 72 篇文章,总第 96 篇文章
总共 7000 字,建议收藏阅读
原题 | SAVING AND LOADING MODELS
作者 | Matthew Inkawhich
原文 | https://pytorch.org/tutorials/beginner/saving_loading_models.html
译者 | kbsc13("算法猿的成长"公众号作者)
声明 | 翻译是出于交流学习的目的,欢迎转载,但请保留本文出于,请勿用作商业或者非法用途
简介
本文主要介绍如何加载和保存 PyTorch 的模型。这里主要有三个核心函数:
torch.save
:把序列化的对象保存到硬盘。它利用了 Python 的pickle
来实现序列化。模型、张量以及字典都可以用该函数进行保存;torch.load
:采用pickle
将反序列化的对象从存储中加载进来。torch.nn.Module.load_state_dict
:采用一个反序列化的state_dict
加载一个模型的参数字典。
本文主要内容如下:
什么是状态字典(state_dict)?
预测时加载和保存模型
加载和保存一个通用的检查点(Checkpoint)
在同一个文件保存多个模型
采用另一个模型的参数来预热模型(Warmstaring Model)
不同设备下保存和加载模型
1. 什么是状态字典(state_dict)
PyTorch 中,一个模型(torch.nn.Module
)的可学习参数(也就是权重和偏置值)是包含在模型参数(model.parameters()
)中的,一个状态字典就是一个简单的 Python 的字典,其键值对是每个网络层和其对应的参数张量。模型的状态字典只包含带有可学习参数的网络层(比如卷积层、全连接层等)和注册的缓存(batchnorm
的 running_mean
)。优化器对象(torch.optim
)同样也是有一个状态字典,包含的优化器状态的信息以及使用的超参数。
由于状态字典也是 Python 的字典,因此对 PyTorch 模型和优化器的保存、更新、替换、恢复等操作都很容易实现。
下面是一个简单的使用例子,例子来自:https://pytorch.org/tutorials/beginner/blitz/cifar10_tutorial.html#sphx-glr-beginner-blitz-cifar10-tutorial-py
# Define model
class TheModelClass(nn.Module):def __init__(self):super(TheModelClass, self).__init__()self.conv1 = nn.Conv2d(3, 6, 5)self.pool = nn.MaxPool2d(2, 2)self.conv2 = nn.Conv2d(6, 16, 5)self.fc1 = nn.Linear(16 * 5 * 5, 120)self.fc2 = nn.Linear(120, 84)self.fc3 = nn.Linear(84, 10)def forward(self, x):x = self.pool(F.relu(self.conv1(x)))x = self.pool(F.relu(self.conv2(x)))x = x.view(-1, 16 * 5 * 5)x = F.relu(self.fc1(x))x = F.relu(self.fc2(x))x = self.fc3(x)return x# Initialize model
model = TheModelClass()# Initialize optimizer
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)# Print model's state_dict
print("Model's state_dict:")
for param_tensor in model.state_dict():print(param_tensor, "\t", model.state_dict()[param_tensor].size())# Print optimizer's state_dict
print("Optimizer's state_dict:")
for var_name in optimizer.state_dict():print(var_name, "\t", optimizer.state_dict()[var_name])
上述代码先是简单定义一个 5 层的 CNN,然后分别打印模型的参数和优化器参数。
输出结果:
Model's state_dict:
conv1.weight torch.Size([6, 3, 5, 5])
conv1.bias torch.Size([6])
conv2.weight torch.Size([16, 6, 5, 5])
conv2.bias torch.Size([16])
fc1.weight torch.Size([120, 400])
fc1.bias torch.Size([120])
fc2.weight torch.Size([84, 120])
fc2.bias torch.Size([84])
fc3.weight torch.Size([10, 84])
fc3.bias torch.Size([10])Optimizer's state_dict:
state {}
param_groups [{'lr': 0.001, 'momentum': 0.9, 'dampening': 0, 'weight_decay': 0, 'nesterov': False, 'params': [4675713712, 4675713784, 4675714000, 4675714072, 4675714216, 4675714288, 4675714432, 4675714504, 4675714648, 4675714720]}]
2. 预测时加载和保存模型
加载/保存状态字典(推荐做法)
保存的代码:
torch.save(model.state_dict(), PATH)
加载的代码:
model = TheModelClass(*args, **kwargs)
model.load_state_dict(torch.load(PATH))
model.eval()
当需要为预测保存一个模型的时候,只需要保存训练模型的可学习参数即可。采用 torch.save()
来保存模型的状态字典的做法可以更方便加载模型,这也是推荐这种做法的原因。
通常会用 .pt
或者 .pth
后缀来保存模型。
记住
在进行预测之前,必须调用
model.eval()
方法来将dropout
和batch normalization
层设置为验证模型。否则,只会生成前后不一致的预测结果。load_state_dict()
方法必须传入一个字典对象,而不是对象的保存路径,也就是说必须先反序列化字典对象,然后再调用该方法,也是例子中先采用torch.load()
,而不是直接model.load_state_dict(PATH)
加载/保存整个模型
保存:
torch.save(model, PATH)
加载:
# Model class must be defined somewhere
model = torch.load(PATH)
model.eval()
保存和加载模型都是采用非常直观的语法并且都只需要几行代码即可实现。这种实现保存模型的做法将是采用 Python 的 pickle
模块来保存整个模型,这种做法的缺点就是序列化后的数据是属于特定的类和指定的字典结构,原因就是 pickle
并没有保存模型类别,而是保存一个包含该类的文件路径,因此,当在其他项目或者在 refactors
后采用都可能出现错误。
3. 加载和保存一个通用的检查点(Checkpoint)
保存的示例代码:
torch.save({'epoch': epoch,'model_state_dict': model.state_dict(),'optimizer_state_dict': optimizer.state_dict(),'loss': loss,...}, PATH)
加载的示例代码:
model = TheModelClass(*args, **kwargs)
optimizer = TheOptimizerClass(*args, **kwargs)checkpoint = torch.load(PATH)
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
epoch = checkpoint['epoch']
loss = checkpoint['loss']model.eval()
# - or -
model.train()
当保存一个通用的检查点(checkpoint)时,无论是用于继续训练还是预测,都需要保存更多的信息,不仅仅是 state_dict
,比如说优化器的 state_dict
也是非常重要的,它包含了用于模型训练时需要更新的参数和缓存信息,还可以保存的信息包括 epoch
,即中断训练的批次,最后一次的训练 loss,额外的 torch.nn.Embedding
层等等。
上述保存代码就是介绍了如何保存这么多种信息,通过用一个字典来进行组织,然后继续调用 torch.save
方法,一般保存的文件后缀名是 .tar
。
加载代码也如上述代码所示,首先需要初始化模型和优化器,然后加载模型时分别调用 torch.load
加载对应的 state_dict
。然后通过不同的键来获取对应的数值。
加载完后,根据后续步骤,调用 model.eval()
用于预测,model.train()
用于恢复训练。
4. 在同一个文件保存多个模型
保存模型的示例代码:
torch.save({'modelA_state_dict': modelA.state_dict(),'modelB_state_dict': modelB.state_dict(),'optimizerA_state_dict': optimizerA.state_dict(),'optimizerB_state_dict': optimizerB.state_dict(),...}, PATH)
加载模型的示例代码:
modelA = TheModelAClass(*args, **kwargs)
modelB = TheModelBClass(*args, **kwargs)
optimizerA = TheOptimizerAClass(*args, **kwargs)
optimizerB = TheOptimizerBClass(*args, **kwargs)checkpoint = torch.load(PATH)
modelA.load_state_dict(checkpoint['modelA_state_dict'])
modelB.load_state_dict(checkpoint['modelB_state_dict'])
optimizerA.load_state_dict(checkpoint['optimizerA_state_dict'])
optimizerB.load_state_dict(checkpoint['optimizerB_state_dict'])modelA.eval()
modelB.eval()
# - or -
modelA.train()
modelB.train()
当我们希望保存的是一个包含多个网络模型 torch.nn.Modules
的时候,比如 GAN、一个序列化模型,或者多个模型融合,实现的方法其实和保存一个通用的检查点的做法是一样的,同样采用一个字典来保持模型的 state_dict
和对应优化器的 state_dict
。除此之外,还可以继续保存其他相同的信息。
加载模型的示例代码如上述所示,和加载一个通用的检查点也是一样的,同样需要先初始化对应的模型和优化器。同样,保存的模型文件通常是以 .tar
作为后缀名。
5. 采用另一个模型的参数来预热模型(Warmstaring Model)
保存模型的示例代码:
torch.save(modelA.state_dict(), PATH)
加载模型的示例代码:
modelB = TheModelBClass(*args, **kwargs)
modelB.load_state_dict(torch.load(PATH), strict=False)
在之前迁移学习教程中也介绍了可以通过预训练模型来微调,加快模型训练速度和提高模型的精度。
这种做法通常是加载预训练模型的部分网络参数作为模型的初始化参数,然后可以加快模型的收敛速度。
加载预训练模型的代码如上述所示,其中设置参数 strict=False
表示忽略不匹配的网络层参数,因为通常我们都不会完全采用和预训练模型完全一样的网络,通常输出层的参数就会不一样。
当然,如果希望加载参数名不一样的参数,可以通过修改加载的模型对应的参数名字,这样参数名字匹配了就可以成功加载。
6. 不同设备下保存和加载模型
在GPU上保存模型,在 CPU 上加载模型
保存模型的示例代码:
torch.save(model.state_dict(), PATH)
加载模型的示例代码:
device = torch.device('cpu')
model = TheModelClass(*args, **kwargs)
model.load_state_dict(torch.load(PATH, map_location=device))
在 CPU 上加载在 GPU 上训练的模型,必须在调用 torch.load()
的时候,设置参数 map_location
,指定采用的设备是 torch.device('cpu')
,这个做法会将张量都重新映射到 CPU 上。
在GPU上保存模型,在 GPU 上加载模型
保存模型的示例代码:
torch.save(model.state_dict(), PATH)
加载模型的示例代码:
device = torch.device('cuda')
model = TheModelClass(*args, **kwargs)
model.load_state_dict(torch.load(PATH)
model.to(device)
# Make sure to call input = input.to(device) on any input tensors that you feed to the model
在 GPU 上训练和加载模型,调用 torch.load()
加载模型后,还需要采用 model.to(torch.device('cuda'))
,将模型调用到 GPU 上,并且后续输入的张量都需要确保是在 GPU 上使用的,即也需要采用 my_tensor.to(device)
。
在CPU上保存,在GPU上加载模型
保存模型的示例代码:
torch.save(model.state_dict(), PATH)
加载模型的示例代码:
device = torch.device("cuda")
model = TheModelClass(*args, **kwargs)
model.load_state_dict(torch.load(PATH, map_location="cuda:0")) # Choose whatever GPU device number you want
model.to(device)
# Make sure to call input = input.to(device) on any input tensors that you feed to the model
这次是 CPU 上训练模型,但在 GPU 上加载模型使用,那么就需要通过参数 map_location
指定设备。然后继续记得调用 model.to(torch.device('cuda'))
。
保存 torch.nn.DataParallel 模型
保存模型的示例代码:
torch.save(model.module.state_dict(), PATH)
torch.nn.DataParallel
是用于实现多 GPU 并行的操作,保存模型的时候,是采用 model.module.state_dict()
。
加载模型的代码也是一样的,采用 torch.load()
,并可以放到指定的 GPU 显卡上。
完整的代码:
https://github.com/pytorch/tutorials/blob/master/beginner_source/saving_loading_models.py
欢迎关注我的微信公众号--算法猿的成长,或者扫描下方的二维码,大家一起交流,学习和进步!
如果觉得不错,在看、转发就是对小编的一个支持!
推荐阅读
快速入门Pytorch(1)--安装、张量以及梯度
快速入门PyTorch(2)--如何构建一个神经网络
快速入门PyTorch(3)--训练一个图片分类器和多 GPUs 训练
PyTorch系列 | 快速入门迁移学习
PyTorch系列 | 如何加快你的模型训练速度呢?
PyTorch 系列 | 数据加载和预处理教程
PyTorch | 保存和加载模型教程相关推荐
- Pytorch 保存和加载模型
当保存和加载模型时,需要熟悉三个核心功能: 1. torch.save :将序列化对象保存到磁盘.此函数使用Python的 pickle 模块进行序列化.使 用此函数可以保存如模型.tensor.字典 ...
- pytorch保存和加载模型state_dict
保存模型: torch.save({'epoch': epoch + 1,'state_dict': model.state_dict(),'optimizer': optimizer.state_d ...
- python保存模型与参数_基于pytorch的保存和加载模型参数的方法
当我们花费大量的精力训练完网络,下次预测数据时不想再(有时也不必再)训练一次时,这时候torch.save(),torch.load()就要登场了. 保存和加载模型参数有两种方式: 方式一: torc ...
- pytorch load state dict_Pytorch学习记录-使用Pytorch进行深度学习,保存和加载模型
新建 Microsoft PowerPoint 演示文稿 (2).jpg 保存和加载模型 在完成60分钟入门之后,接下来有六节tutorials和五节关于文本处理的tutorials.争取一天一节.不 ...
- 【pytorch】(六)保存和加载模型
文章目录 保存和加载模型 保存加载模型参数 保存加载模型和参数 保存和加载模型 import torch from torch import nn from torch.utils.data impo ...
- tensorflow 保存训练loss_tensorflow2.0保存和加载模型 (tensorflow2.0官方教程翻译)
最新版本:https://www.mashangxue123.com/tensorflow/tf2-tutorials-keras-save_and_restore_models.html 英文版本: ...
- python torch exp_Python:PyTorch 保存和加载训练过的网络 (八十)
保存和加载模型 在这个 notebook 中,我将为你展示如何使用 Pytorch 来保存和加载模型.这个步骤十分重要,因为你一定希望能够加载预先训练好的模型来进行预测,或是根据新数据继续训练. %m ...
- TensorFlow 保存和加载模型
参考: 保存和恢复模型官方教程 tensorflow2保存和加载模型 TensorFlow2.0教程-keras模型保存和序列化
- pytorch 使用DataParallel 单机多卡和单卡保存和加载模型时遇到的问题
首先很多网上的博客,讲的都不对,自己跟着他们踩了很多坑 1.单卡训练,单卡加载 这里我为了把三个模块save到同一个文件里,我选择对所有的模型先封装成一个checkpoint字典,然后保存到同一个文件 ...
最新文章
- 创新驱动未来,浪潮持续深耕信息安全市场
- IIS7 授权配置错误
- 记一次Java动态代理实践
- linux 程序 加密码忘了怎么办,linux忘记了密码怎么办
- Bootstrap--导航栏样式编辑
- a good website to test OTP
- 优化 Go 中的 map 并发存取
- 51Nod - 1385 凑数字
- 通过互联网搜索接口更新拼写语法库的设计
- Spring Boot入门教程(三十六):支付宝集成-当面付
- 宜家IKEA EDIFACT PRODAT报文详解
- python输入一组数字求平均值和标准差_如何计算PySpark DataFrame的平均值和标准差?...
- 第三方登陆--接入谷歌和FaceBook
- BPM端到端流程解决方案分享
- icss之继承inherit
- iOS之UITextField怎么自定义键盘的return键
- SLA--如何学习英语理论篇
- OFBiz终于起航了
- Linux Shell 通配符、元字符、转义符使用实例介绍--Learning the korn shell
- [附源码]JAVA+ssm视频网站(程序+Lw)
热门文章
- ecs服务器数据迁移_如何非常方便地从Windows文件服务器把数据完整地迁移到ONTAP Select...
- 一维有限元法matlab,有限元matlab研究.ppt
- 按键 使用WinHttp实现POST方式用户模拟登录网站
- NetCore NW714 v2.0路由器TTL救砖
- [Python2.x] 标准库 urllib2 的使用细节
- [Linux] VIM 代码折叠
- [Redux/Mobx] 在React中你是怎么对异步方案进行选型的?
- React开发(232):传参可以转变思路
- [html] 如何动态修改`<title>`的标题名称?
- [html] html的元素有哪些(包含H5)?