加载序列数据

加载示例数据。chickenpox_dataset 包含一个时序,其时间步对应于月份,值对应于病例数。输出是一个元胞数组,其中每个元素均为单一时间步。将数据重构为行向量。

data = chickenpox_dataset;%加载数据集 chickenpox_dataset是一个函数data = [data{:}];%将数据集变为数组的形式,此时得到的是一个1*n维的数组,n代表n个时刻,其中存储的是每个时刻的值,即对于时序预测,只有发生的顺序,不存在实际的时间figure %创建一个用来显示图形输出的一个窗口对象。各种属性都使用的默认设置 plot(data) %若data是向量,则以data的分量为纵坐标,以元素序号为横坐标,用直线以此连接数据点,绘制曲线。若y为实矩阵,则按列绘制每列对应的曲线xlabel("Month") %纵坐标的名称ylabel("Cases") %横坐标的名称title("Monthy Cases of Chickenpox") %曲线图的标题

对训练数据和测试数据进行分区。序列的前 90% 用于训练,后 10% 用于测试。

numTimeStepsTrain = floor(0.9*numel(data));%将90%的数据设定为训练集dataTrain = data(1:numTimeStepsTrain+1);%定义训练集dataTest = data(numTimeStepsTrain+1:end);%定义测试集

标准化数据

为了获得较好的拟合并防止训练发散,将训练数据标准化为具有零均值和单位方差。在预测时,您必须使用与训练数据相同的参数来标准化测试数据。

mu = mean(dataTrain);%求均值,为以后的归一化做准备sig = std(dataTrain);%求均差,为以后的归一化做准备dataTrainStandardized = (dataTrain - mu) / sig;%归一化,这里方式数据发散

准备预测变量和响应

要预测序列在将来时间步的值,请将响应指定为将值移位了一个时间步的训练序列。也就是说,在输入序列的每个时间步,LSTM 网络都学习预测下一个时间步的值。预测变量是没有最终时间步的训练序列。

XTrain = dataTrainStandardized(1:end-1);YTrain = dataTrainStandardized(2:end);

定义 LSTM 网络架构

创建 LSTM 回归网络。指定 LSTM 层有 200 个隐含单元。

numFeatures = 1;%输入特征维数numResponses = 1;%输出特征维数numHiddenUnits = 200;%每一层lsmt网络中存在多少神经单元layers = [ ...sequenceInputLayer(numFeatures)%输入层,参数是输入特征维数lstmLayer(numHiddenUnits)%lsmt层,如果想要构建多层lstm,修改参数即可fullyConnectedLayer(numResponses)%全连接层 也就是输出的维数regressionLayer];%该参数说明是在进行回归问题,而不是分类问题

指定训练选项。将求解器设置为 'adam' 并进行 250 轮训练。要防止梯度爆炸,请将梯度阈值设置为 1。指定初始学习率 0.005,在 125 轮训练后通过乘以因子 0.2 来降低学习率。

options = trainingOptions('adam', ...'MaxEpochs',250, ... %这个参数是最大迭代次数,即进行250次训练,每次训练后更改神经网络参数'GradientThreshold',1, ...%设置梯度阀值为1 ,防止梯度爆炸'InitialLearnRate',0.005, ... %设置初始学习率'LearnRateSchedule','piecewise', ...'LearnRateDropPeriod',125, ... %训练125次后学习率下降,衰落因子为0.2'LearnRateDropFactor',0.2, ... %训练125次后的学习率的衰落因子为0.2'Verbose',0, ...'Plots','training-progress'); %构建曲线图

训练 LSTM 网络

使用 trainNetwork 以指定的训练选项训练 LSTM 网络。

net = trainNetwork(XTrain,YTrain,layers,options);

预测将来时间步

要预测将来多个时间步的值,请使用 predictAndUpdateState 函数一次预测一个时间步,并在每次预测时更新网络状态。对于每次预测,使用前一次预测作为函数的输入。

使用与训练数据相同的参数来标准化测试数据。

dataTestStandardized = (dataTest - mu) / sig; %归一化处理,防止数据发散XTest = dataTestStandardized(1:end-1);%预测变量

要初始化网络状态,请先对训练数据 XTrain 进行预测。接下来,使用训练响应的最后一个时间步 YTrain(end) 进行第一次预测。循环其余预测并将前一次预测输入到 predictAndUpdateState。

对于大型数据集合、长序列或大型网络,在 GPU 上进行预测计算通常比在 CPU 上快。其他情况下,在 CPU 上进行预测计算通常更快。对于单时间步预测,请使用 CPU。要使用 CPU 进行预测,请将 predictAndUpdateState 的 'ExecutionEnvironment'选项设置为 'cpu'。

net = predictAndUpdateState(net,XTrain);[net,YPred] = predictAndUpdateState(net,YTrain(end));numTimeStepsTest = numel(XTest);
for i = 2:numTimeStepsTest
[net,YPred(:,i)] = predictAndUpdateState(net,YPred(:,i1),'ExecutionEnvironment','cpu');
end

使用先前计算的参数对预测去标准化。

YPred = sig*YPred + mu;

训练进度图会报告根据标准化数据计算出的均方根误差 (RMSE)。根据去标准化的预测值计算 RMSE。

YTest = dataTest(2:end);rmse = sqrt(mean((YPred-YTest).^2))

rmse = single

209.5295

使用预测值绘制训练时序。

figureplot(dataTrain(1:end-1))hold on % 保存图像  若想在新画图像之后不想覆盖原图像,就需要加上idx = numTimeStepsTrain:(numTimeStepsTrain+numTimeStepsTest);plot(idx,[data(numTimeStepsTrain) YPred],'.-')hold off %画图的时候,取消(覆盖)原来的图xlabel("Month")ylabel("Cases")title("Forecast")legend(["Observed" "Forecast"]) %legend(string1,string2,string3, ...)分别将字符串1、字符串           2、字符串3…… 标注到图中,每个字符串对应的图标为画图时的图标。

将预测值与测试数据进行比较。

figuresubplot(2,1,1) %将多个图画到一个平面上  subplot(m,n,p)或者subplot(m n p) 。其中,m表示是图排成m行,n表示图排成n列,也就是整个figure中有n个图是排成一行的,一共m行,如果m=2就是表示2行图。p表示图所在的位置,p=1表示从左到右从上到下的第一个位置。plot(YTest)hold onplot(YPred,'.-')hold offlegend(["Observed" "Forecast"])ylabel("Cases")title("Forecast")subplot(2,1,2)stem(YPred - YTest) % stem函数用于绘制火柴梗图。  本句指 在XPred的指定点处画出数据序列YTest. xlabel("Month")ylabel("Error")title("RMSE = " + rmse)

使用观测值更新网络状态

如果您可以访问预测之间的时间步的实际值,则可以使用观测值而不是预测值更新网络状态。

首先,初始化网络状态。要对新序列进行预测,请使用 resetState 重置网络状态。重置网络状态可防止先前的预测影响对新数据的预测。重置网络状态,然后通过对训练数据进行预测来初始化网络状态。

net = resetState(net); %重置网络状态net = predictAndUpdateState(net,XTrain);

对每个时间步进行预测。对于每次预测,使用前一时间步的观测值预测下一个时间步。将 predictAndUpdateState 的 'ExecutionEnvironment' 选项设置为 'cpu'。

YPred = []; numTimeStepsTest = numel(XTest);for i = 1:numTimeStepsTest[net,YPred(:,i)] = predictAndUpdateState(net,XTest(:,i),'ExecutionEnvironment','cpu');end

使用先前计算的参数对预测去标准化。

YPred = sig*YPred + mu;

计算均方根误差 (RMSE)。

rmse = sqrt(mean((YPred-YTest).^2))rmse = 122.2634

将预测值与测试数据进行比较。

figuresubplot(2,1,1)plot(YTest)hold onplot(YPred,'.-')hold offlegend(["Observed" "Predicted"])ylabel("Cases")title("Forecast with Updates")subplot(2,1,2)stem(YPred - YTest)xlabel("Month")ylabel("Error")title("RMSE = " + rmse)

这里,当使用观测值而不是预测值更新网络状态时,预测更准确。

感谢:https://blog.csdn.net/weixin_42791427/article/details/886807

官方文档:https://ww2.mathworks.cn/help/deeplearning/examples/time-series-forecasting-using-deep-learning.html

利用深度学习进行时间序列预测相关推荐

  1. 【深度学习】利用深度学习进行时间序列预测

    作者 | Christophe Pere 编译 | VK 来源 | Towards Datas Science 介绍 长期以来,我听说时间序列问题只能用统计方法(AR[1],AM[2],ARMA[3] ...

  2. 深度学习多变量时间序列预测:LSTM算法构建时间序列多变量模型预测交通流量+代码实战

    深度学习多变量时间序列预测:LSTM算法构建时间序列多变量模型预测交通流量+代码实战 LSTM(Long Short Term Memory Network)长短时记忆网络,是一种改进之后的循环神经网 ...

  3. 深度学习多变量时间序列预测:GRU算法构建时间序列多变量模型预测交通流量+代码实战

    深度学习多变量时间序列预测:GRU算法构建时间序列多变量模型预测交通流量+代码实战 GRU是LSTM网络的一种效果很好的变体,它较LSTM网络的结构更加简单,而且效果也很好,因此也是当前非常流形的一种 ...

  4. 基于深度学习的时间序列预测方法

    之前对时间序列预测的方法大致梳理了一下,最近系统的学习了深度学习,同时也阅读了一些处理序列数据的文献,发现对于基于深度学习的时间序列预测的方法,还可以做进一步细分:RNN.Attention和TCN. ...

  5. 深度学习多变量时间序列预测:Bi-LSTM算法构建时间序列多变量模型预测交通流量+代码实战

    深度学习多变量时间序列预测:Bi-LSTM算法构建时间序列多变量模型预测交通流量+代码实战 人类并不是每时每刻都从一片空白的大脑开始他们的思考.在你阅读这篇文章时候,你都是基于自己已经拥有的对先前所见 ...

  6. 深度学习多变量时间序列预测:Encoder-Decoder LSTM算法构建时间序列多变量模型预测交通流量+代码实战

    深度学习多变量时间序列预测:Encoder-Decoder LSTM算法构建时间序列多变量模型预测交通流量+代码实战 LSTM是一种时间递归神经网络,适合于处理和预测时间序列中间隔和延迟相对较长的重要 ...

  7. 深度学习多变量时间序列预测:卷积神经网络(CNN)算法构建时间序列多变量模型预测交通流量+代码实战

    深度学习多变量时间序列预测:卷积神经网络(CNN)算法构建时间序列多变量模型预测交通流量+代码实战 卷积神经网络,听起来像是计算机科学.生物学和数学的诡异组合,但它们已经成为计算机视觉领域中最具影响力 ...

  8. matlab使用深度学习进行时间序列预测

    使用深度学习进行时间序列预测 - MATLAB & Simulink - MathWorks 中国 Deep Learning Toolbox 快速入门 - MathWorks 中国

  9. 基于深度学习的时间序列预测

    # 技术黑板报 # 第十一期 推荐阅读时长:15min 前言 时间序列建模历来是学术和工业界的关键领域,比如用于气候建模.生物科学和医学等主题应用,零售业的商业决策和金融等.虽然传统的统计方法侧重于从 ...

最新文章

  1. SQL中的case when then else end用法
  2. 超越对手之四、五、六
  3. mysql 1418 错误原因及解决
  4. 使用Dottrace跟踪代码执行时间
  5. Learning Cocos2d-x for WP8(7)——让Sprite动起来
  6. istio springcloud_手牵手一起学Springcloud(1)微服务这么流行,你理解了嘛?
  7. 在FLEX中获得当前PLAYER版本等信息.
  8. php过滤除了文字数据英文,正则:过滤除英文和汉字的其它特殊符号
  9. 深度学习项目实施流程
  10. Jmeter接口测试+压力测试
  11. Visual Studio2010当前不会命中代码,源代码与原始版本不同问题的解决方法
  12. ue4vr插件_UE4虚幻引擎可视化VR实例3dsMax全流程中级教学
  13. 利用pe系统重装电脑
  14. c语言 [Error] expected declaration or statement at end of input的解决方法
  15. 桂电 数电实验 期末考试 试卷+解析(74LS192 + 74LS153 + 74LS139 + 74LS00 / 74LS20)
  16. platform平台驱动模型简述(linux驱动开发篇)
  17. JSD-2204-Java语言基础-数组-方法-Day06
  18. Android通讯录开发之通讯录联系人搜索功能最新实现,kotlin入门到精通pdf
  19. HTML浪漫动态表白代码+音乐(附源码)
  20. Intellij IDEA如何将包的层级目录完全展现出来(树状结构)

热门文章

  1. 如何屏蔽网页上的烦人广告
  2. C++排序——Bookshelf B
  3. 就业两年国企辞职考研经验心得
  4. 计算机二级c类考试试题及答案,2016最新计算机二级C上机考试试题及答案
  5. 图片拼图软件哪个好?建议收藏这些软件
  6. Group equivariant capsule networks(组等变胶囊网络) 论文翻译
  7. 有点Python编程基础,怎么赚点小钱?
  8. Netty 入门教程
  9. 微软正式抛弃UWP!
  10. 【GIT】源仓库新建分支如何同步到fork的自有仓库