1、Deep Network Designer工具箱使用介绍

2、神经网络的GPU训练

3、预测与分类

一、Deep Network Designer工具箱使用介绍

相比BP、GRNN、RBF、NARX神经网络的简单结构,深度神经网络结构更加复杂,比如卷积神经网络CNN,长短时序神经网络LSTM等,matlab集成了深度学习工具箱,可输入如下指令调用:

Deep Network Designer

可以使用别人的网络架构也可以自己创建,点击“空白网络”创建。如下图最左侧是常用的各种网络层,可根据文献上的网络结构或者自己设计的结构任意组合,具体模块参数双击进行设计,前提是网络数据维度没有错误。如图所示,为作者创建的用于RGB图像分类的卷积神经网络CNN结构,具体设计过程后续出。构建完成,点击“分析”可查看是否有错误,无错误之后可通过“导出”得到网络架构的代码即layers。

layers = [imageInputLayer([120 160 3],"Name","imageinput")   %输入相机帧convolution2dLayer([3 3],15,"Name","conv_1","Padding","same")   %卷积层reluLayer("Name","relu_1")averagePooling2dLayer([2 2],"Name","avgpool2d_1","Stride",[2 2])  %平均池化层convolution2dLayer([3 3],15,"Name","conv_2","Padding","same")  %卷积层reluLayer("Name","relu_2")averagePooling2dLayer([2 2],"Name","avgpool2d_2","Stride",[2 2])  %平均池化层convolution2dLayer([3 3],12,"Name","conv_3","Padding","same")   %卷积层reluLayer("Name","relu_3")averagePooling2dLayer([2 2],"Name","avgpool2d_3","Stride",[2 2])  %平均池化层dropoutLayer(0.3,"Name","dropout_2")  %随机失活,失活率为30%fullyConnectedLayer(256,"Name","fc_1","WeightL2Factor",6)  %全连接层reluLayer("Name","relu_4")dropoutLayer(0.3,"Name","dropout_1")  %随机失活,失活率为30%fullyConnectedLayer(20,"Name","fc_2","WeightL2Factor",6)  %全连接层softmaxLayer("Name","softmax") classificationLayer("Name","classoutput")];   %输出层

创建一个m程序,将此代码复制进去。

二、神经网络的GPU训练

网络构建好以后,就是编写训练的代码,主要过程分为:读取数据集、归一化(可有可无)、划分训练集与测试集、反归一化(可有可无)、训练配置与训练。作者此处给出图像分类的代码,详细过程可见代码注释。

%% 工具箱导出的网络结构
layers = [imageInputLayer([120 160 3],"Name","imageinput")convolution2dLayer([3 3],15,"Name","conv_1","Padding","same")reluLayer("Name","relu_1")averagePooling2dLayer([2 2],"Name","avgpool2d_1","Stride",[2 2])convolution2dLayer([3 3],15,"Name","conv_2","Padding","same")reluLayer("Name","relu_2")averagePooling2dLayer([2 2],"Name","avgpool2d_2","Stride",[2 2])convolution2dLayer([3 3],12,"Name","conv_3","Padding","same")reluLayer("Name","relu_3")averagePooling2dLayer([2 2],"Name","avgpool2d_3","Stride",[2 2])dropoutLayer(0.3,"Name","dropout_2")fullyConnectedLayer(256,"Name","fc_1","WeightL2Factor",6)reluLayer("Name","relu_4")dropoutLayer(0.3,"Name","dropout_1")fullyConnectedLayer(20,"Name","fc_2","WeightL2Factor",6)softmaxLayer("Name","softmax")classificationLayer("Name","classoutput")];
%% 读取数据集
digitDatasetPath=fullfile('.\');  %打开数据集文件夹路径
% 注释:此路径下放有30个文件夹,每个文件夹为一个类别,每个文件夹里面有等数量的图片,这些图片都已经预处理。
imds=imageDatastore(digitDatasetPath,...'IncludeSubfolders',true,'LabelSource','foldernames');  %读取图片数据集,标签Label设置为文件名。
% 注释:每个文件夹的名字即为分类的类别标签
%% 划分数据集(训练集和验证集)
numTrainFiles=round(2/3*30);   % 20为类别文件夹数量,测试集作者放在另外的地方,训练时候只需要训练集和验证集。
[imdsTrain,imdsValidation]=splitEachLabel(imds,numTrainFiles,'randomize'); % 随机划分每类文件夹下的训练集和验证集%若数据图片大小与网络输入不一样,可通过下面三行代码处理。若相同可去掉此三行代码
inputSize=layers(1).InputSize;  %读取网络输入层的输入图像的大小尺寸
imdsTrain=augmentedImageDatastore(inputSize(1:2),imdsTrain);  %整合训练集的尺寸1与inputSize的第一二个维度相同。
augimdsValidation=augmentedImageDatastore(inputSize(1:2),imdsValidation);
%%  训练配置
ExecutionEnvironment='gpu';  %此处设置用GPU或者CPU训练,建议GPU快
%具体一些需要改动的配置说明,可以上matlab官网查看trainingOptions函数文档
options_train=trainingOptions('sgdm',...'MaxEpochs',100,...   % 训练轮数为65次'InitialLearnRate',0.0001,...  %初始学习率'Verbose',true,'MiniBatchSize',10,... 'LearnRateSchedule','piecewise',...'LearnRateDropFactor',0.6,...'LearnRateDropPeriod',5,...'Plots','training-progress',...'ValidationData',augimdsValidation,...'ValidationFrequency',10,...'ExecutionEnvironment',ExecutionEnvironment);
net=trainNetwork(imdsTrain,layers,options_train);   %开始训练
save('train.mat'); %保存训练完的网络模型为train.mat。

三、预测与分类

此处我们是属于分类任务,所以在第一步创建网络最后一层模块是分类块,如果是数据回归即数据预测则不同,本文不详细说明。下面给出利用已训练好的网络模型进行分类的代码。再创建一个m程序用来放分类的代码:

load('train.mat'); %先下载同一文件夹下之前训练好的模型
x=imread('1.jpg'); %读取一张事先准备好的图片1,命名为x
YPred=classify(net,x); %用训练好的网络net对x进行分类识别 ,分类结果为YPred
sprintf('测试结果为%s',YPred) 将结果YPred显示。注意这个YPred是一个奇怪的数据类型categorical
%为了后续GUI界面的方便使用,作者的数据集名字即类别lable都是数字哦
%下面就是将categorical数据类型转化为矩阵mat类型,命名为nn。
M=string(YPred);
nn=double(M);

结语

读者可能需要一些图片的预处理和数据增强,视频帧读取,GUI的网络嵌入与端到端识别等程序,可以参考其他博主的文章,作者后续闲暇之余有可能会出相关博客。本文神经网络和识别的一些原理算法,后续博客直接给出本科毕设论文以供参考。

matlab卷积神经网络的创建与图片识别相关推荐

  1. PyTorch之LeNet-5:利用PyTorch实现最经典的LeNet-5卷积神经网络对手写数字图片识别CNN

    PyTorch之LeNet-5:利用PyTorch实现最经典的LeNet-5卷积神经网络对手写数字图片识别CNN 目录 训练过程 代码设计 训练过程 代码设计 #PyTorch:利用PyTorch实现 ...

  2. Top2:CNN 卷积神经网络实现猫狗图片识别二分类

    Top2:CNN 卷积神经网络实现猫狗图片识别二分类 系统:Windows10 Professional 环境:python=3.6 tensorflow-gpu=1.14 ```python &qu ...

  3. Matlab卷积神经网络(CNN)手写数字识别(一)

    今天买的书到了,开始接触卷积神经网络,展示书中内容~ Matlab卷积神经网络手写数字识别(一) 机器学习的基本流程 加载Matlab自带数据集 机器学习的基本流程 在机器学习中,一般将数据集划分为两 ...

  4. matlab深度学习——【卷积神经网络】手写字的识别

    > 本文所使用的数据集在文章最后,不需要积分就可以下载! 数据集下载 这里主要是基于卷积神经网络的手写字的识别,我是用matlab做的,如果有对卷积神经网络不太熟悉的伙伴可以搜下,网上资源比较多 ...

  5. 【图像识别】基于卷积神经网络cnn实现银行卡数字识别matlab源码

    1 基于卷积神经网络cnn实现银行卡数字识别模型 模型参考这里. 2 部分代码 %印刷体识别 clc;clear;close all; addpath('util/'); addpath('data/ ...

  6. 读书笔记-深度学习入门之pytorch-第四章(含卷积神经网络实现手写数字识别)(详解)

    1.卷积神经网络在图片识别上的应用 (1)局部性:对一张照片而言,需要检测图片中的局部特征来决定图片的类别 (2)相同性:可以用同样的模式去检测不同照片的相同特征,只不过这些特征处于图片中不同的位置, ...

  7. 基于卷积神经网络的手写数字识别(附数据集+完整代码+操作说明)

    基于卷积神经网络的手写数字识别(附数据集+完整代码+操作说明) 配置环境 1.前言 2.问题描述 3.解决方案 4.实现步骤 4.1数据集选择 4.2构建网络 4.3训练网络 4.4测试网络 4.5图 ...

  8. 基于matlab BP神经网络的手写数字识别

    摘要 本文实现了基于MATLAB关于神经网络的手写数字识别算法的设计过程,采用神经网络中反向传播神经网络(即BP神经网络)对手写数字的识别,由MATLAB对图片进行读入.灰度化以及二值化等处理,通过神 ...

  9. 基于卷积神经网络VGG实现水果分类识别

    基于卷积神经网络VGG实现水果分类识别 一. 前言 二. 模型介绍 三. 数据处理 四. 模型搭建 4.1 定义卷积池化网络 4.2 搭建VGG网络 4.3 参数配置 4.4 模型训练 4.5 绘制l ...

  10. 用卷积神经网络实现猫狗图片分类

    该例程使用数据集来源于 kaggle cat_VS _dog 数据集中的一部分, 用卷积神经网络实现猫狗图片二分类,例程序比较简单,就不多解释了,代码中会有相应的注释,直接上代码: import nu ...

最新文章

  1. 大白话讲解Promise(二)理解Promise规范
  2. A* 算法之父、人工智能先驱Nils Nilsson逝世 | 缅怀
  3. mysql与ofbiz,ofbiz+mysql安装求教
  4. 皮一皮:这样的消息我也想收...
  5. leetcode954. 二倍数对数组(treemap)
  6. 队列模块(Queue)
  7. 音频服务器未运行怎么办,音频服务未运行怎么办 音频服务未运行解决方法【详细介绍】...
  8. mysql 编码utfmb4
  9. 双android手机同步工具,android手机同步数据PC(SyncDroid)
  10. ios 身份证照片识别信息
  11. Android穿山甲SDK激励视频
  12. 使用腾讯乐固加固安卓APK
  13. steam动态令牌源码(python版本)
  14. phd计算机考试,美国计算机PHD院校申请难度有多大?
  15. Mysql基础入门篇(二)
  16. python撤销_Python 实现还原已撤回的微信消息
  17. python 画图colorbar 颜色大全 plt.cm.get_cmap
  18. Pomotroid 使用指南:一款高颜值 PC 端番茄时钟
  19. 新型1688分销网店让想开店的店主轻松无忧
  20. Graph500教程

热门文章

  1. html5swf小游戏源码,亲测可用120个H5小游戏实例源码
  2. SRC漏洞挖掘之信息收集
  3. 基于Matlab的极限学习机(ELM)实现
  4. 修改mysql字段长度
  5. 关于STM32编译报错:Error: L6218E: Undefined symbol SystemInit (referred from startup_stm32f10x_md.o).
  6. Golang中MYSQL驱动
  7. Alex 的 Hadoop 菜鸟教程: 第19课 华丽的控制台 HUE 安装以及使用教程
  8. MySQL联合查询分页
  9. 用户故事与敏捷方法—Scrum与用户故事
  10. java蓝字代表什么_蓝是什么意思 蓝字五行属什么