1.引言

1.1.什么是Pytorch

PyTorch是一个开源的Python机器学习库。

1.2.什么是CNN

卷积神经网络(Convolutional Neural Networks)是一种深度学习模型或类似于人工神经网络的多层感知器,常用来分析视觉图像。

1.2.什么是MNIST

MNIST是一个入门级的计算机视觉数据集,它包含各种手写数字图片,以及每一张图片对应的标签,告诉我们这个是数字几。
这里我们选择的是由kaggle提供的Fashion MNIST,包含的是各种服装的图片,难度相较于原始的MNIST而言要更高。
数据集介绍与下载:点此
数据中包含10种不同的衣物,分别为:

  • 0 T-shirt/top
  • 1 Trouser
  • 2 Pullover
  • 3 Dress
  • 4 Coat
  • 5 Sandal
  • 6 Shirt
  • 7 Sneaker
  • 8 Bag
  • 9 Ankle boot

2.分析

完整代码于文末给出。

2.1.CNN结构

本文采用如下结构的卷积神经网络:

Layer Input Kernel Output
INPUT 28×28 / 28×28
CONV1 28×28 5×5(padding=2) 16×28×28
POOL1 16×28×28 2×2 16×14×14
CONV2 16×14×14 3×3 32×12×12
CONV3 32×12×12 3×3 64×10×10
POOL2 64×10×10 2×2 64×5×5
FC 64×5×5 / 10

各层的详细分析如下:

  • 输入层INPUT: 数据集中的原始数据为28×28的图像,无需额外调整。
  • 卷积层CONV1: 采用5×5的卷积核,并采用2个像素进行边缘填充,保证卷积得到的16张特征图仍保持28×28的大小不变。
    之后,还进行数据的归一化,防止数据在进行Relu之前因为数据过大而导致网络性能的不稳定;之后再进行Relu处理。
  • 池化层POOL1: 采用2×2的采样空间,进行最大池化。输出得到16张14×14的特征图。
  • 卷积层CONV2: 采用3×3的卷积核,之后同样进行归一化与Relu处理。最后得到32张12×12的特征图。
  • 卷积层CONV3: 采用3×3的卷积核,之后同样进行归一化与Relu处理。最后得到64张10×10的特征图。
  • 池化层POOL2: 采用2×2的采样空间,进行最大池化。输出得到64张5×5的特征图。
  • 输出层FC: 输入为64张5×5的特征图。将这些特征图先压缩成向量,然后进行全连接,最后得到10维的向量。

2.2.CNN代码

class CNN(nn.Module):def __init__(self):#nn.Module子类的函数必须在构造函数中执行父类的构造函数super(CNN, self).__init__()#卷积层conv1self.conv1 = nn.Sequential(   nn.Conv2d(1, 16, kernel_size=5, padding=2),nn.BatchNorm2d(16), nn.ReLU()) #池化层pool1self.pool1=nn.MaxPool2d(2) #卷积层conv2self.conv2 = nn.Sequential(nn.Conv2d(16, 32, kernel_size=3),nn.BatchNorm2d(32),nn.ReLU())#卷积层conv3self.conv3 = nn.Sequential(nn.Conv2d(32, 64, kernel_size=3),nn.BatchNorm2d(64),nn.ReLU()) #池化层pool2self.pool2=nn.MaxPool2d(2)  #全连接层fc(输出层)self.fc = nn.Linear(5*5*64, 10)#前向传播def forward(self, x):out = self.conv1(x)out = self.pool1(out)out = self.conv2(out)out = self.conv3(out)out = self.pool2(out)#压缩成向量以供全连接out = out.view(out.size(0), -1)out = self.fc(out)return out

2.3.数据观察

打开下载得到的fashion-mnist_train.csv与fashion-mnist_test.csv:

可以发现训练集包含60000个样本,测试集包含10000个样本;每个样本包含其实际对应的图片类型label以及对应的各像素pixel1~pixel784灰度值,即28×28的图像。

2.4.数据导入

构造FashionMNISTDataset类,以方便使用pytorch的dataloader进行数据加载。
该类需要指定三个函数:

  • init:主要作用是进行数据的加载,指定X(特征),Y(标签)与len(样本容量)三个变量。这里将原始的784维向量调整为28×28的矩阵作为特征X。
  • len:样本容量。
  • getitem:返回(样本,标签)元组,其实就是返回了一张图片及其对应的分类。
class FashionMNISTDataset(Dataset):def __init__(self, csv_file, transform = None):data = pd.read_csv(csv_file)self.X = np.array(data.iloc[:, 1:]).reshape(-1, 1, 28, 28).astype(float)self.Y = np.array(data.iloc[:, 0])   self.len = len(self.X)del datadef __len__(self):return self.lendef __getitem__(self, idx):item = self.X[idx]label = self.Y[idx]return (item, label)

然后读取数据,创建训练集train_dataset与测试集test_dataset:

from pathlib import Path
DATA_PATH = Path('./data/')
train_dataset = FashionMNISTDataset(csv_file = DATA_PATH / "fashion-mnist_train.csv")
test_dataset = FashionMNISTDataset(csv_file = DATA_PATH / "fashion-mnist_test.csv")

最后利用dataloader进行导入。
关于超参数BATCH_SIZE:

  • 影响的是每次训练的样本个数
  • 一般设置为2的幂或者2的倍数
  • 越大的话,内存利用率更高,矩阵乘法的并行化效率提高,跑完一次epoch(全数据集)所需要的迭代次数减小,对于相同数据量的处理速度进一步加快,并且训练震荡也可能越小
  • 但是越大的话也会对内存容量要求更高,超出电脑性能限制则可能引发OOM异常

shuffle则影响的是是否打乱数据,一般只需要打乱训练集数据,测试集数据不需要打乱。

from torch.utils.data import DataLoader as dataloader
BATCH_SIZE = 256
train_loader = dataloader(dataset=train_dataset, batch_size=BATCH_SIZE, shuffle=True)
test_loader = dataloader(dataset=test_dataset, batch_size=BATCH_SIZE, shuffle=False)

2.5.训练

2.5.1.开始

首先实例化一个CNN对象,并指定是使用CPU还是GPU进行训练:

cnn = CNN()
DEVICE = torch.device("cpu")
if torch.cuda.is_available():DEVICE = torch.device("cuda")
cnn = cnn.to(DEVICE)
2.5.2.损失函数

由于本问题是一个多分类问题,使用了softmax回归将神经网络前向传播得到的结果变成概率分布,所以使用交叉熵损失。

criterion = nn.CrossEntropyLoss().to(DEVICE)
2.5.3.优化器

Adam优化器在大多数情况下都能取得不错的结果:

LEARNING_RATE = 0.01
optimizer = torch.optim.Adam(cnn.parameters(), lr=LEARNING_RATE)
2.5.4.开始训练

注意这里为了演示方便,训练批次TOTAL_EPOCHS设置成了5,为了更好的训练结果可以增大训练批次。

TOTAL_EPOCHS = 5
losses = []
for epoch in range(TOTAL_EPOCHS):#在每个批次下,遍历每个训练样本for i, (images, labels) in enumerate(train_loader):images = images.float().to(DEVICE)labels = labels.to(DEVICE)#清零optimizer.zero_grad()outputs = cnn(images)#计算损失函数loss = criterion(outputs, labels)loss.backward()optimizer.step()losses.append(loss.cpu().data.item());if (i+1) % 100 == 0:print ('Epoch : %d/%d, Iter : %d/%d,  Loss: %.4f'%(epoch+1, TOTAL_EPOCHS, i+1, len(train_dataset)//BATCH_SIZE, loss.data.item()))
2.5.5.结果评估

可视化训练结果:

plt.xlabel('Epoch #')
plt.ylabel('Loss')
plt.plot(losses)
plt.show()


保存模型:

torch.save(cnn.state_dict(), "fm-cnn3.pth")

最终结果评估具体的流程如下:

  • 将CNN模型切换成eval模式。eval模式是相对于train模式而言的,前者用于模型评估阶段,后者用于模型训练阶段。
  • 将图片放入网络中进行运算,得到结果outputs。
  • 将概率分布形式的outputs数据进行独热处理,即选择“可能性最大”的分类作为当前图片的分类。
  • 判断分类结果是否正确,并最终统计正确率。
cnn.eval()
correct = 0
total = 0
for images, labels in test_loader:images = images.float().to(DEVICE)outputs = cnn(images).cpu()_, predicted = torch.max(outputs.data, 1)total += labels.size(0)correct += (predicted == labels).sum()
print('准确率: %.4f %%' % (100 * correct / total))

在笔者机器上运算得到的准确率为91.63%

3.完整代码

# -*- coding: utf-8 -*-
import torch
from pathlib import Path
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from torch.utils.data import Dataset
from torch.utils.data import DataLoader as dataloader
import torch.nn as nnclass FashionMNISTDataset(Dataset):def __init__(self, csv_file, transform = None):data = pd.read_csv(csv_file)self.X = np.array(data.iloc[:, 1:]).reshape(-1, 1, 28, 28).astype(float)self.Y = np.array(data.iloc[:, 0])   self.len = len(self.X)del datadef __len__(self):return self.lendef __getitem__(self, idx):item = self.X[idx]label = self.Y[idx]return (item, label)class CNN(nn.Module):def __init__(self):#nn.Module子类的函数必须在构造函数中执行父类的构造函数super(CNN, self).__init__()#卷积层conv1self.conv1 = nn.Sequential(   nn.Conv2d(1, 16, kernel_size=5, padding=2),nn.BatchNorm2d(16), nn.ReLU()) #池化层pool1self.pool1=nn.MaxPool2d(2) #卷积层conv2self.conv2 = nn.Sequential(nn.Conv2d(16, 32, kernel_size=3),nn.BatchNorm2d(32),nn.ReLU())#卷积层conv3self.conv3 = nn.Sequential(nn.Conv2d(32, 64, kernel_size=3),nn.BatchNorm2d(64),nn.ReLU()) #池化层pool2self.pool2=nn.MaxPool2d(2)  #全连接层fc(输出层)self.fc = nn.Linear(5*5*64, 10)#前向传播def forward(self, x):out = self.conv1(x)out = self.pool1(out)out = self.conv2(out)out = self.conv3(out)out = self.pool2(out)#压缩成向量以供全连接out = out.view(out.size(0), -1)out = self.fc(out)return outDATA_PATH = Path('./data/')
train_dataset = FashionMNISTDataset(csv_file = DATA_PATH / "fashion-mnist_train.csv")
test_dataset = FashionMNISTDataset(csv_file = DATA_PATH / "fashion-mnist_test.csv")
BATCH_SIZE=256
train_loader = dataloader(dataset=train_dataset,batch_size=BATCH_SIZE,shuffle=True)
test_loader = dataloader(dataset=test_dataset,batch_size=BATCH_SIZE,shuffle=False)
cnn = CNN()
DEVICE = torch.device("cpu")
if torch.cuda.is_available():DEVICE = torch.device("cuda")
cnn = cnn.to(DEVICE)
criterion = nn.CrossEntropyLoss().to(DEVICE)
LEARNING_RATE = 0.01
optimizer = torch.optim.Adam(cnn.parameters(), lr=LEARNING_RATE)
TOTAL_EPOCHS = 5
losses = []
for epoch in range(TOTAL_EPOCHS):#在每个批次下,遍历每个训练样本for i, (images, labels) in enumerate(train_loader):images = images.float().to(DEVICE)labels = labels.to(DEVICE)#清零optimizer.zero_grad()outputs = cnn(images)#计算损失函数loss = criterion(outputs, labels)loss.backward()optimizer.step()losses.append(loss.cpu().data.item());if (i+1) % 100 == 0:print ('Epoch : %d/%d, Iter : %d/%d,  Loss: %.4f'%(epoch+1, TOTAL_EPOCHS, i+1, len(train_dataset)//BATCH_SIZE, loss.data.item()))plt.xlabel('Epoch #')
plt.ylabel('Loss')
plt.plot(losses)
plt.show()
torch.save(cnn.state_dict(), "fm-cnn3.pth")
cnn.eval()
correct = 0
total = 0
for images, labels in test_loader:images = images.float().to(DEVICE)outputs = cnn(images).cpu()_, predicted = torch.max(outputs.data, 1)total += labels.size(0)correct += (predicted == labels).sum()
print('准确率: %.4f %%' % (100 * correct / total))

Pytorch初学实战(一):基于的CNN的Fashion MNIST图像分类相关推荐

  1. tensorflow2.0 CNN fashion MNIST图像分类

    基于 CNN的 fashion MNIST图像分类 fashion MNIST图像分类 数据集简介 数据的预处理 CNN简介和构建 模型部分代码 CNN实验结果 致谢 fashion MNIST图像分 ...

  2. tensorflow卷积神经网络实战:Fashion Mnist 图像分类与人马分类

    卷积神经网络实战:Fashion Mnist 图像分类与人马分类 一.FashionMnist的卷积神经网络模型 1.卷积VS全连接 2.卷积网络结构 3.卷积模型结构 1)Output Shape ...

  3. 【Pytorch神经网络实战案例】21 基于Cora数据集实现Multi_Sample Dropout图卷积网络模型的论文分类

    Multi-sample Dropout是Dropout的一个变种方法,该方法比普通Dropout的泛化能力更好,同时又可以缩短模型的训练时间.XMuli-sampleDropout还可以降低训练集和 ...

  4. 【MIMIC-IV/pytorch实战】基于word2vec、transformer进行英文影像报告文本分类

    完整流程可以分以下几步 数据整理 word2vec 构建transformer模型 训练模型 测试模型 资源下载介绍 若懒得看程序,也可以直接下载全部程序,在最后一部分进行了资源的介绍. [MIMIC ...

  5. 深度学习实战篇-基于RNN的中文分词探索

    深度学习实战篇-基于RNN的中文分词探索 近年来,深度学习在人工智能的多个领域取得了显著成绩.微软使用的152层深度神经网络在ImageNet的比赛上斩获多项第一,同时在图像识别中超过了人类的识别水平 ...

  6. PyTorch学习笔记(四):PyTorch基础实战

    PyTorch实战:以FashionMNIST时装分类为例: 往期学习资料推荐: 1.Pytorch实战笔记_GoAI的博客-CSDN博客 2.Pytorch入门教程_GoAI的博客-CSDN博客 本 ...

  7. 网易云课程:深度学习与PyTorch入门实战

    网易云课程:深度学习与PyTorch入门实战 01 深度学习初见 1.1 深度学习框架简介 1.2 pytorch功能演示 2开发环境安装 3回归问题 3.1简单的回归问题(梯度下降算法) 3.3回归 ...

  8. 《机器学习实战:基于Scikit-Learn、Keras和TensorFlow(第2版)》学习笔记

    文章目录 书籍信息 技术和工具 Scikit-Learn TensorFlow Keras Jupyter notebook 资源 书籍配套资料 流行的开放数据存储库 元门户站点(它们会列出开放的数据 ...

  9. 视频教程-深度学习与PyTorch入门实战教程-深度学习

    深度学习与PyTorch入门实战教程 新加坡国立大学研究员 龙良曲 ¥399.00 立即订阅 扫码下载「CSDN程序员学院APP」,1000+技术好课免费看 APP订阅课程,领取优惠,最少立减5元 ↓ ...

最新文章

  1. ABAP日期函数应用
  2. 线上日志集中化可视化管理:ELK
  3. OO Design之SOLID原则
  4. C# string类型和byte[]类型相互转换
  5. python的变量命名及其使用
  6. 年度旗舰机广告片遭电视台泄露 三星:我有句话不知当讲不当讲
  7. 归一化方法 Normalization Method
  8. C++中char[]转string
  9. mongodb日志分析工具mtools之mplotqueries
  10. 各省产业结构-高级化指数(二产与三产比值)合理化指数
  11. digispark开发板attiny85烧写digispark bootloader
  12. 浅识Flutter 基本组件之showDatePicker方法
  13. pb11.5调用系统打印机 Function ulong ShellExecute(ulong hwnd,ref string lpOperation,ref string lpFile,ref st
  14. 使用poi导出excel生成复杂多级表头通用方法
  15. java毕业设计鸿鹄教育培训(附源码、数据库)
  16. 哲学家就餐问题实验报告
  17. Linux,下载安装minio
  18. Angular4_select设置默认选中
  19. python自动生成五言绝句,一定让你学会他!!
  20. 基于加密软件的加密和解密

热门文章

  1. signature=d522a0024e7d20dbfee94b566a5dfed5,End-to-end (e2e) application packet flow visibility
  2. python列表解析,生成表达式(一分钟读懂)
  3. matlab机液位置伺服系统,基于MATLAB的电液位置伺服系统仿真分析
  4. rabbitmq消费者获取消息慢_RabbitMQ:快速生产者和慢速消费者
  5. table合并单元格_element ui el-table 合并单元格
  6. 计算机窗口跳转列表,别小看它!Windows跳转列表效率高
  7. Hierarchical line matching based on Line–Junction–Line structure
  8. 如何找到适合结婚的女朋友
  9. Python之 dict(字典)(回)
  10. 《LaTeX写作》——LaTeX编写环境的安装笔记