LSTM是递归神经网络(RNN)的一个变种,相较于RNN而言,解决了记忆消失的问题,用来处理序列问题是一个很好的选择。本文主要介绍如何使用DL4J中的LSTM来执行回归分析。如果不清楚RNN和LSTM,可以先阅读 LSTM和递归网络教程 以及 通过DL4J使用递归网络 ,特别是不熟悉RNN输入和预测方式的强烈建议先阅读这两个教程。如果不太会建立DL4J的工程,建议在其样例工程中进行本实验。

言归正传,文本通过使用 LSTM对上证指数历史数据进行回归学习,并给出一个初始序列预测之后20天的大盘收盘价格来演示如何使用LSTM处理简单的序列回归问题。首先是准备数据,可以下载例子中我使用的数据集。那么接下来的问题就分成如下几步:

1. 读入训练数据,并处理成一个DataIterator;

2. 构建一个LSTM的递归神经网络;

3. 迭代训练,并输出预测结果;

4. 调参和优化。

一.处理训练数据

我们的数据是上证指数每个交易日的基本数据,格式为:

股票代码 日期开盘价 收盘价最高价  最低价成交量  成交额涨跌幅

这个文件中的数据是倒序的,也就是说新的数据在最前面,因此在读取数据时需要做一次倒转。我将读取文件的方法放在Dataiterator中。DL4J给出了序列数据处理的DataIterator,但是在本例中我们是自己实现一个DataIterator。代码如下:

package edu.zju.cst.krselee.example.stock;import org.deeplearning4j.datasets.iterator.DataSetIterator;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.dataset.api.DataSetPreProcessor;
import org.nd4j.linalg.factory.Nd4j;import java.io.BufferedReader;
import java.io.FileInputStream;
import java.io.IOException;
import java.io.InputStreamReader;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.NoSuchElementException;/*** Created by kexi.lkx on 2016/8/23.*/
public class StockDataIterator  implements DataSetIterator {private static final int VECTOR_SIZE = 6;//每批次的训练数据组数private int batchNum;//每组训练数据长度(DailyData的个数)private int exampleLength;//数据集private List<DailyData> dataList;//存放剩余数据组的index信息private List<Integer> dataRecord;private double[] maxNum;/*** 构造方法* */public StockDataIterator(){dataRecord = new ArrayList<>();}/*** 加载数据并初始化* */public boolean loadData(String fileName, int batchNum, int exampleLength){this.batchNum = batchNum;this.exampleLength = exampleLength;maxNum = new double[6];//加载文件中的股票数据try {readDataFromFile(fileName);}catch (Exception e){e.printStackTrace();return false;}//重置训练批次列表resetDataRecord();return true;}/*** 重置训练批次列表* */private void resetDataRecord(){dataRecord.clear();int total = dataList.size()/exampleLength+1;for( int i=0; i<total; i++ ){dataRecord.add(i * exampleLength);}}/*** 从文件中读取股票数据* */public List<DailyData> readDataFromFile(String fileName) throws IOException{dataList = new ArrayList<>();FileInputStream fis = new FileInputStream(fileName);BufferedReader in = new BufferedReader(new InputStreamReader(fis,"UTF-8"));String line = in.readLine();for(int i=0;i<maxNum.length;i++){maxNum[i] = 0;}System.out.println("读取数据..");while(line!=null){String[] strArr = line.split(",");if(strArr.length>=7) {DailyData data = new DailyData();//获得最大值信息,用于归一化double[] nums = new double[6];for(int j=0;j<6;j++){nums[j] = Double.valueOf(strArr[j+2]);if( nums[j]>maxNum[j] ){maxNum[j] = nums[j];}}//构造data对象data.setOpenPrice(Double.valueOf(nums[0]));data.setCloseprice(Double.valueOf(nums[1]));data.setMaxPrice(Double.valueOf(nums[2]));data.setMinPrice(Double.valueOf(nums[3]));data.setTurnover(Double.valueOf(nums[4]));data.setVolume(Double.valueOf(nums[5]));dataList.add(data);}line = in.readLine();}in.close();fis.close();System.out.println("反转list...");Collections.reverse(dataList);return dataList;}public double[] getMaxArr(){return this.maxNum;}public void reset(){resetDataRecord();}public boolean hasNext(){return dataRecord.size() > 0;}public DataSet next(){return next(batchNum);}/*** 获得接下来一次的训练数据集* */public DataSet next(int num){if( dataRecord.size() <= 0 ) {throw new NoSuchElementException();}int actualBatchSize = Math.min(num, dataRecord.size());int actualLength = Math.min(exampleLength,dataList.size()-dataRecord.get(0)-1);INDArray input = Nd4j.create(new int[]{actualBatchSize,VECTOR_SIZE,actualLength}, 'f');INDArray label = Nd4j.create(new int[]{actualBatchSize,1,actualLength}, 'f');DailyData nextData = null,curData = null;//获取每批次的训练数据和标签数据for(int i=0;i<actualBatchSize;i++){int index = dataRecord.remove(0);int endIndex = Math.min(index+exampleLength,dataList.size()-1);curData = dataList.get(index);for(int j=index;j<endIndex;j++){//获取数据信息nextData = dataList.get(j+1);//构造训练向量int c = endIndex-j-1;input.putScalar(new int[]{i, 0, c}, curData.getOpenPrice()/maxNum[0]);input.putScalar(new int[]{i, 1, c}, curData.getCloseprice()/maxNum[1]);input.putScalar(new int[]{i, 2, c}, curData.getMaxPrice()/maxNum[2]);input.putScalar(new int[]{i, 3, c}, curData.getMinPrice()/maxNum[3]);input.putScalar(new int[]{i, 4, c}, curData.getTurnover()/maxNum[4]);input.putScalar(new int[]{i, 5, c}, curData.getVolume()/maxNum[5]);//构造label向量label.putScalar(new int[]{i, 0, c}, nextData.getCloseprice()/maxNum[1]);curData = nextData;}if(dataRecord.size()<=0) {break;}}return new DataSet(input, label);}public int batch() {return batchNum;}public int cursor() {return totalExamples() - dataRecord.size();}public int numExamples() {return totalExamples();}public void setPreProcessor(DataSetPreProcessor preProcessor) {throw new UnsupportedOperationException("Not implemented");}public int totalExamples() {return (dataList.size()) / exampleLength;}public int inputColumns() {return dataList.size();}public int totalOutcomes() {return 1;}@Overridepublic List<String> getLabels() {throw new UnsupportedOperationException("Not implemented");}@Overridepublic void remove() {throw new UnsupportedOperationException();}
}

StockDataIterator实现了DataIterator接口,于是需要实现几个必须的方法,例如hasNext、next、reset……用来进行每一批次DataSet的获取,loadData和readDataFromFile用来获取数据,并保存在一个DailyData类型的List中,每次调用next方法时,就会从List取出当前需要的数据,并构造成DataSet,返回给调用者。DailyData的实现如下:

package edu.zju.cst.krselee.example.stock;/*** Created by kexi.lkx on 2016/8/23.*/
public class DailyData {//开盘价private double openPrice;//收盘价private double closeprice;//最高价private double maxPrice;//最低价private double minPrice;//成交量private double turnover;//成交额private double volume;public double getTurnover() {return turnover;}public double getVolume() {return volume;}public DailyData(){}public double getOpenPrice() {return openPrice;}public double getCloseprice() {return closeprice;}public double getMaxPrice() {return maxPrice;}public double getMinPrice() {return minPrice;}public void setOpenPrice(double openPrice) {this.openPrice = openPrice;}public void setCloseprice(double closeprice) {this.closeprice = closeprice;}public void setMaxPrice(double maxPrice) {this.maxPrice = maxPrice;}public void setMinPrice(double minPrice) {this.minPrice = minPrice;}public void setTurnover(double turnover) {this.turnover = turnover;}public void setVolume(double volume) {this.volume = volume;}@Overridepublic String toString(){StringBuilder builder = new StringBuilder();builder.append("开盘价="+this.openPrice+", ");builder.append("收盘价="+this.closeprice+", ");builder.append("最高价="+this.maxPrice+", ");builder.append("最低价="+this.minPrice+", ");builder.append("成交量="+this.turnover+", ");builder.append("成交额="+this.volume);return builder.toString();}
}

代码中对数据的各个维度进行了归一化处理,方法是记录每个维度的最大值,构造特征向量与标签时用原始数值除以最大值,得到0-1之间的数,归一化的好处在于使训练过程收敛变快,读者也可以试试不归一化的情况,比较两者的差别。

二.构建LSTM网络

       本例中我构造了一个两个隐含层的LSTM网络,隐含层激活函数是tanh,输出层使用identity函数来执行回归。输入单元数为6,因为单个向量是6维的(开盘价、收盘价、最高价、最低价、成交量、成交额);输出单元数为1,用于预测第二天收盘价,代码如下:
    private static final int IN_NUM = 6;private static final int OUT_NUM = 1;private static final int Epochs = 100;private static final int lstmLayer1Size = 50;private static final int lstmLayer2Size = 100;public static MultiLayerNetwork getNetModel(int nIn,int nOut){MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).iterations(1).learningRate(0.1).rmsDecay(0.5).seed(12345).regularization(true).l2(0.001).weightInit(WeightInit.XAVIER).updater(Updater.RMSPROP).list().layer(0, new GravesLSTM.Builder().nIn(nIn).nOut(lstmLayer1Size).activation("tanh").build()).layer(1, new GravesLSTM.Builder().nIn(lstmLayer1Size).nOut(lstmLayer2Size).activation("tanh").build()).layer(2, new RnnOutputLayer.Builder(LossFunctions.LossFunction.MSE).activation("identity").nIn(lstmLayer2Size).nOut(nOut).build()).pretrain(false).backprop(true).build();MultiLayerNetwork net = new MultiLayerNetwork(conf);net.init();net.setListeners(new ScoreIterationListener(1));return net;}

这段代码中有很多参数可以进行调整来寻找最优的拟合效果或调整训练速率,比如隐含层单元数目、激活函数、学习速率、正则化因子……构造好网络后加入一个ScoreIterationListener来监听每次迭代训练后的得分。

三.执行迭代训练

第二部分里面我们设置了完整训练集的迭代次数Epochs为100,表示用整个数据集反复训练100次,训练部分代码如下:
 public static void train(MultiLayerNetwork net,StockDataIterator iterator){//迭代训练for(int i=0;i<Epochs;i++) {DataSet dataSet = null;while (iterator.hasNext()) {dataSet = iterator.next();net.fit(dataSet);}iterator.reset();System.out.println();System.out.println("=================>完成第"+i+"次完整训练");INDArray initArray = getInitArray(iterator);System.out.println("预测结果:");for(int j=0;j<20;j++) {INDArray output = net.rnnTimeStep(initArray);System.out.print(output.getDouble(0)*iterator.getMaxArr()[1]+" ");}System.out.println();net.rnnClearPreviousState();}}private static INDArray getInitArray(StockDataIterator iter){double[] maxNums = iter.getMaxArr();INDArray initArray = Nd4j.zeros(1, 6, 1);initArray.putScalar(new int[]{0,0,0}, 3433.85/maxNums[0]);initArray.putScalar(new int[]{0,1,0}, 3445.41/maxNums[1]);initArray.putScalar(new int[]{0,2,0}, 3327.81/maxNums[2]);initArray.putScalar(new int[]{0,3,0}, 3470.37/maxNums[3]);initArray.putScalar(new int[]{0,4,0}, 304197903.0/maxNums[4]);initArray.putScalar(new int[]{0,5,0}, 3.8750365e+11/maxNums[5]);return initArray;}

每当进行一次完整集的训练之后,我们初始化了一个初始序列进行预测之后20个序列的输出。整个程序主函数如下:

    public static void main(String[] args) {String inputFile = StockRnnPredict.class.getClassLoader().getResource("stock/sh000001.csv").getPath();int batchSize = 1;int exampleLength = 30;//初始化深度神经网络StockDataIterator iterator = new StockDataIterator();iterator.loadData(inputFile,batchSize,exampleLength);MultiLayerNetwork net = getNetModel(IN_NUM,OUT_NUM);train(net, iterator);}
      迭代100次后,得到的输出序列如下:
3489.9679512619973 3516.991701169014 3510.4443733012677 3490.410951650143 3476.138713735342 3469.275475754738 3466.278687063456 3464.9017547094822 3464.2161934530736 3463.8574357616903 3463.670068384409 3463.582194536925 3463.5545977914335 3463.5658543586733 3463.6010765206815 3463.650460170508 3463.7067430067063 3463.764115188122 3463.8196717941764 3463.8705079042916 

四.进一步优化

本文主要介绍LSTM的使用方法,并不是真的说如此就能准确的预测大盘走势了(当然,网络是有可能真的学习到一些大盘走势特征进行),想要做到这一点需要对本例进行许多调整,比如获取更全面的每日大盘信息,选取更多合适的维度来构建特征向量,当然也可以调整预测值,不仅仅是预测收盘价而已。另外可以调整在第二部分提到的那些参数。

DeepLearning4J入门——使用LSTM进行大盘回归相关推荐

  1. Deeplearning4j - 入门视频

    作者:寒沧 链接:https://www.jianshu.com/p/566fc3db676b 来源:简书 DeepLearning4J(DL4J)是一套基于Java语言的神经网络工具包,可以构建.定 ...

  2. Python 数据科学入门教程:机器学习:回归

    Python 数据科学入门教程:机器学习:回归 原文:Regression - Intro and Data 译者:飞龙 协议:CC BY-NC-SA 4.0 引言和数据 欢迎阅读 Python 机器 ...

  3. 热门数据挖掘模型应用入门(一): LASSO回归

    热门数据挖掘模型应用入门(一): LASSO回归 目录: • 模型简介 • 线性回归 • Logistic回归 • Elstic Net理论简介 • 学习资料 模型简介 Kaggle网站(https: ...

  4. Keras入门教程 3.波士顿房价回归 (MPL)

    Keras入门教程 1.线性回归建模(快速入门) 2.线性模型的优化 3.波士顿房价回归 (MPL) 4.卷积神经网络(CNN) 5.使用LSTM RNN 进行时间序列预测 6.Keras 预训练模型 ...

  5. 入门 | 如何为你的回归问题选择最合适的机器学习方法?

    点击"小詹学Python","星标"或"置顶" 关键时刻,第一时间送达 本文转载自"AI算法之心" 在目前的机器学习领域 ...

  6. DeepLearning4J入门——让计算机阅读《天龙八部》

    很早在实验室就看见钱宝宝用Google的Word2Vector来阅读<天龙八部>并找出与指定词最相关的几个词,最近正好学习新出的深度学习开源项目DeepLearning4J,于是就拿这个例 ...

  7. 【MATLAB第1期】LSTM/GRU网络回归/分类预测改进与优化合集(含录屏操作,持续更新)

    文章目录 一.智能进化算法-LSTM(优化超参数) 1.金枪鱼算法TSO-LSTM--案例1 2.孔雀优化算法(POA)-LSTM--案例1 3.猎人优化算法(HPO)-LSTM--案例1 4.人工大 ...

  8. 机器学习入门|快速掌握逻辑回归模型

    http://blog.itpub.net/29829936/viewspace-2558236/ 2019-01-14 17:30:19 主要内容: 一.逻辑回归的原理 二.极大似然估计 三.逻辑回 ...

  9. python数据预测模型算法_Python AI极简入门:4、使用回归模型预测房价

    一.回归预测 在前面的文章中我们介绍了机器学习主要解决分类.回归和聚类三大问题.今天我们来具体了解一下使用机器学习算法进行回归预测. 回归预测主要用于预测与对象关联的连续值属性,得到数值型的预测数据. ...

最新文章

  1. 【机器学习】基于人工鱼群算法的多元非线性函数寻优
  2. 虚拟主机传奇服务器,虚拟主机市场迅猛发展 演绎网络产品传奇
  3. Dubbo-go 发布 1.5 版,朝云原生迈出关键一步
  4. 阿里DataV可视化大屏基本操作
  5. MySQL COUNT函数优化及count(1)/count(*)/count(列名)的区别
  6. Organizational Data assignment block里value help的determine逻辑
  7. java代码实现解压文件_Java压缩/解压文件的实现代码
  8. 终端服务器安全层在协议流中检测到错误,终端服务器安全层在协议流中检测到错误,并已取消客户端连接...
  9. 跨考计算机报班,考研跨考的经验总结与分享
  10. TensorFlow 学习指南 二、线性模型
  11. win7备份工具_一键重装win7系统教程,如何重装win7系统
  12. utilities(C++)——错误提示
  13. Mac 调整磁盘分区:调整本地与虚拟机内存分区占比
  14. ImageJ-计算创面面积 此博文包含图片 (2014-01-28 15:59:14)
  15. R语言安装包的几种方法
  16. 清理C盘空间,无需命令行,可清理几十G内存,实测有效
  17. U3D场景制作规范(转)
  18. destoon ajax_area_select,destoon城市地区两级联动
  19. Spring Boot项目使用Graphics2D 生成二维码海报图片流返回给前端
  20. c# 屏幕取词的方法

热门文章

  1. 微擎安装遇到一个问题,大佬救救我
  2. 从零维到十维空间如何在纸上用手绘出来
  3. 机器学习第四课:SVM前置知识点(凸优化问题)
  4. 小白学习:李航《统计学习方法》第二版第11章 条件随机场
  5. 好书推荐:《黑客秘笈:渗透测试实用指南》
  6. 数理逻辑 形式可推演与逻辑推论
  7. 卡牌游戏战斗系统的设计和实现二
  8. 计算机桌面怎么情理,关于电脑桌面上的流氓图标要怎么清理???
  9. HTML期末学生大作业 响应式动漫网页作业 html+css+javascript
  10. carton num_Carton先生–世界上第一个卡通系列MadeWithUnity