文章目录

  • 1. 线性回归
    • 1.1 基本形式
    • 1.2 成本函数
  • 2. w 的计算方式
    • 2.1 标准方程法
      • 2.1.1 普通形式
      • 2.1.2 向量形式
      • 2.1.3 Python 实现
      • 2.1.4 计算复杂度
    • 2.2 梯度下降法
      • 2.2.1 梯度下降原理
      • 2.2.2 Python 实现
  • 3. Sklearn 实现
  • 参考资料

相关文章:

机器学习 | 目录

机器学习 | 回归评估指标

监督学习 | 非线性回归 之多项式回归原理及Sklearn实现

监督学习 | 线性回归 之正则线性模型原理及Sklearn实现

监督学习 | 线性分类 之Logistic回归原理及Sklearn实现

1. 线性回归

线性回归,又称普通最小二乘法(Ordinary Least Squares, OLS),是回归问题最简单也最经典的线性方法。线性回归需按照参数 w 和 b,使得对训练集的预测值与真实的回归目标值 y 之间的均方误差(MSE)最小。

均方误差(Mean Squared Error)是预测值与真实值之差的平方和除以样本数。

线性回归没有参数,这是一个优点,但也因此无法控制模型的复杂度。

1.1 基本形式

线性回归预测模型:

(1)f(x)=w1x1+w2x2+⋅⋅⋅+wnxn+bf(x)=w_1 x_1 + w_2 x_2 + \cdot \cdot \cdot + w_n x_n + b \tag{1}f(x)=w1​x1​+w2​x2​+⋅⋅⋅+wn​xn​+b(1)

  • f(x)f(x)f(x) 是预测值

  • nnn 是特征的数量

  • xix_ixi​ 是第 iii 个特征值

  • 偏置项 bbb 以及特征权重 w1,w2,⋅⋅⋅,wnw_1, w_2,\cdot \cdot \cdot ,w_nw1​,w2​,⋅⋅⋅,wn​

这可以用更为简介的向量化表示。

线性回归预测模型(向量化):

(2)f(x)=wT⋅x+b=θT⋅x\begin{aligned} f(x) &= w^T \cdot x + b \\ &= \theta^T \cdot x \\ \end{aligned}\tag{2} f(x)​=wT⋅x+b=θT⋅x​(2)

  • w=(w1;w2;...;wn)w=(w_1;w_2;...;w_n)w=(w1​;w2​;...;wn​)

  • www 和 bbb 学习得到,模型就得以确定

  • θ=(w;b)\theta=(w;b)θ=(w;b)

1.2 成本函数

在线性回归中,我们选择 MSE(均方误差)作为其成本函数(Cost Function),其原因在与:首先它是一个凸函数,其次是因为它是可导的,这两个条件决定了可以利用梯度下降法来求得 θ\thetaθ 的最小值。

2. w 的计算方式

关于 www 的计算,有两种方法,一种是利用最小二乘法最小化 MSE 导出 www 的计算公式(标准方程);另一种是利用梯度下降法找出 MSE 的最小值。

首先来看最小利用最小二乘法计算 www 的公式。

利用最小二乘法最小化成本函数 ,可以得出 θ\thetaθ 的计算方程(标准方程):[1]

(3)θ^=(XT⋅X)(−1)⋅XT⋅y\hat{\theta} = (X^T \cdot X)^{(-1)} \cdot X^T \cdot y \tag{3}θ^=(XT⋅X)(−1)⋅XT⋅y(3)

  • θ^\hat{\theta}θ^ 是使成本函数 MSE 最小化的 θ\thetaθ 值

  • yyy 是包含 y(1)y^{(1)}y(1) 到 y(m)y^{(m)}y(m) 的目标值向量

则最终学得的多元线性回归模型为:

(4)f(xˉi)=xˉi⋅θT=xˉi⋅(XTX)−1XTy\begin{aligned} f(\bar{x}_i) &= \bar{x}_i \cdot \theta^T \\ &= \bar{x}_i\cdot(X^TX)^{-1}X^Ty\\ \end{aligned}\tag{4} f(xˉi​)​=xˉi​⋅θT=xˉi​⋅(XTX)−1XTy​(4)

其中:xˉi=(xi;1)\bar{x}_i=(x_i;1)xˉi​=(xi​;1)。

推导过程如下:

2.1 标准方程法

2.1.1 普通形式

标准方程法,又称标准最小二乘法,即通过最小二乘法求出 www 和 bbb 或向量形式下的 θ\thetaθ,由最小二乘法导出的 θ\thetaθ 计算公式称为标准方程。

首先来推导普通形式下 www 和 bbb 的计算公式。

线性回归试图学得:

(5)f(xi)=wxi+b,使得f(xi)≃yif(x_i)=wx_i+b, 使得 f(x_i)\simeq y_i \tag{5}f(xi​)=wxi​+b,使得f(xi​)≃yi​(5)

如何确定 www 和 bbb 呢?关键在于如何衡量 f(x)f(x)f(x) 与 yyy 之间的差别。

回想一下,训练模型就是设置模型参数知道模型最适应训练集的过程。要达到这个目的,我们首先需要知道怎么衡量模型对训练数据的拟合程度是好还是差,在 机器学习 | 回归评估指标 里,我们了解到回归模型最常见的性能指标有均方误差(MSE)。因此以 MSE 为线性回归模型的成本函数,在训练线性回归模型时,我们需要找到最小化 MSE 的 www 值 w∗w^*w∗,即:

(6)(w∗,b∗)=arg⁡min⁡(w,b)∑i=1m(f(xi)−yi)2=arg⁡min⁡(w,b)∑i=1m(yi−wxi−b)2\begin{aligned} (w^*,b^*) &= \mathop{\arg\min}\limits_{(w,b)}\sum_{i=1}^m(f(x_i)-y_i)^2 \\ &= \mathop{\arg\min}\limits_{(w,b)}\sum_{i=1}^m(y_i-wx_i-b)^2 \\ \end{aligned} \tag{6} (w∗,b∗)​=(w,b)argmin​i=1∑m​(f(xi​)−yi​)2=(w,b)argmin​i=1∑m​(yi​−wxi​−b)2​(6)

均方误差有非常好的几何意义,它对应了常用的欧几里得距离(或简称欧氏距离,Euclidean distance),基于均方误差最小化来进行模型求解的方法称为“最小二乘法”(least square method),在线性回归中,最小二乘法就是试图找出一条直线,使得所有样本到线上的欧氏距离之和最小。

求解 www 和 bbb 使得 E(w,b)=∑i=1m(yi−wxi−b)2E_{(w,b)}=\sum_{i=1}^m(y_i-wx_i-b)^2E(w,b)​=∑i=1m​(yi​−wxi​−b)2 最小化的过程,称为线性回归模型的最小二乘“参数估计”(parameter estimation),我们可以将E(w,b)E_{(w,b)}E(w,b)​ 分别对 www 和 bbb 求偏导,得到:[2]

(7)∂E(w,b)∂w=2(w∑i=1mxi2−∑i=1m(yi−b)xi)\frac{\partial E_{(w,b)}}{\partial w} = 2\bigg(w\sum_{i=1}^m x_i^2 - \sum_{i=1}^m(y_i-b)x_i \bigg) \tag{7} ∂w∂E(w,b)​​=2(wi=1∑m​xi2​−i=1∑m​(yi​−b)xi​)(7)
(9)∂E(w,b)∂b=2(mb−∑i=1m(yi−wxi))\frac{\partial E_{(w,b)}}{\partial b} = 2\bigg(mb - \sum_{i=1}^m(y_i-wx_i)\bigg) \tag{9}∂b∂E(w,b)​​=2(mb−i=1∑m​(yi​−wxi​))(9)

然后令公式 (8)、(9) 为零可以得到 www 和 bbb 最优解的闭式(closed-form)解:

(10)w=∑i=1myi(xi−xˉ)∑i=1mxi2−1m(∑i=1mxi)2w = \frac{\sum_{i=1}^my_i(x_i-\bar{x})}{\sum_{i=1}^mx_i^2 - \frac{1}{m}\big(\sum_{i=1}^m x^i\big)^2} \tag{10}w=∑i=1m​xi2​−m1​(∑i=1m​xi)2∑i=1m​yi​(xi​−xˉ)​(10)
(11)b=1m∑i=1m(yi−wxi)b = \frac{1}{m}\sum_{i=1}^m(y_i-wx_i) \tag{11} b=m1​i=1∑m​(yi​−wxi​)(11)

2.1.2 向量形式

更一般的情形是对如有 nnn 个属性的数据集 DDD ,这时我们试图学得:

(12)f(xi)=wTxi+b,使得f(xi)≃yif(x_i)=w^Tx_i+b, 使得 f(x_i)\simeq y_i \tag{12}f(xi​)=wTxi​+b,使得f(xi​)≃yi​(12)

这称为多元线性回归(multivariate linear regression).

类似的,可利用最小二乘法对 www 和 bbb 进行估计。为便于讨论,我们吧 www 和 bbb 转换为向量形式 θ=(w;b)\theta = (w;b)θ=(w;b),即:

(13)f(xi)=wT⋅xi+b=θT⋅xi\begin{aligned} f(x_i) &= w^T \cdot x_i + b \\ &= \theta^T \cdot x_i \\ \end{aligned}\tag{13} f(xi​)​=wT⋅xi​+b=θT⋅xi​​(13)

相应的,把数据集 DDD 表示为一个 m×(d+1)m \times (d+1)m×(d+1) 大小的矩阵 XXX,其中每行对应与一个示例,该行前 ddd 个元素对应与示例的 ddd 个属性值,最后一个元素恒为 1,即:

(14)X=(x11x12⋯x1n1x21x22⋯x2n1⋮⋮⋱⋮⋮xm1xm2⋯xmn1)=(x1nT1x2nT1⋮⋮xmT1)X = \left( \begin{array}{cc} x_{11} & x_{12} & \cdots& x_{1n} & 1\\ x_{21} & x_{22} & \cdots\ & x_{2n} & 1\\ \vdots & \vdots & \ddots\ & \vdots & \vdots\\ x_{m1} & x_{m2} & \cdots& x_{mn} & 1\\ \end{array} \right) =\left( \begin{array}{cc} x_{1n}^T & 1\\ x_{2n}^T & 1\\ \vdots & \vdots\\ x_{m}^T & 1\\ \end{array} \right)\tag{14} X=⎝⎜⎜⎜⎛​x11​x21​⋮xm1​​x12​x22​⋮xm2​​⋯⋯ ⋱ ⋯​x1n​x2n​⋮xmn​​11⋮1​⎠⎟⎟⎟⎞​=⎝⎜⎜⎜⎛​x1nT​x2nT​⋮xmT​​11⋮1​⎠⎟⎟⎟⎞​(14)

再把标记也写成向量形式 y=(y1;y2;⋯ ;ym)y=(y_1;y_2;\cdots;y_m)y=(y1​;y2​;⋯;ym​),类似于公式 (7) ,有:

(15)θ^=arg⁡min⁡θ(y−Xθ)T(y−Xθ)\hat{\theta} = \mathop{\arg\min}\limits_{\theta}(y-X\theta)^T(y-X\theta)\tag{15}θ^=θargmin​(y−Xθ)T(y−Xθ)(15)

令 Eθ=(y−Xθ)T(y−Xθ)E_{\theta}=(y-X\theta)^T(y-X\theta)Eθ​=(y−Xθ)T(y−Xθ) ,对 θ\thetaθ 求导得到:

(16)dEθdθ=2XT(Xθ−y)\frac{dE_{\theta}}{d\theta}=2X^T(X\theta-y)\tag{16}dθdEθ​​=2XT(Xθ−y)(16)

令上式为零可得 θ\thetaθ 的最优解的闭式接,但由于涉及矩阵逆的计算,比单变量情形要复杂一些,下面我们做一个简单的讨论。

当 XTXX^TXXTX 为满秩矩阵(full-rank matrix)或正定矩阵(positive definite matrix)时,令公式 (15) 为零可得标准方程

(16)θ^=(XTX)−1XTy\hat{\theta}=(X^TX)^{-1}X^Ty \tag{16}θ^=(XTX)−1XTy(16)

其中 (XTX)−1(X^TX)^{-1}(XTX)−1 是矩阵 (XTX)(X^TX)(XTX) 的逆矩阵,令 xˉi=(xi,1)\bar{x}_i=(x_i,1)xˉi​=(xi​,1) ,则最终学得的多元线性回归模型为:

KaTeX parse error: No such environment: align* at position 9: \begin{̲a̲l̲i̲g̲n̲*̲}̲ f(\bar{x}…

2.1.3 Python 实现

我们利用 Python 简单实现一下 θ\thetaθ 以及回归方程的计算,首先方程 y=4+3×xy=4+3\times xy=4+3×x生成 100 个数据点并可视化:

import numpy as np
import matplotlib.pyplot as pltX = 2 * np.random.rand(100, 1)
y = 4 + 3 * X + np.random.randn(100, 1)plt.plot(X, y, "b.")
plt.xlabel("$x_1$", fontsize=18)
plt.ylabel("$y$", rotation=0, fontsize=18)
plt.axis([0, 2, 0, 15])
plt.show()
<Figure size 640x480 with 1 Axes>

接着我们利用公式 (16) 来计算 θ\thetaθ:

X_b = np.c_[np.ones((100, 1)), X]  # 向量形式下 x 的输入为 (x, 1)
theta = np.linalg.inv(X_b.T.dot(X_b)).dot(X_b.T).dot(y)
theta
array([[4.10499602],[2.78527083]])

我们用区间的首尾两个点(x=0 和 x=2)来画出拟合直线。计算出 θ\thetaθ 之后就可以利用公式 (17) 来计算两个点的的预测数据 y_predict :

X_new = np.array([[0], [2]])
X_new_b = np.c_[np.ones((2, 1)), X_new]  # add x0 = 1 to each instance
y_predict = X_new_b.dot(theta)
y_predictplt.plot(X_new, y_predict, "r-")
plt.plot(X, y, "b.")
plt.axis([0, 2, 0, 15])
plt.show()

2.1.4 计算复杂度

标准方程需对矩阵 (XTX)(X^TX)(XTX) 求逆,这是一个 n×nn \times nn×n的矩阵( nnn 是特征数量)。对这种矩阵求逆计算复杂度通常为 O(n2.4)O(n^{2.4})O(n2.4) 到 O(n3)O(n^{3})O(n3) 之间(取决于计算实现)。换句话说,如果将特征数量翻倍,那么计算时间将乘以大约 22.4=5.32^{2.4}=5.322.4=5.3 倍到 23=82^{3}=823=8 倍之间。[3]

特征数量较大时(例如 100 000)时,标准方程的计算将极其缓慢

好的一面是,相对于训练集中的实例数量 O(m)O(m)O(m) 来说,方程式线性的,所以能够有效的处理大量的训练集,只要内存足够。

同样,线性回归模型一经训练(不论是标准方程还是其他算法),预测就非常快速,因为计算复杂度相对于想要预测的实例数量和特征数量来说,都是线性的。换句话说,对两倍的实例(或者是两倍的特征数)进行预测,大概需要两倍的时间。因此,我们来看看其他的优化算法:梯度下降算法。

2.2 梯度下降法

2.2.1 梯度下降原理

关于梯度下降法的推导及 Python 实现,请参考我的另一片文章:机器学习 | 梯度下降原理及Python实现。

2.2.2 Python 实现

机器学习 | 梯度下降原理及Python实现

3. Sklearn 实现

我们将使用线性回归根据体质指数 (BMI) 预测预期寿命。

对于线性模型,我们将使用 sklearn.linear_model.LinearRegression 类(Sklearn 官方文档)。

我们将使用线性回归模型对数据进行拟合并画出拟合直线。

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from sklearn.linear_model import LinearRegressionbmi_life_data = pd.read_csv("data/bmi_and_life_expectancy.csv")bmi_life_model = LinearRegression()
bmi_life_model.fit(bmi_life_data[['BMI']], bmi_life_data[['BMI']])y_1 = bmi_life_model.predict(np.array(min(bmi_life_data['BMI'])).reshape(-1,1))
y_2 = bmi_life_model.predict(np.array(max(bmi_life_data['BMI'])).reshape(-1,1))y_1 = y_1.tolist()
y_1 = [y for x in y_1 for y in x]y_2 = y_2.tolist()
y_2 = [y for x in y_2 for y in x]plt.plot(bmi_life_data['BMI'], bmi_life_data['BMI'], 'b.')
plt.plot([min(bmi_life_data['BMI']), max(bmi_life_data['BMI'])], [y_1, y_2], "r-")
plt.xlabel("BMI")
plt.ylabel('life_expectancy')
plt.show()

参考资料

[1] 周志华. 机器学习[M]. 北京: 清华大学出版社, 2016: 53-56

[2] Aurelien Geron, 王静源, 贾玮, 边蕤, 邱俊涛. 机器学习实战:基于 Scikit-Learn 和 TensorFlow[M]. 北京: 机械工业出版社, 2018: 103-106.

[3] Aurelien Geron, 王静源, 贾玮, 边蕤, 邱俊涛. 机器学习实战:基于 Scikit-Learn 和 TensorFlow[M]. 北京: 机械工业出版社, 2018: 106-107.

监督学习 | 线性回归 之多元线性回归原理及Sklearn实现相关推荐

  1. 监督学习 | 线性回归 之正则线性模型原理及Sklearn实现

    文章目录 1. 正则线性模型 1.1 Ridge Regression(L2) 1.1.1 Sklearn 实现 1.1.2 Ridge + SDG 1.1.2.1 Sklearn 实现 1.2 La ...

  2. 线性回归原理----简单线性回归、多元线性回归

    回归分析是用来评估变量之间关系的统计过程.用来解释自变量X与因变量Y的关系.即当自变量X发生改变时,因变量Y会如何发生改变. 线性回归是回归分析的一种,评估的自变量X与因变量Y之间是一种线性关系,当只 ...

  3. 机器学习——一元线性回归和多元线性回归

    一元线性回归:梯度下降法 一元线性回归是线性回归的最简单的一种,即只有一个特征变量.首先是梯度下降法,这是比较经典的求法.一元线性回归通俗易懂地说,就是一元一次方程.只不过这里的斜率和截距要通过最小二 ...

  4. 【机器学习】线性回归,多元线性回归、自回归及衡量指标

    经典线性模型自变量的线性预测就是因变量的估计值. 广义线性模型:自变量的线性预测的函数是因变量的估计值. 常见的广义线性模型有:probit模型.poisson模型.对数线性模型等.对数线性模型里有: ...

  5. 绘制线性回归和多元线性回归

    本文用C#语言实现一元线性回归和多元线性回归.结合"winform双缓冲绘制坐标轴图像"https://www.luweidong.cn/details/89 实现绘制曲线图,效果 ...

  6. 机器学习:回归分析—— 一元线性回归、多元线性回归的简单实现

    回归分析 回归分析概述 基本概念 可以解决的问题 基本步骤和分类 线性回归 一元线性回归 多元线性回归 回归分析概述 基本概念 回归分析是处理多变量间相关关系的一种数学方法.相关关系不同于函数关系,后 ...

  7. matlab重复线性回归,(MATLAB)一元线性回归和多元线性回归

    (MATLAB)一元线性回归和多元线性回归 (MATLAB)一元线性回归和多元线性回归 (MATLAB)一元线性回归和多元线性回归1.一元线性回归 2.多元线性回归2.1数据说明 2.2程序运行结果 ...

  8. (MATLAB)一元线性回归和多元线性回归

    (MATLAB)一元线性回归和多元线性回归 1.一元线性回归 2.多元线性回归 2.1数据说明 2.2程序运行结果 1.一元线性回归 直接看代码,目标是建立 y y y和 x x x的函数关系,即求 ...

  9. 简单线性回归和多元线性回归

    有很多初学者不知道如何用R语言做回归,这里我讲解一下简单线性回归和多元线性回归. 当回归模型包含一个因变量和一个自变量时,我们称为简单线性回归.比如:身高和体重的关系. 当有不止一个预测变量时, 则称 ...

最新文章

  1. ​基于BCI的现代神经反馈有助于认知增强(一)
  2. Docker容器虚拟化技术---Docker运维管理(Swarm集群管理)3
  3. 伯克利弹跳机器人再进化:超精准着陆,指哪打哪
  4. 多线程编程(2): 线程的创建、启动、挂起和退出
  5. 大数据各子项目的环境搭建之建立与删除软连接(博主推荐)
  6. opendrive道路标准基础知识
  7. Linux:什么是 i386、i586、 i686、noarch?
  8. 轻松获得oblog2.52的WebShell
  9. 开封文化艺术职业学院学报杂志社开封文化艺术职业学院学报编辑部2022年第4期目录
  10. cad怎么画立体图形教学_怎么在CAD中绘制三维立体图
  11. 什么是迭代器(Iterator)
  12. Adding Animations之Zooming a View
  13. php制作搜索框_搜索功能(search.php)模板制作 - WordPress模板开发
  14. cesium去除控件及版权信息
  15. php中文网教程 百度云,网盘直链问题请教
  16. C++ 取整,四舍五入
  17. JavaScript 实例:当当网 首页选项卡切换效果
  18. Python快速上手系列--循环结构--基础篇
  19. 成都计算机学校什么时候开学,2018年成都中小学放假开学时间表
  20. 【AD】电源布线学习记录/关于过孔的大小/覆铜

热门文章

  1. 《大数据》第1期“研究”——大数据管理系统评测基准的 挑战与研究进展(上)...
  2. 【软件工程】抽象泄漏
  3. 【数据结构与算法】二项队列的Java实现
  4. 【Java】求解N皇后问题
  5. extern 在c/c++ 中的作用
  6. Servlet处理文件下载的编码问题,乱码。
  7. Windows Azure Web Site (15) 取消Azure Web Site默认的IIS ARR
  8. 2016年3月16日作业
  9. mysql 5.6 设置long_query_time的值无效的原因
  10. win7下cocos2dx2.2+vs2010+python2.7环境搭建