2021-08-06MATLAB深度学习简单应用
Matlab深度学习的简单实现
该实例源自MATLAB的官方学习课程中的《深度学习入门之旅》的课后练习项目,对已给的“显微镜下蛔虫图像”进行训练,最后识别出dead or alive。
课程官网链接:https://matlabacademy.mathworks.com/cn
1.准备部分
- 练习项目的数据集所需的数据集及标签链接:https://pan.baidu.com/s/1wWWTgGMDqJgEEhOGbDzERA 提取码:dp9m
- MATLAB中有关深度学习的网络可能需要自己去 附加功能资源管理器 安装。
图1 MATLAB附加功能资源管理器
2.具体实现
导入数据集。
wormds = imageDatastore('Roundworms\WormImages');
MATLAB中的
imageDatastore
函数用来创建一个数据集合存储,指定文件夹名称或者文件名,这样可以批量引入多个文件。此处导入数据集中的显微镜蛔虫图像。
montage(wormds);
然后可以用
montage
来预览批量导入的图像,如下图所示:
图2 montage函数下的WormImages
标注数据集
lifelabel = readtable('Roundworms\WormData.csv'); wormds.Labels = categorical(lifelabel.Status);
训练所需的标签,一般存储在图像存储集合(wormds)中的Labels属性,将.csv表格文件中的图像分类结果读入,并作为图像标签;
图3 部分标签worm.Labels
拆分 & 训练数据
% 80%的数据用来训练,20%用来测试 (随机) [wormTrain,wormTest] = splitEachLabel(wormds,0.8,'randomized');
splitEachLabel(images,n,'randomized')
函数可以将一个数据集合分成两部分,这样可以一部分图片作为训练集,另一部分作为测试集。其中n如果是(0,1)范围,则表示百分比;如果是整数,则直接代表确切数量。
randomized
是一个随机排序的方法,因为splitEachLabel
本身默认保持文件有序,可以添加随机标志来实现随机排序。调整神经网络
AlexNet由2012年的ImageNet冠军Hinton及其学生Alex Krizhevsky设计的一种卷积网络框架。 这次练习利用“迁移学习”的概念对其进行改动,调整网络层(network layers)、训练数据(training data)、训练选项(training options)等要素。
图4 AlexNet网络结构图
net = alexnet; ly = net.layers; %用ly表示网络的各个层 inputlayer = ly(1); outlayer = ly(end); categoryname = outlayer.Classes;
也可以通过MATLAB的附加资源管理器下载如GoogleNet等其他神经网络。
AlexNet输入为227x227x3,可以查看输入层的
InputSize
属性来了解一个网络的输入大小要求,也可以通过size()
函数查看数据集中的图像尺寸,来验证是否需要预处理调整尺寸。 前馈网络在MATLAB中以一个层数组来表示,方便进行索引和调整更改。
%修改网络参数 %2个标签 创建一个新的全连接层 fclayer = fullyConnectedLayer(2); %用新层替换原有的最后一个全连接层 ly(23) = fclayer; %用classificationLayer函数为图像分类网络构建一个新的输出层 ly(end) = classificationLayer;
设置训练选项
通常改动的训练选项包括:
- **Batch Size:**每步训练的数据量;
- **Max Iterations:**最大迭代次数;
- **Learning Rate:**学习率;
%opts = trainingOptions('sgdm','Name',value) opts = trainingOptions('sgdm','InitialLearnRate',0.001);
创建一个包含动态随机梯度下降
SGDM
优化器的训练算法变量opts
其中batch size和max iterations都使用默认,因为数据量少,训练的速度比较快,学习率可以先设的低一点,0.001。输入数据,执行训练
由于本身的WormImages尺寸并不完全是AlexNet所要求的‘227x227x3’,即尺寸为227x227的RGB图像,因此我们需要进行一下调整,否则使用net进行训练和预测时会报错。
audsTrain = augmentedImageDatastore([227 227],wormTrain,'ColorPreprocessing','gray2rgb'); audsTest = augmentedImageDatastore([227 227],wormTest,'ColorPreprocessing','gray2rgb');
在第3步中我将原有的数据集拆分成了wormTrain和wormTest两部分,且后面都会用到,所以都进行预处理。
ColorPreprocessing
属性可以将灰度图转化为三维数组。 完成对图像的预处理后,输入调整好的网络执行训练:
[wormnet,info] = trainNetwork(audsTrain,ly,opts);
放入训练数据、网络、训练选项。
图5 训练信息
性能评估
第6步中的训练部分获得了结构体
info
,这里面包含了三个训练信息:TrainingLoss(损失率)、TrainingAccuracy(准确率)、BaseLearnRate(学习率)。 我们可以绘制损失率和准确率的图像来观察训练效果。
图6 TrainingAccuracy和TrainingLoss图像
从准确率和损失率看起来,网络的性能很好,接下来要验证在真实数据上的表现:
wormpreds = classify(wormnet,audsTest);
将6中训练好的网络
wormnet
用来对准备好的测试集audsTest
进行测试,并将结果存在wormpreds
中。wormactual = wormTest.Labels; %已知的分类存储在测试集的Label标签属性中; num_correct = nnz(wormpreds == wormactual); %预测正确的数量; pencent_correct = num_correct/numel(wormpreds); %预测正确百分比; confusionchart(wormactual,wormpreds); %预测结果的混淆矩阵;
nnz
函数可以确定一个非零矩阵中元素的数目,可以用来比较有多少个预测值与真实值相匹配。
confusionchart
函数可以计算并显示预测分类的混淆矩阵:confusionchart(knownclass,predictedclass);
它会绘制一个横、纵轴标签相同的矩阵,横轴为预测值,纵轴为真实值,因此在对角线上的元素代表正确分类,非对角线上的元素代表误分类。最终绘制结果如下图所示:
图7 预测结果的混淆矩阵
因为原始数据WormImages中共93个图像,在第3步拆分中,我将其中的80%(74个图像)用作训练集,20%(19个)图像作 为测试集。从图7混淆矩阵结果可以看出,网络对19个测试图像进行了准确的分类。
- 完整代码(网盘下载的压缩包中有官方的代码,它用的是另一个网络,但方法和原理都是一样的)。
%深度学习课程检验%1.构建数据集
wormds = imageDatastore('Roundworms\WormImages');%1.1标注
lifelabel = readtable('Roundworms\WormData.csv');
wormds.Labels = categorical(lifelabel.Status);%display images
%montage(wormds);%1.2处理数据集
%60%的数据用来训练,40%用来测试 (随机)
[wormTrain,wormTest] = splitEachLabel(wormds,0.8,'randomized');%1.3网络
net = alexnet;
ly = net.Layers;
inputlayer = ly(1);
outlayer = ly(end);
categoryname = outlayer.Classes;%图像大小
%expectsz = inputlayer.InputSize;%1.4图像预处理
% img1 = imread('Roundworms\WormImages\wormA01.tif');
% imshow(img1);
% img1_sz = size(img1);
% img1= imresize(img1,[227 227]);
% imshow(img1);%1.5图像批量处理
%灰度图转为三维rgb图像
audsTrain = augmentedImageDatastore([227 227],wormTrain,'ColorPreprocessing','gray2rgb');
audsTest = augmentedImageDatastore([227 227],wormTest,'ColorPreprocessing','gray2rgb');%1.6研究预测
% [preds,scores] = classify(net,audsTrain);% %预测分数图像绘制
% highscores = scores > 0.01;
% bar(scores(highscores));%1.7修改网络
fclayer = fullyConnectedLayer(2);
ly(23) = fclayer;
ly(end) = classificationLayer;%1.8设置训练选项
opts = trainingOptions('sgdm','InitialLearnRate',0.001);%1.9性能评估
[wormnet,info] = trainNetwork(audsTrain,ly,opts);%损失率
trainingloss = info.TrainingLoss;
trainingAccu = info.TrainingAccuracy;
subplot(2,1,1);
plot(trainingAccu,'r');
legend('Accuracy');title('TrainingAccuracy');
xlabel('Batch(s)');ylabel('Accuracy(%)')
grid on;
subplot(2,1,2);
plot(trainingloss);title('TrainingLoss');
xlabel('Batch(s)');ylabel('Loss(%)')
legend('Loss');
grid on;%真实数据验证
wormpreds = classify(wormnet,audsTest);wormactual = wormTest.Labels;
num_correct = nnz(wormpreds == wormactual);
pencent_correct = num_correct/numel(wormpreds);
confusionchart(wormactual,wormpreds);
- 参考
- [1] MATLAB深度学习入门之旅
2021-08-06MATLAB深度学习简单应用相关推荐
- 【深度学习】李宏毅2021/2022春深度学习课程笔记 - Deep Learning Task Tips
文章目录 一.深度学习步骤回顾 二.常规指导 三.训练过程中Loss很大 3.1 原因1:模型过于简单 3.2 原因2:优化得不好 3.3 原因1 or 原因2 ? 四.训练过程Loss小.测试过程L ...
- 【深度学习】李宏毅2021/2022春深度学习课程笔记 - Convolutional Neural NetWork(CNN)
文章目录 一.图片分类问题 二.观察图片分类问题的特性 2.1 观察1 2.2 简化1:卷积 2.3 观察2 2.4 简化2:共享参数 - 卷积核 2.5 观察3 2.6 简化3:池化 2.6.1 M ...
- 【李宏毅机器学习2021】Task04 深度学习介绍和反向传播机制
[李宏毅机器学习2021]本系列是针对datawhale<李宏毅机器学习-2022 10月>的学习笔记.本次是对深度学习介绍和反向传播机制的学习总结.本节针对上节课内容,对batch.梯度 ...
- 干货 | 2021年,深度学习还有哪些研究方向可以做?
点击上方"视学算法",选择加"星标"或"置顶" 重磅干货,第一时间送达 作者丨谢凌曦.数据误码率.Zhifeng 来源丨知乎问答 编辑丨极市 ...
- 2021年,深度学习还有哪些未饱和、有潜力且处于上升期的研究方向?
作者丨谢凌曦.数据误码率.Zhifeng 来源丨知乎问答 编辑丨极市平台 问题链接: https://www.zhihu.com/question/460500204 0 1 作者:谢凌曦 来源链接: ...
- 2021年,深度学习还有哪些有潜力且处于上升期的研究方向?
点上方计算机视觉联盟获取更多干货 仅作学术分享,不代表本公众号立场,侵权联系删除 转载于:作者丨谢凌曦.数据误码率.Zhifeng 来源丨知乎问答 编辑丨极市平台 AI博士笔记系列推荐 周志华< ...
- 2021年,深度学习的发展趋势是什么?有哪些值得关注的新动向?
作者丨刘斯坦,电光幻影炼金术 来源丨知乎问答 编辑丨极市平台 [导读]到目前为止,深度学习领域的发展趋势是什么?有哪些值得关注的新动向?在应用领域,诸如cv,nlp等,研究思路是否有新的变化? 问题来 ...
- 笔记:计算机视觉与深度学习-简单的实现一个网络-番外篇
写在开头 1.课程来源及所有代码来源,后面不再每一步中标明了:pytorch官方网站的Tutorials.和Docs. 2.笔记目的:个人学习+增强记忆+方便回顾 3.时间:2021年4月19日 4. ...
- 【深度学习】李宏毅2021/2022春深度学习课程笔记 - 机器学习的可解释性
文章目录 一.为什么我们需要可解释性的机器学习 二.可解释性的 vs 强大的(Powerful) 三.可解释性机器学习的目标 四.可解释性的机器学习 4.1 Local Explanation 局部的 ...
- 【深度学习】李宏毅2021/2022春深度学习课程笔记 - Auto Encoder 自编码器 + PyTorch实战
文章目录 一.Basic Idea of Auto Encoder 1.1 Auto Encoder 结构 1.2 Auto Encoder 降维 1.3 Why Auto Encoder 1.4 D ...
最新文章
- python:文件操作
- linux下/proc/cpuinfo文件
- 使用.NET,郁闷之余,写下的废话
- Linux系统下GCC编译错误:“undefined reference to ‘sqrt‘”
- 短学期实训——第二篇
- java socket 一边关闭_java socket - 半关闭
- [转]虚拟机网络模式简介
- SAP CRM BSPWDApplication.do
- 软件工程编码阶段_软件工程的编码阶段
- 欧盟回应Meta退出欧洲威胁:没有Facebook生活一样很美好
- 博客园上海地区活动——LinkCoder主题社区第二期:淘宝服务化架构的设计和实践...
- Windows 10 之修改登录背景(Win10BGChanger)
- VMware vSphere 5.5的12个更新亮点(1)
- fetch oracle 1007,Oracle 教程 Fetch子句 - 闪电教程JSRUN
- Git - 设置签名(Autograph)
- 百度UNIT 机器人多轮对话技能创建以及API调用
- signal 11 定位
- Ubuntu查看usb设备驱动/usb以太网卡设备驱动
- Android 科大讯飞语音SDK集成步骤
- 生成带大写英文字母和数字的验证码(手机或邮箱)