前言

手写字体MNIST数据集是一组常见的图像,其常用于测评和比较机器学习算法的性能,本文使用pytorch框架来实现对该数据集的识别,并对结果进行逐步的优化。

一、数据集

MNIST数据集是由28x28大小的0-255像素值范围的灰度图像(如下图所示),其中6万张用于训练模型,1万张用于测试模型。

该数据集可从以下链接获取:
训练数据集:
https://pjreddie.com/media/files/mnist_train.csv
测试数据集:
https://pjreddie.com/media/files/mnist_test.csv
数据集一行有785个值,第一个值为图像中的数字标签,其余784个值为图像的像素值。
读取数据实例代码如下:

import pandas
import matplotlib.pyplot as pltdf = pandas.read_csv(r'./data/mnist_train.csv', header=None)
# print(df.head())  # 显示前5行
# print(df.info())   # 显示DataFrame概况
row = 0
data = df.iloc[row]
label = data[0],
img = data[1:].values.reshape(28, 28)
plt.title('label = ' + str(label))
plt.imshow(img, interpolation='none', cmap='Blues')
plt.show()

二、建立模型

# 构建模型
import torch
import torch.nn as nn
from torch.utils.data import Datasetclass Classifier(nn.Module):def __init__(self):# 初始化pytorch父类super().__init__()self.model = nn.Sequential(nn.Linear(784, 200),nn.Sigmoid(),nn.Linear(200, 10),nn.Sigmoid())self.loss_function = nn.MSELoss()self.optimizer = torch.optim.SGD(self.parameters(), lr=0.01)self.counter = 0self.progress = []def forward(self, inputs):return self.model(inputs)def train_model(self, inputs, targets):outputs = self.forward(inputs)loss = self.loss_function(outputs, targets)self.optimizer.zero_grad()  # 梯度归零 ,因为反向传播计算的梯度会累计loss.backward()  # 反向传播self.optimizer.step()  # 更新权重# 可视化训练过程self.counter += 1if self.counter % 10 == 0:self.progress.append(loss.item())  # 获取单张张量里的数字passif self.counter % 10000 == 0:print('counter = ', self.counter)passdef plot_progress(self):df = pandas.DataFrame(self.progress, columns=['loss'])df.plot(ylim=(0, 1.0), figsize=(16, 8), alpha=0.1, marker='.', grid=True, yticks=(0, 0.25, 0.5))plt.show()passclass MnistDataset(Dataset):def __init__(self, csv_file):self.data_df = pandas.read_csv(csv_file, header=None)passdef __len__(self):return len(self.data_df)def __getitem__(self, index):label = self.data_df.iloc[index, 0]target = torch.zeros((10))target[label] = 1image_value = torch.FloatTensor(self.data_df.iloc[index, 1:].values) / 255.0return label, image_value, targetdef plot_image(self, index):arr = self.data_df.iloc[index, 1:].values.reshape(28, 28)plt.title('label = ' + str(self.data_df.iloc[index, 0]))plt.imshow(arr, interpolation='none', cmap='Blues')plt.show()passpass

以上建立模型框架,并对训练过程进行可视化,建立读取数据类。

三、训练分类模型

mnist_train_dataset = MnistDataset(r'./data/mnist_train.csv')
# mnist_dataset.plot_image(9)# 训练分类模型
start_time = time.time()
C = Classifier()
epochs = 3  # 训练3轮
for i in range(epochs):print('training epoch ', i+1, 'of', epochs)for lable, image_tensor, target_tensor in mnist_train_dataset:C.train_model(image_tensor, target_tensor)passpass
C.plot_process()
print('run time = ', (time.time()-start_time) / 60)

训练3轮所花费的时间大约不到3min,效率还不错

四、测试模型

# 测试模型
mnist_test_dataset = MnistDataset(r'./data/mnist_test.csv')
record = 19
mnist_test_dataset.plot_image(record)  # 图像里的数字
image_data = mnist_test_dataset[record][1]
output = C.forward(image_data)
pandas.DataFrame(output.detach().numpy()).plot(kind='bar', legend=False, ylim=(0, 1))  # 预测的数字
plt.show()score = 0
items = 0
for label, img_tensor, label_tensor in mnist_test_dataset:ans = C.forward(img_tensor)if ans.argmax() == label:score += 1passitems += 1pass
print(score, items, score / items)

模型的测试分数是87%,考虑到这是一个简单的网络,这个分数不算太差。

五、模型优化

模型的优化主要从四个方面着手:

  • 1、损失函数
    在上面的模型中设计损失函数为MSEloss,这里将其更改为二元交叉熵损失((binary cross entropy loss)
self.loss_function = nn.BCELoss()

训练3轮,发现分数由87%提升到91%了

  • 2、激活函数
    Sigmoid激活函数的一个缺点是,当输入值变大时,梯度会变得非常小甚至消失。现在常用的是改进过的线性整流函数Leaky ReLU,也叫带泄露线性整流函数。
self.model = nn.Sequential(nn.Linear(784, 200),# nn.Sigmoid(),nn.LeakyReLU(0.02),nn.Linear(200, 10),# nn.Sigmoid()nn.LeakyReLU(0.02))

损失函数为原来的MSEloss,训练3轮,分数由87%上升到97%,这是一个很大的提升。

  • 3 、优化器
    上面模型所使用的是梯度下降法,该方法的一个缺点是会陷入损失函数的局部最小值,另一个缺点是对所有可学习参数都使用同一学习率。常见的替代方案是Adam,它利用动量减少陷入局部最小的可能,另外它对每个可学习参数使用单独的学习率,这些学习率随着每个参数在训练期间的变化而变化。
self.optimizer = torch.optim.Adam(self.parameters())

仅改变优化器发现模型达到和修改激活函数一样的效果,分数由87%提升到97%。

  • 4、标准化
    标准化是指减少网络中的参数和信号的取值范围,将均值转换为0,常见做法是在信号输入到神经网络前将其进行标准化。
self.model = nn.Sequential(nn.Linear(784, 200),nn.Sigmoid(),# nn.LeakyReLU(0.02),nn.LayerNorm(200),     # 标准化nn.Linear(200, 10),nn.Sigmoid()# nn.LeakyReLU(0.02))

向网络中添加标准化,模型的分数87%提升到91%
将以上所有方法进行整合,由于二元交叉熵函数只能处理0~1的值,而LeakyReLU可能会输出范围外的值,将后一层激活函数保留为原来的Sigmoid函数:

 self.model = nn.Sequential(nn.Linear(784, 200),# nn.Sigmoid(),nn.LeakyReLU(0.02),nn.LayerNorm(200),nn.Linear(200, 10),nn.Sigmoid()# nn.LeakyReLU(0.02))

3周期训练完后,模型的分数为97%,整合的优化方案无法使模型分数大于97%。

END

参考资料

-[英]塔里克•拉希德(Tariq Rashid)著,韩江雷译. PyTorch生成对抗网络编程. 人民邮电出版社

pytorch应用于MNIST手写字体识别相关推荐

  1. MNIST手写字体识别入门编译过程遇到的问题及解决

    MNIST手写字体识别入门编译过程遇到的问题及解决 以MNIST手写字体识别作为神经网络及各种网络模型的作为练手,将遇到的问题在这里记录与交流. 激活tensorflow环境后,运行spyder或者j ...

  2. matlab文字bp识别,MNIST手写字体识别(CNN+BP两种实现)-Matlab程序

    [实例简介] MNIST手写字 Matlab程序,包含BP和CNN程序.不依赖任何库,包含MNIST数据,BP网络可达到98.3%的识别率,CNN可达到99%的识别率.CNN比较耗时,关于CNN的程序 ...

  3. linux手写数字识别,OpenCV 3.0中的SVM训练 mnist 手写字体识别

    前言: SVM(支持向量机)一种训练分类器的学习方法 mnist 是一个手写字体图像数据库,训练样本有60000个,测试样本有10000个 LibSVM 一个常用的SVM框架 OpenCV3.0 中的 ...

  4. TensorFlow | 使用Tensorflow带你实现MNIST手写字体识别

    github:https://github.com/MichaelBeechan CSDN:https://blog.csdn.net/u011344545 涉及代码:https://github.c ...

  5. PyTorch MNIST手写字体识别

    代码: # 1 加载必要的库 import torch import torch.nn as nn import torch.nn.functional as F import torch.optim ...

  6. Pytorch入门:LeNet手写字体识别案例

    # 1 加载必要的库 # 2 定义超参数 # 3 构建pipeline(transforms),对图像进行处理 # 4 下载,加载数据集(MNIST) # 5 创建网络模型 # 6 定义优化器 # 7 ...

  7. (二)Tensorflow搭建卷积神经网络实现MNIST手写字体识别及预测

    1 搭建卷积神经网络 1.0 网络结构 图1.0 卷积网络结构 1.2 网络分析 序号 网络层 描述 1 卷积层 一张原始图像(28, 28, 1),batch=1,经过卷积处理,得到图像特征(28, ...

  8. PyTorch手写字体识别MNIST

    手写字体识别MNIST 1.准备工作 可以看这个老师的视频进行学习,讲解的非常仔细:视频学习 2.项目代码 2.1 导入模块 # 1.加载相关库 import torch import torch.n ...

  9. pytorch CNN手写字体识别

    ## """CNN手写字体识别"""import torch import torch.nn as nn from torch.autogr ...

最新文章

  1. python对于办公有什么帮助-日常工作中python能够有哪些帮助?
  2. 10_上午回顾数据库事务
  3. java迪杰斯特拉算法_迪杰斯特拉算法完整代码(Java)
  4. (Object-C)学习笔记(一)--开发环境配置和与c语言的区别
  5. Linux命令解释之rpm
  6. 直接插入排序和冒泡排序有什么区别 直接插入排序和冒泡排序有哪些不同
  7. excel文件下载下来损坏 js_js实现txt/excel文件下载
  8. 形态学图像处理之边界提取与跟踪
  9. seo需要编程技术吗?学黑帽seo需要什么技术?
  10. 博客中常用的Emoji表情整理,欢迎自取
  11. 解决QML debugging is enabled.Only use this in a safe environment.警告
  12. php删除双引号,php如何去除引号
  13. 认识越南语的发音体系
  14. elementUI合并表头
  15. Java自定义导出列_后台生成EXCEL文档,自定义列
  16. 【MATLAB教程案例47】基于双目相机拍摄图像的三维重建matlab仿真
  17. USB_HID协议基础
  18. 【历史上的今天】3 月 12 日:万维网概念被提出;Google Code 停运;仙童半导体公司被收购
  19. 基于uni-app开发的一款视频播放器插件
  20. 【亲测有效】Linux系统安装NVIDIA显卡驱动

热门文章

  1. 宽带连接远程计算机691,电脑宽带连接错误691怎么办? 爱问知识人
  2. 禁止iphone浏览器拖动反弹(橡皮筋效果)
  3. Hugging Face(1)——Transformer Models
  4. GPU与GPGPU泛淡
  5. 信息爆炸时代的纳米技术-分子通信
  6. java删除表格_Java 创建、删除Word表格
  7. Monte Carlo Integration 蒙特卡罗方法求积分 附简单例题+代码
  8. access与trunk详细解析+区别
  9. Java基础知识----字符串
  10. java菜鸟1:jdk 安装