softmax是一个非线性函数,但softmax回归是一个线性模型(linear model):是不是线性的是由决策面是否是线性函数决定的,不是由拟合的数据分布决定的。softmax只是对数据分布做了非线性的处理,但它的决策函数形式还是Xw+b的线性形式。

Fashion-MNIST数据集:包含70000张灰度图像,其中包含60,000个示例的训练集和10,000个示例的测试集,每个示例都是一个28x28灰度图像。主要分为:T恤(T-shirt)、裤子(Trouser)、套头衫(Pullover)、连衣裙(Dress)、外套(Coat)、凉鞋(Sandal)、衬衫(Shirt)、运动鞋(Sneaker)、包(Bag)、靴子(Ankle boot)

1.导入包

%matplotlib inline
import torch
import torchvision
from torch.utils import data
from torchvision import transforms
from d2l import torch as d2ld2l.use_svg_display()

use_svg_display()函数指定matplotlib软件包输出svg图表以获得更清晰的图像

2.读取数据集

通过框架中的内置函数将Fashion_MNIST数据集下载并读取到内存中

trans = transforms.ToTensor()
mnist_train = torchvision.datasets.FashionMNIST(root="../data", train=True, transform=trans, download=True)
mnist_test = torchvision.datasets.FashionMNIST(root="../data", train=False, transform=trans,download=True)len(mnist_train), len(mnist_test)#训练集与测试集中样本的数量
mnist_train[0][0].shape#训练集中第一个图片

transforms.ToTensor():将图片转化为Tensor.

mnist_train是训练集,mnist_test是测试集,两者是torch.utils.data.Dataset的子类

root="../data", train=True, transform=trans, download=True:将Fashion-MNIST的训练集(train=True)从网上下载(download=True)到(root="../data")上级目录的data中,并确保得到是tensor而不是图片(transform=trans)

输出结果:

Out[3]:表示训练集有60000张图片,测试集有10000张图片。

Out[4]:因为通过transforms.ToTensor()的转换,变成了尺寸为(CxHxW),数据类型为torch.float32,位于[0.0, 1.0] 的Tensor,输出结果[1,28,28]的‘1’表示的是第一维的通道数为1,所以是灰度图像,后面两维中的‘28‘表示图像的高和宽。

3.定义两个可视化的数据集函数

def get_fashion_mnist_labels(labels):text_labels=['t-shirt','trouser','pullover','dress','cost','sandal','shirt','sneaker','bag','ankle boot']return [text_labels[int(i)] for i in labels]def show_images(imgs, num_rows, num_cols,titles=None,scale=1.5):figsize=(num_cols*scale,num_rows*scale)_, axes = d2l.plt.subplots(num_rows, num_cols, figsize=figsize)axes = axes.flatten()for i,(ax,img) in enumerate(zip(axes,imgs)):if torch.is_tensor(img):#是否为张量ax.imshow(img.numpy())#图片张量else:ax.imshow(img)#PIL图片ax.axes.get_xaxis().set_visible(False)ax.axes.get_yaxis().set_visible(False)if titles:ax.set_title(titles[i])return axes
X,y=next(iter(data.DataLoader(mnist_train,batch_size=18)))#data.DataLoader(mnist_train,batch_size=18):在mist_train数据集中加载数据,每批次要装载18个样品,最后将这些数据封装为Tensor.
show_images(X.reshape(18,28,28),2,9,titles=get_fashion_mnist_labels(y))

1)plt.subplots():plt.subplots()是matplotlib中绘制子图的一种方法。在matplotlib中整个图像为一个Figure对象,在Figure对象中可以包含一个或多个Axes对象,每个Axes(ax)对象都是一个拥有自己坐标系统的绘图区域。plt.subplots()直接在函数内部设置子图纸信息,返回两个变量,一个是Figure实例fig,另一个是AxesSubplot实例ax。fig代表整个图像,ax代表坐标轴和子图。d2l.plt.subplots(num_rows, num_cols, figsize=figsize)中,第一个参数代表子图的行数,第二个参数代表该行图像的列数,第三个参数代表每行的第几个图像。

2)axes.flatten():flatten()是numpy.ndarray.flatten的一个函数,即返回一个一维数组。axes.flatten()表示把axes数组降到一维,默认为按行的方向降。

3)enumerate():获取可迭代对象的每个元素的索引值及该元素值,进行拆包,多用于for循环

4)zip():用于将可迭代的对象作为参数,将对象中对应的元素打包成一个个元组,然后返回由这些元组组成的列表

5)axe.get_xaxis().set_visible():是设置坐标轴显示与否,包括了刻度与标签,如果设置为False则表示不显示,True为显示。

6)next()与iter():两者要一起使用。iter()函数将Iterable转换为Iterator;对获取到的迭代器(Iterator)不断使用next()函数来获取下一条数据

#flatten()实例
from numpy import *
a = arange(1, 7).reshape(3, 2)
print(a)
print(a.flatten())

输出结果:

[[1 2][3 4][5 6]]
[1 2 3 4 5 6]
#zip()实例
b = [4, 5, 6]
c = [7, 8]
print(list(zip(b, c)))#元素个数与最短的列表一致

输出结果:

[(4, 7), (5, 8)]
#enumerate()实例
for i, value in enumerate(['a', 'b', 'c', 'd']):print(i, value)

输出结果:

0 a
1 b
2 c
3 d
#axe.get_xaxis().set_visible()实例
import matplotlib.pyplot as pltfig=plt.figure(figsize=(5,5),dpi=100)   #创建画布
axe=plt.subplot(1,1,1)    #创建子图
axe.set_title('test')   #设置子图标题
fig.savefig('test.png',dpi=100) #保存图片
axe.get_xaxis().set_visible(True)
plt.show()  #展示图片

输出结果:

#next()&&iter()实例
# 首先获得Iteration对象
it = iter([1, 2, 3, 4, 5])
# 循环
while True:try:# 获得下一个值x = next(it)print(x)except StopIteration:# 遇到StopIteration就退出循环break

输出结果:

1
2
3
4
5

4.读取小批量数据

batch_size = 256def get_dataloader_workers():  #@save"""使用4个进程来读取数据"""return 4train_iter = data.DataLoader(mnist_train, batch_size, shuffle=True,num_workers=get_dataloader_workers())#读取训练数据所需的时间
timer = d2l.Timer()
for X, y in train_iter:continue
f'{timer.stop():.2f} sec'

5.整合所有的函数

定义load_data_fashion_mnist函数,用于获取和读取Fashion-MNIST数据集。

def load_data_fashion_mnist(batch_size, resize=None):  #@save"""下载Fashion-MNIST数据集,然后将其加载到内存中"""trans = [transforms.ToTensor()]if resize:trans.insert(0, transforms.Resize(resize))#修改图片大小trans = transforms.Compose(trans)mnist_train = torchvision.datasets.FashionMNIST(root="../data", train=True, transform=trans, download=True)mnist_test = torchvision.datasets.FashionMNIST(root="../data", train=False, transform=trans, download=True)return (data.DataLoader(mnist_train, batch_size, shuffle=True,num_workers=get_dataloader_workers()),data.DataLoader(mnist_test, batch_size, shuffle=False,num_workers=get_dataloader_workers()))

transforms.Compose():串联多个图片变换的操作

train_iter, test_iter = load_data_fashion_mnist(32, resize=64)
for X, y in train_iter:print(X.shape, X.dtype, y.shape, y.dtype)break

动手学深度学习--课堂笔记图片分类数据集相关推荐

  1. 《动手学深度学习》笔记——深度学习简介

    原文链接: https://tangshusen.me/Dive-into-DL-PyTorch/#/chapter01_DL-intro/deep-learning-intro 机器学习与深度学习 ...

  2. 动手学深度学习课程笔记ch02

    ch_02 线性代数 线性代数李老师讲得比较少,需要自己下去多看看书,后期还是需要一些矩阵论的知识. 基本知识 标量:由只有一个元素的张量表示(一般为数据的标签). # 创建标量进行运算 import ...

  3. 《动手学深度学习》笔记---3.16

    3.16 实战Kaggle比赛:房价预测 3.16.2 读取数据集 # 导入所需的模块和包%matplotlib inline import d2lzh as d2l from mxnet impor ...

  4. 李沐动手学深度学习V2-机器翻译和数据集

    一. 机器翻译和数据集 1. 介绍 机器翻译的数据集是由源语言和目标语言的文本序列对组成的,因此需要一种完全不同的方法来预处理机器翻译数据集, 而不是复用语言模型的预处理程序. 2. 下载和预处理数据 ...

  5. 【动手学深度学习】Softmax 回归 + 损失函数 + 图片分类数据集

    学习资料: 09 Softmax 回归 + 损失函数 + 图片分类数据集[动手学深度学习v2]_哔哩哔哩_bilibili torchvision.transforms.ToTensor详解 | 使用 ...

  6. 伯禹公益AI《动手学深度学习PyTorch版》Task 05 学习笔记

    伯禹公益AI<动手学深度学习PyTorch版>Task 05 学习笔记 Task 05:卷积神经网络基础:LeNet:卷积神经网络进阶 微信昵称:WarmIce 昨天打了一天的<大革 ...

  7. 动手学深度学习笔记3.4+3.5+3.6+3.7

    系列文章目录 动手学深度学习笔记系列: 动手学深度学习笔记3.1+3.2+3.3 文章目录 系列文章目录 前言 一.softmax回归 1.1 分类问题 1.2 网络架构 1.3 全连接层的参数开销 ...

  8. 【李沐动手学深度学习】读书笔记 01前言

    虽然之前已经学过这部分内容和深度学习中的基础知识,但总觉得学的不够系统扎实,所以希望再通过沐神的课程以及书籍,系统条理的学习一遍.在读书过程中,利用导图做了一下梳理,形成了这个读书笔记.如有侵权,请联 ...

  9. (d2l-ai/d2l-zh)《动手学深度学习》pytorch 笔记(2)前言(介绍各种机器学习问题)以及数据操作预备知识Ⅰ

    开源项目地址:d2l-ai/d2l-zh 教材官网:https://zh.d2l.ai/ 书介绍:https://zh-v2.d2l.ai/ 笔记基于2021年7月26日发布的版本,书及代码下载地址在 ...

最新文章

  1. mysql如何创建简单索引_mysql 如何创建索引呢,这个其实很简单
  2. TCP/IP详解--五层协议的作用以及对应的设备
  3. sql参数化还是被注入了_SQL注入是什么?
  4. linux下面获取当前bing-国内版的壁纸
  5. c语言网络在线人数统计,教你用ASP程序实现网站在线人数统计
  6. 工具栏自定义_EXCEL LESSON12 自定义功能区菜单及工具栏(1/3)
  7. .NET Core 2.0及.NET Standard 2.0
  8. java 03_Java基础03—流程控制
  9. 11月12号 用户登录输入密码错误达到指定次数后,锁定账户 004
  10. 【人工智能沙龙】未来,语音识别可能应用于哪些商业化场景?
  11. HTML跳转下一行快捷键,wps常用快捷键有哪些?
  12. 微软对开发者献真爱,全面支持开源,加速研发云升级
  13. red5搭建流媒体直播系统
  14. Python 三维姿态估计+Unity3d 实现 3D 虚拟现实交互游戏
  15. 蓝屏代码大全(留着自己看)
  16. CentOS8下超详细安装配置kubernetes(K8S)
  17. php订单表设计,订单详情表,与,订单表 怎么做?
  18. TCP/IP层次模型
  19. 南阳理工学院ACM多乐赛暨16级退役纪念赛 C PK没有女朋友
  20. matlab求解erfc方程

热门文章

  1. 指数型生成函数(EGF)略解
  2. CTF中的md5弱类型(ALL_IN_ONE)
  3. 教学一体化服务平台——学生选课系统需求分析
  4. 计算传奇客户端中NPC外观代码的方法
  5. java之高质量代码优化技巧
  6. 利用java连wifi_wifi 连接
  7. Java可变参数类型实例
  8. 将caj文件整篇转换成Word的教程
  9. 配置OpenGL(Linux)
  10. iOS 打开html、txt、PDF、PPT等文件