深度学习之卷积神经网络(8)BatchNorm层

  • BatchNorm层概念
  • BatchNorm层实现
    • 1. 向前传播
    • 2. 反向更新
    • 3. BN层实现
    • 4. 完整代码

卷积神经网络的出现,网络参数量大大减低,使得几十层的深层网络称为可能。然而,在残差网络出现之前,网络的加深使得网络训练变得十分不稳定,甚至出现网络长时间不更新甚至不收敛的现象,同时网络对超参数比较敏感,超参数的微量扰动也会导致网络的训练轨迹完全改变。

 2015年,Google研究人员Sergey Ioffe等提出了一种参数标准化(Normalize)的手段,并基于参数标准化设计了Batch Normalization(简写为BatchNorm,或BN)层。BN层的提出,使得网络的超参数的设定更加自由,比如更大的学习率、更随意的网络初始化等,同时网络的收敛速度更快,性能也更好。BN层提出后便广泛地应用在各种深度网络模型上,卷积层、BN层、ReLU层、池化层一度成为网络模型的标配单元块,通过堆叠Conv-BN-ReLU-Pooling方式往往可以获得不错的模型性能。

BatchNorm层概念

 首先我们来探索,为什么需要对网络中的数据进行标准化操作?这个问题很难从理论层面解释透彻,即使是BN层的作者给出的解释也未必让所有人信服。与其纠结其缘由,不如通过具体问题来感受数据标准化后的好处。

 考虑Sigmoid激活函数和它的梯度分布,如下图所示,Sigmoid函数在x∈[−2,2]x∈[-2,2]x∈[−2,2]区间的导数值在[0.1,0.25][0.1,0.25][0.1,0.25]区间分布; 当x>2x>2x>2或x<−2x<-2x<−2时,Sigmoid函数的导数变得很小,逼近于0,从而容易出现梯度弥散现象。为了避免因为输入较大或者较小而导致Sigmoid函数出现梯度弥散现象,将函数输入x标准化映射到0附近的一段较小区间将变得非常重要,可以从下图看到,通过标准化重映射后,值被映射在0附近,此处的导数值不至于过小,从而不容易出现梯度弥散现象。这时使用标准化手段收益的一个例子。

Sigmoid函数及其导数曲线

 我们再看另一个例子。考虑2个输入节点的线性模型,如图所示:
L=a=x1w1+x2w2+b\mathcal L=a=x_1 w_1+x_2 w_2+bL=a=x1​w1​+x2​w2​+b
讨论如下两种输入分布下的问题:

  • 输入x1∈[1,10],x2∈[1,10]x_1∈[1,10],x_2∈[1,10]x1​∈[1,10],x2​∈[1,10]
  • 输入x1∈[1,10],x2∈[100,1000]x_1∈[1,10],x_2∈[100,1000]x1​∈[1,10],x2​∈[100,1000]

由于模型相对简单,可以绘制出两种x1x_1x1​、x2x_2x2​下,函数的损失等高线图,图(b)是x1∈[1,10],x2∈[100,1000]x_1∈[1,10],x_2∈[100,1000]x1​∈[1,10],x2​∈[100,1000]时的某条优化轨迹线示意,图(c)是x1∈[1,10],x2∈[1,10]x_1∈[1,10],x_2∈[1,10]x1​∈[1,10],x2​∈[1,10]时的某条优化轨迹线示意,图中的圆环中心即为全局极值点。

数据标准化举例示意图

考虑到:
∂L∂w1=x1∂L∂w2=x2\frac{∂\mathcal L}{∂w_1}=x_1\\ \frac{∂\mathcal L}{∂w_2}=x_2∂w1​∂L​=x1​∂w2​∂L​=x2​
当x1x_1x1​、x2x_2x2​输入分布相近时,∂L∂w1\frac{∂\mathcal L}{∂w_1}∂w1​∂L​、∂L∂w2\frac{∂\mathcal L}{∂w_2}∂w2​∂L​偏导数值相当,函数的优化轨迹如图(c)所示; 当x1x_1x1​、x2x_2x2​输入分布差距较大时,比如x1≪x2x_1≪x_2x1​≪x2​,则:
∂L∂w1≪∂L∂w2\frac{∂\mathcal L}{∂w_1}≪\frac{∂\mathcal L}{∂w_2}∂w1​∂L​≪∂w2​∂L​
损失函数等势线在w2w_2w2​轴更加陡峭,某条可能的优化轨迹如图(b)所示。对比两条优化轨迹线可以观察到,x1x_1x1​、x2x_2x2​分布相近时图(c)中收敛更加快速,优化轨迹更理想。

 通过上述的两个例子,我们能够经验性归纳出: 网络层输入xxx分布相近,并且分布在较小范围内时(如0附近),更有利于函数的优化。那么如何保证输入xxx分布相近呢?数据标准化可以实现此目的,通过数据标准化操作可以将数据xxx映射到x^\hat{x}x^:
x^=x−μrσr2+ϵ\hat{x}=\frac{x-μ_r}{\sqrt{σ_r^2+ϵ}}x^=σr2​+ϵ​x−μr​​
其中μrμ_rμr​、σr2σ_r^2σr2​来自统计的所有数据的均值和方差,ϵϵϵ是为防止出现除0错误而设置的较小的数字,如1e−81e-81e−8。
在基于Batch的训练阶段,如何获取每个网络层所有输入的统计数据μrμ_rμr​、σr2σ_r^2σr2​呢?考虑Batch内部的均值μBμ_BμB​和方差σB2σ_B^2σB2​:
μB=1m∑i=1mxiμ_B=\frac{1}{m} \sum_{i=1}^mx_i μB​=m1​i=1∑m​xi​
σB2=1m∑i=1m(xi−μB)2σ_B^2=\frac{1}{m} \sum_{i=1}^m(x_i-μ_B)^2 σB2​=m1​i=1∑m​(xi​−μB​)2
可以视为近似于μrμ_rμr​、σr2σ_r^2σr2​,其中mmm为Batch样本数。因此,在训练阶段,通过
x^train=xtrain−μBσB2+ϵ\hat{x}_{train}=\frac{x_{train}-μ_B}{\sqrt{σ_B^2+ϵ}}x^train​=σB2​+ϵ​xtrain​−μB​​
标准化输入,并记录每个Batch的统计数据μBμ_BμB​、σB2σ_B^2σB2​,用于统计真实的全局μrμ_rμr​、σr2σ_r^2σr2​。

 在测试阶段,根据记录的每个Batch的μBμ_BμB​、σB2σ_B^2σB2​估计出所有训练数据的μrμ_rμr​、σr2σ_r^2σr2​,按着
x^test=xtest−μrσr2+ϵ\hat{x}_{test}=\frac{x_{test}-μ_r}{\sqrt{σ_r^2+ϵ}}x^test​=σr2​+ϵ​xtest​−μr​​
将每层的输入标准化。

 上述的标准化运算并没有引入额外的待优化变量,μrμ_rμr​、σr2σ_r^2σr2​和μBμ_BμB​、σB2σ_B^2σB2​均由统计得到,不需要参与梯度更新。实际上为了提高BN层的表达能力,BN层作者引入了“scale and shift”技巧,将x^\hat{x}x^变量再次映射变换:
x~=x^⋅γ+β\tilde{x}=\hat{x}\cdotγ+βx~=x^⋅γ+β
其中γγγ参数实现对标准化后的x^\hat{x}x^再次进行缩放,βββ参数实现对标准化后的x^\hat{x}x^进行平移,不同的是,γγγ、βββ参数均由反向传播算法自动优化,实现网络层“按需”缩放平移数据的分布的目的。

 下面我们来学习在TensorFlow中实现的BN层的方法。

BatchNorm层实现

1. 向前传播

 我们将BN层的输入记为xxx,输出记为x^\hat{x}x^。分训练阶段和测试阶段来讨论前向传播过程。

训练阶段: 首先计算当前Batch的μBμ_BμB​、σB2σ_B^2σB2​,根据
x^train=xtrain−μBσB2+ϵ⋅γ+β\hat{x}_{train}=\frac{x_{train}-μ_B}{\sqrt{σ_B^2+ϵ}}\cdotγ+βx^train​=σB2​+ϵ​xtrain​−μB​​⋅γ+β
计算BN层的输出。

 同时按照
μr←momentum⋅μr+(1−momentum)⋅μBσr2←momentum⋅σr2+(1−momentum)⋅σB2μ_r←\text{momentum}\cdotμ_r+(1-\text{momentum})\cdotμ_B\\ σ_r^2←\text{momentum}\cdotσ_r^2+(1-\text{momentum})\cdotσ_B^2μr​←momentum⋅μr​+(1−momentum)⋅μB​σr2​←momentum⋅σr2​+(1−momentum)⋅σB2​
迭代更新全局训练数据的统计值μrμ_rμr​和σr2σ_r^2σr2​,其中momentum\text{momentum}momentum是需要设置一个超参数,用于平衡μrμ_rμr​、σr2σ_r^2σr2​的更新幅度:

当momentum=0\text{momentum}=0momentum=0时,μrμ_rμr​和σr2σ_r^2σr2​直接被设置为最新一个Batch的μBμ_BμB​和σB2σ_B^2σB2​;

当momentum=1\text{momentum}=1momentum=1时,μrμ_rμr​和σr2σ_r^2σr2​保持不变,忽略最新一个Batch的μBμ_BμB​和σB2σ_B^2σB2​;

在TensorFlow中,momentum\text{momentum}momentum默认设置为0.99。

测试阶段: BN层根据
x~test=xtest−μrσr2+ϵ⋅γ+β\tilde{x}_{test}=\frac{x_{test}-μ_r}{\sqrt{σ_r^2+ϵ}}\cdotγ+βx~test​=σr2​+ϵ​xtest​−μr​​⋅γ+β
计算出x~test\tilde{x}_{test}x~test​,其中μrμ_rμr​、σr2σ_r^2σr2​、γγγ、βββ均来自训练阶段统计或优化的结果,在测试阶段直接使用,并不会更新这些参数。

2. 反向更新

 在训练模式下的反向更新阶段,反向传播算法根据损失L\mathcal LL求解梯度∂L∂γ\frac{∂\mathcal L}{∂γ}∂γ∂L​和∂L∂β\frac{∂\mathcal L}{∂β}∂β∂L​,并按着梯度更新法则自动优化γγγ、βββ参数。

 需要注意的是,对于2D特征图输入X:[b,h,w,c]\boldsymbol X:[b,h,w,c]X:[b,h,w,c],BN层并不是计算每个点的μBμ_BμB​、σB2σ_B^2σB2​,而是在通道轴ccc上面统计每个通道上面所有数据的μBμ_BμB​、σB2σ_B^2σB2​,因此μBμ_BμB​、σB2σ_B^2σB2​是每个通道上所有其它维度的均值和方差。以shape为[100,32,32,3][100,32,32,3][100,32,32,3]为例,在通道轴ccc上面的均值计算如下:

import tensorflow as tf# 构造输入
x = tf.random.normal([100,32,32,3])
# 将其他维度合并,仅保留通道维度
x = tf.reshape(x, [-1,3])
# 计算其他维度的均值
ub = tf.reduce_mean(x, axis=0)
print(ub)

运行结果如下:

数据有ccc个通道数,则有ccc个均值产生。

 除了在ccc轴上面统计数据μBμ_BμB​、σB2σ_B^2σB2​的方式,我们也很容易将其推广至其它维度计算均值的方式,如图所示:

  • Layer Norm: 统计每个样本的所有特征的均值和方差
  • Instance Norm: 统计每个样本的每个通道上特征的均值和方差
  • Group Norm: 将ccc通道分成若干组,统计每个样本的通道组内的特征均值和方差

 上面提到的Normalization方法均由独立的几篇论文提出,并在某些应用上验证了其相当于或者由于BatchNorm算法的效果。由此可见没深度学习算法研究并非难于上青天,只要多思考、多锻炼算法工程能力,人人都有机会发表创新性成果。

不同标准化方案示意图

3. BN层实现

 在TensorFlow中,通过layers.BatchNormalization()类可以非常方便地实现BN层:

# 创建BN层
layer = layers.BatchNormalization()

与全连接层、卷积层不同,BN层的训练阶段和测试阶段的行为不同,需要通过设置training标志位来区分训练模式还是测试模式。

 以LeNet-5的网络模型为例,在卷积层后添加BN层,代码如下:

network = Sequential([  # 网络容器layers.Conv2D(6, kernel_size=3, strides=1),  # 第一个卷积层,6个3×3卷积核# 插入BN层layers.BatchNormalization(),layers.MaxPooling2D(pool_size=2, strides=2),  # 高宽各减半的池化层layers.ReLU(),  # 激活函数layers.Conv2D(16, kernel_size=2, strides=1),  # 第二个卷积层,16个3×3卷积核# 插入BN层layers.BatchNormalization(),layers.MaxPooling2D(pool_size=2, strides=2),  # 高宽各减半的池化层layers.ReLU(),  # 激活函数layers.Flatten(),  # 打平层,方便全连接层处理layers.Dense(120, activation='relu'),  # 全连接层,120个节点# 此处也可以插入BN层layers.Dense(84, activation='relu'),  # 全连接层,84个节点# 此处也可以插入BN层layers.Dense(10),  # 全连接层,10个节点
])

在训练阶段,需要设置网络的参数training=True以区分BN层是训练还是测试模型,代码如下:

with tf.GradientTape() as tape:# 插入通道维度x = tf.expand_dims(x, axis=3)# 向前计算,设置计算模式,[b,784] => [b,10]out = network(x, training=True)

在测试阶段,需要设置training=False,避免BN层采用错误的行为,代码如下:

for x, y in test_db:  # 遍历所有训练集样本# 插入通道维度,=>[b,28,28,1]x = tf.expand_dims(x, axis=3)# 向前计算,获得10类别的概率分布,[b,784] => [b,10]out = network(x, training=False)

4. 完整代码

加入BN层的LeNet-5完整代码如下:

import osfrom Chapter08 import metrics
from Chapter08.metrics import loss_meteros.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers, Sequential, losses, optimizers, datasets# 加载MNIST数据集
def preprocess(x, y):# 预处理函数x = tf.cast(x, dtype=tf.float32) / 255y = tf.cast(y, dtype=tf.int32)return x, y# 加载MNIST数据集
(x, y), (x_test, y_test) = keras.datasets.mnist.load_data()
# 创建数据集
batchsz = 128
train_db = tf.data.Dataset.from_tensor_slices((x, y))
train_db = train_db.map(preprocess).shuffle(60000).batch(batchsz).repeat(10)
test_db = tf.data.Dataset.from_tensor_slices((x_test, y_test))
test_db = test_db.batch(batchsz)network = Sequential([  # 网络容器layers.Conv2D(6, kernel_size=3, strides=1),  # 第一个卷积层,6个3×3卷积核# 插入BN层layers.BatchNormalization(),layers.MaxPooling2D(pool_size=2, strides=2),  # 高宽各减半的池化层layers.ReLU(),  # 激活函数layers.Conv2D(16, kernel_size=2, strides=1),  # 第二个卷积层,16个3×3卷积核# 插入BN层layers.BatchNormalization(),layers.MaxPooling2D(pool_size=2, strides=2),  # 高宽各减半的池化层layers.ReLU(),  # 激活函数layers.Flatten(),  # 打平层,方便全连接层处理layers.Dense(120, activation='relu'),  # 全连接层,120个节点# 此处也可以插入BN层layers.Dense(84, activation='relu'),  # 全连接层,84个节点# 此处也可以插入BN层layers.Dense(10),  # 全连接层,10个节点
])# build一次网格模型,给输入x的形状,其中4为随意给的batchsize
network.build(input_shape=(4, 28, 28, 1))
# 统计网络信息
network.summary()# 创建损失函数的类,在实际计算时直接调用实例即可
criteon = losses.CategoricalCrossentropy(from_logits=True)optimizer = optimizers.Adam(lr=0.01)# 训练部分实现如下
# 构建梯度记录环境
# 训练20个epochdef train_epoch(epoch):for step, (x, y) in enumerate(train_db):  # 循环优化with tf.GradientTape() as tape:# 插入通道维度,=>[b,28,28,1]x = tf.expand_dims(x, axis=3)# 向前计算,获得10类别的概率分布,[b,784] => [b,10]out = network(x, training=True)# 真实标签one-hot编码,[b] => [b,10]y_onehot = tf.one_hot(y, depth=10)# 计算交叉熵损失函数,标量loss = criteon(y_onehot, out)# 自动计算梯度grads = tape.gradient(loss, network.trainable_variables)# 自动更新参数optimizer.apply_gradients(zip(grads, network.trainable_variables))if step % 100 == 0:print(step, 'loss:', loss_meter.result().numpy())loss_meter.reset_states()# 计算准确度if step % 100 == 0:# 记录预测正确的数量,总样本数量correct, total = 0, 0for x, y in test_db:  # 遍历所有训练集样本# 插入通道维度,=>[b,28,28,1]x = tf.expand_dims(x, axis=3)# 向前计算,获得10类别的概率分布,[b,784] => [b,10]out = network(x, training=False)# 真实的流程时先经过softmax,再argmax# 但是由于softmax不改变元素的大小相对关系,故省去pred = tf.argmax(out,axis=-1)y = tf.cast(y, tf.int64)# 统计预测样本总数correct += float(tf.reduce_sum(tf.cast(tf.equal(pred, y), tf.float32)))# 统计预测样本总数total += x.shape[0]# 计算准确率print('test acc:', correct/total)def train():for epoch in range(30):train_epoch(epoch)if __name__ == '__main__':train()

深度学习之卷积神经网络(8)BatchNorm层相关推荐

  1. 深度学习之卷积神经网络(11)卷积层变种

    深度学习之卷积神经网络(11)卷积层变种 1. 空洞卷积 2. 转置卷积 矩阵角度 转置卷积实现 3. 分离卷积 卷积神经网络的研究产生了各种各样优秀的网络模型,还提出了各种卷积层的变种,本节将重点介 ...

  2. 深度学习之卷积神经网络(7)池化层

    深度学习之卷积神经网络(7)池化层 在卷积层中,可以通过调节步长参数s实现特征图的高宽成倍缩小,从而降低了网络的参数量.实际上,处理通过设置步长,还有一种专门的网络层可以实现尺寸缩减功能,它就是这里要 ...

  3. 深度学习之卷积神经网络(3)卷积层实现

    深度学习之卷积神经网络(3)卷积层实现 1. 自定义权值 2. 卷积层类  在TensorFlow中,既可以通过自定义权值的底层实现方式搭建神经网络,也可以直接调用现成的卷积层类的高层方式快速搭建复杂 ...

  4. 【深度学习】卷积神经网络实现图像多分类的探索

    [深度学习]卷积神经网络实现图像多分类的探索 文章目录 1 数字图像解释 2 cifar10数据集踩坑 3 Keras代码实现流程 3.1 导入数据 3.2 浅层CNN 3.3 深层CNN 3.4 进 ...

  5. 【深度学习】卷积神经网络速成

    [深度学习]卷积神经网络速成 文章目录 [深度学习]卷积神经网络速成 1 概述 2 组成 2.1 卷积层 2.2 池化层 2.3 全连接层 3 一个案例 4 详细分析 1 概述 前馈神经网络(feed ...

  6. 深度学习~卷积神经网络(CNN)概述

    目录​​​​​​​ 1. 卷积神经网络的形成和演变 1.1 卷积神经网络结构 1.2 卷积神经网络的应用和影响 1.3 卷积神经网络的缺陷和视图 1.3.1 缺陷:可能错分 1.3.2 解决方法:视图 ...

  7. 深度学习之卷积神经网络(13)DenseNet

    深度学习之卷积神经网络(13)DenseNet  Skip Connection的思想在ResNet上面获得了巨大的成功,研究人员开始尝试不同的Skip Connection方案,其中比较流行的就是D ...

  8. 深度学习之卷积神经网络(12)深度残差网络

    深度学习之卷积神经网络(12)深度残差网络 ResNet原理 ResBlock实现 AlexNet.VGG.GoogleLeNet等网络模型的出现将神经网络的法阵带入了几十层的阶段,研究人员发现网络的 ...

  9. 深度学习之卷积神经网络(10)CIFAR10与VGG13实战

    深度学习之卷积神经网络(10)CIFAR10与VGG13实战 MNIST是机器学习最常用的数据集之一,但由于手写数字图片非常简单,并且MNIST数据集只保存了图片灰度信息,并不适合输入设计为RGB三通 ...

最新文章

  1. C语言实现bmp图像几何变换(移动,旋转,镜像,转置,缩放)
  2. jquery+bootstrap实现tab切换, 每次切换时都请求数据, 点击提交分别向不同的地址提交数据...
  3. 時鐘,天氣預報--js
  4. Blazor 事件处理开发指南
  5. HDU-2332 机器人的舞蹈 递推
  6. ieee期刊的科技写作思路曹文平_科技论文写作与发表教程(第六版)
  7. 6月6号=》80页-100页
  8. jdk8 源码 比较器
  9. 配置F5 负载均衡(转)
  10. jpa的批量修改_jpa批量处理
  11. Log4Qt 日志格式化(TTCCLayout)
  12. unity的切屏显示顺序
  13. 抄袭爆款:先饱带动后饱!
  14. Echo,Linux上最忧伤的命令(故事)
  15. GeoHash算法详解
  16. 从两家主流报表工具的报jia看报表行业的报jia水深-----常用报表工具对比---主流报表对比
  17. ad 卡尔曼_理解卡尔曼五个方程
  18. 【入门】已知一个圆的半径,求解该圆的面积和周长
  19. early_param分析
  20. 二三四层交换机的区别

热门文章

  1. 编写程序定义一个有 10 个 int 型元素的数组,并以其在数组中的位置作为各元素的初值。
  2. QT之QHash简介
  3. Android开发之虹软人脸识别活体检测基本步骤
  4. linux怎么取消raid磁盘阵列,Linux下彻底关闭某个RAID磁盘阵列
  5. ubuntu16下vue-cli安装
  6. poj1942(求组合数)
  7. ASP.NET 例程完全代码版(5)——通过web.config配置数据库连接池
  8. IP5的接口模式运行测试
  9. 甲骨文将Exadata Cloud转化为内部软件包
  10. 深入Spring:自定义注解加载和使用