https://blog.csdn.net/bigFatCat_Tom/article/details/91619977
https://pytorch-cn.readthedocs.io/zh/latest/package_references/torch-nn/#normalization-layers-source
https://blog.csdn.net/t20134297/article/details/104960101

基本原理

对小批量(mini-batch)3d数据组成的4d[batch_size,num_features,height,width]输入进行批标准化(Batch Normalization)操作

在卷积神经网络的卷积层之后总会添加BatchNorm2d进行数据的归一化处理,这使得数据在进行Relu之前不会因为数据过大而导致网络性能的不稳定,BatchNorm2d()函数数学原理如下:

在每一个小批量(mini-batch)数据中,计算输入各个维度的均值和标准差。gamma与beta是可学习的大小为C的参数向量(C为输入大小)

也就是说:BN层中含有统计数据数值,即均值和方差

在训练时,该层计算每次输入的均值与方差,并进行移动平均。移动平均默认的动量值为0.1。

在验证时,训练求得的均值/方差将用于标准化验证数据。

在训练过程中model.train(),train过程的BN的统计数值—均值和方差是通过当前batch数据估计的。

并且测试时,model.eval()后,若track_running_stats=True,模型此刻所使用的统计数据是Running status 中的,即通过指数衰减规则,积累到当前的数值。否则依然使用基于当前batch数据的估计值。

BN层的统计数据更新是在每一次训练阶段model.train()后的forward()方法中自动实现的,而不是在梯度计算与反向传播中更新optim.step()中完成

从上面的分析可以看出来,正确的冻结BN的方式是在模型训练时,把BN单独挑出来,重新设置其状态为eval (在model.train()之后覆盖training状态

在训练过程中 nn.BatchNorm2d() 的作用是根据统计的mean 和var来对数据进行标准化,并且这个mena和var在每个batch中都会进行,为了使得数据更有统计意义,使得整个训练数据的特征都能够被保存,则在每个batch过程中,都会对网络的mean和var进行更新,这里就涉及到新的 batch的统计数据mean和var与网络已经保存的这两个统计数据之间的取舍问题了,而这个0.8就指定了保存的比例,这个参数名为momentum.

class torch.nn.BatchNorm2d(num_features, eps=1e-05, momentum=0.1, affine=True)

BatchNorm2d()内部的参数如下:

1.num_features:一般输入参数为batch_sizenum_featuresheight*width,即为其中特征的数量

2.eps:分母中添加的一个值, 为保证数值稳定性(分母不能趋近或取0),给分母加上的值。默认为1e-5。

3.momentum:一个用于运行过程中均值和方差的一个移动平均默认的动量值参数(我的理解是一个稳定系数,类似于SGD中的momentum的系数)

4.affine:一个布尔值,当设为true,给该层添加可学习的仿射变换参数。会给定可以学习的系数矩阵gamma和beta 。(不可学习默认是常数1和0).

Shape: - 输入:(N, C,H, W) - 输出:(N, C, H, W)(输入输出相同)

上面的讲解还不够形象,我们具体通过如下的代码进行讲解:

#encoding:utf-8
import torch
import torch.nn as nn
#num_features - num_features from an expected input of size:batch_size*num_features*height*width
#eps:default:1e-5 (公式中为数值稳定性加到分母上的值)
#momentum:动量参数,用于running_mean and running_var计算的值,default:0.1
m=nn.BatchNorm2d(2,affine=True) #affine参数设为True表示weight和bias将被使用
input=torch.randn(1,2,3,4)
output=m(input)print(input)
print(m.weight)
print(m.bias)
print(output)
print(output.size())

输出

tensor([[[[ 1.4174, -1.9512, -0.4910, -0.5675],[ 1.2095,  1.0312,  0.8652, -0.1177],[-0.5964,  0.5000, -1.4704,  2.3610]],[[-0.8312, -0.8122, -0.3876,  0.1245],[ 0.5627, -0.1876, -1.6413, -1.8722],[-0.0636,  0.7284,  2.1816,  0.4933]]]])
Parameter containing:
tensor([0.2837, 0.1493], requires_grad=True)
Parameter containing:
tensor([0., 0.], requires_grad=True)
tensor([[[[ 0.2892, -0.4996, -0.1577, -0.1756],[ 0.2405,  0.1987,  0.1599, -0.0703],[-0.1824,  0.0743, -0.3871,  0.5101]],[[-0.0975, -0.0948, -0.0347,  0.0377],[ 0.0997, -0.0064, -0.2121, -0.2448],[ 0.0111,  0.1232,  0.3287,  0.0899]]]],grad_fn=<NativeBatchNormBackward>)
torch.Size([1, 2, 3, 4])

分析:输入是一个1234 四维矩阵,gamma和beta为一维数组,是针对input[0][0],input[0][1]两个34的二维矩阵分别进行处理的,我们不妨将input[0][0]的按照上面介绍的基本公式来运算,看是否能对的上output[0][0]中的数据。首先我们将input[0][0]中的数据输出,并计算其中的均值和方差。

print("输入的第一个维度:")
print(input[0][0]) #这个数据是第一个3*4的二维数据
#求第一个维度的均值和方差
firstDimenMean=torch.Tensor.mean(input[0][0])
firstDimenVar=torch.Tensor.var(input[0][0],False)   #false表示贝塞尔校正不会被使用
print(m)
print('m.eps=',m.eps)
print(firstDimenMean)
print(firstDimenVar)

输出结果如下:

输入的第一个维度:
tensor([[ 1.4174, -1.9512, -0.4910, -0.5675],[ 1.2095,  1.0312,  0.8652, -0.1177],[-0.5964,  0.5000, -1.4704,  2.3610]])
BatchNorm2d(2, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
m.eps= 1e-05
tensor(0.1825)
tensor(1.4675)

我们可以通过计算器计算出均值和方差均正确计算。最后通过公式计算input[0][0][0][0]的值,代码如下:

batchnormone=((input[0][0][0][0]-firstDimenMean)/(torch.pow(firstDimenVar,0.5)+m.eps))\*m.weight[0]+m.bias[0]
print(batchnormone)

输出结果如下:

tensor(0.2892, grad_fn=<AddBackward0>)

结果值等于output[0][0][0][0]。ok,代码和公式完美的对应起来了。

ps:上面计算方差时有一个贝塞尔校正系数,具体可以通过如下链接参考:https://www.jianshu.com/p/8dbb2535407e

从公式上理解即在计算方差时一般的计算方式如下:

通过贝塞尔校正的样本方差如下:

目的是在总体中选取样本时能够防止边缘数据不被选到。

pytorch——nn.BatchNorm2d()函数相关推荐

  1. Batch Normalization原理及pytorch的nn.BatchNorm2d函数

    下面通过举个例子来说明Batch Normalization的原理,我们假设在网络中间经过某些卷积操作之后的输出的feature map的尺寸为4×3×2×2,4为batch的大小,3为channel ...

  2. python batchnorm2d_BatchNorm2d原理、作用及其pytorch中BatchNorm2d函数的参数讲解

    BN原理.作用: 函数参数讲解: BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) 1. ...

  3. PyTorch学习笔记(1)nn.Sequential、nn.Conv2d、nn.BatchNorm2d、nn.ReLU和nn.MaxPool2d

    文章目录 一.nn.Sequential 二.nn.Conv2d 三.nn.BatchNorm2d 四.nn.ReLU 五.nn.MaxPool2d 一.nn.Sequential torch.nn. ...

  4. PyTorch基础(12)-- torch.nn.BatchNorm2d()方法

    Batch Normanlization简称BN,也就是数据归一化,对深度学习模型性能的提升有很大的帮助.BN的原理可以查阅我之前的一篇博客.白话详细解读(七)----- Batch Normaliz ...

  5. pytorch的nn.CrossEntropyLoss()函数使用方法

    nn.CrossEntropyLoss()函数计算交叉熵损失 用法: # output是网络的输出,size=[batch_size, class] #如网络的batch size为128,数据分为1 ...

  6. pytorch中批量归一化BatchNorm1d和BatchNorm2d函数

    class torch.nn.BatchNorm1d(num_features, eps=1e-05, momentum=0.1, affine=True) [source] 对小批量(mini-ba ...

  7. pytorch之torch.nn.Conv2d()函数详解

    文章目录 一.官方文档介绍 二.torch.nn.Conv2d()函数详解 参数详解 参数dilation--扩张卷积(也叫空洞卷积) 参数groups--分组卷积 三.代码实例 一.官方文档介绍 官 ...

  8. Pytorch —— nn.Module类(nn.sequential)

    对于前面的线性回归模型. Logistic回归模型和神经网络,我们在构建的时候定义了需要的参数.这对于比较小的模型是可行的,但是对于大的模型,比如100 层的神经网络,这个时候再去手动定义参数就显得非 ...

  9. Pytorch的RELU函数

    4.1.2 激活函数 PyTorch实现了常见的激活函数,其具体的接口信息可参见官方文档1,这些激活函数可作为独立的layer使用.这里将介绍最常用的激活函数ReLU,其数学表达式为: 代码: rel ...

  10. nn.Linear()函数详解

    nn.Linear()函数详解 torch.nn.Linear(in_features, out_features, bias=True, device=None, dtype=None)[原文地址] ...

最新文章

  1. spring cloud快速搭建
  2. 图像分类、检测,语义分割等方法梳理
  3. 致敬2016,拥抱2017
  4. JDK源码(12)-Enum
  5. 理想更新“货车并线预警”遭用户吐槽 李想:仍在优化
  6. Spring Cloud Stream与RabbitMQ整合
  7. 自用迷你版的Deferred
  8. DateTime Calendar
  9. liunx上mysql源码安装mysql_搞定linux上MySQL编程(一):linux上源码安装MySQL
  10. 大学生必看的一分钟——俞洪敏语录
  11. 拓端tecdat|R语言基于ARMA-GARCH过程的VaR拟合和预测
  12. 关于julia的路径问题,往往很重要!
  13. 一、MySQL数据库优化策略
  14. 微服务结合领域驱动设计落地
  15. GCC --verbose选项, -lpthread 和-pthread的区别
  16. oracle的sql硬解析和软解析,[ORACLE]oracle SQL执行过程 软解析(soft prase)硬解析(hard prase)以及 Soft Soft Parse...
  17. 家长警惕 这4类孩子最易反复感冒
  18. 聊天文字在气泡背景图片上的展示
  19. 多功能搜索友联自助交换多色彩皮肤网站图片本地化附带交易系统网址导航源码蜘蛛
  20. 【border相关】【P3426】 [POI2005]SZA-Template

热门文章

  1. Springboot项目中static文件和templates文件的区别
  2. java componentorientation_Java JLabel.applyComponentOrientation方法代码示例
  3. android webview 太大,android – ScrollView中的WebView:“查看太大而无法适应绘图缓存” – 如何重新布局?...
  4. oracle中外键的使用方法,Oracle数据库中外键的相关操作整理
  5. Flutter跨组件共享状态的利器Provider原理解析
  6. mysql 索引重复 更新_MySQL——ON DUPLICATE KEY UPDATE添加索引值实现重复插入变更update...
  7. 阿里巴巴高级技术专家章剑锋:大数据发展的 8 个要点
  8. Kotlin实战【三】表示与选择
  9. python基础篇–变量和简单的数据类型(下)
  10. mysql客户端备份数据库失败,mysqlhotcopy的使用和安装方法【快速备份mysql数据库】及错误解...