Batch_Size(批尺寸)是机器学习中一个重要参数,涉及诸多矛盾,下面逐一展开。

首先,为什么需要有 Batch_Size 这个参数?

Batch 的选择,首先决定的是下降的方向。如果数据集比较小,完全可以采用全数据集 ( Full Batch Learning )的形式,这样做至少有 2 个好处:其一,由全数据集确定的方向能够更好地代表样本总体,从而更准确地朝向极值所在的方向。其二,由于不同权重的梯度值差别巨大,因此选取一个全局的学习率很困难。 Full Batch Learning 可以使用 Rprop 只基于梯度符号并且针对性单独更新各权值。

对于更大的数据集,以上 2 个好处又变成了 2 个坏处:其一,随着数据集的海量增长和内存限制,一次性载入所有的数据进来变得越来越不可行。其二,以 Rprop 的方式迭代,会由于各个 Batch 之间的采样差异性,各次梯度修正值相互抵消,无法修正。这才有了后来 RMSProp 的妥协方案。

既然 Full Batch Learning 并不适用大数据集,那么走向另一个极端怎么样?

所谓另一个极端,就是每次只训练一个样本,即 Batch_Size = 1。这就是在线学习(Online Learning)。线性神经元在均方误差代价函数的错误面是一个抛物面,横截面是椭圆。对于多层神经元、非线性网络,在局部依然近似是抛物面。使用在线学习,每次修正方向以各自样本的梯度方向修正,横冲直撞各自为政,难以达到收敛。如图所示:

可不可以选择一个适中的 Batch_Size 值呢?

当然可以,这就是批梯度下降法(Mini-batches Learning)。因为如果数据集足够充分,那么用一半(甚至少得多)的数据训练算出来的梯度与用全部数据训练出来的梯度是几乎一样的。

在合理范围内,增大 Batch_Size 有何好处?

  • 内存利用率提高了,大矩阵乘法的并行化效率提高。
  • 跑完一次 epoch(全数据集)所需的迭代次数减少,对于相同数据量的处理速度进一步加快。
  • 在一定范围内,一般来说 Batch_Size 越大,其确定的下降方向越准,引起训练震荡越小。

盲目增大 Batch_Size 有何坏处?

  • 内存利用率提高了,但是内存容量可能撑不住了。
  • 跑完一次 epoch(全数据集)所需的迭代次数减少,要想达到相同的精度,其所花费的时间大大增加了,从而对参数的修正也就显得更加缓慢。
  • Batch_Size 增大到一定程度,其确定的下降方向已经基本不再变化。

调节 Batch_Size 对训练效果影响到底如何?

这里跑一个 LeNet 在 MNIST 数据集上的效果。MNIST 是一个手写体标准库,我使用的是 Theano 框架。这是一个 Python 的深度学习库。安装方便(几行命令而已),调试简单(自带 Profile),GPU / CPU 通吃,官方教程相当完备,支持模块十分丰富(除了 CNNs,更是支持 RBM / DBN / LSTM / RBM-RNN / SdA / MLPs)。在其上层有 Keras 封装,支持 GRU / JZS1, JZS2, JZS3 等较新结构,支持 Adagrad / Adadelta / RMSprop / Adam 等优化算法。如图所示:


运行结果如上图所示,其中绝对时间做了标幺化处理。运行结果与上文分析相印证:

  • Batch_Size 太小,算法在 200 epoches 内不收敛。
  • 随着 Batch_Size 增大,处理相同数据量的速度越快。
  • 随着 Batch_Size 增大,达到相同精度所需要的 epoch 数量越来越多。
  • 由于上述两种因素的矛盾, Batch_Size 增大到某个时候,达到时间上的最优。
  • 由于最终收敛精度会陷入不同的局部极值,因此 Batch_Size 增大到某些时候,达到最终收敛精度上的最优。

欢迎一起讨论。

本文转自http://blog.csdn.net/ycheng_sjtu/article/details/49804041,感谢原作者的付出和分享。


神经网络算法学习---mini-batch++++mini-batch和batch的区别相关推荐

  1. MATLAB遗传神经网络算法学习

    误差反向传播(BP)神经网络根据反向传播的误差来调节连接权值和阈值,具有很强的非线性模拟能力 第一层为输入层,节点数目M由输入向量维数确定:中间层为双隐含层,节点数可选,一般不同层有不同的节点数:最后 ...

  2. BP神经网络算法学习

    BP(Back Propagation)网络是1986年由Rumelhart和McCelland为首的科学家小组提出,是一种按误差逆传播算法训练的多层前馈网络,是眼下应用最广泛的神经网络模型之中的一个 ...

  3. BP神经网络算法学习---基础理论1

    本文转自http://blog.csdn.net/acdreamers/article/details/44657439,对于BP基本原理的介绍非常的干净利索清晰,感谢原作者的付出和分享. 今天来讲B ...

  4. 深度学习笔记第二门课 改善深层神经网络 第三周 超参数调试、Batch正则化和程序框架...

    本文是吴恩达老师的深度学习课程[1]笔记部分. 作者:黄海广[2] 主要编写人员:黄海广.林兴木(第四所有底稿,第五课第一二周,第三周前三节).祝彦森:(第三课所有底稿).贺志尧(第五课第三周底稿). ...

  5. TF之BN:BN算法对多层中的每层神经网络加快学习QuadraticFunction_InputData+Histogram+BN的Error_curve

    TF之BN:BN算法对多层中的每层神经网络加快学习QuadraticFunction_InputData+Histogram+BN的Error_curve 目录 输出结果 代码设计 输出结果 代码设计 ...

  6. 深度学习(二十九)Batch Normalization 学习笔记

    Batch Normalization 学习笔记 原文地址:http://blog.csdn.net/hjimce/article/details/50866313 作者:hjimce 一.背景意义 ...

  7. 深度学习之 神经网络算法原理

    深度学习之 神经网络算法原理 什么是神经网络算法? 初中映射 神经网络 求映射? 求解参数 图解求参 参考文献 什么是神经网络算法? 初中映射 初中的时候 y = f(x) 老师进过 映射 . 通过 ...

  8. DL之CNN:计算机视觉之卷积神经网络算法的简介(经典架构/论文)、CNN优化技术、调参学习实践、CNN经典结构及其演化、案例应用之详细攻略

    DL之CNN:计算机视觉之卷积神经网络算法的简介(经典架构/论文).CNN优化技术.调参学习实践.CNN经典结构.案例应用之详细攻略 目录 卷积神经网络算法的简介 0.Biologically Ins ...

  9. scikit-learn学习之神经网络算法

    ====================================================================== 本系列博客主要参考 Scikit-Learn 官方网站上的 ...

最新文章

  1. 相机标定 matlab opencv ROS三种方法标定步骤(3)
  2. SVN使用教程之——分支、合并
  3. Win7x64+VS2012+OpenCV2.4.3+CMake2.8.10+TBB41重编译OpenCV
  4. WebSocket 从入门到写出开源库
  5. 看来美国的霸道不仅仅是针对Lenovo的,SONY也被威胁--索尼被判侵犯专利,PlayStation游戏机销售面临威胁【ZZ】...
  6. linux设备驱动学习(二)——字符设备编写及测试
  7. python编写安全工具_Python3学习系列(四):编写属于自己的邮件伪造工具
  8. Python数据类型--字典
  9. 【Jmeter篇】jmeter+Ant+Jenkins实现自动化测试集成(一)
  10. 飞鸽推广的超级商务微博的栏目介绍(二)
  11. 物理服务器转虚拟路径,服务器配置虚拟路径
  12. 红豆、绿豆、黑豆、花生、莲子、薏仁米放在一起吃,可以吗?
  13. SQL Server活动监视器
  14. WPS设置标题行固定
  15. linux 磁盘配额 期限,linux磁盘配额管理
  16. 根据域名查询IP地址的网站推荐
  17. 符号_变压器电路图符号大全
  18. 【CSS3】浅谈display属性flex弹性布局、gird网格布局
  19. Java练习习题,百钱买百鸡问题,用100文钱买鸡,公鸡5文钱一只,母鸡3文钱一只,小鸡3只1文钱
  20. 小米测试总监的十年测试路,愿测试人都不再迷茫

热门文章

  1. ffmpeg编解码详细过程
  2. strcpy()源代码
  3. 聊聊 top 命令中的 CPU 使用率
  4. 一句话输出没有结束符的字符串
  5. strstrsubstr、AfxGetApp
  6. 【Pytorch神经网络实战案例】11 循环神经网络结构训练语言模型并进行简单预测
  7. 智慧交通day04-特定目标车辆追踪03:siamese在目标跟踪中的应用-SiamRPN(2017)
  8. 相邻位数字差值的绝对值不能超过_热点争议中技术问题,伺服控制有几个零点?对应真绝对值多圈编码器意义...
  9. LeetCode 5832. 构造元素不等于两相邻元素平均值的数组
  10. LeetCode 549. 二叉树中最长的连续序列(树上DP)