多元线性回归(java实现)
import java.io.BufferedReader;
import java.io.File;
import java.io.FileReader;
import java.io.IOException;public class LinearRegression {/** 训练数据示例:* x0 x1 x2 y1.0 1.0 2.0 7.21.0 2.0 1.0 4.91.0 3.0 0.0 2.61.0 4.0 1.0 6.31.0 5.0 -1.0 1.01.0 6.0 0.0 4.71.0 7.0 -2.0 -0.6注意!!!!x1,x2,y三列是用户实际输入的数据,x0是为了推导出来的公式统一,特地补上的一列。x0,x1,x2是“特征”,y是结果h(x) = theta0 * x0 + theta1* x1 + theta2 * x2theta0,theta1,theta2 是想要训练出来的参数此程序采用“梯度下降法”**/private double [][] trainData;//训练数据,一行一个数据,每一行最后一个数据为 yprivate int row;//训练数据 行数private int column;//训练数据 列数private double [] theta;//参数thetaprivate double alpha;//训练步长private int iteration;//迭代次数public LinearRegression(String fileName){int rowoffile=getRowNumber(fileName);//获取输入训练数据文本的 行数int columnoffile = getColumnNumber(fileName);//获取输入训练数据文本的 列数trainData = new double[rowoffile][columnoffile+1];//这里需要注意,为什么要+1,因为为了使得公式整齐,我们加了一个特征x0,x0恒等于1this.row=rowoffile;this.column=columnoffile+1;this.alpha = 0.001;//步长默认为0.001this.iteration=100000;//迭代次数默认为 100000theta = new double [column-1];//h(x)=theta0 * x0 + theta1* x1 + theta2 * x2 + .......initialize_theta();loadTrainDataFromFile(fileName,rowoffile,columnoffile);}public LinearRegression(String fileName,double alpha,int iteration){int rowoffile=getRowNumber(fileName);//获取输入训练数据文本的 行数int columnoffile = getColumnNumber(fileName);//获取输入训练数据文本的 列数trainData = new double[rowoffile][columnoffile+1];//这里需要注意,为什么要+1,因为为了使得公式整齐,我们加了一个特征x0,x0恒等于1this.row=rowoffile;this.column=columnoffile+1;this.alpha = alpha;this.iteration=iteration;theta = new double [column-1];//h(x)=theta0 * x0 + theta1* x1 + theta2 * x2 + .......initialize_theta();loadTrainDataFromFile(fileName,rowoffile,columnoffile);}private int getRowNumber(String fileName){int count =0;File file = new File(fileName);BufferedReader reader = null;try {reader = new BufferedReader(new FileReader(file));while ( reader.readLine() != null)count++;reader.close();} catch (IOException e) {e.printStackTrace();} finally {if (reader != null) {try {reader.close();} catch (IOException e1) {}}}return count;}private int getColumnNumber(String fileName){int count =0;File file = new File(fileName);BufferedReader reader = null;try {reader = new BufferedReader(new FileReader(file));String tempString = reader.readLine();if(tempString!=null)count = tempString.split(" ").length;reader.close();} catch (IOException e) {e.printStackTrace();} finally {if (reader != null) {try {reader.close();} catch (IOException e1) {}}}return count;}private void initialize_theta()//将theta各个参数全部初始化为1.0{for(int i=0;i<theta.length;i++)theta[i]=1.0;}public void trainTheta(){int iteration = this.iteration;while( (iteration--)>0 ){//对每个theta i 求 偏导数double [] partial_derivative = compute_partial_derivative();//偏导数//更新每个thetafor(int i =0; i< theta.length;i++)theta[i]-= alpha * partial_derivative[i];}}private double [] compute_partial_derivative(){double [] partial_derivative = new double[theta.length];for(int j =0;j<theta.length;j++)//遍历,对每个theta求偏导数{partial_derivative[j]= compute_partial_derivative_for_theta(j);//对 theta i 求 偏导}return partial_derivative;}private double compute_partial_derivative_for_theta(int j){double sum=0.0;for(int i=0;i<row;i++)//遍历 每一行数据{sum+=h_theta_x_i_minus_y_i_times_x_j_i(i,j);}return sum/row;}private double h_theta_x_i_minus_y_i_times_x_j_i(int i,int j){double[] oneRow = getRow(i);//取一行数据,前面是feature,最后一个是ydouble result = 0.0;for(int k=0;k< (oneRow.length-1);k++)result+=theta[k]*oneRow[k];result-=oneRow[oneRow.length-1];result*=oneRow[j];return result;}private double [] getRow(int i)//从训练数据中取出第i行,i=0,1,2,。。。,(row-1){return trainData[i];}private void loadTrainDataFromFile(String fileName,int row, int column){for(int i=0;i< row;i++)//trainData的第一列全部置为1.0(feature x0)trainData[i][0]=1.0;File file = new File(fileName);BufferedReader reader = null;try {reader = new BufferedReader(new FileReader(file));String tempString = null;int counter = 0;while ( (counter<row) && (tempString = reader.readLine()) != null) {String [] tempData = tempString.split(" ");for(int i=0;i<column;i++)trainData[counter][i+1]=Double.parseDouble(tempData[i]);counter++;}reader.close();} catch (IOException e) {e.printStackTrace();} finally {if (reader != null) {try {reader.close();} catch (IOException e1) {}}}}public void printTrainData(){System.out.println("Train Data:\n");for(int i=0;i<column-1;i++)System.out.printf("%10s","x"+i+" ");System.out.printf("%10s","y"+" \n");for(int i=0;i<row;i++){for(int j=0;j<column;j++){System.out.printf("%10s",trainData[i][j]+" ");}System.out.println();}System.out.println();}public void printTheta(){for(double a:theta)System.out.print(a+" ");}}
测试类:
public class TestLinearRegression {public static void main(String[] args) {// TODO Auto-generated method stubLinearRegression m = new LinearRegression("trainData",0.001,1000000);m.printTrainData();m.trainTheta();m.printTheta();}}
参考地址:
https://www.bbsmax.com/A/xl563Y21dr/
多元线性回归(java实现)相关推荐
- java 线性回归_多元线性回归----Java简单实现
import java.io.BufferedReader; import java.io.File; import java.io.FileReader; import java.io.IOExce ...
- 基于多元线性回归去除图片水印(Java版)
前提 采集的图片有淡淡的水印,为了避免不必要的麻烦,需要淡化或去除水印.图片如下所示: Java自带的工具可以对图片指定位置(x,y)的颜色(r,g,b)进行替换.如果图片上下左右颜色一致,则可进行颜 ...
- java算多元回归方程_java多元线性回归算法
多元线性回归的计算方法摘要 在实际经济问题中,一个变量往往受到多个变量的影响.例... STLyy UQ 在多元线性回归分析中,回归平方和表示的是所有 k 个自变量对 y 的变差的总影响,它可以 按公 ...
- 机器学习第3天:多元线性回归
文章目录 一.具体实现步骤 第1步:数据预处理 导入库 导入数据集 将类别数据数字化 躲避虚拟变量陷阱 拆分数据集为训练集和测试集 第2步: 在训练集上训练多元线性回归模型 第3步:在测试集上预测结果 ...
- android 揭示动画_遗传编程揭示具有相互作用的多元线性回归
android 揭示动画 We all had some sort of experience with linear regression. It's one of the most used re ...
- 机器学习多元线性回归_过度简化的机器学习(1):多元回归
机器学习多元线性回归 The term machine learning may sound provocative. Machines do not learn like humans do. Ho ...
- 多元线性回归分析c语言,多元线性回归公式推导及R语言实现
多元线性回归 多元线性回归模型 实际中有很多问题是一个因变量与多个自变量成线性相关,我们可以用一个多元线性回归方程来表示. 为了方便计算,我们将上式写成矩阵形式: Y = XW 假设自变量维度为N W ...
- 多元线性回归之Spss实现
我们还以上篇中的数据来进行多元线性回归的分析及说明: 首先来看引入一个变量时候,也就是分析不良贷款和各项贷款余额的回归分析,这里只做对比用,详细的分析结果可以上上一篇博客: 引入两个自变量时,各年贷款 ...
- 多元线性回归最小二乘法及其应用
Cholesky分解求系数参考: [1]冯天祥. 多元线性回归最小二乘法及其经济分析[J]. 经济师,2003,11:129. 还可以采用最小二乘法来估计参数: 算法设计也可以参考两种系数最终公式设计 ...
- TensorFlow多元线性回归实现
TensorFlow实现多元线性回归 多元线性回归的具体实现 导入需要的所有软件包: 因为各特征的数据范围不同,需要归一化特征数据.为此定义一个归一化函数.另外,这里添加一个额外的固定输入值将权重和偏 ...
最新文章
- Linux 安装iostat命令
- tomcat的jdbc连接池PoolExhaustedException
- python telnet模块 more_[宜配屋]听图阁 - python 处理telnet返回的More,以及get想要的那个参数方法...
- 对学校的希望和寄语_新年元旦寄语【三篇】
- 监控 SQL Server 的运行状况
- Webpack —— tree-starking 解析
- (29)System Verilog进程间同步(旗语semaphore)
- 明年起网剧可参评白玉兰奖 你期待哪部网剧上榜?
- C 语言:我不是针对谁,我是说在座的都是乐色
- 完美数及寻找完美数的算法(Perfect number‘s algorithm)
- 配置光盘镜像YUM源和阿里YUM源(纯干货,建议收藏)
- UniBeast:在任何支持基于英特尔处理器的PC上安装OS X优胜美地
- 华为携手Work Shift Calendar (Shifter),将工作效率提升至更高水平
- html自动拿微信扫描,HTML——微信浏览器H5页面调用微信扫一扫
- java 读取文件inputstream_使用FileInputStream读取本地文件
- 海康威视监控下载下来的mp4格式的视频,小类别MPEG-PS格式
- 计算机专业保研英语自我介绍,计算机专业保研面试英文自我介绍
- 关于声子和热输运计算中BORN电荷和non-analytic修正的问题
- 北京翰鑫信息科技有限公司怎么样
- 【最新版】友价T5交易商城源码 ,10月更新新增自助交易
热门文章
- ES6-18:class类及其继承
- Web Service 移除 xmlns
- 一个flash前后台开源框架的的站点
- mongodb java 嵌套查询_MongoDB java 查询嵌套JSON数据?
- 京瓷 打印 打印机 账户_UV打印机买回来成废铁?不是选择不对,而是你想太多了……...
- 如何去选取第一批要阅读的论文?_顶会最佳论文奖得主:初入科研领域,如何正确做科研?...
- img 显示base64_用 PySimpleGUI 做程序(7)--显示图片
- android 坐标布局变形,Android:scrollBy实现view随意移动并显示坐标
- 从右边开始放_幸福的生活,从入住新房开始,效果很漂亮,忍不住给大家晒晒全屋...
- python的django框架是干嘛的_Django框架在Python开发很重要为什么?