Matlab深度学习-手写体数字识别
Matlab深度学习
文章目录
- Matlab深度学习
- 前言
- 一、MNIST手写体数字数据
- 二、用到的深度学习框架-LeNet5
- 2-0 LeNet5的网络架构
- 2-1 框架实现-通过Matlab GUI 拖拽界面
- 2-2 框架实现2-直接写框架代码
- 三、代码
- 3-0 机器学习与深度学习对比
- 3-1 SVM分类手写体数字
- 3-2 ANN分类手写体数字
- 3-3 深度学习LetNet5
- 测试
- 最后
前言
最近想在matlab环境下跑一下深度学习模型,找到了以下的一篇博客,但因为该博客所附录的资源失效,而且也可能因为版本问题导致数据集的预处理容易出bug,所以打算改成.mat格式数据上传到我的博客资源,方便下载和练习。原博客地址如下:
https://blog.csdn.net/longlongsvip/article/details/105466512
环境
Matlab: 2020b
GPU:NVIDIA Quadro P620
(注:Matlab2020b好像不支持NVIDIA安培架构显卡即不支持30系列,用3090在matlab2020b上测试过,报错找不到GPU)
一、MNIST手写体数字数据
MNIST手写体数据包含60000个训练样本(数字0-9),以及测试集数据10000个(数字0-9),这里不再详细叙述数据集背景,上文博客写得很详细。 我已经将数据集转换为.mat文件格式(省去上文博客中的复杂预处理步骤),其数据格式如下:
由上图可以看出,图片是以一维的形式存储的,其数据维度为1 x 784,即原28 x 28 维度的图片压缩成了一维,所以在使用的时候需要将图片恢复为二维形式(代码下面给出)。
二、用到的深度学习框架-LeNet5
2-0 LeNet5的网络架构
2-1 框架实现-通过Matlab GUI 拖拽界面
打开matlab APP 的Deep Network Designer
打开后如下图所示:(可以选择好预训练的网络,也可以新建空白网络)
点击空白网络,根据LetNet5[1]的网络结构搭建网络(拖拽搭建)
搭建完成后点击分析,没问题后就可以导出(导出-生成代码)网络了
2-2 框架实现2-直接写框架代码
很明显,以上步骤是非必要的,老手可以直接写层代码来实现…
layers = [imageInputLayer([28 28 1],"Name","imageinput")convolution2dLayer([5 5],6,"Name","conv1","Padding","same")tanhLayer("Name","tanh1")maxPooling2dLayer([2 2],"Name","maxpool1","Stride",[2 2])convolution2dLayer([5 5],16,"Name","conv2")tanhLayer("Name","tanh2")maxPooling2dLayer([2 2],"Name","maxpool","Stride",[2 2])fullyConnectedLayer(120,"Name","fc1")fullyConnectedLayer(84,"Name","fc2")fullyConnectedLayer(10,"Name","fc")softmaxLayer("Name","softmax")classificationLayer("Name","classoutput")];
三、代码
3-0 机器学习与深度学习对比
在进行实现深度网络前,可以先用传统机器学习的方法进行对比,有助于更好理解深度学习网络架构的优越性。
3-1 SVM分类手写体数字
SVM作为机器学习的老大哥,在小样本,小分类数的情况下一直表现优异,但数据量庞大和分类数量比较多的情况下不一定合适。SVM在本例的评价如下:
运行时间:1星 (不管是训练还是测试都非常慢)
准确率: 1星
用到的SVM分类器为libsvm工具箱,其调用代码如下:
load handwriting.mat %载入数据集
model=svmtrain(y_train,x_train,'-c 2 -g 0.10'); %训练
[predicted_label]=svmpredict(y_test,x_test,model); %测试
3-2 ANN分类手写体数字
相对于SVM,ANN人工神经网络更适合于处理数据量庞大的情况,但是相对与本例的LetNet-5而言,ANN忽略了输入图片的空间信息,即输入的数据是一维训练数据。设计的ANN网络结构如下:
这里我用了3层隐藏层的ANN网络,其中第一层有100个神经元,第二层60个神经元,第三层40个神经元,第四层10个神经元。ANN在本例的表现评价如下:
运行时间:2星
准确率: 3星(Accuracy:95%)
ANN的实现用matlab自带的ANN工具箱,其调用代码如下:
load handwriting.mat %载入手写体数字数据集
for i=1:size(y_train)if y_train(i)==0y_train(i)=10; %这里把数字0标签换成10,不然出bugend
end
for j=1:size(y_test)if y_test(j)==0y_test(j)=10; %这里把数字0标签换成10,不然出bugend
end
class=y_train;
[input,minI,maxI]=premnmx(x_train');
s = length( class) ;
output = zeros( s , 2 ) ;%构造输出矩阵
for i = 1 : s output( i , class( i ) ) = 1 ;
end
%% 网络参数
net = newff( minmax(input) , [100 60 40 10] , { 'logsig' 'logsig' 'logsig' 'purelin' } , 'traingdx' ) ; %创建神经网络
%激活函数有'tansig' 'logsig'以及'purelin'三种
net.trainparam.show = 50 ; %显示中间结果的周期
net.trainparam.epochs = 7000 ; %最大迭代次数(学习次数)
net.trainparam.goal = 0.01 ; %神经网络训练的目标误差
net.trainParam.lr = 0.001 ; %学习速率(Learning rate)
%% 开始训练
net = train( net, input , output' ) ; %其中input为训练集的输入信号,对应output为训练集的输出结果
%% GPU训练
% gpudev=gpuDevice;%事先声明gpudev变量为gpu设备类
% gpudev.AvailableMemory;%实时获得当前gpu的可用内存
% input=single(input);%将double型的P转为single型
% output=single(output);%将double型的T转为single型
% net = train( net, input , output' , 'useGPU','only' ) ; %GPU
%% 测试
tic
testInput=tramnmx(x_test',minI,maxI);
Y=sim(net,testInput);
[s1 , s2] = size( Y ) ; %统计识别正确率
hitNum = 0 ;
predictChar=[]; %输出结果
for i = 1 : s2[m , Index] = max( Y( : , i ) ) ;predictChar=[predictChar;Index];if( Index == y_test(i) ) hitNum = hitNum + 1 ; end
end
sprintf('识别率是 %3.3f%%',100 * hitNum / s2 )
toc
3-3 深度学习LetNet5
通过第二节获得层参数layers后,就可以直接将该参数用于深度学习训练,当然训练前需要进行数据格式的转换,把一维数据转换为二维图片数据。LetNet5在本例的评价如下:
运行时间:3星(GPU训练)
准确率: 4星
代码如下:
Datapre.m(一维数据转换为二维)
load handwriting.mat
% 将一维数据转为二维图像数据
%% 训练集
X=x_train;
X = permute(X,[2 1]); %交换数据维度
X = X./255; %归一化
X=reshape(X,[28,28,1,size(X,2)]);
X = dlarray(X, 'SSCB');
Y=categorical(y_train);
%% 测试集
X2=x_test;
X2 = permute(X2,[2 1]); %交换数据维度
X2 = X2./255; %归一化
X2=reshape(X2,[28,28,1,size(X2,2)]);
X2 = dlarray(X2, 'SSCB');
Y2=categorical(y_test);
%% 保存数据
save XY.mat X Y X2 Y2
CNNTrain.m(网络训练)
load XY.mat
XTrain=X; %训练集
YTrain=Y; %训练集标签
XTest=X2; %测试集
Ytest=Y2; %测试集标签
layers = [imageInputLayer([28 28 1],"Name","imageinput")convolution2dLayer([5 5],6,"Name","conv1","Padding","same")tanhLayer("Name","tanh1")maxPooling2dLayer([2 2],"Name","maxpool1","Stride",[2 2])convolution2dLayer([5 5],16,"Name","conv2")tanhLayer("Name","tanh2")maxPooling2dLayer([2 2],"Name","maxpool","Stride",[2 2])fullyConnectedLayer(120,"Name","fc1")fullyConnectedLayer(84,"Name","fc2")fullyConnectedLayer(10,"Name","fc")softmaxLayer("Name","softmax")classificationLayer("Name","classoutput")];options = trainingOptions('sgdm', ... %优化器'LearnRateSchedule','piecewise', ... %学习率'LearnRateDropFactor',0.2, ... 'LearnRateDropPeriod',5, ...'MaxEpochs',20, ... %最大学习整个数据集的次数'MiniBatchSize',128, ... %每次学习样本数'Plots','training-progress'); %画出整个训练过程%训练网络
trainNet = trainNetwork(XTrain, YTrain,layers,options); save Minist_LeNet5 trainNet %训练完后保存模型
yTest = classify(trainNet, XTest); %测试训练后的模型
accuracy = sum(yTest == Ytest)/numel(yTest); %模型在测试集的准确率
disp(accuracy) %打印测试集准确率
CNNTest.m(通过JPG图片进行测试)
load('Minist_LeNet5'); %导入训练好的LeNet5网络
test_image = imread('1.jpg'); %导入手写体数字图片
shape = size(test_image);
dimension=numel(shape);
if dimension > 2test_image = rgb2gray(test_image); %灰度化
end
test_image = imresize(test_image, [28,28]); %保证输入为28*28
test_image = imcomplement(test_image); %反转,使得输入网络时一定要保证图片 背景是黑色,数字部分是白色
test_image=double(test_image);
test_image=test_image'; %旋转
test_image=test_image./255; %归一化
result = classify(trainNet, test_image); %利用LetNet5分类
disp(result);
测试
测试用的是自己用windows画图工具画的10个数字,如下:
CNNTest.m运行结果:
最后
手写体数字识别.mat数据集已经上传到我的博客资源,到我的博客资源就可以下载了。
MNIST手写体数字数据集官网
http://yann.lecun.com/exdb/mnist/
参考文献
[1] Lecun Y , Bottou L . Gradient-based learning applied to document recognition[J]. Proceedings of the IEEE, 1998, 86(11):2278-2324.
Matlab深度学习-手写体数字识别相关推荐
- 基于深度学习的数字识别GUI的设计
基于深度学习的数字识别GUI的设计 用matlab的deeplearning工具箱搭建了CNN来识别手写数字的GUI. 一.训练CNN 采用的是matlab自带的数字训练集和验证集,搭建的CNN的代码 ...
- lenet5手写数字识别 matlab,LeNet5实现手写体数字识别(基于PyTorch实现)
import torch import torch.nn as nn import torch.nn.functional as F import torch.optim as optim from ...
- 基于Matlab深度学习Yolov4-tiny的交通标志识别道路标志识别检测
交通标志检测是辅助驾驶.自动驾驶系统中的重要组成部分,针对交通标志检测任务中复杂环境下的小目标检测精度低的问题,提出一种基于YOLOv4-tiny的交通标志检测方法. 基于Matlab深度学习的道路标 ...
- keras框架下的深度学习(一)手写体数字识别
文章目录 前言 一.keras的介绍及其操作使用 二.手写题数字识别 1.介绍 2.对数据的预处理 3.搭建网络框架 4.编译 5.循环训练 6.测试训练的网络模 7.总代码 三.附:梯度下降算法 1 ...
- 基于AlexNet卷积神经网络的手写体数字识别系统研究-附Matlab代码
⭕⭕ 目 录 ⭕⭕ ✳️ 一.引言 ✳️ 二.手写体数字识别系统 ✳️ 2.1 MNIST 数据集 ✳️ 2.2 CNN ✳️ 2.3 网络训练 ✳️ 三.手写体数字识别结果 ✳️ 四.参考文献 ✳️ ...
- 基于matlab的手写体数字识别系统,基于matlab的手写体数字识别系统研究
基于matlab的手写体数字识别系统研究 丁禹鑫1,丁会2,张红娟2,杨彤彤1 [摘要]随着科学技术的发展,机器学习成为一大学科热门领域,是一门专门研究计算机怎样模拟或实现人类的学习行为的交叉学科.文 ...
- 基于matlab的手写体数字识别系统
摘要:随着科学技术的发展,机器学习成为一大学科热门领域,是一门专门研究计算机怎样模拟或实现人类的学习行为的交叉学科.文章在matlab软件的基础上,利用BP神经网络算法完成手写体数字的识别. 机器学习 ...
- 基于matlab支持向量机SVM多分类手写体数字识别
此程序为本人模式识别大作业,参考了网上的代码,并进行了一定的修改,希望对大家有所帮助! 此代码主要参考了以下文章: https://blog.csdn.net/Einperson/article/de ...
- 基于MATLAB的手写体数字识别算法的实现
基于MATLAB的手写体数字识别 一.课题介绍 手写数字识别是模式识别领域的一个重要分支,它研究的核心问题是:如何利用计算机自动识别人手写在纸张上的阿拉伯数字.手写体数字识别问题,简而言之就是识别出1 ...
最新文章
- 2012年总结,2013年的计划
- 2019牛客暑期多校训练营(第十场)C - Gifted Composer (二分+哈希)
- vue 悬浮按钮_Vue@哇!几行代码实现拖拽视图组件
- 挖矿为什么要用显卡_数字货币行情分析 2020/07/17 为什么大佬们都转向显卡挖矿了?...
- linux AB测试
- WP8模拟器需要BIOS开启虚拟化支持(转载)
- 集体智慧编程 - 优化
- 常见危险函数及特殊函数(一)
- 【面经】关于逻辑回归,面试官们都怎么问
- javascript焦点图自动播放
- cocostudio的TextField空件实现光标。
- linux内核源码下载地址
- 一键清空服务器文件,一键清理操作系统垃圾文件的BAT
- 全网最全的 Java 技术栈内容梳理(持续更新中)
- 奥城大学计算机专业,双录取的美国研究生大学有哪些?哪些专业被允许?
- 朋友圈发图多大不会被压缩_微信:朋友圈照片自动压缩 不暴露位置信息
- 【转】常见蓝屏错误信息
- MySQL高级-(存储引擎、索引、锁)
- 【教学类-30-01】5以内加法题不重复(一页两份)(包含1以内、2以内、3以内、4以内、5以内加法,抽取最大不重复数量)
- j-4 大炮打蚊子 (10 分)关于最后一个测试点出错及本题的具体思路(以作者思路为例)