Matlab搭建AlexNet实现手写数字识别

个人博客地址

文章目录

  • Matlab搭建AlexNet实现手写数字识别
    • 环境
    • 内容
    • 步骤
      • 准备MNIST数据集
      • 数据预处理
      • 定义网络模型
      • 定义训练超参数
      • 网络训练和预测
    • 代码下载

环境

  • Matlab 2020a
  • Windows10

内容

使用Matlab对MNIST数据集进行预处理,搭建卷积神经网络进行训练,实现识别手写数字的任务。在训练过程中,每隔30个batch输出一次模型在验证集上的准确率和损失值。在训练结束后会输出验证集中每个数字的真实值、网络预测值和判定概率,并给出总的识别准确率。

步骤

准备MNIST数据集

为了方便进行测试,本次只选用500张MNIST数据集,每个数字50张。

下载数据集后并解压,为每个数字创建单独文件夹并将该数字的所有图片放在对应的文件夹下,如图1所示。
数据集下载地址 提取码:af6n

手动分类结束后每个文件夹中应有50张图片。

数据预处理

% 加载数据集
imds = imageDatastore(..."./data",...'IncludeSubfolders', true,...'LabelSource','foldernames');

使用imageDatastore加载数据集。第一个参数填写数据集路径。由于本次实验data目录下含有子文件夹所以IncludeSubfolders需要指定为true。LabelSource表示标签来源,这里使用文件夹名字来代表标签。

  ImageDatastore - 属性:Files: {'D:\data\0\0_1.bmp';'D:\data\0\0_10.bmp';'D:\data\0\0_11.bmp'... and 497 more}Folders: {'D:\data'}Labels: [0; 0; 0 ... and 497 more categorical]AlternateFileSystemRoots: {}ReadSize: 1SupportedOutputFormats: [1×5 string]DefaultOutputFormat: "png"ReadFcn: @readDatastoreImage

上面内容为执行imageDatastore后返回变量的属性。可以看出已经成功将数据集读入并对每张图片进行label处理。

由于每个数字有50张图像,因此本次实验每个数字选用30张进行训练,另20张进行验证。使用splitEachLabel进行划分,得到训练集和验证集。

% 数据打乱
imds = shuffle(imds);% 划分训练集和验证集。每一个类别训练集有30个,验证集有20个
[imdsTrain,imdsValidation] = splitEachLabel(imds, 30);

使用shuffle进行数据打乱。得到的imdsTrain和imdsValidation分别有300和200张图片。

% 将训练集与验证集中图像的大小调整成与输入层的大小相同
augimdsTrain = augmentedImageDatastore([28 28],imdsTrain);
augimdsValidation = augmentedImageDatastore([28 28],imdsValidation);

定义网络模型

% 构建alexnet卷积网络
alexnet = [imageInputLayer([56,56,1], 'Name', 'Input')convolution2dLayer([11,11],48,'Padding','same','Stride',4, 'Name', 'Conv_1')batchNormalizationLayer('Name', 'BN_1')reluLayer('Name', 'Relu_1')maxPooling2dLayer(3,'Padding','same','Stride',2, 'Name', 'MaxPooling_1')convolution2dLayer([5,5],128,'Padding',2,'Stride',1, 'Name', 'Conv_2')batchNormalizationLayer('Name', 'BN_2')reluLayer('Name', 'Relu_2')maxPooling2dLayer(3,'Padding','same','Stride',2, 'Name', 'MaxPooling_2')convolution2dLayer([3 3],192,'Padding',1,'Stride',1, 'Name', 'Conv_3')batchNormalizationLayer('Name', 'BN_3')reluLayer('Name', 'Relu_3')convolution2dLayer([3 3],192,'Padding',1,'Stride',1, 'Name', 'Conv_4')batchNormalizationLayer('Name', 'BN_4')reluLayer('Name', 'Relu_4')convolution2dLayer([3 3],128,'Stride',1,'Padding',1, 'Name', 'Conv_5')batchNormalizationLayer('Name', 'BN_5')reluLayer('Name', 'Relu_5')maxPooling2dLayer(3,'Padding','same','Stride',2, 'Name', 'MaxPooling_3')fullyConnectedLayer(4096, 'Name', 'FC_1')reluLayer('Name', 'Relu_6')fullyConnectedLayer(4096, 'Name', 'FC_2')reluLayer('Name', 'Relu_7')fullyConnectedLayer(10, 'Name', 'FC_3')    % 将新的全连接层的输出设置为训练数据中的种类softmaxLayer('Name', 'Softmax')            % 添加新的Softmax层classificationLayer('Name', 'Output') ];   % 添加新的分类层

使用上面的代码即可构建AlexNet模型。

% 对构建的网络进行可视化分析
lgraph = layerGraph(mynet);
analyzeNetwork(lgraph)

定义训练超参数

% 配置训练选项
options = trainingOptions('sgdm', ...'InitialLearnRate',0.001, ...    'MaxEpochs',100, ...               'Shuffle','every-epoch', ...'ValidationData',augimdsValidation, ...'ValidationFrequency',30, ...'Verbose',true, ...'Plots','training-progress');

本次实验选用sgdm作为优化器,初始学习率设置为0.001,最大迭代次数为100,每次迭代都会打乱数据,每隔30个batch进行一次验证。

网络训练和预测

% 对网络进行训练
net = trainNetwork(augimdsTrain, mynet, options); % 将训练好的网络用于对新的输入图像进行分类,得到预测结果和判定概率
[YPred, err] = classify(net, augimdsValidation);

其中,YPred是存放网络对验证集预测结果的数组,err存放着每个数字的判定概率。

% 打印真实数字、预测数字、判定概率和准确率
YValidation = imdsValidation.Labels;
for i=1:200
fprintf("真实数字:%d  预测数字:%d", double(YValidation(i,1))-1, double(YPred(i, 1))-1);
fprintf("  判定概率:%f\n", max(err(i, :)));
end

运行上面代码即可打印相关结果。

... ...
真实数字:4  预测数字:4  判定概率:0.814434
真实数字:0  预测数字:0  判定概率:0.657829
真实数字:8  预测数字:8  判定概率:0.874560
真实数字:0  预测数字:0  判定概率:0.988826
真实数字:6  预测数字:6  判定概率:0.970034
... ...
真实数字:5  预测数字:5  判定概率:0.806220
真实数字:4  预测数字:4  判定概率:0.938233
真实数字:7  预测数字:7  判定概率:0.906994
真实数字:7  预测数字:7  判定概率:0.837794
真实数字:6  预测数字:6  判定概率:0.951572
真实数字:6  预测数字:1  判定概率:0.415834
真实数字:5  预测数字:5  判定概率:0.789031
真实数字:2  预测数字:2  判定概率:0.363526
真实数字:7  预测数字:7  判定概率:0.930049准确率:0.880000

代码下载

GitHub下载

Matlab搭建AlexNet实现手写数字识别相关推荐

  1. 基于matlab BP神经网络的手写数字识别

    摘要 本文实现了基于MATLAB关于神经网络的手写数字识别算法的设计过程,采用神经网络中反向传播神经网络(即BP神经网络)对手写数字的识别,由MATLAB对图片进行读入.灰度化以及二值化等处理,通过神 ...

  2. matlab朴素贝叶斯手写数字识别_TensorFlow手写数字识别(一)

    本篇文章通过TensorFlow搭建最基础的全连接网络,使用MNIST数据集实现基础的模型训练和测试. MNIST数据集 MNIST数据集 :包含7万张黑底白字手写数字图片,其中55000张为训练集, ...

  3. matlab朴素贝叶斯手写数字识别_从“手写数字识别”学习分类任务

    机器学习问题可以分为回归问题和分类问题,回归问题已经在线性回归讲过,本文学习分类问题.分类问题跟回归问题有明显的区别,回归问题是连续的数值,而分类问题是离散的类别,比如将性别分为[男,女],将图片分为 ...

  4. Python纯手动搭建BP神经网络--手写数字识别

    1 实验介绍 实验要求: 实现一个手写数字识别程序, 如下图所示, 要求神经网络包含一个隐层, 隐层的神经元个数为 15. 整体思路:主要参考西瓜书第五章神经网络部分的介绍,使用批量梯度下降对神经网络 ...

  5. Keras搭建CNN(手写数字识别Mnist)

    MNIST数据集是手写数字识别通用的数据集,其中的数据是以二进制的形式保存的,每个数字是由28*28的矩阵表示的. 我们使用卷积神经网络对这些手写数字进行识别,步骤大致为: 导入库和模块 我们导入Se ...

  6. matlab朴素贝叶斯手写数字识别_机器学习系列四:MNIST 手写数字识别

    4. MNIST 手写数字识别 机器学习中另外一个相当经典的例子就是MNIST的手写数字学习.通过海量标定过的手写数字训练,可以让计算机认得0~9的手写数字.相关的实现方法和论文也很多,我们这一篇教程 ...

  7. Matlab:神经网络实现手写数字识别

    如今人工智能发展的时代,机器学习有着不可或缺的地位,而其中最为突出的模型该属于神经网络.从提出神经网络开始,历经感知机.人工神经网络.BP神经网络.进化神经网络.卷积神经网络.图神经网络等,不断的深入 ...

  8. matlab朴素贝叶斯手写数字识别_基于MNIST数据集实现手写数字识别

    介绍 在TensorFlow的官方入门课程中,多次用到mnist数据集.mnist数据集是一个数字手写体图片库,但它的存储格式并非常见的图片格式,所有的图片都集中保存在四个扩展名为idx*-ubyte ...

  9. MATLAB人工神经网络的手写数字识别系统

    函数MouseDraw实现手写识别系统GUI界面的建立和鼠标手写的实现.(使用时保存为MouseDraw.m) function MouseDraw(action) % MouseDraw 本例展示如 ...

最新文章

  1. Java项目:网上电子书城项目(java+SSM+JSP+maven+Mysql)
  2. 如何垂直居中一个浮动元素
  3. Oracle SQL Developer语言设置
  4. eclipse如何快速查找某个类
  5. mysql查看日志命令_面对成百上千台服务器产生的日志,试试这款轻量级日志搬运神器!...
  6. 阿里在美申请区块链专利;Win10 最新漏洞被发现;MongoDB 4.2 发布​ | 极客头条...
  7. 为了进大厂,韩顺平Java教程百度云
  8. 京东商品评论的文本主题分析
  9. pytorch实现多种经典GAN
  10. 【电源专题】脉宽调制(PWM)与脉冲频率调制(PFM)
  11. 录入人员照片注意事项无身份证人员录入
  12. 远方测试软件,远方测试仪操作指导书
  13. 磁盘压缩卷只能压缩一半
  14. 初来乍到,请多多指教
  15. cut and choose
  16. MarkText ctrl+num 切换 标题级别快捷键 失效问题
  17. Lingo练习 选拔问题
  18. 随机森林算法(Random Forest)R语言实现
  19. 互联网最有价值行业分析
  20. 计算机组成及层次结构

热门文章

  1. Springer期刊投稿的部分(latex)模板资料
  2. 中国外汇交易中心员工英语培训四次合作TutorABC
  3. 无字天书之Python第十五页(Excel表格操作)
  4. 各路大咖云集探讨eBPF技术在可观测性领域的落地现状和未来可能
  5. Ubuntu 18.04 独显和集显切换
  6. 利用JQeury屏蔽网页广告(针对手机端,pc端有很多很好用的插件,请自行google)
  7. 穷人瞧不起的暴利赚钱项目, 富人却在偷偷地发着横财!
  8. android的Bundle
  9. JavaXml教程(八)使用JDOM将Java对象转换为XML
  10. 谷歌浏览器访问localhost跨域解决