有时候,我们想要保存训练好的模型,等需要用来进行图像分类等任务的时候,不经训练,直接加载使用。
这时,可以采用torch.save(model, 'data/cnn_model.pt')保存模型(放在代码最后面),如:

import torch
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import numpy as npimport torch.nn as nn
import torch.nn.functional as Ftransform = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")trainset = torchvision.datasets.CIFAR10(root='./data', train=True,download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=4,shuffle=True, num_workers=0)testset = torchvision.datasets.CIFAR10(root='./data', train=False,download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=4,shuffle=False, num_workers=0)classes = ('plane', 'car', 'bird', 'cat','deer', 'dog', 'frog', 'horse', 'ship', 'truck')class Net(nn.Module):def __init__(self):super(Net, self).__init__()self.conv1 = nn.Conv2d(3, 6, 5)self.pool = nn.MaxPool2d(2, 2)self.conv2 = nn.Conv2d(6, 16, 5)self.fc1 = nn.Linear(16 * 5 * 5, 120)self.fc2 = nn.Linear(120, 84)self.fc3 = nn.Linear(84, 10)def forward(self, x):x = self.pool(F.relu(self.conv1(x)))x = self.pool(F.relu(self.conv2(x)))x = x.view(-1, 16 * 5 * 5)x = F.relu(self.fc1(x))x = F.relu(self.fc2(x))x = self.fc3(x)return xnet = Net()
net.to(device)import torch.optim as optimcriterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)for epoch in range(2):  # loop over the dataset multiple timesrunning_loss = 0.0for i, data in enumerate(trainloader, 0):# get the inputsinputs, labels = datainputs, labels = inputs.to(device), labels.to(device)# zero the parameter gradientsoptimizer.zero_grad()# forward + backward + optimizeoutputs = net(inputs)loss = criterion(outputs, labels)loss.backward()optimizer.step()# print statisticsrunning_loss += loss.item()if i % 2000 == 1999:    # print every 2000 mini-batchesprint('[%d, %5d] loss: %.3f' %(epoch + 1, i + 1, running_loss / 2000))running_loss = 0.0print('Finished Training')correct = 0
total = 0
with torch.no_grad():for data in testloader:images, labels = dataimages, labels = images.to(device), labels.to(device)outputs = net(images)_, predicted = torch.max(outputs.data, 1)total += labels.size(0)correct += (predicted == labels).sum().item()print('Accuracy of the network on the 10000 test images: %d %%' % (100 * correct / total))class_correct = list(0. for i in range(10))
class_total = list(0. for i in range(10))
with torch.no_grad():for data in testloader:images, labels = dataimages, labels = images.to(device), labels.to(device)outputs = net(images)_, predicted = torch.max(outputs, 1)c = (predicted == labels).squeeze()for i in range(4):label = labels[i]class_correct[label] += c[i].item()class_total[label] += 1for i in range(10):print('Accuracy of %5s : %2d %%' % (classes[i], 100 * class_correct[i] / class_total[i]))torch.save(net, 'data/net_model.pt')

然后新建predict.py,采用model = torch.load('./data/cnn_model.pt')加载前面训练好的模型,如:

import numpy as np
import torch
import torch.nn as nn
from utils import *
import time
from matplotlib import pyplot as plt
from torch.utils.tensorboard import SummaryWriter
import torchvision
import torchinfodevice = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')trainset = torchvision.datasets.CIFAR10(root='./data', train=True,download=True, transform=transform)
train_loader = torch.utils.data.DataLoader(trainset, batch_size=4,shuffle=True, num_workers=0)testset = torchvision.datasets.CIFAR10(root='./data', train=False,download=True, transform=transform)
test_loader = torch.utils.data.DataLoader(testset, batch_size=4,shuffle=False, num_workers=0)class Net(nn.Module):def __init__(self):super(Net, self).__init__()self.conv1 = nn.Conv2d(3, 6, 5)self.pool = nn.MaxPool2d(2, 2)self.conv2 = nn.Conv2d(6, 16, 5)self.fc1 = nn.Linear(16 * 5 * 5, 120)self.fc2 = nn.Linear(120, 84)self.fc3 = nn.Linear(84, 10)def forward(self, x):x = self.pool(F.relu(self.conv1(x)))x = self.pool(F.relu(self.conv2(x)))x = x.view(-1, 16 * 5 * 5)x = F.relu(self.fc1(x))x = F.relu(self.fc2(x))x = self.fc3(x)return x# 加载模型
model = torch.load('./data/net.pt')model.eval()
test_correct_count = 0
test_total_count = 0
test_loss = 0
with torch.no_grad():for batch_count, (test_image, test_label) in enumerate(test_loader):test_image, test_label = test_image.to(device), test_label.to(device)outputs = model(test_image)_, pred = torch.max(outputs.data, 1)# loss = criterion(outputs, test_label)# test_loss += loss.item()test_correct_count += torch.sum(pred == test_label).item()test_total_count += test_label.size(0)
# test_loss /= len(test_loader)test_accuracy = test_correct_count / (num_classes * num_samples_test)
print('测试集正确率:',format(100 * test_accuracy, '.2f'), '%',)

pytorch实战(四)——模型的保存与读取相关推荐

  1. sklearn与pytorch模型的保存与读取

    当我们花了很长时间训练了一个模型,需要用该模型做其他事情(比如迁移学习),或者我们想把自己的机器学习模型分享出去的时候,我们这时候需要将我们的ML模型持久化到硬盘中去. 1.sklearn中模型的保存 ...

  2. Keras——模型的保存、读取及加载

    本文将会介绍如何利用Keras来实现模型的保存.读取以及加载.   本文使用的模型为解决IRIS数据集的多分类问题而设计的深度神经网络(DNN)模型,模型的结构示意图如下: 具体的模型参数可以参考文章 ...

  3. Pytorch中参数和模型的保存与读取

    Tensor变量的存取(包括parameter) 对于普通Tensor变量的存取,如下代码所示: import torch import torch.nn as nn x = torch.ones(3 ...

  4. Keras中的各种Callback函数示例(含Checkpoint模型的保存、读取示例)-----记录

    本文整理了绝大多数keras里的Callback回调)函数,并且收集了代码调用示例. 大多数内容整理自网络,参考资料已在文章最后给出. 回调函数Callbacks 回调函数是一组在训练的特定阶段被调用 ...

  5. Pytorch实战_Seq2seq模型

    1. Sequence-to-Sequence 简介 大多数常见的 sequence-to-sequence (seq2seq) model 为 encoder-decoder model,主要由两个 ...

  6. 如何:在OpenText Workflow 6.5模型中保存和读取多行数据

    在Captaris Workflow 6.0和之前的版本中,保存多行数据似乎没有被提及,因此大部分(包括我的团队)都要自己建立(利用IDE的向导也算)数据库来保存订单项.物品列表.人员列表这样的多行集 ...

  7. 2-3实战分类模型之数据的读取与展示

    模块名称:tf_keras_classification_model 引入相关的包 sklearn 是一个机器学习常用的库对机器学习的常用算法进行封装 %matplotlib inline impor ...

  8. Tensorflow2 图像分类-Flowers数据深度学习模型保存、读取、参数查看和图像预测

    目录 1.原文完整代码 1.1 模型运行参数总结 1.2模型训练效果 ​编辑2.模型的保存 3.读取模型model 4.使用模型进行图片预测 5.补充 如何查看保存模型参数 5.1 model_wei ...

  9. PyTorch模型的保存加载以及数据的可视化

    文章目录 PyTorch模型的保存和加载 模块和张量的序列化和反序列化 模块状态字典的保存和载入 PyTorch数据的可视化 TensorBoard的使用 总结 PyTorch模型的保存和加载 在深度 ...

最新文章

  1. Android Service学习之本地服务
  2. python画剖面图_如何创建Matplotlib图形与图像和剖面图相匹配?
  3. js中while死循环语句_Java系列教程day06——循环语句
  4. 使用C#编程解决数独求解过程(从图片识别到数独求解)第二篇
  5. 围棋经典棋谱_秀秀老师:茶艺师也要学好围棋
  6. 一群阿里人如何用 10 年自研洛神云网络平台?技术架构演进全揭秘!
  7. 神龙X-Dragon,这技术“范儿”如何?| 问底中国IT技术演进
  8. 好程序员分享居中一个float元素
  9. 学习OpenStack之(6):Neutron 深入学习之 OVS + GRE 之 Compute node 篇
  10. set获取元素_C++与STL入门(4):关联容器:集合set
  11. 190808每日一句
  12. windows 启动c\windows\systen32\spool\DRIVERS\W32x86\3\ssnetmon. dll 时出现问题 找不到指定模块
  13. lisp 图层字体式样替换_ps将不同图层字体修改成相同字体的方法
  14. thinkpad10平板电脑装linux,ThinkPad X61上经历Ubuntu 8.10(安装笔记)
  15. 【释义详解】Software License (软件许可证)是什么?
  16. 用计算机公式计算优良,『excel怎样合并单元格』如何在EXCEL中如何用公式计算全年级各班各科平均分、优秀率、合格率的方法...
  17. C语言bool类型定义
  18. C#序列化与反序列化学习
  19. rfid射频前端的主要组成部分有_7.3.1 RFID电子标签射频前端的结构
  20. mysql fnv64函数_FNV哈希算法 - osc_tiaoycd5的个人空间 - OSCHINA - 中文开源技术交流社区...

热门文章

  1. Win PE在移动硬盘上“安家”
  2. 微信小程序:去除自带顶部导航栏
  3. 安装java1.6_JAVA1.6怎么安装
  4. IT30: IT人怎样成为解决问题的高手
  5. 华工计算机组成原理第一次作业,华工网络教育学院2018计算机组成原理作业
  6. python写一个爬虫、爬取网站漫画信息_python爬取漫画
  7. 微服务化的不同阶段 Kubernetes 的不同玩法
  8. 启动盘制作 rufus3.8下载
  9. 【黑金动力社区】【531体验板教程】 第三章 开发环境(三)
  10. 2022中国中医药产业展,山东医药保健展,济南药交会9月举办