基于LSTM-CNN-CBAM模型的股票预测研究
1.摘要
为了更好地对股票价格进行预测,进而为股民提供合理化的建议,提出了一种在结合长短期记忆网络 (LSTM)和卷积神经网络(CNN)的基础上引入注意力机制的股票预测混合模型(LSTM-CNN-CBAM),该模型采用 的是端到端的网络结构,使用LSTM来提取数据中的时序特征,利用CNN挖掘数据中的深层特征,通过在网络结构中加入注意力机制——Convolutional Attention Block Module(CBAM)卷积模块,可以有效地提升网络的特征提取 能力。基于苹果公司2010-01-04至2018-12-31的股票进行实验,通过对比实验预测结果和评价指标,验证了在LSTM与CNN结合的网络模型中加入CBAM模块的预测有效性及可行,基于实验结果我们将给出合理结果。本实验中使用tensorflow、keras 作为神经网络的框架,使用Python 语言进行了网络的代码实现。使用苹果公司2010-01-04至2018-12-31的股票数据进行分析预测实验,将真实值和预测值进行对比,并且进行预测结果图形拟合和误差评估,通过与 ATT-LSTM 模型的对比实验,最后验证了在LSTM与CNN结合的网络模型中加入CBAM模块预测模型的有效性。
2. 相关技术
2.1 LSTM的结构和原理介绍
详细内容查看:https://haosen.blog.csdn.net/article/details/80847806
2.2 注意力模型CBAM
最近几年注意力模型在深度学习的各个领域被广泛使用,深度学习中的注意力机制的核心目标是从众多信息中选择出对当前任务目标更关键的信息。本文中,采用 Convolutional Block Attention Module(CBAM)去实现 attention 机制。CBAM 表示卷积模块的注意力机制模块,它是一种为卷积神经网络设计的,简单有效的注意力模块,结合了空间和通道的注意力模块,相对于SENet多了一个空间attention,可以取得更好的效果。CBAM 使得模型拥有了重视关键特征忽视无用特征的能力。对于卷积神经网络生成的特征图,CBAM 从通道和空间两个维度计算特征图的权重图,然后将权重图与输入的特征图相乘来进行特征的自适应学习。CBAM是一个轻量的通用模块,可以将其融入到各种卷积神经网络中进行端到端的训练。图 为CBAM网络结构图,其中Channel attention module主要关注于输入数据中有意义的内容。
其中:Channel attention module:
将输入的featuremap,分别经过基于width和height的global max pooling 和global average pooling,然后分别经过MLP。将MLP输出的特征进行基于elementwise的加和操作,再经过sigmoid激活操作,生成最终的channel attention featuremap。将该channel attention featuremap和input featuremap做elementwise乘法操作,生成Spatial attention模块需要的输入特征。
其中,seigema为sigmoid操作,r表示减少率,其中W0后面需要接RELU激活。
Spatial attention module:
将Channel attention模块输出的特征图作为本模块的输入特征图。首先做一个基于channel的global max pooling 和global average pooling,然后将这2个结果基于channel 做concat操作。然后经过一个卷积操作,降维为1个channel。再经过sigmoid生成spatial attention feature。最后将该feature和该模块的输入feature做乘法,得到最终生成的特征。
其中,seigema为sigmoid操作,7*7表示卷积核的大小,7*7的卷积核比3*3的卷积核效果更好。
CBAM其keras代码实现:
def cbam_block(cbam_feature, ratio=8):"""Contains the implementation of Convolutional Block Attention Module(CBAM) block.As described in https://arxiv.org/abs/1807.06521."""cbam_feature = channel_attention(cbam_feature, ratio)cbam_feature = spatial_attention(cbam_feature)return cbam_feature
def channel_attention(input_feature, ratio=8):channel_axis = 1 if K.image_data_format() == "channels_first" else -1channel = input_feature._keras_shape[channel_axis]shared_layer_one = Dense(channel // ratio,activation='relu',kernel_initializer='he_normal',use_bias=True,bias_initializer='zeros')shared_layer_two = Dense(channel,kernel_initializer='he_normal',use_bias=True,bias_initializer='zeros')avg_pool = GlobalAveragePooling2D()(input_feature)avg_pool = Reshape((1, 1, channel))(avg_pool)assert avg_pool._keras_shape[1:] == (1, 1, channel)avg_pool = shared_layer_one(avg_pool)assert avg_pool._keras_shape[1:] == (1, 1, channel // ratio)avg_pool = shared_layer_two(avg_pool)assert avg_pool._keras_shape[1:] == (1, 1, channel)max_pool = GlobalMaxPooling2D()(input_feature)max_pool = Reshape((1, 1, channel))(max_pool)assert max_pool._keras_shape[1:] == (1, 1, channel)max_pool = shared_layer_one(max_pool)assert max_pool._keras_shape[1:] == (1, 1, channel // ratio)max_pool = shared_layer_two(max_pool)assert max_pool._keras_shape[1:] == (1, 1, channel)cbam_feature = Add()([avg_pool, max_pool])cbam_feature = Activation('sigmoid')(cbam_feature)if K.image_data_format() == "channels_first":cbam_feature = Permute((3, 1, 2))(cbam_feature)return multiply([input_feature, cbam_feature])def spatial_attention(input_feature):kernel_size = 7if K.image_data_format() == "channels_first":channel = input_feature._keras_shape[1]cbam_feature = Permute((2, 3, 1))(input_feature)else:channel = input_feature._keras_shape[-1]cbam_feature = input_featureavg_pool = Lambda(lambda x: K.mean(x, axis=3, keepdims=True))(cbam_feature)assert avg_pool._keras_shape[-1] == 1max_pool = Lambda(lambda x: K.max(x, axis=3, keepdims=True))(cbam_feature)assert max_pool._keras_shape[-1] == 1concat = Concatenate(axis=3)([avg_pool, max_pool])assert concat._keras_shape[-1] == 2cbam_feature = Conv2D(filters=1, kernel_size=kernel_size, strides=1,padding='same',activation='sigmoid',kernel_initializer='he_normal',use_bias=False)(concat)assert cbam_feature._keras_shape[-1] == 1if K.image_data_format() == "channels_first":cbam_feature = Permute((3, 1, 2))(cbam_feature)return multiply([input_feature, cbam_feature])
3. 本文模型设计
基于 LSTM-CNN-CBAM 的股票预测网络模型是在 win操作系统下搭建的,使用的是 CPU 版本的
tensorflow框架。通过在结合长短时记忆神经网络和卷积神经网络的长记忆性分析的时间序列分类模型中加入了 CBAM 注意力机制,使模型自动学习和提取时间序列中的局部特征和长记忆性特征,模型展开如图所示。
首先是 LSTM 模块,使用了 1 层 LSTM 神经网络学习数据中的时序特征,每层 LSTM 有 128 个隐藏神经元,学习率为0.001,迭代次数(epochs)为100次,然后将学习到的特征通过卷积神经网络进行特征学习和提取,并且加入了注意力机制,最后通过5层反向传播神经网络输出预测价格,每个全连接层的神经元个数依次为256、128、64、20、1,激活函数使用ReLu函数。
该模型代码实现为:
inputs = Input(shape=(train_X.shape[1], train_X.shape[-1])) # 输入特征接收维度lstm1=LSTM(100, return_sequences=True)(inputs)print(lstm1._keras_shape[0])lstm1=Reshape((-1,lstm1._keras_shape[1],lstm1._keras_shape[-1]))(lstm1)input0 = Conv2D(10, (4, 4), activation='relu', padding='same', strides=(1, 1))(lstm1)input = Conv2D(10, (4, 4), activation='relu', padding='same', strides=(1, 1))(input0)x = BatchNormalization()(input)cbam_block = "se_block"x1 = attach_attention_module(x, cbam_block)input2 = Flatten()(x1)dense_1 = Dense(100, activation='sigmoid')(input2)dense_2 = Dropout(0.5)(dense_1)output = Dense(1, activation='linear')(dense_2) model = Model(inputs=inputs, outputs=output) # 初始命名训练的模型为modelmodel.summary()model.compile(loss="mse", optimizer='adam') # optimizer='adam',
该模型代码结构为:可以查看图像形式的:https://haosen.blog.csdn.net/article/details/118097751
_________________________________________________________________________________________
Layer (type) Output Shape Param # Connected to
=========================================================================================
input_1 (InputLayer) (None, 10, 5) 0
_________________________________________________________________________________________
lstm_1 (LSTM) (None, 10, 100) 42400 input_1[0][0]
_________________________________________________________________________________________reshape_1 (Reshape) (None, 1, 10, 100) 0 lstm_1[0][0]
_________________________________________________________________________________________
conv2d_1 (Conv2D) (None, 1, 10, 10) 16010 reshape_1[0][0]
_________________________________________________________________________________________
conv2d_2 (Conv2D) (None, 1, 10, 10) 1610 conv2d_1[0][0]
_________________________________________________________________________________________
batch_normalization_1 (BatchNor (None, 1, 10, 10) 40 conv2d_2[0][0]
_________________________________________________________________________________________
global_average_pooling2d_1 (Glo (None, 10) 0 batch_normalization_1[0][0]
_________________________________________________________________________________________
reshape_2 (Reshape) (None, 1, 1, 10) 0 global_average_pooling2d_1[0][0]
_________________________________________________________________________________________
dense_1 (Dense) (None, 1, 1, 1) 11 reshape_2[0][0]
_________________________________________________________________________________________
dense_2 (Dense) (None, 1, 1, 10) 20 dense_1[0][0]
_________________________________________________________________________________________
multiply_1 (Multiply) (None, 1, 10, 10) 0 batch_normalization_1[0][0] dense_2[0][0]
_________________________________________________________________________________________
flatten_1 (Flatten) (None, 100) 0 multiply_1[0][0]
_________________________________________________________________________________________
dense_3 (Dense) (None, 100) 10100 flatten_1[0][0]
_________________________________________________________________________________________
dropout_1 (Dropout) (None, 100) 0 dense_3[0][0]
_________________________________________________________________________________________
dense_4 (Dense) (None, 1) 101 dropout_1[0][0]
=========================================================================================
Total params: 70,292
Trainable params: 70,272
Non-trainable params: 20
_________________________________________________________________________________________
4.实验及结果分析
4.1 数据集预处理
数据预处理:由于获取到的原始数据集存在缺值和乱序等情况,所以要先对下载的数据集进行插值和按日期进行排序等操作,获得一个无乱序的完整数据集。数据标准化:由于数据集的数据之间量级不一样,例如开盘价、收盘价与成交量、成交额等数据量级之间存在着巨大的差异,为了消除数据之间不同量级的影响,将不同量级的数据统一转化为同一个量级,所以本模型对这些数据进行了 z-score 标准化处理,它是将观测值减去该组观测值的中值 (μ) ,再除以标准差 (σ) 得到的,有利于提高模型的训练速度和预测精度。
4.2 实验设置
因为LSTM神经网络具有时间序列的特性,本文将数据集的前85%作为训练集数据,后15%作为测试集数据。在 LSTM-CNN-CBAM 股票预测网络模型中,通过设置时间步长为10进行实验对比。
对比模型:LSTM和ATT-LSTM(ATT为注意力机制,其实现代码为:https://haosen.blog.csdn.net/article/details/118097281
4.3 实验结果
LSTM模型:
ATT-LSTM模型:
LSTM-CNN-CBAM模型:
其测试集的实验结果如下:
模型 | RMSE | R2 |
LSTM | 6.07 | 0.967 |
ATT-LSTM | 5.49 | 0.973 |
LSTM-CNN-CBAM | 13.94 | 0.827 |
结果分析
通 过 观 察 对 比 实 验 拟 合 图 形 可 以 发 现 单 一 的LSTM 网络对于股票价格的波动不敏感,而 LSTM 与注意力机制的结合模型有能力学习到股票价格波动的特征。然而,本文提出的LSTM-CNN-CBAM模型对股票价格的波动不敏感,可能原因如下:
1.LSTM网络到CNN网络的特征张量经过了reshape,reshape函数毕竟改变了数据的维度,也就改变了数据的组成方式,这可能会对后续的拟合预测产生影响。
2.可能CBAM机制更适用于图像领域用于提高目标检测和物体分类的精度,因此可以在图像领域及非线性特征数据中引入这一机制。
3.可能实现模型的代码出现问题。
基于LSTM-CNN-CBAM模型的股票预测研究相关推荐
- 基于泰尔森回归的股票预测研究
基于泰尔森回归的股票预测研究 绪论 背景 目的 流程 主要内容 数据获取与数据存储 数据调取以及案例数据分析 模型比较分析 2.3.1 模型初始化 2.3.2模型创建 2.3.3 模型可视化 2.3. ...
- DL之Keras:基于Keras框架建立模型实现【预测】功能的简介、设计思路、案例分析、代码实现之详细攻略(经典,建议收藏)
DL之Keras:基于Keras框架建立模型实现[预测]功能的简介.设计思路.案例分析.代码实现之详细攻略(经典,建议收藏) 目录 Keras框架使用分析 Keras框架设计思路 案例分析 代码实现 ...
- 【电力预测】基于matlab GUI灰色模型电力负荷预测【含Matlab源码 769期】
一.获取代码方式 获取代码方式1: 完整代码已上传我的资源: [电力负荷预测]基于matlab GUI灰色模型电力负荷预测[含Matlab源码 769期] 获取代码方式2: 通过订阅紫极神光博客付费专 ...
- 径向基函数神经网络_基于RBF神经网络的网络安全态势感知预测研究
点击上方"网络空间安全学术期刊"关注我们 基于RBF神经网络的网络安全态势 感知预测研究 钱建, 李思宇 摘要 针对网络安全态势的感知问题,结合巨龙山和者磨山风电场的运行情况,文章 ...
- LSTM 长短期记忆神经网络及股票预测实现
一.介绍 我们知道RNN(循环神经网络)可以通过时间序列预测输出,LSTM也具有同样的功能,那么为什么需要LSTM呢? 由于RNN在参数更新过程中参数矩阵更新可能会造成梯度消失的问题,这才演化出了具有 ...
- keras实现简单lstm_四十二.长短期记忆网络(LSTM)过程和keras实现股票预测
一.概述 传统循环网络RNN可以通过记忆体实现短期记忆进行连续数据的预测,但是,当连续数据的序列边长时,会使展开时间步过长,在反向传播更新参数的过程中,梯度要按时间步连续相乘,会导致梯度消失或者梯度爆 ...
- 文献综述--------山东某地区基于深度学习神经网络的配电网负荷预测研究
摘 要:地区电网负荷预测是供电企业在电网建设.运营过程中一项十分要的基础性的工作.小到一个企业的负荷预测,大到全国性电网的负荷预测研究,它的应用结果都会对适用范围内的企业经营管理.电力设施(电网)的 ...
- libsvm java下载_一个基于LIBSVM(JAVA)的股票预测demo
[实例简介] 一个基于LIBSVM的股票价格预测程序,采用随机森林算法对样本进行训练和预测,使用的编程语言为JAVA. [实例截图] [核心代码] stock-master └── stock-mas ...
- 基于LSTM算法的风电功率区间预测【含源代码】
基本情况: 数据集为2018年6月2日-2018年8月9日的6624个样本,采样间隔为15 分钟.选取2018年6月2日-2018年8月4日为训练集,剩下作为测试集.采用过去5个小时的特征作为输入,目 ...
最新文章
- android中常见的内存泄漏和解决的方法
- 将标签one-hot化的方法
- git--分支管理策略
- VTK:图像投射用法实战
- 已知线性表最多可能有20个元素,存储每个元素需要8字节,存储每个指针需要4字节。当元素个数为( )时使用单链表比使用数组存储此线性表更加节约空间。
- 吐血整理!近二十年全国数学联赛赛题大全,烧脑全集来啦!
- c# 微服务学习_资深架构师学习笔记:什么是微服务?
- 免费报名通道限时开启!解锁QCon「AI 时代下的融合通信技术」专场
- python下载官网-Python2.7.10
- 读书:鲁迅的《呐喊》和《彷徨》
- 关于html5中a链接的download属性
- MySQL · 性能优化 · SQL错误用法详解
- 【ArcGIS|空间分析|网络分析】5 计算服务区和创建 OD 成本矩阵
- 对于tf.nn.sparse_softmax_cross_entropy_with_logits(logits=y,labels=tf.argmax(y_,1))的研究
- oracle9i阻塞,Oracle 9i 整体性能优化概述(zt)
- android 外接u盘格式化,u盘格式(安卓u盘格式化工具apk)
- js-showModalDialog和dialogArguments
- matlab 符号函数 syms
- 平庸前端码农之蜕变 — AST
- IDEA-局部变量、成员变量
热门文章
- 做一个简单的java小游戏--贪吃蛇
- Merge into的使用 之 Where 条件
- 原生微信小程序圆形倒计时svg组件
- c语言课程设计输入模块,【图片】发几个C语言课程设计源代码(恭喜自己当上技术小吧主)【东华理工大学吧】_百度贴吧...
- 常系数微分方程组的V函数构造定理的解释
- 关于配置远程仓库gitee无法连接配置微服务的问题
- 夜深人静写算法(四十三)- 线性DP
- 九、键盘控制无人机 · 中(multirotor_communication.py解读)
- Android中shell控制cpu,常用ADB指令控制手机
- WIFEXITED WEXITSTATUS WIFSIGNALED