前期在学习特征分类的时候确实花了不少功夫,想去了解一下长短时记忆网络的分类效果如何。这里主要分享一下LSTM的一些简介和代码。


这个例子展示了如何使用长短时记忆(LSTM)网络对序列数据进行分类。

若要训练深度神经网络对序列数据进行分类,您可以使用LSTM网络。LSTM网络使您能够将序列数据输入到网络中,并根据序列数据的单个时间步长进行预测。

本示例使用了日语元音数据集。这个例子训练一个LSTM网络来识别给定的代表两个连续日语元音的时间序列数据。训练数据包含了9名演讲者的时间序列数据。每个序列有12个特征,并且长度也有所不同。该数据集包含270个训练观察结果和370个测试观察。


Load Sequence Data
加载日语元音训练数据。
XTrain是一个包含270个长度12维序列的单元格阵列。Y是标签“1”,“2”,...,“9”的分类向量,对应于9个扬声者。XTrain中的条目是包含12行(每个特性为一行)和不同数量的列(每个时间步长为一列)的矩阵。

[XTrain,YTrain] = japaneseVowelsTrainData;
XTrain(1:5)

Visualize the first time series in a plot. Each line corresponds to a feature.

figure
plot(XTrain{1}')
xlabel("Time Step")
title("Training Observation 1")
numFeatures = size(XTrain{1},1);
legend("Feature " + string(1:numFeatures),'Location','northeastoutside')

Prepare Data for Padding

在训练期间,默认情况下,软件会将训练数据分成小批,并填充序列,使它们具有相同的长度。过多的填充物可能会对网络性能产生负面影响。为了防止训练过程添加过多的填充,可以按序列长度对训练数据进行排序,并选择一个小批的大小,以便小批中的序列具有相似的长度。

Get the sequence lengths for each observation.

numObservations = numel(XTrain);
for i=1:numObservationssequence = XTrain{i};sequenceLengths(i) = size(sequence,2);
end
Sort the data by sequence length.
[sequenceLengths,idx] = sort(sequenceLengths);
XTrain = XTrain(idx);
YTrain = YTrain(idx);
View the sorted sequence lengths in a bar chart.
figure
bar(sequenceLengths)
ylim([0 30])
xlabel("Sequence")
ylabel("Length")
title("Sorted Data")

选择一个27的小批量大小,以均匀地划分训练数据,并减少小批量中的填充量。

miniBatchSize = 27;

Define LSTM Network Architecture:

定义LSTM网络架构。将输入大小指定为12大小的序列(输入数据的尺寸)。指定一个包含100个隐藏单元的双向LSTM层,并输出序列的最后一个元素。最后,包含一个大小为9的全连接层,然后是一个softmax层和一个分类层。如果在预测时可以访问完整的序列,那么可以在网络中使用双向LSTM层。双向LSTM层在每个时间步长都从完整序列中学习。如果在预测时无法访问完整的序列,例如,如果正在预测值或一次预测一个时间步长,那么就使用LSTM层来代替。
inputSize = 12;
numHiddenUnits = 100;
numClasses = 9;
layers = [ ...sequenceInputLayer(inputSize)bilstmLayer(numHiddenUnits,'OutputMode','last')fullyConnectedLayer(numClasses)softmaxLayerclassificationLayer]
layers =
                5×1 Layer array with layers:
                        4 1 '' Sequence Input Sequence input with 12 dimensions
                        2 '' BiLSTM BiLSTM with 100 hidden units
                        3 '' Fully Connected 9 fully connected layer
                        4 '' Softmax softmax
                        5 '' Classification Output crossentropyex


现在,指定培训选项。指定求解器为“adam”,梯度阈值为1,最大周期数为100。要减少小批量的填充量,请选择27。要将数据填充为与最长序列相同的长度,请指定序列长度为“longest”。要确保数据仍然按序列长度排序,请指定永远不要打乱数据。由于小批量处理很小,序列很短,所以训练更适合CPU。请将“ExecutionEnvironment”指定为“cpu”。若要在GPU上进行训练,如果可用,请将“ExecutionEnvironment”设置为“auto”(这是默认值)。

maxEpochs = 100;
miniBatchSize = 27;
options = trainingOptions('adam', ...'ExecutionEnvironment','cpu', ...'GradientThreshold',1, ...'MaxEpochs',maxEpochs, ...'MiniBatchSize',miniBatchSize, ...'SequenceLength','longest', ...'Shuffle','never', ...'Verbose',0, ...'Plots','training-progress');
Train LSTM Network
使用trainNetwork训练LSTM网络。
net = trainNetwork(XTrain,YTrain,layers,options);
Test LSTM Network
加载测试集,并将序列分类为扬声器。加载日语元音测试数据。
XTest是一个包含370个不同长度为12的序列的单元格阵列。YTest是标签“1”,“2”,...“9”的分类向量,对应于9个扬声器。
[XTest,YTest] = japaneseVowelsTestData;
XTest(1:3)

LSTM网络网络使用相似长度的小批量序列进行训练。确保测试数据以相同的方式组织。按序列长度对测试数据进行排序。

numObservationsTest = numel(XTest);
for i=1:numObservationsTestsequence = XTest{i};sequenceLengthsTest(i) = size(sequence,2);
end
[sequenceLengthsTest,idx] = sort(sequenceLengthsTest);
XTest = XTest(idx);
YTest = YTest(idx);

对测试数据进行分类。为了减少分类过程中引入的填充量,请将小批量大小设置为27。要应用与训练数据相同的填充,请指定序列长度为'longest'

miniBatchSize = 27;
YPred = classify(net,XTest, ...'MiniBatchSize',miniBatchSize, ...'SequenceLength','longest');

计算预测的分类精度:

acc = sum(YPred == YTest)./numel(YTest)
Copyright 2021.11.25 15:17 The MathWorks, Inc.

关于输入数据类型可参考另一篇文章,包含代码和数据类型介绍,便于大家直接上手。链接如下:

LSTM程序输入数据转化_Cloudning的博客-CSDN博客为了便于大家快速上手LSTM,这里给出简单的MATLAB转换程序,仅供参考。data = readmatrix('original_data.xlsx')'; %% 原始数据label = readmatrix('fault_labels.xlsx'); %% 分类标签 %% 4个特征值,700组数据%% 划分测试集和训练集%% train datadata0 =data(1:4,1:2:700); https://blog.csdn.net/weixin_45168197/article/details/121539274

LSTM matlab实现相关推荐

  1. 论文笔记(三):深度学习在水文水资源中的应用综述

    A Comprehensive Review of Deep Learning Applications in Hydrology and Water Resources 深度学习在水文水资源中的应用 ...

  2. 【LSTM】基于LSTM网络的人脸识别算法的MATLAB仿真

    1.软件版本 matlab2021a 2.本算法理论知识 长短时记忆模型LSTM是由Hochreiter等人在1997年首次提出的,其主要原理是通过一种特殊的神经元结构用来长时间存储信息.LSTM网络 ...

  3. matlab LSTM序列分类的官方示例

    matlab版本是2018b及其以上. %% %加载序列数据 %数据描述:总共270组训练样本共分为9类,每组训练样本的训练样个数不等,每个训练训练样本由12个特征向量组成, [XTrain,YTra ...

  4. 【LSTM分类】基于双向长短时记忆(BiLSTM)实现数据分类含Matlab源码

    1 简介 LSTM 是循环神经网络中的一个特殊网络,它能够很好的处理序列信息并从中学习有效特征,它把以往的神经单元用一个记忆单元( memory cell) 来代替,解决了以往循环神经网络在梯度反向传 ...

  5. 组合预测 | MATLAB实现EMD-KPCA-LSTM、EMD-LSTM、LSTM多变量时间序列预测对比

    组合预测 | MATLAB实现EMD-KPCA-LSTM.EMD-LSTM.LSTM多变量时间序列预测对比 目录 组合预测 | MATLAB实现EMD-KPCA-LSTM.EMD-LSTM.LSTM多 ...

  6. 时序预测 | MATLAB实现基于EMD-LSTM时间序列预测(EMD分解结合LSTM长短期记忆神经网络)

    时序预测 | MATLAB实现基于EMD-LSTM时间序列预测(EMD分解结合LSTM长短期记忆神经网络) 目录 时序预测 | MATLAB实现基于EMD-LSTM时间序列预测(EMD分解结合LSTM ...

  7. 时序预测 | MATLAB实现LSTM长短期记忆神经网络时间序列预测

    目录 时序预测 | MATLAB实现LSTM长短期记忆神经网络时间序列预测 预测效果 程序设计 案例1 案例2 参考资料 时序预测 | MATLAB实现LSTM长短期记忆神经网络时间序列预测 预测效果 ...

  8. MATLAB-基于长短期记忆网络(LSTM)的SP500的股票价格预测 股价预测 matlab实战 数据分析 数据可视化 时序数据预测 变种RNN 股票预测

    MATLAB-基于长短期记忆网络(LSTM)的SP500的股票价格预测 股价预测 matlab实战 数据分析 数据可视化 时序数据预测 变种RNN 股票预测 摘要 近些年,随着计算机技术的不断发展,神 ...

  9. 【LSTM时间序列预测】基于matlab鲸鱼算法优化LSTM时间序列预测【含Matlab源码 105期】

    ⛄一.鲸鱼算法及LSTM简介 1 鲸鱼优化算法(Whale Optimization Algorithm,WOA)简介 鲸鱼优化算法(WOA),该算法模拟了座头鲸的社会行为,并引入了气泡网狩猎策略. ...

  10. 【LSTM时间序列数据】基于matlab LSTM时间序列数据预测【含Matlab源码 1949期】

    ⛄一.获取代码方式 获取代码方式1: 完整代码已上传我的资源:[LSTM时间序列数据]基于matlab LSTM时间序列数据预测[含Matlab源码 1949期] 获取代码方式2: 付费专栏Matla ...

最新文章

  1. c java http_[C] 类似于HttpClient的C语言实现Http POST功能如何实现?
  2. 最近一次.Dragon4444勒索病毒的成功解密过程
  3. 处理数字_6_NULL值的列的个数
  4. 大数据之Spark简介及RDD说明
  5. css js写在一起 vue_如何把vue2.0 和 animate.css合并在一起使用(详细教程)
  6. 【shell编程基础0】bash shell编程的基本配置
  7. 站长之家bbs.chinaz.com宣布将于2018年7月15日永久关站
  8. 手把手教你实现热更新功能,带你了解 Arthas 热更新背后的原理
  9. 封装0603和0805的区别
  10. CTC loss 理解
  11. tumblr_如何在WordPress中添加Tumblr共享按钮
  12. Android 高效安全加载图片
  13. C#实现向手机发送验证码短信
  14. js用正则表达式完成邮箱验证
  15. 分层自动化测试模型变与不变
  16. 调用第三方地图app导航(高德、百度、腾讯)
  17. 本地通过cmd开启一个服务
  18. iOS App被Apple拒绝的原因
  19. 计算机二级vb考试代码,二级计算机vb考试常用代码(看完必过).doc
  20. 紫外可见吸收光谱测试仪器

热门文章

  1. 点云数据集汇总整理(匠心之作,附官方下载地址)
  2. MySQL软件下载安装配置——详细教程
  3. 大数据集群安装02之Hadoop配置
  4. Vue3项目中使用AE+bodymovin+lottie的模式制作特效
  5. CPU226怎么与西门子变频器通讯
  6. 优麒麟桌面闪烁_优麒麟 19.10 正式发布—百尺竿头,更进一步
  7. 笔记 黑马程序员C++教程从0到1入门编程——核心编程
  8. 吴昂雄回应Arm中国控制权争夺:Arm罢免我无效
  9. js怎么实现ftp上传文件到服务器上,js ftp上传文件到服务器上
  10. 面试题整理|45个CSS面试题