Tensorflow LSTM时间序列预测的尝试
一、网上的资源
网上有不少用LSTM来预测时间序列的资源,如下面:
深度学习(08)_RNN-LSTM循环神经网络-03-Tensorflow进阶实现
http://blog.csdn.net/u013082989/article/details/73693392
Applying Deep Learning to Time Series Forecasting with TensorFlow
https://mapr.com/blog/deep-learning-tensorflow/
Tensorflow 笔记 RNN 预测时间序列
https://www.v2ex.com/t/339544
tf19: 预测铁路客运量
http://blog.csdn.net/u014365862/article/details/53869802
但是调试起来,都很困难!借鉴比较多的是tf19:预测铁路客运量这篇博文。这篇博文首先是基本上可以运行的。但是训练模型和测试模型分开,需要通过文件来传递模型参数。而且训练和测试不能同时运行。因此调试起来也费了不少功夫!
二、LSTM时间序列预测
1. 用namedtuple来配置模型的超参数。
HParams = namedtuple('HParams', 'seq_size, hidden_size, learning_rate')
这种方式比定义一个Config类好。
2. 构建时间序列预测模型类TS_LSTM
class TS_LSTM(object):def __init__(self, hps):self._X = X = tf.placeholder(tf.float32, [None, hps.seq_size, 1]) self._Y = Y = tf.placeholder(tf.float32, [None, hps.seq_size]) W = tf.Variable(tf.random_normal([hps.hidden_size, 1]), name='W') b = tf.Variable(tf.random_normal([1]), name='b') lstm_cell = tf.nn.rnn_cell.BasicLSTMCell(hps.hidden_size) #测试cost 1.3809outputs, states = tf.nn.dynamic_rnn(lstm_cell, X, dtype=tf.float32) W_repeated = tf.tile(tf.expand_dims(W, 0), [tf.shape(X)[0], 1, 1]) output = tf.nn.xw_plus_b(outputs, W_repeated, b) self._output = output = tf.squeeze(output) self._cost = cost = tf.reduce_mean(tf.square(output - Y)) self._train_op = tf.train.AdamOptimizer(hps.learning_rate).minimize(cost) @propertydef X(self):return self._X@propertydef Y(self):return self._Y @propertydef cost(self):return self._cost@propertydef output(self):return self._output@propertydef train_op(self):return self._train_op
这种方式比用函数定义模型更加方便。@property的设计使得模型用起来更加方便!
模型的关键就是:
1). 设定BasicLSTMCell的隐藏节点个数
2). 调用dynamic_rnn(lstm_cell,X)来计算输出outputs
3). 调用xw_plus_b将outputs计算为单个的output
模型中各变量的维度如下:(batch_size=100, seq_size=3, hidden_size=6)
- X定义为[None, hps.seq_size, 1]是因为dynamic_rnn的输入针对的是二维图像样本的输入,因此,必须多定义一个1的维度,传入的实际应该为100*3*1。
- Y的维度维持与图像标签输入数据维度相同,传入的实际应该为100*3。
- W为6*1
- b为1*1
- outputs为100*3*6
- W_repeated为100*6*1,其变化过程6*11*6*1100*6*1。
- output在squeeze之前为100*3*1,squeeze后为100*3
- cost为1*1
3. 训练和测试函数train_test
def train_test(hps, data):#训练数据准备train_data_len = len(data)*2//3train_x, train_y = [], [] for i in range(train_data_len - hps.seq_size - 1): train_x.append(np.expand_dims(data[i : i + hps.seq_size], axis=1).tolist()) train_y.append(data[i + 1 : i + hps.seq_size + 1].tolist()) #测试数据准备 test_data_len = len(data)//3test_x, test_y = [], [] for i in range(train_data_len,train_data_len+test_data_len - hps.seq_size - 1): test_x.append(np.expand_dims(data[i : i + hps.seq_size], axis=1).tolist()) test_y.append(data[i + 1 : i + hps.seq_size + 1].tolist()) with tf.Graph().as_default(), tf.Session() as sess: with tf.variable_scope('model',reuse=None):m_train = TS_LSTM(hps) #训练tf.global_variables_initializer().run()for step in range(20000): _, train_cost = sess.run([m_train.train_op, m_train.cost], feed_dict={m_train.X: train_x, m_train.Y: train_y}) #预测 test_cost, output = sess.run([m_train.cost, m_train.output],feed_dict={m_train.X: test_x, m_train.Y: test_y}) #print(hps, train_cost, test_cost)return train_cost, test_cost
这里的关键是测试用是训练模型,我也不知道为什么好多网络资源都将训练模型和测试模型分离开来。测试不就是用测试数据来测试训练模型的效果吗?因此这里把2/3的数据划给训练,1/3的数据用于测试。自己动手编代码时一定要对session.run函数用法和原理熟悉。
4. 主函数(对超参数组合的测试误差进行比较)
def main():#读取原始数据f=open('铁路客运量.csv') df=pd.read_csv(f) data = np.array(df['铁路客运量_当期值(万人)']) normalized_data = (data - np.mean(data)) / np.std(data) #测试不同组合的超参数对测试误差的影响costs =[]for seq_size in [4,6,12,16,24]:for hidden_size in [6,10,20,30]:print(seq_size, hidden_size)hps = HParams(seq_size, hidden_size, 0.003)train_cost, test_cost = train_test(hps, normalized_data) costs.append([train_cost,test_cost])
进行了初步比较,感觉有两个:
1)同一个超参数,测试误差相差挺大。
2)不同超参数,训练时误差基本都很小,但是测试误差相差很大,如何限制学习过程中的过拟合是一个很大的问题。
可以看看我运行的训练误差和测试误差的比较。代码已放到csdn下载资源,csdn下载代码来!
训练误差 测试误差
[[ 4.04044241e-02 4.97651482e+00][ 3.57200466e-02 6.96304381e-01][ 2.97380015e-02 1.77482967e+01][ 3.09452992e-02 2.62166214e+00][ 3.62494551e-02 2.53422332e+00][ 2.57663596e-02 1.44900203e+00][ 2.24006996e-02 2.28607416e+00][ 2.28729844e-02 1.12727535e+00][ 2.58173030e-02 1.43265343e+00][ 1.48035632e-02 1.05281734e+00][ 1.24982912e-02 6.59598827e+00][ 1.27354050e-02 1.69984627e+00][ 1.60749555e-02 4.03962803e+00][ 1.18473349e-02 7.92685986e-01][ 7.39684049e-03 6.16959620e+00][ 7.60479691e-03 3.01771784e+00][ 1.40351299e-02 4.48093843e+00][ 7.94599950e-03 3.78614712e+00][ 5.50406286e-03 5.83478451e-01][ 4.54067113e-03 8.15259743e+00]]
Tensorflow LSTM时间序列预测的尝试相关推荐
- 时序预测 | python实现仿生算法优化LSTM时间序列预测(全网最全仿生算法)
** 时序预测 | python实现仿生算法优化LSTM时间序列预测(全网最全仿生算法) ** 多变量/单变量预测程序 多变量/单变量预测程序 多变量/单变量预测程序 A ABC-LSTM--人工蜂群 ...
- LSTM 时间序列预测+股票预测案例(Pytorch版)
文章目录 LSTM 时间序列预测 股票预测案例 数据特征 对收盘价(Close)单特征进行预测 1. 导入数据 2. 将股票数据收盘价(Close)进行可视化展示 3. 特征工程 4. 数据集制作 5 ...
- python时间序列分析航空旅人_Python深度学习教程:LSTM时间序列预测小练习—国航乘客数量预测...
Python深度学习教程:LSTM时间序列预测小练习-国航乘客数量预测 参考数据: 数据一共两列,左边是日期,右边是乘客数量 对数据做可视化:import math import numpy as n ...
- LSTM时间序列预测代码超通俗解释(MATLAB)
数据在评论区 可以查看这一篇博客有更好的代码和可视化: 多序列:http://t.csdn.cn/a4pM0 单序列:https://blog.csdn.net/m0_62526778/article ...
- Kesci:Tensorflow 实现 LSTM——时间序列预测(超详细)
云脑项目3 -真实业界数据的时间序列预测挑战 https://www.kesci.com/home/project/5a391c670e1fc52691fde623 这篇文章将讲解如何使用lstm进行 ...
- Kesci:Tensorflow 实现 LSTM——时间序列预测
LSTM https://www.kesci.com/home/project/5a38a9c00e1fc52691fd9c72 这篇文章将讲解如何使用lstm进行时间序列方面的预测,重点讲lstm的 ...
- [转载] lstm时间序列预测_pytorch入门使用PyTorch进行LSTM时间序列预测
参考链接: 在Python中使用LSTM和PyTorch进行时间序列预测 想了解更多好玩的人工智能应用,请关注公众号"机器AI学习 数据AI挖掘","智能应用" ...
- 大数据毕业设计 LSTM时间序列预测算法 - 股票预测 天气预测 房价预测
文章目录 0 简介 1 基于 Keras 用 LSTM 网络做时间序列预测 2 长短记忆网络 3 LSTM 网络结构和原理 3.1 LSTM核心思想 3.2 遗忘门 3.3 输入门 3.4 输出门 4 ...
- Kesci: Keras 实现 LSTM——时间序列预测
博主之前参与的一个科研项目是用 LSTM 结合 Attention 机制依据作物生长期内气象环境因素预测作物产量.本篇博客将介绍如何用 keras 深度学习的框架搭建 LSTM 模型对时间序列做预测. ...
最新文章
- 波卡链Substrate (3)SRML框架
- 【转】PYTHON open/文件操作
- 雌性激素过高怎么办?
- RKNN Toolkit使用教程
- linux上安装java环境
- 【C++深度剖析教程40】使用数值型模板技术计算1+2+3+...+N的值
- vue-day04-vue前端交互
- 【超详细】在Linux上远程登录遇到的若干问题及解决方法(一)
- php微信jsapi支付小结,ThinkPHP接入微信支付 - JSAPI支付
- python安装django模块_在您的(Django)项目中使用setup.py
- webpack3 css,webpack3之处理css文件
- NYOJ题目1045看美女
- ffplay视频播放原理分析
- Photoshop插件-奥顿效果(梦幻柔焦)-脚本开发-PS插件
- 开发人员必备的四象限壁纸
- (4.6.14)android 插桩基本概念plugging or Swap
- 按Backspace键删除时,会出现^H
- excel怎么合并数据?
- 【坑】MySQL数据库对于毫秒大于500的数据会进位
- linux极点五笔无法输入词组_ibus设置