目录

项目背景

加载序列数据

定义 LSTM 网络架构

训练LSTM网络

测试 LSTM 网络

使用 classify 对测试数据进行分类。

计算预测的准确度。

全部源代码

参考文献


项目背景

  • 此示例使用从佩戴在身体上的智能手机获得的传感器数据。该示例训练一个 LSTM 网络,旨在根据表示三个不同方向上的加速度计读数的时间序列数据来识别佩戴者的活动。训练数据包含七个人的时间序列数据。每个序列有三个特征,且长度不同。该数据集包含六个训练观测值和一个测试观测值。
  • 本文初衷在于帮助初学者创建LSTM网络,网络上很多介绍停留理论层面,反复地说LSTM网络的特点而忽略matlab创建的实际过程,或者以收费的形式贩卖mathwork里的实例,对学生党不友好。
  • LSTM网络简介:LSTM 网络常用于对序列数据进行分类 。LSTM 网络是一种循环神经网络 (RNN),可学习序列数据的时间步之间的长期依存关系。
  • 输入输出模式:LSTM网络按输入输出模式可分sequence to sequence网络与sequence to last(label)网络,前者多输入多输出,后者多输入单输出,本实例为sequence to sequence输入输出模式,例子来源mathwork。
  • 输入输出详细情况:该实例输入类型为cell型1*Y矩阵,每个元胞内为3*X的矩阵,3个特征*X个时间步(每个元胞内X可以不同)。输出类型为1*Ycell类型,每个cell内矩阵大小为1*X,其中是对对应输入元胞每个X标注的类别。(sequence to sequence。

重点:对于LSTM网络的两种工作模式,

  • 1、输入 特征数*时间步 矩阵,输出对每个时间步的模式分类(sequence to sequence)。  2、输入 特征数*时间步 矩阵,输出对整体(最后一个时间步)的模式分类(sequence to last)。
  • 这两种工作模式输入相同输出不同,前者输出为1*Ycell类型矩阵,每个cell内矩阵大小为:1*X(它对每一个时间步作分类)。后者输出为categorical类型分类列向量,大小为Y*1(它对输入中每个元胞包含序列的整体作分类)
  • 要训练深度神经网络以对序列数据的每个时间步进行分类,可以使用“序列到序列”LSTM 网络。通过“序列到序列”LSTM 网络,您可以对序列数据的每个时间步进行不同预测。

加载序列数据

加载人体活动识别数据。该数据包含从佩戴在身体上的智能手机获得的七个时间序列的传感器数据。每个序列有三个特征,且长度不同。这三个特征对应于三个不同方向上的加速度计读数。该例子来源mathwork,直接输入下列代码即可得到数据

load HumanActivityTrain

XTrain
XTrain=6×1 cell array
    {3×64480 double}
    {3×53696 double}
    {3×56416 double}
    {3×50688 double}
    {3×51888 double}
    {3×54256 double}

在绘图中可视化一个训练序列。绘制第一个训练序列的第一个特征,并按照对应的活动为绘图着色。

X = XTrain{1}(1,:);
classes = categories(YTrain{1});figure
for j = 1:numel(classes)label = classes(j);idx = find(YTrain{1} == label);hold onplot(idx,X(idx))
end
hold offxlabel("Time Step")
ylabel("Acceleration")
title("Training Sequence 1, Feature 1")
legend(classes,'Location','northwest')

定义 LSTM 网络架构

定义 LSTM 网络架构。将输入指定为大小为 3(输入数据的特征数量)的序列。指定包含 200 个隐含单元的 LSTM 层,并输出完整序列。最后,在网络中包含一个大小为 5 的全连接层,后跟 softmax 层和分类层,以此来指定五个类。

numFeatures = 3;
numHiddenUnits = 200;
numClasses = 5;layers = [ ...sequenceInputLayer(numFeatures)lstmLayer(numHiddenUnits,'OutputMode','sequence')fullyConnectedLayer(numClasses)softmaxLayerclassificationLayer];
%指定训练选项。将求解器设置为 'adam'。进行 60 轮训练。要防止梯度爆炸,请将梯度阈值设置为 2。options = trainingOptions('adam', ...'MaxEpochs',60, ...'GradientThreshold',2, ...'Verbose',0, ...'Plots','training-progress');
%使用 trainNetwork 以指定的训练选项训练 LSTM 网络。每个小批量都包含整个训练集,因此每训练一轮便更新一次绘图。序列非常长,因此处理每个小批量并更新绘图可能需要一些时间。

训练LSTM网络

net = trainNetwork(XTrain,YTrain,layers,options);

测试 LSTM 网络

加载测试数据并对每个时间步的活动进行分类。

加载人体活动测试数据。XTest 包含一个维度为 3 的序列。YTest 包含对应于每个时间步的活动的分类标签序列。

load HumanActivityTest
figure
plot(XTest{1}')
xlabel("Time Step")
legend("Feature " + (1:numFeatures))
title("Test Data")

使用 classify 对测试数据进行分类。

YPred = classify(net,XTest{1});

您也可以使用 classifyAndUpdateState 一次对一个时间步进行预测。这在时间步的值以流的方式到达时非常有用。通常,对完整序列进行预测比一次对一个时间步进行预测更快。

计算预测的准确度。

acc = sum(YPred == YTest{1})./numel(YTest{1})

acc = 0.9998

%通过绘图将预测值与测试数据进行比较。figure
plot(YPred,'.-')
hold on
plot(YTest{1})
hold offxlabel("Time Step")
ylabel("Activity")
title("Predicted Activities")
legend(["Predicted" "Test Data"])

全部源代码

load HumanActivityTrainX = XTrain{1}(1,:);
classes = categories(YTrain{1});figure
for j = 1:numel(classes)label = classes(j);idx = find(YTrain{1} == label);hold onplot(idx,X(idx))
end
hold offxlabel("Time Step")
ylabel("Acceleration")
title("Training Sequence 1, Feature 1")
legend(classes,'Location','northwest')numFeatures = 3;
numHiddenUnits = 200;
numClasses = 5;layers = [ ...sequenceInputLayer(numFeatures)lstmLayer(numHiddenUnits,'OutputMode','sequence')fullyConnectedLayer(numClasses)softmaxLayerclassificationLayer];
%指定训练选项。将求解器设置为 'adam'。进行 60 轮训练。要防止梯度爆炸,请将梯度阈值设置为 2。options = trainingOptions('adam', ...'MaxEpochs',60, ...'GradientThreshold',2, ...'Verbose',0, ...'Plots','training-progress');
%使用 trainNetwork 以指定的训练选项训练 LSTM 网络。每个小批量都包含整个训练集,因此每训练一轮便更新一次绘图。序列非常长,因此处理每个小批量并更新绘图可能需要一些时间。net = trainNetwork(XTrain,YTrain,layers,options);load HumanActivityTest
figure
plot(XTest{1}')
xlabel("Time Step")
legend("Feature " + (1:numFeatures))
title("Test Data")YPred = classify(net,XTest{1});acc = sum(YPred == YTest{1})./numel(YTest{1})%通过绘图将预测值与测试数据进行比较。figure
plot(YPred,'.-')
hold on
plot(YTest{1})
hold offxlabel("Time Step")
ylabel("Activity")
title("Predicted Activities")
legend(["Predicted" "Test Data"])

参考文献

[1]Mathwork:Sequence-to-Sequence Classification Using Deep Learning

MATLAB LSTM多输入多输出 模式分类 示例解析(含代码)相关推荐

  1. 径向基神经网络RBF:Matlab实现多输入多输出RBF神经网络(含例子及代码)

    创建5输入2输出RBF神经网络: x=2*rand(5,1000)-1;%输入为5维度共1000个数据 y(1,:)=sin(2*sum(x,1));%输出的第一维数据 y(2,:)=cos(3*su ...

  2. 广义回归神经网络GRNN:Matlab实现多输入多输出广义回归神经网络GRNN (含例子及代码)

    创建5输入,2输出的GRNN,随机产生1000个5维数据x作为输入,输出值为y: %net = newgrnn(P,T,spread) %参数P为输入向量: %T为输出向量: %spread 为径向基 ...

  3. 【LSTM预测】基于卷积神经网络结合双向长短时记忆CNN-BiLSTM(多输入单输出)数据预测含Matlab源码

    ✅作者简介:热爱科研的Matlab仿真开发者,修心和技术同步精进,matlab项目合作可私信.

  4. 3.1 matlab数据的输入和输出

    1.数据的输入 A=input(提示信息,选项); >> a = input('请输入变量a的值:') 请输入变量a的值:100a =100 2.数据的输出 disp(输出项); > ...

  5. input输入search查找关键词时,实现(即时搜索)边输入边输出目标内容的例子代码

    先充电: (1)change事件    触发事件必须满足两个条件: a)当前对象属性改变,并且是由键盘或鼠标事件激发的(脚本触发无效) b)当前对象失去焦点(onblur) (2)keypress  ...

  6. 一篇文章带你搞定数学建模中的灰色预测模型(05年长江水质问题示例讲解含代码)

    文章目录 一.题目分析 二.原理步骤 三.MATLAB实现 G(1,1) 预测未来10年的污水情况 四.MATLAB 实现预测六类污染程度的河流长度比例 五.扩展灰色预测知识 一.题目分析 假如不采取 ...

  7. 一篇文章带你搞定19年数学建模机场出租车优化问题示例讲解含代码

    文章目录 一.问题分析 二.数据介绍 三.模型的求解 四.结果分析 一.问题分析 收集国内某一机场及其所在城市出租车的相关数据,给出该机场出租车司机的选择方案,并分析模型的合理性和对相关因素的依赖性. ...

  8. 一篇文章带你搞定数学建模中的载荷矩阵、相关系数矩阵、主成分分析(11年土壤重金属污染示例讲解含代码)

    文章目录 一.题目分析 二.基于主成分分析法的重金属污染评价模型 1. 模型建立 2. 模型求解 三.问题求解代码 四.相关系数矩阵的了解 五.载荷矩阵的了解 一.题目分析 通过数据分析,说明重金属污 ...

  9. 一篇文章带你搞定数学建模中的元胞思想(11年土壤重金属污染示例讲解含代码)

    文章目录 一.题目分析 二.重金属污染物的传播特征 1. 定性分析 2. 定量分析 三.确定污染源的位置 1. 模型建立 2. 模型求解 四.模型代码 1. 元胞思想求解极大值点 2. 确定污染源坐标 ...

  10. 一篇文章带你搞定单因子污染指数和卡梅罗污染指数(11年土壤重金属污染示例讲解含代码)

    文章目录 一.题目分析 二.重金属的元素空间分布的代码 三.重金属污染程度分析的代码 一.题目分析 给出8种主要重金属元素在该城区的空间分布,并分析该城区内不同区域重金属的污染程度 在问题一中,根据三 ...

最新文章

  1. linux虚拟文件系统浅析
  2. Wampserver之 virtualHost
  3. java 分贝_java11教程--jhsdb命令
  4. UNIX TCP回射服务器/客户端之使用epoll模型的服务器
  5. php中如何将验证码放入页面,如何在php中生成验证码图片
  6. MVC设计之MVC设计模式(介绍)
  7. Python+ZeroMQ快速实现消息发布与订阅
  8. ZOJ 1709 Oil Deposits
  9. ORACLE AWR简介
  10. 关于Access数据库安全
  11. Tcl Tutorial 笔记6 ·while
  12. 三维重建_基于RGB-D相机的三维重建总览(静态动态)
  13. 【Sublime Text 3】编译环境
  14. bp网络拟合函数 matlab_神经网络案例分析—基于Matlab的预测
  15. 用最简单的方法解决:linux系统重启网络delaying initialization错误
  16. Ctfhub解题 彩蛋
  17. 如何在eclipse中导入Java项目文件包(方法截图详细步骤)
  18. 解决浏览器 Microsoft Edge主页被hao123恶意篡改
  19. 功能覆盖率与代码覆盖率区别
  20. js挂马,臭名昭著nu99.com

热门文章

  1. HarmoneyOS鸿蒙系统零代码编程入门
  2. 怎么用C语言求解线性规划,线性规划习题详细解析,包括线性规划方程求解步骤...
  3. Python父与子的编程之旅 第八章答案
  4. vue移动端项目使用自定义字体
  5. 网页版 QQ授权登录
  6. 【Spring揭秘】Spring简介
  7. python数据分析特训营课件,Python数据分析PPT学习课件
  8. 004-读书笔记-企业IT架构转型之道-阿里巴巴中台战略思想与架构实战-共享服务中心建设原则...
  9. json类型大小 mysql_MySQL数据类型 - JSON数据类型 (1)
  10. win10下Redis安装教程(新手)