使用DeepAR实现股价预测

文章目录

  • 使用DeepAR实现股价预测
  • 获取股票列表
    • 从众多股票中采样100支
      • 日期处理函数
      • 拉取等长度的股票,并保存
  • 各指标解释
  • 预测区间长度及上下文选取
    • 给这78支股票所在行业进行归类
  • 目标变量处理
  • 协变量处理
    • 协变量归一化操作
  • 训练、测试数据划分
  • 训练模型
  • 预测过程
  • 模型评估
    • 结果查看
  • 绘图结果

以往的RNN时间序列预测往往是强调一支股票的股价预测,当提取的一支其他股票的特征时,用于另一支股票预测时就显得捉襟见肘了;当需要对多只股票进行训练及预测时,通常的做法是将他们归类,再进行分别的预测及训练,更重要的是,以往的RNN神经网络(如LSTM等),给出的都是单点预测,在结果服从连续分布的情形下,单点预测的概率其实是0的,我们更希望知道结果的走向,或者框定一个结果走向的范围;

本实验采用DeepAR这个新兴的时间序列预测算法,对78支上市公司股票价格进行训练,训练好的结果可以应用在任意一支股票的预测上(但是本文未给出相关过程),测试集上的表现比较理想。本文仅供学习参考,不作为投资依据;完全原创,转载请注明出处

本文的数据采用了Tushare的大数据接口,感谢Tushare的开发者,为Quanters提供了持续精良的服务

本文的模型采用了mxnet的Deepar模型, deepar模型已经为我们封装好了大多数处理方法,这使得我们的分析过程更加简单快捷,在此一并感谢

import pandas as pd
import tushare as ts
import numpy as np
# 初始化pro接口(该tokens请在tushare个人主页获取)
pro = ts.pro_api('xxx')
np.random.seed(42)

获取股票列表


# 拉取数据
df = pro.stock_basic(**{"ts_code": "","name": "","exchange": "","market": "","is_hs": "","list_status": "L","limit": "","offset": ""
}, fields=["ts_code","symbol","name","area","industry","market","list_date"
])
df.to_csv('./Stock-data/股票代码.csv')
df.head()

从众多股票中采样100支

stock_code = pd.read_csv('./Stock-data/股票代码.csv')name = []
ts_code = []
i = 0
while i < 100:sample = stock_code.sample()if sample['list_date'].values < 20150731 and 'ST' not in sample['name'].values[0]:ts_code.append(sample['ts_code'].values[0])name.append(sample['name'].values[0])i += 1
print(len(name),name)

日期处理函数

def deal_date(date):temp = [date[0:4],date[4:6],date[6:]]new_date = '-'.join(temp)return new_date

拉取等长度的股票,并保存

stock_list = []
for i,j in zip(name,ts_code):# 拉取数据df = pro.daily(**{"ts_code": f"{j}","trade_date": "","start_date": "20190731","end_date": "20220404","offset": "","limit": ""}, fields=["ts_code","trade_date","open","high","low","close","pre_close","change","pct_chg","vol","amount"])if len(df) == 649:df['Date'] = df['trade_date'].apply(deal_date)stock_list.append(i)df['name'] = f'{i}'df.to_csv(f'./Stock-data/{i}.csv')
print(stock_list)
df.head()

各指标解释

  • open 开盘价
  • high 最高价
  • low 最低价
  • close 收盘价
  • pre_close 昨收价
  • change 涨跌额
  • pct_chg 涨跌幅
  • vol 成交量
  • amount 成交额

我将采用open,high,low,close,change,pct_chg,vol,amount及公司所属行业进行时间序列预测

%matplotlib inline
import mxnet as mx
from mxnet import gluon
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import json
import os
from tqdm.autonotebook import tqdm
plt.rcParams['font.sans-serif'] = ['SimHei']  # 用来正常显示中文标签
plt.rcParams['axes.unicode_minus'] = False  # 用来正常显示负号

预测区间长度及上下文选取

prediction_length = 10
context_length = 20
stock = ['荣盛石化', '新联电子', '栖霞建设', '星徽股份', '中国武夷',
'飞凯材料', '河钢资源', '平高电气', '新北洋', '亚太科技', '杭州高新',
'海立股份', '燕塘乳业', '杰瑞股份', '广电电气', '赛摩智能', '成都路桥',
'恒邦股份', '石化油服', '金隅集团', '青龙管业', '同德化工', '科华数据',
'中国东航', '澳柯玛', '亚太药业', '冠城大通', '白云机场', '华东医药',
'全筑股份', '菲利华', '和而泰', '潮宏基', '岭南股份', '索菲亚', '长江证券',
'炬华科技', '嘉事堂', '西藏珠峰', '聆达股份', '北大荒', '七匹狼', '先河环保',
'中国汽研', '鸿博股份', '合金投资', '华银电力', '世纪瑞尔', '东方日升', '新开普',
'亚光科技', '电科院', '粤电力A', '东方雨虹', '普莱柯', '上海机电', '天利科技',
'奥维通信', '华邦健康', '春秋航空', '杰瑞股份', '海峡股份', '京蓝科技', '中海油服',
'温州宏丰', '御银股份', '芒果超媒', '太平洋', '泰豪科技', '申达股份', '众合科技',
'华帝股份', '财信发展', '大金重工', '协鑫集成', '保利联合', '平高电气', '黄河旋风',
'凌云股份',]
stock = list(set(stock))
data = pd.read_csv('./Stock-data/上海梅林.csv', index_col=False,usecols=['Date','high','low','open','close','change','pct_chg','vol','amount','name'])
print(len(data))
for i in stock:temp = pd.read_csv(f'./Stock-data/{i}.csv', index_col=False,usecols=['Date','high','low','open','close','change','pct_chg','vol','amount','name'])data = pd.concat([data,temp],ignore_index=True)
print(len(data))
data.head(10)

给这78支股票所在行业进行归类

total = data.copy()
stock_list = sorted(list(set(total["name"])))
date_list = sorted(list(set(total["Date"])))
data_dic = {"name": stock_list}
industry = {}
for i in stock_list:temp = stock_code[stock_code['name'] == i]['industry'].values[0]if temp not in industry:industry[temp] = [i]else:industry[temp].append(i)
industry
stat_cat_features = []
company_ind = {}
for i,key in enumerate(industry):for com in industry[key]:company_ind[com] = i
cat_cardinality = [i+1]
print(company_ind)
for i in stock_list:stat_cat_features.append([company_ind[i]])
print(stat_cat_features)

目标变量处理

stock_list = sorted(list(set(total["name"])))
date_list = sorted(list(set(total["Date"])))
data_dic = {"name": stock_list}
for date in date_list:tmp = total[total["Date"]==date][["name", "Date", "close"]]tmp = tmp.pivot(index="name", columns="Date", values="close")tmp_values = tmp[date].valuesdata_dic[date] = tmp_values
new_df = pd.DataFrame(data_dic)
new_df.head()

协变量处理

def deal_cov_variables(date_list,var_name):feature_dict = {}for date in date_list:tmp = total[total["Date"]==date][["name", "Date", var_name]]tmp = tmp.pivot(index="name", columns="Date", values=var_name)tmp_values = tmp[date].valuesfeature_dict[date] = tmp_valuesfeature_df = pd.DataFrame(feature_dict)return feature_df
cov_variables = ['high','low','open','close','change','pct_chg','vol','amount']
feature_df_list = []for i in cov_variables:feature_df_list.append(deal_cov_variables(date_list,i))feature_df_list[0].head()

协变量归一化操作

def min_max_scale(lst):'''# 基于日期级别的归一化:input shape (bank_num,days):output shape (bank_num,days)'''new = []for i in range(len(lst[0])):minimum = min(lst[:,i])maximum = max(lst[:,i])new.append((lst[:,i] - minimum) / (maximum - minimum))return np.array(new).T
dynamic_feats = []
for i in range(len(feature_df_list)):one_feature = min_max_scale(np.array(feature_df_list[i]))dynamic_feats.append(one_feature)
print(one_feature.shape)
dynamic_feats = np.array(dynamic_feats).reshape(-1,len(feature_df_list),len(date_list))
print(dynamic_feats.shape) # (stock_num, feature_num, date_num)

训练、测试数据划分

from gluonts.dataset.common import ListDataset
from gluonts.dataset.field_names import FieldName# test_target_values是649天的实际结果y
train_df = new_df.drop(["name"], axis=1).values
train_df.reshape(-1,len(date_list))
test_target_values = train_df.copy()
print(len(train_df[0]))
# train_target_values是639天的实际结果y,不能让模型训练到后10天,这样才能看出效果 (将649天shift10天)
train_target_values = [ts[:-prediction_length] for ts in train_df]
print(len(train_target_values[0]))
start_date = [pd.Timestamp("2019-07-31", freq='B') for _ in range(len(new_df))]
train_ds = ListDataset([{FieldName.TARGET: target,FieldName.START: start,FieldName.FEAT_DYNAMIC_REAL: dynamic_feat[:,:-prediction_length],FieldName.FEAT_STATIC_CAT:cat_feature,}for (target, start,dynamic_feat,cat_feature) in zip(train_target_values,start_date,dynamic_feats,stat_cat_features)
], freq="1B")test_ds = ListDataset([{FieldName.TARGET: target,FieldName.START: start,FieldName.FEAT_DYNAMIC_REAL: dynamic_feat,FieldName.FEAT_STATIC_CAT:cat_feature,}for (target, start,dynamic_feat,cat_feature) in zip(test_target_values,start_date,dynamic_feats,stat_cat_features)
], freq="1B")
sample_trian = next(iter(train_ds))

训练模型

from gluonts.model.deepar import DeepAREstimator
from gluonts.mx.distribution.gaussian import GaussianOutput
from gluonts.mx.trainer import Trainern = 100
estimator = DeepAREstimator(prediction_length=prediction_length,context_length=context_length,freq="1B",distr_output = GaussianOutput(),use_feat_dynamic_real=True,dropout_rate=0.1,use_feat_static_cat=True,cardinality=cat_cardinality,trainer=Trainer(learning_rate=1e-3,epochs=n,num_batches_per_epoch=50,batch_size=32)
)
predictor = estimator.train(train_ds)

预测过程

from gluonts.evaluation.backtest import make_evaluation_predictionsforecast_it, ts_it = make_evaluation_predictions(dataset=test_ds,predictor=predictor,num_samples=100
)print("Obtaining time series conditioning values ...")
tss = list(tqdm(ts_it, total=len(test_ds)))
print("Obtaining time series predictions ...")
forecasts = list(tqdm(forecast_it, total=len(test_ds)))

模型评估

from gluonts.evaluation import Evaluatorclass CustomEvaluator(Evaluator):def get_metrics_per_ts(self, time_series, forecast):successive_diff = np.diff(time_series.values.reshape(len(time_series)))successive_diff = successive_diff ** 2successive_diff = successive_diff[:-prediction_length]denom = np.mean(successive_diff)pred_values = forecast.samples.mean(axis=0)true_values = time_series.values.reshape(len(time_series))[-prediction_length:]num = np.mean((pred_values - true_values) ** 2)rmsse = num / denommetrics = super().get_metrics_per_ts(time_series, forecast)metrics["RMSSE"] = rmssereturn metricsdef get_aggregate_metrics(self, metric_per_ts):wrmsse = metric_per_ts["RMSSE"].mean()agg_metric, _ = super().get_aggregate_metrics(metric_per_ts)agg_metric["MRMSSE"] = wrmssereturn agg_metric, metric_per_tsevaluator = CustomEvaluator(quantiles=[0.5, 0.67, 0.95, 0.99])
agg_metrics, item_metrics = evaluator(iter(tss), iter(forecasts), num_series=len(test_ds))
print(json.dumps(agg_metrics, indent=4))

结果查看

a = forecasts[0]
print(a.mean)
print(a.quantile(0.95))
import warnings
warnings.filterwarnings("ignore")
plot_log_path = "./plots/"
directory = os.path.dirname(plot_log_path)
if not os.path.exists(directory):os.makedirs(directory)def plot_prob_forecasts(ts_entry, forecast_entry, path, sample_id, name, inline=True):plot_length = 150prediction_intervals = (50, 67, 95, 99)legend = ["observations", "median prediction"] + [f"{k}% prediction interval" for k in prediction_intervals][::-1]_, ax = plt.subplots(1, 1, figsize=(10, 7))ts_entry[-plot_length:].plot(ax=ax)forecast_entry.plot(prediction_intervals=prediction_intervals, color='g')ax.axvline(ts_entry.index[-prediction_length], color='r')plt.legend(legend, loc="upper left")plt.title(f'{name} Price series and predict results')if inline:plt.show()plt.clf()else:plt.savefig('{}forecast_{}.pdf'.format(path, sample_id))plt.close()print("Plotting time series predictions ...")
for i in tqdm(range(20,30)):ts_entry = tss[i]forecast_entry = forecasts[i]name = stock_list[i]plot_prob_forecasts(ts_entry, forecast_entry, plot_log_path, i, name)

绘图结果





欢迎交流,实验不易,转载请注明出处!!!

使用DeepAR实现股价预测相关推荐

  1. 【时间序列预测】股价预测零售预测

    股价预测.零售时间序列预测 1.什么是时间序列预测 时间序列(time series)是一组按照时间发生先后顺序进行排列的数据 时间序列(time series forecaing,简称时序预测.预估 ...

  2. 基于python的马尔科夫链在股价预测中的应用(基于Tushare)

    TushareID:503535 文章目录 前言 一.马尔科夫链是什么 二.代码如下 三.马尔科夫预测模型在股价预测中的应用 1.数据来源及状态划分 2.状态转移概率矩阵 P 及初始状态概率向量 图片 ...

  3. MATLAB-基于长短期记忆网络(LSTM)的SP500的股票价格预测 股价预测 matlab实战 数据分析 数据可视化 时序数据预测 变种RNN 股票预测

    MATLAB-基于长短期记忆网络(LSTM)的SP500的股票价格预测 股价预测 matlab实战 数据分析 数据可视化 时序数据预测 变种RNN 股票预测 摘要 近些年,随着计算机技术的不断发展,神 ...

  4. pytorch LSTM的股价预测

    股价预测一直以来都是幻想能够被解决的问题,本文中主要使用了lstm模型去对股价做一个大致的预测,数据来源是tushare,非常感谢tushare的数据!! 为什么要用LSTM? LSTM是一种序列模型 ...

  5. Python使用GARCH,EGARCH,GJR-GARCH模型和蒙特卡洛模拟进行股价预测

    全文下载链接:http://tecdat.cn/?p=20678 在本文中,预测股价已经受到了投资者,政府,企业和学者广泛的关注.然而,数据的非线性和非平稳性使得开发预测模型成为一项复杂而具有挑战性的 ...

  6. 银行股价预测——基于pytorch框架RNN神经网络

    银行股价预测--基于pytorch框架RNN神经网络 任务目标 数据来源 完整代码 流程分析 1.导包 2.读入数据并做预处理 3.构建单隐藏层Rnn模型 4.设计超参数,训练模型 5.加载模型,绘图 ...

  7. 利用LSTM进行股价预测

    利用LSTM进行股价预测 效果 原理 代码 应用 效果 原理 LSTM即长短记忆网络,是一种很强的RNN,这种网络的特性是以前的输入会影响现在的输出,具体原理请自行搜索. 算法流程: 获取yahoo财 ...

  8. 【python量化】大幅提升预测性能,将NSTransformer用于股价预测

    写在前面 NSTransformer模型来自NIPS 2022的一篇paper<Non-stationary Transformers: Exploring the Stationarity i ...

  9. [Kaggle比赛] 高频股价预测小结

    高频股价预测 文章目录 高频股价预测 问题描述 问题分析 数据分析 数据集 数据清洗 解决方案 数据预处理 归一化 Prices Volume 时间信息 对于预测值的处理 噪声 模型探索 基于LSTM ...

  10. 【股价预测】基于matlab SVM股票价格预测【含Matlab源码 180期】

    一.获取代码方式 获取代码方式1: 完整代码已上传我的资源:[股价预测]基于matlab SVM股票价格预测[含Matlab源码 180期] 点击上面蓝色字体,直接付费下载,即可. 获取代码方式2: ...

最新文章

  1. C++内存对象大会战
  2. unix 查询进程并中止
  3. SQL Server 获取表或视图结构信息
  4. Android 程序适应多种多分辨率
  5. 利用 51 定时器生成 PWM
  6. hdu 1564 Play a game
  7. pip 安装 scipy 出现错 no lapack/blas resources
  8. laravel order 按时间升序_Cache and Related Part3: Coherence amp; Order
  9. 防火墙的三种工作模式介绍(路由模式、透明模式(网桥)、混合模式)
  10. 单点漫延问题(水陆判断、洪水漫延、无权最小路径)
  11. Fine Dining G
  12. html显示和隐藏文字特效,14款震撼人心的HTML5文字特效
  13. Gronwall 不等式
  14. 新栏目上线|我是戴小乐-集美貌与才华于一身~
  15. RR RC 隔离级别
  16. MFC Group Box 组合框的简单使用 笔记
  17. c# record的使用场景
  18. Metasploit 渗透测试手册第三版 第三章 服务端漏洞利用(翻译)
  19. m无线通信信道matlab仿真,包括自由空间损耗模型,Okumura-Hata模型,COST231 Hata模型,SUI信道模型
  20. 双线双路网络路由如何设置?

热门文章

  1. 轻健身餐的市场前景如何?如何选择一个投资小、美食和健身餐清淡的品牌?
  2. [VS2010]逸雨清风 永久稳定音乐外链生成软件V0.1
  3. usb右下角有显示,计算机没显示,U盘显示在计算机的右下角,但无法打开
  4. Zencart模板结构和设计详解
  5. UE4场景流程规范-纹理压缩(美术版/程序版/太长不看版)
  6. 网络会议openmeetings下的openmeetings-util文件分析7
  7. 008_SSSS_ Improved Denoising Diffusion Probabilistic Models
  8. java学习day10(Java基础)特殊类
  9. 裸设备和Oracle问答20例
  10. 解决DELL WIN7 bootmgr is missing