线性回归到底要干什么,顾名思义很简单,即在已有数据集上通过构建一个线性的模型来拟合该数据集特征向量的各个分量之间的关系,对于需要预测结果的新数据,我们利用已经拟合好的线性模型来预测其结果。关于线性回归的方法,现在使用得比较广泛的就是梯度下降和最小二乘法;我打算把最小二乘法和梯度下降分两篇博客来写,这篇就来说一说我对线性回归及最小二乘法的理解以及原理实现。

线性模型在二维空间中就是一条直线,在三维空间是一个平面,高维空间的线性模型不好去描述长什么样子;如果这个数据集能够用一个线性模型来拟合它的数据关系,不管是多少维的数据,我们构建线性模型的方法都是通用的。之前看吴恩达机器学习视频,第一节课讲的就是线性回归算法,课程里面提到了一个非常简单的案例:房屋估价系统。房屋估价系统问题就是当知道房屋面积、卧室个数与房屋价格的对应关系之后,在得知一个新的房屋信息后如何得到对应的新房屋价格,如果我们将房屋面积用x1表示,卧室个数用x2表示,即房屋价格h(x)可以被表示为房屋面积与卧室个数的一维线性方程:

这就是我们所说的线性模型,当然这个案例中只有房屋面积和卧室个数两个特征分量作,现实情况下我们要拟合的模型可能有相当多的特征分量,那么线性模型中对应的权重值θ也会有相同多的数量。为了方便表示我们使用矩阵和向量来表示这些数据

向量θ(长度为n)中每一个分量都是估计表达式函数h(x)中一个参数,矩阵X(m*n)表示由数据集中每一个样本的特征向量所组成的矩阵。其实这样一个看起来很简单的式子,它的本质经常就是一个规模极其庞大的线性方程组。如果我们用向量Y(长度为m)来表示数据集的实际值,如果用实际值来建立一个方程组,参数向量θ中的每一个值就是我们要求的未知量;大多数情况下建立的是一个超定方程组(没有确定的解),这个时候我们只能求出超定方程组的最优解。通过建立一个损失函数来衡量估计值和实际之间的误差的大小,我们将最小化损失函数作为一个约束条件来求出参数向量的最优解。


函数J(θ)即为损失函数,它计算出数据集中每一个样例的估计值和实际值的平方差并求取平均。然后就是我们的最小二乘法登场了,最小二乘法通过数学推导(后面给出证明)可以直接得到一个标准方程,这个标准方程的解就是最优的参数向量。

推导方法一

最小二乘法通过数学推导出标准方程的过程其实非常简单,知乎上有一篇博客https://zhuanlan.zhihu.com/p/22474562写得很详细,这里借鉴一下:



推导方法二


推导方法三 (利用Hessian矩阵判断是否是极小值点)

通过正规方程计算得到最优的参数向量之后,就可以确定线性回归方程了,预测只需要将特征向量代入到回归方程之中,就可以计算出估计值了。

既然弄清了原理,那么实现就会显得非常简单;考虑到这个算法过程中使用矩阵乘法的次数很多,所以我使用了python语言以及调用numpy库来实现线性回归的算法,这里使用了sklearn库中的波士顿房价数据集,代码如下


import numpy as np
from sklearn.datasets import load_boston  # 导入博士顿房价数据集
from sklearn import linear_modelclass LinerRegression:M_x = []  #M_y = []  #M_theta = []  # 参数向量trained = Falsedef __init__(self):passdef regression(self, data, target):self.M_x = np.mat(data)# 每个向量添加一个分量1,用来对应系数θ0fenliang = np.ones((len(data), 1))self.M_x = np.hstack((self.M_x, fenliang))self.M_y = np.mat(target)M_x_T = self.M_x.T  # 计算X矩阵的转置矩阵self.M_theta = (M_x_T * self.M_x).I * M_x_T * self.M_y.T  # 通过最小二乘法计算出参数向量print('Start to train it with my own implementation of liearRegression')print(self.M_theta)self.trained = Truedef predict(self, vec):if not self.trained:print("You haven't finished the regression!")returnM_vec = np.mat(vec)fenliang = np.ones((len(vec), 1))M_vec = np.hstack((M_vec, fenliang))estimate = np.matmul(M_vec, self.M_theta)return estimatedef test_my_linear_regression():# 从sklearn的数据集中获取相关向量数据集data和房价数据集targetdata, target = load_boston(return_X_y=True)lr = LinerRegression()lr.regression(data, target)# 提取一批样例观察一下拟合效果test = data[::51]M_test = np.mat(test)real = target[::51]  # 实际数值realestimate = np.array(lr.predict(M_test))  # 回归预测数值estimate# 打印结果for i in range(len(test)):print("实际值:", real[i], " 估计值:", estimate[i, 0])def test_sklearn_liear_regression():data, target = load_boston(return_X_y=True)clf = linear_model.LinearRegression()clf.fit(data, target)print('Start to train it with Sklearn ')print(clf.coef_)print(clf.intercept_)if __name__ == '__main__':test_sklearn_liear_regression()test_my_linear_regression()

运行结果如下:

Start to train it with Sklearn
[-1.08011358e-01  4.64204584e-02  2.05586264e-02  2.68673382e+00-1.77666112e+01  3.80986521e+00  6.92224640e-04 -1.47556685e+003.06049479e-01 -1.23345939e-02 -9.52747232e-01  9.31168327e-03-5.24758378e-01]
36.45948838509001
Start to train it with my own implementation of liearRegression
[[-1.08011358e-01][ 4.64204584e-02][ 2.05586264e-02][ 2.68673382e+00][-1.77666112e+01][ 3.80986521e+00][ 6.92224640e-04][-1.47556685e+00][ 3.06049479e-01][-1.23345939e-02][-9.52747232e-01][ 9.31168327e-03][-5.24758378e-01][ 3.64594884e+01]]
实际值: 24.0  估计值: 30.003843377016256
实际值: 20.5  估计值: 23.97222284868958
实际值: 18.6  估计值: 19.79013683546337
实际值: 19.4  估计值: 17.286018936122648
实际值: 50.0  估计值: 43.18949843697001
实际值: 20.9  估计值: 21.695808865539043
实际值: 33.4  估计值: 35.56226856966765
实际值: 21.7  估计值: 22.718066074783085
实际值: 17.2  估计值: 13.70756369210629
实际值: 20.0  估计值: 18.51247609292085

可以看到用Sklearn里LinearRegarssion训练得到的参数其实和我实现的结果是一样的。
另外
绝大多数样例通过线性回归模型预测的结果与真实结果十分接近,但是存在有一定的误差,可以接受。

参考资料
https://blog.csdn.net/qq_32864683/article/details/80488523
https://blog.csdn.net/perfect_accepted/article/details/78383434

[机器学习-原理及实现篇]线性回归-最小二乘法相关推荐

  1. [机器学习-原理篇]学习之线性回归、岭回归、Lasso回归

    线性回归.岭回归.Lasso回归 前言 一,线性回归--最小二乘 二,Lasso回归 三,岭回归 四, Lasso回归和岭回归的同和异 五, 为什么 lasso 更容易使部分权重变为 0 而 ridg ...

  2. 机器学习原理篇:基础数学理论 Ⅱ

    机器学习原理篇:基础数学理论 Ⅱ 文章目录 机器学习原理篇:基础数学理论 Ⅱ 一.前言 二.概率论 三.数理统计 四.最优化理论 1.目标函数 2.线性规划 3.梯度下降法 五.思考 1.微积分的主要 ...

  3. 《机器学习实战》8.2 线性回归基础篇之预测鲍鱼年龄

    <机器学习实战>8.2 线性回归基础篇之预测鲍鱼年龄 搜索微信公众号:'AI-ming3526'或者'计算机视觉这件小事' 获取更多人工智能.机器学习干货 csdn:https://blo ...

  4. 【机器学习笔记】:大话线性回归(二)拟合优度和假设检验

    大家好,我是东哥. 前一篇文章给大家介绍了线性回归的模型假设,损失函数,参数估计,和简单的预测.具体内容请看下面链接:[机器学习笔记]:大话线性回归(一) 但其实还有很多问题需要我们解决:这个模型的效 ...

  5. 【机器学习】用QR分解求最小二乘法的最优闭式解

    [机器学习]用QR分解求最小二乘法的最优闭式解 写在前面 QR分解 定义 QR的求解 线性回归模型 用QR分解求解最优闭式解 矩阵的条件数 实验 运行结果 写在前面 今天刷知乎,看到张皓在面试官如何判 ...

  6. 《机器学习》实验一:线性回归

    <机器学习>实验一:线性回归 <机器学习>实验一:线性回归 实验目的 实验原理 1. 线性回归 2. 梯度下降法 3. 最小二乘法 实验内容 实验器材 实验步骤 1. 随机生成 ...

  7. 《机器学习实战》8.4 线性回归之乐高玩具套件二手交易价格预测

    <机器学习实战>8.4 线性回归之乐高玩具套件二手交易价格预测 搜索微信公众号:'AI-ming3526'或者'计算机视觉这件小事' 获取更多人工智能.机器学习干货 csdn:https: ...

  8. Coursera公开课笔记: 斯坦福大学机器学习第四课“多变量线性回归(Linear Regression with Multiple Variables)”

    Coursera公开课笔记: 斯坦福大学机器学习第四课"多变量线性回归(Linear Regression with Multiple Variables)" 斯坦福大学机器学习第 ...

  9. 斯坦福大学机器学习第三课“多变量线性回归“

    斯坦福大学机器学习第三课"多变量线性回归(Linear Regression with Multiple Variables)" 斯坦福大学机器学习第四课"多变量线性回归 ...

最新文章

  1. 『设计模式』开发设计的七大原则,我做人还是挺有原则,那些代码呢?
  2. 使用FTP获取RFC文档
  3. 执行一次怎么会写入两次数据_浅谈 Redis 数据持久化之 AOF 模式
  4. 彻底凉凉!两头部网红女主播账号被封,逃税被罚近亿元 还被曝不给员工交社保...
  5. 同一台Windows机器中启动多个Memcached服务
  6. 想要做好SEO优化,你必须懂得SEO的核心因素
  7. 计算机机房里面难闻的气味,新装中央空调气味刺鼻为什么?怎么办?-中央空调 异味 刚开...
  8. 影响中国互联网的100人
  9. nginx代理出现Provisional headers are shown
  10. 不从Win7/Win8.1升级,直接全新安装并激活Win10方法
  11. python将数字拆分_Python 整数拆分
  12. java web添加背景图片_java web项目中如何插入背景图片
  13. matlab 复数函数拟合,lsqcurvefit拟合结果为复数
  14. 定义一个点类(Point)、圆类(Circle)和圆柱体类(Cylinder)的层次结构。圆包括圆心和半径两个数据成员,圆心具有点类的所有特征。圆柱体类由半径和高构成。要求各类提供支持初始化的构造函数
  15. ROS基础学习之ros-tutorials(一)
  16. 【分治法】逆序对的数量(结合归并排序,含详细思想、解法、代码及注释)
  17. -2147483648是不是int常量
  18. phpstudy安装后使原本的mysql连接不上
  19. 163个人电子邮箱免费注册,外贸邮箱用哪个比较好?
  20. 《绝地求生:大逃杀》百万级别数据分析项目,还有源码+数据

热门文章

  1. mysql 长时间连接没操作,断开连接
  2. Python 实现网络爬虫小程序
  3. 图像处理相关知识(不断更新)
  4. install cuda5 on ubuntu12.04
  5. nginx负载均衡的5种策略
  6. ETCD v3 restful api 使用详解
  7. JavaScript玩转机器学习:保存并加载 tf.Model
  8. MiniO纠删码快速入门
  9. 深入理解Java虚拟机(第三版)-13.Java内存模型与线程
  10. leetcode 实现 strStr()