神经网络训练后我们需要将模型进行保存,要用的时候将保存的模型进行加载,PyTorch 中保存和加载模型主要分为两类:

  • 保存加载整个模型
  • 只保存加载模型参数

https://zhuanlan.zhihu.com/p/73893187

一、保存加载模型基本用法

1、保存加载整个模型

保存整个网络模型(网络结构+权重参数)。

torch.save(model, 'net.pkl')

直接加载整个网络模型(可能比较耗时)。

model = torch.load('net.pkl')

2、只保存加载模型参数

只保存模型的权重参数(速度快,占内存少)。

torch.save(model.state_dict(), 'net_params.pkl')

因为我们只保存了模型的参数,所以需要先定义一个网络对象,然后再加载模型参数。

# 构建一个网络结构
model = ClassNet()
# 将模型参数加载到新模型中
state_dict = torch.load('net_params.pkl')
model.load_state_dict(state_dict)

保存模型进行推理测试时,只需保存训练好的模型的权重参数,即推荐第二种方法。

主要用法就是上面这些,接下来讲一下PyTorch中保存加载模型内部的一些原理,以及我们可能会遇到的一些特殊的需求。

二、保存加载自定义模型

上面保存加载的 net.pkl 其实一个字典,通常包含如下内容:

  • 网络结构:输入尺寸、输出尺寸以及隐藏层信息,以便能够在加载时重建模型。
  • 模型的权重参数:包含各网络层训练后的可学习参数,可以在模型实例上调用 state_dict() 方法来获取,比如前面介绍只保存模型权重参数时用到的 model.state_dict()。
  • 优化器参数:有时保存模型的参数需要稍后接着训练,那么就必须保存优化器的状态和所其使用的超参数,也是在优化器实例上调用 state_dict() 方法来获取这些参数。
  • 其他信息:有时我们需要保存一些其他的信息,比如 epoch,batch_size 等超参数。

知道了这些,那么我们就可以自定义需要保存的内容,比如:

# saving a checkpoint assuming the network class named ClassNet
checkpoint = {'model': ClassNet(),'model_state_dict': model.state_dict(),'optimizer_state_dict': optimizer.state_dict(),'epoch': epoch}
​
torch.save(checkpoint, 'checkpoint.pkl')

上面的 checkpoint 是个字典,里面有4个键值对,分别表示网络模型的不同信息。

然后我们要加载上面保存的自定义的模型:

def load_checkpoint(filepath):checkpoint = torch.load(filepath)model = checkpoint['model']  # 提取网络结构model.load_state_dict(checkpoint['model_state_dict'])  # 加载网络权重参数optimizer = TheOptimizerClass()optimizer.load_state_dict(checkpoint['optimizer_state_dict'])  # 加载优化器参数for parameter in model.parameters():parameter.requires_grad = Falsemodel.eval()return modelmodel = load_checkpoint('checkpoint.pkl')

如果加载模型只是为了进行推理测试,则将每一层的 requires_grad 置为 False,即固定这些权重参数;还需要调用 model.eval() 将模型置为测试模式,主要是将 dropout 和 batch normalization 层进行固定,否则模型的预测结果每次都会不同。

如果希望继续训练,则调用 model.train(),以确保网络模型处于训练模式。

state_dict() 也是一个Python字典对象,model.state_dict() 将每一层的可学习参数映射为参数矩阵,其中只包含具有可学习参数的层(卷积层、全连接层等)。

# Define model
class TheModelClass(nn.Module):def __init__(self):super(TheModelClass, self).__init__()self.conv1 = nn.Conv2d(3, 8, 5)self.bn = nn.BatchNorm2d(8)self.conv2 = nn.Conv2d(8, 16, 5)self.pool = nn.MaxPool2d(2, 2)self.fc1 = nn.Linear(16 * 5 * 5, 120)self.fc2 = nn.Linear(120, 10)
​def forward(self, x):x = self.pool(F.relu(self.conv1(x)))x = self.bn(x)x = self.pool(F.relu(self.conv2(x)))x = x.view(-1, 16 * 5 * 5)x = F.relu(self.fc1(x))x = F.relu(self.fc2(x))x = self.fc3(x)return x# Initialize modelmodel = TheModelClass()
​# Initialize optimizeroptimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
​print("Model's state_dict:")for param_tensor in model.state_dict():print(param_tensor, "\t", model.state_dict()[param_tensor].size())
​print("Optimizer's state_dict:")for var_name in optimizer.state_dict():print(var_name, "\t", optimizer.state_dict()[var_name])

输出为:

Model's state_dict:
conv1.weight            torch.Size([8, 3, 5, 5])
conv1.bias              torch.Size([8])
bn.weight               torch.Size([8])
bn.bias                 torch.Size([8])
bn.running_mean         torch.Size([8])
bn.running_var          torch.Size([8])
bn.num_batches_tracked  torch.Size([])
conv2.weight            torch.Size([16, 8, 5, 5])
conv2.bias              torch.Size([16])
fc1.weight              torch.Size([120, 400])
fc1.bias                torch.Size([120])
fc2.weight              torch.Size([10, 120])
fc2.bias                torch.Size([10])
Optimizer's state_dict:
state            {}
param_groups     [{'lr': 0.001, 'momentum': 0.9, 'dampening': 0, 'weight_decay': 0, 'nesterov': False, 'params': [139805696932024, 139805483616008, 139805483616080, 139805483616152, 139805483616440, 139805483616512, 139805483616584, 139805483616656, 139805483616728, 139805483616800]}]

可以看到 model.state_dict() 保存了卷积层,BatchNorm层和最大池化层的信息;而 optimizer.state_dict() 则保存的优化器的状态和相关的超参数。

三、跨设备保存加载模型

1、在 CPU 上加载在 GPU 上训练并保存的模型(Save on GPU, Load on CPU):

device = torch.device('cpu')
model = TheModelClass()
# Load all tensors onto the CPU device
model.load_state_dict(torch.load('net_params.pkl', map_location=device))

map_location:a function, torch.device, string or a dict specifying how to remap storage locations

令 torch.load() 函数的 map_location 参数等于 torch.device(‘cpu’) 即可。 这里令 map_location 参数等于 ‘cpu’ 也同样可以。

2、在 GPU 上加载在 GPU 上训练并保存的模型(Save on GPU, Load on GPU):

device = torch.device("cuda")
model = TheModelClass()
model.load_state_dict(torch.load('net_params.pkl'))
model.to(device)

在这里使用 map_location 参数不起作用,要使用 model.to(torch.device(“cuda”)) 将模型转换为CUDA优化的模型。

还需要对将要输入模型的数据调用 data = data.to(device),即将数据从CPU转移到GPU。请注意,调用 my_tensor.to(device) 会返回一个 my_tensor 在 GPU 上的副本,它不会覆盖 my_tensor。因此需要手动覆盖张量:my_tensor = my_tensor.to(device)。

3、在 GPU 上加载在 GPU 上训练并保存的模型(Save on CPU, Load on GPU)

device = torch.device("cuda")
model = TheModelClass()
model.load_state_dict(torch.load('net_params.pkl', map_location="cuda:0"))
model.to(device)

当加载包含GPU tensors的模型时,这些tensors 会被默认加载到GPU上,不过是同一个GPU设备。

当有多个GPU设备时,可以通过将 map_location 设定为 cuda:device_id 来指定使用哪一个GPU设备,上面例子是指定编号为0的GPU设备。

其实也可以将 torch.device(“cuda”) 改为 torch.device(“cuda:0”) 来指定编号为0的GPU设备。

最后调用 model.to(torch.device(‘cuda’)) 来将模型的tensors转换为 CUDA tensors。

下面是PyTorch官方文档上的用法,可以进行参考:

>>> torch.load('tensors.pt')
# Load all tensors onto the CPU
>>> torch.load('tensors.pt', map_location=torch.device('cpu'))
# Load all tensors onto the CPU, using a function
>>> torch.load('tensors.pt', map_location=lambda storage, loc: storage)
# Load all tensors onto GPU 1
>>> torch.load('tensors.pt', map_location=lambda storage, loc: storage.cuda(1))
# Map tensors from GPU 1 to GPU 0
>>> torch.load('tensors.pt', map_location={'cuda:1':'cuda:0'})

四、CUDA 的用法

在PyTorch中和GPU相关的几个函数:

import torch# 判断cuda是否可用;
print(torch.cuda.is_available())# 获取gpu数量;
print(torch.cuda.device_count())# 获取gpu名字;
print(torch.cuda.get_device_name(0))# 返回当前gpu设备索引,默认从0开始;
print(torch.cuda.current_device())# 查看tensor或者model在哪块GPU上
print(torch.tensor([0]).get_device())

我的电脑输出为:

True
1
GeForce RTX 2080 Ti
0

有时我们需要把数据和模型从cpu移到gpu中,有以下两种方法:

use_cuda = torch.cuda.is_available()# 方法一:
if use_cuda:data = data.cuda()model.cuda()# 方法二:
device = torch.device("cuda" if use_cuda else "cpu")
data = data.to(device)
model.to(device)

个人比较习惯第二种方法,可以少一个 if 语句。而且该方法还可以通过设备号指定使用哪个GPU设备,比如使用0号设备:

device = torch.device("cuda:0" if use_cuda else "cpu")

1月13 PyTorch 中模型的使用,保存加载模型相关推荐

  1. 【pytorch】(六)保存和加载模型

    文章目录 保存和加载模型 保存加载模型参数 保存加载模型和参数 保存和加载模型 import torch from torch import nn from torch.utils.data impo ...

  2. 【待更新】GPU 保存模型参数,GPU 加载模型参数

    GPU 保存模型参数,GPU 加载模型参数 保存 # 模型 device = torch.device('cuda') net = KGCN(num_user, num_entity, num_rel ...

  3. pytorch load state dict_Pytorch学习记录-使用Pytorch进行深度学习,保存和加载模型

    新建 Microsoft PowerPoint 演示文稿 (2).jpg 保存和加载模型 在完成60分钟入门之后,接下来有六节tutorials和五节关于文本处理的tutorials.争取一天一节.不 ...

  4. 加载dict_Pytorch模型resume training,加载模型基础上继续训练

    Step1:首先查看源码train.py中如何保存模型的: checkpoint_dict = {'epoch': epoch, 'model_state_dict': model.state_dic ...

  5. Tensorflow学习(二)之——保存加载模型、Saver的用法

    1. Saver的背景介绍 我们经常在训练完一个模型之后希望保存训练的结果,这些结果指的是模型的参数,以便下次迭代的训练或者用作测试.Tensorflow针对这一需求提供了Saver类. Saver类 ...

  6. pytorch学习笔记(6):GPU和如何保存加载模型

    参考文档:https://mp.weixin.qq.com/s/kmed_E4MaDwN-oIqDh8-tg 上篇文章我们完成了一个 vgg 网络的实现,那么现在已经掌握了一些基础的网络结构的实现,距 ...

  7. pytorch: 在训练中保存模型,加载模型

    文章目录 1. 保存整个模型 2.仅保存和加载模型参数(推荐使用) 3. 保存其他参数到模型中,比如optimizer,epoch 等 1. 保存整个模型 torch.save(model, 'mod ...

  8. Pytorch 保存和加载模型

    当保存和加载模型时,需要熟悉三个核心功能: 1. torch.save :将序列化对象保存到磁盘.此函数使用Python的 pickle 模块进行序列化.使 用此函数可以保存如模型.tensor.字典 ...

  9. OpenGL通过Assimp加载模型

    OpenGL通过Assimp加载模型 OpenGL通过Assimp加载模型简介 源代码剖析 主要源代码 OpenGL通过Assimp加载模型简介 到目前为止,我们已经使用了手动创建的模型.如您所见,为 ...

  10. OSG —— 笔记2 - 加载模型(附源码)

    效果         相关文章      OSG -- 笔记1 - 指令调用模型      OSG -- 笔记2 - 加载模型(附源码)      OSG -- 笔记3 - 绘制矩形(附源码)     ...

最新文章

  1. Android 中文 API (25) —— ZoomControls
  2. numpy 中的 random.rand() 函数
  3. 【算法】哈希表 ( 两数之和 )
  4. Ubuntu 16.04下Caffe-SSD的应用(二)——准备与处理VOC2007数据集
  5. pmbook 知识领域 第六版_PMP项目管理10大知识领域脑图
  6. 在ArcGIS调坐标系引发的一系列问题
  7. python怎么导入大小字母_isort-用于对python导入的库按照字母进行排序的工具
  8. 设置WordPress文章关键词自动获取,文章所属分类名称,描述自动获取文章内容,给文章的图片自动加上AlT标签...
  9. BZOJ 2324: [ZJOI2011]营救皮卡丘(带上下限的最小费用最大流)
  10. abovedisplayskip无效_latex公式图片行间距段间距调整心得 -
  11. 安装Linux操作系统完成必做几件事
  12. C加加学习之路 1——开始
  13. Android自定义GridView显示一行,并且可以左右滑动
  14. 开源字体下载——思源黑体
  15. 30天自己制作操作系统中二进制编辑器BZ-1621
  16. Sql Server卸载安装
  17. window sserver 2008 r2安装教程
  18. 上海嵌联自控供应车流量统计系统
  19. #1829 : Tomb Raider(哈希)
  20. Error: Cannot find module ‘@/xxx‘

热门文章

  1. PIX4D工作手册分享
  2. ENVI软件中决策树分类和监督分类算法比较
  3. ArcGIS操作:矢量shp编辑
  4. html5开发播放器,larkplayer: 插件化的 HTML5 播放器
  5. python爬虫-- 爬取51job网招聘信息
  6. android zip4j之--解压zip文件并实时显示解压进度
  7. 【React Native 安卓开发】----(Flexbox布局)【第二篇】
  8. Linux系统kill端口占用简书,MAC/Linux解决端口占用
  9. 点赞功能java_jquery点赞功能实现代码 点个赞吧!
  10. 第二周函数-的基本格式: