作者主页(文火冰糖的硅基工坊):文火冰糖(王文兵)的博客_文火冰糖的硅基工坊_CSDN博客

本文网址:https://blog.csdn.net/HiWangWenBing/article/details/121055489


目录

第1章 TorchVision概述

1.1 TorchVision

1.2 TorchVision的安装

1.3 TorchVision官网的数据集

1.4 TorchVision常见的数据集概述

第2章 MNIST数据集

2.1 MNIST数据集介绍

2.2 样本数据与样本标签格式

2.3 MNIST数据的下载与导入

2.4 对样本数据预处理

2.5 批量数据读取与显示


第1章 TorchVision概述

1.1 TorchVision

Pytorch非常有用的工具集:

  • torchtext:处理自然语言
  • torchaudio:处理音频的
  • torchvision:处理图像视频的。

torchvision包含一些常用的数据集、模型、转换函数等等。本文重点放在torchvision的数据集上。

1.2 TorchVision的安装

pip install torchvision 

1.3 TorchVision官网的数据集

https://pytorch-cn.readthedocs.io/zh/latest/torchvision/torchvision-datasets/

1.4 TorchVision常见的数据集概述

  • MNIST
  • CIFAR10
  • CIFAR100
  • COCO(用于图像标注和目标检测)(Captioning and Detection)
  • LSUN Classification
  • ImageNet Flower
  • Imagenet-12
  • STL10

第2章 MNIST数据集

2.1 MNIST数据集介绍

MNIST数据集: http://yann.lecun.com/exdb/

备注 :可以先把样本数据下载本地,以提升程序调试的效率。最终的产品可以远程下载数据。

  • 每张图片大小:28*28.
  • 单通道的黑白色图片,即(batch_size, channels, Height, Width) =(batch_size, 1, 28, 28)

2.2 样本数据与样本标签格式

2.3 MNIST数据的下载与导入

(1)操作函数MNIST()的解读

MNIST (root, train=True, transform=None, target_transform=None, download=False)

参数说明:

  • root : 文件存放路的根路径,下载的文件存放在该路径下,processed/training.pt 和 processed/test.pt 的主目录
  • train : True = 训练集, False = 测试集
  • target_transform:导入数据时,是否需要对数据格式进行转换,一个函数,原始图片作为输入,返回一个转换后的图片。有时候神经网络所需要的尺寸与数据集提供的尺寸不一致,则可以通过此方法进行转换。
  • download : True = 从互联网上下载数据集,并把数据集放在root目录下. 如果数据集之前下载过,将处理过的数据(minist.py中有相关函数)放在processed文件夹下。

(2)代码实例

#环境准备
import numpy as np              # numpy数组库
import math                     # 数学运算库
import matplotlib.pyplot as plt # 画图库import torch             # torch基础库
import torchvision.datasets as dataset  #公开数据集的下载和管理
import torchvision.transforms as transforms  #公开数据集的预处理库,格式转换
import torchvision.utils as utils
import torch.utils.data as data_utils  #对数据集进行分批加载的工具集print("Hello World")
print(torch.__version__)
print(torch.cuda.is_available())
Hello World
1.8.0
False
#2-1 准备数据集
train_data = dataset.MNIST(root = "mnist",train = True,transform = transforms.ToTensor(),download = True)#2-1 准备数据集
test_data = dataset.MNIST(root = "mnist",train = False,transform = transforms.ToTensor(),download = True)print(train_data)
print("size=", len(train_data))
print("")
print(test_data)
print("size=", len(test_data))
Dataset MNISTNumber of datapoints: 60000Root location: mnistSplit: TrainStandardTransform
Transform: ToTensor()
size= 60000Dataset MNISTNumber of datapoints: 10000Root location: mnistSplit: TestStandardTransform
Transform: ToTensor()
size= 10000

2.4 对样本数据预处理

(1)原图不叠加噪声显示

#原图不叠加噪声
#获取一张图片数据
print("原始图片")
image, label = train_data[0]
print("torch image shape:", image.shape)
print("torch image label:", label)print("\n单通道原始图片:numpy")
image = image.numpy().transpose(1,2,0)
print("numpy image shape:", image.shape)
print("numpy image label:", label)print("\n不叠加噪声, 原图显示")plt.imshow(image)
plt.show()
原始图片
torch image shape: torch.Size([1, 28, 28])
torch image label: 5单通道原始图片:numpy
numpy image shape: (28, 28, 1)
numpy image label: 5不叠加噪声, 原图显示

(2)原图叠加噪声

#原图叠加噪声
#获取一张图片数据
print("原始图片")
image, label = train_data[0]
print("torch image shape:", image.shape)
print("torch image label:", label)print("\n单通道原始图片:numpy")
image = image.numpy().transpose(1,2,0)
print("numpy image shape:", image.shape)
print("numpy image label:", label)print("\n叠加噪声, 平滑显示")
std = [0.5]
mean = [0.5]
image = image * std + meanplt.imshow(image)
plt.show()
原始图片
torch image shape: torch.Size([1, 28, 28])
torch image label: 5单通道原始图片:numpy
numpy image shape: (28, 28, 1)
numpy image label: 5叠加噪声, 平滑显示

(3)#叠加噪声,灰度显示图片

#叠加噪声,灰度显示图片
print("原始图片")
image, label = train_data[0]
print("torch image shape:", image.shape)
print("torch image label:", label)print("\n三通道灰度图片:torch")
image = utils.make_grid(image)
print("torch image shape:", image.shape)
print("torch image label:", label)print("\n三通道灰度图片:numpy")
image = image.numpy().transpose(1,2,0)
print("numpy image shape:", image.shape)
print("numpy image label:", label)print("\n叠加噪声, 平滑显示")
std = [0.5]
mean = [0.5]
image = image * std + meanplt.imshow(image)
plt.show()
原始图片
torch image shape: torch.Size([1, 28, 28])
torch image label: 5三通道灰度图片:torch
torch image shape: torch.Size([3, 28, 28])
torch image label: 5三通道灰度图片:numpy
numpy image shape: (28, 28, 3)
numpy image label: 5叠加噪声, 平滑显示

(4)#不叠加噪声,黑白显示图片

#不叠加噪声,黑白显示图片
print("原始图片")
image, label = train_data[0]
print("torch image shape:", image.shape)
print("torch image label:", label)print("\n三通道灰度图片:torch")
image = utils.make_grid(image)
print("torch image shape:", image.shape)
print("torch image label:", label)print("\n三通道灰度图片:numpy")
image = image.numpy().transpose(1,2,0)
print("numpy image shape:", image.shape)
print("numpy image label:", label)print("\n不叠加噪声,黑白显示")
plt.imshow(image)
plt.show()
print("numpy image shape:", image.shape)
原始图片
torch image shape: torch.Size([1, 28, 28])
torch image label: 5三通道灰度图片:torch
torch image shape: torch.Size([3, 28, 28])
torch image label: 5三通道灰度图片:numpy
numpy image shape: (28, 28, 3)
numpy image label: 5不叠加噪声,黑白显示

2.5 批量数据读取与显示

(1)batch批量图片的读取

# 批量数据读取
train_loader = data_utils.DataLoader(dataset = train_data,batch_size = 64,shuffle = True)test_loader = data_utils.DataLoader(dataset = test_data,batch_size = 64,shuffle = True)print(train_loader)
print(test_loader)
print(len(train_loader), len(train_data)/64)
print(len(test_loader),  len(test_data)/64)
<torch.utils.data.dataloader.DataLoader object at 0x000002461EF4A1C0>
<torch.utils.data.dataloader.DataLoader object at 0x000002461ED66610>
938 937.5
157 156.25

(2)一个batch图片的显示

显示一个batch图片
print("获取一个batch组图片")
imgs, labels = next(iter(train_loader))
print(imgs.shape)
print(labels.shape)
print(labels.size()[0])print("\n合并成一张三通道灰度图片")
images = utils.make_grid(imgs)
print(images.shape)
print(labels.shape)print("\n转换成imshow格式")
images = images.numpy().transpose(1,2,0)
print(images.shape)
print(labels.shape)print("\n显示样本标签")
#打印图片标签
for i in range(64):print(labels[i], end=" ")i += 1#换行if i%8 == 0:print(end='\n')print("\n显示图片")
plt.imshow(images)
plt.show()
获取一个batch组图片
torch.Size([64, 1, 28, 28])
torch.Size([64])
64合并成一张三通道灰度图片
torch.Size([3, 242, 242])
torch.Size([64])转换成imshow格式
(242, 242, 3)
torch.Size([64])显示样本标签
tensor(0) tensor(8) tensor(3) tensor(7) tensor(5) tensor(7) tensor(9) tensor(7)
tensor(1) tensor(1) tensor(1) tensor(8) tensor(8) tensor(6) tensor(0) tensor(1)
tensor(4) tensor(8) tensor(1) tensor(3) tensor(3) tensor(6) tensor(4) tensor(4)
tensor(0) tensor(5) tensor(8) tensor(5) tensor(9) tensor(3) tensor(7) tensor(5)
tensor(2) tensor(1) tensor(0) tensor(6) tensor(8) tensor(8) tensor(9) tensor(6)
tensor(1) tensor(3) tensor(5) tensor(3) tensor(4) tensor(4) tensor(3) tensor(1)
tensor(4) tensor(1) tensor(4) tensor(4) tensor(9) tensor(8) tensor(7) tensor(2)
tensor(3) tensor(1) tensor(2) tensor(0) tensor(8) tensor(1) tensor(1) tensor(4) 显示图片



作者主页(文火冰糖的硅基工坊):文火冰糖(王文兵)的博客_文火冰糖的硅基工坊_CSDN博客

本文网址:https://blog.csdn.net/HiWangWenBing/article/details/121055489

[Pytorch系列-33]:数据集 - torchvision与MNIST数据集相关推荐

  1. 使用Pytorch实现手写数字识别(Mnist数据集)

    目标 知道如何使用Pytorch完成神经网络的构建 知道Pytorch中激活函数的使用方法 知道Pytorch中torchvision.transforms中常见图形处理函数的使用 知道如何训练模型和 ...

  2. 508任务一:用pytorch简单实现LeNet5网络对MNIST数据集训练

    看了一些pytorch教学视频,结合别人的代码,按自己的喜好写出来了比较简单的实现,其实还可以把loss数据绘个表,还可以在训练时2加个循环,多训练几次. import torch import to ...

  3. 【小白学习PyTorch教程】十一、基于MNIST数据集训练第一个生成性对抗网络

    「@Author:Runsen」 GAN 是使用两个神经网络模型训练的生成模型.一种模型称为生成网络模型,它学习生成新的似是而非的样本.另一个模型被称为判别网络,它学习区分生成的例子和真实的例子. 生 ...

  4. pytorch,tensorflow加载本地mnist数据集

    1. pytorch import torch import torch.nn as nn from torchvision import datasets, transforms import to ...

  5. 使用mnist数据集_使用MNIST数据集上的t分布随机邻居嵌入(t-SNE)进行降维

    使用mnist数据集 It is easy for us to visualize two or three dimensional data, but once it goes beyond thr ...

  6. pytorch实现手写字体识别(Mnist数据集)

    1.加载数据集 一个快速体验学习的小tip在google的云jupyter上做实验,速度快的飞起. import torch from torch.nn import Linear, ReLU imp ...

  7. PyTorch Sequential Models - 简化神经网络(pytorch系列-33)

    PyTorch Sequential Models - 简化神经网络 这一集中,我们将学习如何使用PyTorch的Sequential类来构建神经网络. PyTorch Sequential Mode ...

  8. python调用数据集mnist_Python读取MNIST数据集

    importnumpy as npimportmatplotlib.pyplot as plt'''试验transpose() def back (a,b): return a,b if __name ...

  9. [Pytorch系列-35]:卷积神经网络 - 搭建LeNet-5网络与CFAR10分类数据集

    作者主页(文火冰糖的硅基工坊):文火冰糖(王文兵)的博客_文火冰糖的硅基工坊_CSDN博客 本文网址:https://blog.csdn.net/HiWangWenBing/article/detai ...

最新文章

  1. UML学习总结(3)——StarUML指导手册
  2. 【创业】史上最完整创业数据,30岁以下创业白皮书
  3. c语言试卷大全,C语言试题大全
  4. BZOJ_1009_[HNOI2008]_GT考试_(动态规划+kmp+矩阵乘法优化+快速幂)
  5. 精通ASP.NET MVC ——视图
  6. Spark生态圈及安装
  7. ARM, MIPS, Power PC
  8. 信息化项目甲方采购的准备与实施
  9. 电源大师课笔记 2.5
  10. 如何去掉PDF右下角的全能扫描王水印
  11. 计算机论文的研究思路与方法,硕士论文中研究方法怎么写 介绍3种简单的方法...
  12. 云运维拓扑图_云平台网络拓扑图
  13. linux防恶意软件防病毒 防护工具
  14. python统计大写辅音字母_大写
  15. word 2007中在页眉中插入或这删除下划线
  16. 树莓派 Linux 操作系统大全
  17. java 手机智能拨号_智能拨号 CeleDial v1.8
  18. 经典的排错过程 expected unqualified-id before string constant
  19. MACD指标为什么不灵了?试试QMACD
  20. java.sql.SQLException: Access denied for user ‘root’@‘localhost’ (using password: YES)和错误原因 解决方案:

热门文章

  1. 计算机教育软件参评作品例子,2018年东莞计算机教育软件评审活动.doc
  2. 单相变压器的等效电路
  3. 仓库智能分拣机器人RFID,如何实现分拣工作
  4. 【免费分享】2000-2019 年中国各省、市、区县分年、分月、逐日平均降水量
  5. 信息系统分析与设计 第七章 用例建模
  6. Python批量爬取Win10锁屏壁纸,根本不用浪费钱!
  7. html css如何渐变阴影,CSS 实现文字阴影 + 文字渐变色
  8. 自动化/控制工程专业英语02——拉普拉斯变换[考研/保研面试]
  9. 咪咕音乐java笔试题_咪咕音乐链接歌词封面搜索等接口API
  10. Java Web应用实践