在上一篇(https://blog.csdn.net/baymax_007/article/details/82748544)中,利用逻辑回归实现资讯多分类。本文在之前基础上,又引入决策树、随机森林、多层感知分类器、xgboost和朴素贝叶斯分类算法,并对验证集和测试集分类正确率和耗时进行性能对比。

ml支持决策树、随机森林、梯度提升决策树(GBDT)、线性支持向量机(LSVC)、多层感知分类器(MPC,简单神经网络)和朴素贝叶斯分类,可以直接使用。需要注意,梯度提升决策树和线性支持向量机在spark mllib2.3.1版本中暂时不支持多分类,本文先不对两者作对比。xgboost4j-spark中封装支持java和scala版本的xgboost,可以直接使用。

一、环境

java 1.8.0_172+scala 2.11.8+spark 2.3.1+HanLP portable-1.6.8+xgboost-spark 0.80

 <!-- https://mvnrepository.com/artifact/org.apache.spark/spark-core --><dependency><groupId>org.apache.spark</groupId><artifactId>spark-core_2.11</artifactId><version>2.3.1</version></dependency><!-- https://mvnrepository.com/artifact/org.apache.spark/spark-sql --><dependency><groupId>org.apache.spark</groupId><artifactId>spark-sql_2.11</artifactId><version>2.3.1</version></dependency><!-- https://mvnrepository.com/artifact/org.apache.spark/spark-hive --><dependency><groupId>org.apache.spark</groupId><artifactId>spark-hive_2.11</artifactId><version>2.3.1</version></dependency><!-- https://mvnrepository.com/artifact/org.apache.spark/spark-mllib --><dependency><groupId>org.apache.spark</groupId><artifactId>spark-mllib_2.11</artifactId><version>2.3.1</version></dependency><!-- https://mvnrepository.com/artifact/com.hankcs/hanlp --><dependency><groupId>com.hankcs</groupId><artifactId>hanlp</artifactId><version>portable-1.6.8</version></dependency><!-- https://mvnrepository.com/artifact/ml.dmlc/xgboost4j-spark --><dependency><groupId>ml.dmlc</groupId><artifactId>xgboost4j-spark</artifactId><version>0.80</version></dependency>

二、实验设计

spark ml支持pipeline,可以将特征提取转换、分类模型一起组装到pipeline中,通过对pipeline的训练建模,并在统一分类评估准则下,进行算法对比,从而简化代码冗余,如下图所示。

而上一篇HanLP分词无法组装到pipeline中,因此需要自定义ml包Tokenizer继承UnaryTransformer类,并重写UnaryTransformFunc,实现基于HanLP分词、可以组装到pipeline功能。

三、代码实现

1. 自定义HanLP实现pipeline封装

ml自带Tokenizer可以封装到pipeline中,查看代码,发现其继承UnaryTransformer类,并重写UnaryTransformFunc方法,这也是分词的核心方法,outputDataType和valiadateInputType可以约束输出类型和输入类型校验。

class HanLPTokenizer(override val uid:String) extends UnaryTransformer[String, Seq[String], HanLPTokenizer] {private var segmentType = "StandardTokenizer"private var enableNature = falsedef setSegmentType(value:String):this.type = {this.segmentType = valuethis}def enableNature(value:Boolean):this.type  = {this.enableNature = valuethis}def this() = this(Identifiable.randomUID("HanLPTokenizer"))override protected def createTransformFunc: String => Seq[String] = {hanLP}private def hanLP(line:String): Seq[String] = {var terms: Seq[Term] = Seq()import collection.JavaConversions._segmentType match {case "StandardSegment" =>terms = StandardTokenizer.segment(line)case "NLPSegment" =>terms = NLPTokenizer.segment(line)case "IndexSegment" =>terms = IndexTokenizer.segment(line)case "SpeedSegment" =>terms = SpeedTokenizer.segment(line)case "NShortSegment" =>terms = new NShortSegment().seg(line)case "CRFlexicalAnalyzer" =>terms = new CRFLexicalAnalyzer().seg(line)case _ =>println("分词类型错误!")System.exit(1)}val termSeq = terms.map(term =>if(this.enableNature) term.toString else term.word)termSeq}override protected def validateInputType(inputType: DataType): Unit = {require(inputType == DataTypes.StringType,s"Input type must be string type but got $inputType.")}override protected def outputDataType: DataType = new ArrayType(StringType, true)}

2. 特征工程代码

主要包含有:标签索引转换,本文分词,去除停用词、关键词频数特征提取和预测索引标签还原。

val indexer = new StringIndexer().setInputCol("tab").setOutputCol("label").fit(peopleNews)val segmenter = new HanLPTokenizer().setInputCol("content").setOutputCol("tokens").enableNature(false).setSegmentType("StandardSegment")val stopwords = spark.read.textFile("/opt/data/stopwordsCH.txt").collect()val remover = new StopWordsRemover().setStopWords(stopwords).setInputCol("tokens").setOutputCol("removed")val vectorizer = new CountVectorizer().setVocabSize(1000).setInputCol("removed").setOutputCol("features")val converts = new IndexToString().setInputCol("prediction").setOutputCol("predictionTab").setLabels(indexer.labels)

3.逻辑回归代码

    val lr = new LogisticRegression().setMaxIter(40).setTol(1e-7).setLabelCol("label").setFeaturesCol("features")val lrStartTime = new Date().getTimeval lrPipeline = new Pipeline().setStages(Array(indexer,segmenter,remover,vectorizer,lr,converts))val Array(train,test) = peopleNews.randomSplit(Array(0.8,0.2),12L)val lrModel = lrPipeline.fit(train)val lrValiad = lrModel.transform(train)val lrPredictions = lrModel.transform(test)val evaluator = new MulticlassClassificationEvaluator().setLabelCol("label").setPredictionCol("prediction").setMetricName("accuracy")val accuracyLrt = evaluator.evaluate(lrValiad)println(s"逻辑回归验证集分类准确率 = $accuracyLrt")val accuracyLrv = evaluator.evaluate(lrPredictions)println(s"逻辑回归测试集分类准确率 = $accuracyLrv")val lrEndTime = new Date().getTimeval lrCostTime = lrEndTime - lrStartTimeprintln(s"逻辑回归分类耗时:$lrCostTime")

4.决策树代码

//    训练决策树模型val dtStartTime = new Date().getTimeval dt = new DecisionTreeClassifier().setLabelCol("label").setFeaturesCol("features").setImpurity("entropy") // 不纯度.setMaxBins(1000) // 离散化"连续特征"的最大划分数.setMaxDepth(10) // 树的最大深度.setMinInfoGain(0.01) //一个节点分裂的最小信息增益,值为[0,1].setMinInstancesPerNode(5) //每个节点包含的最小样本数.setSeed(123456L)val dtPipeline = new Pipeline().setStages(Array(indexer,segmenter,remover,vectorizer,dt,converts))val dtModel = dtPipeline.fit(train)val dtValiad = dtModel.transform(train)val dtPredictions = dtModel.transform(test)val accuracyDtt = evaluator.evaluate(dtValiad)println(s"决策树验证集分类准确率 = $accuracyDtt")val accuracyDtv = evaluator.evaluate(dtPredictions)println(s"决策树测试集分类准确率 = $accuracyDtv")val dtEndTime = new Date().getTimeval dtCostTime = dtEndTime - dtStartTimeprintln(s"决策树分类耗时:$dtCostTime")

5. 随机森林代码

//    训练随机森林模型val rfStartTime = new Date().getTimeval rf = new RandomForestClassifier().setLabelCol("label").setFeaturesCol("features").setImpurity("entropy") // 不纯度.setMaxBins(1000) // 离散化"连续特征"的最大划分数.setMaxDepth(10) // 树的最大深度.setMinInfoGain(0.01) //一个节点分裂的最小信息增益,值为[0,1].setMinInstancesPerNode(5) //每个节点包含的最小样本数.setNumTrees(100).setSeed(123456L)val rfPipeline = new Pipeline().setStages(Array(indexer,segmenter,remover,vectorizer,rf,converts))val rfModel = rfPipeline.fit(train)val rfValiad = rfModel.transform(train)val rfPredictions = rfModel.transform(test)val accuracyRft = evaluator.evaluate(rfValiad)println(s"随机森林验证集分类准确率为:$accuracyRft")val accuracyRfv = evaluator.evaluate(rfPredictions)println(s"随机森林测试集分类准确率为:$accuracyRfv")val rfEndTime = new Date().getTimeval rfCostTime = rfEndTime - rfStartTimeprintln(s"随机森林分类耗时:$rfCostTime")

6. 多层感知分类器代码

多层感知分类器(简单神经网络)网络节点设置可以参考:

m:输入层节点个数,n:输出层节点个数,h1:第一层隐含层节点个数=log2(m),h2:第一层隐含层节点个数=sqrt(m+n)+a,其中a取1-10

 //    多层感知分类器val inputLayers = vectorizer.getVocabSizeval hideLayer1 = Math.round(Math.log(inputLayers)/Math.log(2)).toIntval outputLayers = peopleNews.select("tab").distinct().count().toIntval hideLayer2 = Math.round(Math.sqrt(inputLayers + outputLayers) + 1).toIntval layers = Array[Int](inputLayers, hideLayer1, hideLayer2, outputLayers)val mpcstartTime = new Date().getTimeval mpc = new MultilayerPerceptronClassifier().setLayers(layers).setBlockSize(128).setTol(1e-7).setMaxIter(100).setLabelCol("label").setFeaturesCol("features").setSeed(1234L)val mpcPipeline = new Pipeline().setStages(Array(indexer,segmenter,remover,vectorizer,mpc,converts))val mpcModel = mpcPipeline.fit(train)val mpcValiad = mpcModel.transform(train)val mpcPredictions = mpcModel.transform(test)val accuracyMpct = evaluator.evaluate(mpcValiad)println(s"多层感知分类器验证集分类准确率:$accuracyMpct")val accuracyMpcv = evaluator.evaluate(mpcPredictions)println(s"多层感知分类器测试集分类准确率:$accuracyMpcv")val mpcEndTime = new Date().getTimeval mpcCostTime = mpcEndTime - mpcstartTimeprintln(s"多层感知分类器分类耗时:$mpcCostTime")

7. XGBOOST代码

//    xgboost训练模型val xgbParam = Map("eta" -> 0.1f,"max_depth" -> 10, //数的最大深度。缺省值为6 ,取值范围为:[1,∞]"objective" -> "multi:softprob",  //定义学习任务及相应的学习目标"num_class" -> outputLayers,"num_round" -> 10,"num_workers" -> 1)val xgbStartTime = new Date().getTimeval xgb = new XGBoostClassifier(xgbParam).setFeaturesCol("features").setLabelCol("label")val xgbPipeline = new Pipeline().setStages(Array(indexer,segmenter,remover,vectorizer,xgb,converts))val xgbModel = xgbPipeline.fit(train)val xgbValiad = xgbModel.transform(train)val xgbPredictions = xgbModel.transform(test)val accuracyXgbt = evaluator.evaluate(xgbValiad)println(s"xgboost验证集分类准确率为:$accuracyXgbt")val accuracyXgbv = evaluator.evaluate(xgbPredictions)println(s"xgboost测试集分类准确率为:$accuracyXgbv")val xgbEndTime = new Date().getTimeval xgbCostTime = xgbEndTime - xgbStartTimeprintln(s"xgboost分类耗时:$xgbCostTime")

8. 朴素贝叶斯代码

//    朴素贝叶斯分类val nvbStartTime = new Date().getTimeval nvb = new NaiveBayes()val nvbPipeline = new Pipeline().setStages(Array(indexer,segmenter,remover,vectorizer,nvb,converts))val nvbModel = nvbPipeline.fit(train)val nvbValiad = nvbModel.transform(train)val nvbPredictions = nvbModel.transform(test)val accuracyNvbt = evaluator.evaluate(nvbValiad)println(s"朴素贝叶斯验证集分类准确率:$accuracyNvbt")val accuracyNvbv = evaluator.evaluate(nvbPredictions)println(s"朴素贝叶斯测试集分类准确率:$accuracyNvbv")val nvbEndTime = new Date().getTimeval nvbCostTime = nvbEndTime - nvbStartTimeprintln(s"朴素贝叶斯分类耗时:$nvbCostTime")

四、性能对比

  验证集分类准确率 测试集分类准确率 耗时(ms)
逻辑回归(LR) 100% 79.53% 44697
决策树(DT) 81.58% 73.68% 34597
随机森林(RF) 94.24% 73.68% 56608
多层感知分类器(MPC) 97.98% 68.42% 30801
XGBOOST 99.71% 79.53% 31947
朴素贝叶斯分类(NVB) 83.74% 71.34% 11510

以上算法,设计参数调优会在稍后进行尝试。

参考文献

https://blog.csdn.net/baymax_007/article/details/82748544

https://blog.csdn.net/liam08/article/details/79184159

http://spark.apache.org/docs/latest/ml-classification-regression.html

https://blog.csdn.net/u013421629/article/details/78329191

https://xgboost.readthedocs.io/en/latest/jvm/

基于逻辑回归/决策树/随机森林/多层感知分类器/xgboost/朴素贝叶斯分类的资讯多分类性能对比相关推荐

  1. 基于6种监督学习(逻辑回归+决策树+随机森林+SVM+朴素贝叶斯+神经网络)的毒蘑菇分类

    公众号:尤而小屋 作者:Peter 编辑:Peter 大家好,我是Peter~ 本文是kaggle案例分享的第3篇,赛题的名称是:Mushroom Classification,Safe to eat ...

  2. ML之mlxtend:基于iris鸢尾花数据集利用逻辑回归LoR/随机森林RF/支持向量机SVM/集成学习算法结合mlxtend库实现模型可解释性(决策边界可视化)

    ML之mlxtend:基于iris鸢尾花数据集利用逻辑回归LoR/随机森林RF/支持向量机SVM/集成学习算法结合mlxtend库实现模型可解释性(决策边界可视化) 目录 相关文章 ML之mlxten ...

  3. 数据挖掘算法(logistic回归,随机森林,GBDT和xgboost)-腾讯云社区

    机器学习就是样本中有大量的x(特征量)和y(目标变量)然后求这个function.(了解更多可以看: https://zhuanlan.zhihu.com/p/21340974?refer=mlear ...

  4. XGBoost 、逻辑回归、随机森林 模型实战对比

    目录 引言 一. 数据的特征处理 二.导入XGBoost模型 三. 使用其他模型于XGBoost进行对比 引言 在XGBoost基本原理博文中我们介绍了XGBoost的基本原理,本篇博文我们将介绍XG ...

  5. 基于逻辑回归算法的心脏病不平衡数据分类代码实现

    1.数据说明 数据来源Kaggle网站中引用的CDC数据,原数据大概有300个变量,这里大约使用20个,严重不平衡数据,约为1:9.该代码通过对进行清洗,标准化,欠采样(Undersampling)进 ...

  6. 数据代码分享|Python用NLP自然语言处理LSTM神经网络Twitter推特灾难文本数据、词云可视化与SVM,KNN,多层感知器,朴素贝叶斯,随机森林,GBDT对比

    作者:Yunfan Zhang Twitter是一家美国社交网络及微博客服务的网站,致力于服务公众对话.迄今为止,Twitter的日活跃用户达1.86亿.与此同时,Twitter也已成为突发紧急情况时 ...

  7. 15- 决策回归树, 随机森林, 极限森林 (决策树优化) (算法)

    1.  决策回归树: from sklearn.tree import DecisionTreeRegressor model = DecisionTreeRegressor(criterion='m ...

  8. 决策树 随机森林 xgboost_一文看懂随机森林-RandomForest(附4个构造步骤+4种实现方式评测+10个优缺点)...

    随机森林是一种由决策树构成的集成算法,他在很多情况下都能有不错的表现.本文将介绍随机森林的基本概念.4 个构造步骤.4 种方式的对比评测.10 个优缺点和 4 个应用方向. 什么是随机森林? 随机森林 ...

  9. 决策树随机森林GBDTXGBoost学习笔记以及代码实现

    文章目录 1. 引入 1.1 决策树 1.2 随机森林 1.3 GBDT(Gradient Boosting Decision Tree)梯度提升决策树 1.4 XGBoost(eXtreme Gra ...

最新文章

  1. 单片机C语言知识用法之,单片机C语言知识用法之define
  2. 小白也能看懂的教程:微信小程序在线支付功能开通详细流程(图文介绍)
  3. 一文带你重温去年最难忘的10个数据泄露事件
  4. 5G 信令流程 — 5GC 的连接管理(CM,Connection Management)
  5. C# winform 编写记事本
  6. 相继平均法matlab代码_模式识别matlab编程:用k次平均法将20个样本分成2类
  7. 信息学奥赛一本通(C++)在线评测系统——基础(二)基础算法 —— 1312:【例3.4】昆虫繁殖
  8. html遮罩实例,给原生html中添加水印遮罩层的实现示例
  9. 【Linux就该这么学 20期培训笔记 01】部署虚拟环境安装linux系统
  10. Oracle入门(十四.20)之创建DML触发器:第一部分
  11. 背景选择器selector替换按钮默认背景
  12. 使用docker安装fastDFS
  13. java写 excel
  14. 让Office无处不在——Office Web App初体验
  15. MyBatis源码阅读(一) --- 源码阅读环境搭建
  16. BScroll切换才能滚动,刷新一下就不能滚动
  17. 制作京东登陆页面 HTML+CSS
  18. lae界面开发工具入门之介绍二--渲染组件篇
  19. mysql 合并两个update_如何将多条update语句合并为一条
  20. Pytorch使用GPU加速

热门文章

  1. 无头结点单链表的逆置_解析单链表逆置的多种方法 | 术与道的分享
  2. 使用HighCharts绘制bar形柱状图
  3. RGB-D图像(深度图像)的Surface编码
  4. java.lang.RuntimeException: java.io.IOException: Couldn‘t create proxy provider null错误解决
  5. java gui的文本框_GUI编程笔记(java)07:GUI把文本框的值移到文本域案例
  6. Bilibili直播弹幕抓取(1):WebSocket
  7. 手把手教你制作一款iOS越狱App,伪装微信位置
  8. 9年,我从小白到大厂测开工程师,从单身汉到迎娶白富美···
  9. LSTM和双向LSTM讲解及实践
  10. vue实现歌词与播放时间同步