朴素贝叶斯模型简述:
贝叶斯模型通过使用后验概率和类的概率分布来估计先验概率,具体的以公式表达为

P(Y)可以使用训练样本的类分布进行估计。如果X是单特征也很好估计,但如果X={x1,x2,..,xn}等n个特征构成,那估计n个特征的联合概率分布P(X)=P(x1,x2,...,xn)将变得非常困难。由于贝叶斯模型的参数难于估计,限制了其的应用。

朴素贝叶斯模型是贝叶斯模型的简化版本,通过假设特征之间独立不相关,那么

通过求解每个特征的分布和每个特征的后验概率来近似特征的联合概率分布和特征的后验概率。当然,通常情况下,特征相互独立的假设不会成立,这里只是模型复杂度和模型精度的一个权衡。这样样本x属于第i类和第k类的概率分别为:

由于P(X)联合概率分布对每个类都是相同的,可以不求。

Spark下朴素贝叶斯的具体实现:
NaiveBayesModel

NaiveBayesModel保存了朴素贝叶斯模型的参数,继承自ClassificationModel,并重写了predict方法。

先看看NaiveBayesModel的贝叶斯模型参数,

labels:类别编号

pi: 类的先验概率P(Y)的对数值

theta: 条件概率P(X|Y)的对数值

modelType:Multinomial,Bernoulli:实际上依据特征的分布不同,朴素贝叶斯又划分为多个子类别,这实际上又是对特征的一种假设来简化建模。

如果特征近似服从多项式分布,即特征只能取N个值,取到每个值的概率为pi,则p1+p2+..+pn=1。基于此假设构建的贝叶斯分类称为Multinomial NaiveBayesModel,典型的例子是基于词频向量的文本分类。

如果特征服从伯努利分布,基于此假设构建的贝叶斯分类称为 Bernoulli NaiveBayesModel,典型的例子是基于one-hot构建的文本分类。

不同的特征分布假设,将调用不同的概率计算函数:

private val (thetaMinusNegTheta, negThetaSum) = modelType match {
  case Multinomial => (None, None)
  case Bernoulli =>
    val negTheta = thetaMatrix.map(value => math.log(1.0 - math.exp(value)))//事件失败的概率
    val ones = new DenseVector(Array.fill(thetaMatrix.numCols) {1.0})
    //事件不失败的概率
    val thetaMinusNegTheta = thetaMatrix.map { value =>
      value - math.log(1.0 - math.exp(value))
    }
    (Option(thetaMinusNegTheta), Option(negTheta.multiply(ones)))
  case _ =>
    // This should never happen.
    throw new UnknownError(s"Invalid modelType: $modelType.")
}
//特征的分布不一样,其估计也不一样
@Since("1.0.0")
override def predict(testData: Vector): Double = {
  modelType match {
    case Multinomial =>
      labels(multinomialCalculation(testData).argmax)
    case Bernoulli =>
      labels(bernoulliCalculation(testData).argmax)
  }
}
另外为方便计算和避免小数值,根据对数运算法则,可将乘积运算转换为加法运算,后验概率的估计值为

//计算每个类别的概率p(yi|X1,x2)=p(yi)*p(x1|yi)*p(x2|yi)*.../P(X)  实际计算的是log(p(yi)) + log(p(x1|yi)) +***
//全概率P(X)对每个类别一致,可以不算
private def multinomialCalculation(testData: Vector) = {
  val prob = thetaMatrix.multiply(testData)//求出
  BLAS.axpy(1.0, piVector, prob)
  prob
}
private def bernoulliCalculation(testData: Vector) = {
  testData.foreachActive((_, value) =>
    if (value != 0.0 && value != 1.0) {//伯努利事件的结果只有两种状态
      throw new SparkException(
        s"Bernoulli naive Bayes requires 0 or 1 feature values but found $testData.")
    }
  )
  val prob = thetaMinusNegTheta.get.multiply(testData)
  BLAS.axpy(1.0, piVector, prob)
  BLAS.axpy(1.0, negThetaSum.get, prob)
  prob
}
NaiveBayes
再来看NaiveBayesModel的参数估计,参数估计由NaiveBayes类开始。

NaiveBayes构造函数有个lambda参数,一般在估计P(Xi|Y)时,对于在训练数据中没有出现的Xi,会得到其估计P(Xij|Y)=0

在实际应用中,对于某个类别没有出现在样本集中或者某个特征没有出现在某类样本集中,这个时候就需要加入平滑因子lambda去调整,一般常用拉普拉斯平滑进行处理。

类的分布估计调整为

多项式模型下的参数估计调整为:

伯努力模型下参数估计调整为:

朴素贝叶斯模型的训练是在mllib.NaiveBayes中由用户调用其run来完成训练的。run方法调用了ml.NaiveBayes类的trainWithLabelCheck方法来完成参数估计的。

接下来看看trainWithLabelCheck进行参数估计的过程

private[spark] def trainWithLabelCheck(
      dataset: Dataset[_],
      positiveLabel: Boolean): NaiveBayesModel = {
    if (positiveLabel) {...}
    val modelTypeValue = $(modelType)
    val requireValues: Vector => Unit = {...}
    //估算argmax p(yi)*p(X|Yi) ==> argmax log(p(yi)) + log(p(X|Yi))
    //p(yi) = numDocuments in lable i / numDocuments all
    //p(X|Yi) = p(X1|Yi)*p(X2|Yi)... ==> log(p(X1|Yi)) + log(p(X2|Yi))
    //p(X1|Yi) = featureSum in lable i / featureSum all
    //特征数量
    val numFeatures = dataset.select(col($(featuresCol))).head().getAs[Vector](0).size
    //特征权重
    val w = if (!isDefined(weightCol) || $(weightCol).isEmpty) lit(1.0) else col($(weightCol))
 
    // Aggregates分布式计算 进行文档和特征的统计计数
    //aggregateByKey再collect等价于aggregateByKeyLocally,返回的是一个HashMap<lable id,object>
    //aggregated具体形式为 [lable i, numDocuments in lable i, a vector contains <feature1Sum,feature2Sum,..>]
    val aggregated = dataset.select(col($(labelCol)), w, col($(featuresCol))).rdd
      .map { row => (row.getDouble(0), (row.getDouble(1), row.getAs[Vector](2)))
      }.aggregateByKey[(Double, DenseVector)]((0.0, Vectors.zeros(numFeatures).toDense))(//分类别统计featureSum
      seqOp = {
         case ((weightSum: Double, featureSum: DenseVector), (weight, features)) =>
           requireValues(features)
           BLAS.axpy(weight, features, featureSum)//常数乘以向量加另一个向量
           (weightSum + weight, featureSum)
      },
      combOp = {
         case ((weightSum1, featureSum1), (weightSum2, featureSum2)) =>
           BLAS.axpy(1.0, featureSum2, featureSum1)//featureSum2 + featureSum1
           (weightSum1 + weightSum2, featureSum1)
      }).collect().sortBy(_._1)//sortBy lable index
    //分类数
    val numLabels = aggregated.length
    //总的样本数量
    val numDocuments = aggregated.map(_._2._1).sum
    val labelArray = new Array[Double](numLabels)
    //初始化存储p(yi)的数组
    val piArray = new Array[Double](numLabels)
    //用于计算p(xi|yk)的参数,类别数numLabels*特征数量numFeatures大小的数组
    val thetaArray = new Array[Double](numLabels * numFeatures)
    val lambda = $(smoothing)//平滑参数
    val piLogDenom = math.log(numDocuments + numLabels * lambda)//这个是估计p(yi)的分母,见公式。。。
    var i = 0
    //迭代aggregated这个存在本地的HashMap
    aggregated.foreach { case (label, (n, sumTermFreqs)) =>
      labelArray(i) = label
      piArray(i) = math.log(n + lambda) - piLogDenom //计算log(p(yi)) , 是(numDocuments in lable i + lambda)/(numDocuments + numLabels * lambda)的对数形式
      val thetaLogDenom = $(modelType) match {//这个是计算公式。。。的分母部分
        case Multinomial => math.log(sumTermFreqs.values.sum + numFeatures * lambda)//实际上上加了一个平滑因子的
        case Bernoulli => math.log(n + 2.0 * lambda)
        case _ =>
          throw new UnknownError(s"Invalid modelType: ${$(modelType)}.")
      }
      var j = 0
      while (j < numFeatures) {
        //第i类别第j个特征的参数估计
        thetaArray(i * numFeatures + j) = math.log(sumTermFreqs(j) + lambda) - thetaLogDenom //计算log(p(Xk|Yi))
        j += 1
      }
      i += 1
    }
    val pi = Vectors.dense(piArray) //存储log(p(yi))的数组
    val theta = new DenseMatrix(numLabels, numFeatures, thetaArray, true)
    new NaiveBayesModel(uid, pi, theta).setOldLabels(labelArray)
  }
Spark目前只实现了基于伯努利分布和二项分布的朴素贝叶斯算法,对于诸如高斯分布的朴素贝叶斯目前还没有实现,在需要时可参照上述两个模型的过程来自己实现(重写NaiveBayesModel的predict方法和NaiveBayes的参数估计方法)。
————————————————
版权声明:本文为CSDN博主「大愚若智_」的原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接及本声明。
原文链接:https://blog.csdn.net/zbc1090549839/article/details/68067460

spark.mllib:NaiveBayes相关推荐

  1. spark.mllib:GradientBoostedTrees

    Gradient-Boosted Trees(GBT或者GBDT) 和 RandomForests 都属于集成学习的范畴,相比于单个模型有限的表达能力,组合多个base model后表达能力更加丰富. ...

  2. spark.mllib:回归算法

    Spark实现了三类线性回归方法: 1.LinearRegression:普通线性回归模型 2.LassoRegression:加L1正则化的线性回归 3.RidgeRegression:加L2正则化 ...

  3. spark.mllib:Optimizer

    Spark中的求解器,根据输入的训练数据及设定的迭代次数.正则化项.参数收敛精度等进行迭代求解模型的参数.Spark内部实现来两类求解器,基于随机梯度下降(miniBatch选取样本)的Gradien ...

  4. spark.mllib:bagging方法

    在训练集成分类器时,关键的一步在于如何从全训练样本集中构建子样本集提供给子分类器进行训练.目前主流的两种子样本集构造方式是bagging方法和boosting方法.bagging方法的思想是从全样本集 ...

  5. Spark MLlib实现的中文文本分类–Naive Bayes

    2019独角兽企业重金招聘Python工程师标准>>> 中文分词 对于中文文本分类而言,需要先对文章进行分词,我使用的是IKAnalyzer中文分析工具,其中自己可以配置扩展词库来使 ...

  6. Spark MLlib 机器学习

    本章导读 机器学习(machine learning, ML)是一门涉及概率论.统计学.逼近论.凸分析.算法复杂度理论等多领域的交叉学科.ML专注于研究计算机模拟或实现人类的学习行为,以获取新知识.新 ...

  7. Spark MLlib实现的广告点击预测–Gradient-Boosted Trees

    关键字:spark.mllib.Gradient-Boosted Trees.广告点击预测 本文尝试使用Spark提供的机器学习算法 Gradient-Boosted Trees来预测一个用户是否会点 ...

  8. python spark MLlib

    window系统 1. anaconda 或python spark环境变量 2. 配置spark home D:\Develop\spark-1.6.0-bin-hadoop2.6\spark-1. ...

  9. spark mllib实现 广告点击率预测

    本文尝试使用Spark提供的机器学习算法 Gradient-Boosted Trees来预测一个用户是否会点击广告. 训练和测试数据使用Kaggle Avazu CTR 比赛的样例数据,下载地址:ht ...

最新文章

  1. MySQL面试题 | 附答案解析(十五)
  2. mysql100万数据一键下载csv_使用PHP来导入包含100万条数据的csv文件,请问你最快多久能全部导入mysql 数据库?...
  3. Document builder customizing
  4. Oracle 11g DRCP连接方式——基本原理
  5. YOLO,VOC数据集标注格式解析
  6. 手把手教你使用Python网络爬虫获取招聘信息
  7. 微信搜一搜又推出了新功能!搜“医保码”直达医保页面
  8. Linux系统学习: 用户和权限管理指令: 请简述 Linux 权限划分的原则
  9. 趋势安全软件卸载:如何不需要密码或忘记密码卸载Trend Micro OfficeScan Agent?
  10. jflash添加芯片_【原创】巧用J-Link+J-Flash给Kinesis烧写序列号
  11. snb处理器hd3000显卡专用extra_最强i9-9900K处理器造就最强主机!华硕ROG GL12CX评测...
  12. ProcessingJS介绍
  13. 【软件定义汽车】【场景篇】智能座舱
  14. 5种常用的四轴飞行器PID算法讲解集合
  15. 成都开发者看过来!百度资深研发工程师将出席超级账本成都见面会
  16. 这次把怎么做好一个PPT讲清-其他技巧篇
  17. SAP 科目的 未清项管理的理解
  18. VIRTUALPLANETBUILDER (vpb) osgdem 用法
  19. linux可以玩什么游戏_为什么我们要在Linux上玩游戏,与Icculus聊天等等
  20. 基于socket实现FTP服务器

热门文章

  1. 使用nmap扫描提示utf-8编码错误_Web漏洞扫描神器Nikto使用指南
  2. 在Apache上配置防盗链功能和隐藏版本号
  3. python笔记之利用scrapy框架爬取糗事百科首页段子
  4. 关于 $ Super $ $ 和 $ Sub $ $ 的用法
  5. OBD技术速成——J1850协议概述
  6. php群发不用foreach,如何在没有foreach的情况下使用PHP生成器?
  7. pxe安装系统 ip获取错误_【图说】消防系统安装典型错误举例
  8. kafka partition分配_Kafka架构原理,也就这么回事
  9. oracle查看数据库剩余空间,Oracle 查看数据库空间使用情况
  10. 录像带转存电脑的方法_误删微信记录别着急,教你几招可靠的恢复微信记录方法...