文章目录

  • 引入
  • 1 获取数据集
  • 2 简单操作
  • 3 读取小批量
  • 4 完整代码
  • 致谢

引入

  图像分类数据集最常用的是手写数字识别数据集MNIST (1),但是大部分模型在其上的分类精度都超过了95%。为了更直观地观察算法之间的差异,将使用一个图像内容更加复杂的数据集[Fashion-MNIST (2)]。
  接下来的部分将使用torchvision包,主要用于构建计算机视觉模型,主要由以下4部分组成:

组成 功能
torchvision.datasets 加载数据的函数及常用的数据集接口
torchvision.models 包含常用的模型结构 (含预训练模型)
torchvision.transforms 常用的图片变化,例如裁剪、旋转
torchvision…utils 其他方法

  代码已上传至github:
  https://github.com/InkiInki/Python/blob/master/Python1/deepLearning/ImageMnist.py

1 获取数据集

  需要导入的包如下:

import torch
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import time
import sys
from IPython import display

  下面,将通过torchvision.datasets下载数据集,第一次调用时会自动从网上获取数据 (若出现速度较慢,请向后查看注意);通过参数train来指定获取训练集或者测试集;通过transform = transforms.Tensor()将数据转化为Tensor,如果不转换,则返回PIL图片。
  transforms.Tensor()将尺寸为 (H×W×CH×W×CH×W×C)且数据位于 (0, 255)的PIL图片或数据类型为np.uint8的Numpy转换为尺寸为 (C×H×WC×H×WC×H×W)且数据类型为torch.float32且位于 (0.0, 1.0)的Tensor。

  使用代码如下:

class ImageMnist():def __init__(self):self.mnist_train = torchvision.datasets.FashionMNIST(root='~/DataSets/FashionMNIST',train=True, download=True, transform=transforms.ToTensor())self.mnist_test = torchvision.datasets.FashionMNIST(root='~/DataSets/FashionMNIST',train=False, download=True, transform=transforms.ToTensor())if __name__ == "__main__":test = ImageDataSet()test.__init__()print(test.mnist_train)print(len(test.mnist_train), len(test.mnist_test))

  运行结果:

Dataset FashionMNISTNumber of datapoints: 60000Root location: C:\Users\Administrator/DataSets/FashionMNISTSplit: TrainStandardTransform
Transform: ToTensor()
60000 10000

  注意:
  1)如果用像素值表示图片数据,那么一律将其类型设置成unit8,以避免不必要的bug;
  2)第一次下载时速度也许很慢,推荐在cmd中输入以下代码,并复制出现的http链接下载:

import torchvision
import torchvision.transforms as transforms
torchvision.datasets.FashionMNIST(root='~/DataSets/FashionMNIST', train=True, download=True, transform=transforms.ToTensor())
torchvision.datasets.FashionMNIST(root='~/DataSets/FashionMNIST', train=False, download=True, transform=transforms.ToTensor())

2 简单操作

  可以通过下标来访问任意一个样本:

if __name__ == "__main__":test = ImageMnist()test.__init__()data, label = test.mnist_train[0]print(data.shape)print(label)

  运行结果:

torch.Size([1, 28, 28])    # 分别对应通道数、图像高、图像宽
9

  Fashion-MNIST共10个类别,分别为t-shirt、trouser、pullover、dress、coat、sandal、shirt、sneaker、bag和ankle boot,以下函数可以将数值标签转换成相应的文本标签:

 ...def get_text_labels(self, labels):text_labels = ['t-shirt', 'trouser', 'pullover', 'dress', 'coat', 'sandal', 'shirt', 'sneaker', 'bag', 'ankle boot']return [text_labels[int(i)] for i in labels]if __name__ == "__main__":test = ImageMnist()test.__init__()data, label = test.mnist_train[0]print(test.get_text_labels([label]))

  运行结果:

['ankle boot']

  现在定义一个可以在一行里画出多张图像和对应标签的函数:

 ...def show_mnist(self, images, labels):display.set_matplotlib_formats('svg')_, figs = plt.subplots(1, len(images), figsize=(12, 12))# zip()接受一系列可迭代对象作为参数,将对象中对应的元素打包成一个个元组,然后返回由这些元组组成的列表for f, img, lbl in zip(figs, images, labels):f.imshow(img.view((28, 28)).numpy())f.set_title(lbl)f.axis('off')plt.show()if __name__ == "__main__":test = ImageMnist()test.__init__()x, y = [], []for i in range(10):x.append(test.mnist_train[i][0])y.append(test.mnist_train[i][1])test.show_mnist(x, test.get_text_labels(y))

  运行结果:

3 读取小批量

  torch的DataLoader中一个很方便的功能是运行使用多进程来加速读取数据,这里通过参数num_workers来设置4个进程读取数据。

 ...def data_iter(self, batch_size=256):if sys.platform.startswith('win'):num_workers = 0    # 0表示不需要额外的进程来加速读取数据else:num_workers = 4train_iter = torch.utils.data.DataLoader(self.mnist_train, batch_size=batch_size, shuffle=True, num_workers=num_workers)test_iter = torch.utils.data.DataLoader(self.mnist_test, batch_size=batch_size, shuffle=False, num_workers=num_workers)return train_iter, test_iterif __name__ == "__main__":start = time.time()test = ImageMnist()test.__init__()train_iter, test_iter = test.data_iter()for x, y in train_iter:continueprint("%.2f sec" % (time.time() - start))

  运行结果:

6.65 sec

4 完整代码

'''
@(#)test.py
The class of test.
Author: Yu-Xuan Zhang
Email: inki.yinji@qq.com
Created on May 05, 2020
Last Modified on May 05, 2020@author: inki
'''
import torch
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import time
import sys
from IPython import displayclass ImageMnist():def __init__(self):self.mnist_train = torchvision.datasets.FashionMNIST(root='~/DataSets/FashionMNIST',train=True, download=True, transform=transforms.ToTensor())self.mnist_test = torchvision.datasets.FashionMNIST(root='~/DataSets/FashionMNIST',train=False, download=True, transform=transforms.ToTensor())def get_text_labels(self, labels):text_labels = ['t-shirt', 'trouser', 'pullover', 'dress', 'coat', 'sandal', 'shirt', 'sneaker', 'bag', 'ankle boot']return [text_labels[int(i)] for i in labels]def show_mnist(self, images, labels):display.set_matplotlib_formats('svg')_, figs = plt.subplots(1, len(images), figsize=(12, 12))for f, img, lbl in zip(figs, images, labels):f.imshow(img.view((28, 28)).numpy())f.set_title(lbl)f.axis('off')plt.show()def data_iter(self, batch_size=256):if sys.platform.startswith('win'):num_workers = 0else:num_workers = 4train_iter = torch.utils.data.DataLoader(self.mnist_train, batch_size=batch_size, shuffle=True, num_workers=num_workers)test_iter = torch.utils.data.DataLoader(self.mnist_test, batch_size=batch_size, shuffle=False, num_workers=num_workers)return train_iter, test_iterif __name__ == "__main__":start = time.time()test = ImageMnist()test.__init__()train_iter, test_iter = test.data_iter()for x, y in train_iter:continueprint("%.2f sec" % (time.time() - start))

致谢

  特别感谢李沐、Aston Zhang等老师的这本《动手学深度学习》一书~

图像分类数据集 (FASHION-MNIST)相关推荐

  1. tensorflow2.0 CNN fashion MNIST图像分类

    基于 CNN的 fashion MNIST图像分类 fashion MNIST图像分类 数据集简介 数据的预处理 CNN简介和构建 模型部分代码 CNN实验结果 致谢 fashion MNIST图像分 ...

  2. 深度学习之自编码器(2)Fashion MNIST图片重建实战

    深度学习之自编码器(2)Fashion MNIST图片重建实战 1. Fashion MNIST数据集 2. 编码器 3. 解码器 4. 自编码器 5. 网络训练 6. 图片重建 完整代码  自编码器 ...

  3. Fashion MNIST

    原文: Fashion MNIST An MNIST-like dataset of 70,000 28x28 labeled fashion images Fashion-MNIST is a da ...

  4. fashionmnist数据集_Keras实现Fashion MNIST数据集分类

    本篇用keras构建人工神经网路(ANN)和卷积神经网络(CNN)实现Fashion MNIST 数据集单个物品分类,并从模型预测的准确性方面对ANN和CNN进行简单比较. Fashion MNIST ...

  5. Pytorch初学实战(一):基于的CNN的Fashion MNIST图像分类

    1.引言 1.1.什么是Pytorch PyTorch是一个开源的Python机器学习库. 1.2.什么是CNN 卷积神经网络(Convolutional Neural Networks)是一种深度学 ...

  6. python cnn程序_python cnn训练(针对Fashion MNIST数据集)

    本文将和大家一起一步步尝试对Fashion MNIST数据集进行调参,看看每一步对模型精度的影响.(调参过程中,基础模型架构大致保持不变) 废话不多说,先上任务: 模型的主体框架如下(此为拿到的原始代 ...

  7. 计算机视觉两个入门数据集(mnist和fashion mnist)本地下载地址

    1.计算机视觉经典数据集 1.mnist数据集 MNIST(Mixed National Institute of Standards andTechnology database)数据集大家可以说是 ...

  8. tensorflow卷积神经网络实战:Fashion Mnist 图像分类与人马分类

    卷积神经网络实战:Fashion Mnist 图像分类与人马分类 一.FashionMnist的卷积神经网络模型 1.卷积VS全连接 2.卷积网络结构 3.卷积模型结构 1)Output Shape ...

  9. Fashion MNIST数据集的处理——“...-idx3-ubyte”文件解析

    Fashion MNIST MNIST数据集可能是计算机视觉所接触的第一个图片数据集.而 Fashion MNIST 是在遵循 MNIST 的格式和大小的基础上,提升了一定的难度,在比较算法的性能时可 ...

  10. TensorFlow中的Fashion MNIST图像识别实战

    1.导入相应的库: 关于Fashion MNIST数据集的介绍:看这位博主: https://blog.csdn.net/qq_28869927/article/details/85079808 im ...

最新文章

  1. 计算机视觉开源库OpenCV之平滑、模糊和滤波
  2. 用.XML填充TreeView
  3. 谷歌浏览器使用IE内核
  4. 每日一题(26)—— 无限循环的几种形式
  5. 优化算法笔记|粒子群算法理解及Python实现
  6. 案例解读:Oracle目录由于TFA触发bug导致jdb文件未自动清理引起空间不足
  7. cmd后台运行exe_了解运行命令的原理,为QQ制作运行命令启动
  8. VS2012 产品密钥
  9. 练习题2 -和可被 K 整除的子数组
  10. sublime复制一行_sublime怎么快速复制一行,快捷键是什么?
  11. 分光器光衰多少?分光器如何选购?分光器如何使用?
  12. ssm毕设项目磐基建筑机械租赁有限公司机械租赁系统41c32(java+VUE+Mybatis+Maven+Mysql+sprnig)
  13. MyBatis 第二扇门
  14. 维特比算法(Viterbi algorithm) 的理解
  15. 2023年五面蚂蚁、三面拼多多、字节跳动最终拿offer入职拼多多
  16. 10.30系统进程及服务控制,前后台调用,kill,进程信号,top进程动态监控,系统控制systemctl,ssh服务和认证,用户登陆审计
  17. 第十章 国民收入的决定:收入-支出模型
  18. Hive 练习(带数据)
  19. XILINX FPGA和CPLD引脚约束步骤
  20. 裁判文书网 爬虫 最新更新2020-08-12

热门文章

  1. 永磁同步电机的标么值系统
  2. 中国财团收购Opera 为什么要美国监管部门同意?
  3. BestMPRBaseVtk-003-修改工程,搬运官方代码并尝试理解-2
  4. php判断浏览器语言国内外,PHP判断浏览器语言
  5. APP性能测试关注点详细介绍
  6. 高通9xxx系列模块modem射频 RF LTE B41频段踩过的坑
  7. Html5 文件上传
  8. android 仿微信朋友圈发布动态功能
  9. 【java】org.xml.sax.SAXParseException;在实体引用中, 实体名称必须紧跟在 '' 后面。解决方法
  10. 程序员必备技能之Markdown