动手学深度学习--课堂笔记图片分类数据集
softmax是一个非线性函数,但softmax回归是一个线性模型(linear model):是不是线性的是由决策面是否是线性函数决定的,不是由拟合的数据分布决定的。softmax只是对数据分布做了非线性的处理,但它的决策函数形式还是Xw+b的线性形式。
Fashion-MNIST数据集:包含70000张灰度图像,其中包含60,000个示例的训练集和10,000个示例的测试集,每个示例都是一个28x28灰度图像。主要分为:T恤(T-shirt)、裤子(Trouser)、套头衫(Pullover)、连衣裙(Dress)、外套(Coat)、凉鞋(Sandal)、衬衫(Shirt)、运动鞋(Sneaker)、包(Bag)、靴子(Ankle boot)
1.导入包
%matplotlib inline import torch import torchvision from torch.utils import data from torchvision import transforms from d2l import torch as d2ld2l.use_svg_display()
use_svg_display()函数指定matplotlib软件包输出svg图表以获得更清晰的图像
2.读取数据集
通过框架中的内置函数将Fashion_MNIST数据集下载并读取到内存中
trans = transforms.ToTensor() mnist_train = torchvision.datasets.FashionMNIST(root="../data", train=True, transform=trans, download=True) mnist_test = torchvision.datasets.FashionMNIST(root="../data", train=False, transform=trans,download=True)len(mnist_train), len(mnist_test)#训练集与测试集中样本的数量 mnist_train[0][0].shape#训练集中第一个图片
transforms.ToTensor():将图片转化为Tensor.
mnist_train是训练集,mnist_test是测试集,两者是torch.utils.data.Dataset的子类
root="../data", train=True, transform=trans, download=True:将Fashion-MNIST的训练集(train=True)从网上下载(download=True)到(root="../data")上级目录的data中,并确保得到是tensor而不是图片(transform=trans)
输出结果:
Out[3]:表示训练集有60000张图片,测试集有10000张图片。
Out[4]:因为通过transforms.ToTensor()的转换,变成了尺寸为(CxHxW),数据类型为torch.float32,位于[0.0, 1.0] 的Tensor,输出结果[1,28,28]的‘1’表示的是第一维的通道数为1,所以是灰度图像,后面两维中的‘28‘表示图像的高和宽。
3.定义两个可视化的数据集函数
def get_fashion_mnist_labels(labels):text_labels=['t-shirt','trouser','pullover','dress','cost','sandal','shirt','sneaker','bag','ankle boot']return [text_labels[int(i)] for i in labels]def show_images(imgs, num_rows, num_cols,titles=None,scale=1.5):figsize=(num_cols*scale,num_rows*scale)_, axes = d2l.plt.subplots(num_rows, num_cols, figsize=figsize)axes = axes.flatten()for i,(ax,img) in enumerate(zip(axes,imgs)):if torch.is_tensor(img):#是否为张量ax.imshow(img.numpy())#图片张量else:ax.imshow(img)#PIL图片ax.axes.get_xaxis().set_visible(False)ax.axes.get_yaxis().set_visible(False)if titles:ax.set_title(titles[i])return axes X,y=next(iter(data.DataLoader(mnist_train,batch_size=18)))#data.DataLoader(mnist_train,batch_size=18):在mist_train数据集中加载数据,每批次要装载18个样品,最后将这些数据封装为Tensor. show_images(X.reshape(18,28,28),2,9,titles=get_fashion_mnist_labels(y))
1)plt.subplots():plt.subplots()是matplotlib中绘制子图的一种方法。在matplotlib中整个图像为一个Figure对象,在Figure对象中可以包含一个或多个Axes对象,每个Axes(ax)对象都是一个拥有自己坐标系统的绘图区域。plt.subplots()直接在函数内部设置子图纸信息,返回两个变量,一个是Figure实例fig,另一个是AxesSubplot实例ax。fig代表整个图像,ax代表坐标轴和子图。d2l.plt.subplots(num_rows, num_cols, figsize=figsize)中,第一个参数代表子图的行数,第二个参数代表该行图像的列数,第三个参数代表每行的第几个图像。
2)axes.flatten():flatten()是numpy.ndarray.flatten的一个函数,即返回一个一维数组。axes.flatten()表示把axes数组降到一维,默认为按行的方向降。
3)enumerate():获取可迭代对象的每个元素的索引值及该元素值,进行拆包,多用于for循环
4)zip():用于将可迭代的对象作为参数,将对象中对应的元素打包成一个个元组,然后返回由这些元组组成的列表
5)axe.get_xaxis().set_visible():是设置坐标轴显示与否,包括了刻度与标签,如果设置为False则表示不显示,True为显示。
6)next()与iter():两者要一起使用。iter()函数将Iterable转换为Iterator;对获取到的迭代器(Iterator)不断使用next()函数来获取下一条数据
#flatten()实例 from numpy import * a = arange(1, 7).reshape(3, 2) print(a) print(a.flatten())
输出结果:
[[1 2][3 4][5 6]] [1 2 3 4 5 6]
#zip()实例 b = [4, 5, 6] c = [7, 8] print(list(zip(b, c)))#元素个数与最短的列表一致
输出结果:
[(4, 7), (5, 8)]
#enumerate()实例 for i, value in enumerate(['a', 'b', 'c', 'd']):print(i, value)
输出结果:
0 a 1 b 2 c 3 d
#axe.get_xaxis().set_visible()实例 import matplotlib.pyplot as pltfig=plt.figure(figsize=(5,5),dpi=100) #创建画布 axe=plt.subplot(1,1,1) #创建子图 axe.set_title('test') #设置子图标题 fig.savefig('test.png',dpi=100) #保存图片 axe.get_xaxis().set_visible(True) plt.show() #展示图片
输出结果:
#next()&&iter()实例 # 首先获得Iteration对象 it = iter([1, 2, 3, 4, 5]) # 循环 while True:try:# 获得下一个值x = next(it)print(x)except StopIteration:# 遇到StopIteration就退出循环break
输出结果:
1 2 3 4 5
4.读取小批量数据
batch_size = 256def get_dataloader_workers(): #@save"""使用4个进程来读取数据"""return 4train_iter = data.DataLoader(mnist_train, batch_size, shuffle=True,num_workers=get_dataloader_workers())#读取训练数据所需的时间 timer = d2l.Timer() for X, y in train_iter:continue f'{timer.stop():.2f} sec'
5.整合所有的函数
定义
load_data_fashion_mnist
函数,用于获取和读取Fashion-MNIST数据集。def load_data_fashion_mnist(batch_size, resize=None): #@save"""下载Fashion-MNIST数据集,然后将其加载到内存中"""trans = [transforms.ToTensor()]if resize:trans.insert(0, transforms.Resize(resize))#修改图片大小trans = transforms.Compose(trans)mnist_train = torchvision.datasets.FashionMNIST(root="../data", train=True, transform=trans, download=True)mnist_test = torchvision.datasets.FashionMNIST(root="../data", train=False, transform=trans, download=True)return (data.DataLoader(mnist_train, batch_size, shuffle=True,num_workers=get_dataloader_workers()),data.DataLoader(mnist_test, batch_size, shuffle=False,num_workers=get_dataloader_workers()))
transforms.Compose():串联多个图片变换的操作
train_iter, test_iter = load_data_fashion_mnist(32, resize=64) for X, y in train_iter:print(X.shape, X.dtype, y.shape, y.dtype)break
动手学深度学习--课堂笔记图片分类数据集相关推荐
- 《动手学深度学习》笔记——深度学习简介
原文链接: https://tangshusen.me/Dive-into-DL-PyTorch/#/chapter01_DL-intro/deep-learning-intro 机器学习与深度学习 ...
- 动手学深度学习课程笔记ch02
ch_02 线性代数 线性代数李老师讲得比较少,需要自己下去多看看书,后期还是需要一些矩阵论的知识. 基本知识 标量:由只有一个元素的张量表示(一般为数据的标签). # 创建标量进行运算 import ...
- 《动手学深度学习》笔记---3.16
3.16 实战Kaggle比赛:房价预测 3.16.2 读取数据集 # 导入所需的模块和包%matplotlib inline import d2lzh as d2l from mxnet impor ...
- 李沐动手学深度学习V2-机器翻译和数据集
一. 机器翻译和数据集 1. 介绍 机器翻译的数据集是由源语言和目标语言的文本序列对组成的,因此需要一种完全不同的方法来预处理机器翻译数据集, 而不是复用语言模型的预处理程序. 2. 下载和预处理数据 ...
- 【动手学深度学习】Softmax 回归 + 损失函数 + 图片分类数据集
学习资料: 09 Softmax 回归 + 损失函数 + 图片分类数据集[动手学深度学习v2]_哔哩哔哩_bilibili torchvision.transforms.ToTensor详解 | 使用 ...
- 伯禹公益AI《动手学深度学习PyTorch版》Task 05 学习笔记
伯禹公益AI<动手学深度学习PyTorch版>Task 05 学习笔记 Task 05:卷积神经网络基础:LeNet:卷积神经网络进阶 微信昵称:WarmIce 昨天打了一天的<大革 ...
- 动手学深度学习笔记3.4+3.5+3.6+3.7
系列文章目录 动手学深度学习笔记系列: 动手学深度学习笔记3.1+3.2+3.3 文章目录 系列文章目录 前言 一.softmax回归 1.1 分类问题 1.2 网络架构 1.3 全连接层的参数开销 ...
- 【李沐动手学深度学习】读书笔记 01前言
虽然之前已经学过这部分内容和深度学习中的基础知识,但总觉得学的不够系统扎实,所以希望再通过沐神的课程以及书籍,系统条理的学习一遍.在读书过程中,利用导图做了一下梳理,形成了这个读书笔记.如有侵权,请联 ...
- (d2l-ai/d2l-zh)《动手学深度学习》pytorch 笔记(2)前言(介绍各种机器学习问题)以及数据操作预备知识Ⅰ
开源项目地址:d2l-ai/d2l-zh 教材官网:https://zh.d2l.ai/ 书介绍:https://zh-v2.d2l.ai/ 笔记基于2021年7月26日发布的版本,书及代码下载地址在 ...
最新文章
- mysql如何创建简单索引_mysql 如何创建索引呢,这个其实很简单
- TCP/IP详解--五层协议的作用以及对应的设备
- sql参数化还是被注入了_SQL注入是什么?
- linux下面获取当前bing-国内版的壁纸
- c语言网络在线人数统计,教你用ASP程序实现网站在线人数统计
- 工具栏自定义_EXCEL LESSON12 自定义功能区菜单及工具栏(1/3)
- .NET Core 2.0及.NET Standard 2.0
- java 03_Java基础03—流程控制
- 11月12号 用户登录输入密码错误达到指定次数后,锁定账户 004
- 【人工智能沙龙】未来,语音识别可能应用于哪些商业化场景?
- HTML跳转下一行快捷键,wps常用快捷键有哪些?
- 微软对开发者献真爱,全面支持开源,加速研发云升级
- red5搭建流媒体直播系统
- Python 三维姿态估计+Unity3d 实现 3D 虚拟现实交互游戏
- 蓝屏代码大全(留着自己看)
- CentOS8下超详细安装配置kubernetes(K8S)
- php订单表设计,订单详情表,与,订单表 怎么做?
- TCP/IP层次模型
- 南阳理工学院ACM多乐赛暨16级退役纪念赛 C PK没有女朋友
- matlab求解erfc方程