深度学习中 Internal Covariate Shift 问题以及 Batch Normalization 的作用

  • 前言
  • 一、Batch Normalization是什么?
    • 1.1 Internal Covariate Shift
    • 1.2 Internal Covariate Shift 带来的影响
    • 1.3 如何减缓 Internal Covariate Shift 问题带来的影响
      • 白化(Whitening)
      • 白化存在的问题
  • 二、Batch Normalization
    • 2.1 传统 Normalization
    • 2.2 改进
    • 2.3 原文算法[^4]
  • 三、Batch Normalization 在测试阶段的使用
    • 3.1 测试阶段如何计算μl\mu_lμl​ 和 σl2\sigma_l^2σl2​ [^5]
    • 3.2 完成BN层的计算
  • 四、Batch Normalization总结

前言

提示:这里可以添加本文要记录的大概内容:
最近在看到一篇论文:FedBN: Federated Learning on Non-IID Features via Local Batch Normalization

论文合理的利用了 Batch Norlization 解决了联邦学习中,不同的边缘节点上数据存在的非独立同分布(Non-IID)的问题,降低了 Feature shift 带来的影响。

为此,这里记录对 Batch Normalization 的理解,以及DedBN论文的总结。


一、Batch Normalization是什么?

相信大家第一次接触,应该是在训练模型的时候,模型训练困难,总是波动很大或者损失函数下降十分缓慢,甚至是无法收敛。

有研究表明:随着 DNN 网络层次的加深,参数的变化导致每一层的输入分布会发生改变,进而上层的网络需要不停地去适应这些分布变化,使得我们的模型训练变得困难。

为什么会这样呢?因为 Internal Covariate Shift

1.1 Internal Covariate Shift

因为在训练过程中,每一层的输入分布会随着前一层参数的变化而变化,这种现象称之为 Internal Covariate Shift1

下图为一个多层全连接的神经网络结构示意图,左侧的网络层为底层,右侧的网络层称之为顶层。

以本图为例,每一层 lll 可以理解为两个操作:

  • 线性变换:Y[l]=W[l]×input+b[l]Y^{[l]} = W^{[l]} \times input + b^{[l]}Y[l]=W[l]×input+b[l]。其中W[l]W^{[l]}W[l]表示lll层的权重,b[l]b^{[l]}b[l]表示lll层的偏置,Y[l]Y^{[l]}Y[l]表示lll层的线性输出
  • 非线性变换: Z[l]=g[l](Y[l])Z^{[l]} = g^{[l]}(Y^{[l]})Z[l]=g[l](Y[l])。g[l]g^{[l]}g[l]表示lll层的激活函数

在模型的反向传播过程中,根据计算的梯度来更新每一层的W[l]W^{[l]}W[l]和b[l]b^{[l]}b[l],那么Y[l]Y^{[l]}Y[l]的分布也会改变,Z[l]Z^{[l]}Z[l]的分布也随之改变。

然而Z[l]Z^{[l]}Z[l]作为下一层(l+1)(l+1)(l+1)的输入,这就使得(l+1)(l+1)(l+1)层的神经元也需要不断的适应这样的变化,这就会降低整个网络的收敛速度。

1.2 Internal Covariate Shift 带来的影响

① 上层网络需要不停调整来适应输入数据分布的变化,导致网络学习速度的降低

如上所提到的,梯度下降使得每一层的参数都在不断发生变化,
进而使得每一层的线性与非线性计算结果分布产生变化。
后层网络就要不停地去适应这种分布变化,这个时候就会使得整个网络的学习速率过慢。

② 网络的训练过程容易陷入梯度饱和区,减缓网络收敛速度

梯度饱和和梯度消失的后果有点类似(但不要混淆哦)。梯度饱和:常常是和激活函数相关的,比如sigmod和tanh就属于典型容易进入梯度饱和区的函数。
即自变量进入某个区间后,梯度变化会非常小,
表现在图上就是函数曲线进入某些区域后,越来越趋近一条直线,梯度变化很小。
梯度饱和会导致训练过程中梯度变化缓慢,从而造成模型训练缓慢

下图为 sigmoid(左)sigmoid(左)sigmoid(左) 和 Tanh(右)Tanh(右)Tanh(右) 的激活函数与对应的一阶导数曲线图2


两者的导数均在原点处取得最大值,sigmoidsigmoidsigmoid最大为0.25,TanhTanhTanh最大为1;
在远离原点的正负方向上,两者导数均趋近于0,即存在饱和区。

饱和区: 一旦陷入饱和区,两者的偏导都接近于0,导致权重的更新量很小,比如某些权重很大,导致相关的神经元一直陷在饱和区,更新量又接近于0,以致很难跳出或者要花费很长时间才能跳出饱和区。

1.3 如何减缓 Internal Covariate Shift 问题带来的影响

注意:Internal Covariate Shift 是因为参数更新带来的网络中每一层输入值分布的改变,并且随着网络层数的加深而变得更加严重.
因此可以通过固定每一层网络输入值的分布来对减缓ICS问题。

白化(Whitening)

这不是本文探讨的重点,想了解的可以参考 Whitening 3 ,这里我们只需要知道,白化后的数据会有如下性质:

  • 特征之间相关性较低;
  • 所有特征具有相同的方差。

通过白化操作,可以有效地减缓 Internal Covariate Shift 的问题,进而固定了每一层网络输入分布,加速网络训练过程的收敛。

白化存在的问题

白化存在如下缺点:

  • 白化过程计算成本太高,比如 PCA 中,需要计算协方差矩阵,并且在每一轮训练中的每一层我们都需要做如此高成本计算的白化操作;
  • 白化过程由于改变了网络每一层的分布,因而改变了网络层中本身数据的表达能力。底层网络学习到的参数信息会被白化操作丢失掉。

二、Batch Normalization

这里说明下传统 Normalization 使用的原因:

(1)由于神经网络学习过程本质上是为了学习数据的分布,一旦训练数据与测试数据的分布不同,
那么网络的泛化能力也大大降低;
(2)另一方面,在mini-batch梯度下降训练的时候,每批训练数据的分布不相同,那么网络就
要在每次迭代的时候去学习以适应不同的分布,这样将会大大降低网络的训练速度。

而 Normalization 能够很好的使得样本处于同一个分布

2.1 传统 Normalization

上面所说的传统 Normalization 在数学里面都学过,也叫归一化,常见的形式如下:

在非线性变换之后或者线性变换之前对 xxx 进行标准化处理(减去均值,除标准差),让数据处于均值为0、方差为1的分布中,以降低样本间的差异性。而 Batch Normalization 则是指对一个 batch 进行 Normalization .

这样使得对应的样本计算出的梯度处于中间的中心区域(图中红色显示区域)。因为梯度一直都能保持比较大的状态,所以很明显对神经网络的参数调整效率比较高,就是变动大,就是说向损失函数最优值迈动的步子大,反向传播信息流动性更强,加快训练收敛速度。

如果仅仅使用这样的方法,对网络某一层 lll 的输出数据做归一化,然后送到网络下一层B中,这样会影响本层网络lll 所学到的特征,从而导致数据表达能力的缺失。

另一方面,通过让每一层的输入分布均值为0,方差为1,会使得输入在经过sigmoid或tanh激活函数时,容易陷入非线性激活函数的线性区域,即0附近的线性区域。

这样一来非线性激活函数就起不到相应的非线性变换的作用,或者就是相当于一个线性层罢了,那么网络的非线性表达能力就下降了。

2.2 改进

因此,BN又引入了两个可学习(learnable)的参数 γ\gammaγ与 β\betaβ,这两个参数的引入是为了恢复数据本身的表达能力,对规范化后的数据进行线性变换,即:

Zj~=γjZj^+βj\tilde{Z_{j}} = \gamma_j\hat{Z_{j}} + \beta_jZj​~​=γj​Zj​^​+βj​

这个操作使数据在中心区域附近的线性区域往旁边的非线性区域进行了一定的偏移,即通过γj\gamma_jγj​ 和 βj\beta_jβj​ 把原来的输出值从标准正态分布左移或者右移一点,使得曲线更加胖一点或瘦一点,每个实例挪动的程度不一样,这样等价于非线性函数的值从正中心周围的线性区往非线性区进行了扩散移动。

核心思想: 找到一个线性和非线性的较好平衡点,既能享受非线性的较强表达能力的好处(γj\gamma_jγj​ 和 βj\beta_jβj​ 带来的),又避免太靠非线性区两头使得网络收敛速度太慢(Normalization 带来的)。这两个参数的核心思想就是兼顾线性的快速收敛,与非线性的较强表达能力。这两个参数需要通过学习得到的。

2.3 原文算法4


三、Batch Normalization 在测试阶段的使用

3.1 测试阶段如何计算μl\mu_lμl​ 和 σl2\sigma_l^2σl2​ 5

在训练阶段,各层的 μl\mu_lμl​ 和 σl2\sigma_l^2σl2​ 是通过当前层得到的输入 batch 计算而得。

而测试阶段有可能仅输入一个或者极少样本,它对应的 μl\mu_lμl​ 和 σl2\sigma_l^2σl2​是没有意义的,这时候该如何计算 μl\mu_lμl​ 和 σl2\sigma_l^2σl2​ 呢?

针对每一层 lll 而言:因为在训练结束后,每一层的参数已经固定好了,那么每一层有很多个已经计算过的mini-batch,则有这些 batch 对应的μl\mu_lμl​ 和 σl2\sigma_l^2σl2​,在训练时,把这些值都保存下来;在测试时,通过计算μl\mu_lμl​的数学期望,以及 σl2\sigma_l^2σl2​ 的无偏估计,从而间接求出该层的全局统计量:

  • μ^l=E(μbatch)\hat{\mu}_l = E(\mu_{batch})μ^​l​=E(μbatch​)
  • σ^l2=mm−1E(σbatch2)\hat{\sigma}_l^2 = \frac m {m-1} E({\sigma}_{batch}^2)σ^l2​=m−1m​E(σbatch2​)。

这样的好处是,能够追踪训练过程中所有 mini-batch 的样本的特性。

需要注意的是: 对每一层输入进行归一化的时候,是按照一维度一维度的归一化,即每一个神经元 lil_ili​ 表示一维特征,在对该维度求μli\mu_{l_i}μli​​ 和 σli2\sigma_{l_i}^2σli​2​,然后对该维度的输入进行 Normalization ,这一点千万别搞错了。

在得到每一个特征的均值和方差后,就可以对测试样本进行 Normalization 了.

3.2 完成BN层的计算

这里特别指出:BN 层往往是添加在每一层网络中非线性层之前(激活函数计算之前)。

因此可以通过 Batch Normalization 得到:

BN(Xtest)=γ⋅Xtest−μtestσtest2+ε+βBN(X_{test}) = {\gamma} {\cdot} {\frac {X_{test} - {\mu}_{test}} {\sqrt{\sigma^2_{test}+ \varepsilon}}} + {\beta}BN(Xtest​)=γ⋅σtest2​+ε​Xtest​−μtest​​+β


四、Batch Normalization总结

BN的优势可以总结为如下6

(1)BN使得网络中的每一层输入数据都相对稳定,能够给加快模型的训练速度

BN通过规范化与线性变换使得每一层网络的输入数据的均值与方差都在一定范围内,使得后一层网络不必不断去适应底层网络
中输入的变化,从而实现了网络中层与层之间的解耦,允许每一层进行独立学习,有利于提高整个神经网络的学习速度。

(2)BN使得模型对网络中的参数不那么敏感,简化调参过程,使得网络学习更加稳定

训练时,学习率设置的过大会导致步长过大,从而使得整个模型来回震荡。但是使用了BN层的网络则不会。为什么呢?

假设对参数 WWW 进行缩放,得到 aWaWaW。那么有:

  • 缩放前:BN层的输入值为 W⋅μW \cdot {\mu}W⋅μ;均值为 μ1\mu_1μ1​, 方差为 σ12\sigma_1^2σ12​。
  • 缩放后:BN层的输入值为 aW⋅μaW \cdot {\mu}aW⋅μ; 均值为 μ2\mu_2μ2​, 方差为 σ22\sigma_2^2σ22​。

其中 μ2=aμ1\mu_2 = a\mu_1μ2​=aμ1​ ,σ22=a2σ12\sigma_2^2 = a^2 \sigma_1^2σ22​=a2σ12​

忽略计算过程中的 ε\varepsilonε,则有:

BN(aWu)=γ⋅aWu−μ2σ22+β=γ⋅aWu−aμ1a2σ12+β=γ⋅Wu−μ1σ12+β=BN(Wu)BN(aWu)=\gamma\cdot\frac{aWu-\mu_2}{\sqrt{\sigma^2_2}}+\beta=\gamma\cdot\frac{aWu-a\mu_1}{\sqrt{a^2\sigma^2_1}}+\beta=\gamma\cdot\frac{Wu-\mu_1}{\sqrt{\sigma^2_1}}+\beta=BN(Wu)BN(aWu)=γ⋅σ22​​aWu−μ2​​+β=γ⋅a2σ12​​aWu−aμ1​​+β=γ⋅σ12​​Wu−μ1​​+β=BN(Wu)

∂BN((aW)u)∂u=γ⋅aWσ22=γ⋅aWa2σ12=∂BN(Wu)∂u\frac{\partial{BN((aW)u)}}{\partial{u}}=\gamma\cdot\frac{aW}{\sqrt{\sigma^2_2}}=\gamma\cdot\frac{aW}{\sqrt{a^2\sigma^2_1}}=\frac{\partial{BN(Wu)}}{\partial{u}}∂u∂BN((aW)u)​=γ⋅σ22​​aW​=γ⋅a2σ12​​aW​=∂u∂BN(Wu)​

∂BN((aW)u)∂(aW)=γ⋅uσ22=γ⋅uaσ12=1a⋅∂BN(Wu)∂W\frac{\partial{BN((aW)u)}}{\partial{(aW)}}=\gamma\cdot\frac{u}{\sqrt{\sigma^2_2}}=\gamma\cdot\frac{u}{a\sqrt{\sigma^2_1}}=\frac{1}{a}\cdot\frac{\partial{BN(Wu)}}{\partial{W}}∂(aW)∂BN((aW)u)​=γ⋅σ22​​u​=γ⋅aσ12​​u​=a1​⋅∂W∂BN(Wu)​

从第一个式子可以看到,在缩放后BN层的输出和缩放前BN层的输出的值一样,即缩放效果 aaa 被消掉,这里的缩放就是W更新一次之后,相比原来更新前,其实就是一个增大减小,使用一个a来表示。

意思就是,即便每一层的参数在不断地改变,但是这对BN层的输出区间没有太大影响,能够使得输出稳定在有 γ和β{\gamma} 和 {\beta}γ和β 决定的分布区间内.这就能够有效地降低底层、顶层等网络层之间由于分布不断改变,而使得网络训练速度慢、训练不稳定等问题。

此外,权重的缩放并不会影响对 μ\muμ 的梯度计算。
并且当权重增大,即 aaa 大于1,则 1a\frac 1 aa1​ 越小,意味着权重 WWW 的在增大的过程中,会被 aaa 抑制(相比无BN层的梯度);
而当权重越小,即 aaa 小于1,则 1a\frac 1 aa1​ 越大,意味着权重 WWW 的梯度越大(相比无BN层的梯度);
即从而保证了梯度不会随着参数的放大缩小而发生剧烈的改变,保证了参数更新显得更加稳定。

(3)BN允许网络使用饱和性激活函数(例如sigmoid,tanh等),缓解梯度消失问题

Batch Normalization 使得激活函数的输入稳定地落在梯度非饱和区(激活函数的两端为饱和区域),缓解梯度消失的问题;
另外通过 γ 和 β 参数进行变换,使得数据保留更多的原始信息。

(4)BN能够起到一定的正则化效果

尽管每一个batch的数据都是来源于总体样本,但是每一个mini-batch计算的样本均值和方差还是会有所不同,这就相当于
在网络中添加了一些随机噪音;
此外,通过将输入特征转化到了同一scale,模型就不会特意的偏向某一个神经元,从而减轻了过拟合的问题;
也就自然起到了正则化的效果。但是BN本身并不能真正的抑制过拟合。

  1. Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift ↩︎

  2. 网络权重初始化方法总结(上):梯度消失、梯度爆炸与不良的初始化 ↩︎

  3. 机器学习(七)白化whitening ↩︎

  4. Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift ↩︎

  5. 吴恩达深度学习-测试时的Batch Normalization ↩︎

  6. Batch Normalization原理与实战 ↩︎

深度学习中 Internal Covariate Shift 问题以及 Batch Normalization 的作用相关推荐

  1. 花书+吴恩达深度学习(八)优化方法之 Batch normalization

    目录 0. 前言 1. Batch normalization 训练 2. Batch normalization 测试 如果这篇文章对你有一点小小的帮助,请给个关注,点个赞喔~我会非常开心的~ 花书 ...

  2. 深度学习(十三)——花式池化, Batch Normalization

    https://antkillerfarm.github.io/ 花式池化 池化和卷积一样,都是信号采样的一种方式. 普通池化 池化的一般步骤是:选择区域P,令Y=f(P)Y=f(P).这里的f为池化 ...

  3. 对 BatchNormalization 中 Internal Convariate Shift 的理解

    前言:写的不好,主要解释了对内部协变量漂移(Internal Convariate Shift)的理解. 之前对BatchNormalization的理解不是很透彻,在搭建神经网络的时候也没有很注意去 ...

  4. 深度学习中的归一化方法总结(BN、LN、IN、GN、SN、PN、BGN、CBN、FRN、SaBN)

    目录 概要 Batch Normalization(BN) (1)提出BN的原因 (2)BN的原理 (3)BN优点 (4)BN缺点 Instance Normalization(IN) (1)提出IN ...

  5. 2.3)深度学习笔记:超参数调试、Batch正则化和程序框架

    目录 1)Tuning Process 2)Using an appropriate scale to pick hyperparameters 3)Hyperparameters tuning in ...

  6. 深度学习——Internal Covariate Shift与Normalization

    转载自 https://blog.csdn.net/sinat_33741547/article/details/87158830 Internal Covariate Shift与Normaliza ...

  7. 深度学习论文--Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift

    本文翻译论文为深度学习经典模型之一:GoogLeNet-BN 论文链接:https://arxiv.org/abs/1502.03167v3 摘要:训练深度神经网络的难度在于:前一层网络参数的变化,导 ...

  8. 详解深度学习中的Normalization,不只是BN(1)

    " 深度神经网络模型训练之难众所周知,其中一个重要的现象就是 Internal Covariate Shift. Batch Normalization 大法自 2015 年由Google ...

  9. 深度学习中的Normalization模型(附实例公式)

    来源:运筹OR帷幄 本文约14000字,建议阅读20分钟. 本文以非常宏大和透彻的视角分析了深度学习中的多种Normalization模型,从一个新的数学视角分析了BN算法为什么有效. [ 导读 ]不 ...

最新文章

  1. 永成科技C++笔试题
  2. RAC环境下的备份与恢复(四)
  3. UVA-10954 Add All
  4. Spring MVC 特性实现文件下载
  5. Mac常用开源软件与下载链接一览
  6. 集群服务器下使用SpringBoot @Scheduled注解定时任务
  7. 【CAD开发】3dxml文件格式读取(Python、C++、C#)
  8. Java+SSM+Jsp+Mysql项目大学生健康管理系统
  9. javashop多用户商城系统源码
  10. jquery api的整体解读
  11. JMETER使用CURL导入功能
  12. LINUX下截图快捷方式
  13. ps水彩效果教程-庞姿姿
  14. CondaSSLError: OpenSSL appears to be unavailable on this machine.
  15. mac 连接linux sh,ssh工具 – windows和mac 上ssh连接linux 服务器工具推荐 – The Hu Post...
  16. eUSB是什么/可以干什么?
  17. 20181027解题报告
  18. pandas学习(创建多层索引、数据重塑与轴向旋转)
  19. ARCore HDR 光估测深度解析
  20. 连锁多门店收银系统之进销存的采购进货单源码功能逻辑

热门文章

  1. spring boot + vue 使用poi实现Excel导出功能(包括Excel样式调整,以及前后端代码)
  2. MySQL Error Query database. Causejava.sql.SQLException: Incorrect key file for table ‘/tmp/#sql_181c
  3. ubuntu上显卡驱动安装——GeForce GTX 1080 Ti
  4. CSS3day(CSS三大特性,行高的继承,选择器的权重,盒子模型:外边距,边框,内边距)
  5. 《水调歌头·丙辰中秋》 苏轼
  6. Android11 Wifi 加密类型详解
  7. H5技术的潮流----阿冬专栏
  8. 吲哚菁绿ICG-Osu,ICG-PEG12-Osu,吲哚菁绿-聚乙二醇-活性酯
  9. 【计算机毕业设计】列车票务信息管理系统
  10. 易语言POST教程分享一波