说到时间序列预测,我想一定首先想到RNN,然后想到LSTM,LSTM原理就不说了,网上有很多相关文章。

下面使用tensorflow2.0来实现预测

不得不说tensorflow2.0 太香了,太简单了,真的是有手就行

在tensorflow中只需要调用已经tensorflow的LSTM模块就行了,比如下面的代码

from tensorflow.keras.layers import Dense,LSTM,Dropoutmodel = tf.keras.Sequential([LSTM(80, return_sequences=True),Dropout(0.2),LSTM(80),Dropout(0.2),Dense(1)
])
model.compile(optimizer='adam',loss='mse',)

这样就创建了一个2层LSTM,每层80个神经元的;同时添加了Droopout函数防止过拟合;使用adam激活函数;使用mse作为损失误差的神经网络。真的炒鸡简单。
主要问题是数据的处理,要做时间序列预测,原理应该是使用前n个时间去预测下一个时间,也就是所模型训练的数据应该是下面这个图这样的数据
所以处理数据才是困难的地方。

下面我使用的数据是在上一个文章中提到的英国站点数据。其他的数据也是大同小异。

百度网盘: https://pan.baidu.com/s/19vKN2eZZPbOg36YEWts4aQ
密码 4uh7

在导入数据时,不知道为什么如果有标红的这一列,就会提示错误,所以我把这个数据直接删了,这列数据对预测也没有影响

然后通过下面代码就可以得到一个包含,日期、流量的的数据

f = pd.read_csv('..\Desktop\AE86.csv')
# 从新设置列标
def set_columns():columns = []for i in f.loc[2]:columns.append(i.strip())return columns
f.columns = set_columns()
f.drop([0,1,2], inplace = True)# data 包含要操作的列
data = pd.DataFrame()
# 想留下哪一行数据,就在这里添加到data中
data['datetime'] = f['Local Date']+' '+f['Local Time']
data['total_flow'] = f['Total Carriageway Flow']
# data['speed'] = f['Speed Value']  速度本文没用到
data['datetime'] = pd.to_datetime(data['datetime'])data['month'] = data['datetime'].apply(lambda date: date.month)
data['day'] = data['datetime'].apply(lambda date: date.day)
data['hour'] = data['datetime'].apply(lambda date:date.hour)
data['minute'] = data['datetime'].apply(lambda date: date.minute)# 数据转格式
data['total_flow'] = np.array(data['total_flow']).astype(np.float64)

处理后的数据如下

之后就是划分训练集和测试集,归一化

# 一月第25天第一个时间的索引值
d25 = data.query('day==25').index[0]
# 训练集  2211个数据,2018年一月前三周
train_set = data.iloc[:d25,1:2]
# 检测集  669个数据,2018年最后一周
test_set = data.iloc[d25:,1:2]# 归一化
sc = MinMaxScaler(feature_range=(0, 1))
train_set_sc = sc.fit_transform(train_set)
test_set_sc = sc.transform(test_set)

下面就是创建LSTM的输入数据,以time_step=5为预测间隔,即使用前5个时间段,预测下一个时间段

time_step = 5
# 按照time_step划分时间步长
x_train = []
y_train = []
x_test = []
y_test = []
for i in range(time_step, len(train_set_sc)):  x_train.append(train_set_sc[i - time_step:i])y_train.append(train_set_sc[i:i + 1])
for i in range(time_step, len(test_set_sc)):  x_test.append(test_set_sc[i - time_step:i])y_test.append(test_set_sc[i:i + 1])
x_test, y_test = np.array(x_test), np.array(y_test)# 随机化,这部分可以不要
np.random.seed(7)
np.random.shuffle(x_train)
np.random.seed(7)
np.random.shuffle(y_train)
tf.random.set_seed(7)# 转为array格式
x_train, y_train = np.array(x_train), np.array(y_train)
x_test, y_test = np.array(x_test), np.array(y_test)
x_train = np.reshape(x_train, (x_train.shape[0], time_step, 1))
x_test = np.reshape(x_test, (x_test.shape[0], time_step, 1))

下面就是构建模型,预测,误差分析,可视化之类的了

总体代码如下

import tensorflow as tf
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from tensorflow.keras.layers import Dense,LSTM,Dropout,Flatten
from sklearn.preprocessing import MinMaxScaler
from sklearn.metrics import mean_squared_error, mean_absolute_error
import math
from matplotlib.font_manager import FontProperties  # 画图时可以使用中文
f = pd.read_csv('..\Desktop\AE86.csv')
# 从新设置列标
def set_columns():columns = []for i in f.loc[2]:columns.append(i.strip())return columns
f.columns = set_columns()
f.drop([0,1,2], inplace = True)# data 包含要操作的列
data = pd.DataFrame()
data['datetime'] = f['Local Date']+' '+f['Local Time']
data['total_flow'] = f['Total Carriageway Flow']
# data['speed'] = f['Speed Value']
data['datetime'] = pd.to_datetime(data['datetime'])data['month'] = data['datetime'].apply(lambda date: date.month)
data['day'] = data['datetime'].apply(lambda date: date.day)
data['hour'] = data['datetime'].apply(lambda date:date.hour)
data['minute'] = data['datetime'].apply(lambda date: date.minute)# 数据转格式
data['total_flow'] = np.array(data['total_flow']).astype(np.float64)
# 一月第25天第一个时间的索引值
d25 = data.query('day==25').index[0]
# 训练集  2211个数据,2018年一月前三周
train_set = data.iloc[:d25,1:2]
# 检测集  669个数据,2018年最后一周
test_set = data.iloc[d25:,1:2]
# 归一化
sc = MinMaxScaler(feature_range=(0, 1))
train_set_sc = sc.fit_transform(train_set)
test_set_sc = sc.transform(test_set)# 按照time_step划分时间步长
time_step = 5
x_train = []
y_train = []
x_test = []
y_test = []
for i in range(time_step, len(train_set_sc)):  x_train.append(train_set_sc[i - time_step:i])y_train.append(train_set_sc[i:i + 1])
for i in range(time_step, len(test_set_sc)):  x_test.append(test_set_sc[i - time_step:i])y_test.append(test_set_sc[i:i + 1])
x_test, y_test = np.array(x_test), np.array(y_test)# 随机化,这部分可以不要
np.random.seed(7)
np.random.shuffle(x_train)
np.random.seed(7)
np.random.shuffle(y_train)
tf.random.set_seed(7)# 转为array格式
x_train, y_train = np.array(x_train), np.array(y_train)
x_test, y_test = np.array(x_test), np.array(y_test)
x_train = np.reshape(x_train, (x_train.shape[0], time_step, 1))
x_test = np.reshape(x_test, (x_test.shape[0], time_step, 1))# LSTM模型
model = tf.keras.Sequential([LSTM(80, return_sequences=True),Dropout(0.2),LSTM(80),Dropout(0.2),Dense(1)
])
model.compile(optimizer='adam',loss='mse',)# 训练模型, 其中epochs,batch_size 可以自己更改
history = model.fit(x_train, y_train,epochs=5,validation_data=(x_test, y_test))
# 模型预测
pre_flow = model.predict(x_test)
# 反归一化
pre_flow = sc.inverse_transform(pre_flow)
real_flow = sc.inverse_transform(y_test.reshape(y_test.shape[0], 1))# 计算误差
mse = mean_squared_error(pre_flow, real_flow)
rmse = math.sqrt(mean_squared_error(pre_flow, real_flow))
mae = mean_absolute_error(pre_flow, real_flow)
print('均方误差---', mse)
print('均方根误差---', rmse)
print('平均绝对误差--', mae)# 画出预测结果图
font_set = FontProperties(fname=r"C:\Windows\Fonts\simsun.ttc", size=15)    # 中文字体使用宋体,15号
plt.figure(figsize=(15,10))
plt.plot(real_flow, label='Real_Flow', color='r', )
plt.plot(pre_flow, label='Pre_Flow')
plt.xlabel('测试序列', fontproperties=font_set)
plt.ylabel('交通流量/辆', fontproperties=font_set)
plt.legend()
# 预测储存图片
# plt.savefig('...\Desktop\123.jpg')
上面代码是最简单的,只是用流量,同时单一节点进行流量预测。

也可以使用速度,占有率等信息,加入到模型中对流量进行预测。真要认真做起来是比较难的,但是如果只是应付应付,提供一个思虑:

可以把另外几种特征也按照time_step=5,进行划分,直接传入到模型中,只不过在模型的最后一层加一个Flatten层(将所有数据拉直成一维),这样就可以大大方方的说”本文考虑了,流量、速度、车道占有率等多种因数,相对于以前文章具有重大改进“

交通流预测爬坑记(二):最简单的LSTM预测交通流,使用tensorflow2实现相关推荐

  1. 交通流预测爬坑记(一):交通流数据集,原始数据

    目录 主要数据类型 个人出行数据,轨迹数据 高速公路观察点数据集 其他 出行数据集 高速公路数据集 其他 赠人玫瑰 如今网上有非常多的数据集,在CSDN,知乎什么搜一下可以找到一大堆,在收集数据时,发 ...

  2. 小小甜菜Movidius爬坑记

    小小甜菜Movidius爬坑记 我是在神经计算棒+树莓派3B上看到实际效果后决定使用本方案的,实际项目中使用树莓派CM3作为核心板卡,使用Movidius 2或Movidius X(具体版本看项目需求 ...

  3. 小白爬坑记:C语言学习点滴——我对单、双引号的理解

    小白爬坑记:C语言学习点滴--我对单.双引号的理解 一.单引号的作用: 二.双引号的作用: 三.字符或字符串容易犯的错误: 三.做个小题: 一.单引号的作用: 将单引号中间的所有符号直接转换为ASCI ...

  4. 我在「小米爬坑记」里,看到的 3 个创业真相

    小米把10年的创业经历,做了一次「开源」. 就在昨晚,雷军做了小米十周年的公开演讲.在演讲中,雷军对小米10年的发展做了一次大梳理,既有成绩,也有反思,还有小米历史上一些非常关键的发展节点.其实上个月 ...

  5. 小小甜菜深度学习爬坑记

    小小甜菜深度学习爬坑记 主要目的是整理一路学习的技术分享贴.我的电脑是笔记本联想拯救者R720,显卡是GTX1050Ti.其它环境仅供参考. 安装ubuntu+win10双系统 详情见技术贴用 Eas ...

  6. Spring Cloud Contract 爬坑记

    前言:spring-cloud-starter-contract-verifier:2.1.1.RELEASE. spring-cloud-contract-maven-plugin:2.2.1.RE ...

  7. 小小甜菜百度AI爬坑记

    小小甜菜百度AI爬坑记 作为国内较好的深度学习平台,我是在2018百度开发者大会拉入坑的.其可分为语音识别,语音合成,文字识别,图像识别,人体分析,人脸识别,理解与互交技术,AR增强现实等多个方向.我 ...

  8. npm run build 打包爬坑记(1)

    npm run build 打包爬坑记 先说说打包过程,npm run build 后放入phpstudy里面,访问本地ip,查看phpstudy的端口号,就能访问页面了(访问地址:http://19 ...

  9. 小小甜菜keras爬坑记

    小小甜菜keras爬坑记 前言 安装(参考深度学习爬坑记安装ubuntu16.04+cuda9+cudnn+tensorflow1.9.0) 视频资料 相关模型资料 学习笔记 ssd模型 yolov3 ...

最新文章

  1. pytorch中的transpose()
  2. 中国IT潜在的巨大希望
  3. Zedboard学习(五):MIO与EMIO操作
  4. 【原创】PostgreSQL 增量备份详解以及相关示例
  5. 【转】MySQL实现Oracle里的 rank()over(ORDER BY) 功能
  6. opencv3.10加入OPENCV_contrib模块
  7. BugkuCTF-Reverse题不好用的ce
  8. 需求获取的三阶段:需求背景、需求调研、需求分析 (2)
  9. Leetcode每日一题:659.split-array-into-consecutive-subsequences(分割数组为连续子序列)
  10. hal库开启中断关中断_STM32 HAL库学习系列第9篇---NVIC按键外部中断函数
  11. sqlserver 软件授权
  12. html 预加载图片,实现网页图片预加载的几个方法
  13. android 车牌输入键盘
  14. 推荐使用Windows10企业版LTSC的理由
  15. Genero BDL 数据类型(1)
  16. 打印图案系列(菱形、X形、箭形、空心正方形)
  17. ERP 系统的核心是什么?有什么作用?
  18. Thinkpad X270上用U深度优盘还原安装win7无法启动
  19. Matlab进行彩色图像直方图匹配(不用histeq函数)
  20. Andoid 手机安装 Google 应用套件

热门文章

  1. 大连将在东京建设软件园
  2. Vue+style 动态样式绑定(收藏图标)
  3. bzero和memset函数区别联系
  4. 说说wps jsa的ListBox控件的数组写入方法
  5. [编译原理学习笔记2-2] 程序语言的语法描述
  6. HTTP Status 404(The requested resource is not available)(转)
  7. 移动终端安全 顶级会议_顶级移动应用开发公司
  8. 新职业人才缺口近千万,90后最担心失业;字节跳动回应TikTok被收购传闻;Twitter公布账号劫持事故细节 | EA周报...
  9. .CreateFeatureClass报错原因解析
  10. 巴什博弈--Nim游戏