目的:通过一段时间的数据,预测后面一段时间的类别,比如输入数据是1-50步的变量,预测的是50-60步的类别。

输入变量的数目:15

预测的类别数:0,1,2,3,4,10 (1类数目最多,数据不均衡)

GRU模型参数解释:

参考链接:[Pytorch系列-54]:循环神经网络 - torch.nn.GRU()参数详解_文火冰糖的硅基工坊的博客-CSDN博客_nn.gru参数

input_size: 输入序列的变量的数目。

hidden_size: 隐藏层的特征的数目。

num_layers: GRU层的数目。

bias:是否需要偏置,默认是True(需要)。

batch_first: 用于确定batch size是否需要放到输入输出数据形状的最前面。

若为True, 则输入、输出的tensor的格式为(batch, seq_len, feature)

若为False,则输入、输出的tensor的格式为(seq_len,batch,feature)

默认是False。

为什么需要该参数呢?

在CNN网络和全连接网络,batch通常位于输入数据形状的最前面。

而对于具有时间信息的序列化数据,通常需要把seq放在最前面,需要把序列数据串行地输入网络中。(那我的模型不能设置为True???)

seq_len: 输入序列的长度。在我的情形下可以为50。

搭建GRU网络:

参考链接:pytorch使用torch.nn.Sequential快速搭建神经网络 - pytorch中文网

self.gru = nn.GRU(self.input_size, self.hidden_size, self.num_layers, batch_first=True, dropout=self.dropout)
self.fc = nn.Sequential(nn.Linear(self.hidden_size, self.output_size), nn.Sigmoid())
self.gru = torch.nn.GRU(self.input_size, self.hidden_size, self.num_layers, batch_first=True)
self.fc1 = torch.nn.Linear(self.hidden_size, 4)
self.fc2 = torch.nn.Linear(self.hidden_size, 4)
self.fc3 = torch.nn.Linear(self.hidden_size, 4)
self.fc4 = torch.nn.Linear(self.hidden_size, 4)
self.fc5 = torch.nn.Linear(self.hidden_size, 4)
self.softmax = torch.nn.Softmax(dim=1)

nn.Sequential:是一个Sequential容器,模块将按照构造函数中传递的顺序添加到模块中。另外,也可以传入一个有序模块。使用torch.nn.Sequential会自动加入激励函数。

torch.nn.Sequential与torch.nn.Module区别与选择

  • 使用torch.nn.Module,我们可以根据自己的需求改变传播过程,如RNN

  • 如果你需要快速构建或者不需要过多的过程,直接使用torch.nn.Sequential即可

nn.Linear(input_dimoutput_dim)

torch.nn.Softmax(dim=1)

参考链接: torch.nn.Softmax_CtrlZ1的博客-CSDN博客_torch.nn.softmax

tensor([[0.3458, 0.0596, 0.5147],

[0.3774, 0.7503, 0.3705],

[0.2768, 0.1901, 0.1148]])

dim=0表示对于第一个维度的对应下标之和是1, 即0.3458+0.3774+0.2768=1、0.0596+0.7503+0.1901=1。

tensor([[0.3381, 0.1048, 0.5572],

[0.1766, 0.6315, 0.1919],

[0.3711, 0.4586, 0.1704]])

dim=1表示对于第二维度而言,对应下标之和为1,0.3381+0.1048+0.5572=1, 0.1766+0.6315+0.1919=1,即所有列的对应下标之和为1。

 一些报错记录:

1. 计算交叉熵损失使用的output必须是softmax输出的概率而不是argmax之后得到的类别。

RuntimeError: Expected floating point type for target with class probabilities, got Long

语义分割损失函数系列(1):交叉熵损失函数_spectrelwf的博客-CSDN博客_语义分割交叉熵

2. 加载生成训练数据集的时候报错。

Ran out of input

python报错Ran out of input_在上树的路上的博客-CSDN博客

因为生成的数据集太大了,要减少数据集。(The actually error is OverflowError: cannot serialize a bytes object larger than 4 GiB. You have to reduce the size of the input.)

3.  输入张量和隐藏张量不在一个device上。

h_0 = torch.zeros(self.num_layers, batch_size, self.hidden_size)
Input and hidden tensors are not at the same device, found input tensor at cuda:0 and hidden tensor at cpu
h_0 = torch.zeros(self.num_layers, batch_size, self.hidden_size).to(Train.device)
Input and hidden tensors are not at the same device, found input tensor at cpu and
and hidden tensor at cuda:0 

解决方法:

output, _ = self.gru(input_seq.to(Train.device), h_0)

(input_seq后面加上to(Train.device))

4. 预测和真实标签长度不一致。

报错:Found input variables with inconsistent numbers of samples

y_true.shape

y_predict.shape

查看真实值和预测值的形状。

GRU时间序列数据分类预测相关推荐

  1. 分类预测 | MATLAB实现WOA-CNN-GRU鲸鱼算法优化卷积门控循环单元数据分类预测

    分类预测 | MATLAB实现WOA-CNN-GRU鲸鱼算法优化卷积门控循环单元数据分类预测 分类效果 基本描述 1.Matlab实现WOA-CNN-GRU多特征分类预测,多特征输入模型,运行环境Ma ...

  2. 深度学习时间序列预测:卷积神经网络(CNN)算法构建单变量时间序列预测模型预测空气质量(PM2.5)+代码实战

    深度学习时间序列预测:卷积神经网络(CNN)算法构建单变量时间序列预测模型预测空气质量(PM2.5)+代码实战 神经网络(neual networks)是人工智能研究领域的一部分,当前最流行的神经网络 ...

  3. 时间序列挖掘-预测算法-三次指数平滑法(Holt-Winters)——三次指数平滑算法可以很好的保存时间序列数据的趋势和季节性信息...

    from:http://www.cnblogs.com/kemaswill/archive/2013/04/01/2993583.html 在时间序列中,我们需要基于该时间序列当前已有的数据来预测其在 ...

  4. Keras之MLPR:利用MLPR算法(3to1【窗口法】+【Input(3)→(12+8)(relu)→O(mse)】)实现根据历史航空旅客数量数据集(时间序列数据)预测下月乘客数量问题

    Keras之MLPR:利用MLPR算法(3to1[窗口法]+[Input(3)→(12+8)(relu)→O(mse)])实现根据历史航空旅客数量数据集(时间序列数据)预测下月乘客数量问题 目录 输出 ...

  5. Keras之MLPR:利用MLPR算法(1to1+【Input(1)→8(relu)→O(mse)】)实现根据历史航空旅客数量数据集(时间序列数据)预测下月乘客数量问题

    Keras之MLPR:利用MLPR算法(1to1+[Input(1)→8(relu)→O(mse)])实现根据历史航空旅客数量数据集(时间序列数据)预测下月乘客数量问题 目录 输出结果 设计思路 实现 ...

  6. keras时间序列数据预测_使用Keras的时间序列数据中的异常检测

    keras时间序列数据预测 Anomaly Detection in time series data provides e-commerce companies, finances the insi ...

  7. java三次指数平滑_时间序列挖掘-预测算法-三次指数平滑法(Holt-Winters)

    所有移动平均法都存在很多问题. 它们都太难计算了.每个点的计算都让你绞尽脑汁.而且也不能通过之前的计算结果推算出加权移动平均值. 移动平均值永远不可能应用于现有的数据集边缘的数据,因为它们的窗口宽度是 ...

  8. python学习日志3--ARIMA时间序列模型预测

    前言 这篇文章主要讲述如何使用python实现时间序列ARIMA预测算法 一.代码 代码如下(示例): #跟着视频学习的代码,记录一下. import numpy as np import panda ...

  9. 时间序列模型预测_时间序列预测,使用facebook先知模型预测股价

    时间序列模型预测 1.简介 (1. Introduction) 1.1. 时间序列和预测模型 (1.1. Time-series & forecasting models) Tradition ...

最新文章

  1. Android开发之发送邮件功能的实现(源代码分享)
  2. python中__init__后面加特殊符号_详解Python中的__new__、__init__、__call__三个特殊方法...
  3. ganglia-介绍安装(二)
  4. 50、Power Query-Text.Contains的学习
  5. 重磅快讯:CCF发布最新版推荐中文科技期刊目录
  6. 解决plsql中中文乱码问题
  7. 解决:PHP Deprecated: Comments starting with '#' are deprecated in ……
  8. 在线HTTP POST/GET接口测试 地址
  9. 涠洲岛形成及地形地貌特征
  10. win10去掉快捷方式小箭头_快捷方式小箭头很烦人 一招教你取消
  11. 土味情话恋爱话术微信小程序源码下载
  12. [工具使用]搜索引擎 Hacking
  13. chorme vue中使用audio自动播放问题
  14. Hyperledger Fabric网络节点架构
  15. 计算机组成:中断向量的相关计算
  16. 2020.10.19 第18节 预处理和宏定义
  17. 腾讯企业邮箱HTTPS设置
  18. 电子器件系列40:高压放电电阻(绕线电阻)
  19. 网络鞋城HTML和css代码,基于jsp的网上鞋城系统-JavaEE实现网上鞋城系统 - java项目源码...
  20. FreeType像素格式:FT_PIXEL_MODE_MONO

热门文章

  1. Hadoop第七天--MapReduceYarn详解(二)
  2. EMNLP 2021中预训练模型最新研究进展
  3. 1265 最近公共祖先
  4. 微信翻译生日快乐的代码_广外,54岁生日快乐!校庆日专属头像上线!
  5. [CISCN2019 华北赛区 Day2 Web1]Hack World
  6. 移动硬盘的“磁盘结构损坏且无法读取”问题的解决方法
  7. TKinter布局之pack
  8. Apache Tomcat优化
  9. 机器学习数学知识(一) 自然数e
  10. 数据赋能:Uber的数据治理实践分享