CIFAR10数据集训练及测试
一、数据集介绍
该数据集共有60000张彩色图像,这些图像是32*32,分为10个类,每类6000张图。这里面有50000张用于训练,构成了5个训练批,每一批10000张图;另外10000用于测试,单独构成一批。测试批的数据里,取自10类中的每一类,每一类随机取1000张。抽剩下的就随机排列组成了训练批。注意一个训练批中的各类图像并不一定数量相同,总的来看训练批,每一类都有5000张图。
下面这幅图就是列举了10各类,每一类展示了随机的10张图片:
二、搭建神经网络模型
使用CIFAR10网路模型,基于pytorch搭建网络模型
import torch
from torch import nnclass Test(nn.Module):def __init__(self):super(Test, self).__init__()self.model = nn.Sequential(nn.Conv2d(3, 32, 5, padding=2),nn.MaxPool2d(2),nn.Conv2d(32, 32, 5, padding=2),nn.MaxPool2d(2),nn.Conv2d(32, 64, 5, padding=2),nn.MaxPool2d(2),nn.Flatten(),nn.Linear(1024, 64),nn.Linear(64, 10))def forward(self,x):x = self.model(x)return x
三、数据集的准备及加载
使用torchvision.datasets.CIFAR10()加载数据集,train=True表示数据集为训练数据集,train=False表示数据集为测试集,dowwnload=True表示下载数据集,本地存在数据集不会再次下载。
import torchvision
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
import time
from model import *# 定义训练设备
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")# 准备数据集
train_data = torchvision.datasets.CIFAR10("dataset", train=True, transform=torchvision.transforms.ToTensor(),download=True)
test_data = torchvision.datasets.CIFAR10("dataset", train=False, transform=torchvision.transforms.ToTensor(),download=True)
train_data_size = len(train_data)
test_data_size = len(test_data)
# print("训练数据集的长度为{}".format(train_data_size))
# print("测试数据集的长度为{}".format(test_data_size))# 利用DataLoader来加载数据集
train_dataloader = DataLoader(train_data,batch_size=64)
test_dataloader = DataLoader(test_data,batch_size=64)
四、神经网络、损失函数、优化器等加载
test = Test()
test = test.to(device)# 损失函数
loss_fn = nn.CrossEntropyLoss()
loss_fn = loss_fn.to(device)# 优化器
learning_rate = 0.01
optimizer = torch.optim.SGD(test.parameters(), lr=learning_rate)# 设置训练网络的一些参数
# 记录训练的次数
total_train_step = 0
# 记录测试的次数
total_test_step = 0
epoch = 30# 添加Tensorboard
writer = SummaryWriter("logs_train")
五、训练、测试、模型保存
start_time = time.time()
for i in range(epoch):print("-----第{}轮训练开始------".format(i+1))# 训练步骤开始test.train()for data in train_dataloader:imgs,targets = dataimgs = imgs.to(device)targets = targets.to(device)output = test(imgs)loss = loss_fn(output, targets)# 优化器优化模型optimizer.zero_grad()loss.backward()optimizer.step()total_train_step += 1if total_train_step % 100 == 0:end_time = time.time()print(end_time - start_time)print("训练次数{}, Loss:{}".format(total_train_step, loss.item()))writer.add_scalar("train_loss", loss.item(), total_train_step)# 测试步骤开始test.eval()total_test_loss = 0total_accuracy = 0with torch.no_grad():for data in test_dataloader:imgs, targets = dataimgs = imgs.to(device)targets = targets.to(device)outputs = test(imgs)loss = loss_fn(outputs, targets)total_test_loss += loss.item()accuracy = (outputs.argmax(1) == targets).sum()total_accuracy += accuracyprint("整体测试集上的Loss: {}".format(total_test_loss))print("整体测试集上的正确率: {}".format(total_accuracy/test_data_size))writer.add_scalar("test_loss", total_test_loss, total_test_step)writer.add_scalar("test_accuracy", total_accuracy/test_data_size, total_test_step)total_test_step += 1torch.save(test, "test_{}.pth".format(i))print("模型已保存")writer.close()
六、模型的加载及测试
import torch
import torchvision
from PIL import Image
from model import *
CLASS = ["airplane", "automobile", "bird", "cat", "deer", "dog", "frog", "horse", "ship", "truck"]
image_path = "./imgs/dog.png"
image = Image.open(image_path)
# print(image)transform = torchvision.transforms.Compose([torchvision.transforms.Resize((32, 32)),torchvision.transforms.ToTensor()])image = transform(image)
# print(image.shape)model = torch.load("test_99.pth", map_location=torch.device('cpu'))
# print(model)
image = torch.reshape(image, (1, 3, 32, 32))
model.eval()
with torch.no_grad():output = model(image)
ret = output.argmax(1)
ret = ret.numpy()
print("预测结果为:{}".format(CLASS[ret[0]]))
CIFAR10数据集训练及测试相关推荐
- caffe学习(五):cifar-10数据集训练及测试(Ubuntu)
简介 网站链接:CIFAR-10 CIFAR-10数据集包括由10个类别的事物,每个事物各有6000张彩色图像,每张图片的大小是32*32. 整个数据集被分成了5个训练集和1个测试集,各有10000张 ...
- cifar10数据集测试有多少张图_pytorch VGG11识别cifar10数据集(训练+预测单张输入图片操作)...
首先这是VGG的结构图,VGG11则是红色框里的结构,共分五个block,如红框中的VGG11第一个block就是一个conv3-64卷积层: 一,写VGG代码时,首先定义一个 vgg_block(n ...
- 基于Keras搭建cifar10数据集训练预测Pipeline
基于Keras搭建cifar10数据集训练预测Pipeline 钢笔先生关注 0.5412019.01.17 22:52:05字数 227阅读 500 Pipeline 本次训练模型的数据直接使用Ke ...
- TF之CNN:基于CIFAR-10数据集训练、检测CNN(2+2)模型(TensorBoard可视化)
TF之CNN:基于CIFAR-10数据集训练.检测CNN(2+2)模型(TensorBoard可视化) 目录 1.基于CIFAR-10数据集训练CNN(2+2)模型代码 2.检测CNN(2+2)模型 ...
- cifar10数据集训练
有关CIFAR-10数据集 (1)CIFAR-10数据集由10个类的60000个32x32彩色图像组成,每个类有6000个图像.有50000个训 练图像和10000个测试图像. (2)数据集分为五个训 ...
- yolo-v2 自己的数据集训练以及测试流程(仅供内部使用!)
warning 该流程仅供内部使用,外部人士使用可能会报很多很多错误! 步骤 先清除backup文件夹中老的权重文件: 将标定好图片以及annotation .txt文件拷贝到obj文件夹,一一对应, ...
- 【小白学习keras教程】二、基于CIFAR-10数据集训练简单的MLP分类模型
@Author:Runsen 分类任务的MLP 当目标(y)是离散的(分类的) 对于损失函数,使用交叉熵:对于评估指标,通常使用accuracy 数据集描述 CIFAR-10数据集包含10个类中的60 ...
- 深度学习入门 FashionMNIST数据集训练和测试(30层神经网路)
使用pytorch框架.模型包含13层卷积层.2层池化层.15层全连接层.为什么叠这么多层?就是玩. FashionMNIST数据集包含训练集6w张图片,测试集1w张图片,每张图片是单通道.大小28× ...
- 【Pytorch实战4】基于CIFAR10数据集训练一个分类器
参考资料: <深度学习之pytorch实战计算机视觉> Pytorch官方教程 Pytorch中文文档 先是数据的导入与预览. import torch import torchvisio ...
最新文章
- sql和泛型方法返回泛型_基于泛型编程的序列化实现方法
- 两种求集合全部子集的方法
- SAP Commerce Cloud deprecation机制
- 数学除了摧残祖国的花朵外,竟然还可以赢钱!
- 前端学习(2225):react之类定义组件
- oracle 11g 如何实现坏块检查、恢复?
- android+ndk+r9+x64下载,Win7 64位中文旗舰版上Cocos2d-x 3.0的Android开发调试环境架设
- nginx代理php不能跳转页面,nginx 解决首页跳转问题详解
- python3扬州大学校园网认证登录与下线
- 数据优化 | CnOpenData中国工业企业绿色专利及引用被引用数据
- 《那些年啊,那些事——一个程序员的奋斗史》八
- 易语言静态连接器提取_正确易语言链接器link.exe,使易语言支持静态编译
- 测试通达信指标胜率的软件,如何利用通达信程序交易评测系统选高胜率小回辙高收益股票...
- jenkins 下载插件失败 有效的处理办法(亲测)
- 对接有道翻译api中英翻译软件
- 私有云Openstack介绍及搭建
- opencv实现眼动检测【胡子哥哥】
- win10网络适配器不见了_win10设备管理器里没有网络适配器的原因及处理方法
- 用GEPHI绘制的 我的微博 好友 关系 与 好友的好友关系图
- 准备女儿的学前班毕业典礼