4.0 单变量线性回归问题

4.0.1 提出问题

在互联网建设初期,各大运营商需要解决的问题就是保证服务器所在的机房的温度常年保持在23摄氏度左右。在一个新建的机房里,如果计划部署346台服务器,我们如何配置空调的最大功率?

这个问题虽然能通过热力学计算得到公式,但是总会有误差。因此人们往往会在机房里装一个温控器,来控制空调的开关或者风扇的转速或者制冷能力,其中最大制冷能力是一个关键性的数值。更先进的做法是直接把机房建在海底,用隔离的海水循环降低空气温度的方式来冷却。

通过一些统计数据(称为样本数据),我们得到了表4-1。

表4-1 样本数据

样本序号 服务器数量(千台)X 空调功率(千瓦)Y
1 0.928 4.824
2 0.469 2.950
3 0.855 4.643
... ... ...

在上面的样本中,我们一般把自变量X称为样本特征值,把因变量Y称为样本标签值。

这个数据是二维的,所以我们可以用可视化的方式来展示,横坐标是服务器数量,纵坐标是空调功率,如图4-1所示。

图4-1 样本数据可视化

通过对上图的观察,我们可以判断它属于一个线性回归问题,而且是最简单的一元线性回归。于是,我们把热力学计算的问题转换成为了一个统计问题,因为实在是不能精确地计算出每块电路板或每台机器到底能产生多少热量。

头脑灵活的读者可能会想到一个办法:在样本数据中,我们找到一个与346非常近似的例子,以它为参考就可以找到合适的空调功率数值了。

不得不承认,这样做是完全科学合理的,实际上这就是线性回归的解题思路:利用已有值,预测未知值。也就是说,这些读者不经意间使用了线性回归模型。而实际上,这个例子非常简单,只有一个自变量和一个因变量,因此可以用简单直接的方法来解决问题。但是,当有多个自变量时,这种直接的办法可能就会失效了。假设有三个自变量,很有可能不能够在样本中找到和这三个自变量的组合非常接近的数据,此时我们就应该借助更系统的方法了。

4.0.2 一元线性回归模型

回归分析是一种数学模型。当因变量和自变量为线性关系时,它是一种特殊的线性模型。

最简单的情形是一元线性回归,由大体上有线性关系的一个自变量和一个因变量组成,模型是:

Y=a+bX+ε(1)(1)Y=a+bX+ε

X是自变量,Y是因变量,ε是随机误差,a和b是参数,在线性回归模型中,a和b是我们要通过算法学习出来的。

什么叫模型?第一次接触这个概念时,可能会有些不明觉厉。从常规概念上讲,是人们通过主观意识借助实体或者虚拟表现来构成对客观事物的描述,这种描述通常是有一定的逻辑或者数学含义的抽象表达方式。

比如对小轿车建模的话,会是这样描述:由发动机驱动的四轮铁壳子。对能量概念建模的话,那就是爱因斯坦狭义相对论的著名推论:E=mc2E=mc2。

对数据建模的话,就是想办法用一个或几个公式来描述这些数据的产生条件或者相互关系,比如有一组数据是大致满足y=3x+2y=3x+2这个公式的,那么这个公式就是模型。为什么说是“大致”呢?因为在现实世界中,一般都有噪音(误差)存在,所以不可能非常准确地满足这个公式,只要是在这条直线两侧附近,就可以算作是满足条件。

对于线性回归模型,有如下一些概念需要了解:

  • 通常假定随机误差的均值为0,方差为σ^2(σ^2﹥0,σ^2与X的值无关)
  • 若进一步假定随机误差遵从正态分布,就叫做正态线性模型
  • 一般地,若有k个自变量和1个因变量(即公式1中的Y),则因变量的值分为两部分:一部分由自变量影响,即表示为它的函数,函数形式已知且含有未知参数;另一部分由其他的未考虑因素和随机性影响,即随机误差
  • 当函数为参数未知的线性函数时,称为线性回归分析模型
  • 当函数为参数未知的非线性函数时,称为非线性回归分析模型
  • 当自变量个数大于1时称为多元回归
  • 当因变量个数大于1时称为多重回归

我们通过对数据的观察,可以大致认为它符合线性回归模型的条件,于是列出了公式1,不考虑随机误差的话,我们的任务就是找到合适的a和b,这就是线性回归的任务。

图4-2 线性回归和非线性回归的区别

如图4-2所示,左侧为线性模型,可以看到直线穿过了一组三角形所形成的区域的中心线,并不要求这条直线穿过每一个三角形。右侧为非线性模型,一条曲线穿过了一组矩形所形成的区域的中心线。在本章中,我们先学习如何解决左侧的线性回归问题。

我们接下来会用几种方法来解决这个问题:

  1. 最小二乘法;
  2. 梯度下降法;
  3. 简单的神经网络法;
  4. 更通用的神经网络算法。

4.0.3 公式形态

这里要解释一下线性公式中w和x的顺序问题。在很多教科书中,我们可以看到下面的公式:

y=wTx+b(1)(1)y=wTx+b

或者:

y=w⋅x+b(2)(2)y=w⋅x+b

而我们在本书中使用:

y=x⋅w+b(3)(3)y=x⋅w+b

这三者的主要区别是样本数据x的形状定义,相应地会影响到w的形状定义。举例来说,如果x有三个特征值,那么w必须有三个权重值与特征值对应,则:

公式1的矩阵形式

x是列向量:

x=⎛⎝⎜x1x2x3⎞⎠⎟x=(x1x2x3)

w也是列向量:

w=⎛⎝⎜w1w2w3⎞⎠⎟w=(w1w2w3)

y=wTx+b=(w1w2w3)⎛⎝⎜x1x2x3⎞⎠⎟+by=wTx+b=(w1w2w3)(x1x2x3)+b

=w1⋅x1+w2⋅x2+w3⋅x3+b(4)(4)=w1⋅x1+w2⋅x2+w3⋅x3+b

w和x都是列向量,所以需要先把w转置后,再与x做矩阵乘法。

公式2的矩阵形式

公式2与公式1的区别是w的形状,在公式2中,w直接就是个行向量:

w=(w1w2w3)w=(w1w2w3)

而x的形状仍然是列向量:

x=⎛⎝⎜x1x2x3⎞⎠⎟x=(x1x2x3)

这样相乘之前不需要做矩阵转置了:

y=wx+b=(w1w2w3)⎛⎝⎜x1x2x3⎞⎠⎟+by=wx+b=(w1w2w3)(x1x2x3)+b

=w1⋅x1+w2⋅x2+w3⋅x3+b(5)(5)=w1⋅x1+w2⋅x2+w3⋅x3+b

公式3的矩阵形式

x是个行向量:

x=(x1x2x3)x=(x1x2x3)

w是列向量:

w=⎛⎝⎜w1w2x3⎞⎠⎟w=(w1w2x3)

所以x在前,w在后:

y=x⋅w+b=(x1x2x3)⎛⎝⎜w1w2w3⎞⎠⎟+by=x⋅w+b=(x1x2x3)(w1w2w3)+b

=x1⋅w1+x2⋅w2+x3⋅w3+b(6)(6)=x1⋅w1+x2⋅w2+x3⋅w3+b

比较公式4,5,6,其实最后的运算结果是相同的。

我们再分析一下前两种形式的x矩阵,由于x是个列向量,意味着特征由行表示,当有2个样本同时参与计算时,x需要增加一列,变成了如下形式:

x=⎛⎝⎜x11x12x13x21x22x23⎞⎠⎟x=(x11x21x12x22x13x23)

x的第一个下标表示样本序号,第二个下标表示样本特征,所以x21x21是第2个样本的第1个特征。看x21x21这个序号很别扭,一般我们都是认为行在前、列在后,但是x21x21却是处于第1行第2列,和习惯正好相反。

如果采用第三种形式,则两个样本的x的矩阵是:

x=(x11x21x12x22x13x23)x=(x11x12x13x21x22x23)

第1行是第1个样本的3个特征,第2行是第2个样本的3个特征,这与常用的阅读习惯正好一致,第1个样本的第2个特征在矩阵的第1行第2列,因此我们在本书中一律使用第三种形式来描述线性方程。

另外一个原因是,在很多深度学习库的实现中,确实是把x放在w前面做矩阵运算的,同时w的形状也是从左向右看,比如左侧有2个样本的3个特征输入(2x3表示2个样本3个特征值),右侧是1个输出,则w的形状就是3x1。否则的话就需要倒着看,w的形状成为了1x3,而x变成了3x2,很别扭。

对于b来说,它永远是1行,列数与w的列数相等。比如w是3x1的矩阵,则b是1x1的矩阵。如果w是3x2的矩阵,意味着3个特征输入到2个神经元上,则b是1x2的矩阵,每个神经元分配1个bias。

4.1 最小二乘法

4.1.1 历史

最小二乘法,也叫做最小平方法(Least Square),它通过最小化误差的平方和寻找数据的最佳函数匹配。利用最小二乘法可以简便地求得未知的数据,并使得这些求得的数据与实际数据之间误差的平方和为最小。最小二乘法还可用于曲线拟合。其他一些优化问题也可通过最小化能量或最小二乘法来表达。

1801年,意大利天文学家朱赛普·皮亚齐发现了第一颗小行星谷神星。经过40天的跟踪观测后,由于谷神星运行至太阳背后,使得皮亚齐失去了谷神星的位置。随后全世界的科学家利用皮亚齐的观测数据开始寻找谷神星,但是根据大多数人计算的结果来寻找谷神星都没有结果。时年24岁的高斯也计算了谷神星的轨道。奥地利天文学家海因里希·奥尔伯斯根据高斯计算出来的轨道重新发现了谷神星。

高斯使用的最小二乘法的方法发表于1809年他的著作《天体运动论》中。法国科学家勒让德于1806年独立发明“最小二乘法”,但因不为世人所知而默默无闻。勒让德曾与高斯为谁最早创立最小二乘法原理发生争执。

1829年,高斯提供了最小二乘法的优化效果强于其他方法的证明,因此被称为高斯-马尔可夫定理。

4.1.2 数学原理

线性回归试图学得:

z(xi)=w⋅xi+b(1)(1)z(xi)=w⋅xi+b

使得:

z(xi)≃yi(2)(2)z(xi)≃yi

其中,xixi是样本特征值,yiyi是样本标签值,zizi是模型预测值。

如何学得w和b呢?均方差(MSE - mean squared error)是回归任务中常用的手段:

J=∑i=1m(z(xi)−yi)2=∑i=1m(yi−wxi−b)2(3)(3)J=∑i=1m(z(xi)−yi)2=∑i=1m(yi−wxi−b)2

JJ称为损失函数。实际上就是试图找到一条直线,使所有样本到直线上的残差的平方和最小。

图4-3 均方差函数的评估原理

图4-3中,圆形点是样本点,直线是当前的拟合结果。如左图所示,我们是要计算样本点到直线的垂直距离,需要再根据直线的斜率来求垂足然后再计算距离,这样计算起来很慢;但实际上,在工程上我们通常使用的是右图的方式,即样本点到直线的竖直距离,因为这样计算很方便,用一个减法就可以了。

假设我们计算出初步的结果是虚线所示,这条直线是否合适呢?我们来计算一下图中每个点到这条直线的距离,把这些距离的值都加起来(都是正数,不存在互相抵消的问题)成为误差。

因为上图中的几个点不在一条直线上,所以不能有一条直线能同时穿过它们。所以,我们只能想办法不断改变红色直线的角度和位置,让总体误差最小(用于不可能是0),就意味着整体偏差最小,那么最终的那条直线就是我们要的结果。

如果想让误差的值最小,通过对w和b求导,再令导数为0(到达最小极值),就是w和b的最优解。

推导过程如下:

∂J∂w=∂(∑mi=1(yi−wxi−b)2)∂w=2∑i=1m(yi−wxi−b)(−xi)(4)(4)∂J∂w=∂(∑i=1m(yi−wxi−b)2)∂w=2∑i=1m(yi−wxi−b)(−xi)

令公式4为0:

∑i=1m(yi−wxi−b)xi=0(5)(5)∑i=1m(yi−wxi−b)xi=0

∂J∂b=∂(∑mi=1(yi−wxi−b)2)∂b=2∑i=1m(yi−wxi−b)(−1)(6)(6)∂J∂b=∂(∑i=1m(yi−wxi−b)2)∂b=2∑i=1m(yi−wxi−b)(−1)

令公式6为0:

∑i=1m(yi−wxi−b)=0(7)(7)∑i=1m(yi−wxi−b)=0

由式7得到(假设有m个样本):

∑i=1mb=m⋅b=∑i=1myi−w∑i=1mxi(8)(8)∑i=1mb=m⋅b=∑i=1myi−w∑i=1mxi

两边除以m:

b=1m(∑i=1myi−w∑i=1mxi)=y¯−wx¯(9)(9)b=1m(∑i=1myi−w∑i=1mxi)=y¯−wx¯

其中:

y¯=1m∑i=1myi,x¯=1m∑i=1mxi(10)(10)y¯=1m∑i=1myi,x¯=1m∑i=1mxi

将公式10代入公式5:

∑i=1m(yi−wxi−y¯+wx¯)xi=0∑i=1m(yi−wxi−y¯+wx¯)xi=0

∑i=1m(xiyi−wx2i−xiy¯+wx¯xi)=0∑i=1m(xiyi−wxi2−xiy¯+wx¯xi)=0

∑i=1m(xiyi−xiy¯)−w∑i=1m(x2i−x¯xi)=0∑i=1m(xiyi−xiy¯)−w∑i=1m(xi2−x¯xi)=0

w=∑mi=1(xiyi−xiy¯)∑mi=1(x2i−x¯xi)(11)(11)w=∑i=1m(xiyi−xiy¯)∑i=1m(xi2−x¯xi)

将公式10代入公式11:

w=∑mi=1(xi⋅yi)−∑mi=1xi⋅1m∑mi=1yi∑mi=1x2i−∑mi=1xi⋅1m∑mi=1xi(12)(12)w=∑i=1m(xi⋅yi)−∑i=1mxi⋅1m∑i=1myi∑i=1mxi2−∑i=1mxi⋅1m∑i=1mxi

分子分母都乘以m:

w=m∑mi=1xiyi−∑mi=1xi∑mi=1yim∑mi=1x2i−(∑mi=1xi)2(13)(13)w=m∑i=1mxiyi−∑i=1mxi∑i=1myim∑i=1mxi2−(∑i=1mxi)2

b=1m∑i=1m(yi−wxi)(14)(14)b=1m∑i=1m(yi−wxi)

而事实上,式13有很多个变种,大家会在不同的文章里看到不同版本,往往感到困惑,比如下面两个公式也是正确的解:

w=∑mi=1yi(xi−x¯)∑mi=1x2i−(∑mi=1xi)2/m(15)(15)w=∑i=1myi(xi−x¯)∑i=1mxi2−(∑i=1mxi)2/m

w=∑mi=1xi(yi−y¯)∑mi=1x2i−x¯∑mi=1xi(16)(16)w=∑i=1mxi(yi−y¯)∑i=1mxi2−x¯∑i=1mxi

以上两个公式,如果把公式10代入,也应该可以得到和式13相同的答案,只不过需要一些运算技巧。比如,很多人不知道这个神奇的公式:

∑i=1m(xiy¯)=y¯∑i=1mxi=1m(∑i=1myi)(∑i=1mxi)=1m(∑i=1mxi)(∑i=1myi)=x¯∑i=1myi=∑i=1m(yix¯)(17)(17)∑i=1m(xiy¯)=y¯∑i=1mxi=1m(∑i=1myi)(∑i=1mxi)=1m(∑i=1mxi)(∑i=1myi)=x¯∑i=1myi=∑i=1m(yix¯)

4.1.3 代码实现

我们下面用Python代码来实现一下以上的计算过程:

计算w值

# 根据公式15
def method1(X,Y,m):x_mean = X.mean()p = sum(Y*(X-x_mean))q = sum(X*X) - sum(X)*sum(X)/mw = p/qreturn w# 根据公式16
def method2(X,Y,m):x_mean = X.mean()y_mean = Y.mean()p = sum(X*(Y-y_mean))q = sum(X*X) - x_mean*sum(X)w = p/qreturn w# 根据公式13
def method3(X,Y,m):p = m*sum(X*Y) - sum(X)*sum(Y)q = m*sum(X*X) - sum(X)*sum(X)w = p/qreturn w

由于有函数库的帮助,我们不需要手动计算sum(), mean()这样的基本函数。

计算b值

# 根据公式14
def calculate_b_1(X,Y,w,m):b = sum(Y-w*X)/mreturn b# 根据公式9
def calculate_b_2(X,Y,w):b = Y.mean() - w * X.mean()return b

4.1.4 运算结果

用以上几种方法,最后得出的结果都是一致的,可以起到交叉验证的作用:

w1=2.056827, b1=2.965434
w2=2.056827, b2=2.965434
w3=2.056827, b3=2.965434

代码位置

ch04, Level1

神经网络系列之四 -- 线性回归方法与原理相关推荐

  1. Tomcat原理系列之四:Tomat如何启动spring(加载web.xml)

    Tomcat原理系列之四:Tomat如何启动spring 熟悉的web.xml ContextLoaderListener Tomcat的初始化StandardContext.startInterna ...

  2. 【JVM系列3】方法重载和方法重写原理分析,看完这篇终于彻底搞懂了

    深入分析Java虚拟机中方法执行流程及方法重载和方法重写原理 前言 思考 栈帧 局部变量表(Local Variables) 操作数栈(Operand Stacks) 动态连接(Dynamic Lin ...

  3. 特征工程系列:特征筛选的原理与实现(下)

    0x00 前言 我们在<特征工程系列:特征筛选的原理与实现(上)>中介绍了特征选择的分类,并详细介绍了过滤式特征筛选的原理与实现.本篇继续介绍封装式和嵌入式特征筛选的原理与实现. 0x01 ...

  4. 积神经网络的参数优化方法——调整网络结构是关键!!!你只需不停增加层,直到测试误差不再减少....

    积神经网络(CNN)的参数优化方法 from:http://blog.csdn.net/u010900574/article/details/51992156 著名: 本文是从 Michael Nie ...

  5. (一)神经网络入门之线性回归

    作者:chen_h 微信号 & QQ:862251340 微信公众号:coderpai 简书地址:https://www.jianshu.com/p/0da... 这篇教程是翻译Peter R ...

  6. 线性回归算法数学原理_线性回归算法-非数学家的高级数学

    线性回归算法数学原理 内部AI (Inside AI) Linear regression is one of the most popular algorithms used in differen ...

  7. DFF之--(一)神经网络入门之线性回归

    这篇教程是翻译Peter Roelants写的神经网络教程,作者已经授权翻译,这是原文. 该教程将介绍如何入门神经网络,一共包含五部分.你可以在以下链接找到完整内容. (一)神经网络入门之线性回归 L ...

  8. 神经网络的参数优化方法

    转载自:https://www.cnblogs.com/bonelee/p/8528863.html 著名: 本文是从 Michael Nielsen的电子书Neural Network and De ...

  9. 进程——Windows核心编程学习手札系列之四

    进程 --Windows核心编程学习手札系列之四 进程是一个正在运行的程序的实例,有两个部分组成:一个是操作系统用来管理进程的内核对象,内核对象是系统用来存放关于进程的统计信息的地方:另一个是地址空间 ...

最新文章

  1. python分解word文档为多个_将一个word文档按一页或多页拆分成多个文档
  2. Orchard:使用VS2010来生成一个地图Content Part
  3. pytorch笔记 pytorch模型中的parameter与buffer
  4. 机器学习总结——机器学习课程笔记整理
  5. 撰写论文时word使用诀窍标题
  6. 信息学奥赛一本通(C++)在线评测系统——基础(三)数据结构 —— 1339:【例3-4】求后序遍历
  7. USACO 1.1 Your Ride Is Here
  8. 了解JUnit的Runner架构
  9. MATLAB绘图辅助操作
  10. 10-Mybatis 多表查询之多对多
  11. 网络共享服务器 samba
  12. Kotlin — 适用于移动端跨平台
  13. .ajax 上传图片,ajax图片上传并预览
  14. go导出mysql中的excel表_golang web 开发 从数据库 导出到excel案例
  15. 采用ATSC标准、欧洲DVB-T和日本ISDB-T标准的国家
  16. 自媒体矩阵mcn是什么怎么做自媒体mcn矩阵运营
  17. explicit,violate,volatile,mutable
  18. [LGOJ5558]心上秋(倍增)
  19. 嵌入式系统和嵌入式操作系统
  20. pow函数 真假硬币

热门文章

  1. 企业知识库的意义何在?到底如何高效搭建一个知识库?
  2. 【App数据运营分析】
  3. 格力电器又加薪了!人均每月加薪1000元
  4. 一日一命令:find 命令详解
  5. 一日一Shader·进阶版笔刷【SS_18】
  6. linux复制后权限变化,学霸Linux基础命令吐血总结,给你当新华字典用
  7. 【转】VMware15虚拟机安装教程
  8. WiFi底层通信接口@Netlink
  9. NYOJ - 独木舟上的旅行
  10. 第 l 个数到第 r 个数的和