深入理解BatchNorm的原理、代码实现以及BN在CNN中的应用

BatchNorm是算法岗面试中几乎必考题,本文将带你理解BatchNorm的原理和代码实现,以及详细介绍BatchNorm在CNN中的应用。NLP中常见的LayerNorm的解读,详见我的另一篇博客深入理解NLP中LayerNorm的原理以及LN的代码详解

BatchNorm

  • 深入理解BatchNorm的原理、代码实现以及BN在CNN中的应用
  • 一、BatchNorm论文
    • 1.2 问题:为什么在Normalize后还要将xxx复原成yyy?
  • 二、BatchNorm代码
    • 2.1 torch.nn.BatchNorm1d
    • 2.2 torch.nn.BatchNorm2d
    • 2.3 BatchNorm层的参数γ,β和统计量
      • 2.3.1 train模式
      • 2.3.2 eval模式
    • 2.4 代码:Pytorch实战演练
  • 三、BatchNorm在CNN中的应用
    • 3.1 图解:卷积神经网络中的BatchNorm
    • 3.2 BatchNorm torch代码实现
  • 四、BatchNorm的优缺点
  • 五、BatchNorm反向传播公式推导
  • 六、参考资料

一、BatchNorm论文

论文题目:Batch Normalization: Accelerating Deep Network Training byReducing Internal Covariate Shift
论文地址:https://arxiv.org/pdf/1502.03167.pdf

BatchNorm伪代码如下:

1.2 问题:为什么在Normalize后还要将xxx复原成yyy?

答:因为我们这里做的是标准化,但是可能真正训练的时候还是方差大一点,或者均值比0大一些比较好的话,那么这里允许你还原回去,至于还原成什么样,神经网络会自己找出好的均值和方差(γ\gammaγ和β\betaβ都是可学习参数)李沐 09:40

二、BatchNorm代码

y=x−mean⁡(x)Var⁡(x)+eps∗gamma⁡+beta⁡\mathrm{y}=\frac{x-\operatorname{mean}(x)}{\sqrt{\operatorname{Var}(x)}+e p s} * \operatorname{gamma}+\operatorname{beta}y=Var(x)​+epsx−mean(x)​∗gamma+beta
根据数据维度的不同,PyTorch中的BatchNorm有不同的形式:

2.1 torch.nn.BatchNorm1d

官方文档:torch.nn.BatchNorm1d
torch.nn.BatchNorm1d(num_features, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True, device=None, dtype=None)

  • 2D input: (mini_batch, num_feature),常见的结构化数据,如,房价预测问题中x的特征数有100个,torch.nn.BatchNorm1d(100)
  • 3D input: (mini_batch, num_feature, additional_channel),使用时 torch.nn.BatchNorm1d(num_feature),不过这种维度一般不常用

2.2 torch.nn.BatchNorm2d

官方文档:torch.nn.BatchNorm2d
torch.nn.BatchNorm2d(num_features, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True, device=None, dtype=None)

  • 4D input: (mini_batch, num_feature_map, p, q),常用于CV的图像数据,如CIFAR10(3x32x32),torch.nn.BatchNorm2d(3)

2.3 BatchNorm层的参数γ,β和统计量

Batch Norm层有可学习的参数γ和β,以及统计量running mean和running var

  • (可学习参数)γ : weight of BatchNorm
  • (可学习参数)β : bias of BatchNorm
  • (统计量)running mean: 预测阶段会使用这个均值
  • (统计量)running var: 预测阶段会使用这个方差

默认初始化参数γ\gammaγ和β\betaβ 为1和0

pytorch中用state_dict()可以查看上面这些信息

print("--- 4D:(mini_batch, num_feature, p, q) ---")
m = nn.BatchNorm2d(3, momentum=0.1)  # 例如, CIFAR10数据集是三通道的,3x32x32
print(m.state_dict().keys())
# 输出:odict_keys(['weight', 'bias', 'running_mean', 'running_var', 'num_batches_tracked'])

2.3.1 train模式

在pytorch中可以使用model.train()将BatchNorm层切换到train模式。

在train模式下参数γ和β会随着网络的反向传播进行梯度更新,而统计量running mean和running var则会用一种特定的方式进行更新。在Pytorch中的更新方式如下:

x^new =(1−\hat{x}_{\text {new }}=(1-x^new ​=(1− momentum )×x^+) \times \hat{x}+)×x^+ momentum ×xt\times x_{t}×xt​

  • x^\hat{x}x^: running mean or running variance
  • xtx_{t}xt​: input mean and variance(训练时的第t个batch的均值和方差)
  • 默认momentum为0.1

2.3.2 eval模式

在pytorch中可以使用model.eval()将BatchNorm层切换到eval模式。

在eval模式下,我们的模型不可能再等到预测的样本数量达到一个batch时,再进行归一化,而是直接使用train模式得到的统计量running mean和running var进行归一化

2.4 代码:Pytorch实战演练

import torch
import torch.nn as nnbs = 64print("Pytorch Batch Norm Layer详解")
print("--- 2D input:(mini_batch, num_feature) ---")
# With Learnable Parameters
m = nn.BatchNorm1d(400)  # 例如,房价预测:x的特征数是400,y是房价
# Without Learnable Parameters(无学习参数γ和β)
# m = nn.BatchNorm1d(100, affine=False)
inputs = torch.randn(bs, 400)
print(m(inputs).shape)print("Batch Norm层的γ和β是要训练学习的参数")
print("γ:", m.state_dict()['weight'].shape)  # gammar
print("β:", m.state_dict()['bias'].shape)  # beta
print("")print("--- 3D input:(mini_batch, num_feature, other_channel) ---")
m = nn.BatchNorm1d(32)
inputs = torch.randn(bs, 32, 32)  # 这种格式的数据不常用
print(m(inputs).shape)print("Batch Norm层的γ和β是要训练学习的参数")
print("γ:", m.state_dict()['weight'].shape)  # gammar
print("β:", m.state_dict()['bias'].shape)  # beta
print("")print("--- 4D input:(mini_batch, num_feature, H, W) ---")
m = nn.BatchNorm2d(3)  # 例如, CIFAR10数据集是三通道的,3x32x32inputs = torch.randn(bs, 3, 32, 32)
print(m(inputs).shape)print("Batch Norm层的γ和β是要训练学习的参数")
print("γ:", m.state_dict()['weight'].shape)  # gammar
print("β:", m.state_dict()['bias'].shape)  # beta
print("Batch Norm层的running_mean和running_var是统计量(主要用于预测阶段)")
print("running_mean:", m.state_dict()['running_mean'].shape)
print("running_var:", m.state_dict()['running_var'].shape)

输出:

Pytorch Batch Norm Layer详解
--- 2D input:(mini_batch, num_feature) ---
torch.Size([64, 400])
Batch Norm层的γ和β是要训练学习的参数
γ: torch.Size([400])
β: torch.Size([400])--- 3D input:(mini_batch, num_feature, other_channel) ---
torch.Size([64, 32, 32])
Batch Norm层的γ和β是要训练学习的参数
γ: torch.Size([32])
β: torch.Size([32])--- 4D input:(mini_batch, num_feature, H, W) ---
torch.Size([64, 3, 32, 32])
Batch Norm层的γ和β是要训练学习的参数
γ: torch.Size([3])
β: torch.Size([3])
Batch Norm层的running_mean和running_var是统计量(主要用于预测阶段)
running_mean: torch.Size([3])
running_var: torch.Size([3])

三、BatchNorm在CNN中的应用

我们在第二部分的代码中发现,BatchNorm2d的参数γ和β数量是跟特征图的数量是一致的,并不是我们直观认为的num_feature*H*W个参数,这是为什么呢?

《百面机器学习》P221是这样解释的:
BatchNorm批量归一化在卷积神经网络中应用时,需要注意卷积神经网络的参数共享机制。每一个卷积核的参数在不同位置的神经元当中是共享的,因此同一个特征图的所有神经元也应该被一起归一化!

  • 换句话说就是,你一个特征图用的是共享的卷积核参数,所以这个特征图中的每个神经元(共H*W个)也应该共享参数γ,β\gamma, \betaγ,β。如果有fff个卷积核,就对应fff个特征图和fff组不同的γ\gammaγ和β\betaβ参数

下面的解释来自hjimce

  • 假如某一层卷积层有6个特征图,每个特征图的大小是100*100,这样就相当于这一层网络有6*100*100个神经元,如果采用BN,就会有6*100*100个参数γ、β,这样岂不是太恐怖了。因此卷积层上的BN使用,其实也是使用了类似权值共享的策略,把一整张特征图当做一个神经元进行处理。
  • 卷积神经网络经过卷积后得到的是一系列的特征图,如果min-batch sizes为m,那么网络某一层输入数据可以表示为四维矩阵(m,f,p,q),m为min-batch sizes,f为特征图个数,p、q分别为特征图的宽高。在cnn中我们可以把每个特征图看成是一个特征处理,因此在使用Batch Normalization,mini-batch size 的大小相当于m*p*q,于是对于每个特征图都只有一对可学习参数:γ、β。

3.1 图解:卷积神经网络中的BatchNorm

这里我特意画了一个图来让大家看清楚CNN中Batchnorm到底是怎么做的

总结来说:

  1. 对于某个特征图而言,一个batch共有m个这样的特征图,并且每个特征图有p*q个神经元,把所有的m*p*q个神经元拉直,然后求得平均值和方差。
  2. 对m个这样特征图的p*q个神经元的每个神经元,利用求出的平均值和方差做下数据变换。

参考资料:BN的操作流程


3.2 BatchNorm torch代码实现

https://d2l.ai/chapter_convolutional-modern/batch-norm.html

import torch
from torch import nn
from d2l import torch as d2ldef batch_norm(X, gamma, beta, moving_mean, moving_var, eps, momentum):# Use `is_grad_enabled` to determine whether the current mode is training# mode or prediction modeif not torch.is_grad_enabled():# If it is prediction mode, directly use the mean and variance# obtained by moving averageX_hat = (X - moving_mean) / torch.sqrt(moving_var + eps)else:assert len(X.shape) in (2, 4)if len(X.shape) == 2:# When using a fully-connected layer, calculate the mean and# variance on the feature dimensionmean = X.mean(dim=0)var = ((X - mean) ** 2).mean(dim=0)else:# When using a two-dimensional convolutional layer, calculate the# mean and variance on the channel dimension (axis=1). Here we# need to maintain the shape of `X`, so that the broadcasting# operation can be carried out latermean = X.mean(dim=(0, 2, 3), keepdim=True)var = ((X - mean) ** 2).mean(dim=(0, 2, 3), keepdim=True)# In training mode, the current mean and variance are used for the# standardizationX_hat = (X - mean) / torch.sqrt(var + eps)# Update the mean and variance using moving averagemoving_mean = momentum * moving_mean + (1.0 - momentum) * meanmoving_var = momentum * moving_var + (1.0 - momentum) * varY = gamma * X_hat + beta  # Scale and shiftreturn Y, moving_mean.data, moving_var.data

下面是来自于Keras卷积层的BN实现的一小段主要源码:

# Keras BatchNorm
input_shape = self.input_shape
reduction_axes = list(range(len(input_shape)))
del reduction_axes[self.axis]
broadcast_shape = [1] * len(input_shape)
broadcast_shape[self.axis] = input_shape[self.axis]
if train:m = K.mean(X, axis=reduction_axes)brodcast_m = K.reshape(m, broadcast_shape)std = K.mean(K.square(X - brodcast_m) + self.epsilon, axis=reduction_axes)std = K.sqrt(std)brodcast_std = K.reshape(std, broadcast_shape)mean_update = self.momentum * self.running_mean + (1-self.momentum) * mstd_update = self.momentum * self.running_std + (1-self.momentum) * stdself.updates = [(self.running_mean, mean_update),(self.running_std, std_update)]X_normed = (X - brodcast_m) / (brodcast_std + self.epsilon)
else:brodcast_m = K.reshape(self.running_mean, broadcast_shape)brodcast_std = K.reshape(self.running_std, broadcast_shape)X_normed = ((X - brodcast_m) /(brodcast_std + self.epsilon))
out = K.reshape(self.gamma, broadcast_shape) * X_normed + K.reshape(self.beta, broadcast_shape)

附:pytorch中取mean的操作

import torchbs = 64
a = torch.randn(bs, 100, 32, 28)
# 将轴0,2,3的元素都放在一起取平均值
print(torch.mean(a, axis=(0, 2, 3)).shape)  # torch.Size([100])

附:CNN网络中的BatchNorm2d

四、BatchNorm的优缺点

BN的优点:

  • 解决内部协变量偏移,简单来说训练过程中,各层分布不同,增大了学习难度,BN缓解了这个问题。当然后来也有论文证明BN有作用和这个没关系,而是可以使损失平面更加的平滑,从而加快收敛速度。
  • 缓解了梯度饱和问题(如果使用sigmoid这种含有饱和区间的激活函数的话),加快收敛。

BN的缺点

  • Batch size比较小的时候,效果会比较差。因为他是用一个batch中的均值和方差来模拟全部数据的均值和方差。比如你一个batch只有2个样本,那你两个样本的均值和方差就不能很好地代表全班人的均值和方差,所以效果肯定就不好。
  • BN是计算机视觉CV的标配,但在自然语音处理NLP中效果一般较差,取而代之的是LN。关于LayerNorm的详解,可以参考我另一篇博客:深入理解NLP中LayerNorm的原理以及LN的代码详解

五、BatchNorm反向传播公式推导

详见我的Notion笔记

六、参考资料

[1] 李宏毅2021机器学习 第5节 Batch Normalization(学习笔记)
[2] 深度学习(二十九)Batch Normalization 学习笔记 (讲得挺全面的)
[3] BatchNormalization、LayerNormalization、InstanceNorm、GroupNorm简介(不同Norm的对比)
[4] BatchNorm behaves different in train() and eval() #5406
[5] BatchNorm2d原理、作用及其pytorch中BatchNorm2d函数的参数讲解
[6] 神经网络之BN层
[7] 5 分钟理解 BatchNorm
[8] Pytorch的BatchNorm层使用中容易出现的问题
[9] 【深度学习】深入理解Batch Normalization批标准化
[10] BN踩坑记–谈一下Batch Normalization的优缺点和适用场景
[11] 深度神经网络架构【斯坦福21秋季:实用机器学习中文版】

深入理解BatchNorm的原理、代码实现以及BN在CNN中的应用相关推荐

  1. python 靶心_手把手教你使用Python实战反欺诈模型|原理+代码

    原标题:手把手教你使用Python实战反欺诈模型|原理+代码 作者 | 萝卜 来源 | 早起Python(ID: zaoqi-python) 本文将基于不平衡数据,使用Python进行 反欺诈模型数据 ...

  2. 赠书:深入理解MySQL主从原理

    根据经验,想要快速学习一门技术有3种方式. 第一种方式是通过代码来理解它的实现,反推它的逻辑. 这种方式的难度很大,而且起点相对高,能够沉浸其中的人非常少,过程相对来说是苦闷的,但如果能够沉下心来看代 ...

  3. 不同类的方法 事务问题_深入理解 Spring 事务原理

    Spring事务的基本原理 Spring事务的本质其实就是数据库对事务的支持,没有数据库的事务支持,spring是无法提供事务功能的.对于纯JDBC操作数据库,想要用到事务,可以按照以下步骤进行: 获 ...

  4. python多元线性回归模型案例_Python 实战多元线性回归模型,附带原理+代码

    原标题:Python 实战多元线性回归模型,附带原理+代码 作者 | 萝卜 来源 | 早起Python( ID:zaoqi-python ) 「多元线性回归模型」非常常见,是大多数人入门机器学习的第一 ...

  5. gcc 删除elf_ELF文件格式解析器 原理 + 代码

    本文为看雪论坛精华文章 看雪论坛作者ID:菜鸟m号 附件链接:[原创] ELF文件格式解析器 原理 + 代码 写在前面: 读<Linux二进制>,发现作者对 ELF文件格式部分并没有做详细 ...

  6. 深入理解RCU|核心原理

    hi,大家好,今天给大家分享并行程序设计中最重要的锁-RCU锁,RCU锁本质是用空间换时间,是对读写锁的一种优化加强,但不仅仅是这样简单,RCU体现出来的垃圾回收思想,也是值得我们学习和借鉴,各个语言 ...

  7. python常用代码_Python常用算法学习(4) 数据结构(原理+代码)-最全总结

    数据结构简介 1,数据结构 数据结构是指相互之间存在着一种或多种关系的数据元素的集合和该集合中数据元素之间的关系组成.简单来说,数据结构就是设计数据以何种方式组织并存贮在计算机中.比如:列表,集合与字 ...

  8. python原理及代码_原理+代码|详解层次聚类及Python实现

    前言 聚类分析是研究分类问题的分析方法,是洞察用户偏好和做用户画像的利器之一.聚类分析的方法非常多,能够理解经典又最基础的聚类方法 -- 层次聚类法(系统聚类) 的基本原理并将代码用于实际的业务案例是 ...

  9. Java 多线程 —— 深入理解 volatile 的原理以及应用

    转载自  Java 多线程 -- 深入理解 volatile 的原理以及应用 推荐阅读:<java 多线程-线程怎么来的> 这一篇主要讲解一下volatile的原理以及应用,想必看完这一篇 ...

最新文章

  1. 好礼相送|CSDN云原生 Meetup 成都站报名热烈启动,12.18见!
  2. 关于sigma pix的理解
  3. hibernate Restrictions 用法 查询
  4. 【报告分享】阿里巴巴全生态就业体系与就业质量研究报告.pdf(附下载链接)...
  5. 陈国良院士将出席“首届对象存储技术与应用大会”
  6. 笔记本linux版刚买回来怎么检查,新电脑买回来要怎么做
  7. android对象引用释放,Android程序的内存泄漏与规避方法
  8. Linux 使用yum下载软件
  9. 解决......lib/include/THC/THCGeneral.h:12:18: fatal error: cuda.h: No such file or directory报错问题
  10. 面试题16:不含重复字符的最长子字符串(Java版)
  11. 管家婆物流配货单快速实现批量拣货
  12. git显示当前分支的父分支名称
  13. 让自己的app支持小程序
  14. 墙裂推荐5款做微商必备的软件
  15. Java编程那些事儿68——抽象类和接口(一)
  16. 【微信小程序】去水印小程序源码,微信和QQ小程序都能用!
  17. 鼎信网关PCM数据包解析转换
  18. linux屏幕拷贝,使用gnome-screenshot在Linux中截取屏幕截图的综合指南
  19. Arduino工程源码分析
  20. 11.1 什么是模块,Python模块化编程

热门文章

  1. 定了!北京冬奥会售票群体出炉,门票需要预定吗?
  2. 群晖服务器显示器设置,【群晖 DS216+II 网络存储 NAS 服务器使用总结】噪音|炒菜|显示屏|优点_摘要频道_什么值得买...
  3. eclipse运行python老是报错_eclipse python
  4. 深度学习卷积神经网络——经典网络GoogLeNet(Inception V3)网络的搭建与实现
  5. 做副业赚钱,这几个热门自媒体平台收益超多
  6. R语言 Nomogram个体得分 各变量Point得分
  7. mac小技巧:如何设置Mac快速锁屏
  8. 求解非线性方程的实根matlab
  9. 软件工程知识点汇总(期末总复习)
  10. element UI 设置滚动条颜色