线性回归(Linear Regression)

梯度下降算法在机器学习方法分类中属于监督学习。利用它可以求解线性回归问题,计算一组二维数据之间的线性关系,假设有一组数据如下下图所示

其中X轴方向表示房屋面积、Y轴表示房屋价格。我们希望根据上述的数据点,拟合出一条直线,能跟对任意给定的房屋面积实现价格预言,这样求解得到直线方程过程就叫线性回归,得到的直线为回归直线,数学公式表示如下:

二:梯度下降 (Gradient Descent)




三:代码实现

数据读入

public List<DataItem> getData(String fileName) {List<DataItem> items = new ArrayList<DataItem>();File f = new File(fileName);try {if (f.exists()) {BufferedReader br = new BufferedReader(new FileReader(f));String line = null;while((line = br.readLine()) != null) {String[] data = line.split(",");if(data != null && data.length == 2) {DataItem item = new DataItem();item.x = Integer.parseInt(data[0]);item.y = Integer.parseInt(data[1]);items.add(item);}}br.close();}} catch (IOException ioe) {System.err.println(ioe);}return items;
}

归一化处理

public void normalization(List<DataItem> items) {float min = 100000;float max = 0;for(DataItem item : items) {min = Math.min(min, item.x);max = Math.max(max, item.x);}float delta = max - min;for(DataItem item : items) {item.x = (item.x - min) / delta;}
}

梯度下降

public float[] gradientDescent(List<DataItem> items) {int repetion = 1500;float learningRate = 0.1f;float[] theta = new float[2];Arrays.fill(theta, 0);float[] hmatrix = new float[items.size()];Arrays.fill(hmatrix, 0);int k=0;float s1 = 1.0f / items.size();float sum1=0, sum2=0;for(int i=0; i<repetion; i++) {for(k=0; k<items.size(); k++ ) {hmatrix[k] = ((theta[0] + theta[1]*items.get(k).x) - items.get(k).y);}for(k=0; k<items.size(); k++ ) {sum1 += hmatrix[k];sum2 += hmatrix[k]*items.get(k).x;}sum1 = learningRate*s1*sum1;sum2 = learningRate*s1*sum2;// 更新 参数thetatheta[0] = theta[0] - sum1;theta[1] = theta[1] - sum2;}return theta;
}

价格预言

public float predict(float input, float[] theta) {float result = theta[0] + theta[1]*input;return result;
}

线性回归图

public void drawPlot(List<DataItem> series1, List<DataItem> series2, float[] theta) {int w = 500;int h = 500;BufferedImage plot = new BufferedImage(w, h, BufferedImage.TYPE_INT_ARGB);Graphics2D g2d = plot.createGraphics();g2d.setRenderingHint(RenderingHints.KEY_ANTIALIASING, RenderingHints.VALUE_ANTIALIAS_ON);g2d.setPaint(Color.WHITE);g2d.fillRect(0, 0, w, h);g2d.setPaint(Color.BLACK);int margin = 50;g2d.drawLine(margin, 0, margin, h);g2d.drawLine(0, h-margin, w, h-margin);float minx=Float.MAX_VALUE, maxx=Float.MIN_VALUE;float miny=Float.MAX_VALUE, maxy=Float.MIN_VALUE;for(DataItem item : series1) {minx = Math.min(item.x, minx);maxx = Math.max(maxx, item.x);miny = Math.min(item.y, miny);maxy = Math.max(item.y, maxy);}for(DataItem item : series2) {minx = Math.min(item.x, minx);maxx = Math.max(maxx, item.x);miny = Math.min(item.y, miny);maxy = Math.max(item.y, maxy);}// draw X, Y Title and Aixesg2d.setPaint(Color.BLACK);g2d.drawString("价格(万)", 0, h/2);g2d.drawString("面积(平方米)", w/2, h-20);// draw labels and legendg2d.setPaint(Color.BLUE);float xdelta = maxx - minx;float ydelta = maxy - miny;float xstep = xdelta / 10.0f;float ystep = ydelta / 10.0f;int dx = (w - 2*margin) / 11;int dy = (h - 2*margin) / 11;// draw labelsfor(int i=1; i<11; i++) {g2d.drawLine(margin+i*dx, h-margin, margin+i*dx, h-margin-10);g2d.drawLine(margin, h-margin-dy*i, margin+10, h-margin-dy*i);int xv = (int)(minx + (i-1)*xstep);float yv = (int)((miny + (i-1)*ystep)/10000.0f);g2d.drawString(""+xv, margin+i*dx, h-margin+15);g2d.drawString(""+yv, margin-25, h-margin-dy*i);}// draw pointg2d.setPaint(Color.BLUE);for(DataItem item : series1) {float xs = (item.x - minx) / xstep + 1;float ys = (item.y - miny) / ystep + 1;g2d.fillOval((int)(xs*dx+margin-3), (int)(h-margin-ys*dy-3), 7,7);}g2d.fillRect(100, 20, 20, 10);g2d.drawString("训练数据", 130, 30);// draw regression lineg2d.setPaint(Color.RED);for(int i=0; i<series2.size()-1; i++) {float x1 = (series2.get(i).x - minx) / xstep + 1;float y1 = (series2.get(i).y - miny) / ystep + 1;float x2 = (series2.get(i+1).x - minx) / xstep + 1;float y2 = (series2.get(i+1).y - miny) / ystep + 1;g2d.drawLine((int)(x1*dx+margin-3), (int)(h-margin-y1*dy-3), (int)(x2*dx+margin-3), (int)(h-margin-y2*dy-3));}g2d.fillRect(100, 50, 20, 10);g2d.drawString("线性回归", 130, 60);g2d.dispose();saveImage(plot);
}

四:总结

本文通过最简单的示例,演示了利用梯度下降算法实现线性回归分析,使用更新收敛的算法常被称为LMS(Least Mean Square)又叫Widrow-Hoff学习规则,此外梯度下降算法还可以进一步区分为增量梯度下降算法与批量梯度下降算法,这两种梯度下降方法在基于神经网络的机器学习中经常会被提及,对此感兴趣的可以自己进一步探索与研究。

只分享干货,不止于代码

基于梯度下降算法求解线性回归相关推荐

  1. 基于jupyter notebook的python编程-----利用梯度下降算法求解多元线性回归方程,并与最小二乘法求解进行精度对比

    基于jupyter notebook的python编程-----利用梯度下降算法求解多元线性回归方程,并与最小二乘法求解进行精度对比目录 一.梯度下降算法的基本原理 1.梯度下降算法的基本原理 二.题 ...

  2. 梯度下降法求解线性回归

    梯度下降法求解线性回归 通过梯度下降法求解简单的一元线性回归 分别通过梯度下降算法和sklearn的线性回归模型(即基于最小二乘法)解决简单的一元线性回归实际案例,通过结果对比两个算法的优缺. 通过最 ...

  3. 基于梯度下降算法自建一种短期有效的自回归模型

    基于梯度下降算法自建一种短期有效的自回归模型 前言 一:移动平均模型 二:基于自适应滤波思想的权重优化 三:代码实现 四:实验分析 五:总结与展望 前言 基于时间序列自回归预测模型还是比较多的,简单的 ...

  4. 梯度下降算法以及线性回归模型

    版权声明:本文为原创文章:http://blog.csdn.net/programmer_wei/article/details/51941358 梯度下降算法是一个很基本的算法,在机器学习和优化中有 ...

  5. Python使用tensorflow中梯度下降算法求解变量最优值

    TensorFlow是一个用于人工智能的开源神器,是一个采用数据流图(data flow graphs)用于数值计算的开源软件库.数据流图使用节点(nodes)和边线(edges)的有向图来描述数学计 ...

  6. 基于梯度下降法的——线性回归拟合

    点击"小詹学Python",选择"置顶"公众号 重磅干货,第一时间送达 本文转载自数据分析挖掘与算法,禁二次转载 阅读本文需要的知识储备: 高等数学 运筹学 P ...

  7. tensorflow实现svm iris二分类——本质上在使用梯度下降法求解线性回归(loss是定制的而已)...

    iris二分类 # Linear Support Vector Machine: Soft Margin # ---------------------------------- # # This f ...

  8. loss下降auc下降_梯度下降算法 线性回归拟合(附Python/Matlab/Julia源代码)

    梯度下降 梯度下降法的原理 梯度下降法(gradient descent)是一种常用的一阶(first-order)优化方法,是求解无约束优化问题最简单.最经典的方法之一. 梯度下降最典型的例子就是从 ...

  9. python多元线性回归代码_Python实现梯度下降算法求多元线性回归(一)

    预备知识及相关文档博客 学习吴恩达机器学习课程笔记,并用python实现算法 python numpy基本教程: numpy相关教程 数据来自于UCI的机器学习数据库: UCI的机器学习数据库 pyt ...

  10. 干货|简单理解梯度下降及线性回归

    点击上方"小白学视觉",选择加"星标"或"置顶" 重磅干货,第一时间送达 一.线性回归 在回归分析中,一个自变量和一个因变量的关系可用一条直 ...

最新文章

  1. explicit specialization of ‘Race‘ after instantiation ,implicit instantiation first required here。
  2. windows phone发布时其他注意事项
  3. 【Android RTMP】音频数据采集编码 ( FAAC 头文件与静态库拷贝到 AS | CMakeList.txt 配置 FAAC | AudioRecord 音频采样 PCM 格式 )
  4. Javascript添加收藏夹和设为首页兼容写法
  5. python笔记 print+‘\r‘ (打印新内容时删除打印的旧内容)
  6. 蓝桥杯-11-1实现strcmp函数(java)
  7. PHP使用session_set_save_handler陷阱
  8. android 购物车抛物线,添加到购物车抛物线动画
  9. Frida之安装和使用教程
  10. 判断包含字符当中包含小数点_Python|提取包含指定文本的行到一个文本文件(字符串或正则)...
  11. Bailian4120 硬币【0-1背包】
  12. X264结构体中的变量解释
  13. spring aop和事务同时开启带来的一些问题
  14. 中国无线城市市场发展策略及未来前景规划报告2022年版
  15. 信息系统项目管理师自学笔记(二十二)——布线工程、网络规划与设计
  16. 体验谷歌菜市场镜像版
  17. 为什么越来越多的人选择FUP T10S系列超声波探伤仪
  18. 克鲁斯卡尔(Kruskal)算法(严蔚敏C语言)
  19. 手机测试用例-多媒体测试用例
  20. 输入一个18位的身份证号码,从中提取出生日期

热门文章

  1. leetcode 5230 Check If It Is a Straight Line
  2. SQL Server Always Encrypted
  3. php 图片 圆角,php使用gd库在图片中画圆角矩形
  4. python编程心得体会800字_学习python的心得体会
  5. 从零开始搭建terraria(泰拉瑞亚)服务器
  6. 【设计模式】装饰者与继承装饰者与代理间的小九九
  7. vscode更改配置文件路径_Visual Studio Code安装和配置
  8. 关于Adobe flash player 本地播放器
  9. 密歇根大学新进展:AI+可穿戴设备,20秒检测儿童内化障碍
  10. 运维人故障定责甩锅话语指南