前记 之前有参加天池的比赛,后面也会分享这个代码,用到过sklearn重的GBDT这个工具,效果还很不错,但是其实一直没有对它的原理搞通,最近花了点时间,好好研究了下GBDT这个东西,感觉很有意思。

基本介绍 有这样一个场景,训练集只有4个人,A,B,C,D,他们年龄分别是14,16,24,26,其中A、B分别是高一与高三的学生;C,D分别是迎接毕业生和工作两类的员工。如果使用传统的回归决策树来train他们:

图1 普通回归树

但是使用GBDT,这时候我们生成两棵树:

图2 GBDT训练

很明显,GBDT不是依赖原始数据来生成第二棵树,而是用预测值与实际值之间的残差来当做输入数据源进行树的训练,而普通回归树是通过分支后的数据来进行下一层次的分支。 图1和图2在这个场景中,最终都能合理地对用户A,B,C,D进行预测,那为什么还需要GBDT,为什么说GBDT会比普通的回归树性能更高呢? 由图1分类使用了3个feature,并且在上网时长这个feature上分类时,恰好可能AB中A每天上网1.09h,B上网1.11h,显然过拟合,极有可能不适合其他数据;另外就是特征数来比,图1大于图2,图1的模型复杂性比较大。 有人可能说,这其实只是我这边参考的一个例子,这个例子无法具有真实性,无法从理论上说明真实情况就是这样 这部分同学确实想的比较正确,这里举得例子只是介绍两种算法的不同,没有严谨的解释,但是其实GBDT基于Boosting的最大好处就是,每一步残差其实就是变相地增加了分错的那些实例的权重,已经分对在经过残差计算后就不会再被考虑,后面生成的树会越来越关注与分错的实例。相对于普通回归树而言,这种方法更加有效,并且会减少模型复杂性,减少过拟合的可能性

算法原理

损失函数 Gradient Boosting和一般的组合方法一样,迭代产生弱分类器集合,组合成为具有强学习能力的分类器。 GBDT的原理是:假定经过前面多次迭代已经产生一个不太完美的分类器

,GBDT不会去改变原来已经生成的弱分类器集合(AdaBoost就是改变权重),而是加入一个新的分类器h,使得性能会更好。然而,如何去找这样的函数h呢? 假定理想状态下,h的加入能够消除的误差,即:

则可以知道:

那么,我们Gradient Boosting的任务就是将h拟合

,通过train来修正之前的误差。

算法推导 GBDT抽样出来最终的目标是在已有训练集

上使整体的损失函数的期望最小:

根据经验风险最小化原则,Gradient Boosting每一步求出的函数都需要最小化训练集数据上的损失函数,迭代构造这个模型,初始化

为一个常数函数:其中f限定为从基函数类H当中选取的函数,即上一节当中的h。 这里把看做一个关于向量的函数,而不是关于f的函数,则:这样就求出f,不过f会在限定为基函数H中,故会在这个类中来选取最近的梯度f。 求出f后,对应的系数

假设在这里,我们指定使用的损失函数h为square-loss:,那么GBDT的基本流程如下:

通常,我们将损失函数写成indicator notation形式:

假设J为叶子数,输入特征空间为,每一个空间有对应的常数作为预测值,故函数h(x)可写为:

根据Boosting的方法,h(x)乘上一个权重,然后加入原先已生成的分类器中来最小化Loss function:

故按indicator notation后的表示,公式可写为:

GBDT如何调参

  • GBDT中决策树的叶子数,通常GBDT中决策树的叶子数控制在4-8之间,效果比较好
  • GBDT正则化,涉及到过拟合问题,正则化减小模型复杂度,防止过拟合
  • 迭代次数 M太小,学习效果会有提升的空间,M太大导致过拟合,通常使用CV来检测M是否为有效地迭代次数
  • Shrinkage,通常,比较小的学习率通常来带来不错的模型泛化能力,但是训练与预测时间会增加
  • 限制每个叶子的数据数量:类似于决策树的过拟合方法,如果某个条件下的数据个数小于我们规定的值,那么就不会被分支,减少树的数量,减少模型复杂性

GBDT算法实践 在前段时间的天池的一个关于客流预测的比赛中,用了GBDT来对公交车某天某时间段的乘客数来进行预测,核心代码如下:

#-*-coding:utf-8-*-
'''
这个脚本用来训练为经过dummies的模型,并且保存
'''
__author__ = 'burness'
import pandas as pd
from add_holiday import add_holiday
from compute_error import compute_error
gd_lines_info = pd.read_csv('./data/gd_line_desc.txt')
gd_lines_info.columns=['line_name','stop_cnt','line_type']gd_weather_info = pd.read_csv('./data/gd_weather_report.txt')
gd_weather_info.columns=['date','weather','temperature','wind_direction_force']gd_lines_info['line_type_val']=gd_lines_info['line_type'].map({'广州市内':0,'广佛跨区域':1})gd_lines_info=gd_lines_info.drop(['line_type'],axis=1)
for line in ['线路6','线路11']:
# for line in ['线路11']:train_count_file = './data/count/final_%s_count.txt'%linegd_train_line_pd = pd.read_csv(train_count_file)gd_train_line_pd.columns = ['date','time','cnt']# print gd_train_line_pd.count()gd_train_line_pd.head()# join the weatherprint gd_weather_info.columnsgd_train_line_pd_weather = pd.merge(gd_train_line_pd,gd_weather_info,on='date')# print gd_train_line_pd_weather.count()gd_train_line_pd_weather['date_val']=pd.to_datetime(gd_train_line_pd_weather['date'])gd_train_line_pd_weather.head()gd_train_line_pd_weather['dayofweek']=gd_train_line_pd_weather['date_val'].apply(lambda x: x.dayofweek)gd_train_line_pd_weather['weatherA']=gd_train_line_pd_weather['weather'].str.split('/').str[0]gd_train_line_pd_weather['weatherB']=gd_train_line_pd_weather['weather'].str.split('/').str[1]gd_train_line_pd_weather['weatherA_val']=gd_train_line_pd_weather['weatherA'].map({'大到暴雨':0,'大雨':1,'中到大雨':2,'中雨':3,'小到中雨':4,'雷阵雨':5,'阵雨':6,'小雨':7,'阴':8,'多云':9,'晴':10})gd_train_line_pd_weather['weatherB_val']=gd_train_line_pd_weather['weatherB'].map({'大到暴雨':0,'大雨':1,'中到大雨':2,'中雨':3,'小到中雨':4,'雷阵雨':5,'阵雨':6,'小雨':7,'阴':8,'多云':9,'晴':10})gd_train_line_pd_weather[['weatherA_val','weatherB_val']]=gd_train_line_pd_weather[['weatherA_val','weatherB_val']].astype(int)gd_train_line_pd_weather['weatherPeriod']=abs(gd_train_line_pd_weather['weatherA_val']-gd_train_line_pd_weather['weatherB_val'])gd_train_line_pd_weather['weatherE']=(gd_train_line_pd_weather['weatherA_val']+gd_train_line_pd_weather['weatherB_val'])/2.0# gd_train_line_pd_weather.dtypesgd_train_line_pd_weather=gd_train_line_pd_weather.drop(['weather','weatherA','weatherB'],axis=1)gd_train_line_pd_weather['temperatureA']=gd_train_line_pd_weather['temperature'].str.split('/').str[0].str.extract('(\d+)')gd_train_line_pd_weather['temperatureB']=gd_train_line_pd_weather['temperature'].str.split('/').str[1].str.extract('(\d+)')gd_train_line_pd_weather[['temperatureA','temperatureB']]=gd_train_line_pd_weather[['temperatureA','temperatureB']].astype(float)gd_train_line_pd_weather['temperaturePeriod']=abs(gd_train_line_pd_weather['temperatureA']-gd_train_line_pd_weather['temperatureB'])gd_train_line_pd_weather['temperatureE']=(gd_train_line_pd_weather['temperatureA']+gd_train_line_pd_weather['temperatureB'])/2.0# 增加bus数量bus_line_name = './data/bus_count/final_'+line+'_bus_count.txt'bus_line = pd.read_csv(bus_line_name)bus_line.columns=['date','time','bus_cnt']bus_line['date_val'] = pd.to_datetime(bus_line['date'])bus_line = bus_line.drop('date',axis=1)# print bus_line.dtypes# print gd_train_line_pd_weather.dtypesgd_train_line_pd_weather = pd.merge(gd_train_line_pd_weather,bus_line,on=['date_val','time'])# print gd_train_line_pd_weather.columns# 滤除非6点到21点得数据# print gd_train_line_pd_weather[gd_train_line_pd_weather['time']==23]gd_train_line_pd_weather = gd_train_line_pd_weather[gd_train_line_pd_weather['time']>=6]gd_train_line_pd_weather = gd_train_line_pd_weather[gd_train_line_pd_weather['time']<=21]# 加上holiday信息gd_train_line_pd_weather = add_holiday(gd_train_line_pd_weather)print gd_train_line_pd_weather.head()print gd_train_line_pd_weather.count()gd_train_line_pd_final = gd_train_line_pd_weather[['cnt','time','dayofweek','weatherA_val','weatherB_val','weatherPeriod','weatherE','temperatureA','temperatureB','temperaturePeriod','temperatureE','holiday','bus_cnt']]from sklearn.ensemble import GradientBoostingRegressorfrom sklearn import grid_searchdata=gd_train_line_pd_final[['time','dayofweek','weatherA_val','weatherB_val','weatherPeriod','weatherE','temperatureA','temperatureB','temperaturePeriod','temperatureE','holiday','bus_cnt']]labels = gd_train_line_pd_final['cnt']# # # print data.head(40)from sklearn.cross_validation import train_test_splittrain_data,test_data,train_labels,test_labels=train_test_split(data,labels,test_size=7*15)est = GradientBoostingRegressor()parameters={'loss':('ls', 'lad', 'huber', 'quantile'),'learning_rate':[0.04*(i+1) for i in range(25)],'n_estimators':[75,100,125,150],'max_depth':[2,3,4]}clf=grid_search.GridSearchCV(est,parameters)print 'performing grid_searching...'print 'parameters:'from time import timet0=time()clf.fit(train_data,train_labels)print 'grid_searching takes %0.3fs'%(time()-t0)best_parameters=clf.best_params_for para_name in sorted(parameters.keys()):print para_nameprint best_parameters[para_name]###est.set_params(learning_rate=best_parameters['learning_rate'],loss=best_parameters['loss'],max_depth=best_parameters['max_depth'],n_estimators=best_parameters['n_estimators'])est.fit(train_data,train_labels)print '保存model....'from sklearn.externals import joblibmodel_name = './model/2015-11-26/traffic_GBDT_'+line+'.model'joblib.dump(est,model_name)# validation proceeest = joblib.load('./model/2015-11-26/traffic_GBDT_'+line+'.model')sum = 0.0for i in range(200):val_train_data,val_test_data,val_train_labels,val_test_labels=train_test_split(data,labels,test_size=7*15)predict_labels = est.predict(val_test_data)print predict_labelserror = compute_error(predict_labels,val_test_labels)print 'val error: %f '% errorsum+=errorprint 'averge error: %f'%(sum/200)

代码当中包括gridserach找最优参数,以及一个简单地本地评测,具体代码可见我的github上:

天池客流预测

天池客流预测–GBDT相关推荐

  1. 公交线路客流预测——手把手教你玩数据(一)

    目录 引言 背景 说明 How Do it? 看数据的容颜 了解性格 恋爱之baseline 恋爱之调优 结婚 总结 关于数据和代码 作者:徐国功 2018.9.7 转载请注明出处:https://b ...

  2. 基于物理-虚拟协同图网络的客流预测

    1.文章信息 本周阅读的论文是题目为<Physical-Virtual Collaboration Modeling for Intra- and Inter-Station Metro Rid ...

  3. 阿里云天池蒸汽预测(一)

    阿里云天池蒸汽预测(一) 数据探索 1.查看数据 2.可视化数据分布 2.1.箱型图 2.2.获取异常数据 2.3.直方图和QQ图 2.4.KDE分布图 2.5线性回归关系图 3.查看数据相关性 3. ...

  4. 基于小波分解与LSTM的城市轨道短时客流预测

    1.文章信息 文章题为<A novel prediction model for the inbound passenger flow of urban rail transit>,是一篇 ...

  5. 多变量干扰事件发生下的地铁客流预测

    文章信息 本周阅读的论文是题目为<Forecasting the subway passenger flow under event occurrences with multivariate ...

  6. 基于智能卡数据的特殊事件地铁客流预测

    文章信息 <Subway Passenger Flow Prediction for Special Events Using Smart Card Data>是2020年发表在期刊IEE ...

  7. K-means聚类后的LSTM-CNN出租车热点区域客流预测

    近些年,随着社会的发展和城市规模的不断拓展,城市人口不断增多,出租车与网约车的数量呈现直线上升的趋势,居民对出租车和网约车的需求量也越来越大.在出租车市场规模快速增长的同时,行业内的竞争也日益激烈,因 ...

  8. 基于OD吸引度的城市轨道交通OD客流预测方法

    摘要 本发明提供一种基于OD吸引度的城市轨道交通OD客流预测方法.该方法包括:根据历史数据统计轨道交通网中不同时段OD对间的吸引度值和对应的OD对间的客流,分别表示为OD吸引度矩阵和第一OD矩阵,其中 ...

  9. 天池“幸福感预测”比赛-2019

    "幸福感预测"Project报告 1 赛题简介 本赛题是天池上的一个数据挖掘类型的比赛--快来一起挖掘幸福感.比赛的数据使用的是官方的<中国综合社会调查(CGSS)>文 ...

  10. 阿里天池--工业蒸汽预测

    赛题描述: 经脱敏后的锅炉传感器采集的数据(采集频率是分钟级别),根据锅炉的工况,预测产生的蒸汽量. 数据说明: 数据分成训练数据(train.txt)和测试数据(test.txt),其中字段 V0- ...

最新文章

  1. Spring Cloud第十三篇: 断路器聚合监控(Hystrix Turbine)
  2. Python 正则表达式
  3. 哥们,你侵权了,哥有权告你去!
  4. 单片机为什么不到一年时间涨这么多?
  5. 怎么把php查询到的值显示到下拉框中_RazorSQL for Mac(数据库工具查询) v8.5.0
  6. C#LeetCode刷题之#55-跳跃游戏(Jump Game)
  7. 用递归方法实现读取文件夹下所有文件信息
  8. 老板突然出现,游戏飞速隐藏,开源神器在手,摸鱼不怕被抓包丨不会写代码也能用...
  9. Acer4552G双硬盘
  10. List集合去重的常见及常用的四种方式
  11. ArcGIS 矢量数据的合并
  12. 三极管作开关应用及详解
  13. vlan的基本指令_vlan划分命令
  14. 悉尼大学计算机硕士健康科技,悉尼大学健康科学学院
  15. 睿智的目标检测28——YoloV4当中的Mosaic数据增强方法
  16. android ogg转mp3,MP3提取转换器
  17. Recurrent Neural Network(循环神经网络)
  18. 棋牌算法——“贰柒拾”(字牌)
  19. 类名.class 类名.this 详解
  20. 【KVM相关】kvm虚拟化部署配置

热门文章

  1. 关于Scala和面向对象的函数式编程
  2. MySQL 按照拼音给中文字段排序
  3. 使Ubuntu登陆时默认开NumLock灯
  4. 飞利浦 f718 java 微信_第一次使用飞利浦F718手机感觉怎么样及优缺点
  5. STC学习:振动声光报警器
  6. sed 追加文本类容_浅谈Linux三剑客中的sed命令之篇二
  7. python中list是链表吗_Python
  8. 【算法原理+洛谷P6114+HDU6761】Lyndon分解
  9. 【POJ2007】Scrambled Polygon(点集逆时针排序--极角排序/凸包--只适用于凸多边形)
  10. 【2018焦作网络赛】Strings and Times(出现次数在[L,R]的子串数目---后缀数组+st表)