在统计学中,线性回归(Linear regression)是利用称为线性回归方程的最小二乘函数对一个或多个自变量和因变量之间关系进行建模的一种回归分析维基百科。

简单线性回归

当只有一个自变量的时候,成为简单线性回归。

简单线性回归模型的思路

为了得到一个简单线性回归模型,假设存在以房屋面积为特征,以价格为样本输出,包含四个样本的样本集,如图:

寻找一条直线,最大程度上拟合样本特征与样本输出之间的关系。

假设最佳拟合的直线方程为:,则对于样本特征  的每一个取值  的预测值为:。而我们希望的就是真值  和预测值  之间的差距尽量小。

可以用  表示两者之间的差距,对于所有的样本,使用求和公式求和处理:

∑ i =1 m | y ( i) −y ^ ( i) |

但是这个公式有一个问题,不容易求导,为了解决这个问题,可先对  进行平方,如此最后的公式就变成了:

∑ i =1 m (y ( i) −y ^ ( i) ) 2

最后,替换掉  ,即为:

∑ i =1 m (y ( i) −a x ( i) −b ) 2

因此,找到的一个简单线性回归模型就是找到合适的 a 和 b,使得该函数的值尽可能的小,该函数也称为损失函数(loss function)。

最小二乘法

找到合适的 a 和 b,使得  的值尽可能的小,这样的方法称为最小二乘法。

如何求 a 和 b 呢?令该函数为 ,分别使对 a 和 b 求导的结果为0。

对 b 求导:,得:

b =y ¯ ¯ ¯ −a x ¯ ¯ ¯

对 a 求导:,得:

a =∑ m i =1 (x ( i) −x ¯ ¯ ¯ )(y ( i) −y ¯ ¯ ¯ ) ∑ m i =1 (x ( i) −x ¯ ¯ ¯ ) 2

注:这里略去了公式的推导过程。还有很多内容因为篇幅有限不详细写了,如果想全面了解的可以点击这个链接跳转到我已经录制好的视频

简单线性回归的实现

有了数学的帮助,实现简单线性回归就比较方便了。

首先声明一个样本集:

import numpy as npx = np.array([1., 2., 3., 4., 5.])
y = np.array([1., 3., 2., 3., 5.])

公式中用到了 x 和 y 的均值:

x_mean = np.mean(x)
y_mean = np.mean(y)

求 a 和 b 的值有两种方法。第一种是使用 for 循环:

# 分子
num = 0.0# 分母
d = 0.0for x_i, y_i in zip(x, y):num += (x_i - x_mean) * (y_i - y_mean)d += (x_i - x_mean) ** 2a = num / d
b = y_mean - a * x_mean

第二种是使用矩阵乘:

num = (x - x_mean).dot(y - y_mean)
d = (x - x_mean).dot(x - x_mean)a = num / d
b = y_mean - a * x_mean

注:使用矩阵乘效率更高。

求出了 a 和 b,简单线性模型就有了:。对当前示例作图表示:

衡量线性回归法的指标

误差

一个训练后的模型通常都会使用测试数据集测试该模型的准确性。对于简单线性归回模型当然可以使用  来衡量,但是它的取值和测试样本个数 m 存在联系,改进方法很简单,只需除以 m 即可,即均方误差(Mean Squared Error):

M SE : 1 m ∑ i =1 m (y ( i) t est −y ^ ( i) t est ) 2

np.sum((y_predict - y_true) ** 2) / len(y_true)

值得一提的是 MSE 的量纲是样本单位的平方,有时在某些情况下这种平方并不是很好,为了消除量纲的不同,会对 MSE 进行开方操作,就得到了均方根误差(Root Mean Squared Error):

R MS E: 1 m ∑ i =1 m (y ( i) t est −y ^ _t est ( i) ) 2 − −− −− −− −− −− −− −− −− −− √ =M SE t es t − −− −− −− √

import mathmath.sqrt(np.sum((y_predict - y_true) ** 2) / len(y_true))

还有一种衡量方法是平均绝对误差(Mean Absolute Error),对测试数据集中预测值与真值的差的绝对值取和,再取一个平均值:

M AE : 1 m ∑ i =1 m | y ( i) t es t −y ^ ( i) t es t |

np.sum(np.absolute(y_predict - y_true)) / len(y_true)

注:Scikit Learn 的 metrics 模块中的 mean_squared_error() 方法表示 MSE,mean_absolute_error() 方法表示 MAE,没有表示 RMSE 的方法。

R Squared

更近一步,MSE、RMSE 和 MAE 的局限性在于对模型的衡量只能做到数值越小表示模型越好,而通常对模型的衡量使用1表示最好,0表示最差,因此引入了新的指标:R Squared,计算公式为:

R 2 =1 −S S r es i d u al S S t ot a l

,表示使用模型产生的错误;,表示使用  预测产生的错误。

更深入的讲,对于每一个预测样本的 x 的预测值都为样本的均值  ,这样的模型称为基准模型;当我们的模型等于基准模型时, 的值为0,当我们的模型不犯任何错误时  得到最大值1。

 还可以进行转换,转换结果为:

R 2 =1 −M SE ( y ^ ,y ) V ary

实现也很简单:

1 - np.sum((y_predict - y_true) ** 2) / len(y_true) / np.var(y_true)

注:Scikit Learn 的 metrics 模块中的 r2_score() 方法表示 R Squared。

多元线性回归

多元线性回归模型的思路

当有不只一个自变量时,即为多元线性回归,如图:

对于有 n 个自变量来说,我们想获得的线性模型为:

y =θ 0 +θ 1 x 1 +θ 2 x 2 +.. .+θ n x n

根据简单线性回归的思路,我们的目标即为:

找到 , , ,...,,使得  尽可能的小,其中  。

:训练数据中第 i 个样本的预测值;:训练数据中第 i 个样本的第 j 个自变量。

如果用矩阵表示即为:

y ^ ( i) =X ( i) ⋅θ

其中:

更进一步,将  也使用矩阵表示,即为:

y ^ =X b ⋅θ

其中:, 

因此,我们目标就成了:使  尽可能小。而对于这个公式的解,称为多元线性回归的正规方程解(Nomal Equation):

还有很多内容因为篇幅有限不详细写了,如果想全面了解的可以点击这个链接跳转到我已经录制好的视频

θ =(X T b Xb ) − 1 (X T b y )

实现多元线性回归

将多元线性回归实现在 LinearRegression 类中,且使用 Scikit Learn 的风格。

_init_() 方法首先初始化线性回归模型,_theta 表示 interception_ 表示截距,chef_ 表示回归模型中自变量的系数:

class LinearRegression:def __init__(self):self.coef_ = Noneself.interceiption_ = Noneself._theta = None

fit_normal() 方法根据训练数据集训练模型,X_b 表示添加了  的样本特征数据,并且使用多元线性回归的正规方程解求出 

def fit_normal(self, X_train, y_train):X_b = np.hstack([np.ones((len(X_train), 1)), X_train])self._theta = np.linalg.inv(X_b.T.dot(X_b)).dot(X_b.T).dot(y_train)self.interception_ = self._theta[0]self.coef_ = self._theta[1:]return self

predict() 方法为预测方法,同样使用了矩阵乘:

def predict(self, X_predict):X_b = np.hstack([np.ones((len(X_predict), 1)), X_predict])return X_b.dot(self._theta)

score() 根据给定的测试数据集使用 R Squared 指标计算模型的准确度:

def score(self, X_test, y_test):y_predict = self.predict(X_test)return r2_score(y_test, y_predict)

Scikit Learn 中的线性回归实现放在 linear_model 模块中,使用方法如下:

from sklearn.linear_model import LinearRegression

线性回归的特点

线性回归算法是典型的参数学习的算法,只能解决回归问题,其对数据具有强解释性。

缺点是多元线性回归的正规方程解  的时间复杂度高,为 ,可优化为 

转载于:https://blog.51cto.com/14014179/2313081

机器学习之优雅落地线性回归法相关推荐

  1. 机器学习9衡量线性回归法的指标,MSE,RMS,MAE

    文章目录 一.衡量线性回归法的指标,MSE,RMS,MAE 1.MSE均方误差(Mean Squared Error) 2.RSE均方误差(Root Mean Squared Error) 3.平均绝 ...

  2. 机器学习(一)—— 线性回归

    机器学习(一)-- 线性回归 目录 0. 回归(Regression)的由来 1. 回归关系 2. 线性回归的整体思路 (1)根据数据提出假设模型 (2)求解参数 1)梯度下降法 2)正规方程求解参数 ...

  3. 机器学习(一)- 线性回归/(拟合)模型

    # 前年学习记录的笔记,分享一下- Linear Models for Regression 目录 一.使用线性回归模型前数据处理及注意 二.线性回归,针对线性数据,通过最小二乘法让损失函数(cost ...

  4. 机器学习特征筛选:互信息法(mutual information)

    机器学习特征筛选:互信息法(mutual information) 互信息法多为分类问题的分类变量的筛选方法 经典的互信息也是评价定性自变量对定性因变量的相关性的,为了处理定量数据,最大信息系数法被提 ...

  5. 机器学习特征筛选:相关系数法(correlation)

    机器学习特征筛选:相关系数法(correlation) 通过计算特征与特征之间的相关系数的大小,可判定两两特征之间的相关程度. 取值区间在[-1, 1]之间,取值关系如下: corr(x1,x2)相关 ...

  6. java 多项式拟合最多的项数_机器学习(1)--线性回归和多项式拟合

    机器学习(1)--线性回归和多项式拟合 机器学习(2)逻辑回归 (数学推导及代码实现) 机器学习(3)softmax实现Fashion-MNIST分类 一 线性回归 线性回归,顾名思义是利用线性模型对 ...

  7. 机器学习经典算法之线性回归sklearn实现

    机器学习经典算法之线性回归sklearn实现 from sklearn import linear_model from sklearn import datasets import numpy as ...

  8. 机器学习之单变量线性回归(Linear Regression with One Variable)

    机器学习之单变量线性回归(Linear Regression with One Variable) 1. 模型表达(Model Representation) 我们的第一个学习算法是线性回归算法,让我 ...

  9. ML之MLiR:利用多元线性回归法,从大量数据(csv文件)中提取五个因变量(输入运输任务总里程数、运输次数、三种不同的车型,预测需要花费的小时数)来预测一个自变量

    ML之MLiR:利用多元线性回归法,从大量数据(csv文件)中提取五个因变量(输入运输任务总里程数.运输次数.三种不同的车型,预测需要花费的小时数)来预测一个自变量 输出结果 代码设计 from nu ...

最新文章

  1. Android 自定义 —— View moveTo与 rMoveTo 的区别
  2. 概率统计笔记:高斯威沙特分布
  3. pandas 中有关isin()函数的介绍,python中del解释
  4. boost::log::attribute_value用法的测试程序
  5. js 如何去除字符两端的引号
  6. 南抖音北快手,智障界的两泰斗
  7. Hbase最新官方文档中文翻译与注解1-10|hbase简介与配置信息等
  8. iPhone出现白苹果怎么修复?简单3步即可解决
  9. 关于:昨天H - 康托展开题目的探究。
  10. 玉米田 炮兵阵地 状态压缩DP
  11. 应用系统安全规范-自己想到和网络搜索到的点子记录整合一下
  12. android 添加json动画,Lottie 站在巨人的肩膀上实现 Android 酷炫动画效果
  13. 《脉脉:人才流动与迁徙2022》,遭”哄抢”的复合型程序员成IT黑马
  14. 前端面试题之【CSS】
  15. 格式化JSON stringify 的使用
  16. Jmeter入门实战(二)如何使用Jmeter的BeanShell断言,把响应数据中的JSON跟数据库中的记录对比
  17. Object.assign 原理及其实现
  18. sa-token使用简单使用
  19. 2020年11月编程排行出炉,Java市场占有率仍第一
  20. 1、esp32(arduino)接入阿里云MQTT及数据处理

热门文章

  1. iOS 11 安全区域适配总结
  2. C# TripleDES NoPadding 时对待加密内容进行补字节(8个字节为一个Block)
  3. Android开源中国客户端学习 (自定义View)左右滑动控件ScrollLayout
  4. Android中Handler
  5. 阿里巴巴连任 Java 全球管理组织席位
  6. Mongo、Redis、Memcached对比及知识总结
  7. PHP实时生成并下载超大数据量的EXCEL文件
  8. 一套完整的数字无线监控系统需要哪些设备和材料?
  9. 《Adobe InDesign CS6中文版经典教程》—第2课2.1节概述
  10. OpenSSH7.0兼容性测试报告