scala+spark+randomForests

代码实现分为读取hive数据、随机森林建模训练、数据预测

  1. 随机森林建模训练的代码主类(实现流程)
package com.inspur.mr.InspurMr.Classification
import java.io.File
import java.io.PrintWriter
import java.util.ArrayList
import org.apache.spark.sql.hive.HiveContext
import org.apache.spark.ml.feature.VectorAssembler
import org.apache.spark.mllib.tree.model.RandomForestModel
import org.apache.spark.rdd.RDD
import org.apache.spark.mllib.util.MLUtils
import org.apache.spark.mllib.tree.RandomForest
import org.apache.spark.mllib.linalg.{ Vector, Vectors }
import com.inspur.mr.InspurMr.conf.RandomForestConf
import com.inspur.mr.InspurMr.Util.Quota
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.tree.configuration.Strategy
import org.apache.hadoop.fs.{FileStatus, FileSystem, FileUtil, Path}
import java.text.SimpleDateFormat
import java.util.Dateobject RandomWithHive extends RandomForestConf {def main(args: Array[String]): Unit = {import hc.implicits._//    从hive中获取数据val database = paraproperties.getProperty("database")val null_fill = paraproperties.getProperty("null_fill")val eare_lon_left = paraproperties.getProperty("eare_lon_left")val eare_lat_left = paraproperties.getProperty("eare_lat_left")val eare_lon_right = paraproperties.getProperty("eare_lon_right")val eare_lat_right = paraproperties.getProperty("eare_lat_right")val grid_length = paraproperties.getProperty("grid_length")val grid_num = paraproperties.getProperty("grid_num").toIntval disgrid = grid_length.toDouble*0.000009hc.sql(s"use $database")val data1 = hc.sql(s"""select floor(($eare_lat_left-lat_uri)/$disgrid)*$grid_num+floor((long_uri-$eare_lon_left)/$disgrid) as llgridid,cellid*1.0 as cellid,ltesctadv*1.0 as ltesctadv,ltescaoa*1.0 as ltescaoa,ltescphr*1.0 as ltescphr,ltescrip*1.0 as ltescrip,ltescsinrul*1.0 as ltescsinrul,ltescearfcn*1.0 as ltescearfcn,ltescpci*1.0 as ltescpci,LON0*1.0 as LON0,LAT0*1.0 as LAT0,azimuth0*1.0 as azimuth0,coverflag0*1.0 as coverflag0,nettype0*1.0 as nettype0,ltescrsrp*1.0 as ltescrsrp,ltescrsrq*1.0 as ltescrsrq,ltencrsrp1*1.0 as ltencrsrp1,ltencrsrq1*1.0 as ltencrsrq1,ltencearfcn1*1.0 as ltencearfcn1,ltencpci1*1.0 as ltencpci1,ltencrsrp2*1.0 as ltencrsrp2,ltencrsrq2*1.0 as ltencrsrq2,ltencearfcn2*1.0 as ltencearfcn2,ltencpci2*1.0 as ltencpci2,ltencrsrp3*1.0 as ltencrsrp3,ltencrsrq3*1.0 as ltencrsrq3,ltencearfcn3*1.0 as ltencearfcn3,ltencpci3*1.0 as ltencpci3 from dw_pods_mro_eutrancell_yyyymmdd where lat_uri<$eare_lat_left and lat_uri>$eare_lat_right and long_uri>$eare_lon_left and long_uri<$eare_lon_right and pow(long_uri-LON0,2)+pow(lat_uri-LAT0,2)<0.00002025 order by hour_id desc limit 30000000""".stripMargin)//    val pathpath = "file:///C:\\Users\\wangkai01\\Desktop\\data\\csvtest.csv"val data = data1.na.fill(null_fill.toDouble).cache()println(s"""select floor(($eare_lat_left-lat_uri)/$disgrid)*$grid_num+floor((long_uri-$eare_lon_left)/$disgrid) as llgridid,cellid*1.0 as cellid,ltesctadv*1.0 as ltesctadv,ltescaoa*1.0 as ltescaoa,ltescphr*1.0 as ltescphr,ltescrip*1.0 as ltescrip,ltescsinrul*1.0 as ltescsinrul,ltescearfcn*1.0 as ltescearfcn,ltescpci*1.0 as ltescpci,LON0*1.0 as LON0,LAT0*1.0 as LAT0,azimuth0*1.0 as azimuth0,coverflag0*1.0 as coverflag0,nettype0*1.0 as nettype0,ltescrsrp*1.0 as ltescrsrp,ltescrsrq*1.0 as ltescrsrq,ltencrsrp1*1.0 as ltencrsrp1,ltencrsrq1*1.0 as ltencrsrq1,ltencearfcn1*1.0 as ltencearfcn1,ltencpci1*1.0 as ltencpci1,ltencrsrp2*1.0 as ltencrsrp2,ltencrsrq2*1.0 as ltencrsrq2,ltencearfcn2*1.0 as ltencearfcn2,ltencpci2*1.0 as ltencpci2,ltencrsrp3*1.0 as ltencrsrp3,ltencrsrq3*1.0 as ltencrsrq3,ltencearfcn3*1.0 as ltencearfcn3,ltencpci3*1.0 as ltencpci3 from dw_pods_mro_eutrancell_yyyymmdd where lat_uri<$eare_lat_left and lat_uri>$eare_lat_right and long_uri>$eare_lon_left and long_uri<$eare_lon_right and pow(long_uri-LON0,2)+pow(lat_uri-LAT0,2)<0.00002025 order by hour_id desc limit 30000000""".stripMargin)println("run here1 !!!!!!!!")
//    data.show()//    特征val featInd = List("cellid", "ltesctadv", "ltescaoa", "ltescphr", "ltescrip", "ltescsinrul", "ltescearfcn", "ltescpci", "LON0", "LAT0", "azimuth0", "coverflag0", "nettype0", "ltescrsrp", "ltescrsrq", "ltencrsrp1", "ltencrsrq1", "ltencearfcn1", "ltencpci1", "ltencrsrp2", "ltencrsrq2", "ltencearfcn2", "ltencpci2", "ltencrsrp3", "ltencrsrq3", "ltencearfcn3", "ltencpci3").map(data.columns.indexOf(_))println(featInd)//    标签labelval Label = data.columns.indexOf("llgridid")val datause = data.map { x =>val label = x(0).toString().toIntval feature = Vectors.dense(featInd.map(x.getDouble(_)).toArray)//                   println(feature)LabeledPoint(label, feature)}println("run here2 !!!!!!!!")//生成训练集和测试集val splits = datause.randomSplit(Array(tarining_rate, test_rate))val (trainingData, testData) = (splits(0), splits(1))//查看训练样本的标签数,做为分类数目
//    val numClasses = (datause.map { x => x.label }.max() + 1).toIntval numClasses = class_num//获取训练样本和测试样本的数量val train_sample = trainingData.count()val test_sample = testData.count()println("run here3 !!!!!!!!")//若存在上次训练文件则删除,并新建模型保存目录。val path = new Path(model_out_path);    val hdfs = org.apache.hadoop.fs.FileSystem.get(    new java.net.URI(model_out_path), new org.apache.hadoop.conf.Configuration()) if (!hdfs.exists(path)){hdfs.mkdirs(path)}else{hdfs.delete(path, true)  hdfs.mkdirs(path)}var bestscore=0.0for (numTrees <- treeList; maxDepth <- depthList) {val s = Strategy.defaultStrategy("Classification")s.setMaxMemoryInMB(2048)s.setNumClasses(numClasses)s.setMaxDepth(maxDepth)s.setMaxBins(maxBins)val model = RandomForest.trainClassifier(trainingData, s, numTrees, featureSubsetStrategy, 10)// 测试数据评价训练好的分类器并计算错误率val labelAndPreds = testData.map { point =>val prediction = model.predict(point.features)(point.label, prediction)}val quota = Quota.calculate(labelAndPreds, testData)val testErr = quota._1//      val testRecall = quota._3//      val f1_score = quota._4println("Test Error = " + testErr)
//            println("Learned classification forest model:\n" + model.toDebugString)//      hdfs.createNewFile(new Path(describe + s"result-$numTrees-$maxDepth-$testErr.txt")) //      val dirfile = new File(describe);
//      if (!dirfile.isDirectory()) {
//        dirfile.mkdirs()
//      }
//      val resultfile = new File(describe + s"result-$numTrees-$maxDepth-$testErr.txt")
//      if(!resultfile.isFile()){
//        val writer = new PrintWriter(resultfile)
//        //      writer.println("train pos count:" + pos_sample + "\n")
//        //      writer.println("train neg count:" + neg_sample + "\n")
//        writer.println("train count:" + train_sample + "\n")
//        writer.println("test count:" + test_sample + "\n")
//        writer.println("Test Error = " + testErr + "\n")
//        writer.println(model.toDebugString)
//        writer.close()
//      }println(s"model-$numTrees-$maxDepth:"+(1-testErr))println(model.toDebugString)// 将训练后的随机森林模型持久化val now: Date = new Date()val dateFormat: SimpleDateFormat = new SimpleDateFormat("yyyy-MM-dd-HH-mm-ss")val date = dateFormat.format(now)val path = new Path(model_out_path + s"model-$numTrees-$maxDepth-$testErr-$date");
//      该参数模型不存在时,则保存模型val hdfs = org.apache.hadoop.fs.FileSystem.get(    new java.net.URI(model_out_path + s"model-$numTrees-$maxDepth-$testErr-$date"), new org.apache.hadoop.conf.Configuration()) if (!hdfs.exists(path)){model.save(sc, model_out_path + s"model-$numTrees-$maxDepth-$testErr-$date")}    if(1-testErr>=bestscore){//该参数模型不存在时,则保存模型val path = new Path(model_file);    val hdfs = org.apache.hadoop.fs.FileSystem.get(    new java.net.URI(model_file), new org.apache.hadoop.conf.Configuration())      if (hdfs.exists(path)) hdfs.delete(path, true)    model.save(sc, model_out_path + "model-RF-best")bestscore = 1-testErr}}sc.stop()println("best score:"+bestscore)println("run done !!!!!!!!")}
}

2、随机森林预测的代码

package com.inspur.mr.InspurMr.Classificationimport com.inspur.mr.InspurMr.conf.AppConf
import org.apache.spark.mllib.tree.model.RandomForestModel
import com.inspur.mr.InspurMr.Util.MLUtils
import org.apache.spark.mllib.linalg.{SparseVector, DenseVector, Vector, Vectors}object RandomPredict extends AppConf {case class TableMrPre(msisdn:String,imsi:String,imei:String,begintime:String,tac:String,eci:String,nettype0:String,long_uri:Double,lat_uri:Double)def main(args: Array[String]): Unit = {val database = paraproperties.getProperty("database")val null_fill = paraproperties.getProperty("null_fill")val eare_lon_left = paraproperties.getProperty("eare_lon_left").toDoubleval eare_lat_left = paraproperties.getProperty("eare_lat_left").toDoubleval eare_lon_right = paraproperties.getProperty("eare_lon_right")val eare_lat_right = paraproperties.getProperty("eare_lat_right")val grid_length = paraproperties.getProperty("grid_length")val grid_num = paraproperties.getProperty("grid_num").toDoubleval disgrid = grid_length.toDouble*0.000009val disgridhalf = grid_length.toDouble*0.000009/2var HOUR_ID = args(0)var MONTH_ID = HOUR_ID.substring(0,6)var DAY_ID = HOUR_ID.substring(0,8)val write_partition = "month_id="+MONTH_ID+","+"day_id="+DAY_ID+","+"hour_id="+HOUR_IDval read_partition = "month_id="+MONTH_ID+" and "+"day_id="+DAY_ID+" and "+"hour_id="+HOUR_IDconf.setAppName("family_test")val pModlePath = postgprop.getProperty("model_file")hc.sql(s"use $database")val data = hc.sql(s"""select cellid*1.0 as cellid,ltesctadv*1.0 as ltesctadv,ltescaoa*1.0 as ltescaoa,ltescphr*1.0 as ltescphr,ltescrip*1.0 as ltescrip,ltescsinrul*1.0 as ltescsinrul,ltescearfcn*1.0 as ltescearfcn,ltescpci*1.0 as ltescpci,LON0*1.0 as LON0,LAT0*1.0 as LAT0,azimuth0*1.0 as azimuth0,coverflag0*1.0 as coverflag0,nettype0*1.0 as nettype0,ltescrsrp*1.0 as ltescrsrp,ltescrsrq*1.0 as ltescrsrq,ltencrsrp1*1.0 as ltencrsrp1,ltencrsrq1*1.0 as ltencrsrq1,ltencearfcn1*1.0 as ltencearfcn1,ltencpci1*1.0 as ltencpci1,ltencrsrp2*1.0 as ltencrsrp2,ltencrsrq2*1.0 as ltencrsrq2,ltencearfcn2*1.0 as ltencearfcn2,ltencpci2*1.0 as ltencpci2,ltencrsrp3*1.0 as ltencrsrp3,ltencrsrq3*1.0 as ltencrsrq3,ltencearfcn3*1.0 as ltencearfcn3,ltencpci3*1.0 as ltencpci3,msisdn,imsi,imei,begintime,tac,eci,nettype0 from dw_pods_mro_eutrancell_pre_yyyymmdd where $read_partition""".stripMargin)println(s"""select cellid*1.0 as cellid,ltesctadv*1.0 as ltesctadv,ltescaoa*1.0 as ltescaoa,ltescphr*1.0 as ltescphr,ltescrip*1.0 as ltescrip,ltescsinrul*1.0 as ltescsinrul,ltescearfcn*1.0 as ltescearfcn,ltescpci*1.0 as ltescpci,LON0*1.0 as LON0,LAT0*1.0 as LAT0,azimuth0*1.0 as azimuth0,coverflag0*1.0 as coverflag0,nettype0*1.0 as nettype0,ltescrsrp*1.0 as ltescrsrp,ltescrsrq*1.0 as ltescrsrq,ltencrsrp1*1.0 as ltencrsrp1,ltencrsrq1*1.0 as ltencrsrq1,ltencearfcn1*1.0 as ltencearfcn1,ltencpci1*1.0 as ltencpci1,ltencrsrp2*1.0 as ltencrsrp2,ltencrsrq2*1.0 as ltencrsrq2,ltencearfcn2*1.0 as ltencearfcn2,ltencpci2*1.0 as ltencpci2,ltencrsrp3*1.0 as ltencrsrp3,ltencrsrq3*1.0 as ltencrsrq3,ltencearfcn3*1.0 as ltencearfcn3,ltencpci3*1.0 as ltencpci3,msisdn,imsi,imei,begintime,tac,eci,nettype0 from dw_pods_mro_eutrancell_pre_yyyymmdd where $read_partition""")println("run here1 !!!!!!!!")//data.show()   //val sameModel = RandomForestModel.load(sc, pModlePath)println("run here2!!!!!")val labelAndPreds = data.map { row =>def isNull(xarr:Any):String = if (null==xarr) "-2" else xarr.toString() val rowStr = isNull(row(0))+" "+isNull(row(1))+" "+isNull(row(2))+" "+isNull(row(3))+" "+isNull(row(4))+" "+isNull(row(5))+" "+isNull(row(6))+" "+isNull(row(7))+" "+isNull(row(8))+" "+isNull(row(9))+" "+isNull(row(10))+" "+isNull(row(11))+" "+isNull(row(12))+" "+isNull(row(13))+" "+isNull(row(14))+" "+isNull(row(15))+" "+isNull(row(16))+" "+isNull(row(17))+" "+isNull(row(18))+" "+isNull(row(19))+" "+isNull(row(20))+" "+isNull(row(21))+" "+isNull(row(22))+" "+isNull(row(23))+" "+isNull(row(24))+" "+isNull(row(25))+" "+isNull(row(26))    val prediction = sameModel.predict(Vectors.dense(rowStr.split(' ').map { _.toDouble }))val glong = prediction%grid_numval glat = prediction/grid_numval lonPre=glong*disgrid+eare_lon_left+disgridhalfval latPre=eare_lat_left-glat*disgrid-disgridhalfTableMrPre(isNull(row(27)),isNull(row(28)),isNull(row(29)),isNull(row(30)),isNull(row(31)),isNull(row(32)),isNull(row(33)),lonPre,latPre)}.cacheprintln("run here4!!!!!")import hc.implicits._ val tabledf = labelAndPreds.toDF()
//    tabledf.show(100)tabledf.registerTempTable("TempTableMrPre")hc.sql("insert OVERWRITE table dw_mr_mme_position_pre partition("+write_partition+") select * from TempTableMrPre")  hc.dropTempTable("TempTableMrPre")sc.stop()println("run done!!!!!")}}

基于spark mllib 随机森林分类 代码记录后续完善相关推荐

  1. 基于java的随机森林算法_基于Spark实现随机森林代码

    本文实例为大家分享了基于Spark实现随机森林的具体代码,供大家参考,具体内容如下 public class RandomForestClassficationTest extends TestCas ...

  2. Python 随机森林分类

    Python 随机森林分类 1 声明 本文的数据来自网络,部分代码也有所参照,这里做了注释和延伸,旨在技术交流,如有冒犯之处请联系博主及时处理. 2 决策树分类简介 相关概念见下: 决策树的最大问题是 ...

  3. 使用基于Apache Spark的随机森林方法预测贷款风险

    原文:Predicting Loan Credit Risk using Apache Spark Machine Learning Random Forests  作者:Carol McDonald ...

  4. 基于python的随机森林回归实现_PYTHON | 随机森林实战(代码+详解)

    大家好,我是菜鸟君,之前跟大家聊过R语言的随机森林建模,指路 R语言 | 随机森林建模实战(代码+详解),作为刚过完1024节日的码农算法工程师来说,怎么可能只会用一种语言呢?今天就来说说Python ...

  5. 随机森林分类算法python代码_随机森林的原理及Python代码实现

    原标题:随机森林的原理及Python代码实现 最近在做kaggle的时候,发现随机森林这个算法在分类问题上效果十分的好,大多数情况下效果远要比svm,log回归,knn等算法效果好.因此想琢磨琢磨这个 ...

  6. 随机森林分类算法python代码_Python机器学习笔记:随机森林算法

    随机森林算法的理论知识 随机森林是一种有监督学习算法,是以决策树为基学习器的集成学习算法.随机森林非常简单,易于实现,计算开销也很小,但是它在分类和回归上表现出非常惊人的性能,因此,随机森林被誉为&q ...

  7. 数学建模_随机森林分类模型详解Python代码

    数学建模_随机森林分类模型详解Python代码 随机森林需要调整的参数有: (1) 决策树的个数 (2) 特征属性的个数 (3) 递归次数(即决策树的深度)''' from numpy import ...

  8. GEE东拼西凑之随机森林分类

    这里记录一下使用landsat5做随机森林分类的代码,理一下思路.很多内容都是到处找教程东拼西凑的,十分感谢各位大佬. 导入研究区.制作标签 首先加载研究区边界,查看需要分类时间的原影像.在影像上添加 ...

  9. 基于python的随机森林回归实现_python实现随机森林

    定义: 随机森林指的是利用多棵决策树对样本进行训练并预测的一种分类器.可回归可分类. 所以随机森林是基于多颗决策树的一种集成学习算法,常见的决策树算法主要有以下几种: 1. ID3:使用信息增益g(D ...

最新文章

  1. 【组队学习】【33期】动手学数据分析
  2. 免费学习AI公开课:打卡、冲击排行榜,还有福利领取
  3. 【项目实践】车距+车辆+车道线+行人检测项目实践
  4. 安装asp.net mvc4后mvc3项目编译报错
  5. view,control,service,dao,model层的关系
  6. 地图上制作线路的动画_魔兽争霸重制版不只是表面上这么简单,新版编辑器制作地图更容易...
  7. OpenMeetings的安装
  8. 神州泰岳2050万元收买并增资奇点国际
  9. NSLayoutConstraint 使用详解 VFL使用介绍
  10. 自从上了 Prometheus 监控,睡觉真香!
  11. c++ windows console 快速编辑模式 关闭
  12. 快速乘 O(lgn) and O(1)
  13. python拟合直线的斜率_Python:直线,斜率k是已知的,一点P1是已知的,长度P1P2是已知的,如何得到P2?...
  14. 2021最受欢迎开源免费CMS建站系统排行榜
  15. VNC_Linux环境服务安装、配置与使用
  16. Spring漫画学习笔记(一) 什么是BeanDefinition
  17. Hadoop单节点设置
  18. python爬虫-异步爬虫
  19. Https、Wss加密实践
  20. bake lightmap in unity 2

热门文章

  1. 纪念碑谷2 - 视觉艺术
  2. Java Swing实现Mybatis3代码生成器,使用jtattoo第三方java Swing美化包
  3. hiveserver2的beeline用法
  4. 【论文导读】Deep Stable Learning for Out-Of-Distribution Generalization
  5. SAP ERP管理系统 | SAP软件 | SAP Business One | SAP Business ByDesign 企业信息化解决方案
  6. 使用hbuilder前端工具直接连接服务器FTP/SFTP连接传输上传文件
  7. 图像常见格式及转换(BGR,YUV,NV12,YUV444)
  8. 罗斯蒙特流量计选用的误区
  9. iphone4越狱后找不到可以安装的openssh
  10. 约翰·冯·诺依曼的一生