传统的梯度下降算法,遍历全部数据集算一次损失函数,然后算函数对各个参数的梯度,更新梯度。这种梯度下降法叫做这称为Batch gradient descent(BDG)。我们知道 Batch 梯度下降的做法是,在对训练集执行梯度下降算法时,必须处理整个训练集,然后才能进行下一步梯度下降。当训练数据量非常多时,每更新一次参数都要把数据集里的所有样本都看一遍,虽然收敛性能好,但是一次迭代需要等待多长时间,速度慢,会极大的降低训练速度。

随机梯度下降,stochastic gradient descent(SDG),每看一个数据就算一下损失函数,然后求梯度更新参数。这个方法速度比较快,但是永远不会收敛,可能在最优点附近晃来晃去,无法收敛。两次参数的更新也有可能互相抵消掉,造成目标函数震荡的比较剧烈。

因此,为了克服两种方法的缺点,现在一般采用的是一种折中方法,mini-batch gradient decent。这种方法把数据分为若干个batch,按batch来更新参数,这样,一个batch中的一组数据共同决定了本次梯度的方向,下降起来就不容易跑偏,减少了随机性。另一方面因为批的样本数与整个数据集相比小了很多,计算量也不是很大。

蓝色:为 batch 梯度下降,即 mini batch size = m,
紫色:为 stochastic 梯度下降,即 mini batch size = 1,
绿色:为 mini batch 梯度下降,即 1 < mini batch size < m。

mini-batch梯度下降

如果选择介于1和最大训练数据量之间的一个batch_size数据量进行训练,叫mini-batch 梯度下降

当b=1的时候,Mini-batch梯度下降就等于随机梯度下降(SDG);当b=m的时候,Mini-batch梯度下降就等于BDG。所以小批量梯度下降法的效果和batcih size的选择相关。
如果训练集较小,一般小于2000的,就直接使用 Batch gradient descent 。这样做至少有 2 个好处:其一,由全数据集确定的方向能够更好地代表样本总体,从而更准确地朝向极值所在的方向。其二,由于不同权重的梯度值差别巨大,因此选取一个全局的学习率很困难。

例如,样本数为m,每一个batch的大小为64,那么我们就可以分为m/64个样本,如果m%64不等于0说明还有剩的样本,则第m/64+1个batch不足64,大小就等于m%64。

一般 Mini Batch gradient descent 的大小在 64 到 512 之间,选择 2 的 n 次幂会运行得相对快一些。

每次训练的不能保证使用的是同一份数据,所以每一个batch不能保证都下降,整体训练loss变化会有很多噪声,但是整体趋势是下降的,随后会在最优值附近波动,不会收敛,但会会更持续地靠近最小值。

mini-batch算法实现

**1.确定mini-batch size。**一般有32、64、128等2的n次幂,按自己的数据集而定,确定mini-batch_num=m/mini-batch_num + 1;

m batch_size
<2000 batch_size=m,即采用batch梯度下降法
>2000 batch_size=64,128,256,512 mini-batch梯度下降法

2.在分组之前将原数据集顺序打乱,随机打乱;
3.分组,将打乱后的数据集分组;
4.将分好后的mini-batch组放进迭代循环中,每次循环都做mini-batch_num次梯度下降。

使用mini-batch梯度下降法时,一次遍历训练集,能让你做m/batch_size个梯度下降。当然正常来说你想要多次遍历训练集,还需要为另一个while循环设置另一个for循环。所以你可以一直处理遍历训练集,直到最后你能收敛到一个合适的精度。
详细算法可参考吴恩达机器学习:
https://www.bilibili.com/video/BV164411b7dx?p=104

伪代码

repeat num iterations{遍历每一个batch{1.前向传播:(1)计算Z=W*X+b(2)计算激活项的值A=g(Z)2.计算cost函数J3.反向传播求解梯度4.更新权重}
}

总结

简单来说,
当每次是对整个训练集进行梯度下降的时候,就是 batch 梯度下降(BDG),
当每次只对一个样本进行梯度下降的时候,是 随机梯度下降(SDG),
当每次处理样本的个数介于二者之间,就是 mini batch 梯度下降

神经网络之Mini-Batch梯度下降相关推荐

  1. Mini batch梯度下降法(吴恩达深度学习视频笔记)

    深度学习并没有在大数据中表现很好,但是我们可以利用一个巨大的数据集来训练神经网络,而在巨大的数据集基础上进行训练速度很慢,因此进行优化算法能够很大程度地增加训练速度,提升效率. 本节,我们将谈谈Min ...

  2. 深度学习--TensorFlow(4)BP神经网络(损失函数、梯度下降、常用激活函数、梯度消失梯度爆炸)

    目录 一.概念与定义 二.损失函数/代价函数(loss) 三.梯度下降法 二维w与loss: 三维w与loss: 四.常用激活函数 1.softmax激活函数 2.sigmoid激活函数 3.tanh ...

  3. 用Numpy搭建神经网络第二期:梯度下降法的实现

    https://www.toutiao.com/a6696699352833851908/ 大数据文摘出品 作者:蒋宝尚 小伙伴们大家好呀~~用Numpy搭建神经网络,我们已经来到第二期了.第一期文摘 ...

  4. 为什么需要 Mini-batch 梯度下降,及 TensorFlow 应用举例

    本文知识点: 什么是 mini-batch 梯度下降 mini-batch 梯度下降具体算法 为什么需要 mini-batch 梯度下降 batch, stochastic ,mini batch 梯 ...

  5. Mini-batch 梯度下降 与Tensorflow中的应用

    mini-batch在深度学习中训练神经网络时经常用到,这是一种梯度下降方法,可以很快的降低cost,接下来系统介绍一下. 1. 什么是 mini-batch梯度下降 先来快速看一下BGD,SGD,M ...

  6. 梯度下降算法_批梯度下降法,Minibatch梯度下降法和随机梯度下降法之间的区别...

    什么是梯度下降法? 梯度下降法是一种机器学习中常用的优化算法,用来找到一个函数(f)的参数(系数)的值,使成本函数(cost)最小. 当参数不能解析计算时(如使用线性代数),并且必须通过优化算法搜索时 ...

  7. Lesson 11.1-11.5 梯度下降的两个关键问题反向传播的原理走出第一步:动量法开始迭代:batch和epochs在Fashion—MNIST数据集熵实现完整的神经网络

    在之前的课程中,我们已经完成了从0建立深层神经网络,并介绍了各类神经网络所使用的损失函数.本节课开始,我们将以分类深层神经网络为例,为大家展示神经网络的学习和训练过程.在介绍PyTorch的基本工具A ...

  8. 2.2 理解 Mini-batch 梯度下降-深度学习第二课《改善深层神经网络》-Stanford吴恩达教授

    ←上一篇 ↓↑ 下一篇→ 2.1 Mini-batch 梯度下降 回到目录 2.3 指数加权平均 理解 Mini-batch 梯度下降 (Understanding Mini-batch Gradie ...

  9. 002-深度学习数学基础(神经网络、梯度下降、损失函数)

    002-深度学习数学基础(神经网络.梯度下降.损失函数) 这里在进入人工智能的讲解之前,你必须知道几个名词,其实也就是要简单了解一下人工智能的数学基础,不然就真的没办法往下讲了. 本节目录如下: 前言 ...

最新文章

  1. 数据库分页存储过程(5)
  2. 解决依赖的moduleBuildConfig.DEBUG总是未false的问题
  3. linux:安装ubuntu18-04
  4. js特效代码-鼠标样式
  5. STM32CubeMX系列教程 5.0版本环境开发——1.工程搭建
  6. delete hive_Hive高级调优
  7. 判断设置两天后时间,时间戳
  8. (办公)eclipse连接github cannot open git-upload-pack(git-receive-pack)
  9. 漫谈moosefs中cgi各项的意义
  10. 关闭安卓系统导航栏右下角自动旋转按钮
  11. thinkphp3.2.3 支付宝授权登录php
  12. uniapp 之 禁用手机物理返回键
  13. 移远BC35-G固件升级
  14. Linux的账号与权限管理
  15. 在PPT中显示时间以提醒演讲者控制时间
  16. RecyclerView加载网络图片防止图片错乱问题
  17. 简单聊天室(java版)
  18. html光标定位到文本框,js获取光标位置和设置文本框光标位置示例代码
  19. 百度 2022 提前批-数据挖掘算法工程师面经
  20. 学做网站有哪些注意事项(下)

热门文章

  1. 爬虫第5课-从QQ音乐上爬取周杰伦前5页歌词
  2. java比较器参数只传进去一个_post请求springMVC后台,只传一个参数,后台用一个字符串接受,参数名对应,但入参却为null,求讲解和解决,谢谢...
  3. 超完整素数算法总结归纳
  4. 腾讯SNG全链路日志监控平台之构建挑战
  5. Lucene—全文检索工具包
  6. 搞笑:当程序员当了爸爸
  7. 不要太爽,这个微信群可以学英语,而且全程免费!
  8. Spring boot jpa 多表关联查询
  9. 笔记本联想拯救者ubuntu系统21.04改善合盖无法唤起屏幕的为问题
  10. centos安装c语言编译器,Centos7安装GCC编译器