这篇文章是论文 Batch Normalization Accelerating Deep Network Training by Reducing Internal Covariate Shift 的翻译,其中精简了部分内容的同时解释了相关的概念,如有错误,敬请指教。

Abstract

在神经网络训练过程中,前一层权重参数的改变会造成之后每层输入样本分布的改变,这造成了训练过程的困难。为了解决这个问题,通常会使用小的学习率和参数初始化技巧,这样导致了训练速度变慢,尤其是训练具有饱和非线性的模型。我们将这一现象定义为internal covariate shift,并提出通过规范化输入来解决。将标准化作为模型架构的一部分,并且对每一个training mini-batch使用标准化来增强算法的效果。BN允许我们使用更高的学习率,并且不用太过关心初始化。同时由于有正则化的效果,有时还能省略Dropout,加速模型训练。BN在ImageNet上的测试取得了很好的效果。

1 Introduction

SGD(Stochastic gradient descent)是一种训练神经网络的有效方法。SGD的改进算法,如momentum、Adagrad都应用广泛。SGD通过优化权重参数来减小loss:

训练过程中的mini-batch的梯度计算如下:

使用mini-batch的优点如下:

  • mini-batch的梯度是对整个数据集梯度的估计,当数据集过大时,计算量暴增,使用mini-batch能提高梯度更新的效率(可参考GoodFellow的DeepLearning教材)
  • 对于含有m个单独样本的mini-batch的一次计算比单独计算m次更快,这得益于并行计算。

虽然SGD简单有效,但是超参数调整十分麻烦,尤其是学习率和初始化。由于每层网络都受到之前所有网络的影响,随着网络加深,每层网络参数的微小变化都会被之后的网络逐渐放大。
每层网络参数的变化会导致一些问题,因为之后的网络必须适应新的分布。假设有一个网络的loss为:

其中F1,F2为任意变换,内层的函数

可以看成是外层函数的一个参数,即:

对于一步梯度更新:

可以看做是独立的网络F2,输入为x。输入分布的特殊性质可以使得训练过程更高效,比如训练集和测试集完全相同的时候。因此,如果固定x的分布,参数就不用再补偿因为x的分布改变引起的变化。
在子网络中固定输入的分布对于外层网络有积极的影响。考虑一个具有sigmoid激活函数的网络:

其中,W和b是需要学习的参数,g(x)为sigmoid函数。由sigmoid函数的性质可知,当x的绝对值增大时,g的导数会逐渐减小至零。这意味着对于所有的x=Wu+b,除开绝对值较小的部分,其余的x值都会使梯度逐渐消失(gradient vanish)。由于x受W和b的影响,改变这两个参数有可能使得x的很多维度陷入饱和的非线性区,降低收敛速度。这个效应会随着网络层数的加深变得愈发明显。在实践过程中,这个问题通常可以使用ReLU来解决,但是如果可以控制非线性输入的分布更加稳定,就可以使优化器更少陷入饱和区域,以达到加速的目的。

2 Towards Reducing Internal Covariate Shift

我们将 Internal Covariate Shift1 定义为训练过程中网络参数的改变引起的网络激活函数分布的改变。BN的目标是减少 Internal Covariate Shift ,以达到加速的目的。已知的是,如果网络的输入进行了白化23,训练时会收敛得更快。这可以通过线性变换,使得输入为零均值,单位方差,并且去除了相关性的数据分布。白化操作可以通过直接更改网络或者改变优化算法来实现。但是如果直接对原始数据进行白化,需要计算输入x的协方差矩阵和求逆,计算量会很大。这促使我们寻找一种替代的方法来完成输入标准化,并且在参数更新过程中不需要对整个数据集进行分析。
注:
1. 白化的解释可参考 UFLDL的教程4
2. Covariate Shift的解释可以参考这里

3 Normalization via Mini-Batch Statistics

因为对每一层的输入进行白化代价太大,这里作出两个必要的简化。

独立地标准化每一个标量特征(均值为零,方差为1),而不是将输入层和输出层的特征一起进行白化。对于一个d维的输入,标准化每一个单独的维度的公式为:

其中,期望和方差都是在 整个数据集 上进行计算。LeCun5 的文献表明,这种标准化的过程即使在数据没有去除相关性的条件下,仍然可以加速收敛。
注意,如果简单地对一层的输入进行标准化可能会影响这一层的表达能力。比如在使用sigmoid激活函数的时候,如果简单地将输入数据进行零均值单位方差标准化,将使得原始数据更加集中,对应于sigmoid函数的中间部分,也就是只使用了函数的线性区域。详细的解释可以参考 这里 为避免这一问题,可以使用重构变换

其中的参数γ(k) \gamma^{(k)} 和 β(k) \beta^{(k)} 代表了放缩和平移的变换参数,需要和原始模型的参数一起学习。这样能恢复模型的表达能力。
注:对于重构变换的理解是,在标准化的过程中减去了均值,这相当于对原始数据的平移;除以标准差,这相当于对原始数据的放缩。因此重构变换可以考虑为反向的过程,即先进行放缩,也就是乘上γ(k) \gamma^{(k)} ;再进行平移,也就是加上β(k) \beta^{(k)} ,只不过这两个参数需要进行学习。最特殊的情况是,如果设置 γ(k)=Var[x(k)]−−−−−−−√\gamma^{(k)} = \sqrt{Var[x^{(k)}]} 和 β(k)=E[x(k)] \beta^{(k)} = E[x^{(k)}] ,就可以恢复原始的数据。
在每个训练步骤的参数更新是基于整个训练集的情况下,我们可以使用整个训练集的来标准化激活函数。但是这对于SGD不可行,因此作出第二个简化:
因为在SGD中使用mini-batches,每一个mini-batches在每一次激活过程中产生均值和方差的估计。这样,标准化过程使用的统计数据就能在梯度反向传播过程中用上。注意使用mini-batches是通过方差的每一个单独的维度进行的,而不是协方差。如果是协方差的情况,将需要进行规范化,因为mini-batch的大小可能比白化的激活函数要少,这将导致协方差矩阵奇异。
考虑一个mini-batch B,大小为m。BN算法流程如下:

其中,ε \varepsilon 是为保证数值稳定性添加到mini-bacth中的常数。
BN算法可以添加到网络中来控制激活。注意到BN()并不是单独地在每一个训练样本中处理激活,而是既依赖训练样本,也依赖于mini-batch中的其他样本。经过伸缩和平移变换的yy 传播到其他层。如果忽略ε\varepsilon ,只要每一个mini-batch的元素是从同一个分布中采样,x^\hat{x} 的取值分布都满足零均值和单位方差。子网络的输入具有固定的均值和方差。虽然所有x^(k)\hat{x}^{(k)} 的联合分布会在训练过程中发生改变。由于标准化子网络的引入会加速子网络的训练过程,最终加速整个网络的训练过程。
在训练过程中需要反向传播ℓ\ell 的梯度,同时需要计算BN中参数的梯度,使用链式法则,简化前的公式为:

因此,BN是一个可微的变换,并且引入了正则化的效果。这保证了在模型训练时,每一层展示出更少的内部internal covariate shift的输入中进行学习,因此加速了训练过程。更重要的是,学习到的仿射变换参数可以保证网络的容量不受影响。

3.1 Training and Inference with Batch-Normalized Networks

对于BN的实际应用,原先接受 xx 作为输入的单层网络现在接受BN(x)BN(x) 作为输入。使用BN的网络可以用SGD及相关的改进算法(如Adagrad)来进行优化。但是在推理过程中,不需要网络进行标准化,而是希望网络的输出只依赖于原始的输入。因此,在网络训练完成后,使用如下公式:

即使用整个数据集所有样本,而不是一个mini-batch的统计资料。如果忽略ε \varepsilon ,标准化的激活就具有和训练时相同的零均值和单位方差。使用无偏方差估计:

其中,期望是在mini-batch上进行计算,标准差是在所有样本上进行计算。因为均值和方差在推理过程中都是固定的,标准化只是一个用于激活中的线性变换。下图是使用BN的网络的训练步骤:

3.2 Batch-Normalized Convolutional Networks

BN可以被应用到任何网络的激活中。考虑一个如下的变换:

其中,WW和bb 是学习到的参数,g()g() 是非线性函数比如sigmoid或者ReLU。我们在非线性运算之前添加BN,即标准化x=W+b x=W+b。当然,也可以直接标准化输入层中的uu,但是由于uu 其他非线性函数的输出,在训练过程中它的分布可能会发生改变。对它的一阶和二阶矩进行约束无法消除covariate shift。反之,x=W+b x=W+b 则更可能具有一个对称、非稀疏的分布,更加近似正态分布。标准化这个分布更有希望产生一个具有稳定分布的激活。
注意到标准化x=W+b x=W+b 的过程中,偏差bb ,可以忽略(去均值的过程)。因此,z=g(BN(Wu+b))z=g(BN(Wu+b)) 可以替换为:
z=g(BN(Wu))z=g(BN(Wu))
BN变换是应用到x=Wu x=Wu 的每一个维度。在每一个维度都具有一对独立的参数γ(k)\gamma^{(k)},β(k)\beta^{(k)}。
对于卷级层,我们希望标准化过程符合卷积的一些性质以便于同一个特征图中的不同位置的元素,都可以使用相同的标准化。为了实现这一点,我们对一个mini-batch中的所有位置的激活同时进行标准化。对于一个大小为mm 的mini-batch,特征图的尺寸为p∗qp*q,我们使用的有效mini-batch大小为m′=m∗pqm^{'}=m*pq,在每一个特征图中学习到一对γ(k)\gamma^{(k)},β(k)\beta^{(k)} 参数,而不是在每一处激活中(有点类似于权值共享)。

3.3 Batch Normalization enables higher learning rates

通常,大的学习率会造成单层的参数放大,在反向传播的过程中,又会造成梯度的放大。但是如果使用BN的话,反向传播过程可以不受影响。假设放大倍数为aa

不难推出:

由于放大过程不会影响单层的Jacobian矩阵,最终也不会影响反向传播。更重要的是,更大的权重会导致更小的梯度,因此BN具有稳定参数增长的作用。
我们可以猜想BN可能导致Jacobian矩阵具有接近于1的奇异值,这对于训练过程是有利的6。考虑两个相邻的层,同时对这两个层进行标准化输入。假设两个层之间的变换为:
z^=F(x^)\hat{z}=F(\hat{x})
如果假定z^\hat{z}和x^\hat{x} 都属于正太分布并且不相关,那么
F(x^)≈Jx^F(\hat{x})\approx J \hat{x} 对于给定的模型参数就近似于一个线性变换。并且如果z^\hat{z}和x^\hat{x} 都具有单位协方差,即I=cov[z^]=Jcov[x^]JT=JJTI = cov[\hat{z}]=Jcov[\hat{x}]J^{T}=JJ^{T},因此JJ 的所有奇异值为1,这在反向啊传播的过程中可以稳定梯度的量级。实际的变换是非线性的,标准化之后的输入也不可能完全保证正态分布或者完全独立,但是BN仍然可以使反向传播的过程表现得更好。

4 Experiments

4.1 Activations over time

为了验证BN对于covariate shift的抑制作用,在MNIST数据集上进行验证。使用3个全连接层,每层100个神经元。每一个隐含层计算y=g(Wu+b)y=g(Wu+b),激活函数为sigmoid,使用正态分布初始化W<script type="math/tex" id="MathJax-Element-45">W</script>。loss使用交叉熵代价函数。训练50000步,每一个mini-batch有60个样本。实验结果见Figure 1。

从图(a)中可以看出,BN在少量数据的条件下能显著提高正确率。(b)(c)给出了15%、50%和85%的输入数据在整个训练过程中分布的差异,可以看到BN使得分布更加平滑,减少了internal covariate shift。

4.2 ImageNet classification

在Inception network中使用momentum优化器,mini-batch样本数为32,并使用3.2节中针对卷积的BN方法。

4.2.1 Accelerating BN Networks

简单地在网络中使用BN,并不能完全发挥出BN的优势,因此对网络的参数作出如下改变:

  • 增大学习率
  • 去除Dropout层
  • 减少L2正则
  • 增大学习率衰减
  • 去除LRN
  • 更加彻底地打乱训练样本
  • 减少图像亮度失真

4.2.2 Single-Network Classification

在LSVRC2012数据集上训练如下网络:

  • Inception:学习率0.0015
  • BN-Baseline:Inception+BN
  • BN-x5:Inception+BN+4.2.1中的改进,学习率0.0075
  • BN-x30:与BN-x5类似,初始学习率为0.045
  • BN-x5-Sigmoid:与BN-x5类似,但是使用sigmoid函数而不是ReLU
    Figure 2给出了验证集上整个训练过程中的正确率,可以看出,达到Inception相同正确率的水平,BN-x5训练步数最少。

Figure 3给出了最大正确率和训练步数的表格:

5 Conclusion

给出keras中卷积层的BN实现的源代码:

input_shape = self.input_shape
reduction_axes = list(range(len(input_shape)))
del reduction_axes[self.axis]
broadcast_shape = [1] * len(input_shape)  #len()用来统计行数
broadcast_shape[self.axis] = input_shape[self.axis]
if train:  m = K.mean(X, axis=reduction_axes)  #求各维度均值brodcast_m = K.reshape(m, broadcast_shape)  #展开均值std = K.mean(K.square(X - brodcast_m) + self.epsilon, axis=reduction_axes)   #求方差std = K.sqrt(std)  #求标准差brodcast_std = K.reshape(std, broadcast_shape)  #展开标准差mean_update = self.momentum * self.running_mean + (1-self.momentum) * m  #更新均值std_update = self.momentum * self.running_std + (1-self.momentum) * std  #更新方差self.updates = [(self.running_mean, mean_update),  (self.running_std, std_update)]  X_normed = (X - brodcast_m) / (brodcast_std + self.epsilon)  #标准化
else:  brodcast_m = K.reshape(self.running_mean, broadcast_shape)  brodcast_std = K.reshape(self.running_std, broadcast_shape)  X_normed = ((X - brodcast_m) /  (brodcast_std + self.epsilon))
out = K.reshape(self.gamma, broadcast_shape) * X_normed + K.reshape(self.beta, broadcast_shape) 

References


  1. Improving predictive inference under covariate shift by weighting the log-likelihood function ↩
  2. LeCun, Y., Bottou, L., Orr, G., and Muller, K. Efficient
    backprop. In Orr, G. and K., Muller (eds.), Neural Networks:
    Tricks of the trade. Springer, 1998b. ↩
  3. Wiesler, Simon and Ney, Hermann. A convergence analysis
    of log-linear training. In Shawe-Taylor, J., Zemel,
    R.S., Bartlett, P., Pereira, F.C.N., andWeinberger, K.Q.
    (eds.), Advances in Neural Information Processing Systems
    24, pp. 657–665,Granada, Spain, December 2011. ↩
  4. http://ufldl.stanford.edu/wiki/index.php/%E7%99%BD%E5%8C%96 ↩
  5. LeCun, Y., Bottou, L., Orr, G., and Muller, K. Efficient
    backprop. In Orr, G. and K., Muller (eds.), Neural Networks:
    Tricks of the trade. Springer, 1998b. ↩
  6. Saxe, Andrew M., McClelland, James L., and Ganguli,
    Surya. Exact solutions to the nonlinear dynamics
    of learning in deep linear neural networks. CoRR,
    abs/1312.6120, 2013. ↩

Batch Normalization详细解读相关推荐

  1. 白话详细解读(七)----- Batch Normalization

    转载:https://www.cnblogs.com/guoyaohua/p/8724433.html Batch Normalization作为最近一年来DL的重要成果,已经广泛被证明其有效性和重要 ...

  2. Batch Normalization原文详细解读

    这篇博客分为两部分, 一部分是[3]中对于BN(Batch Normalization的缩写)的大致描述 一部分是原文[1]中的完整描述 .####################先说下书籍[3]## ...

  3. lgg7深度详细参数_深度学习平均场理论第七讲:Batch Normalization会导致梯度爆炸?...

    前言 Batch Normalization (BN)对于深度学习而言是一项非常重要的技术.尽管BN在网络训练的过程中表现力非常强大,但是大家始终没有一个很好的理论上的清晰理解.今天我们就试图解读这篇 ...

  4. 【深度学习】Batch Normalization(BN)超详细解析

    单层视角 神经网络可以看成是上图形式,对于中间的某一层,其前面的层可以看成是对输入的处理,后面的层可以看成是损失函数.一次反向传播过程会同时更新所有层的权重W1,W2,-,WL,前面层权重的更新会改变 ...

  5. 解读Batch Normalization

    [活动]Python创意编程活动开始啦!!!     CSDN日报20170424 --<技术方向的选择>    程序员4月书讯:Angular来了! 解读Batch Normalizat ...

  6. Batch Normalization(BN)超详细解析

    单层视角 神经网络可以看成是上图形式,对于中间的某一层,其前面的层可以看成是对输入的处理,后面的层可以看成是损失函数.一次反向传播过程会同时更新所有层的权重W1,W2,-,WL,前面层权重的更新会改变 ...

  7. 【论文理解】Batch Normalization论文中关于BN背景和减少内部协变量偏移的解读(论文第1、2节)

    最近在啃Batch Normalization的原论文(Title:Batch Normalization: Accelerating Deep Network Training by Reducin ...

  8. 批标准化(Batch Normalization )最详细易懂的解释

    12. 批标准化(Batch Normalization ) 大纲:Tips for Training Deep Network Training Strategy: Batch Normalizat ...

  9. 关于Batch Normalization的理解和认识

    1 前言 Batch Normalization作为最近几年来DL的重要成果,已经广泛被证明其有效性和重要性.目前几乎已经成为DL的标配了,任何 有志于学习DL的同学们朋友们都应该好好学一学BN.BN ...

最新文章

  1. oc75--不可变字典NSDictionary
  2. Linux varnish代理服务器安装以及健康检查
  3. 为什么越来越多的人都不再愿意做程序员了?
  4. python处理中文字符串_处理python字符串中的中文字符
  5. JavaWeb14-HTML篇笔记(一)
  6. 能改变原生web前端元素样式的water.css
  7. 手机开启热点给其他设备上网和用插卡随身路由给其他设备上网有何区别呢?
  8. 隐藏网络计算机,如何在网络中隐藏自己的计算机名称
  9. java 正则表达式 替换 html,java 正则表达式 替换 html
  10. 机器人基础原理1_2——机器人分类与常见坐标系
  11. 百分百解决python manage.py makemigrations没有反应
  12. Unity不规则按钮
  13. 网络信息安全:五、GRE和IPSEC
  14. 陈景润定理对筛法理论的贡献
  15. JUC并发编程学习笔记
  16. krpano获取地址栏传参
  17. 【项目管理】--- 时间管理 --- 缩短工期
  18. Office Tool Plus的使用
  19. MySQL环境变量的配置(三)(Windows 11)
  20. opc-ua协议机器数据采集-python

热门文章

  1. 通过买新电脑a时买的正版Windows 10 pro for OEM key升级电脑b操作系统Windows 10 home 到 专业版pro
  2. Qt5简单函数计算器
  3. Dart list数组集合类型
  4. 微信开发工具制作会动的海绵宝宝
  5. 【黑金动力社区】【FPGA黑金开发板】他和它的故事 之模块的沟通
  6. 女朋友过生日,男子买了一条项链,女友:值不了多少钱
  7. QrCodeUtil--二维码工具类
  8. selenium自动化测试实战教学(12306自动化订票)春节出行必备
  9. android屏幕跳转,Android 几种屏幕间跳转的跳转Intent Bundle
  10. 安全测试:xss,cookie,xst注入攻防