前言

昨天媛媛说,你是不是很久没写博客了。我说上一篇1.26号,昨天3.26号,刚好两个月,心中也略微有些愧疚。今天正好有个好朋友问,怎么在Java应用里集成Spark MLlib训练好的模型。在StreamingPro里其实都有实际的使用例子,但是如果有一篇文章讲述下,我觉得应该能让更多人获得帮助

追本溯源

记得我之前吐槽过Spark MLlib的设计,也是因为一个朋友使用了spark MLlib的pipeline做训练,然后他把这个pipeline放到了spring boot里,结果做预测的时候奇慢无比,一条记录inference需要30多秒。为什么会这么慢呢?原因是Spark MLlib 是以批处理为核心设计理念的。比如上面朋友遇到的坑是有一部分原因来源于word2vec的transform方法:

@Since("2.0.0")override def transform(dataset: Dataset[_]): DataFrame = {transformSchema(dataset.schema, logging = true)val vectors = wordVectors.getVectors.mapValues(vv => Vectors.dense(vv.map(_.toDouble))).map(identity) // mapValues doesn't return a serializable map (SI-7005)val bVectors = dataset.sparkSession.sparkContext.broadcast(vectors)val d = $(vectorSize)

来一条数据(通常API应用都是如此),他需要先获得vectors(词到vector的映射)对象,假设你有十万个词,

def getVectors: Map[String, Array[Float]] = {wordIndex.map { case (word, ind) =>(word, wordVectors.slice(vectorSize * ind, vectorSize * ind + vectorSize))}}

每次请求他都要做如上调用和计算。接着还需要把这些东西(这个可能就比较大了,几百M或者几个G都有可能)广播出去。

所以注定快不了。

把model集成到Java 服务里实例

假设你使用贝叶斯训练了一个模型,你需要保存下这个模型,保存的方式如下:

val nb = new NaiveBayes()
//做些参数配置和训练过程
.....
//保存模型
nb.write.overwrite().save(path + "/" + modelIndex)

接着,在你的Java/scala程序里,引入spark core,spark mllib等包。加载模型:

val model = NaiveBayesModel.load(tempPath)

这个时候因为要做预测,我们为了性能,不能直接调用model的transform方法,你仔细观察发现,我们需要通过反射调用两个方法,就能实现分类。第一个是predictRaw方法,该方法输入一个向量,输出也为一个向量。我们其实不需要向量,我们需要的是一个分类的id。predictRaw 方法在model里,但是没办法直接调用,因为是私有的:

  override protected def predictRaw(features: Vector): Vector = {$(modelType) match {case Multinomial =>multinomialCalculation(features)case Bernoulli =>bernoulliCalculation(features)case _ =>// This should never happen.throw new UnknownError(s"Invalid modelType: ${$(modelType)}.")}}

所以我们需要通过反射来完成:

val predictRaw = model.getClass.getMethod("predictRaw", classOf[Vector]).invoke(model, vec).asInstanceOf[Vector]

现在我们已经得到了predctRaw的结果,接着我们要用raw2probability 把向量转化为一个概率分布,因为spark 版本不同,该方法的签名也略有变化,所以可能要做下版本适配:

val raw2probabilityMethod = if (sparkSession.version.startsWith("2.3")) "raw2probabilityInPlace" else "raw2probability"
val raw2probability = model.getClass.getMethod(raw2probabilityMethod, classOf[Vector]).invoke(model, predictRaw).asInstanceOf[Vector]

raw2probability 其实也还是一个向量,这个向量的长度是分类的数目,每个位置的值是概率。所以所以我们只要拿到最大的那个概率值所在的位置就行:

val categoryId = raw2probability.argmax

这个时候categoryId 就是我们预测的分类了。

截止到目前我们已经完成了作为一个普通java/scala 方法的调用流程。如果我不想用在应用程序里,而是放到spark 流式计算里呢?或者批处理也行,那么这个时候你只需要封装一个UDF函数即可:

val models = sparkSession.sparkContext.broadcast(_model.asInstanceOf[ArrayBuffer[NaiveBayesModel]])
val f2 = (vec: Vector) => {models.value.map { model =>val predictRaw = model.getClass.getMethod("predictRaw", classOf[Vector]).invoke(model, vec).asInstanceOf[Vector]val raw2probability = model.getClass.getMethod(raw2probabilityMethod, classOf[Vector]).invoke(model, predictRaw).asInstanceOf[Vector]//model.getClass.getMethod("probability2prediction", classOf[Vector]).invoke(model, raw2probability).asInstanceOf[Vector]raw2probability}}sparkSession.udf.register(name , f2)

上面的例子可以参考StreamingPro 中streaming.dsl.mmlib.algs.SQLNaiveBayes的代码。不同的算法因为内部实现不同,我们使用起来也会略微有些区别。

总结

Spark MLlib学习了SKLearn里的transform和fit的概念,但是因为设计上还是遵循批处理的方式,实际部署后会有很大的性能瓶颈,不适合那种数据一条一条过来需要快速响应的预测流程,所以需要调用一些内部的API来完成最后的预测。

作者:祝威廉
链接:https://www.jianshu.com/p/3c038027ff61
來源:简书
简书著作权归作者所有,任何形式的转载都请联系作者获得授权并注明出处。

如何在Java应用里集成Spark MLlib训练好的模型做预测相关推荐

  1. 分享Spark MLlib训练的广告点击率预测模型

    2015年,全球互联网广告营收接近600亿美元,比2014年增长了近20%.多家互联网巨头都依赖于广告营收,如谷歌,百度,Facebook,互联网新贵们也都开始试水广告业,如Snapchat, Pin ...

  2. Spark MLlib实现的广告点击预测–Gradient-Boosted Trees

    关键字:spark.mllib.Gradient-Boosted Trees.广告点击预测 本文尝试使用Spark提供的机器学习算法 Gradient-Boosted Trees来预测一个用户是否会点 ...

  3. spark mllib实现 广告点击率预测

    本文尝试使用Spark提供的机器学习算法 Gradient-Boosted Trees来预测一个用户是否会点击广告. 训练和测试数据使用Kaggle Avazu CTR 比赛的样例数据,下载地址:ht ...

  4. R语言使用caretEnsemble包的caretStack函数把多个机器学习模型融合成一个模型、构建融合(集成)预测模型、使用融合模型进行预测推理

    R语言使用caretEnsemble包的caretStack函数把多个机器学习模型融合成一个模型.构建融合(集成)预测模型.自定义融合模型的trainControl参数.method参数.评估指标参数 ...

  5. python 决策模型_【Spark MLlib速成宝典】模型篇05决策树【Decision Tree】(Python版)...

    #-*-coding=utf-8 -*- from pyspark importSparkConf, SparkContext sc= SparkContext('local')from pyspar ...

  6. Java加载sklearn训练好的模型进行预测(无法搞定)

    参考文献主要是[1][2] [2]中代码各种类函数都是自定义的,放弃吧 转攻向[1] --------------------------------------------------------- ...

  7. android支付宝支付微信支付封装,如何在Android App中集成支付宝和微信支付功能

    前言 本文主要介绍如何在 Android App 里集成支付宝和微信支付的功能,文中将实现的步骤一步步介绍的非常详细,对同样遇到这个问题的朋友相信会是一个很好的参考,下面话不多说了,来一起看看详细的介 ...

  8. 如何在Java中生成比特币钱包地址

    让我们通过学习比特币(Bitcoin)如何实施该技术的各个方面来工作,好吗?该技术包括以下几个方面: 比特币地址bitcoin address是用来发送和接收比特币的. 交易transaction是比 ...

  9. 分布式机器学习之——Spark MLlib并行训练原理

    这里是 王喆的机器学习笔记 的第二十五篇文章.接下来的几篇文章希望与大家一同讨论一下机器学习模型的分布式训练的问题.这个问题在推荐.广告.搜索领域尤为突出,因为在互联网场景下,动辄TB甚至PB级的数据 ...

最新文章

  1. (面试)java基础-== 和 equals 的区别?
  2. 英特尔用英伟达显卡,给GTA5打了个超强画质补丁
  3. Android Bundle类,通过bundle实现在两个activity之间的通讯
  4. Spring教程:使用Spring框架和Spring Boot创建Hello World REST API
  5. 母牛的故事(HDU-2018)
  6. JavaScript玩转机器学习:模型和层
  7. R语言转换并保存json文件--使用jsonlite包
  8. NWT纪事:说是闷声发大财,不干活就发财了?
  9. 关于2020idea不能创建web项目问题
  10. Windows | 常用软件
  11. proccessing 中的 port busy
  12. win11 删除不要的输入法,删除阿尔巴尼亚语输入法
  13. 学霸,顾名思义,就是成绩非常好
  14. Java字符拼成图片
  15. (姊妹仨)BlazePalm: 先检手掌再检骨架,虚拟合成数据助力 2.5D 信息输出
  16. awk 分隔符_awk 命令快速入门
  17. java零到一:Servlet和JSP-7:转发、重定向的区别 和状态管理
  18. 网友观点:IT售前6式
  19. linux下编译dbus源码,ubuntu安装dbus
  20. 《推箱子》python小游戏(第二期)

热门文章

  1. go语言爬虫教程python_Go语言爬虫 - Go语言中文网 - Golang中文社区
  2. BGP小实验(二)——还是他,继续第二波走起来
  3. 树莓派各版本配置对比
  4. FatFs源码剖析(2)
  5. Android textview字体颜色显示和图片显示
  6. android studio开关按钮,Android studio实现滑动开关
  7. @override报错_C++ 多态性:C++11:override 与 final(学习笔记:第8章 09)
  8. sql能查到数据 dataset对象里面没有值_DataSet
  9. android combobox控件,Android中的组合框
  10. html鼠标点击有手势出来,用原生js+css3撸的一个下拉手势事件插件