Batch Normalization是google团队在2015年论文《Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift》提出的。通过该方法能够加速网络的收敛并提升准确率。在网上虽然已经有很多相关文章,但基本都是摆上论文中的公式泛泛而谈,bn真正是如何运作的很少有提及。本文主要分为以下几个部分:

(1)BN的原理

(2)使用pytorch验证本文的观点

(3)使用BN需要注意的地方(BN没用好就是个坑)

1.Batch Normalization原理

我们在图像预处理过程中通常会对图像进行标准化处理,这样能够加速网络的收敛,如下图所示,对于Conv1来说输入的就是满足某一分布的特征矩阵,但对于Conv2而言输入的feature map就不一定满足某一分布规律了(注意这里所说满足某一分布规律并不是指某一个feature map的数据要满足分布规律,理论上是指整个训练样本集所对应feature map的数据要满足分布规律)。而我们Batch Normalization的目的就是使我们的feature map满足均值为0,方差为1的分布规律。

看到这里应该还是蒙的,不要慌,喝口水,慢慢来。下面是从原论文中截取的原话,注意标黄的部分:

“对于一个拥有d维的输入x,我们将对它的每一个维度进行标准化处理。”  假设我们输入的x是RGB三通道的彩色图像,那么这里的d就是输入图像的channels即d=3,,其中就代表我们的R通道所对应的特征矩阵,依此类推。标准化处理也就是分别对我们的R通道,G通道,B通道进行处理。上面的公式不用看,原文提供了更加详细的计算公式:

我们刚刚有说让feature map满足某一分布规律,理论上是指整个训练样本集所对应feature map的数据要满足分布规律,也就是说要计算出整个训练集的feature map然后在进行标准化处理,对于一个大型的数据集明显是不可能的,所以论文中说的是Batch Normalization,也就是我们计算一个Batch数据的feature map然后在进行标准化(batch越大越接近整个数据集的分布,效果越好)。我们根据上图的公式可以知道代表着我们计算的feature map每个维度(channel)的均值,注意是一个向量不是一个值向量的每一个元素代表着一个维度(channel)的均值。代表着我们计算的feature map每个维度(channel)的方差,注意是一个向量不是一个值向量的每一个元素代表着一个维度(channel)的方差,然后根据计算标准化处理后得到的值。下图给出了一个计算均值和方差的示例:

上图展示了一个batch size为2(两张图片)的Batch Normalization的计算过程,假设feature1、feature2分别是由image1、image2经过一系列卷积池化后得到的特征矩阵,feature的channel为2,那么代表该batch的所有feature的channel1的数据,同理代表该batch的所有feature的channel2的数据。然后分别计算的均值与方差,得到我们的两个向量。然后在根据标准差计算公式分别计算每个channel的值(公式中的是一个很小的常量,防止分母为零的情况)。在我们训练网络的过程中,我们是通过一个batch一个batch的数据进行训练的,但是我们在预测过程中通常都是输入一张图片进行预测,此时batch size为1,如果在通过上述方法计算均值和方差就没有意义了。所以我们在训练过程中要去不断的计算每个batch的均值和方差,并使用移动平均(moving average)的方法记录统计的均值和方差,在训练完后我们可以近似认为所统计的均值和方差就等于整个训练集的均值和方差。然后在我们验证以及预测过程中,就使用统计得到的均值和方差进行标准化处理

细心的同学会发现,在原论文公式中不是还有两个参数吗?是的,是用来调整数值分布的方差大小,是用来调节数值均值的位置。这两个参数是在反向传播过程中学习得到的,的默认值是1,的默认值是0。

2.使用pytorch进行试验

你以为你都懂了?不一定哦。刚刚说了在我们训练过程中,均值和方差是通过计算当前批次数据得到的记为为,而我们的验证以及预测过程中所使用的均值方差是一个统计量记为的具体更新策略如下,其中momentum默认取0.1:

这里要注意一下,在pytorch中对当前批次feature进行bn处理时所使用的总体标准差,计算公式如下:

在更新统计量时采用的样本标准差,计算公式如下:

下面是我使用pytorch做的测试,代码如下:

(1)bn_process函数是自定义的bn处理方法验证是否和使用官方bn处理方法结果一致。在bn_process中计算输入batch数据的每个维度(这里的维度是channel维度)的均值和标准差(标准差等于方差开平方),然后通过计算得到的均值和总体标准差对feature每个维度进行标准化,然后使用均值和样本标准差更新统计均值和标准差。

(2)初始化统计均值是一个元素为0的向量,元素个数等于channel深度;初始化统计方差是一个元素为1的向量,元素个数等于channel深度,初始化=1,=0。

import numpy as np
import torch.nn as nn
import torchdef bn_process(feature, mean, var):feature_shape = feature.shapefor i in range(feature_shape[1]):# [batch, channel, height, width]feature_t = feature[:, i, :, :]mean_t = feature_t.mean()# 总体标准差std_t1 = feature_t.std()# 样本标准差std_t2 = feature_t.std(ddof=1)# bn process# 这里记得加上eps和pytorch保持一致feature[:, i, :, :] = (feature[:, i, :, :] - mean_t) / np.sqrt(std_t1 ** 2 + 1e-5)# update calculating mean and varmean[i] = mean[i] * 0.9 + mean_t * 0.1var[i] = var[i] * 0.9 + (std_t2 ** 2) * 0.1print(feature)# 随机生成一个batch为2,channel为2,height=width=2的特征向量
# [batch, channel, height, width]
feature1 = torch.randn(2, 2, 2, 2)
# 初始化统计均值和方差
calculate_mean = [0.0, 0.0]
calculate_var = [1.0, 1.0]
# print(feature1.numpy())# 注意要使用copy()深拷贝
bn_process(feature1.numpy().copy(), calculate_mean, calculate_var)bn = nn.BatchNorm2d(2, eps=1e-5)
output = bn(feature1)
print(output)

首先我在最后设置了一个断点进行调试,查看下官方bn对feature处理后得到的统计均值和方差。我们可以发现官方提供的bn的running_mean和running_var和我们自己计算的calculate_mean和calculate_var是一模一样的(只是精度不同)。

然后我们打印出通过自定义bn_process函数得到的输出以及使用官方bn处理得到输出,明显结果是一样的(只是精度不同):

3.使用BN时需要注意的问题

(1)训练时要将traning参数设置为True,在验证时将trainning参数设置为False。在pytorch中可通过创建模型的model.train()和model.eval()方法控制。

(2)batch size尽可能设置大点,设置小后表现可能很糟糕,设置的越大求的均值和方差越接近整个训练集的均值和方差。

(3)建议将bn层放在卷积层(Conv)和激活层(例如Relu)之间,且卷积层不要使用偏置bias,因为没有用,参考下图推理,即使使用了偏置bias求出的结果也是一样的

最后给出李宏毅老师关于batch normalization的视频讲解:

李宏毅深度学习(2017)_哔哩哔哩_bilibili

Batch Normalization详解以及pytorch实验相关推荐

  1. Batch Normalization详解(原理+实验分析)

    Batch Normalization详解(原理+实验分析) 1. 计算过程 2. 前向传播过程 3. 反向传播过程 4. 实验分析 4.1 实验一:验证有没有BatchNorm下准确率的区别 4.2 ...

  2. 批归一化(Batch Normalization)详解

    批归一化(Batch Normalization)详解 文章目录 批归一化(Batch Normalization)详解 前言 一.数据归一化 二.BN解决的问题:Internal Covariate ...

  3. batch normalization详解

    1.引入BN的原因 1.加快模型的收敛速度 2.在一定程度上缓解了深度网络中的"梯度弥散"问题,从而使得训练深层网络模型更加容易和稳定. 3.对每一批数据进行归一化.这个数据是可以 ...

  4. 【深度学习】Batch Normalization详解

    Batch Normalization 学习笔记 原文地址:http://blog.csdn.net/hjimce/article/details/50866313 作者:hjimce 一.背景意义 ...

  5. Conditional Batch Normalization 详解(SFT思路来源)

    Conditional Batch Normalization 的概念来源于这篇文章:Modulating early visual processing by language .后来又先后被用在  ...

  6. 【小白学PyTorch】13.EfficientNet详解及PyTorch实现

    <<小白学PyTorch>> 小白学PyTorch | 12 SENet详解及PyTorch实现 小白学PyTorch | 11 MobileNet详解及PyTorch实现 小 ...

  7. 【小白学PyTorch】12.SENet详解及PyTorch实现

    <<小白学PyTorch>> 小白学PyTorch | 11 MobileNet详解及PyTorch实现 小白学PyTorch | 10 pytorch常见运算详解 小白学Py ...

  8. python batch normalization_Batch Normalization 详解

    一.背景意义 本篇博文主要讲解2015年深度学习领域,非常值得学习的一篇文献:<Batch Normalization: Accelerating Deep Network Training b ...

  9. GoogLeNet——CNN经典网络模型详解(pytorch实现)

    一.前言 论文地址:http://arxiv.org/abs/1602.07261 2014年,GoogLeNet和VGG是当年ImageNet挑战赛(ILSVRC14)的双雄,GoogLeNet获得 ...

最新文章

  1. LeetCode 74. Search a 2D Matrix--有序矩阵查找--python,java,c++解法
  2. 排序算法中——归并排序和快速排序
  3. 解决: AOSP 编译AndroidQ preview 失败
  4. 另一种公钥私钥认证方式
  5. Spring Boot 多模块项目实践(附打包方法)
  6. 论初始值的重要性-仅仅是更改初始值loss差别就非常大
  7. SpringBoot+Jquery+jsTree实现页面树型结构
  8. 数仓大法好!跨境电商 Shopee 的实时数仓之路
  9. Postgres 异常断电导致启动失败的解决方法
  10. ISP图像调试工程师——3D和2D降噪(熟悉图像预处理和后处理技术)
  11. html =拼接dom,在js代码拼接dom对象到页面上去的模板总结(必看)
  12. 谨慎选择镭射祛斑,极易反黑!一定要做好防晒,否则会变成永无止尽的黑斑地狱!
  13. 计算机网络基础知识点总结
  14. linux 查看gc情况
  15. 聊下如何设计知识中台?(附代码)
  16. 慧荣SM2262EN跑RDT教程
  17. 无处不在的微创新——验证码的故事
  18. 百度地图定位功能的错误has leaked ServiceConnection 解决
  19. ssm毕设项目康健医药公司进销存管理22jao(java+VUE+Mybatis+Maven+Mysql+sprnig)
  20. java基础之线程概述_繁星漫天_新浪博客

热门文章

  1. AcceptChanges()和RejectChanges()原理
  2. 智源大会自然语言处理论坛精华观点 | 刘群、陶建华、刘挺、黄萱菁、刘洋等解读NLP最新趋势...
  3. unity3d学习笔记-光照(4.烘焙照明baked lighting)
  4. c语言题目 生日 星座 出生石,输入月份打出星座问题,大神帮忙看看,为什么后面输入前面行而后面不行了?...
  5. mysql nvl2 函数_Oracle的nvl函数和nvl2函数详解
  6. php正则表达式经典实例,php正则表达式学习示例
  7. 关于镂空正方体的作图
  8. TensorFlow2学习:RNN生成古诗词
  9. 拼多多售后有哪些处理技巧?店盈通电商解答
  10. python--猜数字小游戏