北京 | 深度学习与人工智能研修

12月23-24日

再设经典课程 重温深度学习阅读全文>
前言

梯度下降法(Gradient Descent)是机器学习中最常用的优化方法之一,常用来求解目标函数的极值。

其基本原理非常简单:沿着目标函数梯度下降的方向搜索极小值(也可以沿着梯度上升的方向搜索极大值)。

但是如何调整搜索的步长(也叫学习率,Learning Rate)、如何加快收敛速度以及如何防止搜索时发生震荡却是一门值得深究的学问。接下来本文将分析第一个问题:学习率的大小对搜索过程的影响。全部源代码可在本人的GitHub:monitor1379(https://github.com/monitor1379/jianshu_blog/blob/master/scripts/gradient_descent_with_momentum_and_decay.py)中下载。

快速教程

前言啰嗦完了,接下来直接上干货:如何编写梯度下降法。代码运行环境为Python 2.7.11 + NumPy 1.11.0 + Matplotlib 1.5.1。

首先先假设现在我们需要求解目标函数func(x) = x * x的极小值,由于func是一个凸函数,因此它唯一的极小值同时也是它的最小值,其一阶导函数 为dfunc(x) = 2 * x。

import numpy as np

import matplotlib.pyplot as plt

# 目标函数:y=x^2

def func(x): return np.square(x)

# 目标函数一阶导数:dy/dx=2*x

def dfunc(x): return 2 * x

接下来编写梯度下降法函数:

# Gradient Descentdef GD(x_start, df, epochs, lr):    """    梯度下降法。给定起始点与目标函数的一阶导函数,求在epochs次迭代中x的更新值    :param x_start: x的起始点    :param df: 目标函数的一阶导函数    :param epochs: 迭代周期    :param lr: 学习率    :return: x在每次迭代后的位置(包括起始点),长度为epochs+1    """    xs = np.zeros(epochs+1)    x = x_start    xs[0] = x    for i in range(epochs):        dx = df(x)        # v表示x要改变的幅度        v = - dx * lr        x += v        xs[i+1] = x    return xs

需要注意的是参数df是一个函数指针,即需要传进我们的目标函数一阶导函数。

测试代码如下,假设起始搜索点为-5,迭代周期为5,学习率为0.3:

def demo0_GD():    
x_start = -5    
epochs = 5    
lr = 0.3    
x = GD(x_start, dfunc, epochs, lr=lr)    
print x    
# 输出:[-5.     -2.     -0.8    -0.32   -0.128  -0.0512]

继续修改一下demo0_GD函数以更加直观地查看梯度下降法的搜索过程:

def demo0_GD():    
"""演示如何使用梯度下降法GD()"""    
line_x = np.linspace(-5, 5, 100)  
 line_y = func(line_x)    
x_start = -5    
epochs = 5    
lr = 0.3    
x = GD(x_start, dfunc, epochs, lr=lr)    
color = 'r'    
plt.plot(line_x, line_y, c='b')    
plt.plot(x, func(x), c=color, label='lr={}'.format(lr))    
plt.scatter(x, func(x), c=color, )    
plt.legend()

plt.show()

从运行结果来看,当学习率为0.3的时候,迭代5个周期似乎便能得到蛮不错的结果了。

demo0_GD运行结果

梯度下降法确实是求解非线性方程极值的利器之一,但是如果学习率没有调整好的话会发生什么样的事情呢?

学习率对梯度下降法的影响

在上节代码的基础上编写新的测试代码demo1_GD_lr,设置学习率分别为0.1、0.3与0.9:

def demo1_GD_lr():    
# 函数图像  
 line_x = np.linspace(-5, 5, 100)    
line_y = func(line_x)    
plt.figure('Gradient Desent: Learning Rate')    
x_start = -5    
epochs = 5    
lr = [0.1, 0.3, 0.9]    
color = ['r', 'g', 'y']    
size = np.ones(epochs+1) * 10    
size[-1] = 70    
for i in range(len(lr)):        
x = GD(x_start, dfunc, epochs, lr=lr[i])        
plt.subplot(1, 3, i+1)

plt.plot(line_x, line_y, c='b')        
plt.plot(x, func(x), c=color[i], label='lr={}'.format(lr[i]))        
plt.scatter(x, func(x), c=color[i])        
plt.legend()

plt.show()

从下图输出结果可以看出两点,在迭代周期不变的情况下:

  • 学习率较小时,收敛到正确结果的速度较慢。

  • 学习率较大时,容易在搜索过程中发生震荡。

demo1_GD_lr运行结果

综上可以发现,学习率大小对梯度下降法的搜索过程起着非常大的影响,为了解决上述的两个问题,接下来的博客《【梯度下降法】二:冲量(momentum)的原理与Python实现》将讲解冲量(momentum)参数是如何在梯度下降法中起到加速收敛与减少震荡的作用。

原文链接:http://www.jianshu.com/p/186df2db8898

查阅更为简洁方便的分类文章以及最新的课程、产品信息,请移步至全新呈现的“LeadAI学院官网”:

www.leadai.org

请关注人工智能LeadAI公众号,查看更多专业文章

大家都在看

LSTM模型在问答系统中的应用

基于TensorFlow的神经网络解决用户流失概览问题

最全常见算法工程师面试题目整理(一)

最全常见算法工程师面试题目整理(二)

TensorFlow从1到2 | 第三章 深度学习革命的开端:卷积神经网络

装饰器 | Python高级编程

今天不如来复习下Python基础

点击“阅读原文”直接打开报名链接

梯度下降法快速教程 | 第一章:Python简易实现以及对学习率的探讨相关推荐

  1. 梯度下降法快速教程 | 第二章:冲量(momentum)的原理与Python实现

    北京 | 深度学习与人工智能研修 12月23-24日 再设经典课程 重温深度学习阅读全文> 01 前言 梯度下降法(Gradient Descent)是机器学习中最常用的优化方法之一,常用来求解 ...

  2. 梯度下降法快速教程 | 第三章:学习率衰减因子(decay)的原理与Python实现

    北京 | 深度学习与人工智能 12月23-24日 再设经典课程 重温深度学习阅读全文> 正文共3017个字.11张图.预计阅读时间:8分钟 前言 梯度下降法(Gradient Descent)是 ...

  3. 廖雪峰python教程——第一章 Python基础

    第一章 Python基础 一.数据类型和变量 Python的数据类型包括整数.浮点数.字符串.布尔值.空值.变量.常量等.其中整数可以表示任意大小的整数:空值是Python里一个特殊的值,用None表 ...

  4. [转载] 《python程序设计应用教程》第一章 python语言概述

    参考链接: Python语言的优势和应用 第一章 python语言概述 1.1 python语言简介 ① 众多的开源的科学计算软件包都提供了python的调用接口,例如:计算机视觉库OpenCV.三维 ...

  5. 乐行学院RabbitMQ学习教程 第一章 RabbitMQ介绍(可供技术选型时使用)

    乐行学院RabbitMQ学习教程 第一章 RabbitMQ介绍 RabbitMQ介绍 1.RabbitMQ技术简介 2.RabbitMQ其他扩展插件 2.1监控工具rabbitmq-managemen ...

  6. 流畅的python读书笔记-第一章Python 数据模型

    第一章 python数据类型 1 隐式方法 利用collections.namedtuple 快速生成类 import collectionsCard = collections.namedtuple ...

  7. 萌新向Python数据分析及数据挖掘 第一章 Python基础 第三节 列表简介 第四节 操作列表...

    第一章 Python基础 第三节 列表简介 列表是是处理一组有序项目的数据结构,即可以在一个列表中存储一个序列的项目.列表中的元素包括在方括号([])中,每个元素之间用逗号分割.列表是可变的数据类型, ...

  8. 大数据技术技能分析大赛——第一章 python数据分析概述

    目标:掌握python,进行数据处理.统计分析.回归建模和数据可视化. 教材:<大数据分析务实初级教程(python)## 标题> 第一章 python数据分析概述 1.数据分析概述 1 ...

  9. Python入门到精通【精品】第一章 - Python概述

    Python入门到精通[精品]第一章 - Python概述 1. Python语言历史 2. Python语言特点 3. Python的下载和安装 3.1. Python的下载 3.2. Python ...

最新文章

  1. MySQL 学习笔记(1)— 创建/连接/选择/显示数据库(表) 查询单列(多列/所有列)/查询返回特定的行数 各种排序(单列/多列/降序/组合排序) 过滤数据
  2. java 中文符号占位_java – ‘占位符’字符以避免积极比较?
  3. java反射机制知识_Java反射机制讲解,程序员必须掌握的知识点
  4. 如何获取js对象的对象名
  5. 只要暴风骤雨才能使人迅速地成长
  6. java编写学生管理系统_Java实现学生管理系统
  7. Linux驱动程序学习步骤
  8. header()函数使用说明
  9. 我在公司内部的分享(秒针系统)
  10. 解决算法问题的思路总结
  11. 动画演示 Delphi 2007 IDE 功能[6] - 快速查看 Delphi 所有的核心数据类型
  12. odoo报表内部和外部布局
  13. Android 高德地图上自定义动画
  14. js replace 中文分号_关于js分号的问题?
  15. mavell 7040使用方法
  16. supervisorctl error (no such process)
  17. 飞思卡尔智能车 电机PID
  18. type definition error
  19. 小黑客,2020还没有邀请码注册 hackthebox ?reCAPTCHA验证码加载不出来?
  20. fcpx视频剪辑软件中文版

热门文章

  1. mysql properties文件路径_读取web项目properties文件路径 解决tomcat服务器找不到properties路径问题...
  2. pb的webserver增加的方法发布后没有显示_Egret 5.3 正式发布,为重度小游戏开发带来新技能...
  3. 【debug】python3安装win32com模块
  4. 参会人员管理系统C语言代码,某小型会议参会人员管理系统
  5. git关闭密码自动存储_RobotFramework实战篇PC端web自动化demo及持续集成
  6. HashMap原理解析
  7. 走在spring的路上。。。。
  8. LeetCode——7. Reverse Integer
  9. mysql常用命令整理
  10. Intellij IDEA更新SVN没有提示语