天池客流预测–GBDT
前记 之前有参加天池的比赛,后面也会分享这个代码,用到过sklearn重的GBDT这个工具,效果还很不错,但是其实一直没有对它的原理搞通,最近花了点时间,好好研究了下GBDT这个东西,感觉很有意思。
基本介绍 有这样一个场景,训练集只有4个人,A,B,C,D,他们年龄分别是14,16,24,26,其中A、B分别是高一与高三的学生;C,D分别是迎接毕业生和工作两类的员工。如果使用传统的回归决策树来train他们:
图1 普通回归树
但是使用GBDT,这时候我们生成两棵树:
图2 GBDT训练
很明显,GBDT不是依赖原始数据来生成第二棵树,而是用预测值与实际值之间的残差来当做输入数据源进行树的训练,而普通回归树是通过分支后的数据来进行下一层次的分支。 图1和图2在这个场景中,最终都能合理地对用户A,B,C,D进行预测,那为什么还需要GBDT,为什么说GBDT会比普通的回归树性能更高呢? 由图1分类使用了3个feature,并且在上网时长这个feature上分类时,恰好可能AB中A每天上网1.09h,B上网1.11h,显然过拟合,极有可能不适合其他数据;另外就是特征数来比,图1大于图2,图1的模型复杂性比较大。 有人可能说,这其实只是我这边参考的一个例子,这个例子无法具有真实性,无法从理论上说明真实情况就是这样 这部分同学确实想的比较正确,这里举得例子只是介绍两种算法的不同,没有严谨的解释,但是其实GBDT基于Boosting的最大好处就是,每一步残差其实就是变相地增加了分错的那些实例的权重,已经分对在经过残差计算后就不会再被考虑,后面生成的树会越来越关注与分错的实例。相对于普通回归树而言,这种方法更加有效,并且会减少模型复杂性,减少过拟合的可能性
算法原理
损失函数 Gradient Boosting和一般的组合方法一样,迭代产生弱分类器集合,组合成为具有强学习能力的分类器。 GBDT的原理是:假定经过前面多次迭代已经产生一个不太完美的分类器
,GBDT不会去改变原来已经生成的弱分类器集合(AdaBoost就是改变权重),而是加入一个新的分类器h,使得性能会更好。然而,如何去找这样的函数h呢? 假定理想状态下,h的加入能够消除的误差,即:
则可以知道:
那么,我们Gradient Boosting的任务就是将h拟合
,通过train来修正之前的误差。
算法推导 GBDT抽样出来最终的目标是在已有训练集
上使整体的损失函数的期望最小:
根据经验风险最小化原则,Gradient Boosting每一步求出的函数都需要最小化训练集数据上的损失函数,迭代构造这个模型,初始化
为一个常数函数:其中f限定为从基函数类H当中选取的函数,即上一节当中的h。 这里把看做一个关于向量的函数,而不是关于f的函数,则:这样就求出f,不过f会在限定为基函数H中,故会在这个类中来选取最近的梯度f。 求出f后,对应的系数:
假设在这里,我们指定使用的损失函数h为square-loss:,那么GBDT的基本流程如下:
通常,我们将损失函数写成indicator notation形式:
假设J为叶子数,输入特征空间为,每一个空间有对应的常数作为预测值,故函数h(x)可写为:
根据Boosting的方法,h(x)乘上一个权重,然后加入原先已生成的分类器中来最小化Loss function:
故按indicator notation后的表示,公式可写为:
GBDT如何调参
- GBDT中决策树的叶子数,通常GBDT中决策树的叶子数控制在4-8之间,效果比较好
- GBDT正则化,涉及到过拟合问题,正则化减小模型复杂度,防止过拟合
- 迭代次数 M太小,学习效果会有提升的空间,M太大导致过拟合,通常使用CV来检测M是否为有效地迭代次数
- Shrinkage,通常,比较小的学习率通常来带来不错的模型泛化能力,但是训练与预测时间会增加
- 限制每个叶子的数据数量:类似于决策树的过拟合方法,如果某个条件下的数据个数小于我们规定的值,那么就不会被分支,减少树的数量,减少模型复杂性
GBDT算法实践 在前段时间的天池的一个关于客流预测的比赛中,用了GBDT来对公交车某天某时间段的乘客数来进行预测,核心代码如下:
#-*-coding:utf-8-*- ''' 这个脚本用来训练为经过dummies的模型,并且保存 ''' __author__ = 'burness' import pandas as pd from add_holiday import add_holiday from compute_error import compute_error gd_lines_info = pd.read_csv('./data/gd_line_desc.txt') gd_lines_info.columns=['line_name','stop_cnt','line_type']gd_weather_info = pd.read_csv('./data/gd_weather_report.txt') gd_weather_info.columns=['date','weather','temperature','wind_direction_force']gd_lines_info['line_type_val']=gd_lines_info['line_type'].map({'广州市内':0,'广佛跨区域':1})gd_lines_info=gd_lines_info.drop(['line_type'],axis=1) for line in ['线路6','线路11']: # for line in ['线路11']:train_count_file = './data/count/final_%s_count.txt'%linegd_train_line_pd = pd.read_csv(train_count_file)gd_train_line_pd.columns = ['date','time','cnt']# print gd_train_line_pd.count()gd_train_line_pd.head()# join the weatherprint gd_weather_info.columnsgd_train_line_pd_weather = pd.merge(gd_train_line_pd,gd_weather_info,on='date')# print gd_train_line_pd_weather.count()gd_train_line_pd_weather['date_val']=pd.to_datetime(gd_train_line_pd_weather['date'])gd_train_line_pd_weather.head()gd_train_line_pd_weather['dayofweek']=gd_train_line_pd_weather['date_val'].apply(lambda x: x.dayofweek)gd_train_line_pd_weather['weatherA']=gd_train_line_pd_weather['weather'].str.split('/').str[0]gd_train_line_pd_weather['weatherB']=gd_train_line_pd_weather['weather'].str.split('/').str[1]gd_train_line_pd_weather['weatherA_val']=gd_train_line_pd_weather['weatherA'].map({'大到暴雨':0,'大雨':1,'中到大雨':2,'中雨':3,'小到中雨':4,'雷阵雨':5,'阵雨':6,'小雨':7,'阴':8,'多云':9,'晴':10})gd_train_line_pd_weather['weatherB_val']=gd_train_line_pd_weather['weatherB'].map({'大到暴雨':0,'大雨':1,'中到大雨':2,'中雨':3,'小到中雨':4,'雷阵雨':5,'阵雨':6,'小雨':7,'阴':8,'多云':9,'晴':10})gd_train_line_pd_weather[['weatherA_val','weatherB_val']]=gd_train_line_pd_weather[['weatherA_val','weatherB_val']].astype(int)gd_train_line_pd_weather['weatherPeriod']=abs(gd_train_line_pd_weather['weatherA_val']-gd_train_line_pd_weather['weatherB_val'])gd_train_line_pd_weather['weatherE']=(gd_train_line_pd_weather['weatherA_val']+gd_train_line_pd_weather['weatherB_val'])/2.0# gd_train_line_pd_weather.dtypesgd_train_line_pd_weather=gd_train_line_pd_weather.drop(['weather','weatherA','weatherB'],axis=1)gd_train_line_pd_weather['temperatureA']=gd_train_line_pd_weather['temperature'].str.split('/').str[0].str.extract('(\d+)')gd_train_line_pd_weather['temperatureB']=gd_train_line_pd_weather['temperature'].str.split('/').str[1].str.extract('(\d+)')gd_train_line_pd_weather[['temperatureA','temperatureB']]=gd_train_line_pd_weather[['temperatureA','temperatureB']].astype(float)gd_train_line_pd_weather['temperaturePeriod']=abs(gd_train_line_pd_weather['temperatureA']-gd_train_line_pd_weather['temperatureB'])gd_train_line_pd_weather['temperatureE']=(gd_train_line_pd_weather['temperatureA']+gd_train_line_pd_weather['temperatureB'])/2.0# 增加bus数量bus_line_name = './data/bus_count/final_'+line+'_bus_count.txt'bus_line = pd.read_csv(bus_line_name)bus_line.columns=['date','time','bus_cnt']bus_line['date_val'] = pd.to_datetime(bus_line['date'])bus_line = bus_line.drop('date',axis=1)# print bus_line.dtypes# print gd_train_line_pd_weather.dtypesgd_train_line_pd_weather = pd.merge(gd_train_line_pd_weather,bus_line,on=['date_val','time'])# print gd_train_line_pd_weather.columns# 滤除非6点到21点得数据# print gd_train_line_pd_weather[gd_train_line_pd_weather['time']==23]gd_train_line_pd_weather = gd_train_line_pd_weather[gd_train_line_pd_weather['time']>=6]gd_train_line_pd_weather = gd_train_line_pd_weather[gd_train_line_pd_weather['time']<=21]# 加上holiday信息gd_train_line_pd_weather = add_holiday(gd_train_line_pd_weather)print gd_train_line_pd_weather.head()print gd_train_line_pd_weather.count()gd_train_line_pd_final = gd_train_line_pd_weather[['cnt','time','dayofweek','weatherA_val','weatherB_val','weatherPeriod','weatherE','temperatureA','temperatureB','temperaturePeriod','temperatureE','holiday','bus_cnt']]from sklearn.ensemble import GradientBoostingRegressorfrom sklearn import grid_searchdata=gd_train_line_pd_final[['time','dayofweek','weatherA_val','weatherB_val','weatherPeriod','weatherE','temperatureA','temperatureB','temperaturePeriod','temperatureE','holiday','bus_cnt']]labels = gd_train_line_pd_final['cnt']# # # print data.head(40)from sklearn.cross_validation import train_test_splittrain_data,test_data,train_labels,test_labels=train_test_split(data,labels,test_size=7*15)est = GradientBoostingRegressor()parameters={'loss':('ls', 'lad', 'huber', 'quantile'),'learning_rate':[0.04*(i+1) for i in range(25)],'n_estimators':[75,100,125,150],'max_depth':[2,3,4]}clf=grid_search.GridSearchCV(est,parameters)print 'performing grid_searching...'print 'parameters:'from time import timet0=time()clf.fit(train_data,train_labels)print 'grid_searching takes %0.3fs'%(time()-t0)best_parameters=clf.best_params_for para_name in sorted(parameters.keys()):print para_nameprint best_parameters[para_name]###est.set_params(learning_rate=best_parameters['learning_rate'],loss=best_parameters['loss'],max_depth=best_parameters['max_depth'],n_estimators=best_parameters['n_estimators'])est.fit(train_data,train_labels)print '保存model....'from sklearn.externals import joblibmodel_name = './model/2015-11-26/traffic_GBDT_'+line+'.model'joblib.dump(est,model_name)# validation proceeest = joblib.load('./model/2015-11-26/traffic_GBDT_'+line+'.model')sum = 0.0for i in range(200):val_train_data,val_test_data,val_train_labels,val_test_labels=train_test_split(data,labels,test_size=7*15)predict_labels = est.predict(val_test_data)print predict_labelserror = compute_error(predict_labels,val_test_labels)print 'val error: %f '% errorsum+=errorprint 'averge error: %f'%(sum/200)
代码当中包括gridserach找最优参数,以及一个简单地本地评测,具体代码可见我的github上:
天池客流预测
天池客流预测–GBDT相关推荐
- 公交线路客流预测——手把手教你玩数据(一)
目录 引言 背景 说明 How Do it? 看数据的容颜 了解性格 恋爱之baseline 恋爱之调优 结婚 总结 关于数据和代码 作者:徐国功 2018.9.7 转载请注明出处:https://b ...
- 基于物理-虚拟协同图网络的客流预测
1.文章信息 本周阅读的论文是题目为<Physical-Virtual Collaboration Modeling for Intra- and Inter-Station Metro Rid ...
- 阿里云天池蒸汽预测(一)
阿里云天池蒸汽预测(一) 数据探索 1.查看数据 2.可视化数据分布 2.1.箱型图 2.2.获取异常数据 2.3.直方图和QQ图 2.4.KDE分布图 2.5线性回归关系图 3.查看数据相关性 3. ...
- 基于小波分解与LSTM的城市轨道短时客流预测
1.文章信息 文章题为<A novel prediction model for the inbound passenger flow of urban rail transit>,是一篇 ...
- 多变量干扰事件发生下的地铁客流预测
文章信息 本周阅读的论文是题目为<Forecasting the subway passenger flow under event occurrences with multivariate ...
- 基于智能卡数据的特殊事件地铁客流预测
文章信息 <Subway Passenger Flow Prediction for Special Events Using Smart Card Data>是2020年发表在期刊IEE ...
- K-means聚类后的LSTM-CNN出租车热点区域客流预测
近些年,随着社会的发展和城市规模的不断拓展,城市人口不断增多,出租车与网约车的数量呈现直线上升的趋势,居民对出租车和网约车的需求量也越来越大.在出租车市场规模快速增长的同时,行业内的竞争也日益激烈,因 ...
- 基于OD吸引度的城市轨道交通OD客流预测方法
摘要 本发明提供一种基于OD吸引度的城市轨道交通OD客流预测方法.该方法包括:根据历史数据统计轨道交通网中不同时段OD对间的吸引度值和对应的OD对间的客流,分别表示为OD吸引度矩阵和第一OD矩阵,其中 ...
- 天池“幸福感预测”比赛-2019
"幸福感预测"Project报告 1 赛题简介 本赛题是天池上的一个数据挖掘类型的比赛--快来一起挖掘幸福感.比赛的数据使用的是官方的<中国综合社会调查(CGSS)>文 ...
- 阿里天池--工业蒸汽预测
赛题描述: 经脱敏后的锅炉传感器采集的数据(采集频率是分钟级别),根据锅炉的工况,预测产生的蒸汽量. 数据说明: 数据分成训练数据(train.txt)和测试数据(test.txt),其中字段 V0- ...
最新文章
- Spring Cloud第十三篇: 断路器聚合监控(Hystrix Turbine)
- Python 正则表达式
- 哥们,你侵权了,哥有权告你去!
- 单片机为什么不到一年时间涨这么多?
- 怎么把php查询到的值显示到下拉框中_RazorSQL for Mac(数据库工具查询) v8.5.0
- C#LeetCode刷题之#55-跳跃游戏(Jump Game)
- 用递归方法实现读取文件夹下所有文件信息
- 老板突然出现,游戏飞速隐藏,开源神器在手,摸鱼不怕被抓包丨不会写代码也能用...
- Acer4552G双硬盘
- List集合去重的常见及常用的四种方式
- ArcGIS 矢量数据的合并
- 三极管作开关应用及详解
- vlan的基本指令_vlan划分命令
- 悉尼大学计算机硕士健康科技,悉尼大学健康科学学院
- 睿智的目标检测28——YoloV4当中的Mosaic数据增强方法
- android ogg转mp3,MP3提取转换器
- Recurrent Neural Network(循环神经网络)
- 棋牌算法——“贰柒拾”(字牌)
- 类名.class 类名.this 详解
- 【KVM相关】kvm虚拟化部署配置
热门文章
- 关于Scala和面向对象的函数式编程
- MySQL 按照拼音给中文字段排序
- 使Ubuntu登陆时默认开NumLock灯
- 飞利浦 f718 java 微信_第一次使用飞利浦F718手机感觉怎么样及优缺点
- STC学习:振动声光报警器
- sed 追加文本类容_浅谈Linux三剑客中的sed命令之篇二
- python中list是链表吗_Python
- 【算法原理+洛谷P6114+HDU6761】Lyndon分解
- 【POJ2007】Scrambled Polygon(点集逆时针排序--极角排序/凸包--只适用于凸多边形)
- 【2018焦作网络赛】Strings and Times(出现次数在[L,R]的子串数目---后缀数组+st表)