在上一节中,我们介绍了最简单的学习算法——最小二乘法去预测奥运会男子100米时间。但是可以发现,它的自变量只有一个:年份。通常,我们所面对的数据集往往不是单个特征,而是有成千上万个特征组成。那么我们就引入特征的向量来表示,这里涉及到矩阵的乘法,向量,矩阵求导等一些线性代数的知识。


一. 将拟合函数由单变量改写为多变量

设我们的拟合函数

f(xi;ω)=ωTxi

f(\boldsymbol{x_i}; \boldsymbol{\omega}) = \boldsymbol{\omega}^T\boldsymbol{x_i}

其中, w\boldsymbol{w}表示拟合函数的参数,xi\boldsymbol{x_i}表示数据集中第i条数据。

对于上节中的f(x;a,b)=ax+bf(x;a,b) = ax + b,我们可以令

ω=[ab],xi=[x1]

\boldsymbol{\omega} = \begin{bmatrix} a\\b \end{bmatrix}, \boldsymbol{x_i} = \begin{bmatrix} x\\1 \end{bmatrix}

则这两个函数等价。为了方便推导,我们在损失函数前边加上1N\frac{1}{N},由于N是定值,它代表数据集的记录数。那么,损失函数可以写为:

L=1N∑i=1N(yi−ωTxi)2=1N(y−Xω)T(y−Xω)(1)

L=\frac{1}{N}\sum_{i=1}^{N}(y_i-\boldsymbol{\omega^Tx_i})^2=\frac{1}{N}(\boldsymbol{y}-\boldsymbol{X\omega})^T(\boldsymbol{y}-\boldsymbol{X\omega}) (1)
那么上式的推导过程也很简单,令

X=⎡⎣⎢⎢⎢⎢⎢xT1xT2⋮xTn⎤⎦⎥⎥⎥⎥⎥=⎡⎣⎢⎢⎢⎢11⋮1x1x2⋮xn⎤⎦⎥⎥⎥⎥

\boldsymbol{X}=\begin{bmatrix} \boldsymbol{x_1^T} \\ \boldsymbol{x_2^T} \\ \vdots\\ \boldsymbol{x_n^T} \end{bmatrix} =\begin{bmatrix} 1 & x_1\\ 1 & x_2\\ \vdots & \vdots\\ 1 & x_n \end{bmatrix}

y=⎡⎣⎢⎢⎢⎢⎢y1y2⋮yn⎤⎦⎥⎥⎥⎥⎥,ω=[ω0ω1]

\boldsymbol{y}=\begin{bmatrix} y_1\\ y_2\\ \vdots\\ y_n \end{bmatrix}, \boldsymbol{\omega}=\begin{bmatrix} \omega_0\\ \omega_1 \end{bmatrix}

带入(1)式即可得证,此处略过。

二.多特征下求解参数 ω\boldsymbol{\omega}

L=1N(y−Xω)T(y−Xω)=1N(yT−ωTXT)(y−Xω)=1N(yTy−yTXω−ωTXTy+ωTXTXω)=1N(ωTXTXω−2ωTXTy+yTy)(2)

\begin{align} L&=\frac{1}{N}(\boldsymbol{y}-\boldsymbol{X\omega})^T(\boldsymbol{y}-\boldsymbol{X\omega}) \\ &=\frac{1}{N}(\boldsymbol{y^T-\omega^TX^T})(\boldsymbol{y}-\boldsymbol{X\omega})\\ &=\frac{1}{N}(\boldsymbol{y^Ty-y^TX\omega-\omega^TX^Ty+\omega^TX^TX\omega})\\ &=\frac{1}{N}(\boldsymbol{\omega^TX^TX\omega-2\omega^TX^Ty+y^Ty})(2) \end{align}
我们的目标是让损失函数最小,即求(2)的最小值,我们对 ω\boldsymbol{\omega}求偏导数,令其等于0,就可以求出 LL取得极小值时参数ω\boldsymbol{\omega}的值。

∂L∂ω=1N(2XTXω−2XTy)=0(3)⇒XTXω=XTy⇒ω=(XTX)−1XTy

\frac{\partial{L}}{\partial{\boldsymbol{\omega}}}=\frac{1}{N}(2\boldsymbol{X^TX\omega-2X^Ty})=0(3)\\ \Rightarrow\\ \boldsymbol{X^TX\omega=X^Ty}\\ \Rightarrow\\ \boldsymbol{\omega=(X^TX)^{-1}X^Ty}
至此,我们已经求出了参数值,接下来就可以预测了。

至于(3)的求导,注意以下求导公式即可:

f(w)f(\boldsymbol{w}) ∂f∂w\frac{\partial{f}}{\partial{\boldsymbol{w}}}
wTx\boldsymbol{w^Tx} x\boldsymbol{x}
xTw\boldsymbol{x^Tw} x\boldsymbol{x}
wTw\boldsymbol{w^Tw} 2w\boldsymbol{2w}
wTCw\boldsymbol{w^TCw} 2Cw\boldsymbol{2Cw}

机器学习笔记(二)——多变量最小二乘法相关推荐

  1. 吴恩达机器学习笔记二之多变量线性回归

    本节目录: 多维特征 多变量梯度下降 特征缩放 学习率 正规方程 1.多维特征 含有多个变量的模型,模型中的特征为(x1,x2,-xn), 比如对房价模型增加多个特征 这里,n代表特征的数量, x(i ...

  2. 吴恩达机器学习笔记 —— 5 多变量线性回归

    http://www.cnblogs.com/xing901022/p/9321045.html 本篇主要讲的是多变量的线性回归,从表达式的构建到矩阵的表示方法,再到损失函数和梯度下降求解方法,再到特 ...

  3. 机器学习笔记(二)模型评估与选择

    2.模型评估与选择 2.1经验误差和过拟合 不同学习算法及其不同参数产生的不同模型,涉及到模型选择的问题,关系到两个指标性,就是经验误差和过拟合. 1)经验误差 错误率(errorrate):分类错误 ...

  4. 机器学习笔记二 单型线性回归

    线性回归 (一)介绍 (二) 数学模型 2.1 一元线性回归公式 2.2 方差 - 损失函数 Cost Function 2.3 优化方法 Optimization Function 2.4 算法步骤 ...

  5. 机器学习笔记二十四 中文分词资料整理

    一.常见的中文分词方案 1. 基于字符串匹配(词典) 基于规则的常见的就是最大正/反向匹配,以及双向匹配. 规则里糅合一定的统计规则,会采用动态规划计算最大的概率路径的分词. 以上说起来很简单,其中还 ...

  6. 脑电图机器学习笔记(二):SVM 脑电波原信号和傅立叶变换的 癫痫信号检测

    使用SVM进行癫痫检测 背景: 这是一片论文的简单复现,只是还原思想,不知道是不是我看的不够仔细,我觉得论文说的也就是这样简单的操作 论文名称:Seizure prediction with spec ...

  7. 机器学习笔记(5)——逻辑回归

    上一篇:机器学习笔记(4)--多变量线性回归 逻辑回归实际是一种有监督学习中的分类算法,称为回归是历史原因 前言 前面我们已经学习了线性回归,线性回归适用于预测一个连续值,就是说预测值可能的范围存在连 ...

  8. 机器学习笔记三—卷积神经网络与循环神经网络

    系列文章目录 机器学习笔记一-机器学习基本知识 机器学习笔记二-梯度下降和反向传播 机器学习笔记三-卷积神经网络与循环神经网络 机器学习笔记四-机器学习可解释性 机器学习笔记五-机器学习攻击与防御 机 ...

  9. 吴恩达《机器学习》学习笔记三——多变量线性回归

    吴恩达<机器学习>学习笔记三--多变量线性回归 一. 多元线性回归问题介绍 1.一些定义 2.假设函数 二. 多元梯度下降法 1. 梯度下降法实用技巧:特征缩放 2. 梯度下降法的学习率 ...

  10. 吴恩达《机器学习》学习笔记二——单变量线性回归

    吴恩达<机器学习>学习笔记二--单变量线性回归 一. 模型描述 二. 代价函数 1.代价函数和目标函数的引出 2.代价函数的理解(单变量) 3.代价函数的理解(两个参数) 三. 梯度下降- ...

最新文章

  1. Tensorflow C++ API调用Keras模型实现RGB图像语义分割
  2. JSP笔记-页面重定向
  3. JAVA入门到精通-第73讲-学生管理系统5-dao.sqlhelper
  4. cassandra集群环境搭建——注意seeds节点,DHT p2p集群管理难道初始化都应如此吗?...
  5. volatile关键字及JMM模型
  6. Hyperledger Fabric 1.0 实例简析 第一课 network_setup.sh分析
  7. win7 计算机库 桌面,【备忘】win7下再硬盘安装win7(桌面库和家庭组图标删除)...
  8. NSUserDefaults数据保存使用
  9. 我从大厂面试中学到的关于 C# 的知识
  10. java实现线程的方式_java多线程实现的四种方式
  11. 100转换成二进制 java,一段简单的java代码,十进制转二进制
  12. 苹果HomePod mini出现连接不上Wi-Fi怎么办?解决办法来啦!
  13. Cobbler结合windows DHCP服务器的使用
  14. 亿图图示 软件下载与安装 20200715
  15. 硬件编程语言和编程器件
  16. html 打开高德地图,根据经纬度定位到某个地方(位置标注)
  17. python 计算标准体重程序
  18. SpringBoot整合极光推送
  19. 互联网,大数据和人工智能对我们的生活带来的影响
  20. QT项目-“kun容道”

热门文章

  1. java struts1_struts1.x
  2. mac xampp连接mysql数据库_请问在mac下xampp无法读取mysql的数据
  3. C++基础16-类和对象之联编,重写,虚析构
  4. Django连接现有mysql数据库
  5. mixin机制 vue_vue mixins组件复用的几种方式(小结)
  6. spring注解-声明式事务
  7. 史上最全的前端开发面试题(含详细答案)
  8. Jmeter响应中中文乱码怎么解决?
  9. 集合框架(九)----Map
  10. 深入沟通的重要性——《大道至简》第四章读后感