Matlab&深度学习

0.重要的事情

之前我没有注意到Matlab版本的问题,给一些小伙伴造成了困扰,抱歉,以后我会详细说明的
最近很忙,大概率到明年才更新Matlab&深度学习(二),加油

1.为什么使用Matlab?

如今Python语言占据了深度学习,然而Matlab也是可以做的。

  • 好奇心,尝鲜,学习
  • Matlab的优点:
    • 使用应用程序和可视化工具创建、修改和分析深度学习架构
    • 使用应用程序预处理数据,并自动对图像、视频和音频数据进行真值标注
    • 在 NVIDIA® GPU、云和数据中心资源上加速算法,而无需专门编程
    • 与基于 TensorFlow、PyTorch 和 MxNet 等框架的使用者开展协作
    • 使用强化学习仿真和训练动态系统行为
    • 从物理系统的 MATLAB 和 Simulink® 模型生成基于仿真的训练和测试数据

2.入门——手写数字识别

2.0环境介绍
  • 【重要】Matlab 2019b版本以上包括Matlab 2019b,我不知道Matlab 2019a版本怎么样,小伙伴们知道的话请评论一下我会及时更新的
  • RTX 2060,其实其他显卡也可以,只要支持GPU计算即可
  • GPU支持
2.1手写数字图片集

MNIST是手写数字图片数据集,包含60000张训练样本和10000张测试样本。MNIST数据集来自美国国家标准与技术研究所,National Institute of Standards and Technology(NIST),M是Modified的缩写。训练集是由来自250个不同人手写的数字构成,其中50%是高中学生,50%来自人口普查局的工作人员。测试集也是同样比例的手写数字数据。每张图片有28x28个像素点构成,每个像素点用一个灰度值表示,这里是将28*28的像素展开为一个一维的行向量(每行784个值)。图片标签为one-hot编码:0-9

重要:需要下载Mnist数据集

Mnist数据集下载链接

2.2Matlab读取Mnist数据集获取图像和标签
datapath = "./Mnist/";filenameImagesTrain = strcat(datapath, "train-images-idx3-ubyte");
filenameLabelsTrain = strcat(datapath, "train-labels-idx1-ubyte");
filenameImagesTest = strcat(datapath, "t10k-images-idx3-ubyte");
filenameLabelsTest = strcat(datapath, "t10k-labels-idx1-ubyte");XTrain = processMNISTimages(filenameImagesTrain);
YTrain = processMNISTlabels(filenameLabelsTrain);
XTest = processMNISTimages(filenameImagesTest);
YTest = processMNISTlabels(filenameLabelsTest);
% 处理Mnist数据集图像
function X = processMNISTimages(filename)[fileID,errmsg] = fopen(filename,'r','b');if fileID < 0error(errmsg);endmagicNum = fread(fileID,1,'int32',0,'b');if magicNum == 2051fprintf('\nRead MNIST image data...\n')endnumImages = fread(fileID,1,'int32',0,'b');fprintf('Number of images in the dataset: %6d ...\n',numImages);numRows = fread(fileID,1,'int32',0,'b');numCols = fread(fileID,1,'int32',0,'b');X = fread(fileID,inf,'unsigned char');X = reshape(X,numCols,numRows,numImages);X = permute(X,[2 1 3]);X = X./255;X = reshape(X, [28,28,1,size(X,3)]);X = dlarray(X, 'SSCB');fclose(fileID);
end
% 处理Mnist数据集标签
function Y = processMNISTlabels(filename)[fileID,errmsg] = fopen(filename,'r','b');if fileID < 0error(errmsg);endmagicNum = fread(fileID,1,'int32',0,'b');if magicNum == 2049fprintf('\nRead MNIST label data...\n')endnumItems = fread(fileID,1,'int32',0,'b');fprintf('Number of labels in the dataset: %6d ...\n',numItems);Y = fread(fileID,inf,'unsigned char');Y = categorical(Y);fclose(fileID);
end

运行结果:

Read MNIST image data...
Number of images in the dataset:  60000 ...Read MNIST label data...
Number of labels in the dataset:  60000 ...Read MNIST image data...
Number of images in the dataset:  10000 ...Read MNIST label data...
Number of labels in the dataset:  10000 ...
2.3网络设计——LeNet5

LeNet-5的网络模型如下图所示

网络模型具体参数如下表所示

网络层 卷积核尺寸 步长 填充 输出大小
输入层 32 * 32 * 1
卷积层1 5 1 0 28 * 28 * 6
最大池化层1 2 2 0 14 * 14 * 6
卷积层2 5 1 0 10 * 10 * 6
最大池化层2 2 2 0 5 * 5 * 16
全连接层1 1 * 1 * 120
全连接层2 1 * 1 * 84
全连接层3 1 * 1 * 10
Softmax层 1 * 1 * 10
分类层 1 * 1 * 10
2.4Matlab——LeNet5设计

在Matlab 2019b中的App中有一个App,名为Deep Network Designer,即深度网络设计师,打开它,就可以通过拖动神经网络的组件来设计深度网络了

下表给出Matlab网络设计师中常见组件的相关信息

组件 翻译
imageInputLayer 图像输入层
sequenceInputLayer 序列输入层
convolution2dLayer 卷积层
fullyConnectLayer 全连接层
reluLayer relu层
leakyReluLayer leakyRelu层
tanhLayer tanhLayer层
eluLayer eLu层
batchNormalizationLayer BN层
dropoutLayer dropout层
crossChannelNormalizationLayer CCN层
averagePooling2dLayer 平均池化层
globalAveragePooling2dLayer 全局平均池化层
maxPooling2dLayer 最大池化层
additionLayer 加法层
depthConcatenationLayer 深度连接层
concatenationLayer 连接层
softmaxLayer softmax层
classificationLayer 分类层
regressionLayer 回归层

LeNet5设计图

设计完网络后用分析工具进行分析

分析结果无误后,导出LeNet5网络代码

layers = [imageInputLayer([28 28 1],"Name","imageinput")convolution2dLayer([5 5],6,"Name","conv1","Padding","same")tanhLayer("Name","tanh1")maxPooling2dLayer([2 2],"Name","maxpool1","Stride",[2 2])convolution2dLayer([5 5],16,"Name","conv2")tanhLayer("Name","tanh2")maxPooling2dLayer([2 2],"Name","maxpool","Stride",[2 2])fullyConnectedLayer(120,"Name","fc1")fullyConnectedLayer(84,"Name","fc2")fullyConnectedLayer(10,"Name","fc")softmaxLayer("Name","softmax")classificationLayer("Name","classoutput")];

至此,LeNet5网络Matlab设计已完成

2.5训练LeNet5网络

在Matlab训练网络,可以使用以下代码来设置训练,详见注释。训练时如果可以使用GPU来加速,训练会很快完成

options = trainingOptions('sgdm', ... %优化器'LearnRateSchedule','piecewise', ... %学习率'LearnRateDropFactor',0.2, ... % 'LearnRateDropPeriod',5, ...'MaxEpochs',20, ... %最大学习整个数据集的次数'MiniBatchSize',128, ... %每次学习样本数'Plots','training-progress'); %画出整个训练过程doTraining = true; %是否训练
if doTrainingtrainNet = trainNetwork(XTrain, YTrain,layers,options);% 训练网络,XTrain训练的图片,YTrain训练的标签,layers要训练的网% 络,options训练时的参数
end
save Minist_LeNet5 trainNet %训练完后保存模型
yTest = classify(trainNet, XTest); % 测试训练后的模型
accuracy = sum(yTest == YTest)/numel(YTest); %模型在测试集的准确率
2.6训练结果与测试

训练结果

测试模型
首先,给出测试示例图片

测试代码,详见注释

test_image = imread('5.jpg');
shape = size(test_image);
dimension=numel(shape);
if dimension > 2test_image = rgb2gray(test_image); %灰度化
end
test_image = imresize(test_image, [28,28]); %保证输入为28*28
test_iamge = imbinarize(test_image,0.5); %二值化
test_image = imcomplement(test_image); %反转,使得输入网络时一定要保证图片
% 背景是黑色,数字部分是白色
imshow(test_image);load('Minist_LeNet5');
% test_result = Recognition(trainNet, test_image);
result = classify(trainNet, test_image);
disp(test_result);


结果输出为:5,成功

5.完整的目录结构

3.完整的训练代码
%% 需要数据集 %%
datapath = "./Mnist/";filenameImagesTrain = strcat(datapath, "train-images-idx3-ubyte");
filenameLabelsTrain = strcat(datapath, "train-labels-idx1-ubyte");
filenameImagesTest = strcat(datapath, "t10k-images-idx3-ubyte");
filenameLabelsTest = strcat(datapath, "t10k-labels-idx1-ubyte");XTrain = processMNISTimages(filenameImagesTrain);
YTrain = processMNISTlabels(filenameLabelsTrain);
XTest = processMNISTimages(filenameImagesTest);
YTest = processMNISTlabels(filenameLabelsTest);%% LeNet网络 %%
LeNet = [imageInputLayer([28 28 1],"Name","imageinput")convolution2dLayer([5 5],6,"Name","conv1","Padding","same")tanhLayer("Name","tanh1")maxPooling2dLayer([2 2],"Name","maxpool1","Stride",[2 2])convolution2dLayer([5 5],16,"Name","conv2")tanhLayer("Name","tanh2")maxPooling2dLayer([2 2],"Name","maxpool","Stride",[2 2])fullyConnectedLayer(120,"Name","fc1")fullyConnectedLayer(84,"Name","fc2")fullyConnectedLayer(10,"Name","fc")softmaxLayer("Name","softmax")classificationLayer("Name","classoutput")];%% 训练LeNet %%
options = trainingOptions('sgdm', ... %优化器'LearnRateSchedule','piecewise', ... %学习率'LearnRateDropFactor',0.2, ... % 'LearnRateDropPeriod',5, ...'MaxEpochs',20, ... %最大学习整个数据集的次数'MiniBatchSize',128, ... %每次学习样本数'Plots','training-progress'); %画出整个训练过程doTraining = true; %是否训练
if doTrainingtrainLeNet = trainNetwork(XTrain, YTrain,LeNet,options);% 训练网络,XTrain训练的图片,YTrain训练的标签,layers要训练的网% 络,options训练时的参数
end
save Minist_LeNet5 trainLeNet %训练完后保存模型
yTest = classify(trainLeNet, XTest); % 测试训练后的模型
accuracy = sum(yTest == YTest)/numel(YTest); %模型在测试集的准确率%% 函数 %%
%% 处理Mnist数据集图像 %%
function X = processMNISTimages(filename)[fileID,errmsg] = fopen(filename,'r','b');if fileID < 0error(errmsg);endmagicNum = fread(fileID,1,'int32',0,'b');if magicNum == 2051fprintf('\nRead MNIST image data...\n')endnumImages = fread(fileID,1,'int32',0,'b');fprintf('Number of images in the dataset: %6d ...\n',numImages);numRows = fread(fileID,1,'int32',0,'b');numCols = fread(fileID,1,'int32',0,'b');X = fread(fileID,inf,'unsigned char');X = reshape(X,numCols,numRows,numImages);X = permute(X,[2 1 3]);X = X./255;X = reshape(X, [28,28,1,size(X,3)]);X = dlarray(X, 'SSCB');fclose(fileID);
end
%% 处理Mnist数据集标签 %%
function Y = processMNISTlabels(filename)[fileID,errmsg] = fopen(filename,'r','b');if fileID < 0error(errmsg);endmagicNum = fread(fileID,1,'int32',0,'b');if magicNum == 2049fprintf('\nRead MNIST label data...\n')endnumItems = fread(fileID,1,'int32',0,'b');fprintf('Number of labels in the dataset: %6d ...\n',numItems);Y = fread(fileID,inf,'unsigned char');Y = categorical(Y);fclose(fileID);
end
4.完整的测试代码以及测试结果
test_image = imread('5.jpg');
shape = size(test_image);
dimension = numel(shape);
if dimension > 2test_image = rgb2gray(test_image); %灰度化
end
test_image = imresize(test_image, [28,28]); %保证输入为28*28
test_iamge = imbinarize(test_image,0.5); %二值化
test_image = imcomplement(test_image); %反转,使得输入网络时一定要保证图片、
% 背景是黑色,数字部分是白色
imshow(test_image);load('Minist_LeNet5');
% test_result = Recognition(trainNet, test_image);
test_result = classify(trainLeNet, test_image);
disp(test_result);

结束语

这个文章会适时更新的,欢迎大家学习,参考,评论,遇到问题在可以评论里Q我,当然关于“我的训练为啥不能用GPU这种玄学问题”,我能试着给予解答,但不保证能解决你的问题,或许可以用CPU来训练网络,不过我没试。可以看到Matlab在设计一些网络时还是很方便的——通过拖拖拽拽,哈哈,希望Matlab能引入更多的网络层组件,来使我们可以更好更方便地设计深度网络

Matlab深度学习——入门相关推荐

  1. MATLAB深度学习入门之旅

    目录 1. 简介 2. 使用预训练网络:使用已创建和训练后的网络进行分类 2.1 课程示例-识别一些图像中的对象 2.1.1  任务1:读取图像 2.1.2  任务2:显示图像 2.2 进行预测 2. ...

  2. Matlab深度学习入门实例:基于AlexNet的红绿灯识别(附完整代码)

    AlexNet于2012年出现在ImageNet的图像分类比赛中,并取得了当年冠军,从此卷积神经网络开始受到人们的强烈关注.AlexNet是深度卷积神经网络研究热潮的开端,也是研究热点从传统视觉方法过 ...

  3. 【Matlab】基于MNIST数据集的图像识别(深度学习入门、卷积神经网络、附完整学习资料)

    Matlab--数字0~9的图像识别(Phil Kim著.Matlab) 本文可以为那些想对深度学习和人工智能有初步了解的朋友提供一些基础入门的帮助. 本文所用参考书: <MATLAB深度学习 ...

  4. LeCun亲授的深度学习入门课:从飞行器的发明到卷积神经网络

    Root 编译整理 量子位 出品 | 公众号 QbitAI 深度学习和人脑有什么关系?计算机是如何识别各种物体的?我们怎样构建人工大脑? 这是深度学习入门者绕不过的几个问题.很幸运,这里有位大牛很乐意 ...

  5. 深度学习入门教程UFLDL学习实验笔记三:主成分分析PCA与白化whitening

     深度学习入门教程UFLDL学习实验笔记三:主成分分析PCA与白化whitening 主成分分析与白化是在做深度学习训练时最常见的两种预处理的方法,主成分分析是一种我们用的很多的降维的一种手段,通 ...

  6. 深度学习入门之PyTorch学习笔记:深度学习框架

    深度学习入门之PyTorch学习笔记 绪论 1 深度学习介绍 2 深度学习框架 2.1 深度学习框架介绍 2.1.1 TensorFlow 2.1.2 Caffe 2.1.3 Theano 2.1.4 ...

  7. 给深度学习入门者的Python快速教程 - 番外篇之Python-OpenCV

    转载自:https://zhuanlan.zhihu.com/p/24425116 本篇是前面两篇教程:给深度学习入门者的Python快速教程 - 基础篇 给深度学习入门者的Python快速教程 - ...

  8. 给深度学习入门者的Python快速教程 - numpy和Matplotlib篇

    转载自:https://zhuanlan.zhihu.com/p/24309547 本篇部分代码的下载地址: https://github.com/frombeijingwithlove/dlcv_f ...

  9. 深度学习入门(转)(备用)

    深度学习入门(转载) 我来总结下我从一个小白到在国际顶会上发 paper 的学习经验. 深度学习的资料非常多,但这也成为了深度学习坑最大的地方,学习者很容易迷失在各种资料当中,最后只看了个皮毛.所以, ...

  10. Tensorflow2.0深度学习入门与实战(日月光华)(学习总结1)

    Tensorflow2.0深度学习入门与实战(学习总结1) 我是刚学的,网易云课堂跟着日月光华老师,现在对每节课的学习课程做一下记录,总结,仅仅作为总结. 1.使用快捷键 shift+enter执行代 ...

最新文章

  1. 程序员硬核“年终大扫除”,清理了数据库 70GB 空间
  2. MATLAB_图形学_形态学课程_找出薛之谦的歌词所有字数
  3. 中职学校的学生计算机基础较弱,中职学校计算机专业教学的现状分析及对策探究.doc...
  4. 第一篇 webApp启航
  5. Spring –添加Spring MVC –第2部分
  6. Linux常用的基本命令08
  7. springboot+mybatis+redis实现分布式缓存
  8. 美国 ZIP Code 一览表
  9. 数据挖掘--决策树ID3+k-means聚类分析西瓜数据
  10. android 国家代码
  11. 将PPT导出图片分辨率提高的方法
  12. 新接口——“淘特”关键词搜索的API接口
  13. 响应式布局的实现方法
  14. 大二Web课程设计——美食网站设计与实现(HTML+CSS+JavaScript)
  15. java中内边距跟外边距,padding和margin——内边距和外边距
  16. 朱有鹏 socket实际编程2(6)
  17. 全站 HTTPS 来了
  18. 北斗三号短报文终端在大坝安全监测方案的应用
  19. 新疆能源产业发展走势及十四五供需规模调研报告2021版
  20. 安卓工程师教你玩转Android

热门文章

  1. 2021美赛成绩查询入口和美赛成绩公布时间
  2. IPVS之Bypass转发模式
  3. UTF-8,Unicode,GBK,希腊字母读法,ASCII码表,HTTP错误码,URL编码表,HTML特殊字符,汉字编码简明对照表...
  4. rootkit学习总结2
  5. MyEclipse详细使用教程
  6. MapOnline在线地图插件,ArcGIS的得力助手
  7. VC 界面库 皮肤库
  8. Bootstrap从入门到精通(全)
  9. oracle11g教程视频教程,最新oracle11g DBA 开发和应用数据库视频教程_IT教程网
  10. android studio 2048游戏