Batch Normalization :深度网络中的BN层

参考文献:

Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift

Sergey Ioffe
Google Inc., sioffe@google.com
Christian Szegedy
Google Inc., szegedy@google.com

Internal Covariate Shift 问题

在训练神经网络的过程中,实际上我们希望网络学习的是输入数据的分布特征,但是由于深度网络中每一层的输入都是前一层的输出,而前一层的参数是随着学习变化的,这就导致它的输出也随之改变,这种分布的不稳定导致训练深度网络较为困难(训练困难的意思就是说,可能需要很小心地设置初值,并且给一个很小的学习率,从而效率变慢),特别是对于有饱和现象的非线性激活函数,比如sigmoid,tanh等等,如果分布的变化过程中落到了饱和区域,那么就会很难继续下降。这种现象被称为:Internal covariate shift,内部协方差漂移。BN操作就是为了解决这一问题而产生的。

mini-batch SGD

在做梯度下降的时候,通常SGD效果表现较好,包括SGD的一些变种,比如momentum,Adagrad都取得了SOTA的效果。SGD是minimize所有训练样本上的loss的平均值,或者同样的,loss的和,但是在训练时候常常不会直接把所有的样本投喂进去训练,而是把大小为N的样本分成多个大小为m的mini-batch,然后顺次将mini-batch投入训练,这里,用minibatch上的对参数的导数来代替全体样本的导数,也就是梯度,然后按照这个梯度进行下降。这样的好处是,一方面比一个个训练要快,另一方面,在一个batch上的grad可以看做是整体的一个estimate。

但是sgd调节超参数很麻烦,而且由于每层的输入都是要收到前面的层影响,因此前面的参数的每一点微小的改变都会随着网络层数的加深被逐渐放大,最终影响分布。

考虑一个两层的网络:

可以发现,如果把第一层的输出记为x,那么第二层的梯度只和x直接相关,因此可以看成一个独立的子网络,这个网络输入就是x。我们一般希望对于一个网络的训练来说,训练和测试具有相同的分布,那么这一点对于这个子网络也同样适用。

如果对于sigmoid函数做激活的话,如果落在饱和区域,那么就会有梯度消失的现象。而且网络越深越明显。如果我们能让非线性的输入更稳定,那么就会更好训练网络,并且更快的收敛。

normalization via mini-batch statistics

The first is that instead of whitening the features in layer inputs and outputs jointly, we willnormalize each scalar feature independently, by making it have the mean of zero and the variance of 1.

Batch Normalizing Transform

其中,均值和方差是通过minibatch里的数据统计计算得出的,而gamma和beta,也就是scale和shift是可以学习的。然后,作者说明了BN层是可导的。

trainning & inference

上面是训练的情况,那么inference的时候没有batch,因此BN Transform前两步骤求mu和sigma就不能做,此时的解决方法就是:用训练集上的每个mini-batch的平均值和方差求期望,并计算出无偏估计量(对均值就是样本均值,对方差要乘以m/(m-1)的系数)。

具体做法如下:

对于convolutional network的BN层

对于卷积层,记每个batch大小为m,feature map为多个p×q,那么希望对m个batch中的数据对应的fm中的每个位置相同的点做归一化,For convolutional layers, we additionally want the normalization to obey the convolutional property – so that different elements of the same feature map, at different locations, are normalized in the same way.

BN层的作用

上面可以看出,Moreover, larger weights lead to smaller gradients, and Batch Normalization will stabilize the parameter growth.

BN层可以用更大的学习率 lr,也可以看做是对模型的一个规范化(所以有了BN层可以取消或者减轻dropout的使用)。

BN层的加速

  • 增大学习率
  • 去掉dropout
  • 减少 l2 权重正则化
  • 加速学习率衰减
  • 去掉LRN层(局部相应归一化)
  • 更彻底打乱训练样本
  • 减少光度的畸变(因为batch normalization训练速度更快,所以看到训练样本更少,因此希望让它更关注真实样本)

2018年04月22日16:58:55

将人间变成地狱的原因,恰恰是人们试图将其变成天堂。 —— 诗人,荷尔德林 【塔楼之诗】

Batch Normalization :深度网络中的BN层相关推荐

  1. 关于深度网络中的Normalization:BN/RBN/WN/LN的记录

    深度前馈网络中前层输入的变化往往会引起后面层的变化,后面的层需要不断地调整自己的参数去适应前层的输入变化,这被称为internal covariance shift.这不仅会使网络训练变得缓慢,同时会 ...

  2. Batch Normalization应该放在ReLU非线性激活层的前面还是后面?

    点击上方"小白学视觉",选择加"星标"或"置顶" 重磅干货,第一时间送达 编辑:CVDaily  转载自:计算机视觉Daily https: ...

  3. 聊聊Batch Normalization在网络结构中的位置

    Batch Normalization在网络结构中的位置 1. 什么是Batch Normalization? 谷歌在2015年就提出了Batch Normalization(BN),该方法对每个mi ...

  4. caffe中的batchNorm层(caffe 中为什么bn层要和scale层一起使用)

    caffe中的batchNorm层 链接: http://blog.csdn.net/wfei101/article/details/78449680 caffe 中为什么bn层要和scale层一起使 ...

  5. 【Dlib】dlib实现深度网络学习之 input层

    1. dlib::input 模板类,深度神经网络的简单输入层,它将某种图像作为输入并将其加载到网络中. 这是一个基本的输入层,它只是简单地将图像复制到一个张量中. 注意:dlib::input只支持 ...

  6. 【深度学习】聊聊Batch Normalization在网络结构中的位置

    炼丹知识点 Knowledge Points of alchemy "葡萄是一点一点成熟的,知识是一天一天积累的." Batch Normalization 1. 什么是Batch ...

  7. Batch Normalization在CNN中的原理,nb与lrb的区别

    参靠<Batch Normalization 学习笔记> 通过上面的学习,我们知道BN层是对于每个神经元做归一化处理,甚至只需要对某一个神经元进行归一化,而不是对一整层网络的神经元进行归一 ...

  8. 深度学习中的 BN (BatchNormalization)理解

    CNN 三大算子: CONV + BN +RELU 1.为什么 BN 指导思想: 机器学习领域有个很重要的假设:IID独立同分布假设,就是假设训练数据和测试数据是满足相同分布的. 具有统一规格的数据, ...

  9. tensorflow中的BN层实现

    import tensorflow as tf import numpy as np import matplotlib.pyplot as plt from tensorflow.keras imp ...

  10. 单个神经元在深度网络中的作用

    目录 背景描述 任务一:场景分类任务的分析 数据集与模型 网络分析方法 实验结果 实验结果分析 实验结果验证 任务二:场景生成任务的分析 背景描述 大量的实验证明,深度神经网络擅于找到大型数据集上的分 ...

最新文章

  1. 对抗网络用于人脸转正--Beyond Face Rotation
  2. C++知识点41——运算符的重载概念与分数类实现(下)
  3. SPOJ 287 Smart Network Administrator
  4. 2016.9.9《Oracle查询优化改写技巧与案例》电子工业出版社一书中的技巧
  5. Python学习笔记:常用内建模块6 (urllib)
  6. #1117. 编码 ( 字典树版 ) 题解分析
  7. 【Python】SQLAlchemy长时间未请求,数据库连接断开的原因、解决方案
  8. linux sybase 自动备份,Linux平台下Sybase数据库备份方法分析.doc
  9. vite配置 vite.config.js
  10. 微信小程序获取当前地址以及选择地址详解 地点标记
  11. 学习ROS过程中遇到的一些小问题以及解决办法的记录
  12. mysql创建的是拉丁_mysql 拉丁1 转换成 utf8
  13. 深度相机---(3)双目立体视觉
  14. 定期存款转消费卡项目需求书
  15. 标准C程序设计七---05
  16. 信用卡是超前消费的一种手段
  17. python基本语句大全_python常见语句汇总
  18. idea全局查找字段
  19. 从4G到5G,从物联网到云计算 通信的下一个引爆点在哪里?
  20. 25.有5个人做在一起, 问第五个人多少岁? 他说比第四个人大2岁. 问第四个人岁数, 他说比第是三个人大2岁. 问第三个人, 又说比第二人大两岁. 问第二个人, 说比第一个人大两岁. 最后问第一个人

热门文章

  1. android apk 应用分发平台
  2. 插件:Could not find library corresponding to plugin……
  3. 什么是:arguments
  4. ajax 点击下一页,ajax调用不会进入下一页
  5. oracle远程投毒漏洞复现,oracle TNS Listener远程投毒(CVE-2012-1675)漏洞分析、复现...
  6. 瑞星杀毒软件卸载方法
  7. linux下格式化SD卡
  8. 台式机通过网线连接笔记本的wifi网络
  9. 吴恩达深度学习系列笔记
  10. U盘启动盘制作与ISO分享