trainNetwork - Matlab官网介绍的中文版
trainNetwork训练神经网络进行深度学习
原地址 https://www.mathworks.com/help/deeplearning/ref/trainnetwork.html
几种调用方法
net = trainNetwork(imds,layers,options)
net = trainNetwork(ds,layers,options)
net = trainNetwork(X,Y,layers,options)
net = trainNetwork(sequences,Y,layers,options)
net = trainNetwork(tbl,layers,options)
net = trainNetwork(tbl,responseName,layers,options)
[net,info] = trainNetwork(___)
描述
使用trainNetwork
训练卷积神经网络(ConvNet,CNN),长短期记忆(LSTM)网络,或双向LSTM(BiLSTM)网络的深度学习分类和回归的问题。您可以在CPU或GPU上训练网络。对于图像分类和图像回归,您可以使用多个GPU或并行进行训练。使用GPU,多GPU和并行选项需要Parallel Computing Toolbox™。要使用深层学习GPU,你还必须有一个CUDA ®启用NVIDIA ® GPU计算能力3.0或更高版本。使用指定培训选项,包括用于执行环境的选项trainingOptions
。
为图像分类问题训练网络。图像数据存储区net
= trainNetwork(imds
,layers
,options
)imds
存储输入的图像数据,layers
定义网络体系结构,并options
定义训练选项。
使用数据存储训练网络net
= trainNetwork(ds
,layers
,options
)ds
。对于具有多个输入的网络,请将此语法与组合或转换后的数据存储区结合使用。
为图像分类和回归问题训练网络。数字数组net
= trainNetwork(X
,Y
,layers
,options
)X
包含预测变量,并Y
包含分类标签或数字响应。
训练网络以解决序列分类和回归问题(例如LSTM或BiLSTM网络),其中net
= trainNetwork(sequences
,Y
,layers
,options
)sequences
包含序列或时间序列预测变量并Y
包含响应。对于分类问题,Y
是分类向量或分类序列的单元格数组。对于回归问题,Y
是目标矩阵或数字序列的单元格数组。
为分类和回归问题训练网络。该表net
= trainNetwork(tbl
,layers
,options
)tbl
包含数字数据或数据的文件路径。预测变量必须位于的第一列中tbl
。有关目标或响应变量的信息,请参见tbl。
为分类和回归问题训练网络。预测变量必须位于的第一列中net
= trainNetwork(tbl
,responseName
,layers
,options
)tbl
。该responseName
参数指定在响应变量tbl
。
[
net,info
] = trainNetwork(___) 还可以使用先前语法中的任何输入参数返回有关训练的信息。
例子
- 图像分类训练网络
将数据作为ImageDatastore
对象加载。
digitDatasetPath = fullfile(matlabroot,'toolbox','nnet',... 'nndemos','nndatasets','DigitDataset');
imds = imageDatastore(digitDatasetPath,... 'IncludeSubfolders',true,... 'LabelSource','foldernames');
数据存储区包含10,000个从0到9的数字合成图像。这些图像是通过对使用不同字体创建的数字图像应用随机转换而生成的。每个数字图像为28 x 28像素。数据存储区每个类别包含相等数量的图像。
显示数据存储中的某些图像。
figure
numImages = 10000;
perm = randperm(numImages,20);
for i = 1:20subplot(4,5,i);imshow(imds.Files{perm(i)});
end
指定卷积神经网络架构。对于回归问题,请在网络末端包括一个回归层。
layers = [ ...imageInputLayer([28 28 1])convolution2dLayer(5,20)reluLayermaxPooling2dLayer(2,'Stride',2)fullyConnectedLayer(10)softmaxLayerclassificationLayer];
指定网络训练选项。将初始学习速率设置为0.001。
options = trainingOptions('sgdm',... 'InitialLearnRate',0.001,... 'Verbose',false,... 'Plots','training-progress');
训练网络。
net = trainNetwork(imdsTrain,layers,options);
通过评估测试数据的预测准确性来测试网络的性能。使用predict
预测验证图像的旋转角度。
[XTest,〜,YTest] = digitTest4DArrayData;
YPred =predict(net,XTest);
通过计算预测旋转角和实际旋转角的均方根误差(RMSE)来评估模型的性能。
rmse = sqrt(mean((YTest-YPred)。^ 2))
rmse = single6.0655
序列分类训练网络
查看MATLAB命令
训练用于序列到标签分类的深度学习LSTM网络。
如[1]和[2]中所述加载日语元音数据集。XTrain
是包含270个长度可变且特征尺寸为12的序列的单元格数组。Y
是标签1,2,...,9的分类向量。中的条目XTrain
是具有12行(每个要素一行)和不同列数(每个时间步长一列)的矩阵。
[XTrain,YTrain] = japaneseVowelsTrainData;
可视化图中的第一个时间序列。每行对应一个特征。
数字
情节(XTrain {1}')
标题(“训练观察1”)
numFeatures = size(XTrain {1},1);
图例(“ Feature” + string(1:numFeatures),'Location','northeastoutside')
定义LSTM网络体系结构。将输入大小指定为12(输入数据的特征数)。指定一个LSTM层,使其具有100个隐藏单元并输出序列的最后一个元素。最后,通过包括大小为9的完全连接的层,其后是softmax层和分类层,来指定九个类。
inputSize = 12;
numHiddenUnits = 100;
numClasses = 9;层数= [ ...sequenceInputLayer(inputSize)lstmLayer(numHiddenUnits,'OutputMode','last')fullyConnectedLayer(numClasses)softmaxLayer分类图层]
层数= 具有层的5x1层阵列:1英寸序列输入序列输入具有12个尺寸2英寸LSTM LSTM具有100个隐藏单元3英寸全连接9个全连接层4英寸Softmax softmax5''分类输出交叉熵
指定训练选项。将求解器指定为'adam'
和'GradientThreshold'
。1.将小批量大小设置为27,并将最大纪元数设置为100。
由于小批量生产的序列短,因此CPU更适合训练。设置'ExecutionEnvironment'
到'cpu'
。要在GPU上进行训练(如果有),请设置'ExecutionEnvironment'
为'auto'
(默认值)。
maxEpochs = 100;
miniBatchSize = 27;options = trainingOptions('adam',... 'ExecutionEnvironment','cpu',... 'MaxEpochs',maxEpochs,... 'MiniBatchSize',miniBatchSize,... 'GradientThreshold',1,... '详细'',false,... ``情节'',``培训进度'');
使用指定的培训选项来培训LSTM网络。
net = trainNetwork(XTrain,YTrain,图层,选项);
加载测试集并将序列分类为扬声器。
[XTest,YTest] = japaneseVowelsTestData;
分类测试数据。指定用于训练的相同的小批量大小。
YPred = classify(net,XTest,'MiniBatchSize',miniBatchSize);
计算预测的分类准确性。
acc = sum(YPred == YTest)./ numel(YTest)
acc = 0.9541
输入参数
全部收缩
imds
— 图像数据存储
ImageDatastore
对象
图像数据存储,指定为ImageDatastore
对象。
ImageDatastore
允许使用预取功能批量读取JPG或PNG图像文件。如果您使用自定义功能读取图像,则 ImageDatastore
不会预取。
小费
使用augmentedImageDatastore
针对深度学习包括图像大小调整图像的高效预处理。
不要使用readFcn
选项,imageDatastore
因为此选项通常会明显变慢。
ds
— 数据存储数据
存储
数据存储,用于内存不足数据和预处理。
对于只有一个输入的网络,数据存储区返回的表或单元格数组有两列,分别指定了网络输入和期望的响应。
对于具有多个输入的网络,数据存储区必须是组合或转换后的数据存储区,该数据存储区将返回具有(numInputs
+1)列的单元格数组,其中包含预测变量和响应,其中 numInputs
是网络输入 numResponses
的数量,是响应的数量。对于i
小于或等于的值,单元阵列numInputs
的i
第th个元素对应于input layers.InputNames(i)
,其中 layers
是定义网络体系结构的层图。单元格数组的最后一列对应于响应。
下表列出了直接与兼容的数据存储 trainNetwork
。您可以使用transform
和combine
函数将其他内置数据存储区用于训练深度学习网络。这些函数可以将从数据存储中读取的数据转换为所需的表或单元格数组格式 trainNetwork
。有关更多信息,请参阅用于深度学习的数据存储。
数据存储类型 | 描述 |
---|---|
CombinedDatastore
|
水平串联从两个或多个基础数据存储读取的数据。 |
TransformedDatastore
|
根据您自己的预处理管道,转换来自底层数据存储的批量读取数据。 |
AugmentedImageDatastore
|
应用随机仿射几何变换,包括调整大小,旋转,反射,剪切和平移,以训练深度神经网络。 |
PixelLabelImageDatastore
|
将相同的仿射几何变换应用于图像和相应的地面真相标签,以训练语义分割网络(需要Computer Vision Toolbox™)。 |
RandomPatchExtractionDatastore
|
从图像或像素标签图像中提取成对的随机色块(需要Image Processing Toolbox™)。您可以选择将相同的随机仿射几何变换应用于面片对。 |
DenoisingImageDatastore
|
将随机生成的高斯噪声应用于训练降噪网络(需要“图像处理工具箱”)。 |
定制小批量数据存储 | 创建序列,时间序列或文本数据的迷你批。有关详细信息,请参阅开发自定义微型批处理数据存储。 |
X
— 图像数据
数字数组
图像数据,指定为数字数组。数组的大小取决于图像输入的类型:
输入项 | 描述 |
---|---|
2D影像 | 甲 ħ -by- 瓦特 -by- Ç -by- Ñ 数字阵列,其中ħ,瓦特,和 Ç分别是高度,宽度,和图像的信道数,和Ñ是图像的数量。 |
3D影像 | 甲 ħ -by- 瓦特 -by- d -by- Ç -by- Ñ 数字阵列,其中ħ,瓦特, d,和c ^是高度,宽度,深度和数量的图像,分别的通道,并且 Ñ是图像数。 |
如果数组包含NaN
,则它们将通过网络传播。
sequences
— 数字数组的序列或时间序列数据
单元格数组 | 数值数组 | 数据存储
序列或时间序列数据,指定为N乘1的数字数组单元格数组,其中N是观察数,代表单个序列的数字数组或数据存储。
对于单元格数组或数字数组输入,包含序列的数字数组的维数取决于数据类型。
输入项 | 描述 |
---|---|
矢量序列 | Ç -by- 小号矩阵,其中 Ç是的序列的特征的数量和š是序列长度。 |
二维图像序列 | h- by- w - c- by- s 数组,其中h,w和 c分别对应于图像的高度,宽度和通道数,而s是序列长度。 |
3-D图像序列 | ħ -by- 瓦特 -by- d -by- Ç -by- 小号,其中ħ,瓦特, d,和Ç对应的高度,宽度,深度和3-d的图像,分别的通道数,和s是序列长度。 |
对于数据存储区输入,数据存储区必须以序列的单元格数组或第一列包含序列的表的形式返回数据。序列数据的尺寸必须与上表相对应。
Y
— 响应
标签的分类向量 | 数值数组 | 分类序列的单元格数组 | 数字序列的单元格数组
响应,指定为标签的分类向量,数字数组,分类序列的单元格数组或数字序列的单元格数组。的格式Y
取决于任务的类型。响应中不得包含NaN
。
分类
任务 | 格式 |
---|---|
图片分类 | 标签的N ×1分类向量,其中N是观察数。 |
序列到标签分类 | |
序列到序列分类 |
标签分类序列的N ×1单元格数组,其中 N是观察数。在将 |
对于一个观察到的序列到序列分类问题, sequences
也可以是向量。在这种情况下, Y
必须是标签的分类序列。
回归
任务 | 格式 |
---|---|
二维图像回归 |
|
3-D图像回归 |
|
序列一回归 | N × R矩阵,其中N是序列数,R是响应数。 |
序列到序列回归 |
数字序列的N ×1单元格数组,其中N 是序列数。序列是具有R行的矩阵 ,其中R是响应数。在将 |
对于只有一个观察值的逐序列回归问题, sequences
可以将其作为矩阵。在这种情况下, Y
必须是响应矩阵。
标准化响应通常有助于稳定和加速训练神经网络以进行回归。有关更多信息,请参阅 训练卷积神经网络进行回归。
tbl
— 输入数据
table
输入数据,指定为包含第一列中的预测变量和其余列中的响应的表。表格中的每一行都对应一个观察值。
表列中预测变量和响应的排列方式取决于问题的类型。
分类
任务 | 预测变量 | 回应 |
---|---|---|
图片分类 |
|
分类标签 |
序列到标签分类 |
包含序列或时间序列数据的MAT文件的绝对或相对文件路径。 MAT文件必须包含一个由矩阵表示的时间序列,该矩阵具有与数据点相对应的行和与时间步长相对应的列。 |
分类标签 |
序列到序列分类 |
MAT文件的绝对或相对文件路径。MAT文件必须包含一个由分类向量表示的时间序列,并且每个时间步的标签均对应于其条目。 |
对于分类问题,如果您未指定 responseName
,则该函数默认使用的第二列中的响应tbl
。
回归
任务 | 预测变量 | 回应 |
---|---|---|
图像回归 |
|
|
序列一回归 |
包含序列或时间序列数据的MAT文件的绝对或相对文件路径。 MAT文件必须包含一个由矩阵表示的时间序列,该矩阵具有与数据点相对应的行和与时间步长相对应的列。 |
|
序列到序列回归 |
MAT文件的绝对或相对文件路径。MAT文件必须包含一个由矩阵表示的时间序列,其中行对应于响应,列对应于时间步长。 |
对于回归问题,如果不指定 responseName
,则该函数默认使用的其余列tbl
。标准化响应通常有助于稳定和加速训练神经网络以进行回归。有关更多信息,请参阅训练卷积神经网络进行回归。
响应中不能包含NaN
。如果预测变量数据包含NaN
,则它们将通过训练传播。但是,在大多数情况下,培训无法收敛。
资料类型: table
responseName
— 输入表字符向量中的响应变量的名称| 向量的元胞数组 | 字符串数组
输入表中响应变量的名称,指定为字符向量,字符向量的单元格数组或字符串数组。对于一个响应的问题, responseName
是中相应的变量名称 tbl
。对于具有多个响应变量的回归问题, responseName
是中对应变量名称的数组 tbl
。
数据类型:char
| cell
|string
layers
— 网络层
Layer
阵列 | LayerGraph
目的
网络层,指定为Layer
数组或LayerGraph
对象。
要创建依次连接所有层的网络,可以使用Layer
数组作为输入参数。在这种情况下,返回的网络是一个SeriesNetwork
对象。
有向无环图(DAG)网络具有复杂的结构,其中各层可以具有多个输入和输出。要创建DAG网络,请将网络体系结构指定为LayerGraph
对象,然后将该层图用作的输入参数 trainNetwork
。
有关内置层的列表,请参阅深度学习层列表。
options
— 培训选项
TrainingOptionsSGDM
| TrainingOptionsRMSProp
|TrainingOptionsADAM
培训选项,指定为TrainingOptionsSGDM
, TrainingOptionsRMSProp
或者 TrainingOptionsADAM
对象通过返回的trainingOptions
功能。要指定求解器和其他用于网络训练的选项,请使用 trainingOptions
。
输出参数
全部收缩
net
—训练有素的网络
SeriesNetwork
对象| DAGNetwork
目的
经过训练的网络,作为SeriesNetwork
对象或DAGNetwork
对象返回。
如果使用Layer
数组作为 layers
输入参数来训练网络,则它 net
是一个SeriesNetwork
对象。如果使用LayerGraph
对象作为输入参数来训练网络,则 net
该DAGNetwork
对象为对象。
info
—培训信息
结构
训练信息,以结构形式返回,其中每个字段是标量或数字向量,每个训练迭代具有一个元素。
对于分类问题,info
包含以下字段:
TrainingLoss
—损失函数值TrainingAccuracy
-训练精度ValidationLoss
—损失函数值ValidationAccuracy
—验证准确性BaseLearnRate
—学习率FinalValidationLoss
—最终验证损失FinalValidationAccuracy
—最终验证准确性
对于回归问题,info
包含以下字段:
TrainingLoss
—损失函数值TrainingRMSE
—训练RMSE值ValidationLoss
—损失函数值ValidationRMSE
—验证RMSE值BaseLearnRate
—学习率FinalValidationLoss
—最终验证损失FinalValidationRMSE
—最终验证RMSE
结构只包含的字段ValidationLoss
, ValidationAccuracy
,ValidationRMSE
,FinalValidationLoss
, FinalValidationAccuracy
和 FinalValidationRMSE
在options
指定的验证数据。所述'ValidationFrequency'
的选择trainingOptions
确定哪些迭代软件将计算验证指标。对于软件未计算验证指标的迭代,结构中的对应值为NaN
。
如果您的网络包含批处理规范化层,则最终验证指标通常与培训期间评估的验证指标不同。这是因为最终网络中的批处理归一化层执行的操作与训练期间不同。
更多关于
全部收缩
保存检查点网络并继续培训
深度学习工具箱™使您可以在训练期间的每个时期之后将网络另存为.mat文件。当您拥有大型网络或大型数据集并且训练需要很长时间时,这种定期保存特别有用。如果培训由于某种原因而中断,则可以从上次保存的检查点网络恢复培训。如果要 trainNetwork
保存检查点网络,则必须使用的'CheckpointPath'
名称/值对参数指定路径的名称trainingOptions
。如果指定的路径不存在,则 trainingOptions
返回错误。
trainNetwork
自动为检查点网络文件分配唯一的名称。在示例名称中 net_checkpoint__351__2018_04_12__18_09_52.mat
,351是迭代编号,2018_04_12
日期和保存网络18_09_52
的时间trainNetwork
。您可以通过双击或在命令行中使用load命令来加载检查点网络文件。例如:
<span style="color:#404040"><span style="color:inherit">加载net_checkpoint__351__2018_04_12__18_09_52.mat</span></span>
然后,您可以使用网络的各层作为的输入参数来恢复训练 trainNetwork
。例如:
<span style="color:#404040"><span style="color:inherit">trainNetwork(XTrain,YTrain,net.Layers,options)</span></span>
您必须手动指定培训选项和输入数据,因为检查点网络不包含此信息。有关示例,请参阅从Checkpoint Network继续培训。
浮点运算
深度学习工具箱中用于深度学习训练,预测和验证的所有功能都使用单精度浮点算术执行计算。深学习功能包括trainNetwork
,predict
, classify
,和 activations
。当您同时使用CPU和GPU训练网络时,该软件使用单精度算术。
参考资料
[1] Kudo,M.,J。Toyama和M.Shimbo。“使用通过区域的多维曲线分类”。 模式识别字母。卷 20,第11-13号,第1103-1111页。
[2] Kudo,M.,J。Toyama和M.Shimbo。日本元音数据集。https://archive.ics.uci.edu/ml/datasets/Japanese+Vowels
扩展功能
自动并行支持通过使用Parallel Computing Toolbox™自动并行
运行计算来加速代码。
trainNetwork - Matlab官网介绍的中文版相关推荐
- Hadoop 之 Distcp官网介绍和注意事项
Hadoop 之 Distcp方式 官网:https://hadoop.apache.org/docs/r2.10.0/hadoop-distcp/DistCp.html 一.概述 DistCp(分布 ...
- DCASE官网介绍——多模态智能感知与应用(课程报告)
文章目录 课程内容概述 DECASE官网介绍 Introduction Challenge Status 任务选择 Foley Sound Synthesis Summary of Task Desc ...
- 高端大气上档次的官网介绍导航页源码
介绍: 一款非常高端大气上档次官网导航页,非常利于收录,可以下载看看 网盘下载地址: http://kekewl.cc/jw8xBUiCkGm0 图片:
- Cocos2d 官网介绍,新手必看!!!!!!!!!!!!!!!!!!!!!!!!!
1.之前一直没仔细看cocos2d官网,后来发现很坑 http://cocos2d-x.org/ 官网分成英文,中文,还有日本语,建议大家用英语. 我之前给大家做教程,发现下载的一些东西你在中文网 ...
- hadoop官网介绍及如何下载hadoop(2.4)各个版本与查看hadoop API介绍
1.如何访问hadoop官网? 2.如何下载hadoop各个版本? 3.如何查看hadoop API? 很多同学开发都没有二手资料,原因很简单觉得不会英语,但是其实作为软件行业,多多少少大家会英语的, ...
- 2022公司邮箱登录入口官网介绍,个人邮箱用户登录
随着线上业务的发展,越来越多的公司会使用企业邮箱来和客户进行交流,毕竟邮箱的安全性和长久保存性在日常工作中,都起到了很大的作用.那么,在注册了邮箱账号后,要如何登录企业邮箱呢?毕竟邮箱登录入口官网有着 ...
- 深入理解蓝牙BLE之“Nordic官网介绍”
目录 1. Nordic官网及资料下载 2. Nordic infocenter(文档中心) 3. Nordic Devzone(开发者论坛) 4. Nordic Github 转载原地址:http: ...
- Nordic官网介绍(老版本)
1. Nordic官网及资料下载 Nordic官网主页:https://www.nordicsemi.com/,进入官网后,一般点击"Products"标签页,即进入Nordic产 ...
- Nordic老版官网介绍(2018-11-30停止更新)
1. Nordic官网及资料下载 Nordic官网主页:https://www.nordicsemi.com/,进入官网后,一般点击"Products"标签页,即进入Nordic产 ...
- 《flask日志logging一》flask官网介绍logging
官网地址:http://flask.pocoo.org/docs/dev/logging/ logging实例: @app.route('/login', methods=['POST']) def ...
最新文章
- js php 获取时间倒计时_,js实现倒计时及时间对象
- Jmeter 多台机器产生负载及问题解决方法
- Codeforces 1201
- OpenCASCADE:拓扑 API之3D模型周期性
- 【老王来了】之不眠不休教网络协议(RIP、OSPF、DHCP、VRRP、ACL、NAT)
- 转: 深入浅出-网络七层模型
- 【C++grammar】格式化输出与I/O流函数
- 如何选择神经网络的超参数
- Android反编译分析工具
- uniapp 发布网站遇到的问题(跨域,nginx代理失败,index无法打开,手机端无法访问等)
- php preg_split 正则截取字符串
- 《概率论与数理统计》之常见概率分布
- 肢体语言识别系统OpenPose问世,它甚至能明白你的表情
- [Python][sklearn] 使用from sklearn.neighbors import NearestNeighbors计算相似度
- MongoDB数据迁移之迁移工具Kettle
- 用python程序编写问卷调查_如何使用Python实现调查问卷的自动填写
- ftp - Internet 文件传输程序 (file transfer program)
- [洛谷P3262]战争调度
- 考研英语一历年真题写作(小作文+大作文)自己练习与背诵
- [vim与gvim技巧]vimgvim技巧大全(1)