文章目录

1.梯度

2.多元线性回归参数求解

3.梯度下降

4.梯度下降法求解多元线性回归

梯度下降算法在机器学习中出现频率特别高,是非常常用的优化算法。

本文借多元线性回归,用人话解释清楚梯度下降的原理和步骤。

(PS:也不知道为啥,在markdown里写好的公式,有一部分在这儿无法正常显示,只好转图片贴过来了)

1.梯度

梯度是什么呢?

我们还是从最简单的情况说起,对于一元函数来讲,梯度就是函数的导数

而对于多元函数而言,梯度是一个向量,也就是说,把求得的偏导数以向量的形式写出来,就是梯度

例如,我们在用人话讲明白线性回归LinearRegression一文中,求未知参数

时,对损失函数求偏导,此时的梯度向量为
,其中:

那篇文章中,因为一元线性回归中只有2个参数,因此令两个偏导数为0,能很容易求得

的解。

但是,这种求导的方法在多元回归的参数求解中就不太实用了,为什么呢?

2.多元线性回归参数求解

多元线性回归方程的一般形式为:

可以简写为矩阵形式(一般加粗表示矩阵或向量):

其中,

之前我们介绍过一元线性回归的损失函数可以用残差平方和:

代入多元线性回归方程就是:

用矩阵形式表示:

上面的展开过程涉及矩阵转置,这里简单提一下矩阵转置相关运算,以免之前学过但是现在忘了:

好了,按照一元线性回归求解析解的思路,现在我们要对Q求导并令导数为0(原谅我懒,后面写公式就不对向量或矩阵加粗了,大家能理解就行):

上面的推导过程涉及矩阵求导,这里以

求导为例展开讲下,为什么
,其他几项留给大家举一反三。

首先:

为了直观点,我们将

记为A,因为Y是n维列向量,X是n×(p+1)的矩阵,因此
是(p+1)维行向量:

那么上面求导可以简写为:

这种形式的矩阵求导属于分母布局,即分子为行向量或者分母为列向量(这里属于后者)。

搞不清楚的可以看看这篇:矩阵求导实例,这里我直接写出标量/列向量求导的公式,如下(y表示标量,X表示列向量):

根据上式,显然有:

前面我们将

记为A,
,那么上面算出来的结果就是
,即

说了这么多有的没的,最终我想说是的

,里面涉及到矩阵求逆,

但实际问题中可能X没有逆矩阵,这时计算的结果就不够精确

第二个问题就是,如果维度多、样本多,即便有逆矩阵,计算机求解的速度也会很慢

所以,基于上面这两点,一般情况下我们不会用解析解求解法求多元线性回归参数,而是采用梯度下降法,它的计算代价相对更低。

3.梯度下降

好了,重点来了,本文真正要讲的东西终于登场了。

梯度下降,就是通过一步步迭代,让所有偏导函数都下降到最低。如果觉得不好理解,我们就还是以最简单的一元函数为例开始讲。

下图是我用Excel简单画的二次函数图像(看起来有点歪,原谅我懒……懒得调整了……),函数为

,它的导数为y=2x。

如果我们初始化的点在x=1处,它的导函数值,也就是梯度值是2,为正,那就让它往左移一点,继续计算它的梯度值,若为正,就继续往左移。

如果我们初始化的点在x=-1处,该处的梯度值是-2,为负,那就让它往右移。

多元函数的逻辑也一样,先初始化一个点,也就是随便选择一个位置,计算它的梯度,然后往梯度相反的方向,每次移动一点点,直到达到停止条件

这个停止条件,可以是足够大的迭代步数,也可以是一个比较小的阈值,当两次迭代之间的差值小于该阈值时,认为梯度已经下降到最低点附近了。

二元函数的梯度下降示例如上图(图片来自梯度下降),对于这种非凸函数,可能会出现这种情况:初始化的点不同,最后的结果也不同,也就是陷入局部最小值

这种问题比较有效的解决方法,就是多取几个初始点。不过对于我们接下来讲的多元线性回归,以及后面要讲的逻辑回归,都不存在这个问题,因为他们的损失函数都是凸函数,有全局最小值。

用数学公式来描述梯度下降的步骤,就是:

解释下公式含义:

  • 为k时刻的点坐标,
    为下一刻要移动到的点的坐标,例如
    就代表初始化的点坐标,
    就代表第一步到移动到的位置;
  • g代表梯度,前面有个负号,就代表梯度下降,即朝着梯度相反的反向移动;
  • 被称为步长,用它乘以梯度值来控制每次移动的距离,这个值的设定也是一门学问,设定的过小,迭代的次数就会过多,设定的过大,容易一步跨太远,直接跳过了最小值。

4.梯度下降法求解多元线性回归

回到前面的多元线性回归,我们用梯度下降算法求损失函数的最小值。

首先,求梯度,也就是前面我们已经给出的求偏导的公式:

将梯度代入随机梯度下降公式:

这个式子中,X矩阵和Y向量都是已知的,步长是人为设定的一个值,只有参数

是未知的,而每一步的
是由
决定的,也就是每一步的点坐标。

算法过程:

1. 初始化

向量的值,即
,将其代入
得到当前位置的梯度;

2. 用步长

乘以当前梯度,得到从当前位置下降的距离;

3. 更新

,其更新表达式为

4. 重复以上步骤,直到更新到某个

,达到停止条件,这个
就是我们求解的参数向量。

参考链接:

深入浅出--梯度下降法及其实现

梯度下降与随机梯度下降概念及推导过程

excel计算二元线性回归_用人话讲明白梯度下降Gradient Descent(以求解多元线性回归参数为例)...相关推荐

  1. 用人话讲明白线性回归LinearRegression

    文章目录 1.什么是回归 2.一元线性回归 3.损失函数 4.最小二乘估计 5.小结 1.什么是回归 当我们学习一门新课程.接触一个新专业时,总会对该领域的专有名词感到困惑,甚至看完解释仍难以理解其含 ...

  2. 梯度下降参数不收敛_一文讲透梯度下降

    本文始发于个人公众号:TechFlow,原创不易,求个关注 在之前的文章当中,我们一起推导了线性回归的公式,今天我们继续来学习上次没有结束的内容. 上次我们推导完了公式的时候,曾经说过由于有许多的问题 ...

  3. 【优化】近端梯度下降(Proximal Gradient Descent)求解Lasso线性回归问题

    文章目录 近端梯度下降的背景 常见线性回归问题 近端算子(Proximal Operator) 近端梯度下降迭代递推方法 以Lasso线性回归问题为例 参考资料 近端梯度下降的背景 近端梯度下降(Pr ...

  4. 用人话讲明白聚类算法kmeans

    文章目录 1.什么是聚类 2.K-Means步骤 3.K-Means的数学描述 4.初始中心点怎么确定 5.K值怎么确定 6.小结 1.什么是聚类 先来回顾一下本系列第一篇就讲到的机器学习的种类. 监 ...

  5. 梯度下降法求解多元线性回归 — NumPy

    梯度下降法求解多元线性回归问题 使用梯度下降法求解一元线性回归的方法也可以被推广到求解多元线性回归问题. 这是多元线性回归的模型: 其中的 X 和 W 都是 m+1 维的向量. 下图为它的损失函数: ...

  6. 随机梯度下降法_动量梯度下降法(gradient descent with momentum)

    简介 动量梯度下降法是对梯度下降法的改良版本,通常来说优化效果好于梯度下降法.对梯度下降法不熟悉的可以参考梯度下降法,理解梯度下降法是理解动量梯度下降法的前提,除此之外要搞懂动量梯度下降法需要知道原始 ...

  7. excel计算二元线性回归_怎么看懂Excel线性回归参数

    虽然之前用python做线性回归的时候看上去好像很简单,但是直到我在excel上实践了线性回归,发现有很多指标值得去参考,对于在python上实现有指导作用. 打开excel2016,先找个数据 我们 ...

  8. 深度学习笔记--pytorch从梯度下降到反向传播BP到线性回归实现,以及API调用和手写数据集的实现

    梯度下降和反向传播 目标 知道什么是梯度下降 知道什么是反向传播 1. 梯度是什么? 梯度:是一个向量,导数+变化最快的方向(学习的前进方向) 回顾机器学习 收集数据 x x x ,构建机器学习模型 ...

  9. 梯度下降求最小值和线性方程(线性回归详解)

    线性回归 梯度下降法一般用于求解最小值,以下分别例举两种求最小值情况: 1)二次函数求最小值的情况 2)预测线性函数的情况(已知x和y,求解最合适的w和b,是预测误差最小) 以下就按照上述的两种情况进 ...

最新文章

  1. 2003配置php环境,2003配置PHP环境(有利于升级)
  2. 重磅!Nature子刊:利用GAN来​“深度伪造大脑数据”可以改善残疾人的脑机接口...
  3. python 3.x urllib学习
  4. 常用DOS系统功能调用(INT 21H)一览表
  5. mysql useing查询,MySQL数据库之多表查询using优化与案例
  6. 点击User Profile Service Application 报错
  7. mat opencv 修改roi_设置图片ROI(OpenCV学习笔记之二)
  8. 信息学奥赛一本通 1026:空格分隔输出 | OpenJudge NOI 1.1 06
  9. 目标检测(九)--YOLO v1,v2,v3
  10. 数据库管理软件SQLPro for SQLite for Mac 2022.30
  11. Apache Flink 的迁移之路,2 年处理效果提升 5 倍
  12. keil5.24 +注册机 下载
  13. gnu开源代码_GNU Health扩展了对Raspberry Pi的支持,Megadeth的吉他手使用了开源原则,以及更多的开源新闻。...
  14. keras中sample_weight的使用
  15. 2008年最后的感动
  16. Android多进程从头讲到尾,成功定级腾讯T3-2
  17. java 序列化规范_Java序列化格式详解
  18. 细说MySQL的时间戳(Timestamp)类型
  19. 学生dreamweaver网页设计作业成品:电商网页设计——仿淘宝静态首页(HTML+CSS)
  20. matlab数学实验报告西安交通大学微分方程模型高为16米,数学实验第二次作业——常微分方程数值求解...

热门文章

  1. 4. js移动端点触(tap)事件
  2. vue 使用fetch 出现问题解决以及 相应知识学习
  3. 【FPGA】十一、I2C通信回环
  4. JavaScript学习之旅-导言篇
  5. 2017 AMC8中文
  6. 网络工程师学习笔记——RIP路由汇总实验配置精讲
  7. Windows 下xampp的安装使用以及本地静态页面的访问(HTML+CSS+JS)
  8. Git版本控制器(涵盖GitHub\Gitee码云\GitLab),全网最详细教程
  9. 新浪微博客户端开发详解-总结(四)
  10. 提现业务流程介绍与设计