本文始发于个人公众号:TechFlow,原创不易,求个关注

在之前的文章当中,我们一起推导了线性回归的公式,今天我们继续来学习上次没有结束的内容。

上次我们推导完了公式的时候,曾经说过由于有许多的问题,比如最主要的复杂度问题。随着样本和特征数量的增大,通过公式求解的时间会急剧增大,并且如果特征为空,还会出现公式无法计算的情况。所以和直接公式求解相比,实际当中更倾向于使用另外一种方法来代替,它就是今天这篇文章的主角——梯度下降法

梯度下降法可以说是机器学习和深度学习当中最重要的方法,可以说是没有之一。尤其是在深度学习当中,几乎清一色所有的神经网络都是使用梯度下降法来训练的。那么,梯度下降法究竟是一种什么样的方法呢,让我们先从梯度的定义开始。

梯度的定义

我们先来看看维基百科当中的定义:梯度(gradient)是一种关于多元导数的概括。平常的一元(单变量)函数的导数是标量值函数,而多元函数的梯度是向量值函数。多元可微函数

在点
上的梯度,是以
上的偏导数为分量的

向量

这句话很精炼,但是不一定容易理解,我们一点一点来看。我们之前高中学过导数,但是高中时候计算的求导往往针对的是一元函数。也就是说只有一个变量x,求导的结果是一个具体的值,它是一个标量。而多元函数在某个点求导的结果是一个向量,n元函数的求导的结果分量就是n,导数的每个分量是对应的变量在该点的偏导数。这个偏导数组成的向量,就是这个函数在该点的梯度。

那么,根据上面的定义,我们可以明确一点,梯度是一个向量,它既有方向,也有大小。

梯度的解释

维基百科当中还列举了两个关于梯度的例子,帮助我们更好的理解。

第一个例子是最经典的山坡模型,假设我们当下站在一个凹凸不平的山坡上,我们想要以最快的速度下山,那么我们应该该从什么方向出发呢?很简单,我们应该计算一下脚下点的梯度,梯度的方向告诉我们下山最快的方向,梯度的大小代表这点的坡度

第二个例子是房间温度模型,假设我们对房间建立坐标系,那么房间里的每一个点都可以表示成

,该点的温度是
。如果假设房间的温度不随时间变化,那么房间里每个点的梯度表示温度变热最快的方向,梯度的大小代表温度变化的速率。

通过这两个例子,应该很容易理解梯度的方向和大小这两个概念。

举例

假设f是一个定义在三维空间里的函数,那么,f在某一点的梯度,可以写成:

这里的

都是标准单位向量,表示坐标轴
的方向。

我们举个例子:

套入刚才的梯度公式,可以得到:

如果我们知道

的坐标,代入其中,就可以知道对应的梯度了。

梯度下降法

理解了梯度的概念之后,再来看梯度下降法其实就是一张图的事。请看下面这张图。

这里的黑色的曲线表示我们损失函数的函数曲线,我们要做的,就是找到这个最佳的参数x,使得损失函数的值最小。损失函数的值达到最小,也就说明了模型的效果达到了极限,这也就是我们预期的。

我们一开始的时候显然是不知道最佳的x是多少的(废话,知道了还求啥),所以我们假设一开始的时候在一个随机的位置。就假设是图中的

的位置。接着我们对
求梯度。我们之前说了,梯度就是该点下降最陡峭的方向,梯度的大小就是它的陡峭程度。我们既然知道了梯度的方向之后,其实就很简单了,我们要做的就是

朝着梯度下降,也就是最陡峭的方向向前走一小步

我们假设,

处的梯度是
,那么我们根据
通过迭代的方法优化损失函数。说起来有些空洞,我写出来就明白了。

从上面这个公式可以看出来,这是一个迭代公式。也就是说我们通过不停地迭代,来优化参数。理论上来说,这样的迭代是没有穷尽的,我们需要手动终止迭代。什么时候可以停止呢?我们可以判断每一次迭代的梯度,当梯度已经小到逼近于0的时候,就说明模型的训练已经收敛了,这个时候可以停止训练了。

这里的

是一个固定的参数,称为学习率,它表示梯度对于迭代的影响程度。学习率越大,说明梯度对于参数变化的影响越大。如果学习率越小,自然每一次迭代参数的变化也就越小,说明到收敛需要的迭代次数也就越多,也可以单纯理解成,收敛需要的时间也就越长。

那么是不是学习率越大越好呢?显然也不是的。因为如果学习率过大,很有可能会导致在迭代的过程当中错过最优点。就好像油门踩猛了,一下子就过头了,于是可能会出现永远也无法收敛的情况。比如我们可以参考下面这张图:

从这张图上可以看到,变量一直在最值附近震荡,永远也达不成收敛状态。

如果学习率设置得小一些是不是就没事了?也不是,如果设置的学习率过小,除了会导致迭代的次数非常庞大以至于训练花费的时间过久之外,还有可能由于小数的部分过大,导致超出了浮点数精度的范围,以至于出现非法值Nan这种情况出现。同样,我们可以参考一下下图:

这张图画的是学习率过小,导致一直在迭代,迟迟不能收敛的情况。

从上面这两张图,我们可以看得出来,在机器学习领域学习率的设置非常重要。一个好的参数不仅可以缩短模型训练的时间,也可以使模型的效果更好。但是设置学习率业内虽然有种种方法,但是不同的问题场景,不同的模型的学习率设置方法都略有差别,也正因此,很多人才会调侃自己是调参工程师。

我们来看一下一个合适的学习率的迭代曲线是什么样的。

到这里还没有结束,好的学习率并不能解决所有的问题。在有些问题有些模型当中,很有可能最优解本身就是无法达到的,即使用非常科学的方法,设置非常好的参数。我们再来看一张图:

这张图有不止一个极值点,如果我们一开始的时候,参数落在了区间的左侧,那么很快模型就会收敛到一个极值,但是它并不是全局最优解,只是一个局部最优解。这时候无论我们如何设置学习率,都不可能找到右侧的那个全局最优解。同样,如果我们一开始参数落在了区间右侧,那里的曲线非常平坦,使得每次迭代的梯度都非常小,非常接近0.那么虽然最终可以到达全局最优解,但是需要经过漫长的迭代过程。

所以,模型训练、梯度下降虽然方法简单,但是真实的使用场景也是非常复杂的。我们不可以掉以轻心,不过好在,对于线性回归的最小二乘法来说,损失函数是一个凸函数,意味着它一定有全局最优解,并且只有一个。随着我们的迭代,一定可以达到收敛。

代码实战

Talk is cheap, show me the code.

光说不练假把式,既然我们已经学习到了梯度下降的精髓,也该亲身用代码体验一下了。我们还是用之前线性回归的问题。如果有遗忘的同学可以点击下方的链接回顾一下之前的内容:

机器学习基础——推导线性回归公式​mp.weixin.qq.com

还是和之前一样,我们先生成一批点:

import 

这是根据函数

随机出来的,我们接下来就要通过梯度下降的方法来做线性回归。首先,我们来推导一下梯度公式:

在使用梯度下降算法的时候,我们其实计算当前

下的梯度。这个量反应的是当我们的
发生变化的时候,整个的损失函数MSE(mean square error 均方差)会变化多少。而梯度,可以通过对变量求偏导得到。写成:

我们单独计算

的损失函数偏导,写成:
,带入之前的损失函数公式,计算化简可以得到:

这只是

的偏导数,我们可以把向量
中每一个变量的偏导数合在一起计算。标记为:

我们不难看出,在这个公式当中,我们涉及了全量的训练样本X。因此这种方法被称为批量梯度下降。因此,当我们的训练样本非常大的时候,会使得我们的算法非常缓慢。但是使用梯度下降算法,和特征的数量成正比,当特征数量很大的时候,梯度下降要比方程直接求解快得多。

需要注意一点,我们推导得到的梯度是向上的方向。我们要下降,所以需要加一个负号,最后再乘上学习率,得到的公式如下:

根据公式,写出代码就不复杂了:

eta 

我们调用一下这段代码,来查看一下结果:

和我们设置的参数非常接近,效果算是很不错了。如果我们调整学习率和迭代次数,最后的效果可能会更好。

观察一下代码可以发现,我们在实现梯度下降的时候,用到了全部的样本。显然,随着样本的数量增大,梯度下降会变得非常慢。为了解决这个问题,专家们后续推出了许多优化的方法。不过由于篇幅的限制,我们会在下一篇文章当中和大家分享,感兴趣的同学可以小小地期待一下。

梯度下降非常重要,可以说是机器学习领域至关重要的基础之一,希望大家都能学会。

今天的文章就到这里,如果觉得有所收获,请顺手点个关注吧,你们的支持是我最大的动力。

梯度下降参数不收敛_一文讲透梯度下降相关推荐

  1. 双线macd指标参数最佳设置_一文讲透双线MACD指标及其实战运用

    原标题:一文讲透双线MACD指标及其实战运用 船长的舍得交易体系技术理论模型中,我们要用到两大指标,分别是均线系统和双线MACD指标. 很多小伙伴都喜欢用双线MACD这个指标,但是90%的人都不知道其 ...

  2. 10自带sftp服务器_一文讲透FTP和SFTP的区别

    阅读本文约需要10分钟,您可以先关注我们或收藏本文,避免下次无法找到. FTP和SFTP都是文件传输协议,我们知道FTP使用的是20和21端口,SFTP使用的是22端口.另外,SFTP前面的S应该是S ...

  3. cstring只获取到第一个数_一文讲透 Dubbo 负载均衡之最小活跃数算法

    (给ImportNew加星标,提高Java技能) 作者:why技术(本文来自作者投稿) 本文是对于Dubbo负载均衡策略之一的最小活跃数算法的详细分析.文中所示源码,没有特别标注的地方均为2.6.0版 ...

  4. 反向传播算法(过程及公式推导)_一文讲透神经网络的反向传播,要点介绍与公式推导...

    神经网络的反向传播是什么 神经网络的反向传播,实际上就是逐层计算梯度下降所需要的$w$向量的"变化量"(代价函数$J(w1,b1,w2,b2,w3,b3...wn,bn)$对于$w ...

  5. python方差分析误差棒_一文讲透,带你学会用Python绘制带误差棒的柱状图和条形图...

    Python数据可视化,作为数据常用的必备技能,是目前大数据和数据分析的一个热门,而matplotlib库作为Python中最为常用和经典的二维绘图库,受到了很多人的青睐,最近已经和大家共同探讨了多种 ...

  6. python 多线程和协程结合_一文讲透 “进程、线程、协程”

    本文从操作系统原理出发结合代码实践讲解了以下内容: 什么是进程,线程和协程? 它们之间的关系是什么? 为什么说Python中的多线程是伪多线程? 不同的应用场景该如何选择技术方案? ... 什么是进程 ...

  7. ubuntu linux开机启动自动加载ko驱动程序_一文讲透 CentOS 开机流程

    一.Linux开机流程: BIOS: (Basic Input Output System)基本输入输出系统,它是一组固化到计算机内主板上一个ROM芯片 上的程序,保存着计算机最重要的基本输入输出的程 ...

  8. python selenium自动化框架_一文讲透!实现一个Python+Selenium的自动化测试框架如此简单!...

    首先你得知道什么是Selenium? Selenium是一个基于浏览器的自动化测试工具,它提供了一种跨平台.跨浏览器的端到端的web自动化解决方案.Selenium主要包括三部分:Selenium I ...

  9. itstime后面跟什么_一文讲透什么是引流

    这个问题老生常谈,都快腻了,还是有人时不时问老马.究其原因,很多人从想做引流.到动手操作,整个流程都是懵逼的状态. 引流不难,难的是一直卡在某个阶段,或者一直停留在那里.这样,你做再多次引流,还是患得 ...

最新文章

  1. Spring AOP源码分析(八)SpringAOP要注意的地方
  2. hadoop学习笔记(三):hdfs体系结构和读写流程(转)
  3. Android 让系统自动生成缩略图并写入媒体库
  4. reentrantlock非公平锁不会随机挂起线程?_程序员必须要知道的ReentrantLock 及 AQS 实现原理...
  5. 腾讯图片处理 Tencent AlloyTeam 2013
  6. 微软ASP.NET AJAX框架剖析
  7. python求平面坐标最接近的点_从Python中的集合中有效地找到最接近的坐标对
  8. Javascript 监控键盘输入事件
  9. 洛谷P1141 01迷宫【bfs】
  10. JSP项目引入Vue.js进行项目开发(工程搭建)
  11. NXOPEN/UG二次开发C#---获取NX的版本
  12. 【超纯水制备技术分享】超纯水工艺设计流程以及纯水系统前处理技术——离子交换脱盐工艺介绍
  13. 生日悖论分析基于python
  14. 【万字长文】——作者底层逻辑辨析【自组织场景宣言】,拉开未来序幕!
  15. echarts 3d饼图
  16. 安云网络-高防服务器租用的稳定性
  17. 关于如何解决360篡改edge等浏览器主页的解决方案
  18. 首次使用idea需要配置哪些东西?
  19. 李代数与李群间的转换-指数映射、对数映射
  20. 数学论文(优化方向)写作总结

热门文章

  1. Verilog语法_3(同步有限状态机)
  2. consin(consin英文怎么念)
  3. 【前缀和】【dp】开发区规划
  4. 【开源项目】Sa-Token快速登录(使用+源码解析)
  5. 第一篇 厚黑学 三、厚黑经
  6. 55/45 Jump Game 跳跃游戏
  7. 软件工程应用与实践(8)——视频清晰度切换
  8. 一个莆田系医院网站提醒的浏览器插件
  9. Docker下Nacos持久化配置
  10. 创建Maven项目报错