Spark网格搜索——训练验证集拆分
前言
Spark内的网格搜索主要有两种评估流程,分别是 交叉验证 和 训练验证集拆分,这篇文章主要介绍训练验证集拆分
的具体流程
数据集划分
训练集、验证集、测试集
训练集(train):训练模型
验证集(val):评估模型
测试集(test):一旦找到了最佳参数,就开始最终训练
使用训练集训练多个网络模型,再使用验证集测试这些网络,找到得分最高的那个网络作为我们选择的最佳网络,再将训练集和验证集合并,重新训练这个最佳网络,得到最佳网络参数。
Spark ML模型评估算法
回归评估指标
1.RegressionEvaluator
用于验证回归模型的评估算法,如:ALS,线性回归等……
val metricName: Param[String]"rmse" (default): root mean squared error
"mse": mean squared error
"r2": R2 metric
"mae": mean absolute error
分类评估指标
1.BinaryClassificationEvaluator
用于验证二分分类模型的评估算法:如判断(1,0)或(是,否)
val metricName: Param[String]
param for metric name in evaluation (supports "areaUnderROC" (default), "areaUnderPR")
2.MulticlassClassificationEvaluator
用于判断多分类,当然适用于上面的二分类
val metricName: Param[String]
param for metric name in evaluation (supports "f1" (default), "weightedPrecision", "weightedRecall", "accuracy")
聚类评估指标
ClusteringEvaluator
用于聚类模型的评估 越接近1,表明效果越好
val metricName: Param[String]
param for metric name in evaluation (supports "silhouette" (default))
模型评估Example
记录一次较为简易的模型训练过程:
训练模型分为三步:
- 1.训练集训练模型
- 2.验证集评估模型
- 3.测试集最终训练
import org.apache.spark.ml.clustering.KMeansimport org.apache.spark.ml.evaluation.ClusteringEvaluator//数据集//dataset:测试集 training:训练集 vali:验证集val dataset = spark.read.format("libsvm").load("data/mllib/sample_kmeans_data.txt")//将测试集按7:3的比例切分为 训练集和验证集val Array(training,vali)=dataset.randomSplit(Array(0.7,0.3))// 训练Kmeans模型//Kmeans超参数val kmeans = new KMeans().setK(2).setSeed(1L)// 使用训练集训练模型val model = kmeans.fit(training)// Kmeans为聚类模型,使用聚类指标评估val evaluator = new ClusteringEvaluator()// 使用验证集参与评估val predictions = model.transform(vali)val silhouette = evaluator.evaluate(predictions)println(silhouette)//若评估效果符合预期,即silhouette接近1val real_model=kmeans.fit(dataset)
参数探索
根据上文所说的模型评估Example
,我们可以通过变量silhouette的值,来不断调整模型的参数,使其接近于1。这里有个较为方便的方法,快速找到较为合适的参数——网格搜索
网格搜索
网格搜索算法是一种通过遍历给定的参数组合来优化模型表现的方法
为何使用:超参数选择不恰当,就会出现欠拟合或者过拟合的问题
内容: 网格搜索,搜索的是参数,即在指定的参数范围内,按步长依次调整参数,利用调整的参数训练学习器,从所有的参数中找到在验证集上精度最高的参数,这其实是一个训练和比较的过程。
Grid Search:一种调参手段;穷举搜索:在所有候选的参数选择中,通过循环遍历,尝试每一种可能性,表现最好的参数就是最终的结果
用法:网格搜索适用于三四个(或者更少)的超参数(当超参数的数量增长时,网格搜索的计算复杂度会呈现指数增长,这时候则使用随机搜索),用户列出一个较小的超参数值域,这些超参数至于的笛卡尔积(排列组合)为一组组超参数。网格搜索算法使用每组超参数训练模型并挑选验证集误差最小的超参数组合
缺点:遍历所有组合,比较耗时
网格搜索Example:
import org.apache.spark.ml.clustering.KMeansimport org.apache.spark.ml.evaluation.ClusteringEvaluatorimport org.apache.spark.ml.tuning.{ParamGridBuilder, TrainValidationSplit, TrainValidationSplitModel}//数据集val dataset = spark.read.format("libsvm").load("data/mllib/sample_kmeans_data.txt")// 训练Kmeans模型//Kmeans超参数val kmeans = new KMeans()/*** 网格搜索:* 对所有addGrid()内的超参数数组进行排列组合,rmse越小,模型精确度越高* 排列组合的参数不建议太多,网格搜索相当于所有组合遍历一遍* 这里会对maxIter和k进行排列组合 如:* 第一次训练 maxIter=200,k=5* 第二次训练 maxIter=200,k=10* ……* 所有排列组合训练完后,根据评估模型,筛选出最合适的模型*/val paramGrid = new ParamGridBuilder().addGrid(kmeans.maxIter, Array(200, 400, 600)).addGrid(kmeans.k, Array(5, 10, 20)).build()// Kmeans为聚类模型,使用聚类指标评估val evaluator = new ClusteringEvaluator()val trainValidationSplit = new TrainValidationSplit()//设置预测模型.setEstimator(kmeans)//设置评估模型.setEvaluator(evaluator)//训练集、验证集划分 训练集为$ratio 验证集为1-$ratio.setTrainRatio(0.7)//网格搜索参数.setEstimatorParamMaps(paramGrid)//预测seed.setSeed(1L)//训练//该方法将自动完成`模型评估Example`中的一二步,找到最适合的评估模型后,用测试集dataset训练最终模型val final_model=trainValidationSplit.fit(dataset)//打印参数列表println(final_model.bestModel.parent.extractParamMap())
TrainValidationSplit
ALS模型网格调参时遇到了一些坑,这里列举一下有坑的地方(其实都是同一个原因造成的)
1.模型的最优参数,每次都是网格搜索排列组合的第一个
如:
val paramGrid = new ParamGridBuilder().addGrid(als.maxIter, Array(500,800,1000)).addGrid(als.rank, Array(5,10,15)).build()
上述代码设置的网格参数,在使用网格搜索遍历后,最优参数必是 maxIter=500,rank=5
2.查看rmse时,全是NaN
model.validationMetrics=Array(NaN,NaN,Nan)
先说结论:
造成这些结果的主要原因,还是ALS冷启动策略设置错误的缘故。ALS模型默认遇到未知UserCol的用户时(即没参与过运算的userId),会将prediction置为NaN。而评估模型进行计算时,若prediction的值有Nan数据,会导致最后的评估结果值也为NaN。如上述第二点所示。
设一个评分表,有userCol,itemCol,rating三个字段,且全表数据不会重复。
UserCol | ItemCol | Rating |
---|---|---|
A | a | 5.0 |
B | b | 5.0 |
C | c | 1.0 |
D | d | 2.0 |
E | e | 2.0 |
TrainValidationSplit方法在遍历最优参数时,是将训练集和验证集是按照setTrainRatio($ratio)的比例随机分配,假设ratio=0.8,则训练集与验证集的比例则为8:2,上表将有四条数据(ABCD)参与训练,一条数据(E)参与验证。因ALS模型只能预测参与计算的数据,验证集用户E的prediction=NaN。
TrainValidationSplit遍历过程的大致代码:
……val est = $(estimator)val eval = $(evaluator)val epm = $(estimatorParamMaps)val Array(trainingDataset, validationDataset) =dataset.randomSplit(Array($(trainRatio), 1 - $(trainRatio)), $(seed))trainingDataset.cache()validationDataset.cache()……val metricFutures = epm.zipWithIndex.map { case (paramMap, paramIndex) =>Future[Double] {val model = est.fit(trainingDataset, paramMap).asInstanceOf[Model[_]]if (collectSubModelsParam) {subModels.get(paramIndex) = model}// TODO: duplicate evaluator to take extra params from inputval metric = eval.evaluate(model.transform(validationDataset, paramMap))logDebug(s"Got metric $metric for model trained with $paramMap.")metric}(executionContext)}……val (bestMetric, bestIndex) =if (eval.isLargerBetter) metrics.zipWithIndex.maxBy(_._1)else metrics.zipWithIndex.minBy(_._1)logInfo(s"Best set of parameters:\n${epm(bestIndex)}")logInfo(s"Best train validation split metric: $bestMetric.")val bestModel = est.fit(dataset, epm(bestIndex)).asInstanceOf[Model[_]]
Spark网格搜索——训练验证集拆分相关推荐
- 决策树结合网格搜索交叉验证的例子
决策树结合网格搜索交叉验证 如下是常见的模型评估的指标定义及决策树结合网格搜索交叉验证的例子.详见下文: 混淆矩阵: 准确率: 精准率(预测为正样本真实也是正例的比值,又称为查准率): 召回率(真实为 ...
- sklearn学习-SVM例程总结3(网格搜索+交叉验证——寻找最优超参数)
网格搜索+交叉验证--寻找最优超参数 1548962898@qq.com 连续三天写了三篇博客,主要是为了尽快了解机器学习中算法以外的重要知识,这些知识可以迁移到每一个算法中,或许说,这些知识是学习并 ...
- 时间序列预测:SVR用于时间序列预测代码+模型保存+模型加载+网格搜索+交叉验证
本文关于SVR时间序列的预测,详细步骤如下: 1.数据读取 2.数据集的划分(采用滑动窗口重叠切片) 3.训练数据集掷乱 4.SVR参数设置(网格搜索+交叉验证) 5.SVR模型训练+模型保存 6.S ...
- 【转载】如何理解数据集中【训练集】、【验证集】和【测试集】
转自<吴恩达深度学习笔记(28)-网络训练验证测试数据集的组成介绍> 训练,验证,测试集(Train / Dev / Test sets) 在配置训练.验证和测试数据集的过程中做出正确决策 ...
- 训练集 验证集_训练与验证、测试集数据分布不同的情况
在不同分布的数据集上进行训练与验证.测试 深度学习需要大量的数据,但是有时我们可获得的满足我们真实需求分布的数据并不是那么多,不足以对我们的模型进行训练.这时我们就会收集大量相关的数据加入到训练集中, ...
- 训练集、验证集以及测试集的区别
1.数据集的划分: 训练集:含有参考答案的数据,用来训练模型的已标注数据,用来建立模型,发现规律 验证集:模型训练过程中单独留出的样本集,用于调整模型的超参数和用于对模型的能力进行初步评估 测试集:用 ...
- 3.2 实战项目二(手工分析错误、错误标签及其修正、快速地构建一个简单的系统(快速原型模型)、训练集与验证集-来源不一致的情况(异源问题)、迁移学习、多任务学习、端到端学习)
手工分析错误 手工分析错误的大多数是什么 猫猫识别,准确率90%,想提升,就继续猛加材料,猛调优? --应该先做错误分析,再调优! 把识别出错的100张拿出来, 如果发现50%是"把 ...
- 2022-1-17第三章机器学习基础--网格搜索超参数优化、决策树、随机森林
交叉验证与网格搜索 ①交叉验证(训练集划分-训练集.验证集)–将所有数据分成n等分-并不具备调参能力 4等分就是4折交叉验证:一般采用10折交叉验证 ②网格搜索-调参数(与交叉验证一同使用) 如果有多 ...
- 机器学习之网格搜索调参sklearn
网格搜索 网格搜索 GridSearchCV我们在选择超参数有两个途径:1凭经验:2选择不同大小的参数,带入到模型中,挑选表现最好的参数.通过途径2选择超参数时,人力手动调节注意力成本太高,非常不值得 ...
最新文章
- 排序与查找实验报告java,查找排序实验报告 - 范文大全 - 作文仓库-www.zuowencangku.com...
- numpy生成随机数
- 适配器模式和外观模式
- ***S 2012 交互式报表 -- 钻取式报表
- 天融信安全接入客户端_天融信提示您警惕物联网设备Ripple20漏洞风险
- 2013年,6月20日,今天非常无聊啊。
- CAN总线电平(隐性与显性)
- 小米(MIUI)系统组招聘职位
- 知道路程时间求加速度_凸轮分割器的出力轴加速度是怎么算的
- 威胁情报基础:爬取、行走、分析
- 跟我学Spring Cloud(Finchley版)-04-服务注册与服务发现-原理剖析
- 26.go test
- 网页连接服务器失败是怎么回事,网页怎么连接服务器失败是怎么回事
- 求web嘎嘎厉害的朋友
- DLNA介绍(包括UPnP)
- Kaggle:Quora Question Pairs
- ImportError: packaging>=20.0 is required for a normal functioning of this mo
- clickhouse 常用函数 算数函数 时间函数 日期函数 字符串函数 比较函数 数据类型函数 逻辑函数 类型转换函数 替换函数 数组函数 随机函数 编码函数 UUID URL IP 函数
- 读手语图像识别论文笔记
- Python Matplotlib 花式绘图和中文字符显示、散点图、设置网格和散点函数拟合