1.数据引入

import torch from torch
import nn from torch.utils.data
import DataLoader from torchvision
import datasets from torchvision.transforms import ToTensor

2.训练集与测试集

我们用到的数据集是FashionMNIST,是一个图像数据集,用它来进行分类任务。
dataloader用来存放相应的训练数据以及对应的标签
dataset将包装一个可迭代的数据集,

在这里插入代码片
# Download training data from open datasets.
training_data = datasets.FashionMNIST(root="data",train=True,download=True,transform=ToTensor(),
)
# Download test data from open datasets.
test_data = datasets.FashionMNIST(root="data",train=False,download=True,transform=ToTensor(),
)
'''
dataset作为dataloader的参数传入,在数据集上包裹一个可迭代的对象,支持batchsize 加载,混淆,多进程数据加载
'''
batch_size = 64
# Create data loaders.
train_dataloader = DataLoader(training_data, batch_size=batch_size)
test_dataloader = DataLoader(test_data, batch_size=batch_size)
for X, y in test_dataloader:print(f"Shape of X [N, C, H, W]: {X.shape}")print(f"Shape of y: {y.shape} {y.dtype}")break

3.创建模型

从nn.Module继承并创建一个类,用来定义网络结构

# Get cpu or gpu device for training.
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using {device} device")
# Define model
class NeuralNetwork(nn.Module):def __init__(self):super(NeuralNetwork, self).__init__()self.flatten = nn.Flatten()self.linear_relu_stack = nn.Sequential(nn.Linear(28*28, 512),nn.ReLU(),nn.Linear(512, 512),nn.ReLU(),nn.Linear(512, 10))def forward(self, x):x = self.flatten(x)logits = self.linear_relu_stack(x)return logits
model = NeuralNetwork().to(device)
print(model)

4.模型超参优化

定义损失函数和用的优化器

loss_fn = nn.CrossEntropyLoss()
optimizer =torch.optim.SGD(model.parameters(), lr=1e-3)

训练与反向传播代码

def train(dataloader, model, loss_fn, optimizer):size = len(dataloader.dataset)model.train()for batch, (X, y) in enumerate(dataloader):X, y = X.to(device), y.to(device)# Compute prediction errorpred = model(X)loss = loss_fn(pred, y)# Backpropagationoptimizer.zero_grad()loss.backward()optimizer.step()if batch % 100 == 0:loss, current = loss.item(), batch * len(X)print(f"loss: {loss:>7f}  [{current:>5d}/{size:>5d}]")

在训练过程中模型的测试代码

def test(dataloader, model, loss_fn):size = len(dataloader.dataset)num_batches = len(dataloader)model.eval()test_loss, correct = 0, 0with torch.no_grad():for X, y in dataloader:X, y = X.to(device), y.to(device)pred = model(X)test_loss += loss_fn(pred, y).item()correct += (pred.argmax(1) == y).type(torch.float).sum().item()test_loss /= num_batchescorrect /= sizeprint(f"Test Error: \n Accuracy: {(100*correct):>0.1f}%, Avg loss: {test_loss:>8f} \n")

训练过程会持续几个epoch,每经历一个epoch模型学的参数会进行更好的预测。随着训练的进行,模型良好的趋势是模型的accuracy会不断的提高,loss会逐步的下降。

epochs = 5
for t in range(epochs):print(f"Epoch {t+1}\n-------------------------------")train(train_dataloader, model, loss_fn, optimizer)test(test_dataloader, model, loss_fn)
print("Done!")

5.模型保存

保存模型的一个常见方法是序列化内部状态字典(包含模型参数)。

torch.save(model.state_dict(), "model.pth")
print("Saved PyTorch Model State to model.pth")

6.模型加载与预测

加载模型的过程包括重新创建模型结构并将状态字典加载到其中。

model = NeuralNetwork()
model.load_state_dict(torch.load("model.pth"))

加载完模型后可用来进行预测

classes = ["T-shirt/top","Trouser","Pullover","Dress","Coat","Sandal","Shirt","Sneaker","Bag","Ankle boot",
]model.eval()
x, y = test_data[0][0], test_data[0][1]
with torch.no_grad():pred = model(x)predicted, actual = classes[pred[0].argmax(0)], classes[y]print(f'Predicted: "{predicted}", Actual: "{actual}"')

7.参考资料

https://pytorch.org/tutorials/beginner/basics/quickstart_tutorial.html.

pytorch笔记-实现一个图像分类模型相关推荐

  1. 【小白学习PyTorch教程】六、基于CIFAR-10 数据集,使用PyTorch 从头开始​​构建图像分类模型...

    「@Author:Runsen」 图像识别本质上是一种计算机视觉技术,它赋予计算机"眼睛",让计算机通过图像和视频"看"和理解世界. 在开始阅读本文之前,建议先 ...

  2. 【小白学习PyTorch教程】六、基于CIFAR-10 数据集,使用PyTorch 从头开始​​构建图像分类模型

    @Author:Runsen 图像识别本质上是一种计算机视觉技术,它赋予计算机"眼睛",让计算机通过图像和视频"看"和理解世界. 在开始阅读本文之前,建议先了解 ...

  3. 【小白学习PyTorch教程】九、基于Pytorch训练第一个RNN模型

    「@Author:Runsen」 当阅读一篇课文时,我们可以根据前面的单词来理解每个单词的,而不是从零开始理解每个单词.这可以称为记忆.卷积神经网络模型(CNN)不能实现这种记忆,因此引入了递归神经网 ...

  4. 利用Pytorch搭建简单的图像分类模型(之二)---搭建网络

    Pytorch搭建网络模型-ResNet 一.ResNet的两个结构 首先来看一下ResNet和一般卷积网络结构上的差异: 图中上面一部分就是ResNet34的网络结构图,下面可以理解为一个含有34层 ...

  5. 把一个dataset的表放在另一个dataset里面_现在开始:用你的Mac训练和部署一个图像分类模型...

    可能有些同学学习机器学习的时候比较迷茫,不知道该怎么上手,看了很多经典书籍介绍的各种算法,但还是不知道怎么用它来解决问题,就算知道了,又发现需要准备环境.准备训练和部署的机器,啊,好麻烦. 今天,我来 ...

  6. 使用pytorch训练你自己的图像分类模型(包括模型训练、推理预测、误差分析)

    开源代码:https://github.com/xxcheng0708/Pytorch_Image_Classifier_Template​​​​​ 使用pytorch框架搭建一个图像分类模型通常包含 ...

  7. 独家 | 手把手教你用Python构建你的第一个多标签图像分类模型(附案例)

    翻译:吴金笛 校对:郑滋 本文约4600字,建议阅读12分钟. 本文明确了多标签图像分类的概念,并讲解了如何构建多标签图像分类模型. 介绍 你正在处理图像数据吗?我们可以使用计算机视觉算法来做很多事情 ...

  8. python如何训练模型生产_手把手教你用Python构建你的第一个多标签图像分类模型(附案例)...

    你正在处理图像数据吗?我们可以使用计算机视觉算法来做很多事情: 对象检测 图像分割 图像翻译 对象跟踪(实时),还有更多-- 这让我思考--如果一个图像中有多个对象类别,我们该怎么办?制作一个图像分类 ...

  9. 手把手教你用Python构建你的第一个多标签图像分类模型(附案例)

    原文链接: https://www.analyticsvidhya.com/blog/2019/04/build-first-multi-label-image-classification-mode ...

  10. python图片分类技术介绍_手把手教你用Python构建你的第一个多标签图像分类模型(附案例)!...

    介绍 你正在处理图像数据吗?我们可以使用计算机视觉算法来做很多事情:对象检测 图像分割 图像翻译 对象跟踪(实时),还有更多-- 这让我思考--如果一个图像中有多个对象类别,我们该怎么办?制作一个图像 ...

最新文章

  1. 戴尔服务器远程访问管理卡iDRAC 7详解
  2. 谷歌设计规范_[图]谷歌Play商城启用圆角矩形图标设计规范 6月24日强制生效
  3. 几种网站后门排查 不全面
  4. QT解析 JSON 格式的数据
  5. flink scala shell命令行使用示例
  6. java定义一个door的类_再探Java抽象类与接口的设计理念差异
  7. 计组-CISC和RISC的基本概念
  8. HDU - 5335 Walk Out(bfs+路径输出+贪心)
  9. php 同一行,php – 如何在同一行中对类方法进行多个调用?
  10. 甜蜜暴击,情人节插画素材,甜而不腻!
  11. 姿态坐标c语言,判断 AR 中坐标系的姿态和位置的简单方法
  12. Cadence Gerber文件制作过程
  13. ZEGO 自研客户端配置管理系统 —— 云控
  14. 土木学matlab还是python_五行属土的字大全
  15. 女神瓦萨比-小黑中国力鉴淘宝给力明星店
  16. 8/11 Perl和Postgresql联合在京交流会 Perl6项目经理远道参加
  17. android anr 文件路径,android出现ANR 如何导出anr文件
  18. 获取固定到任务栏的快捷方式的图标
  19. 【收藏】实验室十大常见危险操作,关乎生命!
  20. 坚持到底就是成功,坚持到底就是富有,坚持到底就是胜利

热门文章

  1. web.config中特殊字符的处理
  2. Jquery 学习心得和资料
  3. 基于OleDb的Excel数据访问
  4. 文章,记录按内容分页显示,根据文章内容按字数进行分页(转)
  5. c#抽取pdf文档标题(1)
  6. 笔记本电脑触摸板的正确使用方法 --转摘
  7. Django2.1简介及安装
  8. P4692 [Ynoi2016]谁的梦
  9. 深浅拷贝和数列,变量的区别
  10. 由脚本创建的新元素事件不触发和用的easyUI插件中的多选框不起作用的解决方法...