目录

  • 基于Spark的GBDT + LR模型实现

    • 数据预处理部分
    • GBDT模型部分(省略调参部分)
    • GBDT与LR混合部分

基于Spark的GBDT + LR模型实现

测试数据来源http://archive.ics.uci.edu/ml/machine-learning-databases/adult/

该模型利用Spark mllib的GradientBoostedTrees作为GBDT部分,因为ml模块的GBTClassifier对所生成的模型做了相当严密的封装,导致难以获取某些类或方法。而GradientBoostedTrees所需的训练数据为mllib下的LabeledPoint,所以下面的数据预处理的目标是将cat数据进行编码并生成LabeledPoint。

数据预处理部分

import org.apache.spark.mllib.linalg.{SparseVector => OldSparseVector}
import org.apache.spark.sql.functions._
import spark.implicits._val path = ""val manualSchema = StructType(Array(StructField("age", IntegerType, true),StructField("workclass", StringType, true),StructField("fnlwgt", IntegerType, true),StructField("education", StringType, true),StructField("education-num", IntegerType, true),StructField("marital-status", StringType, true),StructField("occupation", StringType, true),StructField("relationship", StringType, true),StructField("race", StringType, true),StructField("sex", StringType, true),StructField("capital-gain", IntegerType, true),StructField("capital-loss", IntegerType, true),StructField("hours-per-week", IntegerType, true),StructField("native-country", StringType, true),StructField("label", StringType, true)))val df = spark.read.option("header", false).option("delimiter", ",").option("nullValue", "?").schema(manualSchema).format("csv").load(path + "adult.data.txt")
//      .limit(1000)// 去掉代表序列号的col
var df1 = df.drop("fnlwgt").na.drop()val allFeature = df1.columns.dropRight(1)// colName和index的映射
val colIdx = new util.HashMap[String, Int](allFeature.length)
var idx = 0
while (idx < allFeature.length){colIdx.put(allFeature(idx), idx)idx += 1
}val numCols = Array("age", "education-num", "capital-gain", "capital-loss", "hours-per-week")
val catCols = df1.columns.dropRight(1).diff(numCols)
val numLen = numCols.length
val catLen = catCols.length// 处理label
def labeludf(elem: String):Int = {if (elem == "<=50K") 0else 1
}val labelIndexer = udf(labeludf(_:String):Int)// 也可以用 when 函数
// val labelIndexer = when($"lable" === "<=50K", 0).otherwise(1)df1 = df1.withColumn("indexed_label", labelIndexer(col("label"))).drop("label")// 处理cat列
// 所有cat列统一编码,例如有两列cat,第一列为性别,第二列为早、午、晚,那么第一列的编码为0或1,而第二列的编码为2、3或4。下面实现仿照StringIndexer,可能更高效
val inderMap: util.HashMap[String, util.HashMap[String, Int]] = new util.HashMap(catCols.length)
var i = numCols.length
for (column <- catCols) {val uniqueElem = df1.select(column).groupBy(column).agg(count(column)).select(column).map(_.getAs[String](0)).collect()val len = uniqueElem.lengthvar index = 0val freqMap = new util.HashMap[String, Int](len)while (index < len) {freqMap.put(uniqueElem(index), i)index += 1i += 1}inderMap.put(column, freqMap)
}val bcMap = spark.sparkContext.broadcast(inderMap)val d = i// 合并为LabeledPoint
val df2 = df1.rdd.map { row =>val indics = new Array[Int](numLen + catLen)val value = new Array[Double](numLen + catLen)var i = 0for (col <- numCols) {indics(i) = ivalue(i) = row.getAs[Int](colIdx.get(col)).toDoublei += 1}for (col <- catCols) {indics(i) = bcMap.value.get(col).get(row.getAs[String](colIdx.get(col)))value(i) = 1i += 1}new LabeledPoint(row.getAs[Int](numLen + catLen), new OldSparseVector(d, indics, value))
}
val ds = df2.toDF("label", "feature")
ds.write.save(path + "processed")

GBDT模型部分(省略调参部分)

val path = ""
val df = spark.read.load(path).rdd.map(row => LabeledPoint(row.getAs[Double](0), row.getAs[OldSparseVector](1)))// Train a GradientBoostedTrees model.
val boostingStrategy = BoostingStrategy.defaultParams("Classification")
boostingStrategy.numIterations = 10
boostingStrategy.treeStrategy.numClasses = 2
boostingStrategy.treeStrategy.maxDepth = 3
boostingStrategy.learningRate = 0.3
// Empty categoricalFeaturesInfo indicates all features are continuous.
boostingStrategy.treeStrategy.categoricalFeaturesInfo = Map[Int, Int]()val model = GradientBoostedTrees.train(df, boostingStrategy)model.save(spark.sparkContext, path + "GBDTmodel")

GBDT与LR混合部分

object GBTLRTraining {// 遍历一颗决策树,找出其出口的叶子节点iddef predictModify(node: OldNode, features: OldSparseVector): Int = {val split = node.splitif (node.isLeaf) {node.id - 1 // 改为0-base} else {if (split.get.featureType == FeatureType.Continuous) {if (features(split.get.feature) <= split.get.threshold) {predictModify(node.leftNode.get, features)} else {predictModify(node.rightNode.get, features)}} else {if (split.get.categories.contains(features(split.get.feature))) {predictModify(node.leftNode.get, features)} else {predictModify(node.rightNode.get, features)}}}}// 获取每棵树的出口叶子节点id数组def getGBTFeatures(gbtModel: GradientBoostedTreesModel, oldFeatures: OldSparseVector): Array[Int] = {val GBTMaxIter = gbtModel.trees.lengthval leafIdArray = new Array[Int](GBTMaxIter)for (i <- 0 until GBTMaxIter) {val treePredict = predictModify(gbtModel.trees(i).topNode, oldFeatures)leafIdArray(i) = treePredict}leafIdArray}def main(args: Array[String]): Unit = {val spark = SparkSession.builder().master("local[*]").appName("TEST")// 本地配置.config("spark.sql.shuffle.partitions", 12).config("spark.default.parallelism", 12).config("spark.memory.fraction", 0.75)//      .config("spark.memory.ofHeap.enabled", true)//      .config("spark.memory.ofHeapa.size", "2G").config("spark.serializer", "org.apache.spark.serializer.KryoSerializer")//      .config("spark.executor.memory", "2G").getOrCreate()spark.sparkContext.setLogLevel("ERROR")import org.apache.spark.sql.functions._import spark.implicits._val path = ""val df = spark.read.load(path)val model = GradientBoostedTreesModel.load(spark.sparkContext, path +  "GBDTmodel")val bcmodel = spark.sparkContext.broadcast(model)var treeNodeNum = 0var treeDepth = 0// 获取最大的树的数据for (elem <- model.trees) {if (treeNodeNum < elem.numNodes){treeNodeNum = elem.numNodestreeDepth = elem.depth}}val leafNum = math.pow(2, treeDepth).toIntval nonLeafNum = treeNodeNum - leafNumval totalColNum = leafNum * model.trees.length//    print(leafNum + " " + nonLeafNum + " " + tree.numNodes + " " + totalColNum)// 利用之前训练好的GBT模型进行特征提取,并把原特征OldSparseVector转化为ml的SparseVector,让后续的LR使用val addFeatureUDF = udf { features: OldSparseVector =>val gbtFeatures = getGBTFeatures(bcmodel.value, features)var i = 0while (i < gbtFeatures.length){val leafIdx = gbtFeatures(i) - nonLeafNum// 有些树可能没有生长完全,leafIdx没有达到最大的树的最后一层,这里将这些情况默认为最大的树的最后一层的第一个叶子节点。gbtFeatures(i) = (if (leafIdx < 0) 0 else leafIdx) + i * leafNumi += 1}val idx = gbtFeaturesval values = Array.fill[Double](idx.length)(1.0)Vectors.sparse(totalColNum, idx, values)}val dsWithCombinedFeatures = df.withColumn("lr_feature", addFeatureUDF(col("feature")))//    dsWithCombinedFeatures.show(false)val lr = new LogisticRegression().setMaxIter(500).setFeaturesCol("lr_feature").setLabelCol("label")val lrmodel = lr.fit(dsWithCombinedFeatures)val res = lrmodel.transform(dsWithCombinedFeatures)//    res.show(false)val evaluator1 = new MulticlassClassificationEvaluator().setMetricName("accuracy").setLabelCol("label").setPredictionCol("prediction")println("ACC:" + evaluator1.evaluate(res))val evaluator2 = new BinaryClassificationEvaluator().setMetricName("areaUnderROC").setLabelCol("label").setRawPredictionCol("prediction")println("AUC:" + evaluator2.evaluate(res))}
}

参考资料:

https://github.com/wzhe06/CTRmodel

转载于:https://www.cnblogs.com/code2one/p/10366183.html

基于Spark的GBDT + LR模型实现相关推荐

  1. 推荐系统 | 基础推荐模型 | GBDT+LR模型 | Python实现

    基础推荐模型--传送门: 推荐系统 | 基础推荐模型 | 协同过滤 | UserCF与ItemCF的Python实现及优化 推荐系统 | 基础推荐模型 | 矩阵分解模型 | 隐语义模型 | PyTor ...

  2. 推荐系统(二)GBDT+LR模型

    推荐系统(二)GBDT+LR模型 推荐系统系列博客: 推荐系统(一)推荐系统整体概览 在写这篇博客之前,一度纠结许久,到底该不该起这个标题,因为把GBDT+LR模型放在推荐系统系列里,似乎有些不妥,如 ...

  3. Facebook的GBDT+LR模型python代码实现

    承接上篇讲解,本文代码,讲解看上篇 目标:GBDT+LR模型 步骤:GBDT+OneHot+LR 测试数据:iris 代码: 结果比较:与直接GBDT模型的比较 目标:GBDT+LR模型 实现GBDT ...

  4. 传统推荐算法Facebook的GBDT+LR模型深入理解

    目标: 深入理解Facebook 2014年提出的的GBDT+LR模型. CSDN上泛泛而谈的文章很多,真正讲解透彻的没几篇.争取我这篇能讲解透彻. 今晚又想了许久,想通了一些原理.也分享出来. 算法 ...

  5. 推荐系统与深度学习(十四)——GBDT+LR模型原理

    公众号后台回复"图书",了解更多号主新书内容 作者:livan 来源:数据python与算法 模型原理 与GBDT+LR模型结缘是刚开始学习推荐系统的过程中,FaceBook一推出 ...

  6. GBDT + LR模型融合

    n)[1],LR是广义线性模型,与传统线性模型相比,LR使用了Logit变换将函数值映射到0~1区间[2],映射后的函数值就是CTR的预估值.LR这种线性模型很容易并行化,处理上亿条训练样本不是问题, ...

  7. 广告点击率(CTR)预测经典模型 GBDT + LR 理解与实践(附数据 + 代码)

    CTR 系列文章: 广告点击率(CTR)预测经典模型 GBDT + LR 理解与实践(附数据 + 代码) CTR经典模型串讲:FM / FFM / 双线性 FFM 相关推导与理解 CTR深度学习模型之 ...

  8. AI上推荐 之 逻辑回归模型与GBDT+LR(特征工程模型化的开端)

    1. 前言 随着信息技术和互联网的发展, 我们已经步入了一个信息过载的时代,这个时代,无论是信息消费者还是信息生产者都遇到了很大的挑战: 信息消费者:如何从大量的信息中找到自己感兴趣的信息? 信息生产 ...

  9. 【Spark】Spark训练Lr模型,并保存为Pmml

    scala版本spark构建的Lr模型: 一.问题背景   需要构建一个Lr模型来进行物品的Ctr预测. 二.解决方案   由于我们训练的数据量较多,所以首先考虑采用spark来构建模型并测试训练,这 ...

  10. 推荐系统入门(五):GBDT+LR(附代码)

    推荐系统入门(五):GBDT+LR(附代码) 目录 推荐系统入门(五):GBDT+LR(附代码) 引言 1. GBDT模型 2. LR模型 3. GBDT+LR模型 4. 编程实践 实战 思考 参考资 ...

最新文章

  1. 如何将UI5应用部署到Fiori On-Premise和On-Cloud的Launchpad上去
  2. linux arm gcc 内联汇编参考手册
  3. 两边双虚线是什么意思_单黄线和双黄线有什么不同?很多人都记不对,被扣分都不知道...
  4. centos7设置时间为日本东京时间
  5. 编译php的时候,报configure: error: mcrypt.h not found. Please reinstall libmcrypt.错误的解决办法...
  6. 5G为何采纳华为力挺的Polar码?一个通信工程师的大实话
  7. Zynq7000开发系列-5(OpenCV开发环境搭建:Ubuntu、Zynq)
  8. 计算机维修队,浙江万里学院计算机维修队
  9. edu汇编语言——实训课程
  10. NUC980开发板Linux系统EC20模块 移植 串口 PPP拨号
  11. SAP License:税额保留小数位差异处理
  12. AcWing 1055. 股票买卖 II
  13. 【ModBus】基础使用:【01】MThings国产调试工具
  14. html文件怎么转换为swf文件,在html里怎么添加flash视频格式(flv、swf)文件
  15. 和机器人问问题的软件_如何开发一个特定领域的自动问答机器人(Chat Bot)?
  16. PyCharm配置Virtual Environment
  17. 2020移动apn接入点哪个快_最新联通上网卡APN的设置方法
  18. 等了3个月终于来啦!传智播客C/C++视频教程开始更新喽~
  19. 拓展KubeVela模块,看addon如何助力开放生态
  20. php返回值乱码,php中文返回乱码怎么办

热门文章

  1. 【控制】反馈控制入门,PID控制
  2. 都2022了,我为什么还要写博客?
  3. utc时间 单位换算_utc时间(utc时间转换北京时间)
  4. 机械工程和人工智能关系
  5. 2021年JAVA多线程并发编程面试题(持续更新)
  6. 实战 用Python放一场浪漫的烟花秀
  7. 《数据资产管理实践白皮书3.0》发布!(附全文下载)
  8. 让ffmpeg支持输出h264格式
  9. linux开发屏幕保护代码,使用xscreensaver编写屏幕保护程序的提示和技巧?
  10. Windows系统以及office等一键激活