1、Batch Normalization概念

Batch Normalization:批标准化

批: 一批数据,通常为mini-batch
标准化: 0均值,1方差

优点:

  1. 可以用更大学习率,加速模型收敛;
  2. 可以不用精心设计权值初始化;
  3. 可以不用dropout或较小的dropout;
  4. 可以不用L2或者较小的weight decay;
  5. 可以不用LRN(local response normalization局部响应值的标准化)


上面伪代码中最后一部分是affine transfrom,也就是scale and shift,公式中的gamma和beta是可学习参数,可以根据loss反向传播更新参数。

为什么在进行normalize更新之后要加一个affine transform呢?这一步可以增强模型的容纳能力,使模型更灵活,选择性更多,可以让模型判断是否需要对模型进行变换。

这个方法是在论文《Batch Normalization:Accelerating Deep Network Training by Reducing Internal Covariate Shift》提出的,主要是为了解决ICS问题(Internal Covariate Shift数据尺度变化)。

2、Pytorch的Batch Normalization 1d/2d/3d实现

Pytorch中nn.Batchnorm1d、nn.Batchnorm2d、nn.Batchnorm3d都继承于基类_Batchnorm;

2.1 _BatchNorm

_BatchNorm的主要参数

  • num_features:一个样本特征数量(最重要);
  • eps:分母修正项,避免分母为零;
  • momentum:指数加权平均估计当前mean/var;
  • affine:布尔变量,是否需要affine transform;
  • track_running_stats:训练状态还是测试状态;如果是训练状态,mean/var需要不断计算更新;如果在测试状态,mean/var是固定的;
def __init__(self,num_features,eps=1e-5,momentum=0.1,affine=True,track_running_stats=True)

2.2 nn.BatchNorm1d/nn.BatchNorm2d/nn.NatchNorm3d

nn.BatchNorm1d/nn.BatchNorm2d/nn.NatchNorm3d的主要属性

  • running_mean:均值;
  • running_var:方差;
  • weight:affine transform中的gamma;
  • bias:affine transform中的beta;

BN的公式:x^i←xi−μBσB2+ϵ\widehat{x}_{i} \leftarrow \frac{x_{i}-\mu_{\mathcal{B}}}{\sqrt{\sigma_{\mathcal{B}}^{2}+\epsilon}}xi​←σB2​+ϵ​xi​−μB​​yi←γx^i+β≡BNγ,β(xi)y_{i} \leftarrow \gamma \widehat{x}_{i}+\beta \equiv \mathrm{B} \mathrm{N}_{\gamma, \beta}\left(x_{i}\right)yi​←γxi​+β≡BNγ,β​(xi​)

BN中的均值和方差在训练的时候采用指数加权平均进行计算,在测试时使用当前统计值:runningmean=(1−momentum)∗pre_running_mean+momentum∗mean_trunning_mean = (1-momentum) * pre_{\_}running_{\_}mean + momentum * mean_{\_}trunningm​ean=(1−momentum)∗pre_​running_​mean+momentum∗mean_​trunning_=(1−momentum)∗pre_running_var+momentum∗var_trunning_{\_}=(1-momentum)*pre_{\_}running_{\_}var+momentum*var_{\_}trunning_​=(1−momentum)∗pre_​running_​var+momentum∗var_​t

2.3 nn.BatchNorm1d/nn.BatchNorm2d/nn.NatchNorm3d对数据的要求

  • nn.BatchNorm1d input = Batch_size * 特征数 * 1d特征维度
  • nn.BatchNorm2d input = Batch_size * 特征数 * 2d特征维度
  • nn.BatchNorm3d input = Batch_size * 特征数 * 3d特征维度

2.3.1 nn.BatchNorm1d

在全连接层使用的就是nn.BatchNorm1d,全连接层中的每一个神经元就是一个特征,假设一个网络层有五个特征,也就是一个网络层有五个神经元,如下图中的每一列是一个数据,每个数据有5个特征作为网络层的输入,每一个特征的维度是红色圆圈圈出的部分,维度为1,这样就构成了一个样本的一个特征。

每次训练数据组成一个batch,假设一个batch有三个样本,这样的三个样本组成的batch就构成了nn.BatchNorm1d的输入数据形式,输入数据的形式为[3,5,1],有时候1可以忽略,因此可以表示为[3,5];

我们知道,nn.BatchNorm1d有四个参数需要计算,这四个参数需要在特征维度上进行计算,如上图,现在有三个样本,每个样本有五个特征,需要在三个样本的同样位置的特征上求取均值、方差、gamma和beta,在每一个特征维度上都有对应的均值、方差、gamma和beta。

下面通过代码学习nn.BatchNorm1d:

    batch_size = 3  # batch_sizenum_features = 5  # 每一个数据的特征个数momentum = 0.3features_shape = (1)  # 特征维度为1feature_map = torch.ones(features_shape)  # [1]                                                        # 1Dfeature_maps = torch.stack([feature_map*(i+1) for i in range(num_features)], dim=0)  # [1,2,3,4,5]     # 2Dfeature_maps_bs = torch.stack([feature_maps for i in range(batch_size)], dim=0)  # [[][][]]            # 3Dprint("input data:\n{} shape is {}".format(feature_maps_bs, feature_maps_bs.shape))bn = nn.BatchNorm1d(num_features=num_features, momentum=momentum)running_mean, running_var = 0, 1for i in range(2):outputs = bn(feature_maps_bs)print("\niteration:{}, running mean: {} ".format(i, bn.running_mean))print("iteration:{}, running var:{} ".format(i, bn.running_var))mean_t, var_t = 2, 0running_mean = (1 - momentum) * running_mean + momentum * mean_trunning_var = (1 - momentum) * running_var + momentum * var_tprint("iteration:{}, 第二个特征的running mean: {} ".format(i, running_mean))print("iteration:{}, 第二个特征的running var:{}".format(i, running_var))

通过运行代码,得到输出为:

iteration:0, running mean: tensor([0.3000, 0.6000, 0.9000, 1.2000, 1.5000])
iteration:0, running var:tensor([0.7000, 0.7000, 0.7000, 0.7000, 0.7000])
iteration:0, 第二个特征的running mean: 0.6
iteration:0, 第二个特征的running var:0.7iteration:1, running mean: tensor([0.5100, 1.0200, 1.5300, 2.0400, 2.5500])
iteration:1, running var:tensor([0.4900, 0.4900, 0.4900, 0.4900, 0.4900])
iteration:1, 第二个特征的running mean: 1.02
iteration:1, 第二个特征的running var:0.48999999999999994

2.3.2 nn.BatchNorm2d

nn.BatchNorm2d和nn.BatchNorm1d输入数据的主要不同在于特征维度上,卷积神经网络输出的一个特征图就是二维的形式。

如下图,假设一个特征图的维度为22,一个层有三个卷积核,会输出三个通道的22的特征图,一个特征图在BN中理解为一个特征,BN会在一个特征上求取均值、方差、gamma和beta。因此在nn.BatchNorm2d中输入数据的形式为[3,3,2,2]。

下面通过代码研究nn.BatchNorm2d的具体使用:

    batch_size = 3num_features = 6momentum = 0.3features_shape = (2, 2)feature_map = torch.ones(features_shape)  # 2d                                                     # 2Dfeature_maps = torch.stack([feature_map*(i+1) for i in range(num_features)], dim=0)  # 3d         # 3Dfeature_maps_bs = torch.stack([feature_maps for i in range(batch_size)], dim=0)  # 4d             # 4Dprint("input data:\n{} shape is {}".format(feature_maps_bs, feature_maps_bs.shape))bn = nn.BatchNorm2d(num_features=num_features, momentum=momentum)running_mean, running_var = 0, 1for i in range(2):outputs = bn(feature_maps_bs)print("\niter:{}, running_mean.shape: {}".format(i, bn.running_mean.shape))print("iter:{}, running_var.shape: {}".format(i, bn.running_var.shape))print("iter:{}, weight.shape: {}".format(i, bn.weight.shape))print("iter:{}, bias.shape: {}".format(i, bn.bias.shape))

q代码对应输出为:

iter:0, running_mean.shape: torch.Size([6])
iter:0, running_var.shape: torch.Size([6])
iter:0, weight.shape: torch.Size([6])
iter:0, bias.shape: torch.Size([6])iter:1, running_mean.shape: torch.Size([6])
iter:1, running_var.shape: torch.Size([6])
iter:1, weight.shape: torch.Size([6])
iter:1, bias.shape: torch.Size([6])

2.3.3 nn.BatchNorm3d

下图所示为nn.BatchNorm3d的输入数据形式,一个数据的一个特征是3维的,其形式为[2,2,3],一个数据有3个特征,一共有3个样本,所以nn.BatchNorm3d的输入数据形式为[3,3,2,2,3]。

nn.BatchNorm3d的代码如下:

    batch_size = 3num_features = 4momentum = 0.3features_shape = (2, 2, 3)feature = torch.ones(features_shape)                                                # 3Dfeature_map = torch.stack([feature * (i + 1) for i in range(num_features)], dim=0)  # 4Dfeature_maps = torch.stack([feature_map for i in range(batch_size)], dim=0)         # 5Dprint("input data:\n{} shape is {}".format(feature_maps, feature_maps.shape))bn = nn.BatchNorm3d(num_features=num_features, momentum=momentum)running_mean, running_var = 0, 1for i in range(2):outputs = bn(feature_maps)print("\niter:{}, running_mean.shape: {}".format(i, bn.running_mean.shape))print("iter:{}, running_var.shape: {}".format(i, bn.running_var.shape))print("iter:{}, weight.shape: {}".format(i, bn.weight.shape))print("iter:{}, bias.shape: {}".format(i, bn.bias.shape))

pytorch —— Batch Normalization相关推荐

  1. 深度学习总结:用pytorch做dropout和Batch Normalization时需要注意的地方,用tensorflow做dropout和BN时需要注意的地方,

    用pytorch做dropout和BN时需要注意的地方 pytorch做dropout: 就是train的时候使用dropout,训练的时候不使用dropout, pytorch里面是通过net.ev ...

  2. Batch Normalization原理及pytorch的nn.BatchNorm2d函数

    下面通过举个例子来说明Batch Normalization的原理,我们假设在网络中间经过某些卷积操作之后的输出的feature map的尺寸为4×3×2×2,4为batch的大小,3为channel ...

  3. PyTorch框架学习十七——Batch Normalization

    PyTorch框架学习十七--Batch Normalization 一.BN的概念 二.Internal Covariate Shift(ICS) 三.BN的一个应用案例 四.PyTorch中BN的 ...

  4. Pytorch中的Batch Normalization操作

    之前一直和小伙伴探讨batch normalization层的实现机理,作用在这里不谈,知乎上有一篇paper在讲这个,链接 这里只探究其具体运算过程,我们假设在网络中间经过某些卷积操作之后的输出的f ...

  5. 【学习笔记】Pytorch深度学习—Batch Normalization

    [学习笔记]Pytorch深度学习-Batch Normalization Batch Normalization概念 `Batch Normalization ` `Batch Normalizat ...

  6. Batch Normalization详解以及pytorch实验

    Batch Normalization是google团队在2015年论文<Batch Normalization: Accelerating Deep Network Training by R ...

  7. PyTorch 深度学习:31分钟快速入门——Batch Normalization

    Batch Normalization¶ 前面在数据预处理的时候,我们尽量输入特征不相关且满足一个标准的正态分布,这样模型的表现一般也较好.但是对于很深的网路结构,网路的非线性层会使得输出的结果变得相 ...

  8. batch normalization

    20210702 深度学习中的五种归一化(BN.LN.IN.GN和SN)方法简介 https://blog.csdn.net/u013289254/article/details/99690730 h ...

  9. 深度神经网络中的Batch Normalization介绍及实现

    之前在经典网络DenseNet介绍_fengbingchun的博客-CSDN博客_densenet中介绍DenseNet时,网络中会有BN层,即Batch Normalization,在每个Dense ...

最新文章

  1. 华为服务器面板显示,服务器面板怎么查看
  2. Acwing第 5 场周赛【未完结】
  3. DP 状态机模型 AcWing算法提高课 详解
  4. Pycharm新建文件时自动添加基础信息
  5. 以太坊开发入门,如何搭建一个区块链DApp投票系统
  6. php 精度比较,PHP浮点数精度和比较
  7. C++中在使用自定义类型(结构体类型)的stl数据结构时,operate的用法
  8. 适用于中小型公司代理服务器的IPTABLES脚本
  9. 第十篇 requests模块
  10. c++求平均值_2020五一建模:C题 饲料混合加工(二)
  11. Texlive 2021安装卡在be patient解决方案
  12. 学生成绩预测模型_学生成绩分析预测
  13. 备案后可以改服务器信息吗,域名备案后可以更改服务器
  14. 隔离升压电源模块24V功率可达40W宽电压输入高电压稳压输出
  15. always_comb,always_ff,和always_latch语句
  16. EUI组件之HScrollBar VScrollBar (动态设置滑块图片)
  17. dedecms{dede:sql}{dede:php}标签的用法
  18. 在这里,有人用10万块的电脑刷贴吧,有人用200块的电脑打LOL
  19. 解决TortoiseSVN文件夹没有绿色对号
  20. Origin中画折线图实现双X轴和双Y轴(双坐标轴)

热门文章

  1. 大型企业网络设备管理
  2. MySQL迁移到ClickHouse方案
  3. 容器编排技术 -- Kubernetes 为 Namespace 设置最小和最大内存限制
  4. Chrome谷歌插件开发-01
  5. 如何区分localhost、127.0.0.1和0.0.0.0等ip地址
  6. #JAVA# 判断从键盘输入的字符串是否为回文
  7. 不重启docker容器修改 容器中的时区
  8. docker安装redis并将配置文件和数据文件映射到外部
  9. JAVA对接支付宝支付(超详细,一看就懂)
  10. 《高性能mysql》读书笔记一