【pytorch】(六)保存和加载模型
文章目录
- 保存和加载模型
- 保存加载模型参数
- 保存加载模型和参数
保存和加载模型
import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision.transforms import ToTensor, Lambdadevice = 'cuda' if torch.cuda.is_available() else 'cpu'
print(device)# ======================= 数据 =======================
training_data = datasets.FashionMNIST(root="data",train=True,download=True,transform=ToTensor()
)test_data = datasets.FashionMNIST(root="data",train=False,download=True,transform=ToTensor()
)train_dataloader = DataLoader(training_data, batch_size=64)
test_dataloader = DataLoader(test_data, batch_size=64)# ======================= 模型 =======================
class NeuralNetwork(nn.Module):def __init__(self):super(NeuralNetwork, self).__init__()self.linear_relu_stack = nn.Sequential(nn.Linear(28 * 28, 512),nn.ReLU(),nn.Linear(512, 512),nn.ReLU(),nn.Linear(512, 10),)def forward(self, x):x = x.flatten(1)logits = self.linear_relu_stack(x)return logitsdef train_loop(dataloader, model, loss_fn, optimizer):'''训练循环'''size = len(dataloader.dataset)for batch, (X, y) in enumerate(dataloader):X = X.to(device)y = y.to(device)# 计算估计值与损失pred = model(X)loss = loss_fn(pred, y)# 反向传播optimizer.zero_grad()loss.backward()optimizer.step()if batch % 100 == 0:loss, current = loss.item(), batch * len(X)print(f"loss: {loss:>7f} [{current:>5d}/{size:>5d}]")def test_loop(dataloader, model, loss_fn):'''测试循环'''size = len(dataloader.dataset)num_batches = len(dataloader)test_loss, correct = 0, 0with torch.no_grad():for X, y in dataloader:X = X.to(device)y = y.to(device)pred = model(X)test_loss += loss_fn(pred, y).item()correct += (pred.argmax(1) == y).type(torch.float).sum().item()test_loss /= num_batchescorrect /= sizeprint(f"Test Error: \n Accuracy: {(100 * correct):>0.1f}%, Avg loss: {test_loss:>8f} \n")# 超参数
learning_rate = 1e-3
batch_size = 64
epochs = 10# 模型实例
model = NeuralNetwork().to(device)
# 损失函数实例
loss_fn = nn.CrossEntropyLoss()
# 优化器实例
optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate)for t in range(epochs):print(f"Epoch {t + 1}\n-------------------------------")train_loop(train_dataloader, model, loss_fn, optimizer)test_loop(test_dataloader, model, loss_fn)
print("Done!")
在之前的章节中,我们介绍了如何用Pytorch搭建模型以及训练。在本节中,我们将了解到如何保存、加载模型。
保存加载模型参数
PyTorch模型可以将学习到的参数存储在内部状态字典(称为state_dict
)中。我们可以通过torch.save
保存这些参数:
torch.save(model.state_dict(), "model_state_dict.pth")
要加载模型参数,需要先创建同一模型的实例,然后使用load_state_dict()
方法加载。
model = NeuralNetwork()
model.load_state_dict(torch.load("model_state_dict.pth"))
使用加载的模型进行预测:
# model.eval()
classes = ["T-shirt/top","Trouser","Pullover","Dress","Coat","Sandal","Shirt","Sneaker","Bag","Ankle boot",
]
x, y = test_data[0][0], test_data[0][1]
with torch.no_grad():x = x.to(device)pred = model(x)predicted, actual = classes[pred[0].argmax(0)], classes[y]print(f'Predicted: "{predicted}", Actual: "{actual}"')
输出:
Predicted: "Ankle boot", Actual: "Ankle boot"
注意:
如果在训练模型时使用到dropout或batch normalization,则需要使用model.eval()
方法将网络设置为测试模式,否则将产生不一致的推断结果。
保存加载模型和参数
加载模型参数时,我们需要首先实例化模型类,因为该类定义了网络的结构。我们还可以将此类的结构与模型一起保存:将model
(而不是model.state_dict()
)传递给torch.save
:
torch.save(model, 'model.pth')
然后我们可以像这样加载模型:
model = torch.load("model.pth")
# model.eval()
classes = ["T-shirt/top","Trouser","Pullover","Dress","Coat","Sandal","Shirt","Sneaker","Bag","Ankle boot",
]
x, y = test_data[0][0], test_data[0][1]
with torch.no_grad():x = x.to(device)pred = model(x)predicted, actual = classes[pred[0].argmax(0)], classes[y]print(f'Predicted: "{predicted}", Actual: "{actual}"')
Predicted: "Ankle boot", Actual: "Ankle boot"
参考:
[1] https://pytorch.org/tutorials/beginner/basics/saveloadrun_tutorial.html
【pytorch】(六)保存和加载模型相关推荐
- python保存模型与参数_基于pytorch的保存和加载模型参数的方法
当我们花费大量的精力训练完网络,下次预测数据时不想再(有时也不必再)训练一次时,这时候torch.save(),torch.load()就要登场了. 保存和加载模型参数有两种方式: 方式一: torc ...
- pytorch load state dict_Pytorch学习记录-使用Pytorch进行深度学习,保存和加载模型
新建 Microsoft PowerPoint 演示文稿 (2).jpg 保存和加载模型 在完成60分钟入门之后,接下来有六节tutorials和五节关于文本处理的tutorials.争取一天一节.不 ...
- PyTorch | 保存和加载模型教程
点击上方"算法猿的成长",选择"加为星标" 第一时间关注 AI 和 Python 知识 图片来自 Unsplash,作者: Jenny Caywood 2019 ...
- Pytorch 保存和加载模型
当保存和加载模型时,需要熟悉三个核心功能: 1. torch.save :将序列化对象保存到磁盘.此函数使用Python的 pickle 模块进行序列化.使 用此函数可以保存如模型.tensor.字典 ...
- tensorflow 保存训练loss_tensorflow2.0保存和加载模型 (tensorflow2.0官方教程翻译)
最新版本:https://www.mashangxue123.com/tensorflow/tf2-tutorials-keras-save_and_restore_models.html 英文版本: ...
- TensorFlow 保存和加载模型
参考: 保存和恢复模型官方教程 tensorflow2保存和加载模型 TensorFlow2.0教程-keras模型保存和序列化
- pytorch 使用DataParallel 单机多卡和单卡保存和加载模型时遇到的问题
首先很多网上的博客,讲的都不对,自己跟着他们踩了很多坑 1.单卡训练,单卡加载 这里我为了把三个模块save到同一个文件里,我选择对所有的模型先封装成一个checkpoint字典,然后保存到同一个文件 ...
- pytorch保存和加载模型state_dict
保存模型: torch.save({'epoch': epoch + 1,'state_dict': model.state_dict(),'optimizer': optimizer.state_d ...
- 机器学习代码实战——保存和加载模型(Save and Load Model)
文章目录 1.实验目的 2.保存与加载模型 2.1.pickle方法 2.2.joblib方法 1.实验目的 每当我们训练完一个模型后,我们需要保存训练好的模型留给下次用或者再次训练,因此我将给出两种 ...
最新文章
- 3NF分解与BCNF分解
- 2020-12-11 图片格式互转:base64、PIL Image opencv cv2互转
- methanol 模块化的可定制的网页爬虫软件,主要的优点是速度快。
- linux是否有免安装程序,在线Ubuntu Linux系统,免安装体验Linux系统
- as it exceeds the max of 500KB._IT狂人第一季 | 如何考察员工
- Object.defineProperty 详解
- STM32之ADC单通道连续例程
- 微软 azure_Microsoft Azure,我们迁移数据的第一步
- 3.26 Tensorflow 实验记录
- 再续上一篇:如果哪天沃尔玛也“.CN”了
- 海康IP摄像头rtsp解码(ubuntu上使用)
- IBM P750 AIX机器根目录空间满问题解决办法
- python注释可用于表明作者和版权信息_vs2017 新建Class 文件时,自动添加作者版权声明注释...
- 多人配音怎么做的?这两个多人配音方法分享给你
- hello 驱动编写-最简单的驱动程序
- 详解互联网平台的资金系统方案 自建支付清结算系统优势明显
- C10k-problem
- 英语听力采用计算机化考试,北京高考英语听力机考有什么特点?
- 基于存储的C语言文件操作常规问题分析(文本文件与二进制文件)
- 数据流代替工作流解决方案