机器学习基础专题:线性回归
线性回归
原理
输入
训练集数据D=(x1,y1)...(xM,yM)D = {(x_1,y_1) ... (x_M,y_M)}D=(x1,y1)...(xM,yM),xi∈X⊆Rpx_i \in \mathcal{X} \subseteq R^pxi∈X⊆Rp,yi∈Ry_i \in Ryi∈R
X=(x1x2...xM)T∈RM∗pX = (x_1\ x_2\ ...\ x_M)^T \in R^{M*p}X=(x1 x2 ... xM)T∈RM∗p
f(x)=wTx+bf(x) = w^Tx+bf(x)=wTx+b
正则化参数λ1\lambda_1λ1 ,λ2\lambda_2λ2
输出
线性回归模型f^(x)\hat f(x)f^(x)
损失函数
均方误差对应欧氏距离。基于均方误差最小化求解的方法称为最小二乘法least square method。
一般的几何意义,会将误差分散在每一个样本上。
L(w)=∑i=1M(wTxi−yi)2=(wTXT−YT)(Xw−Y)L(w) = \sum^M_{i=1}(w^Tx_i - y_i)^2 = (w^TX^T - Y^T)(Xw - Y)L(w)=∑i=1M(wTxi−yi)2=(wTXT−YT)(Xw−Y)
对w进行求导,
∂L(w)∂w=2XTXw−2XTY\frac {\partial L(w)} {\partial w} = 2X^TXw - 2X^TY ∂w∂L(w)=2XTXw−2XTY
w∗=(XTX)−1XTyw^* = (X^TX)^{-1}X^Tyw∗=(XTX)−1XTy,此时(XTX)(X^TX)(XTX)是满秩矩阵。我们也将(XTX)XT(X^TX)X^{T}(XTX)XT称为投影矩阵。这个矩阵可以将yyy投影到XXX所处于的空间中。
另外有一种几何意义,将误差分散在x的p个维度上。
令f(w)=wTx=xTβf(w) = w^Tx = x^T\betaf(w)=wTx=xTβ.相当于实际的y在x平面上的投影。由XT(Y−Xβ)=0X^T(Y - X\beta) = 0XT(Y−Xβ)=0可以推导出β=(XTX)−1XTy\beta = (X^TX)^{-1}X^Tyβ=(XTX)−1XTy。
概率推导
我们下面开始推导最小二乘的由来。
假设噪声是服从高斯分布的。ϵ∼N(0,σ2)\epsilon \sim N(0, \sigma^2)ϵ∼N(0,σ2)。由y=wTx+ϵy = w^Tx + \epsilony=wTx+ϵ很容易可以得到y∼N(wTx,σ2)y \sim N(w^Tx, \sigma^2)y∼N(wTx,σ2)。
计算MLE,
MLE(w)=logP(Y∣X;w)=log∏i=1MP(yi∣xi;w)=∑i=1MlogP(yi∣xi;w)=∑i=1Mlog1(2π)1/2σ+log[exp[−12σ2(yi−wTxi)2]]=∑i=1Mlog1(2π)1/2σ−12σ2(yi−wTxi)2MLE(w) = log P(Y|X; w) \\\\ = log \prod_{i=1}^{M} P(y_i|x_i;w) \\\\ = \sum_{i=1}^{M} log P(y_i|x_i;w) \\\\ = \sum_{i=1}^M log \frac{1}{(2\pi)^{1/2}\sigma} + log[exp[{-\frac 1 {2\sigma^{2}} (y_i-w^Tx_i)^2}]] \\\\ = \sum_{i=1}^M log \frac{1}{(2\pi)^{1/2}\sigma} - \frac 1 {2\sigma^{2}} (y_i-w^Tx_i)^2 \\\\ MLE(w)=logP(Y∣X;w)=logi=1∏MP(yi∣xi;w)=i=1∑MlogP(yi∣xi;w)=i=1∑Mlog(2π)1/2σ1+log[exp[−2σ21(yi−wTxi)2]]=i=1∑Mlog(2π)1/2σ1−2σ21(yi−wTxi)2
对w进行求导,
w∗=argmaxwMLE(w)=argmaxw∑i=1M−12σ2(yi−wTxi)2=argminw∑i=1M(yi−wTxi)2w^* = argmax_w MLE(w) \\\\ = argmax_w \sum_{i=1}^M - \frac 1 {2\sigma^{2}} (y_i-w^Tx_i)^2 \\\\ = argmin_w \sum_{i=1}^M (y_i-w^Tx_i)^2 \\\\ w∗=argmaxwMLE(w)=argmaxwi=1∑M−2σ21(yi−wTxi)2=argminwi=1∑M(yi−wTxi)2
适用场景
自变量和因变量之间是线性关系。当然我们也可以用模型逼近yyy的衍生物,比如ln(y)ln(y)ln(y)。在形式上仍是线性回归,不过实质上已是在求输入空间到输出空间的非线性函数映射。
优点
- 模型计算简单,建模速度快
- 可解释性好
缺点
- 需要数据服从分布(具有严格假设)
- 对离群数据非常敏感
正则化
通常可以使用w∗=(XTX)−1XTyw^* = (X^TX)^{-1}X^Tyw∗=(XTX)−1XTy得到答案。但是实际数据不一定满足M>pM>pM>p,这就会导致XTXX^TXXTX不可逆,从而导致严重的过拟合问题。解决方案之一就是正则化。我们使用C(w)C(w)C(w)代表对于www进行的惩罚(penalty)。线性回归正则化的一般形式可以写成argminwL(w)+λC(w)argmin_w L(w) + \lambda C(w)argminwL(w)+λC(w)。
Lasso Regression
C(w)=∣∣w∣∣1C(w) = ||w||_1C(w)=∣∣w∣∣1。保证了w∗w^*w∗的唯一性。从下图可以看出,Lasso Regression倾向于让xxx的某一维度参数成为0,从而得到稀疏的www,具有选择参数的功能。
岭回归Ridge Regression
也叫权值衰减。C(w)=∣∣w∣∣22C(w) = ||w||_2^2C(w)=∣∣w∣∣22。保证了w∗w^*w∗的唯一性。
L(w)=∑i=1M∣∣wTxi−yi∣∣2+λ∣∣w∣∣22=(wTXT−YT)(Xw−Y)+λwTw=wTXTXw−2wTxTY+YTY+λwTw=wT(XTX+λI)w−2wTxTY+YTYL(w) = \sum_{i=1}^M ||w^Tx_i - y_i||^2 + \lambda ||w||_2^2 \\\\ = (w^TX^T - Y^T)(Xw - Y) + \lambda w^T w \\\\ = w^TX^TXw - 2w^Tx^TY + Y^TY + \lambda w^T w \\\\ = w^T(X^TX + \lambda I)w - 2w^Tx^TY + Y^TY L(w)=i=1∑M∣∣wTxi−yi∣∣2+λ∣∣w∣∣22=(wTXT−YT)(Xw−Y)+λwTw=wTXTXw−2wTxTY+YTY+λwTw=wT(XTX+λI)w−2wTxTY+YTY
对w进行求导,
KaTeX parse error: Undefined control sequence: \part at position 8: \frac{\̲p̲a̲r̲t̲ ̲L(w)}{\part w} …
可以得到w∗=(XTX+λI)−1XTyw^* = (X^TX + \lambda I)^{-1}X^Tyw∗=(XTX+λI)−1XTy。
概率推导
从概率学的角度,将w看做一个变量,w∼N(0,σ02)w \sim N(0, \sigma_0^2)w∼N(0,σ02),p(w∣y)=p(y∣w)∗p(w)/p(y)p(w|y) = p(y|w) * p(w)/ p(y)p(w∣y)=p(y∣w)∗p(w)/p(y)
我们已经知道p(y∣w)p(y|w)p(y∣w) 和 p(w)p(w)p(w)。
p(y∣w)=1(2π)1/2σ+exp[−12σ2(yi−wTxi)2]p(w)=1(2π)1/2σ0+exp[−12σ02∣∣w∣∣22]p(y|w) = \frac{1}{(2\pi)^{1/2}\sigma} + exp[-\frac 1 {2\sigma^{2}} (y_i-w^Tx_i)^2] \\\\ p(w) = \frac{1}{(2\pi)^{1/2}\sigma_0} + exp[-\frac 1 {2\sigma_0^{2}} ||w||_2^2] p(y∣w)=(2π)1/2σ1+exp[−2σ21(yi−wTxi)2]p(w)=(2π)1/2σ01+exp[−2σ021∣∣w∣∣22]
通过求解w的MAP,
w∗=argmaxwp(w∣y)=argmaxwp(y∣w)∗p(w)=argmaxwlog[p(y∣w)∗p(w)]=argmaxwlog(1(2π)1/2σ∗(2π)1/2σ0)+log[exp[−12σ2(yi−wTxi)2−12σ02∣∣w∣∣22]]=argmaxw−12σ2(yi−wTxi)2−12σ02∣∣w∣∣22=argminw12σ2(yi−wTxi)2+12σ02∣∣w∣∣22=argminw(yi−wTxi)2+σ2σ02∣∣w∣∣22w^* = argmax_w p(w|y) \\\\ = argmax_w p(y|w)*p(w) \\\\ = argmax_w log[p(y|w)*p(w)] \\\\ = argmax_w log(\frac{1}{(2\pi)^{1/2}\sigma*(2\pi)^{1/2}\sigma_0}) + log[exp[-\frac 1 {2\sigma^{2}} (y_i-w^Tx_i)^2-\frac 1 {2\sigma_0^{2}} ||w||_2^2]] \\\\ = argmax_w -\frac 1 {2\sigma^{2}} (y_i-w^Tx_i)^2-\frac 1 {2\sigma_0^{2}} ||w||_2^2 \\\\ = argmin_w \frac 1 {2\sigma^{2}} (y_i-w^Tx_i)^2 + \frac 1 {2\sigma_0^{2}} ||w||_2^2 \\\\ = argmin_w (y_i-w^Tx_i)^2 + \frac{\sigma^{2}}{\sigma_0^{2}}||w||_2^2 w∗=argmaxwp(w∣y)=argmaxwp(y∣w)∗p(w)=argmaxwlog[p(y∣w)∗p(w)]=argmaxwlog((2π)1/2σ∗(2π)1/2σ01)+log[exp[−2σ21(yi−wTxi)2−2σ021∣∣w∣∣22]]=argmaxw−2σ21(yi−wTxi)2−2σ021∣∣w∣∣22=argminw2σ21(yi−wTxi)2+2σ021∣∣w∣∣22=argminw(yi−wTxi)2+σ02σ2∣∣w∣∣22
发现和我们的损失函数是相同的。
Reference
- 《美团机器学习实践》by美团算法团队,第三章
- 《机器学习》by周志华,第三、四章
- 白板推导系列,shuhuai007
机器学习基础专题:线性回归相关推荐
- 机器学习基础专题:特征工程
特征工程 特征提取 将原始数据转化为实向量之后,为了让模型更好地学习规律,对特征做进一步的变换.首先,要理解业务数据和业务逻辑. 其次,要理解模型和算法,清楚模型需要什么样的输入才能有精确的结果. 探 ...
- 【机器学习基础】线性回归和梯度下降的初学者教程
作者 | Lily Chen 编译 | VK 来源 | Towards Data Science 假设我们有一个虚拟的数据集,一对变量,一个母亲和她女儿的身高: 考虑到另一位母亲的身高为63,我们如何 ...
- 机器学习基础专题:逻辑回归
逻辑回归 广义线性模型. 原理 输入 训练集数据T=(x1,y1)...(xM,yM)T = {(x_1,y_1) ... (x_M,y_M)}T=(x1,y1)...(xM,yM),xi∈X ...
- 机器学习基础专题:高斯混合模型和最大期望EM算法以及代码实现
高斯混合模型 混合模型是潜变量模型的一种,是最常见的形式之一.而高斯混合模型(Gaussian Mixture Models, GMM)是混合模型中最常见的一种.zzz代表该数据点是由某一个高斯分布产 ...
- 机器学习基础专题:支持向量机SVM
支持向量机 全称Support Vector Machine (SVM).可以分为硬间隔(hard margin SVM),软间隔(soft margin SVM),和核支持向量机(kernel ma ...
- 机器学习基础专题:感知机
感知机 原理 思想是错误驱动.一开始赋予w一个初始值,通过计算被错误分类的样本不断移动分类边界. 输入 训练集数据D=(x1,y1)...(xM,yM)D = {(x_1,y_1) ... (x_M, ...
- 机器学习基础专题:分类
线性分类 分类方式 硬分类 使用的是非概率模型,分类结果是决策函数的决策结果. 代表:线性判别分析.感知机 软分类 分类结果是属于不同类别的概率. 生成式 通过贝叶斯定理,使用MAP比较P(Y=0∣X ...
- 机器学习基础专题:线性判别器
线性判别分析 全称是Linear Discriminant Analysis (LDA). 原理 给定训练样例集,通过降维的思路进行分类.将样例投影到一条直线上,使得同类样例的投影点接近,异类样例的投 ...
- 机器学习基础专题:样本选择
样本选择 选择最少量的训练集S⊂\sub⊂完整训练集T,模型效果不会变差. 优势: 缩减模型计算时间 相关性太低的数据对解决问题没有帮助,直接剔除 去除噪声 数据去噪 噪声数据 特征值不对(缺失.超出 ...
最新文章
- WinCE5.0中文模拟器SDK(VS2005,VS2008)的配置
- SSM中向后端传递的属性为多个对象的实现方法
- Hbase(2)——基础语句(2)
- C# Global.asax.cs 定时任务
- vm虚拟远程部署windows驱动
- 大数据如何影响百姓生活
- 【HTTP】POST 与 PUT 方法区别
- 深度解读MRS IoTDB时序数据库的整体架构设计与实现
- js/jQuery中的宽高
- 根据开始日期和结束日期获取基金的当天净值,并计算收益率
- jquery 操作表格实例
- shiro+springMVC文档
- AI语音技术的架构(学习心得)
- Mac创建txt文件的两种方法
- 这家为AI for Science而生的新研究院,要让科研进入“安卓模式”
- Ueditor上传图片文件大小上限问题
- Scrapy--CrawlSpider
- 自媒体人如何搜集写作素材?建立自己的素材库
- ROC和AUC指标的理解
- picsart旧版本_picsart美易照片编辑旧版
热门文章
- SpringMVC背景介绍及常见MVC框架比较
- koa源码阅读之koa-compose/application.js
- PHP扩展开发系列01 - 我要成为一名老司机
- 《Android应用开发》——1.3节配置Eclipse
- 如何 提高企业网站大数据量 效率
- pku1384---Piggy-Bank(动态规划)
- vivado中设置多线程编译
- Verilog 中 wire 和 reg 数据类型区别
- python np.arange,np.linspace和np.logspace之间的区别
- jittor和pytorch生成网络对比之began