spark上运行xgboost-scala接口
概述
xgboost可以在spark上运行,我用的xgboost的版本是0.7的版本,目前只支持spark2.0以上版本上运行,
编译好jar包,加载到maven仓库里面去:
mvn install:install-file -Dfile=xgboost4j-spark-0.7-jar-with-dependencies.jar -DgroupId=ml.dmlc -DartifactId=xgboost4j-spark -Dversion=0.7 -Dpackaging=jar
添加依赖:
<dependency>
<groupId>ml.dmlc</groupId>
<artifactId>xgboost4j-spark</artifactId>
<version>0.7</version>
</dependency>
<dependency>
<groupId>org.apache.spark</groupId>
<artifactId>spark-core_2.10</artifactId>
<version>2.0.0</version>
</dependency>
<dependency>
<groupId>org.apache.spark</groupId>
<artifactId>spark-mllib_2.10</artifactId>
<version>2.0.0</version>
</dependency>
</dependencies>
RDD接口:
package com.meituan.spark_xgboost
import org.apache.log4j.{ Level, Logger }
import org.apache.spark.{ SparkConf, SparkContext }
import ml.dmlc.xgboost4j.scala.spark.XGBoost
import org.apache.spark.sql.{ SparkSession, Row }
import org.apache.spark.mllib.util.MLUtils
import org.apache.spark.ml.feature.LabeledPoint
import org.apache.spark.ml.linalg.Vectors
object XgboostR {
def main(args: Array[String]): Unit = {
Logger.getLogger("org.apache.spark").setLevel(Level.ERROR)
Logger.getLogger("org.eclipse.jetty.server").setLevel(Level.OFF)
val spark = SparkSession.builder.master("local").appName("example").
config("spark.sql.warehouse.dir", s"file:///Users/shuubiasahi/Documents/spark-warehouse").
config("spark.sql.shuffle.partitions", "20").getOrCreate()
spark.conf.set("spark.serializer", "org.apache.spark.serializer.KryoSerializer")
val path = "/Users/shuubiasahi/Documents/workspace/xgboost/demo/data/"
val trainString = "agaricus.txt.train"
val testString = "agaricus.txt.test"
val train = MLUtils.loadLibSVMFile(spark.sparkContext, path + trainString)
val test = MLUtils.loadLibSVMFile(spark.sparkContext, path + testString)
val traindata = train.map { x =>
val f = x.features.toArray
val v = x.label
LabeledPoint(v, Vectors.dense(f))
}
val testdata = test.map { x =>
val f = x.features.toArray
val v = x.label
Vectors.dense(f)
}
val numRound = 15
//"objective" -> "reg:linear", //定义学习任务及相应的学习目标
//"eval_metric" -> "rmse", //校验数据所需要的评价指标 用于做回归
val paramMap = List(
"eta" -> 1f,
"max_depth" ->5, //数的最大深度。缺省值为6 ,取值范围为:[1,∞]
"silent" -> 1, //取0时表示打印出运行时信息,取1时表示以缄默方式运行,不打印运行时信息。缺省值为0
"objective" -> "binary:logistic", //定义学习任务及相应的学习目标
"lambda"->2.5,
"nthread" -> 1 //XGBoost运行时的线程数。缺省值是当前系统可以获得的最大线程数
).toMap
println(paramMap)
val model = XGBoost.trainWithRDD(traindata, paramMap, numRound, 55, null, null, useExternalMemory = false, Float.NaN)
print("sucess")
val result=model.predict(testdata)
result.take(10).foreach(println)
spark.stop();
}
}
DataFrame接口:
package com.meituan.spark_xgboost
import org.apache.log4j.{ Level, Logger }
import org.apache.spark.{ SparkConf, SparkContext }
import ml.dmlc.xgboost4j.scala.spark.XGBoost
import org.apache.spark.mllib.evaluation.BinaryClassificationMetrics
import org.apache.spark.sql.{ SparkSession, Row }
object XgboostD {
def main(args: Array[String]): Unit = {
Logger.getLogger("org.apache.spark").setLevel(Level.ERROR)
Logger.getLogger("org.eclipse.jetty.server").setLevel(Level.OFF)
val spark = SparkSession.builder.master("local").appName("example").
config("spark.sql.warehouse.dir", s"file:///Users/shuubiasahi/Documents/spark-warehouse").
config("spark.sql.shuffle.partitions", "20").getOrCreate()
spark.conf.set("spark.serializer", "org.apache.spark.serializer.KryoSerializer")
val path = "/Users/shuubiasahi/Documents/workspace/xgboost/demo/data/"
val trainString = "agaricus.txt.train"
val testString = "agaricus.txt.test"
val train = spark.read.format("libsvm").load(path + trainString).toDF("label", "feature")
val test = spark.read.format("libsvm").load(path + testString).toDF("label", "feature")
val numRound = 15
//"objective" -> "reg:linear", //定义学习任务及相应的学习目标
//"eval_metric" -> "rmse", //校验数据所需要的评价指标 用于做回归
val paramMap = List(
"eta" -> 1f,
"max_depth" -> 5, //数的最大深度。缺省值为6 ,取值范围为:[1,∞]
"silent" -> 1, //取0时表示打印出运行时信息,取1时表示以缄默方式运行,不打印运行时信息。缺省值为0
"objective" -> "binary:logistic", //定义学习任务及相应的学习目标
"lambda" -> 2.5,
"nthread" -> 1 //XGBoost运行时的线程数。缺省值是当前系统可以获得的最大线程数
).toMap
val model = XGBoost.trainWithDataFrame(train, paramMap, numRound, 45, obj = null, eval = null, useExternalMemory = false, Float.NaN, "feature", "label")
val predict = model.transform(test)
val scoreAndLabels = predict.select(model.getPredictionCol, model.getLabelCol)
.rdd
.map { case Row(score: Double, label: Double) => (score, label) }
//get the auc
val metric = new BinaryClassificationMetrics(scoreAndLabels)
val auc = metric.areaUnderROC()
println("auc:" + auc)
}
}
---------------------
作者:旭旭_哥
来源:CSDN
原文:https://blog.csdn.net/luoyexuge/article/details/71422270
版权声明:本文为博主原创文章,转载请附上博文链接!
spark上运行xgboost-scala接口相关推荐
- xgboost之spark上运行-scala接口
概述 xgboost可以在spark上运行,我用的xgboost的版本是0.7的版本,目前只支持spark2.0以上版本上运行, 编译好jar包,加载到maven仓库里面去: mvn install: ...
- 在Spark上运行WordCount程序
1.编写程序代码如下: Wordcount.scala package Wordcount import org.apache.spark.SparkConf import org.apache.sp ...
- Flink学习笔记04:将项目打包提交到Flink集群上运行(Scala版)
文章目录 一.创建Maven项目 - ScalaWordCount 三.利用mvn命令打包Maven项目 三.上传项目jar包到Flink集群主节点 四.启动Flink Standalone集群 五. ...
- Intellij IDEA使用Maven搭建spark开发环境(scala)
如何一步一步地在Intellij IDEA使用Maven搭建spark开发环境,并基于scala编写简单的spark中wordcount实例. 1.准备工作 首先需要在你电脑上安装jdk和scala ...
- hortonworks_具有在IBM POWER8上运行的Hortonworks Data Platform(HDP)的SAS软件
Hadoop的SAS / ACCESS接口 Hadoop的SAS / ACCESS接口提供了访问SAS本机中Hadoop中存储的数据集的功能. 通过SAS / ACCESS到Hadoop: LIBNA ...
- Spark利用(idea+maven+scala)创建wordcount打包jar并在spark on yarn上运行——Spark的开发
今天的你不付昨天的辛苦,今天的辛苦定会拥抱明日的幸福, 每一次的成长,都要给自己以鼓励,每一次的突破,都给自己以信心,万花丛中我不是最美,但我有我的自信 ------------送给一直努力的你 今天 ...
- 在local模式下的spark程序打包到集群上运行
一.前期准备 前期的环境准备,在Linux系统下要有Hadoop系统,spark伪分布式或者分布式,具体的教程可以查阅我的这两篇博客: Hadoop2.0伪分布式平台环境搭建 Spark2.4.0伪分 ...
- spark在集群上运行
1.spark在集群上运行应用的详细过程 (1)用户通过spark-submit脚本提交应用 (2)spark-submit脚本启动驱动器程序,调用用户定义的main()方法 (3)驱动器程序与集群管 ...
- DolphiScheduler平台上运行spark程序时,外部参数设置
DolphiScheduler平台上运行spark程序时,外部参数设置 近期使用DS平台执行spark程序,遇到了部分参数设置的问题,代码中需要外部传入一个参数procDate(处理日期),具体设置如 ...
最新文章
- RabbitMQ系列之【启动过程中遇到问题及解决方案】
- 性能建议(这里只针对单机版redis持久化做性能建议)
- XGBoost算法概述
- SAP S/4HANA生产订单释放后自动同步到MES系统
- 飞机游戏跟踪导弹的算法[C#源码]
- 猿创征文|C++软件开发值得推荐的十大高效软件分析工具
- 串口IEC103协议
- IBM 存储管理软件IBM DS Linux Storage Manager安装(Linux)
- 开源的项目管理软件——OpenProj
- linux redis玂家链接不上,Unicode编码的熟悉与研究过程(内附全部汉字编码列表)...
- 修改Window的hosts文件提示“该文件被其他程序占用,无法修改问题”解决方案
- 如何搭建个人视频点播服务器
- java 读取手机sd卡_获取Android手机中SD卡内存信息
- 湖南大学ACM程序设计新生杯大赛(同步赛)L-Liao Han【打表规律+二分】
- Directx11学习笔记【十】 画一个简单的三角形
- BUUCTF WEB Easy Calc
- 在线文档编辑插件——KindEditor
- ASCII码和汉字码
- mac下面安装破解版UltraEdit
- Nutch 分布式运行模式 (v1.14)
热门文章
- 思科bfd静态路由切换_思科路由器与华为路由器静态路由关联双向BFD配置
- mysql内外链接图_图解MySQL 内连接、外连接、左连接、右连接、全连接
- 苹果mac电脑修改并快速linux网络配置
- JSP中动态添加 “添加附件选择框”
- 按照月的第几周统计_商标评审案件审理情况月报(2020年第11期)
- python调用shell用什么类_python脚本中调用shell命令
- python温度转换代码分析_Python温度转换实例分析
- java 矩阵题目_一些数学分析不错的题目
- python内核大小_关于keras.layers.Conv1D的kernel_size参数使用介绍
- 人工智能 ppt_【138期】厉害了!人工智能高清大图+PPT模板全集系列!