Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift

论文链接: https://arxiv.org/abs/1502.03167

一、 Problem Statement

在训练的时候,输入分布的变化要求较低的学习率和较为严谨的参数初始化,使得训练模型困难。此外,各层输入分布的变化带来了一个问题,因为各层需要不断地适应新的分布,把这个问题称为内部协变量偏移(internal covariate shift)。

二、 Direction

对每一个training mini-batch使用了normalization。使得训练网络的时候可以使用更大的学习率较少的关注参数初始化。同时,它也充当了正则化项, 在某些情况可以取消使用Dropout

三、 Method

SGD最优化权重参数www:
w=arg min⁡w1N∑i=1Nl(xi,w)w = \argmin_{w} \frac{1}{N}\sum_{i=1}^Nl(x_i,w) w=wargmin​N1​i=1∑N​l(xi​,w)
mini-batch使用是:
1m∂l(xi,w)∂w\frac{1}{m}\frac{\partial l(x_i, w)}{\partial w} m1​∂w∂l(xi​,w)​
使用mini-batch的好处:

  1. 小批量损失梯度是对训练集上梯度的估计,随着batch的增大,其质量也随之提高
  2. 用batch取计算比计算m次单个样本更加有效。

但SGD需要对超参数比较敏感,随着网络加深,小的变化也会被放大。此外还有内部协变量偏移的问题。
比如下面有一个网络计算:
l=F2(F1(x,w1),w2)l=F_2(F_1(x,w_1),w_2) l=F2​(F1​(x,w1​),w2​)
学习F2F_2F2​的时候,会把F1(x,w1)F_1(x,w_1)F1​(x,w1​)当作输入,这样的梯度就是:
w2←w2−αm∑i=1m∂F2(xi,w2)∂w2w_2 \leftarrow w_2 - \frac{\alpha}{m}\sum_{i=1}^m\frac{\partial F_2(x_i,w_2)}{\partial w_2} w2​←w2​−mα​i=1∑m​∂w2​∂F2​(xi​,w2​)​

因此,各层输入x的分布随时间保持固定是有利的。

  1. w2w_2w2​就不需要重新调整来补偿输入xxx分布的变化
  2. 对于激活函数等也有积极的影响。

比如,假设FFF是sigmoid activate function z=f(Wx+b),f=11+exp(−x)z=f(Wx+b), \quad f=\frac{1}{1+exp(-x)}z=f(Wx+b),f=1+exp(−x)1​。随着∣x∣|x|∣x∣增大,f′(x)f'(x)f′(x)趋向于0,这意味着,梯度会消失,导致训练困难。
同时,由于xxx受到w,bw,bw,b和下面所有层的参数的影响,这些参数在训练过程中的变化可能会使xxx的许多维数进入非线性的饱和状态从而减慢收敛速度,随着网络加深,影响扩大。这个问题通常可以由ReLU解决
ReLU(x)=max(x,0)ReLU(x)=max(x,0)ReLU(x)=max(x,0),小心的初始化。如果我们能确保非线性输入的分布随着网络的训练保持稳定,优化器就不太可能陷入饱和状态,加速训练。

1. Towards Reducing Internal Covariate Shift

作者一开始提出了:
x^(k)=x(k)−E[x(k)]Var[x(k)]\hat{x}^{(k)} = \frac{x^{(k)}-E[x^{(k)}]}{\sqrt{Var[x^{(k)}]}} x^(k)=Var[x(k)]​x(k)−E[x(k)]​
这里的期望和协方差是基于整个训练集的。这样的normalization可能会改变每一层所表示的东西。为了解决这个问题,作者引入了两个变量γ(k)\gamma^{(k)}γ(k)和β(k)\beta^{(k)}β(k),确保嵌入到网络中的变换可以表示同样的变换。
y(k)=γ(k)x^(k)β(k)y^{(k)} = \gamma^{(k)} \hat{x}^{(k)} \beta^{(k)} y(k)=γ(k)x^(k)β(k)

这两个超参数,随着网络的学习得到。事实上,如果它们是最优的话,
γ(k)=Var[x(k)],β(k)=E[x(k)]\gamma^{(k)} = \sqrt{Var[x^{(k)}]}, \quad \beta^{(k)} = E[x^{(k)}] γ(k)=Var[x(k)]​,β(k)=E[x(k)]

但基于整个训练集计算均值和协方差是不实际的。 因此作者做了第二个简化: 只计算每一个mini-batch里面的均值和协方差,具体如下:

其中,又引入了一个常量参数ϵ\epsilonϵ,用来作数值稳定。这样的话,任何x^\hat{x}x^的值的分布都有0期望,协方差为1的性质。
下面是Batch Normalization的导数形式:
∂l∂x^i=∂l∂yi⋅γ∂l∂σB2=∑i=1m∂l∂x^i⋅(xi−μB)⋅−12(σB2+ϵ)−3/2∂l∂μB=(∑i=1m∂l∂x^i⋅−1σB2+ϵ)+∂l∂σB2⋅∑i=1m−2(xi−μB)m∂l∂xi=∂l∂x^i⋅1σB2+ϵ+∂l∂σB2⋅2(xi−μB)m+∂l∂μB⋅1m∂l∂γ=∑i=1m∂l∂yi⋅x^i∂l∂β=∑i=1m∂l∂yi\begin{aligned} &\frac{\partial l}{\partial \hat{x}_i} = \frac{\partial l}{\partial y _i} \cdot \gamma \\ &\frac{\partial l}{\partial \sigma^2_{\Beta}} = \sum_{i=1}^m \frac{\partial l}{\partial \hat{x}_i} \cdot (x_i - \mu_{\Beta}) \cdot \frac{-1}{2}(\sigma_{\Beta}^2+\epsilon)^{-3/2} \\ &\frac{\partial l}{\partial \mu_{\Beta}} = (\sum_{i=1}^m \frac{\partial l}{\partial \hat{x}_i}\cdot \frac{-1}{\sqrt{\sigma_{\Beta}^2+\epsilon}})+\frac{\partial l}{\partial \sigma_{\Beta}^2}\cdot \frac{\sum_{i=1}^m-2(x_i-\mu_{\Beta})}{m}\\ &\frac{\partial l}{\partial x_i} = \frac{\partial l}{\partial \hat{x}_i}\cdot \frac{1}{\sqrt{\sigma^2_{\Beta}+\epsilon}}+\frac{\partial l}{\partial \sigma^2_{\Beta}} \cdot \frac{2(x_i - \mu_{\Beta})}{m}+\frac{\partial l}{\partial \mu_{\Beta}}\cdot \frac{1}{m} \\ &\frac{\partial l}{\partial \gamma} = \sum_{i=1}^m \frac{\partial l}{\partial y_i}\cdot \hat{x}_i \\ &\frac{\partial l}{\partial \beta} = \sum_{i=1}^m \frac{\partial l}{\partial y_i} \end{aligned} ​∂x^i​∂l​=∂yi​∂l​⋅γ∂σB2​∂l​=i=1∑m​∂x^i​∂l​⋅(xi​−μB​)⋅2−1​(σB2​+ϵ)−3/2∂μB​∂l​=(i=1∑m​∂x^i​∂l​⋅σB2​+ϵ​−1​)+∂σB2​∂l​⋅m∑i=1m​−2(xi​−μB​)​∂xi​∂l​=∂x^i​∂l​⋅σB2​+ϵ​1​+∂σB2​∂l​⋅m2(xi​−μB​)​+∂μB​∂l​⋅m1​∂γ∂l​=i=1∑m​∂yi​∂l​⋅x^i​∂β∂l​=i=1∑m​∂yi​∂l​​

2. Training and Inference with Batch Normalized Networks。

从上面可以知道,训练的时候,均值和协方差是通过mini-batch的样本计算的。但是在推理的时候,我们想要输出只与输入有关系。所以一旦模型训练好了,使用
x^=x−E[x]Var[x]+ϵ\hat{x} = \frac{x-E[x]}{\sqrt{Var[x] + \epsilon}} x^=Var[x]+ϵ​x−E[x]​

这里的Var[x]=mm−1EB[σB2]Var[x]=\frac{m}{m-1}E_{\Beta}[\sigma_{\Beta}^2]Var[x]=m−1m​EB​[σB2​],这里的mmm是训练时候batch size的值,σB2\sigma^2_{\Beta}σB2​是训练时候的样本协方差。
具体的算法如下:

3. Batch Normalized Convolutional Networks

在卷积的时候,把B\BetaB设置为m′=∣B∣=m⋅pqm'=|\Beta|=m\cdot pqm′=∣B∣=m⋅pq,其中,mmm为batch size, p和q为feature map的大小。学习每一个feature map的γ(k)\gamma^{(k)}γ(k)和β(k)\beta^{(k)}β(k)。

4. Batch Normalization enables higher learning rates

在传统的深度网络中,太高的学习率或者太低的学习率会导致梯度消失或者爆炸,也有可能进入局部最低。Batch Normlization有助于解决这个问题。批处理规范化还使训练对参数规模更具弹性。通常情况下,较大的学习速率会增加层参数的规模,从而在反向传播过程中放大梯度,导致模型爆炸。
然而,使用了Batch Normalization, 通过一个网络层的时候,反向梯度不会受到learning rate的影响:
BN(Wx)=BN((αW)x)BN(Wx) = BN((\alpha W)x) BN(Wx)=BN((αW)x)
相对应的梯度为:
∂BN(αW)x∂x=∂BN(Wx)∂x\frac{\partial \text{BN}(\alpha W)x}{\partial x} = \frac{\partial \text{BN}(Wx)}{\partial x} \\ ∂x∂BN(αW)x​=∂x∂BN(Wx)​
∂BN(αW)x∂W=1α⋅∂BN(Wx)∂W\frac{\partial \text{BN}(\alpha W)x}{\partial W} = \frac{1}{\alpha} \cdot \frac{\partial \text{BN}(Wx)}{\partial W} ∂W∂BN(αW)x​=α1​⋅∂W∂BN(Wx)​
可以看到,大的学习率会导致更小的梯度。

5. Batch Normalization regularizes the model

Batch Normalization起到正则化的作用,避免了overfitting。

四、 Conclusion

  • 所提出的Batch Normalization 可以固定网络层输入的均值和协方差
  • 且通过减少梯度对参数的大小和初始值的依赖,有助于网络中梯度的传递。
  • Batch Normalization也有正则化的作用,减少了Dropout的使用
  • Batch Normalization通过防止网络陷入饱和模式,使得使用饱和非线性成为可能。

简单的在神经网络中使用BN,并没有完全利用好这篇文章的方法,作者还提出:

  1. 增大学习率,加速网络训练
  2. 删除Dropout,加速训练,也能避免过拟合。
  3. 减少L2L_2L2​ 权重正则化。在Inception网络中,使用了L2L_2L2​ loss来控制过拟合,通过控制这个Loss的权重提升精度。
  4. 加速学习率的衰减。
  5. 移除local response normalization。
  6. 彻底打乱数据样本。
  7. 减少photometric distortions。因为标网络训练更快,观察每个训练示例的次数更少,所以我们让培训师通过较少的扭曲来关注更“真实”的图像。

Batch-normalized 应该放在非线性激活层的前面还是后面?
在BN的原始论文中,BN是放在非线性激活层前面的。

We add the BN transform immediately before the nonlinearity

五、 Reference

  1. https://www.zhihu.com/question/283715823/answer/438882036

Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift 论文笔记相关推荐

  1. Batch normalization:accelerating deep network training by reducing internal covariate shift的笔记

    说实话,这篇paper看了很久,,到现在对里面的一些东西还不是很好的理解. 下面是我的理解,当同行看到的话,留言交流交流啊!!!!! 这篇文章的中心点:围绕着如何降低  internal covari ...

  2. 批归一化《Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift》

    批归一化<Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift ...

  3. 【BN】《Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift》

    ICML-2015 在 CIFAR-10 上的小实验可以参考如下博客: [Keras-Inception v2]CIFAR-10 文章目录 1 Background and Motivation 2 ...

  4. 《Batch Normalization Accelerating Deep Network Training by Reducing Internal Covariate Shift》阅读笔记与实现

    今年过年之前,MSRA和Google相继在ImagenNet图像识别数据集上报告他们的效果超越了人类水平,下面将分两期介绍两者的算法细节. 这次先讲Google的这篇<Batch Normali ...

  5. 【论文泛读】 Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift

    [论文泛读] Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift ...

  6. 论文阅读:Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift

    文章目录 1.论文总述 2.Why does batch normalization work 3.BN加到卷积层之后的原因 4.加入BN之后,训练时数据分布的变化 5.与BN配套的一些操作 参考文献 ...

  7. 读文献——《Batch Normalization Accelerating Deep Network Training by Reducing Internal Covariate Shift》

    在自己阅读文章之前,通过网上大神的解读先了解了一下这篇文章的大意,英文不够好的惭愧... 大佬的文章在https://blog.csdn.net/happynear/article/details/4 ...

  8. 深度学习论文--Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift

    本文翻译论文为深度学习经典模型之一:GoogLeNet-BN 论文链接:https://arxiv.org/abs/1502.03167v3 摘要:训练深度神经网络的难度在于:前一层网络参数的变化,导 ...

  9. Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift

    机器学习领域有个很重要的假设:IID独立同分布假设,就是假设训练数据和测试数据是满足相同分布的,这是通过训练数据获得的模型能够在测试集获得好的效果的一个基本保障.BatchNorm就是在深度神经网络训 ...

最新文章

  1. WPF/Silverlight深度解决方案:(十六)传值实现
  2. 2019年JAVA比较火的框架_2019年Java技术中当前流行的三大框架
  3. 图片不能置于底层怎么办_PPT中常遇到的图片问题和解决方案
  4. P7717-「EZEC-10」序列【Trie】
  5. php中可以实现分支,PHP中的分支及循环语句
  6. python爬取百度贴吧中的所有邮箱_python写的百度贴吧邮箱采集(带界面)
  7. 阿里云:构建全球企业内外安全网络最佳实践
  8. 【ElasticSearch】ElasticSearch在数十亿级别数据下,如何提高查询效率? 性能优化
  9. Android精品开源项目整理_V20140221
  10. Python数据处理(入门教程)
  11. UWB定位系统油库人员定位解决方案
  12. ads1115多片并联
  13. 谈谈百度竞价的一些思路
  14. Ubuntu系统如何屏幕截图
  15. 【项目笔记_答题器】rp552d usb hid 在seewo win10 设备上启动无法识别
  16. 2022Java面试题大全(整理版)面试题附答案详解,最全面详细
  17. 多模态情感识别数据集和模型(下载地址+最新综述2021.8)
  18. Ubuntu Windows双系统切换最简方法!!!
  19. 通过百度API实现图片车牌号识别
  20. 【盘点】值得推荐的优质文章!

热门文章

  1. ios系统跳转苹果商店
  2. Linux-PAM系统管理指南
  3. sass,sass-loader的使用
  4. 面向 JavaScript 开发人员的 5 大物联网库
  5. oracle 查询一个月内每天某个时间段的数据
  6. 小米面试总结(附答案)
  7. 新小程序消息推送平台,不止于推送!
  8. 学会这10种实用的定时任务,拿捏所有业务场景
  9. 药品冷链物流中的温度监测
  10. 微信小程序单位__rpx2px