图像分类数据集 (FASHION-MNIST)
文章目录
- 引入
- 1 获取数据集
- 2 简单操作
- 3 读取小批量
- 4 完整代码
- 致谢
引入
图像分类数据集最常用的是手写数字识别数据集MNIST (1),但是大部分模型在其上的分类精度都超过了95%。为了更直观地观察算法之间的差异,将使用一个图像内容更加复杂的数据集[Fashion-MNIST (2)]。
接下来的部分将使用torchvision包,主要用于构建计算机视觉模型,主要由以下4部分组成:
组成 | 功能 |
---|---|
torchvision.datasets | 加载数据的函数及常用的数据集接口 |
torchvision.models | 包含常用的模型结构 (含预训练模型) |
torchvision.transforms | 常用的图片变化,例如裁剪、旋转 |
torchvision…utils | 其他方法 |
代码已上传至github:
https://github.com/InkiInki/Python/blob/master/Python1/deepLearning/ImageMnist.py
1 获取数据集
需要导入的包如下:
import torch
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import time
import sys
from IPython import display
下面,将通过torchvision.datasets下载数据集,第一次调用时会自动从网上获取数据 (若出现速度较慢,请向后查看注意);通过参数train来指定获取训练集或者测试集;通过transform = transforms.Tensor()将数据转化为Tensor,如果不转换,则返回PIL图片。
transforms.Tensor()将尺寸为 (H×W×CH×W×CH×W×C)且数据位于 (0, 255)的PIL图片或数据类型为np.uint8的Numpy转换为尺寸为 (C×H×WC×H×WC×H×W)且数据类型为torch.float32且位于 (0.0, 1.0)的Tensor。
使用代码如下:
class ImageMnist():def __init__(self):self.mnist_train = torchvision.datasets.FashionMNIST(root='~/DataSets/FashionMNIST',train=True, download=True, transform=transforms.ToTensor())self.mnist_test = torchvision.datasets.FashionMNIST(root='~/DataSets/FashionMNIST',train=False, download=True, transform=transforms.ToTensor())if __name__ == "__main__":test = ImageDataSet()test.__init__()print(test.mnist_train)print(len(test.mnist_train), len(test.mnist_test))
运行结果:
Dataset FashionMNISTNumber of datapoints: 60000Root location: C:\Users\Administrator/DataSets/FashionMNISTSplit: TrainStandardTransform
Transform: ToTensor()
60000 10000
注意:
1)如果用像素值表示图片数据,那么一律将其类型设置成unit8,以避免不必要的bug;
2)第一次下载时速度也许很慢,推荐在cmd中输入以下代码,并复制出现的http链接下载:
import torchvision
import torchvision.transforms as transforms
torchvision.datasets.FashionMNIST(root='~/DataSets/FashionMNIST', train=True, download=True, transform=transforms.ToTensor())
torchvision.datasets.FashionMNIST(root='~/DataSets/FashionMNIST', train=False, download=True, transform=transforms.ToTensor())
2 简单操作
可以通过下标来访问任意一个样本:
if __name__ == "__main__":test = ImageMnist()test.__init__()data, label = test.mnist_train[0]print(data.shape)print(label)
运行结果:
torch.Size([1, 28, 28]) # 分别对应通道数、图像高、图像宽
9
Fashion-MNIST共10个类别,分别为t-shirt、trouser、pullover、dress、coat、sandal、shirt、sneaker、bag和ankle boot,以下函数可以将数值标签转换成相应的文本标签:
...def get_text_labels(self, labels):text_labels = ['t-shirt', 'trouser', 'pullover', 'dress', 'coat', 'sandal', 'shirt', 'sneaker', 'bag', 'ankle boot']return [text_labels[int(i)] for i in labels]if __name__ == "__main__":test = ImageMnist()test.__init__()data, label = test.mnist_train[0]print(test.get_text_labels([label]))
运行结果:
['ankle boot']
现在定义一个可以在一行里画出多张图像和对应标签的函数:
...def show_mnist(self, images, labels):display.set_matplotlib_formats('svg')_, figs = plt.subplots(1, len(images), figsize=(12, 12))# zip()接受一系列可迭代对象作为参数,将对象中对应的元素打包成一个个元组,然后返回由这些元组组成的列表for f, img, lbl in zip(figs, images, labels):f.imshow(img.view((28, 28)).numpy())f.set_title(lbl)f.axis('off')plt.show()if __name__ == "__main__":test = ImageMnist()test.__init__()x, y = [], []for i in range(10):x.append(test.mnist_train[i][0])y.append(test.mnist_train[i][1])test.show_mnist(x, test.get_text_labels(y))
运行结果:
3 读取小批量
torch的DataLoader中一个很方便的功能是运行使用多进程来加速读取数据,这里通过参数num_workers来设置4个进程读取数据。
...def data_iter(self, batch_size=256):if sys.platform.startswith('win'):num_workers = 0 # 0表示不需要额外的进程来加速读取数据else:num_workers = 4train_iter = torch.utils.data.DataLoader(self.mnist_train, batch_size=batch_size, shuffle=True, num_workers=num_workers)test_iter = torch.utils.data.DataLoader(self.mnist_test, batch_size=batch_size, shuffle=False, num_workers=num_workers)return train_iter, test_iterif __name__ == "__main__":start = time.time()test = ImageMnist()test.__init__()train_iter, test_iter = test.data_iter()for x, y in train_iter:continueprint("%.2f sec" % (time.time() - start))
运行结果:
6.65 sec
4 完整代码
'''
@(#)test.py
The class of test.
Author: Yu-Xuan Zhang
Email: inki.yinji@qq.com
Created on May 05, 2020
Last Modified on May 05, 2020@author: inki
'''
import torch
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import time
import sys
from IPython import displayclass ImageMnist():def __init__(self):self.mnist_train = torchvision.datasets.FashionMNIST(root='~/DataSets/FashionMNIST',train=True, download=True, transform=transforms.ToTensor())self.mnist_test = torchvision.datasets.FashionMNIST(root='~/DataSets/FashionMNIST',train=False, download=True, transform=transforms.ToTensor())def get_text_labels(self, labels):text_labels = ['t-shirt', 'trouser', 'pullover', 'dress', 'coat', 'sandal', 'shirt', 'sneaker', 'bag', 'ankle boot']return [text_labels[int(i)] for i in labels]def show_mnist(self, images, labels):display.set_matplotlib_formats('svg')_, figs = plt.subplots(1, len(images), figsize=(12, 12))for f, img, lbl in zip(figs, images, labels):f.imshow(img.view((28, 28)).numpy())f.set_title(lbl)f.axis('off')plt.show()def data_iter(self, batch_size=256):if sys.platform.startswith('win'):num_workers = 0else:num_workers = 4train_iter = torch.utils.data.DataLoader(self.mnist_train, batch_size=batch_size, shuffle=True, num_workers=num_workers)test_iter = torch.utils.data.DataLoader(self.mnist_test, batch_size=batch_size, shuffle=False, num_workers=num_workers)return train_iter, test_iterif __name__ == "__main__":start = time.time()test = ImageMnist()test.__init__()train_iter, test_iter = test.data_iter()for x, y in train_iter:continueprint("%.2f sec" % (time.time() - start))
致谢
特别感谢李沐、Aston Zhang等老师的这本《动手学深度学习》一书~
图像分类数据集 (FASHION-MNIST)相关推荐
- tensorflow2.0 CNN fashion MNIST图像分类
基于 CNN的 fashion MNIST图像分类 fashion MNIST图像分类 数据集简介 数据的预处理 CNN简介和构建 模型部分代码 CNN实验结果 致谢 fashion MNIST图像分 ...
- 深度学习之自编码器(2)Fashion MNIST图片重建实战
深度学习之自编码器(2)Fashion MNIST图片重建实战 1. Fashion MNIST数据集 2. 编码器 3. 解码器 4. 自编码器 5. 网络训练 6. 图片重建 完整代码 自编码器 ...
- Fashion MNIST
原文: Fashion MNIST An MNIST-like dataset of 70,000 28x28 labeled fashion images Fashion-MNIST is a da ...
- fashionmnist数据集_Keras实现Fashion MNIST数据集分类
本篇用keras构建人工神经网路(ANN)和卷积神经网络(CNN)实现Fashion MNIST 数据集单个物品分类,并从模型预测的准确性方面对ANN和CNN进行简单比较. Fashion MNIST ...
- Pytorch初学实战(一):基于的CNN的Fashion MNIST图像分类
1.引言 1.1.什么是Pytorch PyTorch是一个开源的Python机器学习库. 1.2.什么是CNN 卷积神经网络(Convolutional Neural Networks)是一种深度学 ...
- python cnn程序_python cnn训练(针对Fashion MNIST数据集)
本文将和大家一起一步步尝试对Fashion MNIST数据集进行调参,看看每一步对模型精度的影响.(调参过程中,基础模型架构大致保持不变) 废话不多说,先上任务: 模型的主体框架如下(此为拿到的原始代 ...
- 计算机视觉两个入门数据集(mnist和fashion mnist)本地下载地址
1.计算机视觉经典数据集 1.mnist数据集 MNIST(Mixed National Institute of Standards andTechnology database)数据集大家可以说是 ...
- tensorflow卷积神经网络实战:Fashion Mnist 图像分类与人马分类
卷积神经网络实战:Fashion Mnist 图像分类与人马分类 一.FashionMnist的卷积神经网络模型 1.卷积VS全连接 2.卷积网络结构 3.卷积模型结构 1)Output Shape ...
- Fashion MNIST数据集的处理——“...-idx3-ubyte”文件解析
Fashion MNIST MNIST数据集可能是计算机视觉所接触的第一个图片数据集.而 Fashion MNIST 是在遵循 MNIST 的格式和大小的基础上,提升了一定的难度,在比较算法的性能时可 ...
- TensorFlow中的Fashion MNIST图像识别实战
1.导入相应的库: 关于Fashion MNIST数据集的介绍:看这位博主: https://blog.csdn.net/qq_28869927/article/details/85079808 im ...
最新文章
- 计算机视觉开源库OpenCV之平滑、模糊和滤波
- 用.XML填充TreeView
- 谷歌浏览器使用IE内核
- 每日一题(26)—— 无限循环的几种形式
- 优化算法笔记|粒子群算法理解及Python实现
- 案例解读:Oracle目录由于TFA触发bug导致jdb文件未自动清理引起空间不足
- cmd后台运行exe_了解运行命令的原理,为QQ制作运行命令启动
- VS2012 产品密钥
- 练习题2 -和可被 K 整除的子数组
- sublime复制一行_sublime怎么快速复制一行,快捷键是什么?
- 分光器光衰多少?分光器如何选购?分光器如何使用?
- ssm毕设项目磐基建筑机械租赁有限公司机械租赁系统41c32(java+VUE+Mybatis+Maven+Mysql+sprnig)
- MyBatis 第二扇门
- 维特比算法(Viterbi algorithm) 的理解
- 2023年五面蚂蚁、三面拼多多、字节跳动最终拿offer入职拼多多
- 10.30系统进程及服务控制,前后台调用,kill,进程信号,top进程动态监控,系统控制systemctl,ssh服务和认证,用户登陆审计
- 第十章 国民收入的决定:收入-支出模型
- Hive 练习(带数据)
- XILINX FPGA和CPLD引脚约束步骤
- 裁判文书网 爬虫 最新更新2020-08-12
热门文章
- 永磁同步电机的标么值系统
- 中国财团收购Opera 为什么要美国监管部门同意?
- BestMPRBaseVtk-003-修改工程,搬运官方代码并尝试理解-2
- php判断浏览器语言国内外,PHP判断浏览器语言
- APP性能测试关注点详细介绍
- 高通9xxx系列模块modem射频 RF LTE B41频段踩过的坑
- Html5 文件上传
- android 仿微信朋友圈发布动态功能
- 【java】org.xml.sax.SAXParseException;在实体引用中, 实体名称必须紧跟在 '' 后面。解决方法
- 程序员必备技能之Markdown