线性回归

原理

输入

训练集数据D=(x1,y1)...(xM,yM)D = {(x_1,y_1) ... (x_M,y_M)}D=(x1​,y1​)...(xM​,yM​),xi∈X⊆Rpx_i \in \mathcal{X} \subseteq R^pxi​∈X⊆Rp,yi∈Ry_i \in Ryi​∈R

X=(x1x2...xM)T∈RM∗pX = (x_1\ x_2\ ...\ x_M)^T \in R^{M*p}X=(x1​ x2​ ... xM​)T∈RM∗p

f(x)=wTx+bf(x) = w^Tx+bf(x)=wTx+b

正则化参数λ1\lambda_1λ1​ ,λ2\lambda_2λ2​

输出

线性回归模型f^(x)\hat f(x)f^​(x)

损失函数

均方误差对应欧氏距离。基于均方误差最小化求解的方法称为最小二乘法least square method。

一般的几何意义,会将误差分散在每一个样本上。

L(w)=∑i=1M(wTxi−yi)2=(wTXT−YT)(Xw−Y)L(w) = \sum^M_{i=1}(w^Tx_i - y_i)^2 = (w^TX^T - Y^T)(Xw - Y)L(w)=∑i=1M​(wTxi​−yi​)2=(wTXT−YT)(Xw−Y)

对w进行求导,

∂L(w)∂w=2XTXw−2XTY\frac {\partial L(w)} {\partial w} = 2X^TXw - 2X^TY ∂w∂L(w)​=2XTXw−2XTY

w∗=(XTX)−1XTyw^* = (X^TX)^{-1}X^Tyw∗=(XTX)−1XTy,此时(XTX)(X^TX)(XTX)是满秩矩阵。我们也将(XTX)XT(X^TX)X^{T}(XTX)XT称为投影矩阵。这个矩阵可以将yyy投影到XXX所处于的空间中。

另外有一种几何意义,将误差分散在x的p个维度上。

令f(w)=wTx=xTβf(w) = w^Tx = x^T\betaf(w)=wTx=xTβ.相当于实际的y在x平面上的投影。由XT(Y−Xβ)=0X^T(Y - X\beta) = 0XT(Y−Xβ)=0可以推导出β=(XTX)−1XTy\beta = (X^TX)^{-1}X^Tyβ=(XTX)−1XTy。

概率推导

我们下面开始推导最小二乘的由来。

假设噪声是服从高斯分布的。ϵ∼N(0,σ2)\epsilon \sim N(0, \sigma^2)ϵ∼N(0,σ2)。由y=wTx+ϵy = w^Tx + \epsilony=wTx+ϵ很容易可以得到y∼N(wTx,σ2)y \sim N(w^Tx, \sigma^2)y∼N(wTx,σ2)。

计算MLE,
MLE(w)=logP(Y∣X;w)=log∏i=1MP(yi∣xi;w)=∑i=1MlogP(yi∣xi;w)=∑i=1Mlog1(2π)1/2σ+log[exp[−12σ2(yi−wTxi)2]]=∑i=1Mlog1(2π)1/2σ−12σ2(yi−wTxi)2MLE(w) = log P(Y|X; w) \\\\ = log \prod_{i=1}^{M} P(y_i|x_i;w) \\\\ = \sum_{i=1}^{M} log P(y_i|x_i;w) \\\\ = \sum_{i=1}^M log \frac{1}{(2\pi)^{1/2}\sigma} + log[exp[{-\frac 1 {2\sigma^{2}} (y_i-w^Tx_i)^2}]] \\\\ = \sum_{i=1}^M log \frac{1}{(2\pi)^{1/2}\sigma} - \frac 1 {2\sigma^{2}} (y_i-w^Tx_i)^2 \\\\ MLE(w)=logP(Y∣X;w)=logi=1∏M​P(yi​∣xi​;w)=i=1∑M​logP(yi​∣xi​;w)=i=1∑M​log(2π)1/2σ1​+log[exp[−2σ21​(yi​−wTxi​)2]]=i=1∑M​log(2π)1/2σ1​−2σ21​(yi​−wTxi​)2
对w进行求导,
w∗=argmaxwMLE(w)=argmaxw∑i=1M−12σ2(yi−wTxi)2=argminw∑i=1M(yi−wTxi)2w^* = argmax_w MLE(w) \\\\ = argmax_w \sum_{i=1}^M - \frac 1 {2\sigma^{2}} (y_i-w^Tx_i)^2 \\\\ = argmin_w \sum_{i=1}^M (y_i-w^Tx_i)^2 \\\\ w∗=argmaxw​MLE(w)=argmaxw​i=1∑M​−2σ21​(yi​−wTxi​)2=argminw​i=1∑M​(yi​−wTxi​)2

适用场景

自变量和因变量之间是线性关系。当然我们也可以用模型逼近yyy的衍生物,比如ln(y)ln(y)ln(y)。在形式上仍是线性回归,不过实质上已是在求输入空间到输出空间的非线性函数映射。

优点

  1. 模型计算简单,建模速度快
  2. 可解释性好

缺点

  1. 需要数据服从分布(具有严格假设)
  2. 对离群数据非常敏感

正则化

通常可以使用w∗=(XTX)−1XTyw^* = (X^TX)^{-1}X^Tyw∗=(XTX)−1XTy得到答案。但是实际数据不一定满足M>pM>pM>p,这就会导致XTXX^TXXTX不可逆,从而导致严重的过拟合问题。解决方案之一就是正则化。我们使用C(w)C(w)C(w)代表对于www进行的惩罚(penalty)。线性回归正则化的一般形式可以写成argminwL(w)+λC(w)argmin_w L(w) + \lambda C(w)argminw​L(w)+λC(w)。

Lasso Regression

C(w)=∣∣w∣∣1C(w) = ||w||_1C(w)=∣∣w∣∣1​。保证了w∗w^*w∗的唯一性。从下图可以看出,Lasso Regression倾向于让xxx的某一维度参数成为0,从而得到稀疏的www,具有选择参数的功能。

岭回归Ridge Regression

也叫权值衰减。C(w)=∣∣w∣∣22C(w) = ||w||_2^2C(w)=∣∣w∣∣22​。保证了w∗w^*w∗的唯一性。

L(w)=∑i=1M∣∣wTxi−yi∣∣2+λ∣∣w∣∣22=(wTXT−YT)(Xw−Y)+λwTw=wTXTXw−2wTxTY+YTY+λwTw=wT(XTX+λI)w−2wTxTY+YTYL(w) = \sum_{i=1}^M ||w^Tx_i - y_i||^2 + \lambda ||w||_2^2 \\\\ = (w^TX^T - Y^T)(Xw - Y) + \lambda w^T w \\\\ = w^TX^TXw - 2w^Tx^TY + Y^TY + \lambda w^T w \\\\ = w^T(X^TX + \lambda I)w - 2w^Tx^TY + Y^TY L(w)=i=1∑M​∣∣wTxi​−yi​∣∣2+λ∣∣w∣∣22​=(wTXT−YT)(Xw−Y)+λwTw=wTXTXw−2wTxTY+YTY+λwTw=wT(XTX+λI)w−2wTxTY+YTY

对w进行求导,
KaTeX parse error: Undefined control sequence: \part at position 8: \frac{\̲p̲a̲r̲t̲ ̲L(w)}{\part w} …
可以得到w∗=(XTX+λI)−1XTyw^* = (X^TX + \lambda I)^{-1}X^Tyw∗=(XTX+λI)−1XTy。

概率推导

从概率学的角度,将w看做一个变量,w∼N(0,σ02)w \sim N(0, \sigma_0^2)w∼N(0,σ02​),p(w∣y)=p(y∣w)∗p(w)/p(y)p(w|y) = p(y|w) * p(w)/ p(y)p(w∣y)=p(y∣w)∗p(w)/p(y)

我们已经知道p(y∣w)p(y|w)p(y∣w) 和 p(w)p(w)p(w)。
p(y∣w)=1(2π)1/2σ+exp[−12σ2(yi−wTxi)2]p(w)=1(2π)1/2σ0+exp[−12σ02∣∣w∣∣22]p(y|w) = \frac{1}{(2\pi)^{1/2}\sigma} + exp[-\frac 1 {2\sigma^{2}} (y_i-w^Tx_i)^2] \\\\ p(w) = \frac{1}{(2\pi)^{1/2}\sigma_0} + exp[-\frac 1 {2\sigma_0^{2}} ||w||_2^2] p(y∣w)=(2π)1/2σ1​+exp[−2σ21​(yi​−wTxi​)2]p(w)=(2π)1/2σ0​1​+exp[−2σ02​1​∣∣w∣∣22​]
通过求解w的MAP,
w∗=argmaxwp(w∣y)=argmaxwp(y∣w)∗p(w)=argmaxwlog[p(y∣w)∗p(w)]=argmaxwlog(1(2π)1/2σ∗(2π)1/2σ0)+log[exp[−12σ2(yi−wTxi)2−12σ02∣∣w∣∣22]]=argmaxw−12σ2(yi−wTxi)2−12σ02∣∣w∣∣22=argminw12σ2(yi−wTxi)2+12σ02∣∣w∣∣22=argminw(yi−wTxi)2+σ2σ02∣∣w∣∣22w^* = argmax_w p(w|y) \\\\ = argmax_w p(y|w)*p(w) \\\\ = argmax_w log[p(y|w)*p(w)] \\\\ = argmax_w log(\frac{1}{(2\pi)^{1/2}\sigma*(2\pi)^{1/2}\sigma_0}) + log[exp[-\frac 1 {2\sigma^{2}} (y_i-w^Tx_i)^2-\frac 1 {2\sigma_0^{2}} ||w||_2^2]] \\\\ = argmax_w -\frac 1 {2\sigma^{2}} (y_i-w^Tx_i)^2-\frac 1 {2\sigma_0^{2}} ||w||_2^2 \\\\ = argmin_w \frac 1 {2\sigma^{2}} (y_i-w^Tx_i)^2 + \frac 1 {2\sigma_0^{2}} ||w||_2^2 \\\\ = argmin_w (y_i-w^Tx_i)^2 + \frac{\sigma^{2}}{\sigma_0^{2}}||w||_2^2 w∗=argmaxw​p(w∣y)=argmaxw​p(y∣w)∗p(w)=argmaxw​log[p(y∣w)∗p(w)]=argmaxw​log((2π)1/2σ∗(2π)1/2σ0​1​)+log[exp[−2σ21​(yi​−wTxi​)2−2σ02​1​∣∣w∣∣22​]]=argmaxw​−2σ21​(yi​−wTxi​)2−2σ02​1​∣∣w∣∣22​=argminw​2σ21​(yi​−wTxi​)2+2σ02​1​∣∣w∣∣22​=argminw​(yi​−wTxi​)2+σ02​σ2​∣∣w∣∣22​
发现和我们的损失函数是相同的。

Reference

  • 《美团机器学习实践》by美团算法团队,第三章
  • 《机器学习》by周志华,第三、四章
  • 白板推导系列,shuhuai007

机器学习基础专题:线性回归相关推荐

  1. 机器学习基础专题:特征工程

    特征工程 特征提取 将原始数据转化为实向量之后,为了让模型更好地学习规律,对特征做进一步的变换.首先,要理解业务数据和业务逻辑. 其次,要理解模型和算法,清楚模型需要什么样的输入才能有精确的结果. 探 ...

  2. 【机器学习基础】线性回归和梯度下降的初学者教程

    作者 | Lily Chen 编译 | VK 来源 | Towards Data Science 假设我们有一个虚拟的数据集,一对变量,一个母亲和她女儿的身高: 考虑到另一位母亲的身高为63,我们如何 ...

  3. 机器学习基础专题:逻辑回归

    逻辑回归 广义线性模型. 原理 输入 训练集数据T=(x1,y1)...(xM,yM)T = {(x_1,y_1) ... (x_M,y_M)}T=(x1​,y1​)...(xM​,yM​),xi∈X ...

  4. 机器学习基础专题:高斯混合模型和最大期望EM算法以及代码实现

    高斯混合模型 混合模型是潜变量模型的一种,是最常见的形式之一.而高斯混合模型(Gaussian Mixture Models, GMM)是混合模型中最常见的一种.zzz代表该数据点是由某一个高斯分布产 ...

  5. 机器学习基础专题:支持向量机SVM

    支持向量机 全称Support Vector Machine (SVM).可以分为硬间隔(hard margin SVM),软间隔(soft margin SVM),和核支持向量机(kernel ma ...

  6. 机器学习基础专题:感知机

    感知机 原理 思想是错误驱动.一开始赋予w一个初始值,通过计算被错误分类的样本不断移动分类边界. 输入 训练集数据D=(x1,y1)...(xM,yM)D = {(x_1,y_1) ... (x_M, ...

  7. 机器学习基础专题:分类

    线性分类 分类方式 硬分类 使用的是非概率模型,分类结果是决策函数的决策结果. 代表:线性判别分析.感知机 软分类 分类结果是属于不同类别的概率. 生成式 通过贝叶斯定理,使用MAP比较P(Y=0∣X ...

  8. 机器学习基础专题:线性判别器

    线性判别分析 全称是Linear Discriminant Analysis (LDA). 原理 给定训练样例集,通过降维的思路进行分类.将样例投影到一条直线上,使得同类样例的投影点接近,异类样例的投 ...

  9. 机器学习基础专题:样本选择

    样本选择 选择最少量的训练集S⊂\sub⊂完整训练集T,模型效果不会变差. 优势: 缩减模型计算时间 相关性太低的数据对解决问题没有帮助,直接剔除 去除噪声 数据去噪 噪声数据 特征值不对(缺失.超出 ...

最新文章

  1. WinCE5.0中文模拟器SDK(VS2005,VS2008)的配置
  2. SSM中向后端传递的属性为多个对象的实现方法
  3. Hbase(2)——基础语句(2)
  4. C# Global.asax.cs 定时任务
  5. vm虚拟远程部署windows驱动
  6. 大数据如何影响百姓生活
  7. 【HTTP】POST 与 PUT 方法区别
  8. 深度解读MRS IoTDB时序数据库的整体架构设计与实现
  9. js/jQuery中的宽高
  10. 根据开始日期和结束日期获取基金的当天净值,并计算收益率
  11. jquery 操作表格实例
  12. shiro+springMVC文档
  13. AI语音技术的架构(学习心得)
  14. Mac创建txt文件的两种方法
  15. 这家为AI for Science而生的新研究院,要让科研进入“安卓模式”
  16. Ueditor上传图片文件大小上限问题
  17. Scrapy--CrawlSpider
  18. 自媒体人如何搜集写作素材?建立自己的素材库
  19. ROC和AUC指标的理解
  20. picsart旧版本_picsart美易照片编辑旧版

热门文章

  1. SpringMVC背景介绍及常见MVC框架比较
  2. koa源码阅读之koa-compose/application.js
  3. PHP扩展开发系列01 - 我要成为一名老司机
  4. 《Android应用开发》——1.3节配置Eclipse
  5. 如何 提高企业网站大数据量 效率
  6. pku1384---Piggy-Bank(动态规划)
  7. vivado中设置多线程编译
  8. Verilog 中 wire 和 reg 数据类型区别
  9. python np.arange,np.linspace和np.logspace之间的区别
  10. jittor和pytorch生成网络对比之began