目录

  • 1.概述
  • 2.梯度下降算法
    • 2.1 场景假设
    • 2.1 梯度下降
      • 2.1.1 微分
      • 2.2.2 梯度
      • 2.3 数学解释
      • 2.3.1 α
      • 2.3.2 梯度要乘以一个负号
  • 3. 实例
    • 3.2 多变量函数的梯度下降
  • 4. 代码实现
    • 4. 1 场景分析
    • 4. 2 代码
  • 5. 小结

1.概述

梯度下降(gradient descent)在机器学习中应用十分的广泛,不论是在线性回归还是Logistic回归中,它的主要目的是通过迭代找到目标函数的最小值,或者收敛到最小值。
本文将从一个下山的场景开始,先提出梯度下降算法的基本思想,进而从数学上解释梯度下降算法的原理,解释为什么要用梯度,最后实现一个简单的梯度下降算法的实例!

2.梯度下降算法

2.1 场景假设

梯度下降法的基本思想可以类比为一个下山的过程。
假设这样一个场景:一个人被困在山上,需要从山上下来(找到山的最低点,也就是山谷)。但此时山上的浓雾很大,导致可视度很低;因此,下山的路径就无法确定,必须利用自己周围的信息一步一步地找到下山的路。这个时候,便可利用梯度下降算法来帮助自己下山。怎么做呢,首先以他当前的所处的位置为基准,寻找这个位置最陡峭的地方,然后朝着下降方向走一步,然后又继续以当前位置为基准,再找最陡峭的地方,再走直到最后到达最低处;同理上山也是如此,只是这时候就变成梯度上升算法了。

2.1 梯度下降

梯度下降的基本过程就和下山的场景很类似。

首先,我们有一个可微分的函数。这个函数就代表着一座山。我们的目标就是找到这个函数的最小值,也就是山底。根据之前的场景假设,最快的下山的方式就是找到当前位置最陡峭的方向,然后沿着此方向向下走,对应到函数中,就是找到给定点的梯度 ,然后朝着梯度相反的方向,就能让函数值下降的最快!因为梯度的方向就是函数之变化最快的方向(在后面会详细解释)
所以,我们重复利用这个方法,反复求取梯度,最后就能到达局部的最小值,这就类似于我们下山的过程。而求取梯度就确定了最陡峭的方向,也就是场景中测量方向的手段。那么为什么梯度的方向就是最陡峭的方向呢?接下来,我们从微分开始讲起:

2.1.1 微分

看待微分的意义,可以有不同的角度,最常用的两种是:

函数图像中,某点的切线的斜率
函数的变化率
几个微分的例子:
1.单变量的微分,函数只有一个变量

2.多变量的微分,当函数有多个变量的时候,即分别对每个变量进行求微分

2.2.2 梯度

梯度实际上就是多变量微分的一般化。
下面这个例子:

我们可以看到,梯度就是分别对每个变量进行微分,然后用逗号分割开,梯度是用<>包括起来,说明梯度其实一个向量。

梯度是微积分中一个很重要的概念,之前提到过梯度的意义

在单变量的函数中,梯度其实就是函数的微分,代表着函数在某个给定点的切线的斜率
在多变量函数中,梯度是一个向量,向量有方向,梯度的方向就指出了函数在给定点的上升最快的方向
这也就说明了为什么我们需要千方百计的求取梯度!我们需要到达山底,就需要在每一步观测到此时最陡峭的地方,梯度就恰巧告诉了我们这个方向。梯度的方向是函数在给定点上升最快的方向,那么梯度的反方向就是函数在给定点下降最快的方向,这正是我们所需要的。所以我们只要沿着梯度的方向一直走,就能走到局部的最低点!

2.3 数学解释


此公式的意义是:J是关于Θ的一个函数,我们当前所处的位置为Θ0点,要从这个点走到J的最小值点,也就是山底。首先我们先确定前进的方向,也就是梯度的反向,然后走一段距离的步长,也就是α,走完这个段步长,就到达了Θ1这个点!

2.3.1 α

α在梯度下降算法中被称作为学习率或者步长,意味着我们可以通过α来控制每一步走的距离,以保证不要步子跨的太大扯着蛋,哈哈,其实就是不要走太快,错过了最低点。同时也要保证不要走的太慢,导致太阳下山了,还没有走到山下。所以α的选择在梯度下降法中往往是很重要的!α不能太大也不能太小,太小的话,可能导致迟迟走不到最低点,太大的话,会导致错过最低点!

2.3.2 梯度要乘以一个负号

梯度前加一个负号,就意味着朝着梯度相反的方向前进!我们在前文提到,梯度的方向实际就是函数在此点上升最快的方向!而我们需要朝着下降最快的方向走,自然就是负的梯度的方向,所以此处需要加上负号;那么如果时上坡,也就是梯度上升算法,当然就不需要添加负号了。

3. 实例


3.2 多变量函数的梯度下降


4. 代码实现

4. 1 场景分析

下面我们将用python实现一个简单的梯度下降算法。场景是一个简单的线性回归的例子:假设现在我们有一系列的点,如下图所示:


4. 2 代码

首先,我们需要定义数据集和学习率

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
# @Time    : 2019/1/21 21:06
# @Author  : Arrow and Bullet
# @FileName: gradient_descent.py
# @Software: PyCharm
# @Blog    :https://blog.csdn.net/qq_41800366from numpy import *# 数据集大小 即20个数据点
m = 20
# x的坐标以及对应的矩阵
X0 = ones((m, 1))  # 生成一个m行1列的向量,也就是x0,全是1
X1 = arange(1, m+1).reshape(m, 1)  # 生成一个m行1列的向量,也就是x1,从1到m
X = hstack((X0, X1))  # 按照列堆叠形成数组,其实就是样本数据
# 对应的y坐标
y = np.array([3, 4, 5, 5, 2, 4, 7, 8, 11, 8, 12,11, 13, 13, 16, 17, 18, 17, 19, 21
]).reshape(m, 1)
# 学习率
alpha = 0.01

接下来我们以矩阵向量的形式定义代价函数和代价函数的梯度

# 定义代价函数
def cost_function(theta, X, Y):diff = dot(X, theta) - Y  # dot() 数组需要像矩阵那样相乘,就需要用到dot()return (1/(2*m)) * dot(diff.transpose(), diff)# 定义代价函数对应的梯度函数
def gradient_function(theta, X, Y):diff = dot(X, theta) - Yreturn (1/m) * dot(X.transpose(), diff)

最后就是算法的核心部分,梯度下降迭代计算

# 梯度下降迭代
def gradient_descent(X, Y, alpha):theta = array([1, 1]).reshape(2, 1)gradient = gradient_function(theta, X, Y)while not all(abs(gradient) <= 1e-5):theta = theta - alpha * gradientgradient = gradient_function(theta, X, Y)return thetaoptimal = gradient_descent(X, Y, alpha)
print('optimal:', optimal)
print('cost function:', cost_function(optimal, X, Y)[0][0])

当梯度小于1e-5时,说明已经进入了比较平滑的状态,类似于山谷的状态,这时候再继续迭代效果也不大了,所以这个时候可以退出循环!
运行代码,计算得到的结果如下:

print('optimal:', optimal)  # 结果 [[0.51583286][0.96992163]]
print('cost function:', cost_function(optimal, X, Y)[0][0])  # 1.014962406233101

通过matplotlib画出图像,

# 根据数据画出对应的图像
def plot(X, Y, theta):import matplotlib.pyplot as pltax = plt.subplot(111)  # 这是我改的ax.scatter(X, Y, s=30, c="red", marker="s")plt.xlabel("X")plt.ylabel("Y")x = arange(0, 21, 0.2)  # x的范围y = theta[0] + theta[1]*xax.plot(x, y)plt.show()plot(X1, Y, optimal)

所拟合出的直线如下

5. 小结

至此,就基本介绍完了梯度下降法的基本思想和算法流程,并且用python实现了一个简单的梯度下降算法拟合直线的案例!
最后,我们回到文章开头所提出的场景假设:
这个下山的人实际上就代表了反向传播算法,下山的路径其实就代表着算法中一直在寻找的参数Θ,山上当前点的最陡峭的方向实际上就是代价函数在这一点的梯度方向,场景中观测最陡峭方向所用的工具就是微分 。在下一次观测之前的时间就是有我们算法中的学习率α所定义的。
可以看到场景假设和梯度下降算法很好的完成了对应!

本文部分内容来自一位前辈,非常感谢分享!谢谢!

【算法】梯度下降算法及python实现相关推荐

  1. 梯度下降算法的python实现

    前言 梯度下降算法 Gradient Descent GD是沿梯度下降的方向连续迭代逼近求最小值的过程,本文将实现以下梯度下降算法的python实现. 简单梯度下降算法 批量梯度下降算法 随机梯度下降 ...

  2. 神经网络中的常用算法-梯度下降算法

    目录 一.概述 二.算法思想 1.一维 2.多维 三.梯度下降算法类型 1.批量梯度下降算法 2.随机梯度下降算法 3.小批量梯度下降算法 一.概述 梯度下降法(Gradient descent )是 ...

  3. 神经网络中的常用算法-梯度下降算法的优化

    一.概述 梯度下降法(Gradient descent )是一个一阶最优化算法,通常也称为最陡下降法 ,要使用梯度下降法找到一个函数的局部极小值 ,必须向函数上当前点对应梯度(或者是近似梯度)的反方向 ...

  4. 机器学习算法(优化)之一:梯度下降算法、随机梯度下降(应用于线性回归、Logistic回归等等)...

    本文介绍了机器学习中基本的优化算法-梯度下降算法和随机梯度下降算法,以及实际应用到线性回归.Logistic回归.矩阵分解推荐算法等ML中. 梯度下降算法基本公式 常见的符号说明和损失函数 X :所有 ...

  5. 梯度下降算法的正确步骤_梯度下降算法

    梯度下降算法的正确步骤 Title: What is the Gradient Descent Algorithm and its working. 标题:什么是梯度下降算法及其工作原理. Gradi ...

  6. 梯度下降算法和牛顿算法原理以及使用python用梯度下降和最小二乘算法求回归系数

    梯度下降算法 以下内容参考 微信公众号 AI学习与实践平台 SIGAI 导度和梯度的问题 因为我们做的是多元函数的极值求解问题,所以我们直接讨论多元函数.多元函数的梯度定义为: 其中称为梯度算子,它作 ...

  7. Python推荐算法:LFM梯度下降算法实现

    推荐算法步骤: LFM隐语义模型思路:根据用户(M)和产品(N)的共现矩阵R(M*N),去挖掘可能潜在的特征表现,假设有K个隐含特征,通过矩阵分得到M*K和K*M两个矩阵分别为P和Q.T,当损失函数的 ...

  8. 一文清晰讲解机器学习中梯度下降算法(包括其变式算法)

    本篇文章向大家介绍梯度下降(Gradient Descent)这一特殊的优化技术,我们在机器学习中会频繁用到. 前言 无论是要解决现实生活中的难题,还是要创建一款新的软件产品,我们最终的目标都是使其达 ...

  9. Python实现线性回归2,梯度下降算法

    接上篇 4.梯度下降算法 <斯坦福大学公开课 :机器学习课程>吴恩达讲解第二课时,是直接从梯度下降开始讲解,最后采用向量和矩阵的方式推导了解析解,国内很多培训视频是先讲解析解后讲梯度下降, ...

最新文章

  1. GPU微观物理结构框架
  2. redis set 超时_Redis 更新(set) key值过期时间被重置
  3. nginx phase handler的原理和选择
  4. linux 查看shell脚本执行了多长时间
  5. 自定义简单控件之标题控件
  6. Atitit 网络协议概论 艾提拉著作 目录 1. 有的模型分七层,有的分四层。我觉得 1 1.1. 三、链接层 确定了0和1的分组方式 1 1.2. 网络层(ip mac转换层 3 1.3. 传输
  7. [SQLite]www.sqlite.org官网.NET最新版本所有内容下载
  8. Matlab高尔顿板仿真模拟实验
  9. Alfa: 1 vulnhub walkthrough
  10. 关于java构造函数 的错误 there is no default constructor available in ...
  11. CAN总线学习笔记(1)- CAN基础知识
  12. 卸载linux grub rescue,卸载linux后出现“grub rescue”,怎么处理?
  13. Consistent hashing kills tencent2012笔试题附加题
  14. unity3D用什么语言开发好?
  15. 如何在SM30维护表时自动写入表字段的默认值-事件(EVENT)
  16. java暗黑再临-战神之怒_《暗黑破坏神-黑暗再临》暴力+召唤:德鲁伊新人单通攻略...
  17. 玲珑oj 1032A-B(组合数学)
  18. Good feelings
  19. 绝对零度试验机的创造战记:2.小型HTML5本地音乐播放器
  20. 深度学习(二) 神经网络基础算法推导与实践

热门文章

  1. Web 前端学习之 表格
  2. java文件读取与保存
  3. 幼儿安全教育道路交通安全宣传PPT模板
  4. 如何收割淘宝逛逛人群
  5. 银行信用卡中心通常关注哪些用户运营指标?
  6. linux egrep用法,grep,egrep及相应的正则表达式用法详解
  7. 【游记】台州市第一届信息学竞赛
  8. Netty基础入门——Reactor模式
  9. 双变量OLS回归模型(Python3)
  10. mysql打字看不见鼠标_电脑打字不显示怎么办?键盘不能打字不能正常输入怎么办?...