文章来源:https://www.deeplearn.me/1797.html

GBDT+LR 的特征组合方案是工业界经常使用的组合,尤其是计算广告 CTR 中应用比较广泛,方案的提出者是 Facebook 2014 的一篇论文。

相关的开发工具包,sklearn 和 xgboost(ps:xgboost 是一个大杀器,并且支持 hadoop 分布式,你可以部署实现分布式操作,博主部署过,布置过程较为负责,尤其是环境变量的各种设置)


特征决定模型性能上界,例如深度学习方法也是将数据如何更好的表达为特征。如果能够将数据表达成为线性可分的数据,那么使用简单的线性模型就可以取得很好的效果。GBDT 构建新的特征也是使特征更好地表达数据。

主要参考 Facebook[1],原文提升效果:

在预测 Facebook 广告点击中,使用一种将决策树与逻辑回归结合在一起的模型,其优于其他方法,超过 3%。

主要思想:GBDT 每棵树的路径直接作为 LR 输入特征使用。

用已有特征训练 GBDT 模型,然后利用 GBDT 模型学习到的树来构造新特征,最后把这些新特征加入原有特征一起训练模型。构造的新特征向量是取值 0/1 的,向量的每个元素对应于 GBDT 模型中树的叶子结点。当一个样本点通过某棵树最终落在这棵树的一个叶子结点上,那么在新特征向量中这个叶子结点对应的元素值为 1,而这棵树的其他叶子结点对应的元素值为 0。新特征向量的长度等于 GBDT 模型里所有树包含的叶子结点数之和。

上图为混合模型结构。输入特征通过增强的决策树进行转换。 每个单独树的输出被视为稀疏线性分类器的分类输入特征。 增强的决策树被证明是非常强大的特征转换。

例子 1:上图有两棵树,左树有三个叶子节点,右树有两个叶子节点,最终的特征即为五维的向量。对于输入 x,假设他落在左树第一个节点,编码[1,0,0],落在右树第二个节点则编码[0,1],所以整体的编码为[1,0,0,0,1],这类编码作为特征,输入到线性分类模型(LR or FM)中进行分类。

需要注意的是在 sklearn 或者 xgboost 输出的结果都是叶子节点的 index,所以需要自己动手去做 onehot 编码,然后交给 lr 训练,onehot 你可以在 sklearn 的预处理包中调用即可

论文中 GBDT 的参数,树的数量最多 500 颗(500 以上就没有提升了),每棵树的节点不多于 12。

下面给出二者相结合的代码演示

  1. # -*- coding: utf-8 -*-
  2. # @Time : 2018/2/27 上午 10:39
  3. # @Author : Tomcj
  4. # @File : gbdt_lr.py
  5. # @Software: PyCharm
  6. import xgboost as xgb
  7. from sklearn.datasets import load_svmlight_file
  8. from sklearn.model_selection import train_test_split
  9. from sklearn.linear_model import LogisticRegression
  10. from sklearn.metrics import roc_curve, auc, roc_auc_score
  11. from sklearn.externals import joblib
  12. from sklearn.preprocessing import OneHotEncoder
  13. import numpy as np
  14. from scipy.sparse import hstack
  15. def xgb_feature_encode(libsvmFileNameInitial):
  16. # load 样本数据
  17. X_all, y_all = load_svmlight_file(libsvmFileNameInitial)
  18. # 训练/测试数据分割
  19. X_train, X_test, y_train, y_test = train_test_split(X_all, y_all, test_size = 0.3, random_state = 42)
  20. # 定义模型
  21. xgboost = xgb.XGBClassifier(nthread=4, learning_rate=0.08,
  22. n_estimators=50, max_depth=5, gamma=0, subsample=0.9, colsample_bytree=0.5)
  23. # 训练学习
  24. xgboost.fit(X_train, y_train)
  25. # 预测及 AUC 评测
  26. y_pred_test = xgboost.predict_proba(X_test)[:, 1]
  27. xgb_test_auc = roc_auc_score(y_test, y_pred_test)
  28. print('xgboost test auc: %.5f' % xgb_test_auc)
  29. # xgboost 编码原有特征
  30. X_train_leaves = xgboost.apply(X_train)
  31. X_test_leaves = xgboost.apply(X_test)
  32. # 训练样本个数
  33. train_rows = X_train_leaves.shape[0]
  34. # 合并编码后的训练数据和测试数据
  35. X_leaves = np.concatenate((X_train_leaves, X_test_leaves), axis=0)
  36. X_leaves = X_leaves.astype(np.int32)
  37. (rows, cols) = X_leaves.shape
  38. # 记录每棵树的编码区间
  39. cum_count = np.zeros((1, cols), dtype=np.int32)
  40. for j in range(cols):
  41. if j == 0:
  42. cum_count[0][j] = len(np.unique(X_leaves[:, j]))
  43. else:
  44. cum_count[0][j] = len(np.unique(X_leaves[:, j])) + cum_count[0][j-1]
  45. print('Transform features genenrated by xgboost...')
  46. # 对所有特征进行 ont-hot 编码,注释部分是直接使用 onehot 函数,结果输出保证是 libsvm 格式也可以使用
  47. #sklearn 中的 dump_svmlight_file 操作,这个文件代码是参考别人的代码,这些点都是可以优化的。
  48. # onehot=OneHotEncoder()
  49. # onehot.fit(X_leaves)
  50. # x_leaves_encode=onehot.transform(X_leaves)
  51. for j in range(cols):
  52. keyMapDict = {}
  53. if j == 0:
  54. initial_index = 1
  55. else:
  56. initial_index = cum_count[0][j-1]+1
  57. for i in range(rows):
  58. if X_leaves[i, j] not in keyMapDict:
  59. keyMapDict[X_leaves[i, j]] = initial_index
  60. X_leaves[i, j] = initial_index
  61. initial_index = initial_index + 1
  62. else:
  63. X_leaves[i, j] = keyMapDict[X_leaves[i, j]]
  64. # 基于编码后的特征,将特征处理为 libsvm 格式且写入文件
  65. print('Write xgboost learned features to file ...')
  66. xgbFeatureLibsvm = open('xgb_feature_libsvm', 'w')
  67. for i in range(rows):
  68. if i < train_rows:
  69. xgbFeatureLibsvm.write(str(y_train[i]))
  70. else:
  71. xgbFeatureLibsvm.write(str(y_test[i-train_rows]))
  72. for j in range(cols):
  73. xgbFeatureLibsvm.write(' '+str(X_leaves[i, j])+':1.0')
  74. xgbFeatureLibsvm.write('\n')
  75. xgbFeatureLibsvm.close()
  76. def xgboost_lr_train(xgbfeaturefile, origin_libsvm_file):
  77. # load xgboost 特征编码后的样本数据
  78. X_xg_all, y_xg_all = load_svmlight_file(xgbfeaturefile)
  79. X_train, X_test, y_train, y_test = train_test_split(X_xg_all, y_xg_all, test_size = 0.3, random_state = 42)
  80. # load 原始样本数据
  81. X_all, y_all = load_svmlight_file(origin_libsvm_file)
  82. X_train_origin, X_test_origin, y_train_origin, y_test_origin = train_test_split(X_all, y_all, test_size = 0.3, random_state = 42)
  83. # lr 对原始特征样本模型训练
  84. lr = LogisticRegression(n_jobs=-1, C=0.1, penalty='l1')
  85. lr.fit(X_train_origin, y_train_origin)
  86. joblib.dump(lr, 'lr_orgin.m')
  87. # 预测及 AUC 评测
  88. y_pred_test = lr.predict_proba(X_test_origin)[:, 1]
  89. lr_test_auc = roc_auc_score(y_test_origin, y_pred_test)
  90. print('基于原有特征的 LR AUC: %.5f' % lr_test_auc)
  91. # lr 对 load xgboost 特征编码后的样本模型训练
  92. lr = LogisticRegression(n_jobs=-1, C=0.1, penalty='l1')
  93. lr.fit(X_train, y_train)
  94. joblib.dump(lr, 'lr_xgb.m')
  95. # 预测及 AUC 评测
  96. y_pred_test = lr.predict_proba(X_test)[:, 1]
  97. lr_test_auc = roc_auc_score(y_test, y_pred_test)
  98. print('基于 Xgboost 特征编码后的 LR AUC: %.5f' % lr_test_auc)
  99. # 基于原始特征组合 xgboost 编码后的特征
  100. X_train_ext = hstack([X_train_origin, X_train])
  101. del(X_train)
  102. del(X_train_origin)
  103. X_test_ext = hstack([X_test_origin, X_test])
  104. del(X_test)
  105. del(X_test_origin)
  106. # lr 对组合后的新特征的样本进行模型训练
  107. lr = LogisticRegression(n_jobs=-1, C=0.1, penalty='l1')
  108. lr.fit(X_train_ext, y_train)
  109. joblib.dump(lr, 'lr_ext.m')
  110. # 预测及 AUC 评测
  111. y_pred_test = lr.predict_proba(X_test_ext)[:, 1]
  112. lr_test_auc = roc_auc_score(y_test, y_pred_test)
  113. print('基于组合特征的 LR AUC: %.5f' % lr_test_auc)
  114. if __name__ == '__main__':
  115. xgb_feature_encode("/Users/leiyang/xgboost/demo/data/agaricus.txt.train")
  116. xgboost_lr_train("xgb_feature_libsvm","/Users/leiyang/xgboost/demo/data/agaricus.txt.train")

下面给出一个 ipynb 文件,也是从官方的文件改过来的,主要是对 GBDT 输出到 lr 部分数据观察

view rawgbdt_lr.ipynb hosted with ❤ by GitHub

GBDT和LR结合使用分析相关推荐

  1. CTR预估中GBDT与LR融合方案

    1. 背景 CTR预估(Click-Through Rate Prediction)是互联网计算广告中的关键环节,预估准确性直接影响公司广告收入.CTR预估中用的最多的模型是LR(Logistic R ...

  2. 杨鹏谈世纪佳缘推荐算法:基于Spark GraphX,弃GBDT和LR用FM

     杨鹏谈世纪佳缘推荐算法:基于Spark GraphX,弃GBDT和LR用FM 发表于2015-09-30 09:53| 1447次阅读| 来源CSDN| 2 条评论| 作者杨鹏 机器学习推荐算法 ...

  3. GBDT+LR记录- 9.7代码训练GBDT与LR混合模型

    GBDT+LR记录 9.7代码训练GBDT与LR混合模型 在上一节课的train.py中,新建一个函数train_tree_and_lr_model def train_tree_and_lr_mod ...

  4. 【机器学习基础】GBDT 与 LR 的区别总结

    作者:杜博亚,阿里算法工程师,复旦大学计算机硕士,BDKE 之光. 1.从机器学习三要素的角度 1.1 模型 本质上来说,他们都是监督学习,判别模型,直接对数据的分布建模,不尝试挖据隐含变量,这些方面 ...

  5. 【机器学习】GBDT 与 LR 的区别总结

    作者:杜博亚,阿里算法工程师,复旦大学计算机硕士,BDKE 之光. 1.从机器学习三要素的角度 1.1 模型 本质上来说,他们都是监督学习,判别模型,直接对数据的分布建模,不尝试挖据隐含变量,这些方面 ...

  6. SCALA下的GBDT与LR融合实现

    我们直接使用的ML的包对GBDT/LR进行融合 首先我们需要导入的包如下所示: import org.apache.spark.sql. Row import scala.collection.mut ...

  7. 决策树(十)--GBDT及OpenCV源码分析

    一.原理 梯度提升树(GBT,Gradient Boosted Trees,或称为梯度提升决策树)算法是由Friedman于1999年首次完整的提出,该算法可以实现回归.分类和排序.GBT的优点是特征 ...

  8. LR 杂记--nmon 分析 AIX 和 Linux 性能

    用法说明:这个 nmon 工具并未受到正式支持.没有提供或隐含任何保证,并且您无法从 IBM 获取相关的帮助. nmon 工具运行于: AIX® 4.1.5.4.2.0.4.3.2 和 4.3.3(n ...

  9. LR测试结果分析参数说明

    平均事务响应时间 AverageTransation Response Time 优秀:<2s 良好:2-5s 及格:6-10s 不及格:>10s 每秒点击率 Hits perSecond ...

最新文章

  1. OpenCV 相机校正过程中,calibrateCamera函数projectPoints函数的重投影误差的分析
  2. 【JVM】StackOverflowError与OutOfMemoryError
  3. 免费阿里云服务器超爽体验(为阿里做个广告吧)
  4. C++ Primer 5th笔记(2)chapter 2变量和基本类型:引用、const
  5. 系统中常用操作基类(SSH项目中)非常非常经典的部分
  6. db2 oracle mysql sqlserver_mysql、sqlserver、db2、oracle、hsql数据库获取数据库连接方法及分页函数...
  7. 《Redis核心技术与实战》学习总结(2)
  8. PRML-系列一之1.3~1.4
  9. 设计灵感|什么样的登录页能让用户感到体贴?
  10. OpenShift 4 - 用自定义的TLS证书对访问OpenShift的用户认证身份
  11. 谷歌再修复已遭利用的两个高危 Chrome 0day
  12. 读取npy格式的文件
  13. Tomcat+Nginx动静分离
  14. 解决Mac电脑开机无法自动连接蓝牙音箱问题!
  15. Oracle数据库块之旅
  16. C++ Primer Plus 第六版编程练习——第6章
  17. string密钥转PrivateKey和PublicKey
  18. Jenson不等式及其在EM估计与KL散度中的应用
  19. dm8148 开发只boot启动参数vram=128简介
  20. 电源模块KIM-3R35L 超越KIS3R33S YEC-SD200 KIW3312S

热门文章

  1. 安徽计算机应用基础高考试题,安徽省对口高考试题(计算机应用基础部分)
  2. 无盘机服务器,无盘服务器操作系统
  3. 无符号数的算术四则运算中的各类单词的识别_文本反垃圾在花椒直播中的应用概述...
  4. python selenium定位元素方法,python + selenium 练习篇 - 定位元素的方法
  5. php申请证书,用phpstudy来申请SSL证书
  6. redis配置_Redis配置大全(三)
  7. matlab字母随机排列,matlab实现1n整数的一个随机排列
  8. java scanner字符串_Java Scanner toString()用法及代码示例
  9. java 引用队列_Java中的方法队列
  10. win7查看隐藏文件_隐藏在电脑里の秘密,放在你眼前,你也发现不了,就是这么奥给力...