目录:

  1. 业务场景
  2. 环境搭建
  3. 数据及目录结构
  4. 模型
  5. 代码(建模、训练)
  6. 预测及结果呈现


文章只是对模型的学习与实践做简要记录,以免日后给忘了,并没有对模型优劣、应用的场景等理论方面有过多分析。适合快速动手搭建,成功运行、分析代码,并学习怎样用keras实现模型的同学。为更好的阅读下文,需提前下载模型代码(https://github.com/lucktroy/DeepST,不包含预测代码和分析参数的代码)和数据集(北京出租车:http://pan.baidu.com/s/1qYq7ja8,BikeNYC:http://pan.baidu.com/s/1mhIPrRE)。

业务场景: 

模型主体是基于残差神经网络,同时加上在时间维度的数据采样,主要的适用场景应是具有空间关系、并具有时间规律的数据。例如预测某区域人群流动、密度,交通流动等,其数据特点:在空间维度上,会受到周围区域的影响,在时间维度上,每天、每周甚至月、年会出现一定规律的变化。模型、数据以及研究成果来源于微软研究院 郑宇教授,相关学习资料https://github.com/lucktroy/DeepST,http://www.weiot.net/article-95817.html,https://www.microsoft.com/en-us/research/publication/deep-spatio-temporal-residual-networks-for-citywide-crowd-flows-prediction/。

环境搭建:

Windows、Linux都可搭建运行环境,建议linux。

(1)linux

Anaconda2-5.0.1-Linux-x86_64.sh对应python 2.7、keras 2.0.6、TensorFlow1.2.1(或theano 0.9.0)

(2)Windows

Anaconda 5.0.1 python 3.6、keras 2.0.8、theano 0.9.0(对应TensorFlow版本不太好找)

注:1. 源代码是python2.7版本的,python3.6需要修改部分代码

2. keras配置文件.keras.json可修改backend后台支持(theano /TensorFlow),具体可百度

3. .keras.json中修改image_data_format为channels_first

4. 添加数据的路径到环境变量(诡异的很):建立DATAPATH,对应的值为数据路径(如:E:\project\peopleFlowPredict\data)

(3)deepst包安装

除上面基础包安装外,还需安装模型自带的包deepst,当前目录下python setup.py install 和python setup.py develop

数据及目录结构:

目录:

data:数据目录,deepst程序依赖包目录,scripts主体程序入口,script中分为北京出租车、纽约自行车、**省人群流动预测:exptTaxiBJ.py模型建模训练程序,HB_analyse.py测试集分析呈现,HB_prediction.py模型预测函数,HB_weightShow.py模型参数分析呈现,MODEL训练生成的模型,RET模型训练过程记录,preprocessing.pkl保存归一化对象,数据集data:分别是北京出租车、纽约自行车、**省人群流动数据集,CATCH为第一次运行程序读取原数据集后,自动生成模型的输入数据集。

数据

数据是以图的形式,逻辑上为矩阵,如北京出租车数据,首先将北京地区栅格化为32*32的逻辑矩阵,每个样本为某个时间点的出租车在整个北京区域栅格的流动分布。数据分为原数据集和缓存数据集,程序第一次运行会读取原始数据,程序会根据设置的时间采样参数读取位置数据,结合天气、节假日等特征生成模型所需要的输入数据。数据集分为训练集和测试集,例如北京出租车测试集选取后四月的数据。

原始数据集中:分为位置数据集、气象数据集和节假日数据集。位置数据集为shape(n,2,32,32),n为样本个数,2为输入输出两维度,32*32map流动值。external feature维度28维,其中星期7维+1维是否为工作日+1维是否为节假日+19维气象,温度、风速归一化,其他为0-1编码。

CACHE:数据集中含有11对键值对,训练数据和测试数据各5对+1对external feature的维度。5对训练数据中包含近(shape(n,6,32,32))、中(shape(n,2,32,32))、远(shape(n,2,32,32))、external 特征(shape(n,28))、对应的时间序列(shape(n,)),其中6和2代表设定近中远的取样参数为3和1,再乘以输入输出两维度。测试集类似。

模型输入数据结构:训练集和测试集结构类似,只是样本个数不同。模型所需要的输入数据的维度为shape(4,n,m,32,32),其,4代表近中远+external ,n为样本个数,m对应的近中远参数值*2,4拆开分别对应是近(shape(n,6,32,32))、中(shape(n,2,32,32))、远(shape(n,2,32,32))、external 特征(shape(n,28))。

模型

三个角度出发,空间、时间、额外因素建模,以残差神经网络(DNN改进,可让网络层数变得更深)为模型基础。时间影响:时间轴上分段采样,近(时刻)、中(天)、远(周)。空间影响:在近中远三段数据上分别采用多层残差神经网络。额外因素:利用全连接神经网络建模。近、中、远三个模型输出经过加权融合后生成图中Xrex,并和Xext激活后生成X't。结合下图理解。

上面近中远三段数据分布残差神经网络建模,其模型结构类似。模型先是通过一层卷积,然后通过n层的残差(每个残差单元含有两个卷积层),最后再通过一层卷积。模型adam训练,学习率0.0002,近、中、远为3、1、1,每天时刻划分T为48(每个24/T生成一个流动数据图),代价函数和检验标准为rmse,模型训练停止方法有两种:一个是early-stopping 一个是设置最大迭代次数。卷积核为3*3,个数为64。

代码(建模及训练)

模型的代码实现是基于keras,keras的banckend可以选择tensorflow或者theano

代码执行:比如北京出租车,在目录scripts\papers\AAAI17\TaxiBJ中执行 python exptTaxiBJ.py 2,其中2代表残差神经网络层数。

exptTaxiBJ.py中,主要函数:def build_model(建模)def cache(生成数据缓存)def read_cache(读取缓存)def main(程序入口函数)。

mmn = pickle.load(open('preprocessing.pkl', 'rb'))获取归一化对象

model = stresnet(c_conf=c_conf...)具体建模函数

TaxiBJ.load_data(...)按照时间段采样原始数据

early_stopping = EarlyStopping最早停止条件定义

model_checkpoint = ModelCheckpoint模型训练条件

history = model.fit模型训练

model.save_weights保存模型参数

model.load_weights载入模型参数

score = model.evaluate模型评估

注:存在两次模型训练,两次的截止条件不一样

STResNet.py建模文件,def stresnet建模主函数,def _residual_unit残差单元构造,其中具体语法参照keras管网

TaxiBJ.py:北京出租车的载入数据函数文件,def load_holiday(读取节假日)def load_meteorol(读取气象)def load_data(读取位置数据)

STMatrix.py:按照设置的长中短时间参数采样,create_dataset主要的采样原数据函数,其中i代表当前要预测的数据,while循环读取每条i的长中短的数据(X)及当前i的数据(Y)

预测及结果呈现

HB_prediction.py(预测函数):模型训练会生产模型参数,保存在modle文件夹中。预测函数首先按照定义的模型建模,并加载训练好的参数,输入测试数据便可得到预测值y,此函数可预测未来一段时间段的y'。其中,predictNext()函数是用预测值作为输入预测下一刻的值。predicVal()函数是获取真实值和真实值作为输入的预测值。

HB_weightShow.py(参数分析函数),此函数是读取keras生成的参数文件,并将远中近的融合时的参数呈现,可以分别看出在远中近三个时间段不同栅格位置的权重影响。

Deep Spatio-Temporal Residual Networks(深度时空残差神经网络)相关推荐

  1. 深度时空残差网络在城市人流量预测中的应用

    文章目录 摘要 简介 预备知识 人流量问题的制定 深度残差学习 深度时空残差网络 前三个成分的结构 外部组件的结构 融合 算法和优化 实验 设置 结果TaxiBJ 结果BikeNYC 相关工作 总结及 ...

  2. 深度学习——残差神经网络ResNet在分别在Keras和tensorflow框架下的应用案例

    原文链接:https://blog.csdn.net/loveliuzz/article/details/79117397 一.残差神经网络--ResNet的综述 深度学习网络的深度对最后的分类和识别 ...

  3. 深度学习理论——残差神经网络

    大家好,继续理论学习,今天学习了何凯明大神的又一力作,残差神经网络. 我们知道,网络层数越深,其表达能力越强,但之前人们并不敢把网络设计的太深,因为会有梯度衰减等各种问题让网络的性能大幅度下滑,用一些 ...

  4. 【Pytorch】残差神经网络(Residual Networks)

    一.背景 传统的神经网络,由于网络层数增加,会导致梯度越来越小,这样会导致后面无法有效的训练模型,这样的问题成为梯度消弭.为了解决这样的问题,引入残差神经网络(Residual Networks),残 ...

  5. CNN 常用的几个模型 LeNet5 AlexNet VGGNet Google Inception Net 微软ResNet残差神经网络

    LeNet5 LeNet-5:是Yann LeCun在1998年设计的用于手写数字识别的卷积神经网络,当年美国大多数银行就是用它来识别支票上面的手写数字的,它是早期卷积神经网络中最有代表性的实验系统之 ...

  6. 深度学习笔记:Deep Residual Networks with Dynamically Weighted Wavelet Coefficients for Fault Diagnosis of

    深度学习笔记:Deep Residual Networks with Dynamically Weighted Wavelet Coefficients for Fault Diagnosis of ...

  7. Identity Mappings in Deep Residual Networks

    论文地址:Identity Mappings in Deep Residual Networks 译文地址:http://blog.csdn.net/wspba/article/details/607 ...

  8. 残差网络(Residual Networks, ResNets)

    1. 什么是残差(residual)? "残差在数理统计中是指实际观察值与估计值(拟合值)之间的差.""如果回归模型正确的话, 我们可以将残差看作误差的观测值." ...

  9. 《Enhanced Deep Residual Networks for Single Image Super-Resolution》论文阅读之EDSR

    导读 韩国首尔大学的研究团队提出用于图像超分辨率任务的新方法,分别是增强深度超分辨率网络 EDSR 和一种新的多尺度深度超分辨率 MDSR,在减小模型大小的同时实现了比当前其他方法更好的性能,分别赢得 ...

  10. 12.深度学习练习:Residual Networks(注定成为经典)

    本文节选自吴恩达老师<深度学习专项课程>编程作业,在此表示感谢. 课程链接:https://www.deeplearning.ai/deep-learning-specialization ...

最新文章

  1. 2017 CIO展望:新IT运营模式的5大元素
  2. linux lvm snapshot lvm 快照 逻辑卷 快照
  3. Django报错:mysql ImproperlyConfigured: mysqlclient 1.3.13 or newer is required, you have 0.9.3的解决办法
  4. JavaScript基础一
  5. ACM-ICPC国际大学生程序设计竞赛北京赛区(2017)网络赛 A题 Visiting Peking University
  6. CodeForces - 1321B Journey Planning(思维)
  7. 机器学习-回归之逻辑回归算法原理及实战
  8. 14-angular.isDefined
  9. 【转】C# 中@符号在字符串中的作用
  10. ldconfig清理缓存
  11. python打印列出目录及其子目录里面的内容
  12. android revre view,MK802 4.0.4 CWM Recovery
  13. PDA开发从入门到精通
  14. Cadence PSpice 仿真0:绘制电路图方法图文教程
  15. F2FS源码分析-1.6 [F2FS 元数据布局部分] Segment Summary Area-SSA结构
  16. UI行业就业前景怎样 如何成为合格的UI设计师
  17. 什么叫磁场强度、磁通势、磁阻、导磁率、电磁力、涡流?
  18. 传阿里巴巴集团推迟上市至2015年底
  19. 使用vCenter Converter 工具P2V迁移windos server 2003 sp1-sp2操作系统遇到的问题
  20. 掌财社财经_美债收益率提高 黄金持续承压

热门文章

  1. 配置淘宝Maven镜像仓库
  2. 5G新型调制技术FBMC【5G】
  3. blender基本翻译+快捷键
  4. faketime实现游戏服务器时间定制
  5. 三极管作开关应用及详解
  6. Scintilla开源库使用指南(一)
  7. c4d流体插件_C4D的Jet Fluids免费流体插件
  8. 华为交换机配置Vlan
  9. 软考软件设计师考试总结(2019下半年)
  10. C语言简单连点器网课必备