新建 Microsoft PowerPoint 演示文稿 (2).jpg

保存和加载模型

在完成60分钟入门之后,接下来有六节tutorials和五节关于文本处理的tutorials。争取一天一节。不过重点是关注神经网络构建和数据处理部分。

本节主要是用于解决模型的保存和加载。会的直接跳过就好。我也只是做记录,这篇搞定就直接进入NLP部分。

三个核心函数:

  • torch.save:将序列化的对象保存在硬盘上,使用Python的pickle来序列化。
  • torch.load:使用pickle的拆包功能将硬盘上的序列化文件导入内存中。
  • torch.nn.Module.load_state_dict:加载一个模型的参数字典。

主要目录

  1. 什么是state_dict?
  2. Saving & Loading Model for Inference
  3. Saving & Loading a General Checkpoint
  4. Saving Multiple Models in One File
  5. Warmstarting Model Using Parameters from a Different Model
  6. Saving & Loading Model Across Devices

1.什么是state_dict?

在PyTorch中,torch.nn.Module的可学习参数(即权重和偏差),模块模型包含在model's参数中(通过model.parameters()访问)。state_dict是个简单的Python dictionary对象,它将每个层映射到它的参数张量。

注意,只有具有可学习参数的层(卷积层、线性层等)才有model's state_dict中的条目。优化器对象(connector .optim)也有一个state_dict,其中包含关于优化器状态以及所使用的超参数的信息。

可以通过遍历模型中的state_dict每一个tensor来查看。

简单构建一个模型,这里我用了两种写法,注释掉的那一种是比较通用的,也是建议使用的。

import torchimport torch.nn as nnimport torch.nn.functional as F# Define modelclass TheModelClass(nn.Module): def __init__(self): super(TheModelClass, self).__init__() self.conv1 = nn.Conv2d(3, 6, 5) self.pool = nn.MaxPool2d(2, 2) self.conv2 = nn.Conv2d(6, 16, 5) self.fc1 = nn.Linear(16 * 5 * 5, 120) self.fc2 = nn.Linear(120, 84) self.fc3 = nn.Linear(84, 10) def farward(self, x): x = self.pool(F.relu(self.conv1(x))) x = self.pool(F.relu(self.conv2(x))) x = x.view(-1, 16 * 5 * 5) x = F.relu(self.fc1(x)) x = F.relu(self.fc2(x)) x = self.fc3(x) return x# class TheModelClass(nn.Module):# def __init__(self):# super(TheModelClass, self).__init__()# self.conv = torch.nn.Sequential()# self.conv.add_module('conv1', nn.Conv2d(3, 6, 5))# self.conv.add_module('pool', nn.MaxPool2d(2, 2))# self.conv.add_module('conv2', nn.Conv2d(6, 16, 5))# self.dense = torch.nn.Sequential()# self.dense.add_module('fc1', nn.Linear(16 * 5 * 5, 120))# self.dense.add_module('fc2', nn.Linear(120, 84))# self.dense.add_module('fc3', nn.Linear(84, 10))## def forward(self, x):# conv_out = self.conv(x)# res = conv_out.view(conv_out.size(0), -1)# out = self.dense(res)# return out# Initialize modelmodel = TheModelClass()# Initialize optimizeroptimizer = torch.optim.SGD(model.parameters(), lr=1e-4, momentum=0.9)print(model)

查看以下模型的state_dict

print('Model state dict:')for param in model.state_dict(): print(param, 

pytorch load state dict_Pytorch学习记录-使用Pytorch进行深度学习,保存和加载模型相关推荐

  1. PyTorch | 保存和加载模型教程

    点击上方"算法猿的成长",选择"加为星标" 第一时间关注 AI 和 Python 知识 图片来自 Unsplash,作者: Jenny Caywood 2019 ...

  2. Pytorch 保存和加载模型

    当保存和加载模型时,需要熟悉三个核心功能: 1. torch.save :将序列化对象保存到磁盘.此函数使用Python的 pickle 模块进行序列化.使 用此函数可以保存如模型.tensor.字典 ...

  3. 【pytorch】(六)保存和加载模型

    文章目录 保存和加载模型 保存加载模型参数 保存加载模型和参数 保存和加载模型 import torch from torch import nn from torch.utils.data impo ...

  4. python保存模型与参数_基于pytorch的保存和加载模型参数的方法

    当我们花费大量的精力训练完网络,下次预测数据时不想再(有时也不必再)训练一次时,这时候torch.save(),torch.load()就要登场了. 保存和加载模型参数有两种方式: 方式一: torc ...

  5. pytorch 使用DataParallel 单机多卡和单卡保存和加载模型时遇到的问题

    首先很多网上的博客,讲的都不对,自己跟着他们踩了很多坑 1.单卡训练,单卡加载 这里我为了把三个模块save到同一个文件里,我选择对所有的模型先封装成一个checkpoint字典,然后保存到同一个文件 ...

  6. pytorch load state dict_PyTorch 学习笔记(五):Finetune和各层定制学习率

    本文截取自<PyTorch 模型训练实用教程>,获取全文pdf请点击:https://github.com/tensor-yu/PyTorch_Tutorial @[toc] 我们知道一个 ...

  7. pytorch load state dict_pytorch训练trick

    pytorch随机种子 pytorch随机种子是随机初始化的,如果想复现一个比较好的结果,可以设置固定随机种子. 其中cudnn打开可以提高计算效率,但是会影响每次复现结果.另外如果图像预处理的时候用 ...

  8. pytorch保存和加载模型state_dict

    保存模型: torch.save({'epoch': epoch + 1,'state_dict': model.state_dict(),'optimizer': optimizer.state_d ...

  9. 机器学习代码实战——保存和加载模型(Save and Load Model)

    文章目录 1.实验目的 2.保存与加载模型 2.1.pickle方法 2.2.joblib方法 1.实验目的 每当我们训练完一个模型后,我们需要保存训练好的模型留给下次用或者再次训练,因此我将给出两种 ...

最新文章

  1. SpringBoot源码分析之@Scheduled
  2. 2007年3月东北微软技术活动预告
  3. STL 之 list 容器详解
  4. catia 如何提取cgr面_CATIA教程之创成式外形设计金元宝
  5. mysql 节点查根_(三)B数、B+树及在数据库索引中应用
  6. CentOS 命令大全 (转)
  7. tomcat 控制台乱码 windows下
  8. [Java] 蓝桥杯ALGO-48 算法训练 关联矩阵
  9. oracle 基数 选择率,1.1.2.2 可选择率(1)
  10. ROS机器人编程新书推荐(附免费下载)
  11. 【贪玩巴斯】一文学会检索三要素:检索字段、检索词、检索算法检索(二)——「一文学会检索三要素:检索字段、检索词、检索算法」 2021-09-18
  12. uniapp得到用户当前定位以及用户选择位置
  13. Android系统的system/app和system/priv-app
  14. 自然语言处理之词移距离Word Mover's Distance
  15. 对抗攻击常见方法汇总
  16. vue项目执行命令npm run serve运行项目时 停在 98% after emitting CopyPlugin
  17. Linux:ls命令
  18. maxcompute-入门-数据下载
  19. 一文说清楚pytorch和tensorFlow的区别究竟在哪里
  20. [资料分享] 深受电子工程师喜爱的电路资料大合集

热门文章

  1. PyQt5随笔:QSettings 的简单使用详说,进行软件的设置状态数据储存与初始化
  2. oracle+字段科学计数,PL/SQL中查询Oracle大数(17位以上)时显示科学计数法的解决方法...
  3. html页面控制标签,html body标签详解与html常用的控制标记
  4. 重点 (五) : iOS框架搭建-1
  5. 游戏开发中的按键操作管理
  6. mysql左连接多条件,on子句多条件
  7. Python弄懂基础点---print函数格式化输出的几种方式
  8. 两个一阶节的级联型_DSP第八章 时域离散系统的实现ppt课件.ppt
  9. 2021-12-8 Leetcode 914.卡牌分组
  10. leetcode系列--9.回文数