一、网上的资源

网上有不少用LSTM来预测时间序列的资源,如下面:

深度学习(08)_RNN-LSTM循环神经网络-03-Tensorflow进阶实现

http://blog.csdn.net/u013082989/article/details/73693392

Applying Deep Learning to Time Series Forecasting with TensorFlow

https://mapr.com/blog/deep-learning-tensorflow/

Tensorflow 笔记 RNN 预测时间序列

https://www.v2ex.com/t/339544

tf19: 预测铁路客运量

http://blog.csdn.net/u014365862/article/details/53869802

但是调试起来,都很困难!借鉴比较多的是tf19:预测铁路客运量这篇博文。这篇博文首先是基本上可以运行的。但是训练模型和测试模型分开,需要通过文件来传递模型参数。而且训练和测试不能同时运行。因此调试起来也费了不少功夫!

二、LSTM时间序列预测

1. 用namedtuple来配置模型的超参数。

HParams = namedtuple('HParams', 'seq_size, hidden_size, learning_rate')

这种方式比定义一个Config类好。

2. 构建时间序列预测模型类TS_LSTM

class TS_LSTM(object):def __init__(self, hps):self._X = X = tf.placeholder(tf.float32, [None, hps.seq_size, 1])  self._Y = Y = tf.placeholder(tf.float32, [None, hps.seq_size])      W = tf.Variable(tf.random_normal([hps.hidden_size, 1]), name='W')  b = tf.Variable(tf.random_normal([1]), name='b')  lstm_cell = tf.nn.rnn_cell.BasicLSTMCell(hps.hidden_size)  #测试cost 1.3809outputs, states = tf.nn.dynamic_rnn(lstm_cell, X, dtype=tf.float32)  W_repeated = tf.tile(tf.expand_dims(W, 0), [tf.shape(X)[0], 1, 1])  output = tf.nn.xw_plus_b(outputs, W_repeated, b)  self._output = output = tf.squeeze(output)  self._cost = cost = tf.reduce_mean(tf.square(output - Y))  self._train_op = tf.train.AdamOptimizer(hps.learning_rate).minimize(cost)  @propertydef X(self):return self._X@propertydef Y(self):return self._Y   @propertydef cost(self):return self._cost@propertydef output(self):return self._output@propertydef train_op(self):return self._train_op

这种方式比用函数定义模型更加方便。@property的设计使得模型用起来更加方便! 
模型的关键就是: 
1). 设定BasicLSTMCell的隐藏节点个数 
2). 调用dynamic_rnn(lstm_cell,X)来计算输出outputs 
3). 调用xw_plus_b将outputs计算为单个的output 
模型中各变量的维度如下:(batch_size=100, seq_size=3, hidden_size=6) 
- X定义为[None, hps.seq_size, 1]是因为dynamic_rnn的输入针对的是二维图像样本的输入,因此,必须多定义一个1的维度,传入的实际应该为100*3*1。 
- Y的维度维持与图像标签输入数据维度相同,传入的实际应该为100*3。 
- W为6*1 
- b为1*1 
- outputs为100*3*6 
- W_repeated为100*6*1,其变化过程6*11*6*1100*6*1。 
- output在squeeze之前为100*3*1,squeeze后为100*3 
- cost为1*1

3. 训练和测试函数train_test

def train_test(hps, data):#训练数据准备train_data_len = len(data)*2//3train_x, train_y = [], []  for i in range(train_data_len - hps.seq_size - 1):  train_x.append(np.expand_dims(data[i : i + hps.seq_size], axis=1).tolist())  train_y.append(data[i + 1 : i + hps.seq_size + 1].tolist())  #测试数据准备    test_data_len = len(data)//3test_x, test_y = [], []  for i in range(train_data_len,train_data_len+test_data_len - hps.seq_size - 1):  test_x.append(np.expand_dims(data[i : i + hps.seq_size], axis=1).tolist())  test_y.append(data[i + 1 : i + hps.seq_size + 1].tolist())  with tf.Graph().as_default(), tf.Session() as sess:  with tf.variable_scope('model',reuse=None):m_train = TS_LSTM(hps)         #训练tf.global_variables_initializer().run()for step in range(20000):  _, train_cost = sess.run([m_train.train_op, m_train.cost], feed_dict={m_train.X: train_x, m_train.Y: train_y})  #预测 test_cost, output = sess.run([m_train.cost, m_train.output],feed_dict={m_train.X: test_x, m_train.Y: test_y})  #print(hps, train_cost, test_cost)return train_cost, test_cost

这里的关键是测试用是训练模型,我也不知道为什么好多网络资源都将训练模型和测试模型分离开来。测试不就是用测试数据来测试训练模型的效果吗?因此这里把2/3的数据划给训练,1/3的数据用于测试。自己动手编代码时一定要对session.run函数用法和原理熟悉。

4. 主函数(对超参数组合的测试误差进行比较)

def main():#读取原始数据f=open('铁路客运量.csv')  df=pd.read_csv(f)   data = np.array(df['铁路客运量_当期值(万人)'])  normalized_data = (data - np.mean(data)) / np.std(data)     #测试不同组合的超参数对测试误差的影响costs =[]for seq_size in [4,6,12,16,24]:for hidden_size in [6,10,20,30]:print(seq_size, hidden_size)hps = HParams(seq_size, hidden_size, 0.003)train_cost, test_cost = train_test(hps, normalized_data) costs.append([train_cost,test_cost])

进行了初步比较,感觉有两个: 
1)同一个超参数,测试误差相差挺大。 
2)不同超参数,训练时误差基本都很小,但是测试误差相差很大,如何限制学习过程中的过拟合是一个很大的问题。 
可以看看我运行的训练误差和测试误差的比较。代码已放到csdn下载资源,csdn下载代码来!

     训练误差            测试误差
[[  4.04044241e-02   4.97651482e+00][  3.57200466e-02   6.96304381e-01][  2.97380015e-02   1.77482967e+01][  3.09452992e-02   2.62166214e+00][  3.62494551e-02   2.53422332e+00][  2.57663596e-02   1.44900203e+00][  2.24006996e-02   2.28607416e+00][  2.28729844e-02   1.12727535e+00][  2.58173030e-02   1.43265343e+00][  1.48035632e-02   1.05281734e+00][  1.24982912e-02   6.59598827e+00][  1.27354050e-02   1.69984627e+00][  1.60749555e-02   4.03962803e+00][  1.18473349e-02   7.92685986e-01][  7.39684049e-03   6.16959620e+00][  7.60479691e-03   3.01771784e+00][  1.40351299e-02   4.48093843e+00][  7.94599950e-03   3.78614712e+00][  5.50406286e-03   5.83478451e-01][  4.54067113e-03   8.15259743e+00]]

Tensorflow LSTM时间序列预测的尝试相关推荐

  1. 时序预测 | python实现仿生算法优化LSTM时间序列预测(全网最全仿生算法)

    ** 时序预测 | python实现仿生算法优化LSTM时间序列预测(全网最全仿生算法) ** 多变量/单变量预测程序 多变量/单变量预测程序 多变量/单变量预测程序 A ABC-LSTM--人工蜂群 ...

  2. LSTM 时间序列预测+股票预测案例(Pytorch版)

    文章目录 LSTM 时间序列预测 股票预测案例 数据特征 对收盘价(Close)单特征进行预测 1. 导入数据 2. 将股票数据收盘价(Close)进行可视化展示 3. 特征工程 4. 数据集制作 5 ...

  3. python时间序列分析航空旅人_Python深度学习教程:LSTM时间序列预测小练习—国航乘客数量预测...

    Python深度学习教程:LSTM时间序列预测小练习-国航乘客数量预测 参考数据: 数据一共两列,左边是日期,右边是乘客数量 对数据做可视化:import math import numpy as n ...

  4. LSTM时间序列预测代码超通俗解释(MATLAB)

    数据在评论区 可以查看这一篇博客有更好的代码和可视化: 多序列:http://t.csdn.cn/a4pM0 单序列:https://blog.csdn.net/m0_62526778/article ...

  5. Kesci:Tensorflow 实现 LSTM——时间序列预测(超详细)

    云脑项目3 -真实业界数据的时间序列预测挑战 https://www.kesci.com/home/project/5a391c670e1fc52691fde623 这篇文章将讲解如何使用lstm进行 ...

  6. Kesci:Tensorflow 实现 LSTM——时间序列预测

    LSTM https://www.kesci.com/home/project/5a38a9c00e1fc52691fd9c72 这篇文章将讲解如何使用lstm进行时间序列方面的预测,重点讲lstm的 ...

  7. [转载] lstm时间序列预测_pytorch入门使用PyTorch进行LSTM时间序列预测

    参考链接: 在Python中使用LSTM和PyTorch进行时间序列预测 想了解更多好玩的人工智能应用,请关注公众号"机器AI学习 数据AI挖掘","智能应用" ...

  8. 大数据毕业设计 LSTM时间序列预测算法 - 股票预测 天气预测 房价预测

    文章目录 0 简介 1 基于 Keras 用 LSTM 网络做时间序列预测 2 长短记忆网络 3 LSTM 网络结构和原理 3.1 LSTM核心思想 3.2 遗忘门 3.3 输入门 3.4 输出门 4 ...

  9. Kesci: Keras 实现 LSTM——时间序列预测

    博主之前参与的一个科研项目是用 LSTM 结合 Attention 机制依据作物生长期内气象环境因素预测作物产量.本篇博客将介绍如何用 keras 深度学习的框架搭建 LSTM 模型对时间序列做预测. ...

最新文章

  1. 波卡链Substrate (3)SRML框架
  2. 【转】PYTHON open/文件操作
  3. 雌性激素过高怎么办?
  4. RKNN Toolkit使用教程
  5. linux上安装java环境
  6. 【C++深度剖析教程40】使用数值型模板技术计算1+2+3+...+N的值
  7. vue-day04-vue前端交互
  8. 【超详细】在Linux上远程登录遇到的若干问题及解决方法(一)
  9. php微信jsapi支付小结,ThinkPHP接入微信支付 - JSAPI支付
  10. python安装django模块_在您的(Django)项目中使用setup.py
  11. webpack3 css,webpack3之处理css文件
  12. NYOJ题目1045看美女
  13. ffplay视频播放原理分析
  14. Photoshop插件-奥顿效果(梦幻柔焦)-脚本开发-PS插件
  15. 开发人员必备的四象限壁纸
  16. (4.6.14)android 插桩基本概念plugging or Swap
  17. 按Backspace键删除时,会出现^H
  18. excel怎么合并数据?
  19. 【坑】MySQL数据库对于毫秒大于500的数据会进位
  20. linux极点五笔无法输入词组_ibus设置

热门文章

  1. RocketMQ类关系图之NameServer
  2. docker设置国内镜像加速的坑
  3. MariaDB Spider:实现MySQL横纵向扩展的小能手
  4. Self Service Password (SSP)
  5. Mongodb安装搭建Replica Set+Sharding集群
  6. 创业必看:中国八大草根富豪发家史
  7. BPSK,QPSK,2FSK,16QAM,64QAM信号在高斯信道与瑞利信道下的误码率性能仿真
  8. 直播原理----协议
  9. matlab点云配准(总结性)
  10. nysql collation