近期在实验的过程中,发现了一个问题。实验结果受Batch_size影响在波动。Batch_size在分别为4和8的时候,实验效果不同,当Batch_size为8的时候实验效果还要差一点.这一点我很意外,当所有超参设置都一样的时候,BN层效果与Batch_size大小的影响应该是呈现一个正相关.这里放一张Kaiming He在2018ECCV上发表的《Group Normalization》论文中的一张图。图中蓝色曲线是ResNet50在ImageNet上分类的错误率,当Batch_size越大的时候,错误率越小。那么按照论文中的思想。为什么我们的实验结果跟它相反?

本文主要为https://blog.csdn.net/qq_37541097/article/details/104434557该文章做一个简单的补充。

BN层

网上关于这个的讲解非常多,我个人理解就是将网络中间层进行规范化操作,减去均值除以方差。使得feature map数据满足分布规律。加速网络收敛.

下面我讲一下各个参数代表的含义


求图像的均值,这里的均值计算如下,在Pytorch中,数据维度为【B,C,H,W】。那么控制Channel不变,也即在一个通道下计算该Batch_size下所有图像的均值。

同上一样,控制Channel不变,也即在一个通道下计算该Batch_size下所有图像的方差

原像素点值减均值除以方差,这里加一个eps是为了防止除数为0,这个eps是一个constant,不需要做BP的。

这里就是对规范化后的x在进行一次更新参数,其中gamma,beta都是可学习的参数。(Pytorch源码中这些值都是Paramters参数.代表需要做BP的可学习的).这个操作我觉得是增加模型的鲁棒性.其中gamma和beta分别初始化为了1和0.

源码详解:

import numpy as np
import torch.nn as nn
import torch
def bn_process(feature, mean, var):feature_shape = feature.shapefor i in range(feature_shape[1]):feature_t = feature[:, i, :, :] #单通道均值mean_t = feature_t.mean()#总体标准差std_t1 = feature_t.std()#样本标准差std_t2 = feature_t.std(ddof=1)#bn processfeature[:, i, :, :] = (feature_t-mean_t)/np.sqrt(std_t1**2+1e-5)#updata caculating mean and varmean[i] = mean[i]*0.9+mean_t*0.1var[i] = var[i]*0.9+(std_t2**2)*0.1print(feature)feature1 = torch.randn(size=(2, 2, 2, 2))
caculate_mean= [0.0, 0.0]
caculate_var = [1.0, 1.0]bn = nn.BatchNorm2d(2, eps=1e-5)
output = bn(feature1)
bn_process(feature1.numpy().copy(), caculate_mean, caculate_var)
print(output)

在BN层中,主要涉及到四个需要更新的参数,分别是running_mean,running_var,weight,bias。这里的weight,bias是Pytorch官方实现中的叫法,有点误导人,其实weight就是gamma,bias就是beta。当然它这样的叫法也符合实际的应用场景。其实gamma,beta就是对规范化后的值进行一个加权求和操作running_mean,running_var是当前所求得的所有batch_size下的均值和方差,每经过一个mini_batch我们都会更新running_mean,running_var.为什么要更新它?因为测试的时候,往往是一个一个的图像feed至网络的,如果你在这里对其进行计算均值方差显然是不合理的,所以model.eval()这个语句就是控制BN层中的running_mean,running_std不更新。采用训练结束后的running_mean,running_std来规范化该张图像。

debug一下

可以发现bn层中parameters中只有weight和bias,这也证明了我们上面所说的是正确的.同时还有两个另外两个重要的参数就是running_mean,running_var


这里可以看出running_mean和running_var中的required_grad都是False.言外之意这是不需要BP更新的,实际上他是在前向传播中更新的,具体更新如下,momentum默认值是0.1。

BN层(Pytorch)相关推荐

  1. Pytorch中BN层入门思想及实现

    批归一化层-BN层(Batch Normalization) 作用及影响: 直接作用:对输入BN层的张量进行数值归一化,使其成为均值为零,方差为一的张量. 带来影响: 1.使得网络更加稳定,结果不容易 ...

  2. 【PyTorch】eval() ==>主要是针对某些在train和predict两个阶段会有不同参数的层,比如Dropout层和BN层

    model的eval方法主要是针对某些在train和predict两个阶段会有不同参数的层.比如Dropout层和BN层 torch为了方便大家,设计这个eval方法就是让我们可以不用手动去针对这些层 ...

  3. 【pytorch】BN层计算

    官方文档 有一个针对BN层的详细的理解: Pytorch的BatchNorm层使用中容易出现的问题 class torch.nn.BatchNorm2d(num_features, eps=1e-05 ...

  4. pytorch 批量归一化BatchNorm1d和BatchNorm2d的用法、BN层参数 running_mean running_var变量计算 验证

    前提知识 BN层包括mean var gamma beta四个参数,.对于图像来说(4,3,2,2),一组特征图,一个通道的特征图对应一组参数,即四个参数均为维度为通道数的一维向量,图中gamma.b ...

  5. PyTorch中BN层与CONV层的融合(merge_bn)

    之前发了很久之前写好的一篇关于Caffe中merge_bn的博客,详情可见 Caffe中BN层与CONV层的融合(merge_bn) 今天由于工作需要要对PyTorch模型进行merge_bn,发现网 ...

  6. PyTorch中BN层中新加的 num_batches_tracked 有什么用?

    从PyTorch 0.4.1开始, BN层中新增加了一个参数 track_running_stats, BatchNorm2d(128, eps=1e-05, momentum=0.1, affine ...

  7. pytorch实现卷积层和BN层融合

    批归一化 数据的规范化也即(x−mean(x)/var(x))(x-mean(x)/var(x))(x−mean(x)/var(x)) 可以将数据的不同特征规范到均值为0,方差相同的标准正态分布,这就 ...

  8. BN层(Pytorch讲解)

    批量归一化(BN:Batch Normalization 解决在训练过程中,中间层数据分布过度异常问题,让数据分布符合正态分布 主要达到三个目的: 1.加快网络的训练和收敛的速度 2.控制梯度爆炸 3 ...

  9. 【剑指offer】BN层详解

    [剑指offer]系列文章目录 梯度消失和梯度爆炸 交叉熵损失函数 文章目录 [剑指offer]系列文章目录 BN层的本质原理 BN层的优点总结 BN层的过程 代码实现 BN层的本质原理 BN层(Ba ...

  10. Caffe中BN层与CONV层的融合(merge_bn)

    半年前写的博客,今天发现没有发出去,还好本地有md的文档,决定重新发一下 毕竟网上来回抄袭的blog太多了,代码质量也莫得保证 今天需要用pytorch融合下bn层,写个脚本稍后再传到csdn上 原理 ...

最新文章

  1. php 操作分表代码
  2. 一文了解文件上传全过程(1.8w字深度解析)「前端进阶必备」
  3. 定义应用程序的基础--模式(Bridge-桥接,Factory-工厂)
  4. uva437The Tower of Babylon
  5. java中两个map的融合(两个map有相同字段)
  6. 前端学习(1271):async/await处理多个异步请求
  7. mysql统计今天发布了多少条_Mysql统计总结 - 最近30天,昨天的数据统计
  8. 微信果断出手 将封禁拼团砍价链接,网友:终于可以清静了
  9. MongoDB中文问题
  10. HighCharts 详细使用及API文档说明
  11. 企业经常说绩效管理难,误区在哪?附绩效管理系统解决方案
  12. 串行加法器和并行加法器_N位并行加法器(4位二进制加法器和减法器)
  13. 玩单片机需要学数电、模电吗?
  14. 基于阿里语音识别(ASR)C/C++ SDK2.0编写的unimrcp中间件
  15. Figma常用快捷键(Mac版)
  16. C++中的仿函数(functors)和仿函数适配器(adapter function)
  17. SQL篇·Oracle字段根据逗号等分割
  18. 没有对比就没有伤害:《明日之后》竟成最良心国产末日手游?
  19. 酷派手机(Coolpad 8297-T01)在Android开发工具如AndroidStudio、Eclipse中无法打印Log
  20. VB编程:DoWhile...Loop当循环计算0~100累加和-15_彭世瑜_新浪博客

热门文章

  1. Android Effect 解析
  2. 有关EEPROM AT24C02字节写入和页写入
  3. 3dmax学习7——车削修改器
  4. 技术员 Ghost Win 7 Sp1(x86/x64)装机版/纯净版 201808
  5. 2020人脸识别报告:上万家企业入局,八大技术六个趋势一文看尽
  6. Struts2学习笔记(4)-ActionSupport类及Action接口详解
  7. Visio 2003 Professional 安装序列号
  8. 用gambit学博弈论--完全信息动态博弈(一)
  9. kali局域网扫描ip_kali 扫描局域网的QQ
  10. 2022-全球最佳混响插件评测