配置

配置请看我的其他文章 点击跳转

spark官方文档

点击跳转官方文档

其它文章

推荐一个在蚂蚁做算法的人写的文章,不过他的文章偏专业化,有很多数学学公式。我是看的比较懵。点击跳转

数据

训练数据

预测数据


实体类

用了swagger和lombok 不需要的可以删掉


import io.swagger.annotations.ApiModelProperty;
import lombok.Data;import javax.validation.constraints.Max;
import javax.validation.constraints.Min;
import javax.validation.constraints.NotEmpty;
import javax.validation.constraints.NotNull;/*** 线性回归参数** @author teler* @date 2020-09-21*/
@Data
public class LinearRegressionEntity {/*** 训练数据集路径*/@ApiModelProperty("训练数据集路径")@NotEmpty(message = "必须有样本集")private String trainFilePath;/*** 预测数据集路径*/@ApiModelProperty("预测数据集路径")@NotEmpty(message = "必须有预测集")private String dataFilePath;/*** 用于测试模型的数据比例,范围[0,1]*/@ApiModelProperty("用于测试模型的数据比例,范围[0,1]")@Max(value = 1L, message = "数据比例最大值为1.0")@Min(value = 0L, message = "数据比例最小值为0.0")private double testDataPct;/*** 迭代次数*/@ApiModelProperty("迭代次数")@NotNull(message = "迭代次数必填")@Min(value = 0, message = "迭代次数最小值为0")private Integer iter;/*** 正则化参数,范围[0,1]*/@NotNull(message = "正则化参数必填")@Max(value = 1L, message = "正则化参数最大值为1.0")@Min(value = 0L, message = "正则化参数最小值为0.0")@ApiModelProperty("正则化参数,范围[0,1]")private double regParam;/*** 弹性网络混合参数,范围[0,1]*/@ApiModelProperty("弹性网络混合参数,范围[0,1]")@Max(value = 1L, message = "弹性网络混合参数最大值为1.0")@Min(value = 0L, message = "弹性网络混合参数最小值为0.0")private double elasticNetParam;
}

算法实现

里面有些方法是为了保留小数 不需要的自己改


@Resource
private transient SparkSession sparkSession;@Overridepublic Map<String, Object> linearRegression(LinearRegressionEntity record) {log.info("========== 线性回归计算开始 ==========");Map<String, Object> map = new HashMap<>(16);Dataset<Row> source = getDataSetByHdfs(record.getTrainFilePath());List<Map<String, String>> sourceList = toList(source);//训练数据map.put("training", sourceList);//根据比例从数据源中随机抽取数据 /训练数据和测试数据比例 建议设为0.8Dataset<Row>[] splits = source.randomSplit(new double[]{record.getTestDataPct(), 1 - record.getTestDataPct()},1234L);//训练数据Dataset<Row> trainingData = splits[0].cache();// 10 / 0.3 / 0.8LinearRegression lr = new LinearRegression().setMaxIter(record.getIter()).setRegParam(record.getRegParam()).setElasticNetParam(record.getElasticNetParam());LinearRegressionModel lrModel = lr.fit(trainingData);//系数map.put("coefficients", Arrays.stream(lrModel.coefficients().toArray()).map(val -> NumberUtil.roundDown(val, 3).doubleValue()));//截距map.put("intercept", NumberUtil.roundDown(lrModel.intercept(), 3));//训练数据结果集LinearRegressionTrainingSummary trainingSummary = lrModel.summary();//迭代次数map.put("numIterations", trainingSummary.totalIterations());//损失率,一般会逐渐减小map.put("objectiveHistory", Arrays.stream(trainingSummary.objectiveHistory()).map(val -> NumberUtil.roundDown(val, 3).doubleValue()));//均方根误差map.put("rmse", NumberUtil.roundDown(trainingSummary.rootMeanSquaredError(), 3));//真实误差map.put("mae", NumberUtil.roundDown(trainingSummary.meanAbsoluteError(), 3));//r平方 越接近1说明效果越好map.put("r2", NumberUtil.roundDown(trainingSummary.r2(), 3));//预测数据Dataset<Row> predictionData = getDataSetByHdfs(record.getDataFilePath());Dataset<Row> predictionResult = lrModel.transform(predictionData).selectExpr("label", "features", "round(prediction,3) as prediction");predictionResult.show();List<Object> predictionFeaturesVal = dataSetToString(predictionResult.select("features"));map.put("data", toList(predictionResult));log.info("========== 线性回归计算结束 ==========");return map;}

getDataSetByHdfs方法

这个方法我与上面的方法放在一个类中,所以sparkSession没重复写

/*** 从hdfs中取数据** @param dataFilePath 数据路径* @return 数据集合*/private Dataset<Row> getDataSetByHdfs(String dataFilePath) {//屏蔽日志Logger.getLogger("org.apache.spark").setLevel(Level.WARN);Logger.getLogger("org.eclipse.jetty.server").setLevel(Level.OFF);Dataset<Row> dataset;try {//我这里的数据是libsvm格式的 如果是其他格式请自行更改dataset = sparkSession.read().format("libsvm").load(dataFilePath);log.info("获取数据结束 ");} catch (Exception e) {log.info("读取失败:{} ", e.getMessage());}return dataset;}

toList

/*** dataset数据转化为list数据** @param record 数据* @return 数据集合*/private List<Map<String, String>> toList(Dataset<Row> record) {log.info("格式化结果数据集===============================");List<Map<String, String>> list = new ArrayList<>();String[] columns = record.columns();List<Row> rows = record.collectAsList();for (Row row : rows) {Map<String, String> obj = new HashMap<>(16);for (int j = 0; j < columns.length; j++) {String col = columns[j];Object rowAs = row.getAs(col);String val = "";//如果是数组 //这一段不需要的可以只留下else的内容if (rowAs instanceof DenseVector) {if (((DenseVector) rowAs).values() instanceof double[]) {val = ArrayUtil.join(Arrays.stream(((DenseVector) rowAs).values()).map(rowVal -> NumberUtil.roundDown(rowVal, 3).doubleValue()).toArray(), ",");} else {val = rowAs.toString();}} else {val = rowAs.toString();}obj.put(col, val);log.info("列:{},名:{},值:{}", j, col, val);}list.add(obj);}return list;}

java spark一元(多元)线性回归相关推荐

  1. TensorFlow基础3-机器学习基础知识(解析法实现一元线性回归、多元线性回归)

    记录TensorFlow听课笔记 文章目录 记录TensorFlow听课笔记 一,机器学习基础 1.1一元线性回归 1.2解析法实现一元线性回归 1.3解析法实现多元线性回归 二,代码实现一元/多元回 ...

  2. 利用梯度下降法求解一元线性回归和多元线性回归

    文章目录 原理以及公式 [1]一元线性回归问题 [2]多元线性回归问题 [3]学习率 [4]流程分析(一元线性回归) [5]流程分析(多元线性回归) 归一化原理以及每种归一化适用的场合 一元线性回归代 ...

  3. 解析法实现一元线性回归、多元线性回归以及数据模型可视化操作

    目录 [1]解析法实现一元线性回归 python列表实现 利用Numpy实现 利用TensorFlow实现 数据和模型可视化 [2]解析法实现多元线性回归 利用Numpy实现 需要用到的NumPy数组 ...

  4. 多元线性回归数据集_TensorFlow学习Program1——13.实现一元、多元线性回归(基于房价数据集)...

    本节将针对波士顿房价数据集的房间数量(RM)采用简单一元线性回归,目标是预测在最后一列(MEDV)给出的房价.波士顿房价数据集可从http://lib.stat.cmu.edu/datasets/bo ...

  5. matlab重复线性回归,(MATLAB)一元线性回归和多元线性回归

    (MATLAB)一元线性回归和多元线性回归 (MATLAB)一元线性回归和多元线性回归 (MATLAB)一元线性回归和多元线性回归1.一元线性回归 2.多元线性回归2.1数据说明 2.2程序运行结果 ...

  6. java算多元回归方程_java多元线性回归算法

    多元线性回归的计算方法摘要 在实际经济问题中,一个变量往往受到多个变量的影响.例... STLyy UQ 在多元线性回归分析中,回归平方和表示的是所有 k 个自变量对 y 的变差的总影响,它可以 按公 ...

  7. 一元线性回归与多元线性回归理论及公式推导

    一元线性回归 回归分析只涉及到两个变量的,称一元回归分析. 一元回归的主要任务是从两个相关变量中的一个变量去估计另一个变量,被估计的变量,称因变量,可设为Y:估计出的变量,称自变量,设为X.回归分析就 ...

  8. 一元线性回归与多元线性回归

    线性回归action精讲 线性回归是利用数理统计中回归分析,来确定两种或两种以上变量间相互依赖的定量关系的一种统计分析方法,运用十分广泛.其表达形式为y = w'x+e,e为误差服从均值为0的正态分布 ...

  9. java实现一元线性回归算法

    网上看一个达人用java写的一元线性回归的实现,我觉得挺有用的,一些企业做数据挖掘不是用到了,预测运营收入的功能吗?采用一元线性回归算法,可以计算出类似的功能.直接上代码吧: 1.定义一个DataPo ...

最新文章

  1. python如何调用c++numpy.ndarray代码_python – 在cython中使用numpy:定义ndarray数据类型/ ndims...
  2. 强软弱虚引用,只有体会过了,才能记住
  3. mysql经典面试题
  4. 高端网站建设css3动画响应式模板_网站建设中整站定制与模板建站存在着什么差异...
  5. Python3.x和Python2.x的区别[转]
  6. CAP 理论、BASE 理论、FLP 理论
  7. Redis Save 命令
  8. [Deprecated( please use panBy and panTo APIs )]
  9. 史上最全android分辨率
  10. OpenStack安装流程(juno版)- 添加镜像服务(glance)
  11. 完美代码,让你的代码无懈可击
  12. 带蓝色的紫罗兰色——三色配色篇
  13. HDU4622- Reincarnation(后缀自动机)
  14. kafka系列文章四(Consumer Group)
  15. 搭建Domoticz智能家居服务器实现外网控制ESP8266
  16. envi反演水质参数_遥感干旱反演方法汇总
  17. 这个社会最大的现实是“大鱼吃小鱼,小鱼吃虾米”
  18. 计算机b级考试题型分值分布,英语b级多少分算过-英语B级考试分值分布告诉我一下 – 手机爱问...
  19. scanline_p8
  20. 前端基础 es6、vue

热门文章

  1. 计算机在生态学的应用,应用生态学
  2. VC++ Tab Control控件的基本用法
  3. 竹炭纤维集成墙面板装修的缺点是什么,有哪些弊端
  4. 孙正义万字访谈实录:AI是我现在唯一关注的事情,我是科技的绝对信徒
  5. label smooth/mixup——深度学习中的一种防止过拟合方法
  6. Win10系统中临时文件夹位置及临时文件的删除
  7. 可口可乐和Tafi合作铸造NFT
  8. SpringCloudAlibaba踩坑日记(二)Relying upon circular references is discouraged and they are prohibited by
  9. java合肥工业大学考试题库_合肥工业大学java程序设计实验二
  10. Unity画线之GL