系列文章

  • 第七章、手把手教你:基于深度残差网络(ResNet)的水果分类识别系统
  • 第六章、手把手教你:人脸识别的视频打码
  • 第五章、手把手教你:基于深度学习的滚动轴承故障诊断

目录

  • 系列文章
  • 一、项目简介
  • 二、数据集介绍
  • 三、环境安装
    • 1.环境要求
  • 四、重要代码介绍
    • 1.数据预处理
    • 2.预测模型构建
    • 3.模型训练
      • 3.1 训练参数定义
      • 3.2 训练loss及MSE
  • 五、完整代码地址

一、项目简介

本文主要介绍如何使用python搭建:一个基于长短期记忆网络(LSTM:Long Short-Term Memory, 简称 LSTM)的股票、大宗商品预测系统

项目只是用股票预测作为抛砖引玉,其中包含了使用LSTM进行时序预测的相关代码。主要功能如下:

  • 数据预处理。
  • 模型构建及训练,使用tensorflow构建LSTM网络。
  • 预测股票时序走向并进行模型评估。

如各位童鞋需要更换训练数据,完全可以根据源码将图像和标注文件更换即可直接运行。

博主也参考过网上图像分类的文章,但大多是理论大于方法。很多同学肯定对原理不需要过多了解,只需要搭建出一个预测系统即可。

本文只会告诉你如何快速搭建一个基于LSTM的股票预测系统并运行,原理的东西可以参考其他博主

也正是因为我发现网上大多的帖子只是针对原理进行介绍,功能实现的相对很少。

如果您有以上想法,那就找对地方了!


不多废话,直接进入正题!

二、数据集介绍

首先我们这次工作主要是针对,大宗商品指数的一个预测,分别为:化工、贵金属、有色。


  • 接下来是模型预测的结果,这里我用:化工商品,来观测模型预测的时序结果:

  • 可以看到其中红色曲线为化工商品的时序情况、绿色曲线为预测情况。

三、环境安装

1.环境要求

本项目开发IDE使用的是:Pycharm,大家可以直接csdn搜索安装指南非常多,这里就不再赘述。

因为本项目基于TensorFlow因此需要以下环境:

  • tensorflow >= 2.0
  • pandas
  • scikit-learn
  • numpy
  • matplotlib
  • joblib

四、重要代码介绍

环境安装好后就可以打开pycharm开始愉快的执行代码了。由于代码众多,博客中就不放入最终代码了,有需要的童鞋可以在博客最下方找到下载地址

1.数据预处理

  • 首先我们需要将时序问题转换为监督学习,才能进行训练。下方代码将输入的时序的收盘价转化为每日收益率并将收益率中滞后一天(默认为一天)的观测值作为监督学习值。
def series_to_supervised(data, n_in=1, n_out=1, dropnan=True):"""将时间序列转换为监督学习问题Arguments:data: 输入数据需要是列表或二维的NumPy数组的观察序列。n_in: 输入的滞后观察数(X)。值可以在[1..len(data)]之间,可选的。默认为1。n_out: 输出的观察数(y)。值可以在[0..len(data)-1]之间,可选的。默认为1。dropnan: Bool值,是否删除具有NaN值的行,可选的。默认为True。Returns:用于监督学习的Pandas DataFrame。"""# 定义series_to_supervised()函数# 将时间序列转换为监督学习问题n_vars = 1 if type(data) is list else data.shape[1]df = DataFrame(data)cols, names = list(), list()# input sequence (t-n, ... t-1)for i in range(n_in, 0, -1):cols.append(df.shift(i))names += [('var%d(t-%d)' % (j + 1, i)) for j in range(n_vars)]# forecast sequence (t, t+1, ... t+n)for i in range(0, n_out):cols.append(df.shift(-i))if i == 0:names += [('var%d(t)' % (j + 1)) for j in range(n_vars)]else:names += [('var%d(t+%d)' % (j + 1, i)) for j in range(n_vars)]# put it all togetheragg = concat(cols, axis=1)agg.columns = names# drop rows with NaN valuesif dropnan:agg.dropna(inplace=True)# 删除多余列agg.drop(agg.columns[[6, 8, 10]], axis=1, inplace=True)print("*" * 20)print("完成监督学习转换:")print(agg.head())return agg
  • 其二就是在数据构建完成后,以一定比率将训练数据和测试数据分离。

2.预测模型构建

  • 因为使用的是LSTM做回归预测,因此模型输出应该不是分类的类别,而是回归值。模型构建代码如下:
def model_create(train_X):"""搭建LSTM模型:param train_X::return:"""model = Sequential()model.add(LSTM(64, input_shape=(train_X.shape[1], train_X.shape[2])))model.add(Dropout(0.5))model.add(Dense(1, activation='relu'))model.compile(loss='mae', optimizer='adam', metrics=['mse'])return model

3.模型训练

3.1 训练参数定义

  • 设置批处理batch_size:100,博主总共跑了100个epoch。
    # 定义callbacks参数callbacks = [TensorBoard(log_dir=my_log_dir)]# 贵金属模型训练history1 = lstm_gjs.fit(train_x_gjs, train_y_gjs, epochs=100, batch_size=100,validation_data=(test_x_gjs, test_y_gjs), callbacks=callbacks,verbose=2, shuffle=False)# 保存最终模型lstm_gjs.save_weights('models/' + 'model_lstm_gjs.tf')

3.2 训练loss及MSE

  • 训练和测试集的loss,可以看到训练至30个epoch左右,loss已经收敛,同时MSE也较低。

  • 贵金属训练曲线:

  • 有色金属训练曲线:

  • 化工商品训练曲线:


五、完整代码地址

由于项目代码量和数据集较大,感兴趣的同学可以下载完整代码,使用过程中如遇到任何问题可以在评论区进行评论,我都会一一解答。

完整代码下载:
【代码分享】手把手教你:基于LSTM的股票预测系统

手把手教你:基于LSTM的股票预测系统相关推荐

  1. 基于循环神经网络股票预测系统

    循环神经网络-Simple RNN RNN神经网络模型是一种常用的深度神经网络模型,已成功应用于语言识别.文本分类等多个研究领域.相比其他网络模型,RNN最大的优势在于引入了时序与定向循环的概念,能够 ...

  2. 基于LSTM实现股票预测

    LSTM原理 前面的博客中使用传统的循环神经网络RNN,可以通过记忆体实现短期记忆进行连续数据的预测,但是当连续数据的序列变长时,会使展开时间步过长,在反向传播更新参数时,梯度要按照时间步连续相乘,会 ...

  3. garch预测 python_【2019年度合辑】手把手教你用Python做股票量化分析

    引言 不知不觉,2019年已接近尾声,Python金融量化公众号也有一年零两个月.公众号自设立以来,专注于分享Python在金融量化领域的应用,发布了四十余篇原创文章,超过两万人关注.这一路走来,有过 ...

  4. 毕业设计之 --- 基于大数据分析的股票预测系统

    文章目录 0 前言 1 课题意义 1.1 股票预测主流方法 2 什么是LSTM 2.1 循环神经网络 2.1 LSTM诞生 2 如何用LSTM做股票预测 2.1 算法构建流程 2.2 部分代码 3 实 ...

  5. 基于LSTM的序列预测: 飞机月流量预测

    基于LSTM的序列预测: 飞机月流量预测 循环神经网络,如RNN,LSTM等模型,比较适合用于序列预测,下面以一个比较经典的飞机月流量数据集,介绍LSTM的使用方法和训练过程. 完整的项目代码下载:h ...

  6. 基于线性回归的股票预测案例

    基于线性回归的股票预测案例 本次的案例使用的是股票数据,数据源从www.quandl.com 获取.本次案例主要是为了练习线性回归 pip install quandl 安装quandl库. 在使用的 ...

  7. 【毕业设计】LSTM股票预测系统 - python 深度学习

    文章目录 0 前言 1 课题意义 1.1 股票预测主流方法 2 什么是LSTM 2.1 循环神经网络 2.1 LSTM诞生 3 如何用LSTM做股票预测 3.1 算法构建流程 3.2 部分代码 4 实 ...

  8. 机器学习:回归分析——基于线性回归的股票预测

    基于线性回归的股票预测 数据获取 数据预处理 编码实现 数据获取 我们可以从https://data.nasdaq.com/ 获取股票数据集,每个ip访问quandl有次数(50次)的限制,如果访问次 ...

  9. 大数据毕业设计 基于时间序列的股票预测与分析系统 - 大数据分析

    文章目录 1 简介 2 时间序列的由来 2.1 四种模型的名称: 3 数据预览 4 理论公式 4.1 协方差 4.2 相关系数 4.3 scikit-learn计算相关性 5 金融数据的时序分析 5. ...

最新文章

  1. 艾伟:控件之ViewState
  2. 自己的 sublime text 配置
  3. 【转载】关于阿里巴巴的问题
  4. CCNA Cisco 端口配置(上)
  5. 普通 项目打包包含第三方jar包
  6. 如何设置Windows版Go —快速简便的指南
  7. 最小生成树唯一吗_最小生成树 - 齐芒
  8. 程序员简历的10不要与7要
  9. 1_反向代理【跨域】
  10. 指针变量的声明、地址相关运算--“*”和“”
  11. DirectShow 简介
  12. 友华PT921G光猫破解获取超级密码和更改桥接模式
  13. 知识图谱 helloword
  14. 视觉SLAM十四讲学习笔记——第八讲 视觉里程计(3)
  15. 金融学系列之 Inflation Money Remit
  16. 代码揭秘:从c/c++的角度探秘计算机系统,【C++】[代码揭秘:从C/C的角度探秘计算机系统]左飞.pdf...
  17. 如何在Mac电脑中使用键盘移动操作鼠标焦点?如何在Mac中打开辅助键盘?
  18. ue4 改变枢轴位置_UE4实时渲染深入探究----学习总结【上篇】
  19. ESP8266安卓TCP客户端开发(含全部源码)
  20. Mysql动态sql之mybatis动态sql

热门文章

  1. 如何让暴风影音播放flv文件
  2. My MPC·暴风影音 5.00_Stable
  3. 网络安全工程师要学习哪些编程语言?哪里学网络安全知识可靠?
  4. FOB指定货操作标准流程及相关经验
  5. 第25届ccf-csp认证赛后
  6. 质量管理中的“二八法则”
  7. 《计算广告》第一部分计算广告关键技术——笔记
  8. ENVI-IDL中国官方微博
  9. 量化分析入门2:上证指数走势图及移动平均线
  10. RxSwift使用教程