import pandas as pd
import matplotlib.pyplot as plt
import datetime
import torch
import torch.nn as nn
import numpy as np
from torch.utils.data import Dataset, DataLoader# 确定每月日期 2014-01-01~2016-01-01
dates = get_trading_dates(start_date="2018-11-01", end_date="2021-12-15")
#stock = index_components("000065.XSHE")x = get_price(["000065.XSHE"], start_date=dates[0], end_date=dates[-1], fields='close')
T = len(x)//用到的函数
def generate_df_affect_by_n_days(series, n, index=False):if len(series) <= n:raise Exception("The Length of series is %d, while affect by (n=%d)." % (len(series), n))df = pd.DataFrame()for i in range(n):df['c%d' % i] = series.tolist()[i:-(n - i)]df['y'] = series.tolist()[n:]if index:df.index = series.index[n:]return dfdef readData(column='close', n=30, all_too=True, index=False, train_end=-300):#df = pd.read_csv("sh.csv", index_col=0)#df.index = list(map(lambda x: datetime.datetime.strptime(x, "%Y-%m-%d"), df.index))#df.index = x["date"]global xx = x.reset_index()x = x.set_index(['date'])df = xdf_column = df[column].copy()df_column_train, df_column_test = df_column[:train_end], df_column[train_end - n:]df_generate_from_df_column_train = generate_df_affect_by_n_days(df_column_train, n, index=index)if all_too:return df_generate_from_df_column_train, df_column, df.index.tolist()return df_generate_from_df_column_trainclass RNN(nn.Module):def __init__(self, input_size):super(RNN, self).__init__()self.rnn = nn.LSTM(input_size=input_size,hidden_size=64,num_layers=1,batch_first=True)self.out = nn.Sequential(nn.Linear(64, 1))def forward(self, x):r_out, (h_n, h_c) = self.rnn(x, None)  # None 表示 hidden state 会用全0的 stateout = self.out(r_out)return outclass TrainSet(Dataset):def __init__(self, data):# 定义好 image 的路径self.data, self.label = data[:, :-1].float(), data[:, -1].float()def __getitem__(self, index):return self.data[index], self.label[index]def __len__(self):return len(self.data)//提取数据和数据处理
n = 30
LR = 0.0001
EPOCH = 100
train_end = -300
# 数据集建立
df, df_all, df_index = readData('close', n=n, train_end=train_end)df_all = np.array(df_all.tolist())
plt.plot(df_index, df_all, label='real-data')df_numpy = np.array(df)df_numpy_mean = np.mean(df_numpy)
df_numpy_std = np.std(df_numpy)df_numpy = (df_numpy - df_numpy_mean) / df_numpy_std
df_tensor = torch.Tensor(df_numpy)trainset = TrainSet(df_tensor)
trainloader = DataLoader(trainset, batch_size=10, shuffle=True)//模型训练
rnn = RNN(n)
optimizer = torch.optim.Adam(rnn.parameters(), lr=LR)  # optimize all cnn parameters
loss_func = nn.MSELoss()for step in range(EPOCH):for tx, ty in trainloader:output = rnn(torch.unsqueeze(tx, dim=0))loss = loss_func(torch.squeeze(output), ty)optimizer.zero_grad()  # clear gradients for this training steploss.backward()  # back propagation, compute gradientsoptimizer.step()print(step, loss)if step % 10:torch.save(rnn, 'rnn.pkl')
torch.save(rnn, 'rnn.pkl')//股价预测
generate_data_train = []
generate_data_test = []test_index = len(df_all) + train_enddf_all_normal = (df_all - df_numpy_mean) / df_numpy_std
df_all_normal_tensor = torch.Tensor(df_all_normal)
for i in range(n, len(df_all)):x = df_all_normal_tensor[i - n:i]x = torch.unsqueeze(torch.unsqueeze(x, dim=0), dim=0)y = rnn(x)if i < test_index:generate_data_train.append(torch.squeeze(y).detach().numpy() * df_numpy_std + df_numpy_mean)else:generate_data_test.append(torch.squeeze(y).detach().numpy() * df_numpy_std + df_numpy_mean)
plt.plot(df_index[n:train_end], generate_data_train, label='generate_train')
plt.plot(df_index[train_end:], generate_data_test, label='generate_test')
plt.legend()
plt.show()
plt.cla()
plt.plot(df_index[train_end:-400], df_all[train_end:-400], label='real-data')
plt.plot(df_index[train_end:-400], generate_data_test[:-400], label='generate_test')
plt.legend()
plt.show()

效果图

基于ricequant的lstm时间序列股价预测(pytorch)相关推荐

  1. 【LSTM时间序列数据】基于matlab LSTM时间序列数据预测【含Matlab源码 1949期】

    ⛄一.获取代码方式 获取代码方式1: 完整代码已上传我的资源:[LSTM时间序列数据]基于matlab LSTM时间序列数据预测[含Matlab源码 1949期] 获取代码方式2: 付费专栏Matla ...

  2. 利用LSTM进行股价预测

    利用LSTM进行股价预测 效果 原理 代码 应用 效果 原理 LSTM即长短记忆网络,是一种很强的RNN,这种网络的特性是以前的输入会影响现在的输出,具体原理请自行搜索. 算法流程: 获取yahoo财 ...

  3. 基于CCI与SVM的股价预测(菜鸟版)

    基于CCI与SVM的股价预测(菜鸟版) 一.前言 二.工作原理 1.CCI 解读 2.支持向量机原理 三.系统工作流程 1.获取用户输入,同时获取股票或指数列表,判断该证券是否存在. 2.获取日线数据 ...

  4. 时间序列预测 | Python实现Prophet、ARIMA、LSTM时间序列数据预测

    时间序列预测 | Python实现Prophet.ARIMA.LSTM时间序列数据预测 目录 时间序列预测 | Python实现Prophet.ARIMA.LSTM时间序列数据预测 数据描述 特征工程 ...

  5. pytorch LSTM的股价预测

    股价预测一直以来都是幻想能够被解决的问题,本文中主要使用了lstm模型去对股价做一个大致的预测,数据来源是tushare,非常感谢tushare的数据!! 为什么要用LSTM? LSTM是一种序列模型 ...

  6. 【基于MATLAB实现LSTM光伏输出功率预测】

    基于MATLAB实现LSTM(长短记忆网络)光伏输出功率预测 背景 近年来,光伏市场在全世界的规模迅速增长.2020 年中国宣布采取新政策来提高国家的贡献力,目标是在2030 年前实现碳达峰,在206 ...

  7. LSTM实现股票预测--pytorch版本【120+行代码】

    简述 网上看到有人用Tensorflow写了的但是没看到有用pytorch写的. 所以我就写了一份.写的过程中没有参照任何TensorFlow版本的(因为我对TensorFlow目前理解有限),所以写 ...

  8. 基于LASSO分位数回归时间序列区间预测

    1.LASSO LASSO回归的特点是在拟合广义线性模型的同时进行变量筛选(variable selection)和复杂度调整(regularization).因此,不论因变量是连续的(continu ...

  9. 基于keras 搭建LSTM GRU模型预测 共享单车使用情况 完整代码+数据 数据分析 计算机毕设

    项目运行教程:https://www.bilibili.com/video/BV1nT411k7dT/?spm_id_from=333.999.0.0 附完整代码数据:

  10. 如何建立Multi-Step(多步预测)的LSTM时间序列模型(以对家庭用电预测为例)

    译自How to Develop LSTM Models for Multi-Step Time Series Forecasting of Household Power Consumption~ ...

最新文章

  1. VMware Workstation 15 Pro 永久激活密钥 下载
  2. 从互联网大脑模型看腾讯与今日头条之争
  3. select模型详解
  4. Python爬虫项目,获取所有网站上的新闻,并保存到数据库中,解析html网页等(未完待续)
  5. python 判断类是否存在某个属性或方法
  6. pdfminer3k 官方文档_IPFS官方周报112期
  7. 2016030207 - sql50题练习(脚本)
  8. linux进程理解,进程资源 - 进程基础 - [ 理解Linux进程 ] - 手册网
  9. Bootstrap框架常用总结
  10. いちゃコミュ+~いちゃいちゃコミュニケーション プラス 汉化补丁
  11. iOS修改手游服务器数据,iOS 教你修改运动步数(基于Healthkit)
  12. C#中汉字按照首字拼音排序
  13. 安装itunes需要管理员身份_iTunes安装失败 iTunes安装出错解决方法
  14. 产品经理入门攻略(3岁的PM成长分享)
  15. ppt制作的一些要点
  16. 硬盘柱面损坏怎么办_硬盘在坏道检测中出现了要多少个损坏柱面才说明这个硬盘废了?...
  17. C++实现双人枪战游戏
  18. java 获取图片像素_转:java提取图片中的像素
  19. java 调用matlab rank_科学网—Matlab: X is rank deficient - 李旭的博文
  20. Ubuntu 14.04 Linux 3D桌面完全教程,显卡驱动安装方法,compiz特效介绍,常见问题解答

热门文章

  1. 《Adobe Photoshop CS6中文版经典教程(彩色版)》目录—导读
  2. 安装Apache服务器
  3. oracle清理磁盘空间
  4. 富士通Fujitsu DPK9500GA Pro 打印机驱动
  5. 开始使用Mac OS X——写给Mac新人
  6. js锅打灰太狼小游戏
  7. Procdump+Mimikatz获取Windows明文密码
  8. sudo rosdep init命令报错ERROR: cannot download default sources list from:....Website may be down【绝对有用】
  9. 什么是0day漏洞?
  10. (窗口隐藏工具 3.40)自动隐藏指定的应用窗口及托盘图标