pytorch笔记-实现一个图像分类模型
1.数据引入
import torch from torch
import nn from torch.utils.data
import DataLoader from torchvision
import datasets from torchvision.transforms import ToTensor
2.训练集与测试集
我们用到的数据集是FashionMNIST,是一个图像数据集,用它来进行分类任务。
dataloader用来存放相应的训练数据以及对应的标签
dataset将包装一个可迭代的数据集,
在这里插入代码片
# Download training data from open datasets.
training_data = datasets.FashionMNIST(root="data",train=True,download=True,transform=ToTensor(),
)
# Download test data from open datasets.
test_data = datasets.FashionMNIST(root="data",train=False,download=True,transform=ToTensor(),
)
'''
dataset作为dataloader的参数传入,在数据集上包裹一个可迭代的对象,支持batchsize 加载,混淆,多进程数据加载
'''
batch_size = 64
# Create data loaders.
train_dataloader = DataLoader(training_data, batch_size=batch_size)
test_dataloader = DataLoader(test_data, batch_size=batch_size)
for X, y in test_dataloader:print(f"Shape of X [N, C, H, W]: {X.shape}")print(f"Shape of y: {y.shape} {y.dtype}")break
3.创建模型
从nn.Module继承并创建一个类,用来定义网络结构
# Get cpu or gpu device for training.
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using {device} device")
# Define model
class NeuralNetwork(nn.Module):def __init__(self):super(NeuralNetwork, self).__init__()self.flatten = nn.Flatten()self.linear_relu_stack = nn.Sequential(nn.Linear(28*28, 512),nn.ReLU(),nn.Linear(512, 512),nn.ReLU(),nn.Linear(512, 10))def forward(self, x):x = self.flatten(x)logits = self.linear_relu_stack(x)return logits
model = NeuralNetwork().to(device)
print(model)
4.模型超参优化
定义损失函数和用的优化器
loss_fn = nn.CrossEntropyLoss()
optimizer =torch.optim.SGD(model.parameters(), lr=1e-3)
训练与反向传播代码
def train(dataloader, model, loss_fn, optimizer):size = len(dataloader.dataset)model.train()for batch, (X, y) in enumerate(dataloader):X, y = X.to(device), y.to(device)# Compute prediction errorpred = model(X)loss = loss_fn(pred, y)# Backpropagationoptimizer.zero_grad()loss.backward()optimizer.step()if batch % 100 == 0:loss, current = loss.item(), batch * len(X)print(f"loss: {loss:>7f} [{current:>5d}/{size:>5d}]")
在训练过程中模型的测试代码
def test(dataloader, model, loss_fn):size = len(dataloader.dataset)num_batches = len(dataloader)model.eval()test_loss, correct = 0, 0with torch.no_grad():for X, y in dataloader:X, y = X.to(device), y.to(device)pred = model(X)test_loss += loss_fn(pred, y).item()correct += (pred.argmax(1) == y).type(torch.float).sum().item()test_loss /= num_batchescorrect /= sizeprint(f"Test Error: \n Accuracy: {(100*correct):>0.1f}%, Avg loss: {test_loss:>8f} \n")
训练过程会持续几个epoch,每经历一个epoch模型学的参数会进行更好的预测。随着训练的进行,模型良好的趋势是模型的accuracy会不断的提高,loss会逐步的下降。
epochs = 5
for t in range(epochs):print(f"Epoch {t+1}\n-------------------------------")train(train_dataloader, model, loss_fn, optimizer)test(test_dataloader, model, loss_fn)
print("Done!")
5.模型保存
保存模型的一个常见方法是序列化内部状态字典(包含模型参数)。
torch.save(model.state_dict(), "model.pth")
print("Saved PyTorch Model State to model.pth")
6.模型加载与预测
加载模型的过程包括重新创建模型结构并将状态字典加载到其中。
model = NeuralNetwork()
model.load_state_dict(torch.load("model.pth"))
加载完模型后可用来进行预测
classes = ["T-shirt/top","Trouser","Pullover","Dress","Coat","Sandal","Shirt","Sneaker","Bag","Ankle boot",
]model.eval()
x, y = test_data[0][0], test_data[0][1]
with torch.no_grad():pred = model(x)predicted, actual = classes[pred[0].argmax(0)], classes[y]print(f'Predicted: "{predicted}", Actual: "{actual}"')
7.参考资料
https://pytorch.org/tutorials/beginner/basics/quickstart_tutorial.html.
pytorch笔记-实现一个图像分类模型相关推荐
- 【小白学习PyTorch教程】六、基于CIFAR-10 数据集,使用PyTorch 从头开始构建图像分类模型...
「@Author:Runsen」 图像识别本质上是一种计算机视觉技术,它赋予计算机"眼睛",让计算机通过图像和视频"看"和理解世界. 在开始阅读本文之前,建议先 ...
- 【小白学习PyTorch教程】六、基于CIFAR-10 数据集,使用PyTorch 从头开始构建图像分类模型
@Author:Runsen 图像识别本质上是一种计算机视觉技术,它赋予计算机"眼睛",让计算机通过图像和视频"看"和理解世界. 在开始阅读本文之前,建议先了解 ...
- 【小白学习PyTorch教程】九、基于Pytorch训练第一个RNN模型
「@Author:Runsen」 当阅读一篇课文时,我们可以根据前面的单词来理解每个单词的,而不是从零开始理解每个单词.这可以称为记忆.卷积神经网络模型(CNN)不能实现这种记忆,因此引入了递归神经网 ...
- 利用Pytorch搭建简单的图像分类模型(之二)---搭建网络
Pytorch搭建网络模型-ResNet 一.ResNet的两个结构 首先来看一下ResNet和一般卷积网络结构上的差异: 图中上面一部分就是ResNet34的网络结构图,下面可以理解为一个含有34层 ...
- 把一个dataset的表放在另一个dataset里面_现在开始:用你的Mac训练和部署一个图像分类模型...
可能有些同学学习机器学习的时候比较迷茫,不知道该怎么上手,看了很多经典书籍介绍的各种算法,但还是不知道怎么用它来解决问题,就算知道了,又发现需要准备环境.准备训练和部署的机器,啊,好麻烦. 今天,我来 ...
- 使用pytorch训练你自己的图像分类模型(包括模型训练、推理预测、误差分析)
开源代码:https://github.com/xxcheng0708/Pytorch_Image_Classifier_Template 使用pytorch框架搭建一个图像分类模型通常包含 ...
- 独家 | 手把手教你用Python构建你的第一个多标签图像分类模型(附案例)
翻译:吴金笛 校对:郑滋 本文约4600字,建议阅读12分钟. 本文明确了多标签图像分类的概念,并讲解了如何构建多标签图像分类模型. 介绍 你正在处理图像数据吗?我们可以使用计算机视觉算法来做很多事情 ...
- python如何训练模型生产_手把手教你用Python构建你的第一个多标签图像分类模型(附案例)...
你正在处理图像数据吗?我们可以使用计算机视觉算法来做很多事情: 对象检测 图像分割 图像翻译 对象跟踪(实时),还有更多-- 这让我思考--如果一个图像中有多个对象类别,我们该怎么办?制作一个图像分类 ...
- 手把手教你用Python构建你的第一个多标签图像分类模型(附案例)
原文链接: https://www.analyticsvidhya.com/blog/2019/04/build-first-multi-label-image-classification-mode ...
- python图片分类技术介绍_手把手教你用Python构建你的第一个多标签图像分类模型(附案例)!...
介绍 你正在处理图像数据吗?我们可以使用计算机视觉算法来做很多事情:对象检测 图像分割 图像翻译 对象跟踪(实时),还有更多-- 这让我思考--如果一个图像中有多个对象类别,我们该怎么办?制作一个图像 ...
最新文章
- 戴尔服务器远程访问管理卡iDRAC 7详解
- 谷歌设计规范_[图]谷歌Play商城启用圆角矩形图标设计规范 6月24日强制生效
- 几种网站后门排查 不全面
- QT解析 JSON 格式的数据
- flink scala shell命令行使用示例
- java定义一个door的类_再探Java抽象类与接口的设计理念差异
- 计组-CISC和RISC的基本概念
- HDU - 5335 Walk Out(bfs+路径输出+贪心)
- php 同一行,php – 如何在同一行中对类方法进行多个调用?
- 甜蜜暴击,情人节插画素材,甜而不腻!
- 姿态坐标c语言,判断 AR 中坐标系的姿态和位置的简单方法
- Cadence Gerber文件制作过程
- ZEGO 自研客户端配置管理系统 —— 云控
- 土木学matlab还是python_五行属土的字大全
- 女神瓦萨比-小黑中国力鉴淘宝给力明星店
- 8/11 Perl和Postgresql联合在京交流会 Perl6项目经理远道参加
- android anr 文件路径,android出现ANR 如何导出anr文件
- 获取固定到任务栏的快捷方式的图标
- 【收藏】实验室十大常见危险操作,关乎生命!
- 坚持到底就是成功,坚持到底就是富有,坚持到底就是胜利