今天跟大家一起学习机器学习比较简单的一个算法,也就是线性回归算法。

让我们通过一个例子开始:这个例子就是预测住房价格的,我们要使用一个数据集,数据集包含一个地方的住房价格,这里我们要根据不同房屋尺寸所售出的价格,画出我们的数据集,比方说,如果你朋友的房子是1250平方尺大小,你要告诉他这个房子能卖多少钱。那么,你可以做的一件事就是构建一个模型,也许是条直线,从这个数据模型上来看,也许你可以告诉你的朋友,他能以大约220000(美元)左右的价格卖掉这个房子。这就是监督学习算法的一个例子。

它被称作监督学习是因为对于每个数据来说,我们给出了“正确的答案”,即告诉我们:根据我们的数据来说,房子实际的价格是多少,而且,更具体来说,这是一个回归问题。回归一词指的是,我们根据之前的数据预测出一个准确的输出值,对于这个例子就是价格。同时,还有另外一种最常见的监督学洗方式,叫做分类问题,当我们想要预测离散的输出值,例如,我们正在薛兆癌症肿瘤,并想要确定肿瘤是良性还是恶性的,这就是0/1离散输出的问题。更近一步说,在监督学习中我们有一个数据集,这个数据集被称作是训练集。

下面就是实现一元线性回归模型的Java版本的代码,其中绘制数据集,和绘制回归模型使用的是JfreeChart,核心代码如下:

package cn.rocket.ml;import java.io.IOException;
import java.util.ArrayList;
import java.util.List;import cn.rocket.data.DataSet;
import cn.rocket.utils.ScatterPlot;public class LinearRegression {private double theta0 = 0.0 ;  //截距private double theta1 = 0.0 ;  //斜率private double alpha = 0.01 ;  //学习速率private int max_itea = 20000 ; //最大迭代步数private DataSet dataSet = new DataSet() ;public  LinearRegression() throws IOException{dataSet.loadDataFromTxt("datas/house_price.txt", ",",1);}public double predict(double x){return theta0+theta1*x ;}public double calc_error(double x, double y) {return predict(x)-y;}public void gradientDescient(){double sum0 =0.0 ;double sum1 =0.0 ;for(int i = 0 ; i < dataSet.getSize() ;i++) {sum0 += calc_error(dataSet.getDatas().get(i).get(0), dataSet.getLabels().get(i)) ;sum1 += calc_error(dataSet.getDatas().get(i).get(0), dataSet.getLabels().get(i))*dataSet.getDatas().get(i).get(0) ;}this.theta0 = theta0 - alpha*sum0/dataSet.getSize() ; this.theta1 = theta1 - alpha*sum1/dataSet.getSize() ; }public void lineGre() {int itea = 0 ;while( itea< max_itea){//System.out.println(error_rate);System.out.println("The current step is :"+itea);System.out.println("theta0 "+theta0);System.out.println("theta1 "+theta1);System.out.println();gradientDescient();itea ++ ;}} ;public static void main(String[] args) throws IOException {LinearRegression linearRegression = new LinearRegression() ;linearRegression.lineGre();List<Double> list = new ArrayList<Double>() ;for(int i = 0 ; i < linearRegression.dataSet.getSize() ;i++) {list.add(linearRegression.dataSet.getDatas().get(i).get(0));}ScatterPlot.data("Datas", list, linearRegression.dataSet.getLabels(),linearRegression.theta0,linearRegression.theta1);}}

这段代码值得我们注意的问题有很多,一个是学习步长alpha的设置,如果设置的太大最后结果会无法收敛,但是如果设置的太小训练会非常缓慢。

下面是结果,我们可以看到,散点图是训练数据,红色的直线表示我们训练出来的一元线性模型,我们可以看出该模型能对训练数据做一个较好的线性拟合。

该项目的项目源码我已经放在GitHub上。项目地址:https://github.com/ShengPengYu/MachineLearning

Java实现线性回归模型算法相关推荐

  1. 线性回归模型算法原理及Python实现

    文章内容主要来自Aurelien Geron<Hands-on Machine Learning withi Scikit-Learn&TensorFlow> 目录 线性回归方程 ...

  2. java 一元线性回归_算法笔记:一元线性回归及Java实现

    这是voidAlex原创的第四篇博文. 源码在我的GitHub 回归问题 回归问题是研究自变量和因变量之间关系的一种预测模型技术.例如我们可以通过回归模型去预测房价与房子面积之间的关系,一个人每周花在 ...

  3. 通过线性回归模型及优化实现AQI分析与预测

    目录 1 项目背景与分析说明 1.1 项目背景 1.2 数据说明 1.3 分析说明 2 数据预处理 2.1 导入相关库 2.2 导入数据 2.3 数据预处理 2.3.1 缺失值 2.3.2 异常值 2 ...

  4. 基于机器学习梯度下降优化算法来寻找最佳的线性回归模型

    https://www.toutiao.com/a6638782437587419652/ 幻风的AI之路 2018-12-25 18:12:27 线性回归模型 线性回归模型是一个非常简单的算法模型, ...

  5. 【Android 内存优化】Java 内存模型 ( Java 虚拟机内存模型 | 线程私有区 | 共享数据区 | 内存回收算法 | 引用计数 | 可达性分析 )

    文章目录 一. Java 虚拟机内存模型 二. 程序计数器 ( 线程私有区 ) 三. 虚拟机栈 ( 线程私有区 ) 四. 本地方法栈 ( 线程私有区 ) 五. 方法区 ( 共享数据区 ) 1. 方法区 ...

  6. (二十二)用RANSAC算法来求线性回归模型的参数

    线性回归模型 一.什么是线性回归? 举个例子, 某商品的利润在售价为2 元. 5 元. 10 元时分别为 4 元. 11 元. 20 元, 我们很容易得出商品的利润与售价的关系符合直线:y=2x. 在 ...

  7. 常用的三种线性模型算法--线性回归模型、岭回归模型、套索回归模型

    常用的三种线性模型算法–线性回归模型.岭回归模型.套索回归模型 线性模型基本概念 线性模型的一般预测模型是下面这个样子的,一般有多个变量,也可以称为多个特征x1.x2.x3 - 最简单的线性模型就是一 ...

  8. 【项目实战】Python实现多元线性回归模型(statsmodels OLS算法)项目实战

    说明:这是一个机器学习实战项目(附带数据+代码+文档+视频讲解),如需数据+代码+文档+视频讲解可以直接到文章最后获取. 1.项目背景 回归问题是一类预测连续值的问题,而能满足这样要求的数学模型称作回 ...

  9. 机器学习算法——线性回归的详细介绍 及 利用sklearn包实现线性回归模型

    目录 1.线性回归简介 1.1 线性回归应用场景 1.2 什么是线性回归 1.2.1 定义与公式 1.2.2 线性回归的特征与目标的关系分析 2.线性回归api初步使用 2.1 线性回归API 2.2 ...

最新文章

  1. webdiyer aspnet pager最近又用这个。还是记录下。
  2. 阻塞与非阻塞个人小结
  3. 获取本机IP地址[JavaScript / Node.js]
  4. 使用ganglia监控hadoop及hbase集群
  5. 中英翻译(基于百度翻译)
  6. 2.6宽带接入技术ADSL
  7. python阈值计算_python – 在numpy中计算超过阈值的数组值的最快方法
  8. java过滤器对ext异步,拦截EXT请求的过滤器
  9. 科学家预测:未来100万年人类将变成半机械人类
  10. Html加jq实现5星好评效果,关于jquery实现五星好评的方法
  11. Windows压力测试工具SuperBenchmarker
  12. elasticjob 源码分析
  13. Shell中uniq命令的用法
  14. Access-Control-Allow- 跨域CORS 的使用
  15. 文件系统读写性能测试实战
  16. 136 137 260只出现一次的数字【我亦无他唯手熟尔】
  17. 输入三个整数a,b,c。并进行两两相加,最后比较相加和的最大值。
  18. Texture Haven Spider
  19. 下载国外软件慢(例如:Python安装包)
  20. Matlab R2012a/b反复激活无效+license checkout failed解决方案

热门文章

  1. 怎么设置linux系统定时关机,Linux系统定时关机
  2. ORA-01031:insufficient privileges 解决方法
  3. Linux常用命令——rsync命令
  4. Web开发应了解的5种设计模式(转)
  5. CSDN Blogger小工具
  6. 【C语言】 Static关键字的用法和详解(太细了!!!)
  7. 网站关键词排名:关键词排名提升的5个方法!
  8. Microsoft Office Document Image Writer 和 Microsoft XPS Document Writer (Office组件轻松把PDF文件转成Word文档)...
  9. 经方败案群20150303李小荣讲桂枝芍药知母汤
  10. fstream ,ifstream,ofstream的用法详解