PyTorch | 模型的保存和加载

  • 一、模型参数的保存和加载
  • 二、完整模型的保存和加载

一、模型参数的保存和加载

  • torch.save(module.state_dict(), path):使用module.state_dict()函数获取各层已经训练好的参数和缓冲区,然后将参数和缓冲区保存到path所指定的文件存放路径(常用文件格式为.pt.pth.pkl)。
  • torch.nn.Module.load_state_dict(state_dict):从state_dict中加载参数和缓冲区到Module及其子类中 。
  • torch.nn.Module.state_dict()函数返回python中的一个OrderedDict类型字典对象,该对象将每一层与它的对应参数和缓冲区建立映射关系,字典的键值是参数或缓冲区的名称。只有那些参数可以训练的层才会被保存到OrderedDict中,例如:卷积层、线性层等。
  • Python中的字典类以“键:值”方式存取数据,OrderedDict是它的一个子类,实现了对字典对象中元素的排序(OrderedDict根据放入元素的先后顺序进行排序)。由于进行了排序,所以顺序不同的两个OrderedDict字典对象会被当做是两个不同的对象。
  • 示例:
    import torch
    import torch.nn as nnclass Net(nn.Module):def __init__(self):super(Net, self).__init__()self.conv1 = nn.Conv2d(1, 2, 3)self.pool1 = nn.MaxPool2d(2, 2)def forward(self, x):x = self.conv1(x)x = self.pool1(x)return x# 初始化网络
    net = Net()
    net.conv1.weight[0].detach().fill_(1)
    net.conv1.weight[1].detach().fill_(2)
    net.conv1.bias.data.detach().zero_()
    # 获取state_dict
    state_dict = net.state_dict()
    # 字典的遍历默认是遍历key,所以param_tensor实际上是键值
    for param_tensor in state_dict: print(param_tensor,':\n',state_dict[param_tensor])
    # 保存模型参数
    torch.save(state_dict,"net_params.pth")
    # 通过加载state_dict获取模型参数
    net.load_state_dict(state_dict)
    

    输出:

二、完整模型的保存和加载

  • torch.save(module, path):将训练完的整个网络模型module保存到path所指定的文件存放路径(常用文件格式为.pt.pth)。
  • torch.load(path):加载保存到path中的整个神经网络模型。
  • 示例:
    import torch
    import torch.nn as nnclass Net(nn.Module):def __init__(self):super(Net, self).__init__()self.conv1 = nn.Conv2d(1, 2, 3)self.pool1 = nn.MaxPool2d(2, 2)def forward(self, x):x = self.conv1(x)x = self.pool1(x)return x# 初始化网络
    net = Net()
    net.conv1.weight[0].detach().fill_(1)
    net.conv1.weight[1].detach().fill_(2)
    net.conv1.bias.data.detach().zero_()
    # 保存整个网络
    torch.save(net,"net.pth")
    # 加载网络
    net = torch.load("net.pth")
    

PyTorch | 模型的保存和加载相关推荐

  1. pytorch模型的保存和加载、checkpoint

    pytorch模型的保存和加载.checkpoint 其实之前笔者写代码的时候用到模型的保存和加载,需要用的时候就去度娘搜一下大致代码,现在有时间就来整理下整个pytorch模型的保存和加载,开始学习 ...

  2. PyTorch模型的保存加载以及数据的可视化

    文章目录 PyTorch模型的保存和加载 模块和张量的序列化和反序列化 模块状态字典的保存和载入 PyTorch数据的可视化 TensorBoard的使用 总结 PyTorch模型的保存和加载 在深度 ...

  3. PyTorch学习笔记2:nn.Module、优化器、模型的保存和加载、TensorBoard

    文章目录 一.nn.Module 1.1 nn.Module的调用 1.2 线性回归的实现 二.损失函数 三.优化器 3.1.1 SGD优化器 3.1.2 Adagrad优化器 3.2 分层学习率 3 ...

  4. numpy将所有数据变为0和1_PyTorch 学习笔记(二):张量、变量、数据集的读取、模组、优化、模型的保存和加载...

    一. 张量 PyTorch里面最基本的操作对象就是Tensor,Tensor是张量的英文,表示的是一个多维的矩阵,比如零维就是一个点,一维就是向量,二维就是一般的矩阵,多维就相当于一个多维的数组,这和 ...

  5. 线性回归之模型的保存和加载

    线性回归之模型的保存和加载 1 sklearn模型的保存和加载API from sklearn.externals import joblib   [目前这行代码报错,直接写import joblib ...

  6. paddlepaddle模型的保存和加载

    导读 深度学习中模型的计算图可以被分为两种,静态图和动态图,这两种模型的计算图各有优劣. 静态图需要我们先定义好网络的结构,然后再进行计算,所以静态图的计算速度快,但是debug比较的困难,因为只有当 ...

  7. PyTorch基础-模型的保存和加载-09

    模型的保存 import numpy as np import torch from torch import nn,optim from torch.autograd import Variable ...

  8. tensorflow 模型的保存和加载

    为了让训练结果可以复用,需要将训练得到的神经网络模型持久化,也就是把模型的参数保存下来,并保证可以持久化后的模型文件中还原出保存的模型. 1. 保存模型 tensorflow提供了一个API可以方便的 ...

  9. Pytorch模型训练保存/加载(搭建完整流程)

    文章目录 前言 模型训练完整步骤 模型保存与加载 GPU训练 "借鸡生蛋" 模型使用 本博文优先在掘金社区发布! 前言 我们这边还是以CIARF10这个模型为例子. 现在的话先说明 ...

最新文章

  1. Hive代码组织及架构简单介绍
  2. Sql 删除不保留日志
  3. Softmax vs. SoftmaxWithLoss 推导过程
  4. 【100题】第五十三题 字符串的全排列
  5. 接口返回的类型是html页面_1.10 PhalApi 2.x 接口文档
  6. github 开放_GitHub为女性开发人员所做的工作,Tim O'Reilly谈开放数据等
  7. C/C++ stack栈的理解以及使用
  8. iPhone 12包装盒设计曝光,没充电器没耳机实锤?
  9. SVM支持向量机通俗导论(理解SVM的三层境界)
  10. python实现将文件下内每张图片按顺序命名为txt文本文件中的内容
  11. 自有数据集上,如何用keras最简单训练YOLOv3目标检测
  12. Zephyr调整Main栈大小
  13. luogu1970 花匠
  14. securecrt破解版64位
  15. Linux 安装 菜鸟教程,Linux安装Nginx(菜鸟教程简单易懂)
  16. 用ffmpeg批量转换WAV文件采样率
  17. 【AI视野·今日CV 计算机视觉论文速览 第241期】Wed, 1 Dec 2021
  18. 嗅图狗——更新与反馈专贴
  19. 网易云音乐信息爬取(存储为 csv文件)喜马拉雅音乐爬取
  20. java rgb转yuv_RGB,CMY(K),YUV,YIQ,YCbCr颜色的转换算法(java实现)

热门文章

  1. Android开发_简单的网络编程
  2. Docker知识点导航
  3. truncate(截断)与delete(删除)的区别
  4. 计算机录入中级职称,中级会计职称无纸化考试公式怎么输入的 V模式是什么?...
  5. cpua55和a53哪个好_ARM正式发布A75和A55,助华为海思赶超高通
  6. 云计算在物联网中的应用
  7. Windows系统剪切板不可用
  8. 计算机复合材料缠绕机,[转载]最先进的复合材料纤维缠绕仿真软件CADWIND
  9. 好分数教师版服务器维护,好分数教师版app
  10. postman 返回json乱码_POSTMAN发起请求收到乱码 http 406错误