梯度下降(Gradient descent)算法详解

说起梯度下降算法,其实并不是很难,它的重要作用就是求函数的极值。梯度下降就是求一个函数的最小值,对应的梯度上升就是求函数最大值。为什么这样说呢?兔兔之后会详细讲解的。
虽然梯度下降与梯度上升都是求函数极值的算法,为什么我们常常提到“梯度下降”而不是梯度上升“呢?主要原因是在大多数模型中,我们往往需要求函数的最小值。比如BP神经网络算法,我们得出损失函数,当然是希望损失函数越小越好,这个时候肯定是需要梯度下降算法的。梯度下降算法作为很多算法的一个关键环节,其重要意义是不言而喻的。

算法思想
梯度下降算法的思想是:先任取点(x0,f(x0)),求f(x)在该点x0的导数f"(x0),在用x0减去导数值f"(x0),计算所得就是新的点x1。然后再用x1减去f"(x1)得x2…以此类推,循环多次,慢慢x值就无限接近极小值点。
具体是有公式推导的,不过比较麻烦。其实这个算法是可以直观理解的。比如对于函数f(x)=x**2,当x大于0时,导数大于零,x减去导数值后变小;只要x大于零,每次减去一个大于零的导数,x值肯定变小。当x小于零时,导数小于零,减去小于零的数后x值增加,所以无论x0起始于何处,最终都能走到极值点0处。只不过有可能从单侧趋近(像走楼梯一样下降),也可能x一会儿大于极值点,一会儿小于极值点,交替地趋近,最终x趋于0.

import numpy as np
x=np.arange(-5,5,0.1) #定义域-5~5
y=x**2 #求解的函数
pointx=[] #用来储存每次梯度下降后的点
x0=-2 #初始值的横坐标-2,随便选的
for i in range(10): #先执行10次xnew=x0-2*x0 #该点减去该点的导数值x0=xnew #移动到新的点pointx.append(x0) #储存点运动轨迹
pointy=np.array(pointx)**2
plt.plot(x,y,color="green") #画函数图像
plt.plot(pointx,pointy,color="red") #画梯度轨迹
plt.show()

结果如图1所示

没有想到,竟然出现了水平的红线,说明点一直在两边震荡,互相踢皮球,根本没有下去。
这就涉及到一个关键的东西:学习率。
我们减去了导数,的确没有错,但是很多时候函数所在点的导数值是比较大的。所以我们可以将导数f"(x0)在乘上学习率alpha, 让梯度的步子小一点,就会解决该问题。比如我们让学习率为0.2,每次x减去0.2f"(x),结果如图2所示。

为了看起来方便,兔兔把初始值设为-5。不难发现这次的确好了很多,最终几乎收敛到极小值点。
在一般情况下,我们设置学习率在0~1之间。但是学习率也不是越小越好。如果太小的话,每次走的步子很小,需要很长时间才能到达最优解。不过对于一些模型,很可能出现导数很大的情况,alpha若不是足够小的话,很可能会出现梯度爆炸的,这一点一定小心。兔兔当年做BP神经网络时学习率没有设好,就发生了梯度爆炸的。
梯度爆炸性状就是,在极值点两边震荡,并且离极值点越来越远,最终数大的很离谱,一会儿是负10的几百次方,一会儿又是正的几百次方。形状类似图3。

兔兔在这里设初始值为-0.1,为了方便我就直接把前面代码中学习率改成了1.2,也达到了梯度爆炸的效果。不难发现这个点在两边震荡,最终远离了极值点。
还有一种情形是梯度消失。这个在BP神经网络等模型中比较常见。当激活函数选择sigmoid函数,神经元层数很多时(不懂没关系的,兔兔之后肯定会详细讲解的),就容易出现梯度消失的情况,也就是没有了梯度。解决方法就是换激活函数,减少神经元层数。
关于梯度上升道理也是一样的,方法与上面梯度下降相同,只是这里每次x0加上alpha*f"(x0)。用前面的方法分析,最终肯定可以收敛到极大值点的。不过这个不是很常用,兔兔就不详细介绍了。

多元函数的梯度下降
对于多元函数也是一样的。比如二元函数f(x,y),我们每次只需要x0减去f(x,y)在该点对x的偏导数的值,y0减去f(x,y)在该点对y的偏导数。多次循环操作,最终就可以得到二元函数的极小值。过程类似于走山坡,一直往坡下走,走到最低点的坑洼处。

import matplotlib.pyplot as plt
ax=plt.axes(projection='3d')
x=np.arange(-5,5)
y=np.arange(-5,5)
X,Y=np.meshgrid(x,y)
Z=X**2+Y**2
ax.plot_surface(X,Y,Z,cmap="rainbow")
plt.show()


兔兔在上面绘制了函数f(x,y)=x2+y2,现在随便取初始点(4,-4,f(4,-4)),做梯度下降。学习率alpha=0.3.

import matplotlib.pyplot as plt
x0=4;y0=-4 #初始点
xlist=[x0] #储存x变化
ylist=[y0] #储存y点变化
zlist=[x0**2+y0**2]
alpha=0.3 #学习率0.3
for i in range(5): #执行5次xnew=x0-2*x0*alpha ynew=y0-2*y0*alphax0=xnew;y0=ynewxlist.append(x0)ylist.append(y0)zlist.append(x0**2+y0**2)
ax.scatter(xlist,ylist,zlist,color='red',marker='^') #画点的位置
plt.show()

结果如图5所示

这里面红色的三角就是每次梯度下降后变化的位置(兔兔学艺不精,没有画好图)。但是也可以发现红色三角形不断向原点(0,0,0)靠近,最终是可以收敛到最小值点的。

总结
梯度下降算法核心就是不断减导数与学习率的乘积。在日后的学习过程中应该体会不同函数、不同算法模型中学习率的设置,并且学会处理梯度爆炸与梯度消失的情况。虽然该算法不难,却是之后神经网络算法、逻辑回归等各种算法的基础,有着重要的意义。

梯度下降(Gradient descent)算法详解相关推荐

  1. Lesson 4.34.4 梯度下降(Gradient Descent)基本原理与手动实现随机梯度下降与小批量梯度下降

    Lesson 4.3 梯度下降(Gradient Descent)基本原理与手动实现 在上一小节中,我们已经成功的构建了逻辑回归的损失函数,但由于逻辑回归模型本身的特殊性,我们在构造损失函数时无法采用 ...

  2. 梯度下降(Gradient Descent),一句代码,一个式子

    一直以来,总是觉得国外的PhD们的教育以及课程的安排很好很强大,虽然是说很累作业多工作量大,但是功率大了,效果好点儿,浪费的时间也少,年轻人哪有怕苦怕累的.比比身边好多每天睡超过12小时的研究生们,不 ...

  3. 机器学习(1)之梯度下降(gradient descent)

    机器学习(1)之梯度下降(gradient descent) 题记:最近零碎的时间都在学习Andrew Ng的machine learning,因此就有了这些笔记. 梯度下降是线性回归的一种(Line ...

  4. 【李宏毅机器学习】04:梯度下降Gradient Descent

    李宏毅机器学习04:梯度下降Gradient Descent 文章目录 李宏毅机器学习04:梯度下降Gradient Descent 一.梯度下降方法 二.梯度下降的改进方法 Tip 1: Tunin ...

  5. 梯度下降 gradient descent

    文章目录 导数 偏导数 方向导数 梯度 代价函数的梯度 梯度下降的详细算法 先决条件 算法过程 代价损失中 θ 偏导数公式推导 批量梯度下降(Batch Gradient Descent,BGD) 随 ...

  6. 梯度下降 Gradient Descent 详解、梯度消失和爆炸

    1.什么是梯度 在微积分中,对多元函数的参数求∂偏导,把求得的各个参数的偏导数以向量形式写出来即为梯度. 例如对于函数f(x,y),分别对x,y求偏导,求得的梯度向量就是 (∂f/∂x, ∂f/∂y) ...

  7. excel计算二元线性回归_用人话讲明白梯度下降Gradient Descent(以求解多元线性回归参数为例)...

    文章目录 1.梯度 2.多元线性回归参数求解 3.梯度下降 4.梯度下降法求解多元线性回归 梯度下降算法在机器学习中出现频率特别高,是非常常用的优化算法. 本文借多元线性回归,用人话解释清楚梯度下降的 ...

  8. 机器学习中的数学(1)-回归(regression)、梯度下降(gradient descent)

    前言: 上次写过一篇关于贝叶斯概率论的数学,最近时间比较紧,coding的任务比较重,不过还是抽空看了一些机器学习的书和视频,其中很推荐两个:一个是stanford的machine learning公 ...

  9. 机器学习代码实战——梯度下降(gradient descent)

    文章目录 1.实验目的 2.梯度下降 2.1.借助sklearn库 2.2.手写梯度下降函数 1.实验目的 本实验将使用两种方法实现梯度下降算法并可打印出参数,可视化梯度下降过程.第一种方法是借助sk ...

最新文章

  1. JQuery - Sizzle选择器引擎原理分析
  2. 思科服务器型号m1414,Cisco UCS M 系列模块化服务器
  3. MyBatis 架构分层与模块划分-基础支持层
  4. 记录我开发工作中遇到HTTP跨域和OPTION请求的一个坑
  5. 12 哈希表相关类——Live555源码阅读(一)基本组件类
  6. Python 文件操作中的读写模式:open(path, ‘-模式-‘,encoding=‘UTF-8‘)+python读写文件txt +文本数据预处理
  7. python 把多个list合并为一个并去重内容_110道Python面试题(上)
  8. mariadb驱动下载教程_性能测试教程[3] nmon analyser
  9. 大规模数据作成时的注意点。
  10. html5导出错误,JavaScript:toDataUrl()抛出“安全错误:可能无法导出受污染的画布”. - 程序园...
  11. POCO C++库学习和分析 -- 线程 (二)
  12. 计算机基础知识常用口诀,计算机基础知识(初中级教程)-20210712024844.pdf-原创力文档...
  13. java char类型 unicode字符集 utf-8字符编码
  14. IDEA 中 Lombok 编译报错 Java 找不到符号问题
  15. 还不了解MySQL的指令?有它就够了!
  16. 安卓bmi项目_bmi计算器
  17. 再探JS---eval函数
  18. 20181217股市复盘
  19. el-table表格操作列合并行
  20. box-shadow的具体使用方法(一分钟详解)

热门文章

  1. 如何用ps来切分图片
  2. Emergence、Cascading effect and Delay in network system
  3. 计算机,英语,人文书籍廉价大甩卖,有买有送
  4. 深入掌握Java日志体系,再也不迷路了
  5. TMS320C6678 交换网子系统(二)
  6. Drf简介,什么是drf
  7. C语言再学习——指针
  8. STM32CubeProgrammer启动问题解决
  9. 图像二值化的阈值求法
  10. Oracle数据库基本操作(三) —— DQL相关内容说明及应用