之前在 机器学习算法数学基础之 —— 线性代数篇 中,总结过求解线性回归的两种方法:

  1. 最小二乘法
  2. 梯度下降法

这篇文章重点总结一下梯度下降法中的一些细节和需要注意的地方。


梯度下降法是什么

假设有一个估计函数:  ,
其代价函数(cost function)为: 

这个代价函数是 x(i) 的估计值与真实值 y(i) 的差的平方和,前面乘上 1/2,是因为在求导的时候,这个系数就不见了。

梯度下降法的流程:

1)首先对 θ 赋值,这个值可以是随机的,也可以让 θ 是一个全零的向量。
2)改变 θ 的值,使得 J(θ) 的值按梯度下降的方向减小。

参数 θ 与误差函数 J(θ) 的关系图

红色部分表示 J(θ) 有着比较高的取值,我们希望能够让 J(θ) 的值尽可能的低,也就是取到深蓝色的部分。θ0、θ1 表示 θ 向量的两个维度。上面提到梯度下降法的第一步,是给 θ 一个初值,假设随机的初值位于图上红色部分的十字点。然后我们将 θ 按梯度下降的方向进行调整,就会使 J(θ) 往更低的方向进行变化,如图所示,算法的结束将在 θ 下降到无法继续下降为止。

θ 的更新: 

θi 表示更新前的值,减号后边的部分表示按梯度方向减少的量,α 表示步长,也就是每次按梯度减少的方向变化多少。

梯度是有方向的,对于一个向量 θ,每一维分量 θi 都可以求出一个梯度的方向,我们就可以找到一个整体的方向,在变化的时候,我们就朝着下降最多的方向进行变化,就可以达到一个最小点,不管它是局部的还是全局的。

所以梯度下降法的原理简单来说就是:

在变量空间的某个点,函数沿梯度方向的变化率最大,所以在优化目标函数的时候,沿着负梯度方向减小函数值,可以最快达到优化目标。


特征缩放(Feature Scaling)

在 Stanford 的 Machine Learning 课程中,老师提到了使用梯度下降法时可能用到的特征缩放法,为的是当所要研究的数据集中出现数据值范围相差较大的 feature 时,保证所有 feature 有相近范围内的值。

以课程中的房价问题为例,数据集的 feature 中有房屋尺寸  和房间数量  两项,房屋尺寸的范围是 0~2000 feet^2,房间数量的范围是 1~5。此时以两个参数  和  为横纵坐标,绘制代价函数  的等高线图如下:

代价函数的等高线图

解决方法是对两个 feature 进行归一化处理:

x1 = size / 2000
x2 = number of rooms / 5

进行特征缩放后,代价函数的等高线图就变成了这样:

特征缩放后的代价函数等高线图

很明显,归一化处理前,代价函数的等高线图高又窄,在梯度下降过程中,需要反复迭代很多次,才能达到理想的位置,会花费较长的时间。归一化处理后,两个 feature 的数值大小处于相近的范围内,因此横纵坐标  和  两个参数的变化范围也变得相近,对  的影响不再有 feature 值在范围大小伤的影响,从而反映代价函数上。在等高线图上,梯度下降过程中变化受 x1 大小影响的  的变化减小了,变得和 x2 的变化趋势一致,从而等高线图近似圆形。

至于把 feature 缩小到什么范围,没有特别的要求,只要各个 feature 的数值范围相近就可以了,只是为了计算速度可以有所提升。

顺便引用一下 Andrew Ng 老师在课程中用到的例题:


步长 α(学习率)的选择

在参数  的更新式:  中,  称为步长,决定了参数每次被增加或减小的大小,步长如何选择是决定梯度下降法表现的一个因素。如果步长过大,可能会直接越过最佳点,导致无法收敛,甚至发散;如果步长过小,可能导致迭代次数过多,降低效率。

步长的选择过大时,导致越过最佳点

关于步长的选择,有很多种方法,如果初始步长不合适,在后边要不断进行调整,调整的方法有很多种,关于这个问题国内外也有很多论文对其研究过。一旦找到了合适的步长,大多数情况下就不需要再改变了,有人会觉得随着  越来越小,如果步长不变,更新式中第二项的值也会越来越小,迭代的进度会慢下来。但是,随着越来越接近最佳点,梯度也会越来越大,也就是  的微分值也会越来越大,所以迭代的速度并不会变慢。

最常用的选择步长的方法是按3倍调整,即:0.001、0.003、0.01、0.03、0.1、0.3、1 …… 按这个倍率进行测试,寻找能使  下降速度最快的步长范围,确定范围后再对其进行微调。


用 Python 实现梯度下降法

import numpy as np
import randomdef gradient_descent(alpha, x, y, ep=0.0001, max_iter=10000):converged = Falseiter = 0m = x.shape[0] # 数据的行数# 初始化参数(theta)t0 = np.random.random(x.shape[1])t1 = np.random.random(x.shape[1])# 代价函数, J(theta)J = sum([(t0 + t1*x[i] - y[i])**2 for i in range(m)])# 进行迭代while not converged:# 计算训练集中每一行数据的梯度 (d/d_theta j(theta))grad0 = 1.0/m * sum([(t0 + t1*x[i] - y[i]) for i in range(m)]) grad1 = 1.0/m * sum([(t0 + t1*x[i] - y[i])*x[i] for i in range(m)])# 更新参数 theta# 此处注意,参数要同时进行更新,所以要建立临时变量来传值temp0 = t0 - alpha * grad0temp1 = t1 - alpha * grad1t0 = temp0t1 = temp1# 均方误差 (MSE)e = sum( [ (t0 + t1*x[i] - y[i])**2 for i in range(m)] ) if abs(J-e) <= ep:print 'Converged, iterations: ', iter, '!!!'converged = TrueJ = e   # 更新误差值iter += 1  # 更新迭代次数if iter == max_iter:print 'Max interactions exceeded!'converged = Truereturn t0,t1

参考:3 Types of Gradient Descent Algorithms for Small & Large Data Sets

机器学习:用梯度下降法实现线性回归相关推荐

  1. Python机器学习:梯度下降法003线性回归中的梯度下降法

    接下来使用下列公式编程

  2. 梯度下降算法c语言实现,机器学习中梯度下降法原理及用其解决线性回归问题的C语言实现...

    本文讲梯度下降(Gradient Descent)前先看看利用梯度下降法进行监督学习(例如分类.回归等)的一般步骤: 1, 定义损失函数(Loss Function) 2, 信息流forward pr ...

  3. [机器学习-2]梯度下降法及实现(python)

    [机器学习-2]梯度下降法及实现(python) 样例(Example) 利普西斯连续(L-continuity) 利普西斯光滑(L-smoothness) 凸集(Convex Set) 凸函数(Co ...

  4. 基于matlab的梯度下降法实现线性回归

    基于matlab的梯度下降法实现线性回归 1 绪论 1.1线性回归的定义 1.2单变量线性回归 1.3多变量线性回归 2 梯度下降 2.1 cost function 2.2 梯度下降:解决线性回归的 ...

  5. 梯度下降法求解线性回归

    梯度下降法求解线性回归 通过梯度下降法求解简单的一元线性回归 分别通过梯度下降算法和sklearn的线性回归模型(即基于最小二乘法)解决简单的一元线性回归实际案例,通过结果对比两个算法的优缺. 通过最 ...

  6. Python机器学习:梯度下降法004实现线性回归中的梯度下降法

    直接上代码 import numpy as np import matplotlib.pyplot as plt 生成模拟数据 np.random.seed(666) x = 2 * np.rando ...

  7. 线性回归介绍及分别使用最小二乘法和梯度下降法对线性回归C++实现

    回归:在这类任务中,计算机程序需要对给定输入预测数值.为了解决这个任务,学习算法需要输出函数f:Rn→R.除了返回结果的形式不一样外,这类问题和分类问题是很像的.这类任务的一个示例是预测投保人的索赔金 ...

  8. 机器学习之梯度下降法(GD)和坐标轴下降法(CD)

    梯度下降法 梯度下降法(Gradient Descent, GD)常用于求解无约束情况下凸函数(Convex Function)的极小值,是一种迭代类型的算法,因为凸函数只有一个极值点,故求解出来的极 ...

  9. 机器学习_梯度下降法(BGD、SGD、MBGD)

    除了最小二乘法求解损失函数之外,梯度下降法是另一种求解损失函数的方法. 梯度下降的直观理解,先用一个例子说明: 假设找凸函数y=x**2的最小值,通过迭代的方式,假设x=2时,y=4,x=0.8时,y ...

最新文章

  1. 开源项目贡献者_如何吸引新的贡献者加入您的开源项目
  2. Lync Server 2013 标准版部署(二)DNS记录权限
  3. 使用BaseAdapter实现复杂的ListView
  4. UIImageView01
  5. 外键 级联操作 mysql
  6. poi 升级至4.x 的问题总结(POI Excel 单元格内容类型判断并取值)
  7. linux 7 %3e命令,Linux操作系统常用基础命令
  8. linux 环境下安装 docker 精简步骤
  9. 决策树(六)--随机森林
  10. SpringBoot2整合Jooq和Flyway(一)
  11. html flash轮播图,flash滚动图片制作 图片列表左右滚动轮播
  12. 我国中药产业的国际竞争力研究
  13. win10下Java的JDK11下载与安装教程
  14. 用nmap扫描内网conficker
  15. 最好用的 7 款 Vue admin 后台管理系统测评
  16. Mac 右键展示Copy path
  17. 【微信小程序】获取Bmob后端云数据库某一数据表的所有记录
  18. C++模板的特化与偏特化
  19. MySQL 系统自带的数据库有哪些?每个数据库的作用是什么?
  20. python爬虫scrapy爬取新闻标题及链接_python爬虫框架scrapy爬取梅花网资讯信息

热门文章

  1. c++ python混合编程 restful_How to use Python to build a RESTful Web Service
  2. java 简单阻塞队列,制作一个简单的任务队列(使用阻塞队列)
  3. 服务器可用性监测系统,可用性监控区别
  4. linux系统上安装toma,Linux-tar - osc_btnnkvs0的个人空间 - OSCHINA - 中文开源技术交流社区...
  5. 前端JS:判断list(数组)中的json对象是否重复
  6. 二叉树的BFS及DFS
  7. HtmlDom 基础
  8. Docker 加速器升级版
  9. oracle创建索引后sqlldr导入错误
  10. 流程控制库async