一.概述

传统循环网络RNN可以通过记忆体实现短期记忆进行连续数据的预测,但是,当连续数据的序列边长时,会使展开时间步过长,在反向传播更新参数的过程中,梯度要按时间步连续相乘,会导致梯度消失或者梯度爆炸。

LSTM是RNN的变体,通过门结构,有效的解决了梯度爆炸或者梯度消失问题。

LSTM在RNN的基础上引入了三个门结构和记录长期记忆的细胞态以及归纳出新知识的候选态。

二.LSTM结构

1.短期记忆

短期记忆即为RNN中的记忆体,在LSTM中,它的通过输出门

和经过tanh函数的长期记忆的哈达玛积得到:

2.细胞态(长期记忆)

长期记忆记录了当前时刻的历史信息:

其中,

为上一时刻的长期记忆,
为遗忘门,
为输入门,
为候选状态,表示在本时间段归纳出的新知识:

3.输入门、遗忘门、输出门

它们三个都是当前时刻的输入特征

和上个时刻的短期记忆
的函数。

遗忘门通过sigmod函数,将上一层隐藏状态

和本层输入
映射到[0,1],表示上一层的内部状态
需要遗忘多少信息,公式为下:

输入门

控制当前候选状态
有多少信息需要保存。

输出门

控制当前时刻的内部状态
有多少信息传递给隐藏信息

三.LSTM过程

1.先利用上一时刻的隐藏状态

和当前输入计算出三个门和候选状态:

2.结合遗忘门

和输入门
更新长期记忆:

3.结合输出门和内部状态更新隐藏状态:

4.反向传播,利用梯度下降等优化方法更新参数矩阵和偏置。

四.keras+LSTM实现股票预测

导入依赖包

import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
from tensorflow.keras.layers import Dense,Dropout,LSTM
from sklearn.preprocessing import MinMaxScaler
from sklearn.metrics import mean_absolute_error,mean_squared_error

读取数据

maotai = pd.read_csv('./SH600519.csv')
training_set = maotai.iloc[0:2126,2:3].values
test_set = maotai.iloc[2126:,2:3].values
print(training_set.shape,test_set.shape)
输出:
(2126, 1) (300, 1)

归一化

sc = MinMaxScaler(feature_range=(0,1))
training_set = sc.fit_transform(training_set)
test_set = sc.fit_transform(test_set)

划分训练数据和测试数据

x_train,y_train,x_test,y_test=[],[],[],[]
for i in range(60,len(training_set)):x_train.append(training_set[i-60:i,0])y_train.append(training_set[i,0])
np.random.seed(7)
np.random.shuffle(x_train)
np.random.seed(7)
np.random.shuffle(y_train)
tf.random.set_seed(7)
x_train,y_train = np.array(x_train),np.array(y_train)
x_train = np.reshape(x_train, (x_train.shape[0], 60, 1))
for i in range(60, len(test_set)):x_test.append(test_set[i - 60:i, 0])y_test.append(test_set[i, 0])
x_test, y_test = np.array(x_test), np.array(y_test)
x_test = np.reshape(x_test, (x_test.shape[0], 60, 1))

搭建网络

model = tf.keras.Sequential([LSTM(80,return_sequences=True),Dropout(0.2),LSTM(100),Dropout(0.2),Dense(1)
])

配置网络

model.compile(optimizer=tf.keras.optimizers.Adam(0.001),loss='mean_squared_error')

开始训练

history = model.fit(x_train, y_train, batch_size=64, epochs=50, validation_data=(x_test, y_test), validation_freq=1)

训练过程

Epoch 1/50
33/33 [==============================] - 4s 114ms/step - loss: 0.0135 - val_loss: 0.0110
Epoch 2/50
33/33 [==============================] - 3s 95ms/step - loss: 0.0013 - val_loss: 0.0049
Epoch 3/50
33/33 [==============================] - 3s 99ms/step - loss: 0.0011 - val_loss: 0.0051
Epoch 4/50
33/33 [==============================] - 3s 98ms/step - loss: 0.0013 - val_loss: 0.0057
Epoch 5/50
33/33 [==============================] - 3s 95ms/step - loss: 0.0011 - val_loss: 0.0047
Epoch 6/50
33/33 [==============================] - 3s 99ms/step - loss: 0.0011 - val_loss: 0.0046
Epoch 7/50
33/33 [==============================] - 3s 92ms/step - loss: 0.0011 - val_loss: 0.0046
Epoch 8/50
33/33 [==============================] - 3s 86ms/step - loss: 0.0010 - val_loss: 0.0049
Epoch 9/50
33/33 [==============================] - 3s 84ms/step - loss: 0.0010 - val_loss: 0.0051
Epoch 10/50
33/33 [==============================] - 3s 86ms/step - loss: 0.0010 - val_loss: 0.0051
Epoch 11/50
33/33 [==============================] - 3s 86ms/step - loss: 9.7592e-04 - val_loss: 0.0044
Epoch 12/50
33/33 [==============================] - 3s 87ms/step - loss: 9.6163e-04 - val_loss: 0.0043
Epoch 13/50
33/33 [==============================] - 3s 88ms/step - loss: 0.0011 - val_loss: 0.0041
Epoch 14/50
33/33 [==============================] - 3s 89ms/step - loss: 9.1143e-04 - val_loss: 0.0042
Epoch 15/50
33/33 [==============================] - 3s 89ms/step - loss: 0.0011 - val_loss: 0.0046
Epoch 16/50
33/33 [==============================] - 3s 89ms/step - loss: 8.8493e-04 - val_loss: 0.0040
Epoch 17/50
33/33 [==============================] - 3s 90ms/step - loss: 9.2448e-04 - val_loss: 0.0042
Epoch 18/50
33/33 [==============================] - 3s 91ms/step - loss: 8.7795e-04 - val_loss: 0.0038
Epoch 19/50
33/33 [==============================] - 3s 91ms/step - loss: 7.1217e-04 - val_loss: 0.0045
Epoch 20/50
33/33 [==============================] - 3s 91ms/step - loss: 0.0012 - val_loss: 0.0038
Epoch 21/50
33/33 [==============================] - 3s 93ms/step - loss: 8.5274e-04 - val_loss: 0.0037
Epoch 22/50
33/33 [==============================] - 3s 92ms/step - loss: 9.9773e-04 - val_loss: 0.0052
Epoch 23/50
33/33 [==============================] - 3s 93ms/step - loss: 9.0810e-04 - val_loss: 0.0046
Epoch 24/50
33/33 [==============================] - 3s 93ms/step - loss: 8.4353e-04 - val_loss: 0.0041
Epoch 25/50
33/33 [==============================] - 3s 95ms/step - loss: 8.7846e-04 - val_loss: 0.0037
Epoch 26/50
33/33 [==============================] - 3s 94ms/step - loss: 7.2408e-04 - val_loss: 0.0035
Epoch 27/50
33/33 [==============================] - 3s 95ms/step - loss: 7.8355e-04 - val_loss: 0.0059
Epoch 28/50
33/33 [==============================] - 3s 96ms/step - loss: 8.1942e-04 - val_loss: 0.0035
Epoch 29/50
33/33 [==============================] - 3s 96ms/step - loss: 7.7674e-04 - val_loss: 0.0033
Epoch 30/50
33/33 [==============================] - 3s 95ms/step - loss: 7.3867e-04 - val_loss: 0.0037
Epoch 31/50
33/33 [==============================] - 3s 97ms/step - loss: 7.2609e-04 - val_loss: 0.0033
Epoch 32/50
33/33 [==============================] - 3s 96ms/step - loss: 6.9374e-04 - val_loss: 0.0033
Epoch 33/50
33/33 [==============================] - 3s 96ms/step - loss: 6.3776e-04 - val_loss: 0.0050
Epoch 34/50
33/33 [==============================] - 3s 97ms/step - loss: 7.6443e-04 - val_loss: 0.0036
Epoch 35/50
33/33 [==============================] - 3s 98ms/step - loss: 7.9301e-04 - val_loss: 0.0032
Epoch 36/50
33/33 [==============================] - 3s 97ms/step - loss: 7.7646e-04 - val_loss: 0.0036
Epoch 37/50
33/33 [==============================] - 3s 99ms/step - loss: 8.3467e-04 - val_loss: 0.0033
Epoch 38/50
33/33 [==============================] - 3s 99ms/step - loss: 7.6392e-04 - val_loss: 0.0032
Epoch 39/50
33/33 [==============================] - 3s 99ms/step - loss: 6.3954e-04 - val_loss: 0.0047
Epoch 40/50
33/33 [==============================] - 3s 99ms/step - loss: 7.3498e-04 - val_loss: 0.0034
Epoch 41/50
33/33 [==============================] - 3s 99ms/step - loss: 5.8371e-04 - val_loss: 0.0031
Epoch 42/50
33/33 [==============================] - 3s 99ms/step - loss: 5.7156e-04 - val_loss: 0.0034
Epoch 43/50
33/33 [==============================] - 3s 100ms/step - loss: 6.2417e-04 - val_loss: 0.0030
Epoch 44/50
33/33 [==============================] - 3s 101ms/step - loss: 6.8761e-04 - val_loss: 0.0035
Epoch 45/50
33/33 [==============================] - 4s 108ms/step - loss: 6.7483e-04 - val_loss: 0.0031
Epoch 46/50
33/33 [==============================] - 4s 113ms/step - loss: 6.2236e-04 - val_loss: 0.0031
Epoch 47/50
33/33 [==============================] - 4s 115ms/step - loss: 6.4746e-04 - val_loss: 0.0034
Epoch 48/50
33/33 [==============================] - 4s 112ms/step - loss: 7.4622e-04 - val_loss: 0.0029
Epoch 49/50
33/33 [==============================] - 3s 101ms/step - loss: 6.8864e-04 - val_loss: 0.0028
Epoch 50/50
33/33 [==============================] - 3s 101ms/step - loss: 5.6762e-04 - val_loss: 0.0028

loss曲线

loss =  history.history['loss']
val_loss = history.history['val_loss']
plt.plot(loss,label='Training Loss')
plt.plot(val_loss,label='Validation Loss')
plt.legend()
plt.title('Loss')
plt.show()

预测结果与真实值比较

predict_price = model.predict(x_test)
predict_price = sc.inverse_transform(predict_price)
real_price = sc.inverse_transform(test_set[60:])
plt.plot(real_price, color='red', label='MaoTai Stock Price')
plt.plot(predict_price, color='blue', label='Predicted MaoTai Stock Price')
plt.title('MaoTai Stock Price Prediction')
plt.xlabel('Time')
plt.ylabel('MaoTai Stock Price')
plt.legend()
plt.show()

查看评价指标(均方误差和均方根差)

mse=mean_squared_error(predict_price,real_price)
mae = mean_absolute_error(predict_price,real_price)
print('mean_squared_error',mse)
print('mean_absolute_error',mae)
输出:
mean_squared_error 922.6493975725148
mean_absolute_error 23.789508666992194

keras实现简单lstm_四十二.长短期记忆网络(LSTM)过程和keras实现股票预测相关推荐

  1. 动手学深度学习(四十)——长短期记忆网络(LSTM)

    文章目录 一.长短期记忆网络(LSTM) 1.1 门控记忆单元 1.2 输入门.遗忘门与输出门 1.3候选记忆单元 1.4 记忆单元 1.5 隐藏状态 二.从零实现LSTM 2.1 初始化模型参数 2 ...

  2. 基于长短期记忆网络(LSTM)对股票价格的涨跌幅度进行预测

    完整代码:https://download.csdn.net/download/qq_38735017/87536579 为对股票价格的涨跌幅度进行预测,本文使用了基于长短期记忆网络(LSTM)的方法 ...

  3. 简单介绍长短期记忆网络 - LSTM

    文章目录 一.引言 1.1 什么是LSTM 二.循环神经网络RNN 2.1 为什么需要RNN 三.长短时记忆神经网络LSTM 3.1 为什么需要LSTM 3.2 LSTM结构分析 3.3 LSTM背后 ...

  4. 『NLP学习笔记』长短期记忆网络LSTM介绍

    长短期记忆网络LSTM介绍 文章目录 一. 循环神经网络 二. 长期依赖问题 三. LSTM 网络 四. LSTM 背后的核心理念 4.1 忘记门 4.2 输入门 4.3 输出门 五. LSTM总结( ...

  5. MATLAB-基于长短期记忆网络(LSTM)的SP500的股票价格预测 股价预测 matlab实战 数据分析 数据可视化 时序数据预测 变种RNN 股票预测

    MATLAB-基于长短期记忆网络(LSTM)的SP500的股票价格预测 股价预测 matlab实战 数据分析 数据可视化 时序数据预测 变种RNN 股票预测 摘要 近些年,随着计算机技术的不断发展,神 ...

  6. 1014长短期记忆网络(LSTM)

    长短期记忆网络(LSTM) 长期以来,隐变量模型存在着长期信息保存和短期输入缺失的问题,解决这个问题最早的方法之一就是 LSTM 发明于90年代 使用的效果和 GRU 相差不大,但是使用的东西更加复杂 ...

  7. 长短期记忆网络(LSTM)学习笔记

    文章目录 0 前言 1 LSTM与RNN的异同 2 LSTM结构细节 2.1 细胞状态 2.2 遗忘门 2.3 输入门 2.4 输出门 3 总结 4 LSTM的变体 4.1 Adding " ...

  8. 论文解读:《一种基于长短期记忆网络深度学习的药物靶相互作用预测方法》

    论文解读:<A deep learning-based method for drug-target interaction prediction based on long short-ter ...

  9. 长短期记忆网络LSTM

    1. LSTM是循环神经网络的一个变体可以有效的解决简单循环神经网络的梯度消失和梯度爆炸的问题. 2. 改进方面: 新的内部状态 Ct专门进行线性的循环信息传递,同时(非线性的)输出信息给隐藏层的外部 ...

最新文章

  1. 有望取代Spark,Michael Jordan和Ion Stoica提出下一代分布式实时机器学习框架Ray牛在哪?...
  2. visual studio odbc数据源设计器_NEW!WinForm界面开发设计时正式支持.NET 5
  3. [转]VS2010+MFC解析Excel文件中数据
  4. lofter 爬虫_Lofter德赫标签日榜 | 200703
  5. dubbo/dubbox部署资料收集
  6. 【运筹学】匈牙利法 ( 匈牙利法示例 )
  7. 博途v15程序监视无法使用_博途V15打开应用程序失败,应用程序的并行配置不正确-工业支持中心-西门子中国...
  8. 来电弹屏功能在呼叫中心的应用
  9. 【数据预处理】Pandas缺失的数据处理
  10. 全民一起玩Python提高篇第十二课:面向对象基本原理与语法(三)
  11. mysql多字段in用法
  12. TheProjetXXXXXneedstobedeployedbeforeitanbestarted
  13. python是一种跨平台开源解释型的_Python是一种跨平台、开源、免费的动态编程语言。...
  14. 第一人称游戏与第三人称游戏的区别
  15. 如何多方位布局程序化购买生态链条?
  16. mysql_connect()不支持解决方法
  17. es6(二) 解构赋值
  18. Failed to execute 'toDataURL' on 'HTMLCanvasElement,在canvas.toDataURL()执行时候报错解决方案...
  19. Apache POI 实现报表导入和导出
  20. java iso8859 转utf8

热门文章

  1. C语言单片机数码管a段亮,各位大神,如何用C语言实现在数码管上实现1234同时亮...
  2. java returnAddres_JavaClient 查询ES-(重要)
  3. Linux驱动编程 step-by-step (七) 并发 竞态 (信号量与自旋锁)
  4. Linux2.6内核驱动与2.4的区别 .
  5. OPenCV膨胀函数dilate()的使用
  6. First Missing Positive
  7. 5.6 前端开发日报
  8. 嵌入式Linux上通过boa服务器实现cgi/html的web上网【转】
  9. 碰碰车司机教你Linux下使用nmon分析系统性能
  10. 寻找最大的K个数,Top K问题的堆实现