前言

梯度下降法是深度学习领域用于最优化的常见方法,根据使用的batch大小,可分为随机梯度下降法(SGD)和批量梯度下降法(BGD)和小批量梯度下降法(MBGD),这里简单介绍下并且提供Python代码演示。
如有谬误,请联系指正。转载请注明出处。
联系方式:
e-mail: FesianXu@163.com
QQ: 973926198
github: https://github.com/FesianXu
代码开源:click

梯度下降法

其基本思想就是使得网络的各个参数朝着优化函数的梯度的反方向更新,因为梯度的方向是增长最速方向,所以梯度的反方向就是减少的最速方向,因此只要设定一个步进 η \eta η,然后朝着梯度的反方向下降减小即可将优化函数最小化。公式表达如:
w i j l : = w i j l − η ∂ C ∂ w i j l w^l_{ij} := w^l_{ij}-\eta \frac{\partial{C}}{\partial{w^l_{ij}}} wijl​:=wijl​−η∂wijl​∂C​

以下讨论都假设使用的损失函数是MSE损失函数,模型为线性回归模型。
L = 1 2 m ∑ i = 1 m ( y − h ( x ( i ) ) ) 2 , m 是 样 本 总 数 L = \frac{1}{2m} \sum_{i=1}^m (y-h(x^{(i)}))^2,m是样本总数 L=2m1​i=1∑m​(y−h(x(i)))2,m是样本总数
h ( x i ) = θ T x ( i ) + b , 其 中 x ( i ) 为 第 i 个 样 本 , x ( i ) ∈ R n h(x_i) = \theta^Tx^{(i)}+b,其中x^{(i)}为第i个样本,x^{(i)} \in R^n h(xi​)=θTx(i)+b,其中x(i)为第i个样本,x(i)∈Rn

批量梯度下降法(BGD)

批量梯度下降法(Batch Gradient Descent, BGD)是梯度下降法的最原始的形式,其特点就是每一次训练迭代都需要利用所有的训练样本,其表达式为:
∂ L ∂ θ j = − 1 m ∑ i = 1 m [ ( y − h ( x ( i ) ) ) x j ( i ) ] , x j ( i ) 是 第 i 个 样 本 的 第 j 个 输 入 分 量 \frac{\partial{L}}{\partial{\theta_j}} = -\frac{1}{m}\sum_{i=1}^m [(y-h(x^{(i)}))x^{(i)}_j],x^{(i)}_j是第i个样本的第j个输入分量 ∂θj​∂L​=−m1​i=1∑m​[(y−h(x(i)))xj(i)​],xj(i)​是第i个样本的第j个输入分量
θ j : = θ j + η 1 m ∑ i = 1 m [ ( y − h ( x ( i ) ) ) x j ( i ) ] \theta_j := \theta_j+\eta \frac{1}{m}\sum_{i=1}^m [(y-h(x^{(i)}))x^{(i)}_j] θj​:=θj​+ηm1​i=1∑m​[(y−h(x(i)))xj(i)​]
  可以发现,对于每一个参数的更新都需要利用到所有的m个样本,这样会导致训练速度随着训练样本的增大产生极大地减慢,不适合于大规模训练样本的场景,但是,其解是全局最优解,精度高。

随机梯度下降法(SGD)

随机梯度下降法(Stochastic Gradient Descent, SGD)是梯度下降法的改进形式之一,其特点就是每一次训练迭代只需要训练样本集中的一个样本,其表达式为:
∂ L ∂ θ j = − ( y − h ( x ) ) x j , x j 是 选 定 样 本 的 第 j 个 输 入 分 量 \frac{\partial{L}}{\partial{\theta_j}} = -(y-h(x))x_j, x_j是选定样本的第j个输入分量 ∂θj​∂L​=−(y−h(x))xj​,xj​是选定样本的第j个输入分量
θ j : = θ j + η ( y − h ( x ) ) x j \theta_j := \theta_j+\eta (y-h(x))x_j θj​:=θj​+η(y−h(x))xj​
  和BGD进行对比可以发现,其参数更新过程中,只需要选取训练集中的一个训练样本,因此训练速度快,但是因为其只是利用了训练集的一部分知识,因此其解为局部最优解,精度较低。

小批量梯度下降法(MBGD)

小批量梯度下降法(Mini-Batch Gradient Descent, MBGD)是结合了SGD和BGD的一种改进版本,既有训练速度快,也有精度较高的特点,其基本特点就是每一次训练迭代在训练集中随机采样batch_size个样本,其表达式为:
∂ L ∂ θ j = − 1 M ∑ i = 1 M [ ( y − h ( x ( i ) ) ) x j ( i ) ] , 其 中 M 为 b a t c h _ s i z e \frac{\partial{L}}{\partial{\theta_j}} = -\frac{1}{M}\sum_{i=1}^M [(y-h(x^{(i)}))x^{(i)}_j],其中M为batch\_size ∂θj​∂L​=−M1​i=1∑M​[(y−h(x(i)))xj(i)​],其中M为batch_size
θ j : = θ j + η 1 M ∑ i = 1 M [ ( y − h ( x ( i ) ) ) x j ( i ) ] \theta_j := \theta_j+\eta \frac{1}{M}\sum_{i=1}^M [(y-h(x^{(i)}))x^{(i)}_j] θj​:=θj​+ηM1​i=1∑M​[(y−h(x(i)))xj(i)​]
  这个改进版本在深度学习的网络训练中有着广泛地应用,因为其既有较高的精度,也有较快的训练速度。

代码实现

这里以线性回归模型作为演示,实现梯度下降。代码实现基于python和numpy,matplotlib绘图,matlab生成数据集。代码由github托管。

数据描述

目标直线为 y = 2.5 x + 3.5 y=2.5x+3.5 y=2.5x+3.5,加上了一个高斯噪声,代码如:

x = -20:0.01:20;
line = 2.5*x+3.5;
x_rand = randn(1, length(line))*10;
line = line+x_rand;
x = x';
line = line';
samples = [x, line];

图像如:

主要代码如:

import numpy as np
import scipy.io as sio
import random
import matplotlib.pyplot as pltpath = u'./samples.mat'
mat = sio.loadmat(path)
dataset = mat['samples']
batch_size = 1def random_get_samples(mat, batch_size):batch_id = random.sample(range(mat.shape[0]), batch_size)ret_batch = mat[batch_id, 0]ret_line = mat[batch_id, 1]return ret_batch, ret_lineparams = {'w1': np.random.normal(size=(1)),'b': np.random.normal(size=(1))
}def predict(x):return params['w1']*x+params['b']learning_rate = 0.001
for i in range(3000):batch, line = random_get_samples(dataset, batch_size)y_pred = predict(batch)y_pred = np.reshape(y_pred, (batch_size, 1))line = np.reshape(line, (batch_size, 1))batch = np.reshape(batch, (batch_size, 1))delta = line-y_predparams['w1'] = params['w1']+learning_rate*np.sum(delta*batch)/batch_sizeparams['b'] = params['b']+learning_rate*np.sum(delta)/batch_sizeif i % 100 == 0:print(np.sum(np.abs(line-y_pred))/batch_size)print(params['w1'])
print(params['b'])
x = dataset[:, 0]
line = dataset[:, 1]
y = params['w1']*x+params['b']
plt.figure(1)
plt.plot(x, line, 'b--')
plt.plot(x, y, 'r--')
plt.show()

其中调整batch_size可以分别实现SGD,BGD和MBGD,当其为1时为SGD,当其为dataset.shape[0]时为BGD,当其为其他固定常数时如128,为MBGD。通过实验可以明显发现,当batch_size为1时,训练速度最快,但是精度欠佳,当batch_size=dataset.shape[0]时,训练速度最慢,但是精度较高,当batch_size=128或其他常数时,速度和精度兼有顾及。最后的结果图像如:

可见红线成功拟合了原数据。

随机梯度下降法,批量梯度下降法和小批量梯度下降法以及代码实现相关推荐

  1. 梯度下降法的不同形式——随机梯度下降法和小批量梯度下降法

    前文介绍了梯度下降法,其每次迭代均需使用全部的样本,因此计算量巨大.就此,提出了基于单个样本的随机梯度下降法(Stochastic gradient descent,SGD)和基于部分样本的小批量梯度 ...

  2. 『ML笔记』梯度下降法和随机梯度下降法和小批量梯度对比

    目录 1. 梯度下降法(gradient descent) 2. 随机梯度下降(Stochastic gradient descent) 3. 小批量梯度下降(Mini-Batch gradient ...

  3. 3. 机器学习中为什么需要梯度下降?梯度下降算法缺点?_浅谈随机梯度下降amp;小批量梯度下降...

    机器学习三要素 上次的报告中,我们介绍了一种用于求解模型参数的迭代算法--梯度下降法.首先需要明确一点,即"梯度下降算法"在一个完整的统计学习流程中,属于什么?根据<统计学习 ...

  4. 随机梯度下降、批量梯度下降、小批量梯度下降分类是什么?有什么区别?batch_size的选择如何实施、有什么影响?

    随机梯度下降.批量梯度下降.小批量梯度下降分类是什么?有什么区别?batch_size的选择如何实施.有什么影响? 目录

  5. Lesson 4.34.4 梯度下降(Gradient Descent)基本原理与手动实现随机梯度下降与小批量梯度下降

    Lesson 4.3 梯度下降(Gradient Descent)基本原理与手动实现 在上一小节中,我们已经成功的构建了逻辑回归的损失函数,但由于逻辑回归模型本身的特殊性,我们在构造损失函数时无法采用 ...

  6. 批量梯度下降,随机梯度下降和小批量梯度下降的区别

    批量梯度下降,随机梯度下降和小批量梯度下降的区别主要体现在用于计算梯度的样本的数量: 批量梯度下降:在每次迭代时,用整个数据集的所有样本上的梯度计算更新. 随机梯度下降:在每次迭代时,用单个样本上的梯 ...

  7. 梯度下降法的三种形式批量梯度下降法、随机梯度下降以及小批量梯度下降法

    梯度下降法的三种形式BGD.SGD以及MBGD 梯度下降法的三种形式BGD.SGD以及MBGD 阅读目录 1. 批量梯度下降法BGD 2. 随机梯度下降法SGD 3. 小批量梯度下降法MBGD 4. ...

  8. 常见优化算法批量梯度下降、小批量梯度下降、随机梯度下降的对比

    在机器学习领域中,梯度下降的方式有三种,分别是:批量梯度下降法BGD.随机梯度下降法SGD.小批量梯度下降法MBGD,并且都有不同的优缺点. 下面我们以线性回归算法(也可以是别的算法,只是损失函数(目 ...

  9. 【深度学习】——梯度下降优化算法(批量梯度下降、随机梯度下降、小批量梯度下降、Momentum、Adam)

    目录 梯度 梯度下降 常用的梯度下降算法(BGD,SGD,MBGD) 梯度下降的详细算法 算法过程 批量梯度下降法(Batch Gradient Descent) 随机梯度下降法(Stochastic ...

最新文章

  1. 爬虫正则表达式遇到的困难
  2. WebSocket-java实现
  3. LeetCode 835. 图像重叠
  4. flask-sqlalchemy Multiple Databases
  5. python对投标_batterytender-为Python del API投标-Jason Kölker Module
  6. 成都刘女士的第一场锤子科技发布会 | 现场特写
  7. oracle 10g 下载方法
  8. 华为交换机系统软件升级和安全漏洞修复教程
  9. magicbook java开发_荣耀MagicBook2019 Intel版值得买吗 MagicBook2019 Intel版笔记本详细评测...
  10. 700 boost yeezy_公司级Adidas Yeezy Boost 700上脚测评
  11. 《读书是一辈子的事》中篇 了解未来
  12. CentOS7安装PHP开发环境1-源码安装Nginx
  13. 判断某整数是否既是5又是7的整数倍()
  14. CTF之旅WEB篇(3)--ezunser PHP反序列化
  15. python语言arrows用法_python时区运算,时区,时间戳,夏令时讲解
  16. 2021下半年系统集成项目管理师客观题参考答题解析(3)
  17. 安装burp2022 --illegal-access=permit
  18. python字符串模糊匹配 - RapidFuzz
  19. 怎样有效地阅读一篇论文?
  20. English Learning - L2-2 英音地道语音语调 [iː] [ɜː] [æ] 2023.02.23 周四

热门文章

  1. Tensorflow的基本概念与常用函数
  2. 《JavaScript交互式网页设计》复习考试
  3. 浅谈 css的zoom属性
  4. 免费好用的APP你值得一试
  5. python三级联动菜单_Excel–这才是三级联动下拉菜单的正确做法
  6. MySQL实战演练——如何才能构建逾期用户画像?【数据可视化】
  7. Android Framework 电源子系统(04)核心方法updatePowerStateLocked分析-2 循环处理  更新显示设备状态
  8. 蜻蜓FM回应恶意代码事件 音频行业仍将现721格局
  9. PS更新升级Adobe Camera Raw(ACR)15.3
  10. 天平游码读数例题_使用天平游码时的读数方法-word