pytorch实现对Fashion-MNIST数据集进行图像分类

导入所需模块:

import torch
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import time
import sys

对数据集的操作(读取数据集):
由于像素值为0到255的整数,所以刚好是uint8所能表示的范围,包括transforms.ToTensor()在内的一些关于图片的函数就默认输入的是uint8型,若不是,可能不会报错但可能得不到想要的结果。所以,如果用像素值(0-255整数)表示图片数据,那么一律将其类型设置成uint8,避免不必要的bug
通过torchvision的torchvision.datasets来下载这个数据集。第一次调用时数据集目录不存在,会自动从网上获取数据。

mnist_train = torchvision.datasets.FashionMNIST(root='~/Datasets/FashionMNIST', train=True, download=True, transform=transforms.ToTensor())
mnist_test = torchvision.datasets.FashionMNIST(root='~/Datasets/FashionMNIST', train=False, download=True, transform=transforms.ToTensor())

获取标签:

def get_fashion_mnist_labels(labels):text_labels = ['t-shirt', 'trouser', 'pullover', 'dress', 'coat','sandal', 'shirt', 'sneaker', 'bag', 'ankle boot']return [text_labels[int(i)] for i in labels]

定义画图函数:

def use_svg_display():# 用矢量图显示display.set_matplotlib_formats('svg')def set_figsize(figsize=(3.5, 2.5)):use_svg_display()# 设置图的尺寸plt.rcParams['figure.figsize'] = figsizedef show_fashion_mnist(images, labels):use_svg_display()# 这里的_表示我们忽略(不使用)的变量_, figs = plt.subplots(1, len(images), figsize=(12, 12))for f, img, lbl in zip(figs, images, labels):f.imshow(img.view((28, 28)).numpy())f.set_title(lbl)f.axes.get_xaxis().set_visible(False)f.axes.get_yaxis().set_visible(False)plt.show()

展示数据集中的前10个样本:

X, y = [], []
for i in range(10):X.append(mnist_train[i][0])y.append(mnist_train[i][1])
show_fashion_mnist(X, get_fashion_mnist_labels(y))

PyTorch的DataLoader允许使用多进程来加速数据读取
通过参数num_workers来设置4个进程读取数据:

batch_size = 256
if sys.platform.startswith('win'):num_workers = 0  # 0表示不用额外的进程来加速读取数据
else:num_workers = 4
train_iter = torch.utils.data.DataLoader(mnist_train, batch_size=batch_size, shuffle=True, num_workers=num_workers)
test_iter = torch.utils.data.DataLoader(mnist_test, batch_size=batch_size, shuffle=False, num_workers=num_workers)

初始化模型参数:

num_inputs = 784 # 784个样本
num_outputs = 10 # 分为10类W = torch.tensor(np.random.normal(0, 0.01, (num_inputs, num_outputs)), dtype=torch.float)
b = torch.zeros(num_outputs, dtype=torch.float)
# 模型参数梯度回调
W.requires_grad_(requires_grad=True)
b.requires_grad_(requires_grad=True)

定义softmax运算:
softmax运算会先通过exp函数对每个元素做指数运算,再对exp矩阵同行元素求和,最后令矩阵每行各元素与该行元素之和相除。这样一来,最终得到的矩阵每行元素和为1且非负。因此,该矩阵每行都是合法的概率分布。softmax运算的输出矩阵中的任意一行元素代表了一个样本在各个输出类别上的预测概率。

def softmax(X):#对于随机输入,每个元素变成了非负数,且每一行和为1X_exp = X.exp()partition = X_exp.sum(dim=1, keepdim=True) # 按行求和return X_exp / partition  # 这里应用了广播机制,会自动对partition进行扩展到与X_exp尺度相同

定义模型:

def net(X): # 定义模型return softmax(torch.mm(X.view((-1, num_inputs)), W) + b)def cross_entropy(y_hat, y): # 交叉熵损失函数return - torch.log(y_hat.gather(1, y.view(-1, 1)))def accuracy(y_hat, y): # 定义准确率计算return (y_hat.argmax(dim=1) == y).float().mean().item()

这里讲一下为什么使用softmax可以将输出定在(0-9)这样的离散类别中

因为,模型中使用到的线性规划层:输出=输入*w+b
其中w是一个(输入,输出)维度的张量,b是一个(输出)维度的向量
根据矩阵运算,输入*w+b的结果就是一个向量,该向量列数为输出(也就是类别数),而softmax做的,就是将该向量的元素进行归一化处理,使得其变为概率向量,这样,该模型的输出就可以代表样本为某一类别的概率,其和为1.

训练模型:

num_epochs, lr = 5, 0.1def sgd(params, lr, batch_size):  # 随机梯度下降算法for param in params:param.data -= lr * param.grad / batch_size # 注意这里更改param时用的param.datadef evaluate_accuracy(data_iter, net):acc_sum, n = 0.0, 0for X, y in data_iter:acc_sum += (net(X).argmax(dim=1) == y).float().sum().item()n += y.shape[0]return acc_sum / ndef train_ch3(net, train_iter, test_iter, loss, num_epochs, batch_size,params=None, lr=None, optimizer=None):for epoch in range(num_epochs):train_l_sum, train_acc_sum, n = 0.0, 0.0, 0for X, y in train_iter:y_hat = net(X)l = loss(y_hat, y).sum()# 梯度清零if optimizer is not None:optimizer.zero_grad() # 这里我们没有用到优化器,所以直接对参数进行梯度清零elif params is not None and params[0].grad is not None:for param in params:param.grad.data.zero_()l.backward()if optimizer is None:sgd(params, lr, batch_size)else:optimizer.step()  # 简洁实现将用到优化器这里train_l_sum += l.item()train_acc_sum += (y_hat.argmax(dim=1) == y).sum().item()n += y.shape[0] test_acc = evaluate_accuracy(test_iter, net)print('epoch %d, loss %.4f, train acc %.3f, test acc %.3f'% (epoch + 1, train_l_sum / n, train_acc_sum / n, test_acc))train_ch3(net, train_iter, test_iter, cross_entropy, num_epochs, batch_size, [W, b], lr)

注意,这里的每一个train_iter是以batch_size个样本为实例的,所以每一个train_iter包含多个(tensor, type)组合。

训练完毕进行预测:

X, y = iter(test_iter).next()true_labels = get_fashion_mnist_labels(y.numpy())
pred_labels = get_fashion_mnist_labels(net(X).argmax(dim=1).numpy())
titles = [true + '\n' + pred for true, pred in zip(true_labels, pred_labels)]
show_fashion_mnist(X[0:9], titles[0:9])

(pytorch-深度学习系列)pytorch实现对Fashion-MNIST数据集进行图像分类相关推荐

  1. Pytorch深度学习(五):加载数据集以及mini-batch的使用

    Pytorch深度学习(五):加载数据集以及mini-batch的使用 参考B站课程:<PyTorch深度学习实践>完结合集 传送门:<PyTorch深度学习实践>完结合集 一 ...

  2. [PyTorch] 深度学习框架PyTorch中的概念和函数

    Pytorch的概念 Pytorch最重要的概念是tensor,意为"张量". Variable是能够构建计算图的 tensor(对 tensor 的封装).借用Variable才 ...

  3. Python深度学习之分类模型示例,MNIST数据集手写数字识别

    MNIST数据集是机器学习领域中非常经典的一个数据集,由60000个训练样本和10000个测试样本组成,每个样本都是一张28 * 28像素的灰度手写数字图片. 我们把60000个训练样本分成两部分,前 ...

  4. (d2l-ai/d2l-zh)《动手学深度学习》pytorch 笔记(3)前言(介绍各种机器学习问题)以及数据操作预备知识Ⅲ(概率)

    开源项目地址:d2l-ai/d2l-zh 教材官网:https://zh.d2l.ai/ 书介绍:https://zh-v2.d2l.ai/ 笔记基于2021年7月26日发布的版本,书及代码下载地址在 ...

  5. 《深度学习之PyTorch物体检测实战》—读书笔记

    随书代码 物体检测与PyTorch 深度学习 为了赋予计算机以人类的理解能力与逻辑思维,诞生了人工智能(Artificial Intelligence, AI)这一学科.在实现人工智能的众多算法中,机 ...

  6. 基于PyTorch深度学习无人机遥感影像目标检测、地物分类及语义分割

    随着无人机自动化能力的逐步升级,它被广泛的应用于多种领域,如航拍.农业.植保.灾难评估.救援.测绘.电力巡检等.但同时由于无人机飞行高度低.获取目标类型多.以及环境复杂等因素使得对无人机获取的数据处理 ...

  7. 深度学习之Pytorch基础教程!

    ↑↑↑关注后"星标"Datawhale 每日干货 & 每月组队学习,不错过 Datawhale干货 作者:李祖贤,Datawhale高校群成员,深圳大学 随着深度学习的发展 ...

  8. 深度学习准「研究僧」预习资料:图灵奖得主Yann LeCun《深度学习(Pytorch)》春季课程...

    视学算法报道 编辑:蛋酱 转载自公众号:机器之心 开学进入倒计时,深度学习方向的准「研究僧」们,你们准备好了吗? 转眼 2020 年已经过半,又一届深度学习方向的准研究生即将踏上「炼丹」之路.对于这一 ...

  9. 5天玩转PyTorch深度学习,从GAN到词嵌入都有实例丨教程资源

    郭一璞 发自 凹非寺  量子位 报道 | 公众号 QbitAI 学PyTorch深度学习,可能5天就够了. 法国深度学习研究者Marc Lelarge出品的这套名为<Hands-on tour ...

  10. python tensorflow pytorch 深度学习 车牌识别

    车牌识别相关资料收集整理 1.License Plate Detection with RetinaFace 链接:https://github.com/zeusees/License-Plate-D ...

最新文章

  1. 新版Bintray-极简上传Library到JCenter
  2. IOS7开发~API变化
  3. python与数据思维基础笔记_Python小课笔记--Python基础:数据和函数(二)
  4. fastjson 输出null值字段
  5. 4月24日Serverless Developer Meetup上海亮相
  6. Python模块包中__init__.py文件的作用(转载)
  7. 【最简解法】1048 Find Coins (25 分)_18行代码AC
  8. abap中的弹出窗体函数
  9. MySQL优化原理分析及优化方案总结
  10. 浏览器图片解析失败(裂开,空白)排查思路
  11. C语言教程第六章:指针(1)
  12. 《python自动化》学习笔记:电话地址和E-mail地址提取程序
  13. Hive 求全局Top N
  14. 懒人用日志分析-awstats的docker应用
  15. php1008打印机驱动器,hp laserjet p1008打印机驱动
  16. Redis 客户端哪家强? Lettuce手下见真香!
  17. .NET Standard中配置TargetFrameworks输出多版本类库
  18. java 序列化版本号_序列化版本号serialVersionUID的作用
  19. python协程爬取斗鱼美女图片
  20. IntelliJ IDEA—SVN的配置及使用

热门文章

  1. python的socket模块_python模块:socket模块
  2. mysql按字段同步_MySQL同步(一) 基础知识
  3. eclipse mat 分析dump文件,打开文件报错,out of memeory
  4. 对象引用 String引用 基本类型引用 差别
  5. android 常用开发插件,Android Studio 开发利器【常用插件】
  6. loss下降auc下降_梯度下降算法 线性回归拟合(附Python/Matlab/Julia源代码)
  7. WSL安装Oracle,折腾记录:WSL(Windows Subsystem for Linux,Windows上的Linux子系统)安装后的环境配置-Go语言中文社区...
  8. html中dl标签和ul标签,html中dl,dt,dd,ul,li,ol标签区别和使用
  9. mysql集群fuzhi_MySQL集群 和MySQL主从复制的不同
  10. ftm模块linux驱动,飞思卡尔k系列_ftm模块详解.doc