7.3 小批量随机梯度下降

在每一次迭代中,梯度下降使用整个训练数据集来计算梯度,因此它有时也被称为批量梯度下降(batch gradient descent)。而随机梯度下降在每次迭代中只随机采样一个样本来计算梯度。正如我们在前几章中所看到的,我们还可以在每轮迭代中随机均匀采样多个样本来组成一个小批量,然后使用这个小批量来计算梯度。下面就来描述小批量随机梯度下降。

设目标函数 f ( x ) : R d → R f(\boldsymbol{x}): \mathbb{R}^d \rightarrow \mathbb{R} f(x):Rd→R。在迭代开始前的时间步设为0。该时间步的自变量记为 x 0 ∈ R d \boldsymbol{x}_0\in \mathbb{R}^d x0​∈Rd,通常由随机初始化得到。在接下来的每一个时间步 t > 0 t>0 t>0中,小批量随机梯度下降随机均匀采样一个由训练数据样本索引组成的小批量 B t \mathcal{B}_t Bt​。我们可以通过重复采样(sampling with replacement)或者不重复采样(sampling without replacement)得到一个小批量中的各个样本。前者允许同一个小批量中出现重复的样本,后者则不允许如此,且更常见。对于这两者间的任一种方式,都可以使用

g t ← ∇ f B t ( x t − 1 ) = 1 ∣ B ∣ ∑ i ∈ B t ∇ f i ( x t − 1 ) \boldsymbol{g}_t \leftarrow \nabla f_{\mathcal{B}_t}(\boldsymbol{x}_{t-1}) = \frac{1}{|\mathcal{B}|} \sum_{i \in \mathcal{B}_t}\nabla f_i(\boldsymbol{x}_{t-1}) gt​←∇fBt​​(xt−1​)=∣B∣1​i∈Bt​∑​∇fi​(xt−1​)

来计算时间步 t t t的小批量 B t \mathcal{B}_t Bt​上目标函数位于 x t − 1 \boldsymbol{x}_{t-1} xt−1​处的梯度 g t \boldsymbol{g}_t gt​。这里 ∣ B ∣ |\mathcal{B}| ∣B∣代表批量大小,即小批量中样本的个数,是一个超参数。同随机梯度一样,重复采样所得的小批量随机梯度 g t \boldsymbol{g}_t gt​也是对梯度 ∇ f ( x t − 1 ) \nabla f(\boldsymbol{x}_{t-1}) ∇f(xt−1​)的无偏估计。给定学习率 η t \eta_t ηt​(取正数),小批量随机梯度下降对自变量的迭代如下:

x t ← x t − 1 − η t g t . \boldsymbol{x}_t \leftarrow \boldsymbol{x}_{t-1} - \eta_t \boldsymbol{g}_t. xt​←xt−1​−ηt​gt​.

基于随机采样得到的梯度的方差在迭代过程中无法减小,因此在实际中,(小批量)随机梯度下降的学习率可以在迭代过程中自我衰减,例如 η t = η t α \eta_t=\eta t^\alpha ηt​=ηtα(通常 α = − 1 \alpha=-1 α=−1或者 − 0.5 -0.5 −0.5)、 η t = η α t \eta_t = \eta \alpha^t ηt​=ηαt(如 α = 0.95 \alpha=0.95 α=0.95)或者每迭代若干次后将学习率衰减一次。如此一来,学习率和(小批量)随机梯度乘积的方差会减小。而梯度下降在迭代过程中一直使用目标函数的真实梯度,无须自我衰减学习率。

小批量随机梯度下降中每次迭代的计算开销为 O ( ∣ B ∣ ) \mathcal{O}(|\mathcal{B}|) O(∣B∣)。当批量大小为1时,该算法即为随机梯度下降;当批量大小等于训练数据样本数时,该算法即为梯度下降。当批量较小时,每次迭代中使用的样本少,这会导致并行处理和内存使用效率变低。这使得在计算同样数目样本的情况下比使用更大批量时所花时间更多。当批量较大时,每个小批量梯度里可能含有更多的冗余信息。为了得到较好的解,批量较大时比批量较小时需要计算的样本数目可能更多,例如增大迭代周期数。

7.3.1 读取数据

本章里我们将使用一个来自NASA的测试不同飞机机翼噪音的数据集来比较各个优化算法 [1]。我们使用该数据集的前1,500个样本和5个特征,并使用标准化对数据进行预处理。

%matplotlib inline
import numpy as np
import time
import torch
from torch import nn, optim
import sys
sys.path.append("..")
import d2lzh_pytorch as d2ldef get_data_ch7():  # 本函数已保存在d2lzh_pytorch包中方便以后使用data = np.genfromtxt('../../data/airfoil_self_noise.dat', delimiter='\t')data = (data - data.mean(axis=0)) / data.std(axis=0)return torch.tensor(data[:1500, :-1], dtype=torch.float32), \torch.tensor(data[:1500, -1], dtype=torch.float32) # 前1500个样本(每个样本5个特征)features, labels = get_data_ch7()
features.shape # torch.Size([1500, 5])

7.3.2 从零开始实现

3.2节(线性回归的从零开始实现)中已经实现过小批量随机梯度下降算法。我们在这里将它的输入参数变得更加通用,主要是为了方便本章后面介绍的其他优化算法也可以使用同样的输入。具体来说,我们添加了一个状态输入states并将超参数放在字典hyperparams里。此外,我们将在训练函数里对各个小批量样本的损失求平均,因此优化算法里的梯度不需要除以批量大小。

def sgd(params, states, hyperparams):for p in params:p.data -= hyperparams['lr'] * p.grad.data

下面实现一个通用的训练函数,以方便本章后面介绍的其他优化算法使用。它初始化一个线性回归模型,然后可以使用小批量随机梯度下降以及后续小节介绍的其他算法来训练模型。

# 本函数已保存在d2lzh_pytorch包中方便以后使用
def train_ch7(optimizer_fn, states, hyperparams, features, labels,batch_size=10, num_epochs=2):# 初始化模型net, loss = d2l.linreg, d2l.squared_lossw = torch.nn.Parameter(torch.tensor(np.random.normal(0, 0.01, size=(features.shape[1], 1)), dtype=torch.float32),requires_grad=True)b = torch.nn.Parameter(torch.zeros(1, dtype=torch.float32), requires_grad=True)def eval_loss():return loss(net(features, w, b), labels).mean().item()ls = [eval_loss()]data_iter = torch.utils.data.DataLoader(torch.utils.data.TensorDataset(features, labels), batch_size, shuffle=True)for _ in range(num_epochs):start = time.time()for batch_i, (X, y) in enumerate(data_iter):l = loss(net(X, w, b), y).mean()  # 使用平均损失# 梯度清零if w.grad is not None:w.grad.data.zero_()b.grad.data.zero_()l.backward()optimizer_fn([w, b], states, hyperparams)  # 迭代模型参数if (batch_i + 1) * batch_size % 100 == 0:ls.append(eval_loss())  # 每100个样本记录下当前训练误差# 打印结果和作图print('loss: %f, %f sec per epoch' % (ls[-1], time.time() - start))d2l.set_figsize()d2l.plt.plot(np.linspace(0, num_epochs, len(ls)), ls)d2l.plt.xlabel('epoch')d2l.plt.ylabel('loss')

当批量大小为样本总数1,500时,优化使用的是梯度下降。梯度下降的1个迭代周期对模型参数只迭代1次。可以看到6次迭代后目标函数值(训练损失)的下降趋向了平稳。

def train_sgd(lr, batch_size, num_epochs=2):train_ch7(sgd, None, {'lr': lr}, features, labels, batch_size, num_epochs)train_sgd(1, 1500, 6)

输出:

loss: 0.243605, 0.014335 sec per epoch

当批量大小为1时,优化使用的是随机梯度下降。为了简化实现,有关(小批量)随机梯度下降的实验中,我们未对学习率进行自我衰减,而是直接采用较小的常数学习率。随机梯度下降中,每处理一个样本会更新一次自变量(模型参数),一个迭代周期里会对自变量进行1,500次更新。可以看到,目标函数值的下降在1个迭代周期后就变得较为平缓。

train_sgd(0.005, 1)

输出:

loss: 0.243433, 0.270011 sec per epoch

虽然随机梯度下降和梯度下降在一个迭代周期里都处理了1,500个样本,但实验中随机梯度下降的一个迭代周期耗时更多。这是因为随机梯度下降在一个迭代周期里做了更多次的自变量迭代,而且单样本的梯度计算难以有效利用矢量计算。

当批量大小为10时,优化使用的是小批量随机梯度下降。它在每个迭代周期的耗时介于梯度下降和随机梯度下降的耗时之间。

train_sgd(0.05, 10)

输出:

loss: 0.242805, 0.078792 sec per epoch

7.3.3 简洁实现

在PyTorch里可以通过创建optimizer实例来调用优化算法。这能让实现更简洁。下面实现一个通用的训练函数,它通过优化算法的函数optimizer_fn和超参数optimizer_hyperparams来创建optimizer实例。

# 本函数与原书不同的是这里第一个参数优化器函数而不是优化器的名字
# 例如: optimizer_fn=torch.optim.SGD, optimizer_hyperparams={"lr": 0.05}
def train_pytorch_ch7(optimizer_fn, optimizer_hyperparams, features, labels,batch_size=10, num_epochs=2):# 初始化模型net = nn.Sequential(nn.Linear(features.shape[-1], 1))loss = nn.MSELoss()optimizer = optimizer_fn(net.parameters(), **optimizer_hyperparams)def eval_loss():return loss(net(features).view(-1), labels).item() / 2ls = [eval_loss()]data_iter = torch.utils.data.DataLoader(torch.utils.data.TensorDataset(features, labels), batch_size, shuffle=True)for _ in range(num_epochs):start = time.time()for batch_i, (X, y) in enumerate(data_iter):# 除以2是为了和train_ch7保持一致, 因为squared_loss中除了2l = loss(net(X).view(-1), y) / 2 optimizer.zero_grad()l.backward()optimizer.step()if (batch_i + 1) * batch_size % 100 == 0:ls.append(eval_loss())# 打印结果和作图print('loss: %f, %f sec per epoch' % (ls[-1], time.time() - start))d2l.set_figsize()d2l.plt.plot(np.linspace(0, num_epochs, len(ls)), ls)d2l.plt.xlabel('epoch')d2l.plt.ylabel('loss')

使用PyTorch重复上一个实验。

train_pytorch_ch7(optim.SGD, {"lr": 0.05}, features, labels, 10)

输出:

loss: 0.245491, 0.044150 sec per epoch

小结

  • 小批量随机梯度每次随机均匀采样一个小批量的训练样本来计算梯度。
  • 在实际中,(小批量)随机梯度下降的学习率可以在迭代过程中自我衰减。
  • 通常,小批量随机梯度在每个迭代周期的耗时介于梯度下降和随机梯度下降的耗时之间。

参考文献

[1] 飞机机翼噪音数据集。https://archive.ics.uci.edu/ml/datasets/Airfoil+Self-Noise


注:除代码外本节与原书此节基本相同,原书传送门

7.3_minibatch-sgd相关推荐

  1. 批量梯度下降(BGD)、随机梯度下降(SGD)以及小批量梯度下降(MBGD)的理解

    批量梯度下降(BGD).随机梯度下降(SGD)以及小批量梯度下降(MBGD)的理解 </h1><div class="clear"></div> ...

  2. Pytorch实现MNIST(附SGD、Adam、AdaBound不同优化器下的训练比较) adabound实现

     学习工具最快的方法就是在使用的过程中学习,也就是在工作中(解决实际问题中)学习.文章结尾处附完整代码. 一.数据准备   在Pytorch中提供了MNIST的数据,因此我们只需要使用Pytorch提 ...

  3. 从 SGD 到 Adam —— 深度学习优化算法概览 各种优化器 重点

    20210701 https://blog.51cto.com/u_15064630/2571266 [机器学习基础]优化算法详解 详细 https://blog.csdn.net/u01338501 ...

  4. Adam那么棒,为什么还对SGD念念不忘 (1) —— 一个框架看懂优化算法

    机器学习界有一群炼丹师,他们每天的日常是: 拿来药材(数据),架起八卦炉(模型),点着六味真火(优化算法),就摇着蒲扇等着丹药出炉了. 不过,当过厨子的都知道,同样的食材,同样的菜谱,但火候不一样了, ...

  5. Adam那么棒,为什么还对SGD念念不忘 (3)—— 优化算法的选择与使用策略

    在前面两篇文章中,我们用一个框架梳理了各大优化算法,并且指出了以Adam为代表的自适应学习率优化算法可能存在的问题.那么,在实践中我们应该如何选择呢? 本文介绍Adam+SGD的组合策略,以及一些比较 ...

  6. 一个框架看懂优化算法之异同 SGD/AdaGrad/Adam

    Adam那么棒,为什么还对SGD念念不忘 (1) -- 一个框架看懂优化算法 机器学习界有一群炼丹师,他们每天的日常是: 拿来药材(数据),架起八卦炉(模型),点着六味真火(优化算法),就摇着蒲扇等着 ...

  7. Adam那么棒,为什么还对SGD念念不忘 (2)—— Adam的两宗罪

    在上篇文章中,我们用一个框架来回顾了主流的深度学习优化算法.可以看到,一代又一代的研究者们为了我们能炼(xun)好(hao)金(mo)丹(xing)可谓是煞费苦心.从理论上看,一代更比一代完善,Ada ...

  8. 深度学习中的随机梯度下降(SGD)简介

    随机梯度下降(Stochastic Gradient Descent, SGD)是梯度下降算法的一个扩展. 机器学习中反复出现的一个问题是好的泛化需要大的训练集,但大的训练集的计算代价也更大.机器学习 ...

  9. Adam 那么棒,为什么还对 SGD 念念不忘?一个框架看懂深度学习优化算法

    作者|Juliuszh 链接 | https://zhuanlan.zhihu.com/juliuszh 本文仅作学术分享,若侵权,请联系后台删文处理 机器学习界有一群炼丹师,他们每天的日常是: 拿来 ...

  10. pytorch .item_pytorch + SGD

    梯度下降是模型优化常用方法,原理也比较简单,简言之就是参数沿负梯度方向更新,参数更新公式如下. ,其中 表示的是步长,用于控制每次更新移动的步伐. 我们将使用pytorch来试验下这个方法. 首先先生 ...

最新文章

  1. Nat. Commun. | 序列到功能的深度学习框架加速工程核糖调节剂设计和优化
  2. 苹果新的编程语言 Swift 语言进阶(三)--基本运算和扩展运算
  3. Dimple.js基础
  4. 【经济法常识转摘】借款人逾期不还钱,利率如何确定?
  5. Python格式化输出的三种方式
  6. Vue项目打包成桌面程序exe除了使用electron-vue你还可以这样
  7. 鼠标放到控件上 DIV悬浮提示效果(四种)
  8. 下载文件(弹出迅雷来下载)
  9. java框架谁搭建_从零开始搭建一个开发框架(Java + Hibernate + Spring + Oracle)
  10. python循环结构语句_python控制语句---循环结构语句
  11. bzoj4009: [HNOI2015]接水果
  12. java权限管理面试_java shiro面试题
  13. 新版PassXYZ已上线,新增一次性密码(OTP)管理功能
  14. 智能人物画像综合分析系统——Day6
  15. 大型网站建设方案(学院网站建设方案)
  16. 如何提升数据分析的高级感:反客为主、展示神迹、引经据典、繁花似锦
  17. 深圳赏给我的耳光:说到底,生活就是一场接着一场的较量
  18. QQ空间无敌装逼,复制下面的任一代码粘贴即可出现意想不到的图案。
  19. kivy android wifi,Kivy / Buildozer VM Ubuntu不能连接到网络的问题解决
  20. python二级证书含金量排名_计算机二级证书含金量有多高?你真的知道吗???...

热门文章

  1. mac brew安装php7.4
  2. JSON与properties文件的相互转换
  3. 南宁第一职业技术学校计算机专业,南宁第一职业技术学校
  4. C++ fgets()函数
  5. 【Linux编程】三分钟让你学会Linux下用户密码更改
  6. strcat函数用法的一点看法
  7. 关于我和计算机的故事
  8. python一般用几个空格表示缩进_Python 就是使用缩进来表示代码块,一般使用几个空格来表示一个缩进_女子礼仪答案_学小易找答案...
  9. n (n - 1)的用法
  10. PreTranslateMessage()函数的使用说明