一文弄懂Batch Norm / Layer Norm / Instance Norm / Group Norm 归一化方法
文章目录
- 前因
- 总览
- Batch Normalization
- Layer Normalization
- Instance Normalization
- Group Normalization
- 总结
- 参考
前因
Normalization现在已经成了神经网络中不可缺少的一个重要模块了,并且存在多种不同版本的归一化方法,把我们秀得头晕眼花,其本质都是减去均值除以方差,进行线性映射后,使得数据满足某个稳定分布,如下图所示:
出于更好地理解与区分这些方法, 写下本文。
总览
如果不理解此图,可以先看后面的小节,再来会看这幅图就会一目了然。
Batch Normalization
批次归一化是用的最多,也是提出最早、应用最广的归一化方法,其作用可以加快模型训练时的收敛速度,使得模型训练过程更加稳定,避免梯度爆炸或者梯度消失,并且起到一定的正则化作用。
如上图所示,Batch Norm在通道维度进行归一化,最后得到C个统计量u,δ。假设输入特征为[N, H, W, C],在C
的每个维度上对[N, H, W]计算其均值、方差,用于该维度上的归一化操作。
借用参考链接[1]的一个例子,假设有N本书,每本书有C页,每页可容纳HxW个字符,Batch Norm就是页为单位:假设每本书都为C页,首先计算N本书中第1页的字符【N, H, W】
均值方差,得到统计量u1,δ1
,然后对N本书的第一页利用该统计量对第一页的元素进行归一化操作,剩下的C-1页同理。
关于BN的一些细节可以参看我的另一篇博客
代码理解:
import torch
import torch.nn as nnfrom einops import rearrange, reducex = torch.randn((32,3,224,224)) # [B,C,H,W]
b, c, h, w = x.shape# pytorch
bn = nn.BatchNorm2d(c, eps=1e-10, affine=False, track_running_stats=False)
# affine=False 关闭线性映射
# track_running_stats=False 只计算当前的平均值与方差,而非更新全局统计量
y = bn(x)# 为了方便理解,这里使用了einops库实现
x_ = rearrange(x, 'b c h w -> (b w h) c')
mean = rearrange(x_.mean(dim=0), 'c -> 1 c 1 1')
std = rearrange(x_.std(dim=0), 'c -> 1 c 1 1')y_ = (x-mean)/std# 输出差别
print('diff={}'.format(torch.abs(y-y_).max()))
# diff=1.9073486328125e-06
本代码包括之后的代码为了方便理解,均使用了einops库来进行爱因斯坦求和操作,如果不属于可以参考以下资源:
[1]https://github.com/arogozhnikov/einops/blob/master/docs/1-einops-basics.ipynb
[2]https://zhuanlan.zhihu.com/p/342675997
Layer Normalization
BN是按照样本数计算归一化统计量的,当样本数很少时,效果会变得很差。不适用某些场景,比如说硬件资源受限,在线学习等场景。
Batch Norm以Batch为单位计算统计量,与Batch Norm不同,Layer Norm以样本为单位计算统计量,因此最后会得到N个u,δ。假设输入特征为[N, H, W, C],在N
的每个维度上对[H, W,C]计算其均值、方差,用于该维度上的归一化操作。
还是同样的例子,有N本书,每本书有C页,每页可容纳HxW个字符,Layer Norm就是以本为单位:首先计算第一本书中的所有字符【H, W, C】
均值方差,得到统计量u1,δ1
,然后利用该统计量对第一本数进行归一化操作,剩下的N-1本书同理。
代码理解:
x = torch.randn((32,3,224,224)) # [B,C,H,W]
b, c, h, w = x.shape# pytorch
ln = nn.LayerNorm([c,h,w], eps=1e-12, elementwise_affine=False)
# elementwise_affine=False 关闭映射
y = ln(x)# 为了方便理解,这里使用了einops库实现
x_ = rearrange(x, 'b c h w -> (h w c) b')
mean = rearrange(x_.mean(dim=0), 'b -> b 1 1 1')
std = rearrange(x_.std(dim=0), 'b -> b 1 1 1')y_ = (x-mean)/std# 输出差别
print('diff={}'.format(torch.abs(y-y_).max()))
# diff=2.384185791015625e-05
值得注意的是:Layer Normalization操作在进行线性映射时是进行元素级别的映射,维度为输入的Normalized_Shape.
Instance Normalization
这种归一化方法最初用于图像的风格迁移。其作者发现,在生成模型中, feature map 的各个 channel 的均值和方差会影响到最终生成图像的风格,因此可以先把图像在 channel 层面归一化,然后再用目标风格图片对应 channel 的均值和标准差“去归一化”,以期获得目标图片的风格。
还是同样的例子,有N本书,每本书有C页,每页可容纳HxW个字符,Instance Norm就是以每本书的每一页为单位:首先计算第1本书中第1页的所有字符【H, W】
均值方差,得到统计量u1,δ1
,然后利用该统计量对第1本书第1页进行归一化操作,剩下的NC-1页同理。
代码理解:
x = torch.randn((32,3,224,224)) # [B,C,H,W]
b, c, h, w = x.shape# pytorch
In = nn.InstanceNorm2d(c, eps=1e-12, affine=False, track_running_stats=False)y = In(x)# 为了方便理解,这里使用了einops库实现
x_ = rearrange(x, 'b c h w -> b c (h w)')
mean = rearrange(x_.mean(dim=2), 'b c -> b c 1 1')
std = rearrange(x_.std(dim=2), 'b c -> b c 1 1')y_ = (x-mean)/std# 输出差别
print('diff={}'.format(torch.abs(y-y_).max()))
# diff=5.340576171875e-05
Group Normalization
这种归一化方式适用于占用显存比较大的任务,例如图像分割。对这类任务,可能 batchsize 只能是个位数,再大显存就不够用了。而当 batchsize 是个位数时,BN 的表现可能很差,因为没办法通过几个样本的数据量,来近似总体的均值和标准差。
其在计算均值和标准差时,先把每一个样本feature map的 channel 分成 G 组,每组将有 C/G 个 channel,然后将这些 channel 中的元素求均值和标准差。各组 channel 用其对应的归一化参数独立地归一化。
介于Layer Norm与Instance Norm之间
还是同样的例子,有N本书,每本书有C页,每页可容纳HxW个字符,Group Norm就是以每本书的G页为单位:首先计算第1本书中第1组G页中的所有字符【H, W, G】
均值方差,得到统计量u1,δ1
,然后利用该统计量对第1本书第1组G页进行归一化操作,剩下的NC/G-1
组同理。
代码理解:
x = torch.randn((32,6,224,224)) # [B,C,H,W]
b, c, h, w = x.shape
group_num = 3# pytorch
gn = nn.GroupNorm(group_num, c, eps=1e-12, affine=False)
y = gn(x)
# print(gn.weight.shape)# 为了方便理解,这里使用了einops库实现
x_ = rearrange(x, 'b (g n) h w -> b g (n h w)', g=group_num) # [32, 3, 2x224x224]
mean = rearrange(x_.mean(dim=2), 'b g -> b g 1') # [32, 3, 1]
std = rearrange(x_.std(dim=2), 'b g -> b g 1')y_ = (x_-mean)/std
y_ = rearrange(y_, 'b g (n h w) -> b (g n) h w', g=group_num, h=h, w=w)# 输出差别
print('diff={}'.format(torch.abs(y-y_).max()))
# diff=3.0994415283203125e-05
总结
BatchNorm:batch方向做归一化,算NxHxW
的均值,对小batchsize效果不好;BN主要缺点是对batchsize的大小比较敏感,由于每次计算均值和方差是在一个batch上,所以如果batchsize太小,则计算的均值、方差不足以代表整个数据分布。
LayerNorm:channel方向做归一化,算CxHxW
的均值,主要对RNN(处理序列)作用明显,目前大火的Transformer也是使用的这种归一化操作;
InstanceNorm:一个channel内做归一化,算H*W的均值,用在风格化迁移;因为在图像风格化中,生成结果主要依赖于某个图像实例,所以对整个batch归一化不适合图像风格化中,因而对HW做归一化。可以加速模型收敛,并且保持每个图像实例之间的独立。
GroupNorm:将channel方向分group,然后每个group内做归一化,算(C//G)HW的均值;这样与batchsize无关,不受其约束,在分割与检测领域作用较好。
参考
[1] 如何区分并记住常见的几种 Normalization 算法
https://bbs.cvmart.net/topics/469/Normalization
[2] 【基础算法】六问透彻理解BN(Batch Normalization)
https://zhuanlan.zhihu.com/p/93643523
[3] 模型优化之Layer Normalization
https://zhuanlan.zhihu.com/p/54530247
[4] PyTorch学习之归一化层(BatchNorm、LayerNorm、InstanceNorm、GroupNorm)
https://blog.csdn.net/shanglianlm/article/details/85075706
一文弄懂Batch Norm / Layer Norm / Instance Norm / Group Norm 归一化方法相关推荐
- 一文弄懂元学习 (Meta Learing)(附代码实战)《繁凡的深度学习笔记》第 15 章 元学习详解 (上)万字中文综述
<繁凡的深度学习笔记>第 15 章 元学习详解 (上)万字中文综述(DL笔记整理系列) 3043331995@qq.com https://fanfansann.blog.csdn.net ...
- 一文弄懂神经网络中的反向传播法
最近在看深度学习的东西,一开始看的吴恩达的UFLDL教程,有中文版就直接看了,后来发现有些地方总是不是很明确,又去看英文版,然后又找了些资料看,才发现,中文版的译者在翻译的时候会对省略的公式推导过程进 ...
- 一文弄懂神经网络中的反向传播法——BackPropagation【转】
本文转载自:https://www.cnblogs.com/charlotte77/p/5629865.html 一文弄懂神经网络中的反向传播法--BackPropagation 最近在看深度学习的东 ...
- 【一文弄懂】优先经验回放(PER)论文-算法-代码
[一文弄懂]优先经验回放(PER)论文-算法-代码 文章目录 [一文弄懂]优先经验回放(PER)论文-算法-代码 前言: 综合评价: 继续前言唠叨 per论文简述: 参考博客: 背景知识 A MOTI ...
- 一文弄懂 Diffusion Model
什么是 Diffusion Model 一.前向 Diffusion 过程 Diffusion Model 首先定义了一个前向扩散过程,总共包含T个时间步,如下图所示: 最左边的蓝色圆圈 x0 表示真 ...
- 一文弄懂各种loss function
有模型就要定义损失函数(又叫目标函数),没有损失函数,模型就失去了优化的方向.大家往往接触的损失函数比较少,比如回归就是MSE,MAE,分类就是log loss,交叉熵.在各个模型中,目标函数往往都是 ...
- 一文弄懂String的所有小秘密
文章目录 简介 String是不可变的 传值还是传引用 substring() 导致的内存泄露 总结 一文弄懂String的所有小秘密 简介 String是java中非常常用的一个对象类型.可以说ja ...
- 一文弄懂EnumMap和EnumSet
文章目录 简介 EnumMap 什么时候使用EnumMap EnumSet 总结 一文弄懂EnumMap和EnumSet 简介 一般来说我们会选择使用HashMap来存储key-value格式的数据, ...
- CAD2010 为了保护_一文弄懂,锂电池的充电电路,以及它的保护电路方案设计
原标题:一文弄懂,锂电池的充电电路,以及它的保护电路方案设计 锂电池特性 首先,芯片哥问一句简单的问题,为什么很多电池都是锂电池? 锂电池,工程师对它都不会感到陌生.在电子产品项目开发的过程中,尤其是 ...
最新文章
- FCN全连接卷积网络(3)--Fully Convolutional Networks for Semantic Segmentation阅读(摘要部分)
- ocelot和nginx比较_nginx + ocelot+.net core signalr 关于websocket无法正常握手的问题
- 关于Adium近期无法添加MSN联系人的说明
- 解读直播连麦与点播加密
- python 爬虫爬取小说信息
- 深入PHP面向对象、模式与实践读书笔记:面向对象设计和过程式编程
- a类怎么引用b类java_Java中A类的数组如何传入B类???急
- Azkaban工作流调度器(1)--azkaban的安装
- [转载] Python-科赫雪花(科克曲线)
- 每一次结束只是一次新的起点,深有体会。
- redhat 生产环境版本选择
- python finally语句里面出现异常_python try except语句出现异常
- 计算机技术1000字,计算机专业毕业实习报告1000字
- 多读书多看报,少吃零食多睡觉—2014总结,2015规划
- 随便说说,我回来啦~
- 岁月的剪影【五月世界末日】
- 【渲染】解决三维出图黑白边缘溢出问题:直通(STRAIGHT)与预乘(PREMULT)ALPHA剖析
- 【HTML——盛开花朵】(效果+代码)
- ADI实验室电路:带抗混叠滤波器的宽带接收机
- 10/9 看的何向南老师团队关于bias和debias最新综述;还可以吧
热门文章
- 使用PicGo配置七牛云图床(图文步骤详细)
- 全网最全精析破解 Springboot+Jpa 对数据库增删改查
- php实现html转word
- [编辑器]KindEditor 是什么?
- mysql错误01000_错误 ORA-01000: maximum open cursors exceeded Exception
- 从你的全世界路过---陌陌X-SIGN还原
- Setup time 和 Hold time
- css技术点二:字体图标(阿里巴巴字体图标使用)
- python 过滤相似图片_Python过滤纯色图片,挑选视频封面
- 警告:关于电磁辐射对孕妇的危害。