import java.io.BufferedReader;
import java.io.File;
import java.io.FileReader;
import java.io.IOException;public class LinearRegression {/** 训练数据示例:*   x0        x1        x2        y1.0       1.0       2.0       7.21.0       2.0       1.0       4.91.0       3.0       0.0       2.61.0       4.0       1.0       6.31.0       5.0      -1.0       1.01.0       6.0       0.0       4.71.0       7.0      -2.0      -0.6注意!!!!x1,x2,y三列是用户实际输入的数据,x0是为了推导出来的公式统一,特地补上的一列。x0,x1,x2是“特征”,y是结果h(x) = theta0 * x0 + theta1* x1 + theta2 * x2theta0,theta1,theta2 是想要训练出来的参数此程序采用“梯度下降法”**/private double [][] trainData;//训练数据,一行一个数据,每一行最后一个数据为 yprivate int row;//训练数据  行数private int column;//训练数据 列数private double [] theta;//参数thetaprivate double alpha;//训练步长private int iteration;//迭代次数public LinearRegression(String fileName){int rowoffile=getRowNumber(fileName);//获取输入训练数据文本的   行数int columnoffile = getColumnNumber(fileName);//获取输入训练数据文本的   列数trainData = new double[rowoffile][columnoffile+1];//这里需要注意,为什么要+1,因为为了使得公式整齐,我们加了一个特征x0,x0恒等于1this.row=rowoffile;this.column=columnoffile+1;this.alpha = 0.001;//步长默认为0.001this.iteration=100000;//迭代次数默认为 100000theta = new double [column-1];//h(x)=theta0 * x0 + theta1* x1 + theta2 * x2 + .......initialize_theta();loadTrainDataFromFile(fileName,rowoffile,columnoffile);}public LinearRegression(String fileName,double alpha,int iteration){int rowoffile=getRowNumber(fileName);//获取输入训练数据文本的   行数int columnoffile = getColumnNumber(fileName);//获取输入训练数据文本的   列数trainData = new double[rowoffile][columnoffile+1];//这里需要注意,为什么要+1,因为为了使得公式整齐,我们加了一个特征x0,x0恒等于1this.row=rowoffile;this.column=columnoffile+1;this.alpha = alpha;this.iteration=iteration;theta = new double [column-1];//h(x)=theta0 * x0 + theta1* x1 + theta2 * x2 + .......initialize_theta();loadTrainDataFromFile(fileName,rowoffile,columnoffile);}private int getRowNumber(String fileName){int count =0;File file = new File(fileName);BufferedReader reader = null;try {reader = new BufferedReader(new FileReader(file));while ( reader.readLine() != null)count++;reader.close();} catch (IOException e) {e.printStackTrace();} finally {if (reader != null) {try {reader.close();} catch (IOException e1) {}}}return count;}private int getColumnNumber(String fileName){int count =0;File file = new File(fileName);BufferedReader reader = null;try {reader = new BufferedReader(new FileReader(file));String tempString = reader.readLine();if(tempString!=null)count = tempString.split(" ").length;reader.close();} catch (IOException e) {e.printStackTrace();} finally {if (reader != null) {try {reader.close();} catch (IOException e1) {}}}return count;}private void initialize_theta()//将theta各个参数全部初始化为1.0{for(int i=0;i<theta.length;i++)theta[i]=1.0;}public void trainTheta(){int iteration = this.iteration;while( (iteration--)>0 ){//对每个theta i 求 偏导数double [] partial_derivative = compute_partial_derivative();//偏导数//更新每个thetafor(int i =0; i< theta.length;i++)theta[i]-= alpha * partial_derivative[i];}}private double [] compute_partial_derivative(){double [] partial_derivative = new double[theta.length];for(int j =0;j<theta.length;j++)//遍历,对每个theta求偏导数{partial_derivative[j]= compute_partial_derivative_for_theta(j);//对 theta i 求 偏导}return partial_derivative;}private double compute_partial_derivative_for_theta(int j){double sum=0.0;for(int i=0;i<row;i++)//遍历 每一行数据{sum+=h_theta_x_i_minus_y_i_times_x_j_i(i,j);}return sum/row;}private double h_theta_x_i_minus_y_i_times_x_j_i(int i,int j){double[] oneRow = getRow(i);//取一行数据,前面是feature,最后一个是ydouble result = 0.0;for(int k=0;k< (oneRow.length-1);k++)result+=theta[k]*oneRow[k];result-=oneRow[oneRow.length-1];result*=oneRow[j];return result;}private double [] getRow(int i)//从训练数据中取出第i行,i=0,1,2,。。。,(row-1){return trainData[i];}private void loadTrainDataFromFile(String fileName,int row, int column){for(int i=0;i< row;i++)//trainData的第一列全部置为1.0(feature x0)trainData[i][0]=1.0;File file = new File(fileName);BufferedReader reader = null;try {reader = new BufferedReader(new FileReader(file));String tempString = null;int counter = 0;while ( (counter<row) && (tempString = reader.readLine()) != null) {String [] tempData = tempString.split(" ");for(int i=0;i<column;i++)trainData[counter][i+1]=Double.parseDouble(tempData[i]);counter++;}reader.close();} catch (IOException e) {e.printStackTrace();} finally {if (reader != null) {try {reader.close();} catch (IOException e1) {}}}}public void printTrainData(){System.out.println("Train Data:\n");for(int i=0;i<column-1;i++)System.out.printf("%10s","x"+i+" ");System.out.printf("%10s","y"+" \n");for(int i=0;i<row;i++){for(int j=0;j<column;j++){System.out.printf("%10s",trainData[i][j]+" ");}System.out.println();}System.out.println();}public void printTheta(){for(double a:theta)System.out.print(a+" ");}}

测试类:

public class TestLinearRegression {public static void main(String[] args) {// TODO Auto-generated method stubLinearRegression m = new LinearRegression("trainData",0.001,1000000);m.printTrainData();m.trainTheta();m.printTheta();}}

参考地址:

https://www.bbsmax.com/A/xl563Y21dr/

多元线性回归(java实现)相关推荐

  1. java 线性回归_多元线性回归----Java简单实现

    import java.io.BufferedReader; import java.io.File; import java.io.FileReader; import java.io.IOExce ...

  2. 基于多元线性回归去除图片水印(Java版)

    前提 采集的图片有淡淡的水印,为了避免不必要的麻烦,需要淡化或去除水印.图片如下所示: Java自带的工具可以对图片指定位置(x,y)的颜色(r,g,b)进行替换.如果图片上下左右颜色一致,则可进行颜 ...

  3. java算多元回归方程_java多元线性回归算法

    多元线性回归的计算方法摘要 在实际经济问题中,一个变量往往受到多个变量的影响.例... STLyy UQ 在多元线性回归分析中,回归平方和表示的是所有 k 个自变量对 y 的变差的总影响,它可以 按公 ...

  4. 机器学习第3天:多元线性回归

    文章目录 一.具体实现步骤 第1步:数据预处理 导入库 导入数据集 将类别数据数字化 躲避虚拟变量陷阱 拆分数据集为训练集和测试集 第2步: 在训练集上训练多元线性回归模型 第3步:在测试集上预测结果 ...

  5. android 揭示动画_遗传编程揭示具有相互作用的多元线性回归

    android 揭示动画 We all had some sort of experience with linear regression. It's one of the most used re ...

  6. 机器学习多元线性回归_过度简化的机器学习(1):多元回归

    机器学习多元线性回归 The term machine learning may sound provocative. Machines do not learn like humans do. Ho ...

  7. 多元线性回归分析c语言,多元线性回归公式推导及R语言实现

    多元线性回归 多元线性回归模型 实际中有很多问题是一个因变量与多个自变量成线性相关,我们可以用一个多元线性回归方程来表示. 为了方便计算,我们将上式写成矩阵形式: Y = XW 假设自变量维度为N W ...

  8. 多元线性回归之Spss实现

    我们还以上篇中的数据来进行多元线性回归的分析及说明: 首先来看引入一个变量时候,也就是分析不良贷款和各项贷款余额的回归分析,这里只做对比用,详细的分析结果可以上上一篇博客: 引入两个自变量时,各年贷款 ...

  9. 多元线性回归最小二乘法及其应用

    Cholesky分解求系数参考: [1]冯天祥. 多元线性回归最小二乘法及其经济分析[J]. 经济师,2003,11:129. 还可以采用最小二乘法来估计参数: 算法设计也可以参考两种系数最终公式设计 ...

  10. TensorFlow多元线性回归实现

    TensorFlow实现多元线性回归 多元线性回归的具体实现 导入需要的所有软件包: 因为各特征的数据范围不同,需要归一化特征数据.为此定义一个归一化函数.另外,这里添加一个额外的固定输入值将权重和偏 ...

最新文章

  1. Linux 安装iostat命令
  2. tomcat的jdbc连接池PoolExhaustedException
  3. python telnet模块 more_[宜配屋]听图阁 - python 处理telnet返回的More,以及get想要的那个参数方法...
  4. 对学校的希望和寄语_新年元旦寄语【三篇】
  5. 监控 SQL Server 的运行状况
  6. Webpack —— tree-starking 解析
  7. (29)System Verilog进程间同步(旗语semaphore)
  8. 明年起网剧可参评白玉兰奖 你期待哪部网剧上榜?
  9. C 语言:我不是针对谁,我是说在座的都是乐色
  10. 完美数及寻找完美数的算法(Perfect number‘s algorithm)
  11. 配置光盘镜像YUM源和阿里YUM源(纯干货,建议收藏)
  12. UniBeast:在任何支持基于英特尔处理器的PC上安装OS X优胜美地
  13. 华为携手Work Shift Calendar (Shifter),将工作效率提升至更高水平
  14. html自动拿微信扫描,HTML——微信浏览器H5页面调用微信扫一扫
  15. java 读取文件inputstream_使用FileInputStream读取本地文件
  16. 海康威视监控下载下来的mp4格式的视频,小类别MPEG-PS格式
  17. 计算机专业保研英语自我介绍,计算机专业保研面试英文自我介绍
  18. 关于声子和热输运计算中BORN电荷和non-analytic修正的问题
  19. 北京翰鑫信息科技有限公司怎么样
  20. 【最新版】友价T5交易商城源码 ,10月更新新增自助交易

热门文章

  1. ES6-18:class类及其继承
  2. Web Service 移除 xmlns
  3. 一个flash前后台开源框架的的站点
  4. mongodb java 嵌套查询_MongoDB java 查询嵌套JSON数据?
  5. 京瓷 打印 打印机 账户_UV打印机买回来成废铁?不是选择不对,而是你想太多了……...
  6. 如何去选取第一批要阅读的论文?_顶会最佳论文奖得主:初入科研领域,如何正确做科研?...
  7. img 显示base64_用 PySimpleGUI 做程序(7)--显示图片
  8. android 坐标布局变形,Android:scrollBy实现view随意移动并显示坐标
  9. 从右边开始放_幸福的生活,从入住新房开始,效果很漂亮,忍不住给大家晒晒全屋...
  10. python的django框架是干嘛的_Django框架在Python开发很重要为什么?