一、库准备

import matplotlib.pyplot as plt
from sklearn.preprocessing import MinMaxScaler
import numpy as np
import akshare as ak
import torch
from torch import nn

二、构造

1.把茅台2017年到今天的股价画出来

share_prices=ak.stock_zh_a_hist(symbol='600519',start_date='20170101',end_date='20220410',adjust='qfq')['收盘'].valuesshare_prices = share_prices.astype('float32')  # 转换数据类型: obj ->float
plt.plot(share_prices)

2.数据normalization

# 将数据集标准化到 [-1,1] 区间
scaler = MinMaxScaler(feature_range=(-1, 1))  # train data normalized
share_prices = scaler.fit_transform(share_prices.reshape(-1, 1))

3.构造一个数据切分函数

def create_dataset(data, days_for_train=5) -> (np.array, np.array):"""根据给定的序列data,生成数据集。数据集分为输入和输出,每一个输入的长度为days_for_train,每一个输出的长度为1。也就是说用days_for_train天的数据,对应下一天的数据。若给定序列的长度为d,将输出长度为(d-days_for_train)个输入/输出对"""dataset_x, dataset_y = [], []for i in range(len(data) - days_for_train):_x = data[i:(i + days_for_train)]dataset_x.append(_x)dataset_y.append(data[i + days_for_train])return (np.array(dataset_x), np.array(dataset_y))
dataset_x, dataset_y = create_dataset(share_prices, DAYS_FOR_TRAIN)

以5个数为例,相当于X=[[0,1,2,3,4],[1,2,3,4,5]........[n-5,n-4,n-3,n-2,n-1]],Y=[[5],[6],[7],.....[n]]

在这里我们一共有1279个数,因此最后就切分成:

dataset_x:[1274,5,1]——即1274个五行一列的数组

dataset_y:[1274,1]——1274行一列的数组

4.分训练和验证

train_size = int(len(dataset_x) * 0.7)
train_x = dataset_x[:train_size]
train_y = dataset_y[:train_size]
test_x = dataset_x[train_size:]
test_y = dataset_y[train_size:]

5.train改成[891,1,5]的size

seq代表单个序列的长度,batch_size代表一次喂入的序列个数,feature_size代表特征维度,

# 改变数据集形状,RNN 读入的数据维度是 (seq_size, batch_size, feature_size)
train_x = train_x.reshape( -1, 1,DAYS_FOR_TRAIN)
train_y = train_y.reshape(-1, 1, 1)
# 数据集转为pytorch的tensor对象
train_x = torch.from_numpy(train_x)
train_y = torch.from_numpy(train_y)

6.用lstm训练

# train model
model = LSTM_Regression(DAYS_FOR_TRAIN, 32, output_size=1, num_layers=3)  # 网络初始化
loss_function = nn.MSELoss()  # 损失函数
optimizer = torch.optim.Adam(model.parameters(), lr=1e-2)  # 优化器
for epoch in range(EPOCHS):out = model(train_x)loss = loss_function(out, train_y)loss.backward()optimizer.step()optimizer.zero_grad()if (epoch + 1) % 100 == 0:print('Epoch: {}, Loss:{:.5f}'.format(epoch + 1, loss.item()))
class LSTM_Regression(nn.Module):"""使用LSTM进行回归参数:- input_size: feature size- hidden_size: number of hidden units- output_size: number of output- num_layers: layers of LSTM to stack"""def __init__(self, input_size, hidden_size, output_size=1, num_layers=2):super().__init__()self.lstm = nn.LSTM(input_size, hidden_size, num_layers)self.fc = nn.Linear(hidden_size, output_size)def forward(self, _x):x, _ = self.lstm(_x)  # _x is input, size (seq_len, batch, input_size)s, b, h = x.shape  # x is output, size (seq_len, batch, hidden_size)x = x.view(s * b, h)x = self.fc(x)x = x.view(s, b, -1)  # 把形状改回来return x

Epoch: 100, Loss:0.00137
Epoch: 200, Loss:0.00094
Epoch: 300, Loss:0.00067
Epoch: 400, Loss:0.00048
Epoch: 500, Loss:0.00036
Epoch: 600, Loss:0.00072
Epoch: 700, Loss:0.00026
Epoch: 800, Loss:0.00024
Epoch: 900, Loss:0.00023
Epoch: 1000, Loss:0.00022
Epoch: 1100, Loss:0.00034
Epoch: 1200, Loss:0.00022
Epoch: 1300, Loss:0.00055
Epoch: 1400, Loss:0.00020
Epoch: 1500, Loss:0.00048
Epoch: 1600, Loss:0.00020
Epoch: 1700, Loss:0.00021
Epoch: 1800, Loss:0.00019
Epoch: 1900, Loss:0.00026
Epoch: 2000, Loss:0.00019

7.数据验证

model = model.eval()  # 转换成测试模式
test_x = test_x.reshape(-1,1, DAYS_FOR_TRAIN)
pred_y = model(torch.from_numpy(test_x))
pred_y = pred_y.view(-1).data.numpy()

1)只喂test数据(红色为预测,蓝色为真实)

误差:

2)喂全量数据

dataset_x = dataset_x.reshape(-1, 1, DAYS_FOR_TRAIN)  # (seq_size, batch_size, feature_size)
dataset_x = torch.from_numpy(dataset_x)  # 转为pytorch的tensor对象
pred_y = model(dataset_x)  # 全量数据集的模型输出 (seq_size, batch_size, output_size)
pred_y = pred_y.view(-1).data.numpy()
# 对标准化数据进行还原
actual_pred_y = scaler.inverse_transform(pred_y.reshape(-1, 1))
actual_pred_y = actual_pred_y.reshape(-1, 1).flatten()
test_y = scaler.inverse_transform(test_y.reshape(-1, 1))
test_y = test_y.reshape(-1, 1).flatten()
actual_pred_y = actual_pred_y[-len(test_y):]
test_y = test_y.reshape(-1, 1)
assert len(actual_pred_y) == len(test_y)

表现:

误差:

三、结论

1.lstm再单特征预测上作用有限

2.后续预测的时候,也是根据前五预报第六个数, 有数据泄露之嫌

3.趋势看似能预测对,但是幅度预测基本上也不准,下一步倒是可以考虑预测每日涨跌,或者换成预测波动率靠谱一点?(暂时没想明白)

Pytorch下基于lstm的股价预测相关推荐

  1. 基于pytorch下用LSTM做股票预测——超详细

    理论 LSTM理论详解 代码 请转到链接:文章详情 另外,欢迎大家打赏!

  2. lstm 根据前文预测词_干货 | Pytorch实现基于LSTM的单词检测器

    Pytorch实现 基于LSTM的单词检测器 字幕组双语原文: Pytorch实现基于LSTM的单词检测器 英语原文: LSTM Based Word Detectors 翻译: 雷锋字幕组(Icar ...

  3. 基于LSTM的序列预测: 飞机月流量预测

    基于LSTM的序列预测: 飞机月流量预测 循环神经网络,如RNN,LSTM等模型,比较适合用于序列预测,下面以一个比较经典的飞机月流量数据集,介绍LSTM的使用方法和训练过程. 完整的项目代码下载:h ...

  4. 基于Informer的股价预测(量化交易综述)

    摘要 股票市场是金融市场中不可或缺的组成部分.准确预测股票趋势对于投资者和市场参与者具有重要意义,因为它们可以指导投资决策.优化投资组合以及降低金融风险.而且可以提升国家国际地位以及金融风险控制能力, ...

  5. 利用LSTM进行股价预测

    利用LSTM进行股价预测 效果 原理 代码 应用 效果 原理 LSTM即长短记忆网络,是一种很强的RNN,这种网络的特性是以前的输入会影响现在的输出,具体原理请自行搜索. 算法流程: 获取yahoo财 ...

  6. 【ML】基于LSTM的心脏病预测研究(附代码和数据集,系列1)

    写在前面: 首先感谢兄弟们的订阅,让我有创作的动力,在创作过程我会尽最大努力,保证作品的质量,如果有问题,可以私信我,让我们携手共进,共创辉煌. 之前创作过心脏病预测研究文章如下: [ML]基于机器学 ...

  7. 手把手教你:基于LSTM的股票预测系统

    系列文章 第七章.手把手教你:基于深度残差网络(ResNet)的水果分类识别系统 第六章.手把手教你:人脸识别的视频打码 第五章.手把手教你:基于深度学习的滚动轴承故障诊断 目录 系列文章 一.项目简 ...

  8. 交通状态预测 | Python实现基于LSTM的客流量预测方法

    交通状态预测 | Python实现基于LSTM的客流量预测方法 目录 交通状态预测 | Python实现基于LSTM的客流量预测方法 基本介绍 研究回顾 模型结构 程序设计 参考资料 基本介绍 Pyt ...

  9. 基于LSTM时间序列分析预测拉尼娜年天气

    基于LSTM时间序列分析预测拉尼娜年天气 天气预测 Import all necessary libraries Replace all column names by overwritting on ...

  10. pytorch LSTM的股价预测

    股价预测一直以来都是幻想能够被解决的问题,本文中主要使用了lstm模型去对股价做一个大致的预测,数据来源是tushare,非常感谢tushare的数据!! 为什么要用LSTM? LSTM是一种序列模型 ...

最新文章

  1. 安卓如何实现多级结构树_数据结构-树(树基本实现C++)
  2. 基于 Wasm 和 ORAS 简化扩展服务网格功能
  3. 彩票假设 (Lottery Ticket Hypothesis) 在CV、NLP和OOD领域的应用
  4. AFURLRequestSerialization
  5. qq红包领取支付宝红包功能,qqxml跳转技术
  6. html网页制作比赛要求,校园网页设计大赛活动方案
  7. Python 学习:函数
  8. CSS中设置页面背景图片
  9. js实现点击保存图片
  10. 对于position定位的认识
  11. 四川嘉庆恒运:拼多多上买化妆品靠谱不
  12. HTML简洁单页网址导航模板
  13. Python 个性计算器(可不断加挂功能模块实现个性化)
  14. 2019计算机一级msoffice真题,【2019年整理】全国计算机一级MSOFFICE练习题带题解.pdf...
  15. 谁在押注“脱口秀直播带货”?
  16. 基于Vue的医院内部管理系统(医生、患者、挂号、药房)文档+答辩PPt+项目源码+演示视频
  17. python:tzinfo 对象
  18. 浅谈Jmeter性能测试流程
  19. Unity开发——The associated script can not be loaded.
  20. LeetCode 1937. 扣分后的最大得分(动态规划)

热门文章

  1. Bloodsucker 【 ZOJ - 3551】
  2. 2018年Android面试题含答案
  3. 计算机设备选型原则,计算机中通信技术的选用应遵循什么原则
  4. 基于JDE的目标跟踪算法前沿研究跟进
  5. 给人工智能初学者看的5本入门书 | 附下载链接
  6. DC-DC buck降压电路 电压电流双闭环PI控制matlab仿真模型
  7. 百度首页代码(HTML+CSS+jQuery)
  8. An Introduction to Be-trees and Write Optimization 学习笔记
  9. xp系统如何启用服务器服务,xp系统怎么样启用远程服务器
  10. JS - 解决鼠标单击、双击事件冲突问题(同时实现两种事件响应)