一、概述:

传统的线性模型难以解决多变量或多输入问题,而神经网络如LSTM则擅长于处理多个变量的问题,该特性使其有助于解决时间序列预测问题。

本文将初步探究 LSTM 在股票市场的应用。通过使用LSTM对股票收益的预测,可以了解到:

(1)如何将原始数据集转换为可用于时间序列预测的数据。

(2)如何准备数据并使LSTM适合多变量时间序列预测问题。

(3)如何进行预测并将结果重新调整回原始数据。

二 数据选择和处理:

input的时间跨度为30天,每天的features为['close','open','high','low','amount','volume']共6个,因此每个input为30×6的二维向量。
output为未来5日收益future_return_5(future_return_5>0.2,取0.2;future_return_5<-0.2,取-0.2),为使训练效果更加明显,output=future_return_5×10; features均经过标准化处理(在每个样本内每个feature标准化处理一次)。
训练数据:沪深300 2005-01-01至2014-12-31时间段的数据;测试数据:沪深300 2015-01-01至2017-05-01时间段数据。

三 代码例子参考

import numpy as np
import tensorflow as tf
from tensorflow.keras.layers import Dropout, Dense, LSTM
import matplotlib.pyplot as plt
import os
import pandas as pd
from sklearn.preprocessing import MinMaxScaler
from sklearn.metrics import mean_squared_error, mean_absolute_error
import mathmaotai = pd.read_csv('./SH600519.csv')  # 读取股票文件training_set = maotai.iloc[0:2426 - 300, 2:3].values  # 前(2426-300=2126)天的开盘价作为训练集,表格从0开始计数,2:3 是提取[2:3)列,前闭后开,故提取出C列开盘价
test_set = maotai.iloc[2426 - 300:, 2:3].values  # 后300天的开盘价作为测试集# 归一化
sc = MinMaxScaler(feature_range=(0, 1))  # 定义归一化:归一化到(0,1)之间
training_set_scaled = sc.fit_transform(training_set)  # 求得训练集的最大值,最小值这些训练集固有的属性,并在训练集上进行归一化
test_set = sc.transform(test_set)  # 利用训练集的属性对测试集进行归一化x_train = []
y_train = []x_test = []
y_test = []# 测试集:csv表格中前2426-300=2126天数据
# 利用for循环,遍历整个训练集,提取训练集中连续60天的开盘价作为输入特征x_train,第61天的数据作为标签,for循环共构建2426-300-60=2066组数据。
for i in range(60, len(training_set_scaled)):x_train.append(training_set_scaled[i - 60:i, 0])y_train.append(training_set_scaled[i, 0])
# 对训练集进行打乱
np.random.seed(7)
np.random.shuffle(x_train)
np.random.seed(7)
np.random.shuffle(y_train)
tf.random.set_seed(7)
# 将训练集由list格式变为array格式
x_train, y_train = np.array(x_train), np.array(y_train)# 使x_train符合RNN输入要求:[送入样本数, 循环核时间展开步数, 每个时间步输入特征个数]。
# 此处整个数据集送入,送入样本数为x_train.shape[0]即2066组数据;输入60个开盘价,预测出第61天的开盘价,循环核时间展开步数为60; 每个时间步送入的特征是某一天的开盘价,只有1个数据,故每个时间步输入特征个数为1
x_train = np.reshape(x_train, (x_train.shape[0], 60, 1))
# 测试集:csv表格中后300天数据
# 利用for循环,遍历整个测试集,提取测试集中连续60天的开盘价作为输入特征x_train,第61天的数据作为标签,for循环共构建300-60=240组数据。
for i in range(60, len(test_set)):x_test.append(test_set[i - 60:i, 0])y_test.append(test_set[i, 0])
# 测试集变array并reshape为符合RNN输入要求:[送入样本数, 循环核时间展开步数, 每个时间步输入特征个数]
x_test, y_test = np.array(x_test), np.array(y_test)
x_test = np.reshape(x_test, (x_test.shape[0], 60, 1))model = tf.keras.Sequential([LSTM(80, return_sequences=True),Dropout(0.2),LSTM(100),Dropout(0.2),Dense(1)
])model.compile(optimizer=tf.keras.optimizers.Adam(0.001),loss='mean_squared_error')  # 损失函数用均方误差
# 该应用只观测loss数值,不观测准确率,所以删去metrics选项,一会在每个epoch迭代显示时只显示loss值checkpoint_save_path = "./checkpoint/LSTM_stock.ckpt"if os.path.exists(checkpoint_save_path + '.index'):print('-------------load the model-----------------')model.load_weights(checkpoint_save_path)cp_callback = tf.keras.callbacks.ModelCheckpoint(filepath=checkpoint_save_path,save_weights_only=True,save_best_only=True,monitor='val_loss')history = model.fit(x_train, y_train, batch_size=64, epochs=50, validation_data=(x_test, y_test), validation_freq=1,callbacks=[cp_callback])model.summary()file = open('./weights.txt', 'w')  # 参数提取
for v in model.trainable_variables:file.write(str(v.name) + '\n')file.write(str(v.shape) + '\n')file.write(str(v.numpy()) + '\n')
file.close()loss = history.history['loss']
val_loss = history.history['val_loss']plt.plot(loss, label='Training Loss')
plt.plot(val_loss, label='Validation Loss')
plt.title('Training and Validation Loss')
plt.legend()
plt.show()################## predict ######################
# 测试集输入模型进行预测
predicted_stock_price = model.predict(x_test)
# 对预测数据还原---从(0,1)反归一化到原始范围
predicted_stock_price = sc.inverse_transform(predicted_stock_price)
# 对真实数据还原---从(0,1)反归一化到原始范围
real_stock_price = sc.inverse_transform(test_set[60:])
# 画出真实数据和预测数据的对比曲线
plt.plot(real_stock_price, color='red', label='MaoTai Stock Price')
plt.plot(predicted_stock_price, color='blue', label='Predicted MaoTai Stock Price')
plt.title('MaoTai Stock Price Prediction')
plt.xlabel('Time')
plt.ylabel('MaoTai Stock Price')
plt.legend()
plt.show()##########evaluate##############
# calculate MSE 均方误差 ---> E[(预测值-真实值)^2] (预测值减真实值求平方后求均值)
mse = mean_squared_error(predicted_stock_price, real_stock_price)
# calculate RMSE 均方根误差--->sqrt[MSE]    (对均方误差开方)
rmse = math.sqrt(mean_squared_error(predicted_stock_price, real_stock_price))
# calculate MAE 平均绝对误差----->E[|预测值-真实值|](预测值减真实值求绝对值后求均值)
mae = mean_absolute_error(predicted_stock_price, real_stock_price)
print('均方误差: %.6f' % mse)
print('均方根误差: %.6f' % rmse)
print('平均绝对误差: %.6f' % mae)

[深度学习] 使用LSTM实现股票预测相关推荐

  1. 浅谈深度学习:LSTM对股票的收益进行预测(Sequential 序贯模型,Keras实现)

    浅谈深度学习:LSTM对股票的收益进行预测(Sequential 序贯模型,Keras实现) 总包含文章: 一个完整的机器学习模型的流程 浅谈深度学习:了解RNN和构建并预测 浅谈深度学习:基于对LS ...

  2. 解析KDTCN:知识图谱和深度学习模型联合实现股票预测

    背景概述 今天看了一篇论文我觉得挺有意思,一方面是讲的股票预测相关,另一方面是把深度学习和知识图谱相结合解决一个问题.通常知识图谱和深度学习很少有交集,一般是独立发展的两个人工智能领域解决问题的手段, ...

  3. 【深度学习】LSTM神经网络解决COVID-19预测问题(二)

    [深度学习]LSTM神经网络解决COVID-19预测问题(二) 文章目录 1 概述 2 模型求解和检验 3 模型代码 4 模型评价与推广 5 参考 1 概述 建立一个普适性较高的模型来有效预测疫情的达 ...

  4. 【深度学习】LSTM神经网络解决COVID-19预测问题(一)

    [深度学习]LSTM神经网络解决COVID-19预测问题 文章目录 1 概述 2 数据分析 3 SIR模型和LSTM网络的对比 4 LSTM神经网络的建立 5 参考 1 概述 我们将SIR传播模型和L ...

  5. 手把手教你:基于LSTM的股票预测系统

    系列文章 第七章.手把手教你:基于深度残差网络(ResNet)的水果分类识别系统 第六章.手把手教你:人脸识别的视频打码 第五章.手把手教你:基于深度学习的滚动轴承故障诊断 目录 系列文章 一.项目简 ...

  6. tensorflow深度学习之LSTM(变种RNN)的原理

    在tensorflow深度学习之描述循环计算层与循环计算过程(超详细)中,我们已经学习到了传统循环网络RNN的原理. 传统循环网络RNN的缺点 传统循环网络RNN通过记忆体实现短期记忆进行连续数据的预 ...

  7. MATLAB深度学习之LSTM

    MATLAB深度学习之LSTM 深度学习工具箱 net = trainNetwork(sequences,Y,layers,options) clc clear %% 训练深度学习 LSTM 网络,进 ...

  8. 深度学习之LSTM案例分析(三)

    #背景 来自GitHub上<tensorflow_cookbook>[https://github.com/nfmcclure/tensorflow_cookbook/tree/maste ...

  9. HyperAttentionDTI:基于注意机制的序列深度学习改进药物-蛋白质相互作用预测

    题目: HyperAttentionDTI: improving drug–protein interaction prediction by sequence-based deep learning ...

最新文章

  1. Linux系统快速安装JDK
  2. 扎克伯格|在美国国会数据门听证会上的证词-中英文全文
  3. SAP Fiori Launchpad shell.handleGoHome() - home按钮的实现
  4. Jquery通过遍历数组给checkbox赋默认值
  5. [MS Sql Server术语解释]预读,逻辑读,物理读
  6. 【kafka】 kafka如何设置指定分区进行发送和消费
  7. psd文件用什么打开?如何修改psd文件?psd样式怎么修改文字?
  8. 一阶惯性环节如何实现跟踪性能与滤波性能共存(总)
  9. 学习3DMAX的几点收获
  10. 制作传播超级手机病毒嫌犯被抓
  11. react-ssr之路由配置
  12. 对论文中模型进行编程实现时的注意要求和总结
  13. 像把大象放入冰箱那样制造芯片
  14. 解决EZP_XPro无法烧录
  15. 手写一个抖音视频去水印工具,千万别刚一个程序员
  16. mansory自适应label高度
  17. 【转】Laravel+Angularjs+D3打造可视化数据,RESTful+Ajax
  18. USB4是什么?与USB 3.2有什么差异?
  19. 网管系统主流技术及其应用
  20. 利用owncloud搭建私人云盘

热门文章

  1. 更换钢丝绳,为何选择“国标锻造”钢丝绳夹(非铸件)
  2. 如何选择视频聊天程序搭建视频聊天网站
  3. 03 CoCos Creator 偏好设置中ndk配置
  4. Inception模块
  5. mysql 时间 本周 本月_日本人脑洞最大的奇葩恋爱游戏,本周上架Steam,别在吃饭时玩...
  6. 遮天 | 实战绕过卡巴斯基、Defender上线CS和MSF及动态命令执行...
  7. python中tell_Python中tell()方法的使用详解
  8. 感知机算法(一)PLA
  9. 最全面详细(小白)的 filebrowser 搭建部署教程
  10. android移动点餐系统内容和要求,基于Android云计算的移动点餐系统