文章目录

  • 保存和加载模型
    • 保存加载模型参数
    • 保存加载模型和参数

保存和加载模型

import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision.transforms import ToTensor, Lambdadevice = 'cuda' if torch.cuda.is_available() else 'cpu'
print(device)# ======================= 数据 =======================
training_data = datasets.FashionMNIST(root="data",train=True,download=True,transform=ToTensor()
)test_data = datasets.FashionMNIST(root="data",train=False,download=True,transform=ToTensor()
)train_dataloader = DataLoader(training_data, batch_size=64)
test_dataloader = DataLoader(test_data, batch_size=64)# ======================= 模型 =======================
class NeuralNetwork(nn.Module):def __init__(self):super(NeuralNetwork, self).__init__()self.linear_relu_stack = nn.Sequential(nn.Linear(28 * 28, 512),nn.ReLU(),nn.Linear(512, 512),nn.ReLU(),nn.Linear(512, 10),)def forward(self, x):x = x.flatten(1)logits = self.linear_relu_stack(x)return logitsdef train_loop(dataloader, model, loss_fn, optimizer):'''训练循环'''size = len(dataloader.dataset)for batch, (X, y) in enumerate(dataloader):X = X.to(device)y = y.to(device)# 计算估计值与损失pred = model(X)loss = loss_fn(pred, y)# 反向传播optimizer.zero_grad()loss.backward()optimizer.step()if batch % 100 == 0:loss, current = loss.item(), batch * len(X)print(f"loss: {loss:>7f}  [{current:>5d}/{size:>5d}]")def test_loop(dataloader, model, loss_fn):'''测试循环'''size = len(dataloader.dataset)num_batches = len(dataloader)test_loss, correct = 0, 0with torch.no_grad():for X, y in dataloader:X = X.to(device)y = y.to(device)pred = model(X)test_loss += loss_fn(pred, y).item()correct += (pred.argmax(1) == y).type(torch.float).sum().item()test_loss /= num_batchescorrect /= sizeprint(f"Test Error: \n Accuracy: {(100 * correct):>0.1f}%, Avg loss: {test_loss:>8f} \n")# 超参数
learning_rate = 1e-3
batch_size = 64
epochs = 10# 模型实例
model = NeuralNetwork().to(device)
# 损失函数实例
loss_fn = nn.CrossEntropyLoss()
# 优化器实例
optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate)for t in range(epochs):print(f"Epoch {t + 1}\n-------------------------------")train_loop(train_dataloader, model, loss_fn, optimizer)test_loop(test_dataloader, model, loss_fn)
print("Done!")

在之前的章节中,我们介绍了如何用Pytorch搭建模型以及训练。在本节中,我们将了解到如何保存、加载模型。

保存加载模型参数

PyTorch模型可以将学习到的参数存储在内部状态字典(称为state_dict)中。我们可以通过torch.save保存这些参数:

torch.save(model.state_dict(), "model_state_dict.pth")

要加载模型参数,需要先创建同一模型的实例,然后使用load_state_dict()方法加载。

model = NeuralNetwork()
model.load_state_dict(torch.load("model_state_dict.pth"))

使用加载的模型进行预测:

# model.eval()
classes = ["T-shirt/top","Trouser","Pullover","Dress","Coat","Sandal","Shirt","Sneaker","Bag","Ankle boot",
]
x, y = test_data[0][0], test_data[0][1]
with torch.no_grad():x = x.to(device)pred = model(x)predicted, actual = classes[pred[0].argmax(0)], classes[y]print(f'Predicted: "{predicted}", Actual: "{actual}"')

输出:

Predicted: "Ankle boot", Actual: "Ankle boot"

注意:
如果在训练模型时使用到dropout或batch normalization,则需要使用model.eval()方法将网络设置为测试模式,否则将产生不一致的推断结果。

保存加载模型和参数

加载模型参数时,我们需要首先实例化模型类,因为该类定义了网络的结构。我们还可以将此类的结构与模型一起保存:将model(而不是model.state_dict())传递给torch.save

torch.save(model, 'model.pth')

然后我们可以像这样加载模型:

model = torch.load("model.pth")
# model.eval()
classes = ["T-shirt/top","Trouser","Pullover","Dress","Coat","Sandal","Shirt","Sneaker","Bag","Ankle boot",
]
x, y = test_data[0][0], test_data[0][1]
with torch.no_grad():x = x.to(device)pred = model(x)predicted, actual = classes[pred[0].argmax(0)], classes[y]print(f'Predicted: "{predicted}", Actual: "{actual}"')
Predicted: "Ankle boot", Actual: "Ankle boot"

参考:
[1] https://pytorch.org/tutorials/beginner/basics/saveloadrun_tutorial.html

【pytorch】(六)保存和加载模型相关推荐

  1. python保存模型与参数_基于pytorch的保存和加载模型参数的方法

    当我们花费大量的精力训练完网络,下次预测数据时不想再(有时也不必再)训练一次时,这时候torch.save(),torch.load()就要登场了. 保存和加载模型参数有两种方式: 方式一: torc ...

  2. pytorch load state dict_Pytorch学习记录-使用Pytorch进行深度学习,保存和加载模型

    新建 Microsoft PowerPoint 演示文稿 (2).jpg 保存和加载模型 在完成60分钟入门之后,接下来有六节tutorials和五节关于文本处理的tutorials.争取一天一节.不 ...

  3. PyTorch | 保存和加载模型教程

    点击上方"算法猿的成长",选择"加为星标" 第一时间关注 AI 和 Python 知识 图片来自 Unsplash,作者: Jenny Caywood 2019 ...

  4. Pytorch 保存和加载模型

    当保存和加载模型时,需要熟悉三个核心功能: 1. torch.save :将序列化对象保存到磁盘.此函数使用Python的 pickle 模块进行序列化.使 用此函数可以保存如模型.tensor.字典 ...

  5. tensorflow 保存训练loss_tensorflow2.0保存和加载模型 (tensorflow2.0官方教程翻译)

    最新版本:https://www.mashangxue123.com/tensorflow/tf2-tutorials-keras-save_and_restore_models.html 英文版本: ...

  6. TensorFlow 保存和加载模型

    参考: 保存和恢复模型官方教程 tensorflow2保存和加载模型 TensorFlow2.0教程-keras模型保存和序列化

  7. pytorch 使用DataParallel 单机多卡和单卡保存和加载模型时遇到的问题

    首先很多网上的博客,讲的都不对,自己跟着他们踩了很多坑 1.单卡训练,单卡加载 这里我为了把三个模块save到同一个文件里,我选择对所有的模型先封装成一个checkpoint字典,然后保存到同一个文件 ...

  8. pytorch保存和加载模型state_dict

    保存模型: torch.save({'epoch': epoch + 1,'state_dict': model.state_dict(),'optimizer': optimizer.state_d ...

  9. 机器学习代码实战——保存和加载模型(Save and Load Model)

    文章目录 1.实验目的 2.保存与加载模型 2.1.pickle方法 2.2.joblib方法 1.实验目的 每当我们训练完一个模型后,我们需要保存训练好的模型留给下次用或者再次训练,因此我将给出两种 ...

最新文章

  1. 3NF分解与BCNF分解
  2. 2020-12-11 图片格式互转:base64、PIL Image opencv cv2互转
  3. methanol 模块化的可定制的网页爬虫软件,主要的优点是速度快。
  4. linux是否有免安装程序,在线Ubuntu Linux系统,免安装体验Linux系统
  5. as it exceeds the max of 500KB._IT狂人第一季 | 如何考察员工
  6. Object.defineProperty 详解
  7. STM32之ADC单通道连续例程
  8. 微软 azure_Microsoft Azure,我们迁移数据的第一步
  9. 3.26 Tensorflow 实验记录
  10. 再续上一篇:如果哪天沃尔玛也“.CN”了
  11. 海康IP摄像头rtsp解码(ubuntu上使用)
  12. IBM P750 AIX机器根目录空间满问题解决办法
  13. python注释可用于表明作者和版权信息_vs2017 新建Class 文件时,自动添加作者版权声明注释...
  14. 多人配音怎么做的?这两个多人配音方法分享给你
  15. hello 驱动编写-最简单的驱动程序
  16. 详解互联网平台的资金系统方案 自建支付清结算系统优势明显
  17. C10k-problem
  18. 英语听力采用计算机化考试,北京高考英语听力机考有什么特点?
  19. 基于存储的C语言文件操作常规问题分析(文本文件与二进制文件)
  20. 数据流代替工作流解决方案

热门文章

  1. WPF自定义Popup和弹出菜单
  2. 周欢:区块链下一波的牛市机会
  3. MasteringOpenCV实战源码学习笔记 章节一
  4. 十年项目经验面试官亲传大数据面试__大数据面试独孤九剑
  5. node搭建webrtc信令服务器
  6. redash二次开发和制作镜像
  7. 【逻辑与计算理论】λ演算与组合子逻辑概念简介
  8. P265GH钢板牌号和化学元素
  9. UVA 10480 - Sabotage (最大流)
  10. 集群管理——开机B7问题(内存条ERROR),caffe编译matlab接口错误解决方法