Matlab深度学习

文章目录

  • Matlab深度学习
  • 前言
  • 一、MNIST手写体数字数据
  • 二、用到的深度学习框架-LeNet5
    • 2-0 LeNet5的网络架构
    • 2-1 框架实现-通过Matlab GUI 拖拽界面
    • 2-2 框架实现2-直接写框架代码
  • 三、代码
    • 3-0 机器学习与深度学习对比
    • 3-1 SVM分类手写体数字
    • 3-2 ANN分类手写体数字
    • 3-3 深度学习LetNet5
    • 测试
  • 最后

前言

   最近想在matlab环境下跑一下深度学习模型,找到了以下的一篇博客,但因为该博客所附录的资源失效,而且也可能因为版本问题导致数据集的预处理容易出bug,所以打算改成.mat格式数据上传到我的博客资源,方便下载和练习。原博客地址如下:
https://blog.csdn.net/longlongsvip/article/details/105466512
环境
Matlab: 2020b
GPU:NVIDIA Quadro P620
(注:Matlab2020b好像不支持NVIDIA安培架构显卡即不支持30系列,用3090在matlab2020b上测试过,报错找不到GPU)


一、MNIST手写体数字数据

   MNIST手写体数据包含60000个训练样本(数字0-9),以及测试集数据10000个(数字0-9),这里不再详细叙述数据集背景,上文博客写得很详细。 我已经将数据集转换为.mat文件格式(省去上文博客中的复杂预处理步骤),其数据格式如下:

由上图可以看出,图片是以一维的形式存储的,其数据维度为1 x 784,即原28 x 28 维度的图片压缩成了一维,所以在使用的时候需要将图片恢复为二维形式(代码下面给出)。

二、用到的深度学习框架-LeNet5

2-0 LeNet5的网络架构

2-1 框架实现-通过Matlab GUI 拖拽界面

打开matlab APP 的Deep Network Designer

打开后如下图所示:(可以选择好预训练的网络,也可以新建空白网络)

点击空白网络,根据LetNet5[1]的网络结构搭建网络(拖拽搭建)

搭建完成后点击分析,没问题后就可以导出(导出-生成代码)网络了

2-2 框架实现2-直接写框架代码

很明显,以上步骤是非必要的,老手可以直接写层代码来实现…

layers = [imageInputLayer([28 28 1],"Name","imageinput")convolution2dLayer([5 5],6,"Name","conv1","Padding","same")tanhLayer("Name","tanh1")maxPooling2dLayer([2 2],"Name","maxpool1","Stride",[2 2])convolution2dLayer([5 5],16,"Name","conv2")tanhLayer("Name","tanh2")maxPooling2dLayer([2 2],"Name","maxpool","Stride",[2 2])fullyConnectedLayer(120,"Name","fc1")fullyConnectedLayer(84,"Name","fc2")fullyConnectedLayer(10,"Name","fc")softmaxLayer("Name","softmax")classificationLayer("Name","classoutput")];

三、代码

3-0 机器学习与深度学习对比

   在进行实现深度网络前,可以先用传统机器学习的方法进行对比,有助于更好理解深度学习网络架构的优越性。

3-1 SVM分类手写体数字

   SVM作为机器学习的老大哥,在小样本,小分类数的情况下一直表现优异,但数据量庞大和分类数量比较多的情况下不一定合适。SVM在本例的评价如下:
运行时间:1星 (不管是训练还是测试都非常慢)
准确率: 1星
用到的SVM分类器为libsvm工具箱,其调用代码如下:

load handwriting.mat                               %载入数据集
model=svmtrain(y_train,x_train,'-c 2 -g 0.10');    %训练
[predicted_label]=svmpredict(y_test,x_test,model); %测试

3-2 ANN分类手写体数字

   相对于SVM,ANN人工神经网络更适合于处理数据量庞大的情况,但是相对与本例的LetNet-5而言,ANN忽略了输入图片的空间信息,即输入的数据是一维训练数据。设计的ANN网络结构如下:

   这里我用了3层隐藏层的ANN网络,其中第一层有100个神经元,第二层60个神经元,第三层40个神经元,第四层10个神经元。ANN在本例的表现评价如下:
运行时间:2星
准确率: 3星(Accuracy:95%)
ANN的实现用matlab自带的ANN工具箱,其调用代码如下:

load handwriting.mat     %载入手写体数字数据集
for i=1:size(y_train)if y_train(i)==0y_train(i)=10;    %这里把数字0标签换成10,不然出bugend
end
for j=1:size(y_test)if y_test(j)==0y_test(j)=10;     %这里把数字0标签换成10,不然出bugend
end
class=y_train;
[input,minI,maxI]=premnmx(x_train');
s = length( class) ;
output = zeros( s , 2 ) ;%构造输出矩阵
for i = 1 : s output( i , class( i )  ) = 1 ;
end
%% 网络参数
net = newff( minmax(input) , [100 60 40 10] , { 'logsig' 'logsig' 'logsig' 'purelin' } , 'traingdx' ) ;  %创建神经网络
%激活函数有'tansig' 'logsig'以及'purelin'三种
net.trainparam.show = 50 ;                %显示中间结果的周期
net.trainparam.epochs = 7000 ;            %最大迭代次数(学习次数)
net.trainparam.goal = 0.01 ;              %神经网络训练的目标误差
net.trainParam.lr = 0.001 ;               %学习速率(Learning rate)
%% 开始训练
net = train( net, input , output' ) ;     %其中input为训练集的输入信号,对应output为训练集的输出结果
%% GPU训练
% gpudev=gpuDevice;%事先声明gpudev变量为gpu设备类
% gpudev.AvailableMemory;%实时获得当前gpu的可用内存
% input=single(input);%将double型的P转为single型
% output=single(output);%将double型的T转为single型
% net = train( net, input , output' , 'useGPU','only' ) ;     %GPU
%% 测试
tic
testInput=tramnmx(x_test',minI,maxI);
Y=sim(net,testInput);
[s1 , s2] = size( Y ) ;                   %统计识别正确率
hitNum = 0 ;
predictChar=[];                           %输出结果
for i = 1 : s2[m , Index] = max( Y( : ,  i ) ) ;predictChar=[predictChar;Index];if( Index  == y_test(i)   ) hitNum = hitNum + 1 ; end
end
sprintf('识别率是 %3.3f%%',100 * hitNum / s2 )
toc

3-3 深度学习LetNet5

   通过第二节获得层参数layers后,就可以直接将该参数用于深度学习训练,当然训练前需要进行数据格式的转换,把一维数据转换为二维图片数据。LetNet5在本例的评价如下:
运行时间:3星(GPU训练)
准确率: 4星
代码如下:
Datapre.m(一维数据转换为二维)

load handwriting.mat
% 将一维数据转为二维图像数据
%% 训练集
X=x_train;
X = permute(X,[2 1]);    %交换数据维度
X = X./255;              %归一化
X=reshape(X,[28,28,1,size(X,2)]);
X = dlarray(X, 'SSCB');
Y=categorical(y_train);
%% 测试集
X2=x_test;
X2 = permute(X2,[2 1]);    %交换数据维度
X2 = X2./255;              %归一化
X2=reshape(X2,[28,28,1,size(X2,2)]);
X2 = dlarray(X2, 'SSCB');
Y2=categorical(y_test);
%% 保存数据
save XY.mat X Y X2 Y2

CNNTrain.m(网络训练)

load XY.mat
XTrain=X;      %训练集
YTrain=Y;      %训练集标签
XTest=X2;      %测试集
Ytest=Y2;      %测试集标签
layers = [imageInputLayer([28 28 1],"Name","imageinput")convolution2dLayer([5 5],6,"Name","conv1","Padding","same")tanhLayer("Name","tanh1")maxPooling2dLayer([2 2],"Name","maxpool1","Stride",[2 2])convolution2dLayer([5 5],16,"Name","conv2")tanhLayer("Name","tanh2")maxPooling2dLayer([2 2],"Name","maxpool","Stride",[2 2])fullyConnectedLayer(120,"Name","fc1")fullyConnectedLayer(84,"Name","fc2")fullyConnectedLayer(10,"Name","fc")softmaxLayer("Name","softmax")classificationLayer("Name","classoutput")];options = trainingOptions('sgdm', ...         %优化器'LearnRateSchedule','piecewise', ...      %学习率'LearnRateDropFactor',0.2, ...             'LearnRateDropPeriod',5, ...'MaxEpochs',20, ...                       %最大学习整个数据集的次数'MiniBatchSize',128, ...                  %每次学习样本数'Plots','training-progress');             %画出整个训练过程%训练网络
trainNet = trainNetwork(XTrain, YTrain,layers,options); save Minist_LeNet5 trainNet                      %训练完后保存模型
yTest = classify(trainNet, XTest);               %测试训练后的模型
accuracy = sum(yTest == Ytest)/numel(yTest);     %模型在测试集的准确率
disp(accuracy)                                   %打印测试集准确率

CNNTest.m(通过JPG图片进行测试)

load('Minist_LeNet5');                      %导入训练好的LeNet5网络
test_image = imread('1.jpg');               %导入手写体数字图片
shape = size(test_image);
dimension=numel(shape);
if dimension > 2test_image = rgb2gray(test_image);      %灰度化
end
test_image = imresize(test_image, [28,28]); %保证输入为28*28
test_image = imcomplement(test_image);      %反转,使得输入网络时一定要保证图片 背景是黑色,数字部分是白色
test_image=double(test_image);
test_image=test_image';                     %旋转
test_image=test_image./255;                 %归一化
result = classify(trainNet, test_image);    %利用LetNet5分类
disp(result);

测试

测试用的是自己用windows画图工具画的10个数字,如下:

CNNTest.m运行结果:

最后

   手写体数字识别.mat数据集已经上传到我的博客资源,到我的博客资源就可以下载了。

   MNIST手写体数字数据集官网
   http://yann.lecun.com/exdb/mnist/

   参考文献
   [1] Lecun Y , Bottou L . Gradient-based learning applied to document recognition[J].    Proceedings of the IEEE, 1998, 86(11):2278-2324.

Matlab深度学习-手写体数字识别相关推荐

  1. 基于深度学习的数字识别GUI的设计

    基于深度学习的数字识别GUI的设计 用matlab的deeplearning工具箱搭建了CNN来识别手写数字的GUI. 一.训练CNN 采用的是matlab自带的数字训练集和验证集,搭建的CNN的代码 ...

  2. lenet5手写数字识别 matlab,LeNet5实现手写体数字识别(基于PyTorch实现)

    import torch import torch.nn as nn import torch.nn.functional as F import torch.optim as optim from  ...

  3. 基于Matlab深度学习Yolov4-tiny的交通标志识别道路标志识别检测

    交通标志检测是辅助驾驶.自动驾驶系统中的重要组成部分,针对交通标志检测任务中复杂环境下的小目标检测精度低的问题,提出一种基于YOLOv4-tiny的交通标志检测方法. 基于Matlab深度学习的道路标 ...

  4. keras框架下的深度学习(一)手写体数字识别

    文章目录 前言 一.keras的介绍及其操作使用 二.手写题数字识别 1.介绍 2.对数据的预处理 3.搭建网络框架 4.编译 5.循环训练 6.测试训练的网络模 7.总代码 三.附:梯度下降算法 1 ...

  5. 基于AlexNet卷积神经网络的手写体数字识别系统研究-附Matlab代码

    ⭕⭕ 目 录 ⭕⭕ ✳️ 一.引言 ✳️ 二.手写体数字识别系统 ✳️ 2.1 MNIST 数据集 ✳️ 2.2 CNN ✳️ 2.3 网络训练 ✳️ 三.手写体数字识别结果 ✳️ 四.参考文献 ✳️ ...

  6. 基于matlab的手写体数字识别系统,基于matlab的手写体数字识别系统研究

    基于matlab的手写体数字识别系统研究 丁禹鑫1,丁会2,张红娟2,杨彤彤1 [摘要]随着科学技术的发展,机器学习成为一大学科热门领域,是一门专门研究计算机怎样模拟或实现人类的学习行为的交叉学科.文 ...

  7. 基于matlab的手写体数字识别系统

    摘要:随着科学技术的发展,机器学习成为一大学科热门领域,是一门专门研究计算机怎样模拟或实现人类的学习行为的交叉学科.文章在matlab软件的基础上,利用BP神经网络算法完成手写体数字的识别. 机器学习 ...

  8. 基于matlab支持向量机SVM多分类手写体数字识别

    此程序为本人模式识别大作业,参考了网上的代码,并进行了一定的修改,希望对大家有所帮助! 此代码主要参考了以下文章: https://blog.csdn.net/Einperson/article/de ...

  9. 基于MATLAB的手写体数字识别算法的实现

    基于MATLAB的手写体数字识别 一.课题介绍 手写数字识别是模式识别领域的一个重要分支,它研究的核心问题是:如何利用计算机自动识别人手写在纸张上的阿拉伯数字.手写体数字识别问题,简而言之就是识别出1 ...

最新文章

  1. 2012年总结,2013年的计划
  2. 2019牛客暑期多校训练营(第十场)C - Gifted Composer (二分+哈希)
  3. vue 悬浮按钮_Vue@哇!几行代码实现拖拽视图组件
  4. 挖矿为什么要用显卡_数字货币行情分析 2020/07/17 为什么大佬们都转向显卡挖矿了?...
  5. linux AB测试
  6. WP8模拟器需要BIOS开启虚拟化支持(转载)
  7. 集体智慧编程 - 优化
  8. 常见危险函数及特殊函数(一)
  9. 【面经】关于逻辑回归,面试官们都怎么问
  10. javascript焦点图自动播放
  11. cocostudio的TextField空件实现光标。
  12. linux内核源码下载地址
  13. 一键清空服务器文件,一键清理操作系统垃圾文件的BAT
  14. 全网最全的 Java 技术栈内容梳理(持续更新中)
  15. 奥城大学计算机专业,双录取的美国研究生大学有哪些?哪些专业被允许?
  16. 朋友圈发图多大不会被压缩_微信:朋友圈照片自动压缩 不暴露位置信息
  17. 【转】常见蓝屏错误信息
  18. MySQL高级-(存储引擎、索引、锁)
  19. 【教学类-30-01】5以内加法题不重复(一页两份)(包含1以内、2以内、3以内、4以内、5以内加法,抽取最大不重复数量)
  20. j-4 大炮打蚊子 (10 分)关于最后一个测试点出错及本题的具体思路(以作者思路为例)

热门文章

  1. Oracle数据库•笔记
  2. SPOJ 5 The Next Palindrome
  3. Go语言-【包package】-包的基本概念
  4. JAVA架构师之路十五:设计模式之策略模式
  5. python元组有啥用_python元组是什么?python元组的用法介绍
  6. < 每日知识点:关于Javascript 精进小妙招 ( Js技巧 ) >
  7. 面向对象整体GIS数据模型的设计与实现
  8. 如何快速处理图片?超简单实用的图片处理工具推荐
  9. 【NYOJ 289 】
  10. 智能图像处理:基于边缘去除和迭代式内容矫正的复杂文档图像校正