一、数据集介绍

该数据集共有60000张彩色图像,这些图像是32*32,分为10个类,每类6000张图。这里面有50000张用于训练,构成了5个训练批,每一批10000张图;另外10000用于测试,单独构成一批。测试批的数据里,取自10类中的每一类,每一类随机取1000张。抽剩下的就随机排列组成了训练批。注意一个训练批中的各类图像并不一定数量相同,总的来看训练批,每一类都有5000张图。

下面这幅图就是列举了10各类,每一类展示了随机的10张图片:

二、搭建神经网络模型

使用CIFAR10网路模型,基于pytorch搭建网络模型

import torch
from torch import nnclass Test(nn.Module):def __init__(self):super(Test, self).__init__()self.model = nn.Sequential(nn.Conv2d(3, 32, 5, padding=2),nn.MaxPool2d(2),nn.Conv2d(32, 32, 5, padding=2),nn.MaxPool2d(2),nn.Conv2d(32, 64, 5, padding=2),nn.MaxPool2d(2),nn.Flatten(),nn.Linear(1024, 64),nn.Linear(64, 10))def forward(self,x):x = self.model(x)return x

三、数据集的准备及加载

使用torchvision.datasets.CIFAR10()加载数据集,train=True表示数据集为训练数据集,train=False表示数据集为测试集,dowwnload=True表示下载数据集,本地存在数据集不会再次下载。

import torchvision
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
import time
from model import *# 定义训练设备
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")# 准备数据集
train_data = torchvision.datasets.CIFAR10("dataset", train=True, transform=torchvision.transforms.ToTensor(),download=True)
test_data = torchvision.datasets.CIFAR10("dataset", train=False, transform=torchvision.transforms.ToTensor(),download=True)
train_data_size = len(train_data)
test_data_size = len(test_data)
# print("训练数据集的长度为{}".format(train_data_size))
# print("测试数据集的长度为{}".format(test_data_size))# 利用DataLoader来加载数据集
train_dataloader = DataLoader(train_data,batch_size=64)
test_dataloader = DataLoader(test_data,batch_size=64)

 四、神经网络、损失函数、优化器等加载

test = Test()
test = test.to(device)# 损失函数
loss_fn = nn.CrossEntropyLoss()
loss_fn = loss_fn.to(device)# 优化器
learning_rate = 0.01
optimizer = torch.optim.SGD(test.parameters(), lr=learning_rate)# 设置训练网络的一些参数
# 记录训练的次数
total_train_step = 0
# 记录测试的次数
total_test_step = 0
epoch = 30# 添加Tensorboard
writer = SummaryWriter("logs_train")

五、训练、测试、模型保存

start_time = time.time()
for i in range(epoch):print("-----第{}轮训练开始------".format(i+1))# 训练步骤开始test.train()for data in train_dataloader:imgs,targets = dataimgs = imgs.to(device)targets = targets.to(device)output = test(imgs)loss = loss_fn(output, targets)# 优化器优化模型optimizer.zero_grad()loss.backward()optimizer.step()total_train_step += 1if total_train_step % 100 == 0:end_time = time.time()print(end_time - start_time)print("训练次数{}, Loss:{}".format(total_train_step, loss.item()))writer.add_scalar("train_loss", loss.item(), total_train_step)# 测试步骤开始test.eval()total_test_loss = 0total_accuracy = 0with torch.no_grad():for data in test_dataloader:imgs, targets = dataimgs = imgs.to(device)targets = targets.to(device)outputs = test(imgs)loss = loss_fn(outputs, targets)total_test_loss += loss.item()accuracy = (outputs.argmax(1) == targets).sum()total_accuracy += accuracyprint("整体测试集上的Loss: {}".format(total_test_loss))print("整体测试集上的正确率: {}".format(total_accuracy/test_data_size))writer.add_scalar("test_loss", total_test_loss, total_test_step)writer.add_scalar("test_accuracy", total_accuracy/test_data_size, total_test_step)total_test_step += 1torch.save(test, "test_{}.pth".format(i))print("模型已保存")writer.close()

六、模型的加载及测试

import torch
import torchvision
from PIL import Image
from model import *
CLASS = ["airplane", "automobile", "bird", "cat", "deer", "dog", "frog", "horse", "ship", "truck"]
image_path = "./imgs/dog.png"
image = Image.open(image_path)
# print(image)transform = torchvision.transforms.Compose([torchvision.transforms.Resize((32, 32)),torchvision.transforms.ToTensor()])image = transform(image)
# print(image.shape)model = torch.load("test_99.pth", map_location=torch.device('cpu'))
# print(model)
image = torch.reshape(image, (1, 3, 32, 32))
model.eval()
with torch.no_grad():output = model(image)
ret = output.argmax(1)
ret = ret.numpy()
print("预测结果为:{}".format(CLASS[ret[0]]))

CIFAR10数据集训练及测试相关推荐

  1. caffe学习(五):cifar-10数据集训练及测试(Ubuntu)

    简介 网站链接:CIFAR-10 CIFAR-10数据集包括由10个类别的事物,每个事物各有6000张彩色图像,每张图片的大小是32*32. 整个数据集被分成了5个训练集和1个测试集,各有10000张 ...

  2. cifar10数据集测试有多少张图_pytorch VGG11识别cifar10数据集(训练+预测单张输入图片操作)...

    首先这是VGG的结构图,VGG11则是红色框里的结构,共分五个block,如红框中的VGG11第一个block就是一个conv3-64卷积层: 一,写VGG代码时,首先定义一个 vgg_block(n ...

  3. 基于Keras搭建cifar10数据集训练预测Pipeline

    基于Keras搭建cifar10数据集训练预测Pipeline 钢笔先生关注 0.5412019.01.17 22:52:05字数 227阅读 500 Pipeline 本次训练模型的数据直接使用Ke ...

  4. TF之CNN:基于CIFAR-10数据集训练、检测CNN(2+2)模型(TensorBoard可视化)

    TF之CNN:基于CIFAR-10数据集训练.检测CNN(2+2)模型(TensorBoard可视化) 目录 1.基于CIFAR-10数据集训练CNN(2+2)模型代码 2.检测CNN(2+2)模型 ...

  5. cifar10数据集训练

    有关CIFAR-10数据集 (1)CIFAR-10数据集由10个类的60000个32x32彩色图像组成,每个类有6000个图像.有50000个训 练图像和10000个测试图像. (2)数据集分为五个训 ...

  6. yolo-v2 自己的数据集训练以及测试流程(仅供内部使用!)

    warning 该流程仅供内部使用,外部人士使用可能会报很多很多错误! 步骤 先清除backup文件夹中老的权重文件: 将标定好图片以及annotation .txt文件拷贝到obj文件夹,一一对应, ...

  7. 【小白学习keras教程】二、基于CIFAR-10数据集训练简单的MLP分类模型

    @Author:Runsen 分类任务的MLP 当目标(y)是离散的(分类的) 对于损失函数,使用交叉熵:对于评估指标,通常使用accuracy 数据集描述 CIFAR-10数据集包含10个类中的60 ...

  8. 深度学习入门 FashionMNIST数据集训练和测试(30层神经网路)

    使用pytorch框架.模型包含13层卷积层.2层池化层.15层全连接层.为什么叠这么多层?就是玩. FashionMNIST数据集包含训练集6w张图片,测试集1w张图片,每张图片是单通道.大小28× ...

  9. 【Pytorch实战4】基于CIFAR10数据集训练一个分类器

    参考资料: <深度学习之pytorch实战计算机视觉> Pytorch官方教程 Pytorch中文文档 先是数据的导入与预览. import torch import torchvisio ...

最新文章

  1. sql和泛型方法返回泛型_基于泛型编程的序列化实现方法
  2. 两种求集合全部子集的方法
  3. SAP Commerce Cloud deprecation机制
  4. 数学除了摧残祖国的花朵外,竟然还可以赢钱!
  5. 前端学习(2225):react之类定义组件
  6. oracle 11g 如何实现坏块检查、恢复?
  7. android+ndk+r9+x64下载,Win7 64位中文旗舰版上Cocos2d-x 3.0的Android开发调试环境架设
  8. nginx代理php不能跳转页面,nginx 解决首页跳转问题详解
  9. python3扬州大学校园网认证登录与下线
  10. 数据优化 | CnOpenData中国工业企业绿色专利及引用被引用数据
  11. 《那些年啊,那些事——一个程序员的奋斗史》八
  12. 易语言静态连接器提取_正确易语言链接器link.exe,使易语言支持静态编译
  13. 测试通达信指标胜率的软件,如何利用通达信程序交易评测系统选高胜率小回辙高收益股票...
  14. jenkins 下载插件失败 有效的处理办法(亲测)
  15. 对接有道翻译api中英翻译软件
  16. 私有云Openstack介绍及搭建
  17. opencv实现眼动检测【胡子哥哥】
  18. win10网络适配器不见了_win10设备管理器里没有网络适配器的原因及处理方法
  19. 用GEPHI绘制的 我的微博 好友 关系 与 好友的好友关系图
  20. 准备女儿的学前班毕业典礼

热门文章

  1. 【WLAN常用语】—VAP
  2. 逆矩阵(inverse matrix)的概念及其意义
  3. WhatsApp滥发垃圾消息后被禁止了如何解封?
  4. python能做什么自学_分享一下,我当年是如何自学Python从而从事这方面工作的
  5. 4.1 字符串及其表示
  6. 汽车报告丨分析了比亚迪宋全网口碑,我们得出这个结论
  7. ACL 通配符掩码的应用
  8. 产品经理--如何设计一款老年人O2O产品
  9. 银行卡收单业务____轧差
  10. Docker容器之间的通信