文章目录

  • 前因
  • 总览
  • Batch Normalization
  • Layer Normalization
  • Instance Normalization
  • Group Normalization
  • 总结
  • 参考

前因

Normalization现在已经成了神经网络中不可缺少的一个重要模块了,并且存在多种不同版本的归一化方法,把我们秀得头晕眼花,其本质都是减去均值除以方差,进行线性映射后,使得数据满足某个稳定分布,如下图所示:

出于更好地理解与区分这些方法, 写下本文。

总览

如果不理解此图,可以先看后面的小节,再来会看这幅图就会一目了然。

Batch Normalization

批次归一化是用的最多,也是提出最早、应用最广的归一化方法,其作用可以加快模型训练时的收敛速度,使得模型训练过程更加稳定,避免梯度爆炸或者梯度消失,并且起到一定的正则化作用。

如上图所示,Batch Norm在通道维度进行归一化,最后得到C个统计量u,δ。假设输入特征为[N, H, W, C],在C的每个维度上对[N, H, W]计算其均值、方差,用于该维度上的归一化操作。

借用参考链接[1]的一个例子,假设有N本书,每本书有C页,每页可容纳HxW个字符,Batch Norm就是页为单位:假设每本书都为C页,首先计算N本书中第1页的字符【N, H, W】均值方差,得到统计量u1,δ1,然后对N本书的第一页利用该统计量对第一页的元素进行归一化操作,剩下的C-1页同理。

关于BN的一些细节可以参看我的另一篇博客

代码理解

import torch
import torch.nn as nnfrom einops import rearrange, reducex = torch.randn((32,3,224,224)) # [B,C,H,W]
b, c, h, w = x.shape# pytorch
bn = nn.BatchNorm2d(c, eps=1e-10, affine=False, track_running_stats=False)
# affine=False 关闭线性映射
# track_running_stats=False 只计算当前的平均值与方差,而非更新全局统计量
y = bn(x)# 为了方便理解,这里使用了einops库实现
x_ = rearrange(x, 'b c h w -> (b w h) c')
mean = rearrange(x_.mean(dim=0), 'c -> 1 c 1 1')
std = rearrange(x_.std(dim=0), 'c -> 1 c 1 1')y_ = (x-mean)/std# 输出差别
print('diff={}'.format(torch.abs(y-y_).max()))
# diff=1.9073486328125e-06

本代码包括之后的代码为了方便理解,均使用了einops库来进行爱因斯坦求和操作,如果不属于可以参考以下资源:
[1]https://github.com/arogozhnikov/einops/blob/master/docs/1-einops-basics.ipynb
[2]https://zhuanlan.zhihu.com/p/342675997

Layer Normalization

BN是按照样本数计算归一化统计量的,当样本数很少时,效果会变得很差。不适用某些场景,比如说硬件资源受限,在线学习等场景。

Batch Norm以Batch为单位计算统计量,与Batch Norm不同,Layer Norm以样本为单位计算统计量,因此最后会得到N个u,δ。假设输入特征为[N, H, W, C],在N的每个维度上对[H, W,C]计算其均值、方差,用于该维度上的归一化操作。

还是同样的例子,有N本书,每本书有C页,每页可容纳HxW个字符,Layer Norm就是以本为单位:首先计算第一本书中的所有字符【H, W, C】均值方差,得到统计量u1,δ1,然后利用该统计量对第一本数进行归一化操作,剩下的N-1本书同理。

代码理解

x = torch.randn((32,3,224,224)) # [B,C,H,W]
b, c, h, w = x.shape# pytorch
ln = nn.LayerNorm([c,h,w], eps=1e-12, elementwise_affine=False)
# elementwise_affine=False 关闭映射
y = ln(x)# 为了方便理解,这里使用了einops库实现
x_ = rearrange(x, 'b c h w -> (h w c) b')
mean = rearrange(x_.mean(dim=0), 'b -> b 1 1 1')
std = rearrange(x_.std(dim=0), 'b -> b 1 1 1')y_ = (x-mean)/std# 输出差别
print('diff={}'.format(torch.abs(y-y_).max()))
# diff=2.384185791015625e-05

值得注意的是:Layer Normalization操作在进行线性映射时是进行元素级别的映射,维度为输入的Normalized_Shape.

Instance Normalization

这种归一化方法最初用于图像的风格迁移。其作者发现,在生成模型中, feature map 的各个 channel 的均值和方差会影响到最终生成图像的风格,因此可以先把图像在 channel 层面归一化,然后再用目标风格图片对应 channel 的均值和标准差“去归一化”,以期获得目标图片的风格。

还是同样的例子,有N本书,每本书有C页,每页可容纳HxW个字符,Instance Norm就是以每本书的每一页为单位:首先计算第1本书中第1页的所有字符【H, W】均值方差,得到统计量u1,δ1,然后利用该统计量对第1本书第1页进行归一化操作,剩下的NC-1页同理。

代码理解

x = torch.randn((32,3,224,224)) # [B,C,H,W]
b, c, h, w = x.shape# pytorch
In = nn.InstanceNorm2d(c, eps=1e-12, affine=False, track_running_stats=False)y = In(x)# 为了方便理解,这里使用了einops库实现
x_ = rearrange(x, 'b c h w -> b c (h w)')
mean = rearrange(x_.mean(dim=2), 'b c -> b c 1 1')
std = rearrange(x_.std(dim=2), 'b c -> b c 1 1')y_ = (x-mean)/std# 输出差别
print('diff={}'.format(torch.abs(y-y_).max()))
# diff=5.340576171875e-05

Group Normalization

这种归一化方式适用于占用显存比较大的任务,例如图像分割。对这类任务,可能 batchsize 只能是个位数,再大显存就不够用了。而当 batchsize 是个位数时,BN 的表现可能很差,因为没办法通过几个样本的数据量,来近似总体的均值和标准差。

其在计算均值和标准差时,先把每一个样本feature map的 channel 分成 G 组,每组将有 C/G 个 channel,然后将这些 channel 中的元素求均值和标准差。各组 channel 用其对应的归一化参数独立地归一化。

介于Layer Norm与Instance Norm之间

还是同样的例子,有N本书,每本书有C页,每页可容纳HxW个字符,Group Norm就是以每本书的G页为单位:首先计算第1本书中第1组G页中的所有字符【H, W, G】均值方差,得到统计量u1,δ1,然后利用该统计量对第1本书第1组G页进行归一化操作,剩下的NC/G-1组同理。

代码理解

x = torch.randn((32,6,224,224)) # [B,C,H,W]
b, c, h, w = x.shape
group_num = 3# pytorch
gn = nn.GroupNorm(group_num, c, eps=1e-12, affine=False)
y = gn(x)
# print(gn.weight.shape)# 为了方便理解,这里使用了einops库实现
x_ = rearrange(x, 'b (g n) h w -> b g (n h w)', g=group_num) # [32, 3, 2x224x224]
mean = rearrange(x_.mean(dim=2), 'b g -> b g 1')  # [32, 3, 1]
std = rearrange(x_.std(dim=2), 'b g -> b g 1')y_ = (x_-mean)/std
y_ = rearrange(y_, 'b g (n h w) -> b (g n) h w', g=group_num, h=h, w=w)# 输出差别
print('diff={}'.format(torch.abs(y-y_).max()))
# diff=3.0994415283203125e-05

总结

BatchNorm:batch方向做归一化,算NxHxW的均值,对小batchsize效果不好;BN主要缺点是对batchsize的大小比较敏感,由于每次计算均值和方差是在一个batch上,所以如果batchsize太小,则计算的均值、方差不足以代表整个数据分布

LayerNorm:channel方向做归一化,算CxHxW的均值,主要对RNN(处理序列)作用明显,目前大火的Transformer也是使用的这种归一化操作;

InstanceNorm:一个channel内做归一化,算H*W的均值,用在风格化迁移;因为在图像风格化中,生成结果主要依赖于某个图像实例,所以对整个batch归一化不适合图像风格化中,因而对HW做归一化。可以加速模型收敛,并且保持每个图像实例之间的独立

GroupNorm:将channel方向分group,然后每个group内做归一化,算(C//G)HW的均值;这样与batchsize无关,不受其约束,在分割与检测领域作用较好

参考

[1] 如何区分并记住常见的几种 Normalization 算法
https://bbs.cvmart.net/topics/469/Normalization

[2] 【基础算法】六问透彻理解BN(Batch Normalization)
https://zhuanlan.zhihu.com/p/93643523

[3] 模型优化之Layer Normalization
https://zhuanlan.zhihu.com/p/54530247

[4] PyTorch学习之归一化层(BatchNorm、LayerNorm、InstanceNorm、GroupNorm)
https://blog.csdn.net/shanglianlm/article/details/85075706

一文弄懂Batch Norm / Layer Norm / Instance Norm / Group Norm 归一化方法相关推荐

  1. 一文弄懂元学习 (Meta Learing)(附代码实战)《繁凡的深度学习笔记》第 15 章 元学习详解 (上)万字中文综述

    <繁凡的深度学习笔记>第 15 章 元学习详解 (上)万字中文综述(DL笔记整理系列) 3043331995@qq.com https://fanfansann.blog.csdn.net ...

  2. 一文弄懂神经网络中的反向传播法

    最近在看深度学习的东西,一开始看的吴恩达的UFLDL教程,有中文版就直接看了,后来发现有些地方总是不是很明确,又去看英文版,然后又找了些资料看,才发现,中文版的译者在翻译的时候会对省略的公式推导过程进 ...

  3. 一文弄懂神经网络中的反向传播法——BackPropagation【转】

    本文转载自:https://www.cnblogs.com/charlotte77/p/5629865.html 一文弄懂神经网络中的反向传播法--BackPropagation 最近在看深度学习的东 ...

  4. 【一文弄懂】优先经验回放(PER)论文-算法-代码

    [一文弄懂]优先经验回放(PER)论文-算法-代码 文章目录 [一文弄懂]优先经验回放(PER)论文-算法-代码 前言: 综合评价: 继续前言唠叨 per论文简述: 参考博客: 背景知识 A MOTI ...

  5. 一文弄懂 Diffusion Model

    什么是 Diffusion Model 一.前向 Diffusion 过程 Diffusion Model 首先定义了一个前向扩散过程,总共包含T个时间步,如下图所示: 最左边的蓝色圆圈 x0 表示真 ...

  6. 一文弄懂各种loss function

    有模型就要定义损失函数(又叫目标函数),没有损失函数,模型就失去了优化的方向.大家往往接触的损失函数比较少,比如回归就是MSE,MAE,分类就是log loss,交叉熵.在各个模型中,目标函数往往都是 ...

  7. 一文弄懂String的所有小秘密

    文章目录 简介 String是不可变的 传值还是传引用 substring() 导致的内存泄露 总结 一文弄懂String的所有小秘密 简介 String是java中非常常用的一个对象类型.可以说ja ...

  8. 一文弄懂EnumMap和EnumSet

    文章目录 简介 EnumMap 什么时候使用EnumMap EnumSet 总结 一文弄懂EnumMap和EnumSet 简介 一般来说我们会选择使用HashMap来存储key-value格式的数据, ...

  9. CAD2010 为了保护_一文弄懂,锂电池的充电电路,以及它的保护电路方案设计

    原标题:一文弄懂,锂电池的充电电路,以及它的保护电路方案设计 锂电池特性 首先,芯片哥问一句简单的问题,为什么很多电池都是锂电池? 锂电池,工程师对它都不会感到陌生.在电子产品项目开发的过程中,尤其是 ...

最新文章

  1. FCN全连接卷积网络(3)--Fully Convolutional Networks for Semantic Segmentation阅读(摘要部分)
  2. ocelot和nginx比较_nginx + ocelot+.net core signalr 关于websocket无法正常握手的问题
  3. 关于Adium近期无法添加MSN联系人的说明
  4. 解读直播连麦与点播加密
  5. python 爬虫爬取小说信息
  6. 深入PHP面向对象、模式与实践读书笔记:面向对象设计和过程式编程
  7. a类怎么引用b类java_Java中A类的数组如何传入B类???急
  8. Azkaban工作流调度器(1)--azkaban的安装
  9. [转载] Python-科赫雪花(科克曲线)
  10. 每一次结束只是一次新的起点,深有体会。
  11. redhat 生产环境版本选择
  12. python finally语句里面出现异常_python try except语句出现异常
  13. 计算机技术1000字,计算机专业毕业实习报告1000字
  14. 多读书多看报,少吃零食多睡觉—2014总结,2015规划
  15. 随便说说,我回来啦~
  16. 岁月的剪影【五月世界末日】
  17. 【渲染】解决三维出图黑白边缘溢出问题:直通(STRAIGHT)与预乘(PREMULT)ALPHA剖析
  18. 【HTML——盛开花朵】(效果+代码)
  19. ADI实验室电路:带抗混叠滤波器的宽带接收机
  20. 10/9 看的何向南老师团队关于bias和debias最新综述;还可以吧

热门文章

  1. 使用PicGo配置七牛云图床(图文步骤详细)
  2. 全网最全精析破解 Springboot+Jpa 对数据库增删改查
  3. php实现html转word
  4. [编辑器]KindEditor 是什么?
  5. mysql错误01000_错误 ORA-01000: maximum open cursors exceeded Exception
  6. 从你的全世界路过---陌陌X-SIGN还原
  7. Setup time 和 Hold time
  8. css技术点二:字体图标(阿里巴巴字体图标使用)
  9. python 过滤相似图片_Python过滤纯色图片,挑选视频封面
  10. 警告:关于电磁辐射对孕妇的危害。