透彻理解BN(Batch Normalization)层
什么是BN
Batch Normalization是2015年论文《Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift》中提出的数据归一化方法,往往用在深度神经网络中激活层之前。其作用可以加快模型训练时的收敛速度,使得模型训练过程更加稳定,避免梯度爆炸或者梯度消失。并且起到一定的正则化作用,几乎代替了Dropout
批量归一化:通过减少内部协变量偏移来加速深度网络训练
由于每层输入的分布在训练过程中随着前一层的参数发生变化而发生变化,因此训练深度神经网络很复杂。由于需要较低的学习率和仔细的参数初始化,这会减慢训练速度,并且使得训练具有饱和非线性的模型变得非常困难。我们将这种现象称为内部协变量偏移,并通过归一化层输入来解决该问题。我们的方法的优势在于将标准化作为模型架构的一部分,并为每个训练小批量执行标准化。 Batch Normalization 允许我们使用更高的学习率,并且在初始化时不那么小心。它还充当正则化器,在某些情况下消除了 Dropout 的需要。应用于最先进的图像分类模型,批量归一化在训练步骤减少 14 倍的情况下实现了相同的精度,并且以显着的优势击败了原始模型。使用一组批量归一化网络,我们改进了 ImageNet 分类的最佳发布结果:达到 4.9% 的前 5 名验证错误(和 4.8% 的测试错误),超过了人工评估员的准确性
机器学习领域有个很重要的假设:IID独立同分布假设,就是假设训练数据和测试数据是满足相同分布的,这是通过训练数据获得的模型能够在测试集获得好的效果的一个基本保障。那BatchNorm的作用是什么呢?BatchNorm就是在深度神经网络训练过程中使得每一层神经网络的输入保持相同分布的。
BN解决“Internal Covariate Shift”问题
在训练的过程中,即使对输入层做了归一化处理使其变成标准正态,随着网络的加深,函数变换越来越复杂,许多隐含层的分布还是会彻底放飞自我,变成各种奇奇怪怪的正态分布,并且整体分布逐渐往非线性函数(也就是激活函数)的取值区间的上下限两端靠近。对于sigmoid函数来说,就意味着输入值是大的负数或正数,这导致反向传播时底层神经网络的梯度消失,这是训练深层神经网络收敛越来越慢的本质原因。
为了解决上述问题,又想到网络的某个隐含层相对于之后的网络就相当于输入层,所以BN的基本思想就是:把网络的每个隐含层的分布都归一化到标准正态。其实就是把越来越偏的分布强制拉回到比较标准的分布,这样使得激活函数的输入值落在该激活函数对输入比较敏感的区域,这样一来输入的微小变化就会导致损失函数较大的变化。通过这样的方式可以使梯度变大,就避免了梯度消失的问题,而且梯度变大意味着收敛速度快,能大大加快训练速度。
简单说来就是:传统的神经网络只要求第一个输入层归一化,而带BN的神经网络则是把每个输入层(把隐含层也理解成输入层)都归一化。
BN的核心公式理解
pytorch BatchNorm2d
BATCHNORM2D
参数介绍
- num_features,输入数据的通道数,归一化时需要的均值和方差是在每个通道中计算的
- eps,用来防止归一化时除以0
- momentum,滑动平均的参数,用来计算running_mean和running_var
- affine,是否进行仿射变换,即缩放操作
- track_running_stats,是否记录训练阶段的均值和方差,即running_mean和running_var
BN层的状态包含五个参数
- weight,缩放操作的γ
- bias,缩放操作的β
- running_mean,训练阶段统计的均值,测试阶段会用到。
- running_var,训练阶段统计的方差,测试阶段会用到
- num_batches_tracked,训练阶段的batch的数目,如果没有指定momentum,则用它来计算running_mean和running_var。一般momentum默认值为0.1,所以这个属性暂时没用。
weight和bias这两个参数需要训练,而running_mean、running_val和num_batches_tracked不需要训练,它们只是训练阶段的统计值
训练与推理时BN中的均值、方差分别是什么
训练时,均值、方差分别是该批次内数据相应维度的均值与方差;
推理时,均值、方差是基于所有批次的期望计算所得,公式如
BN两大效果
- 收敛速率增加
- 可以达到更好的精度
参考文档
Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift
【基础算法】六问透彻理解BN(Batch Normalization)
Pytorch-BN层详细解读
【深度学习】深入理解Batch Normalization批标准化
透彻理解BN(Batch Normalization)层相关推荐
- 理解BN(Batch Normalization)
https://www.cnblogs.com/king-lp 转自:参数优化方法 1. 深度学习流程简介 1)一次性设置(One time setup) - 激活函数(Activ ...
- 什么是BN(Batch Normalization)
什么是BN(Batch Normalization)? 在之前看的深度学习的期刊里,讲到了BN,故对BN做一个详细的了解.在网上查阅了许多资料,终于有一丝明白. 什么是BN? 2015年的论文< ...
- 残差结构Residual、BN(Batch Normalization)
残差结构Residual 初次接触残差结构是在ResNets的网络中,可以随着网络深度的增加,训练误差会越来越多(被称为网络退化)的问题,引入残差结构即使网络再深吗,训练的表现仍表现很好.它有助于 ...
- BN(Batch Normalization)
批量归一化(BN: Batch Normalization) batch size=8样本,每个样本4维度,左边是数字是第l层输出,即每个神经元输出8个响应值,再经过计算均值,方差后: 值都在0附近, ...
- 【论文理解】Batch Normalization论文中关于BN背景和减少内部协变量偏移的解读(论文第1、2节)
最近在啃Batch Normalization的原论文(Title:Batch Normalization: Accelerating Deep Network Training by Reducin ...
- 『教程』Batch Normalization 层介绍
原文链接 思考 YJango的前馈神经网络--代码LV3的数据预处理中提到过:在数据预处理阶段,数据会被标准化(减掉平均值.除以标准差),以降低不同样本间的差异性,使建模变得相对简单. 我们又知道神经 ...
- 当卷积层后跟batch normalization层时为什么不要偏置b
起因 之前使用的是inception-v2模型,在v2中,标准的卷积模块为: * conv_w_b->bn->relu ->pooling* 即,在v2中,尽管在卷积之后.激活函数之 ...
- 卷积神经网络CNN(2)—— BN(Batch Normalization) 原理与使用过程详解
前言 Batch Normalization是由google提出的一种训练优化方法.参考论文:Batch Normalization Accelerating Deep Network Trainin ...
- BN(Batch Normalization) 原理与使用过程详解
论文名字:Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift 论 ...
最新文章
- 基于关键帧的RGB-D视觉惯性里程计
- JS中for循环的两种写法
- 2021年 第12届 蓝桥杯 第4次模拟赛真题详解及小结【Java版】
- sql 计算两个小数乘积_数学篇|学会这些数学计算技巧,想不满分都难!
- Spring boot 2.4开启静态资源缓存
- 从xml数据集到FairMOT数据集转换
- php导出指定格式excel,php导出excel格式文件的例子
- 自用推荐【浏览器网页监控插件 Distill Web Monitor】
- 系统、驱动相关软件下载
- 自动控制原理9.3---线性定常系统的反馈结构及状态观测器
- 第4章 程序的控制结构(单元测试题Python含答案)
- 腾讯马化腾:公司拥有大量探索和开发元宇宙的技术和能力
- excel单元格内容拆分_Excel分列解决不了的问题,VBA轻松搞定之拆分单元格
- matlab神经网络 股票预测模型,基于BP神经网络的股票预测模型
- pandas DatetimeIndex indexing
- 曾扬言 机器人合法公民_曾扬言“摧毁人类”的机器人索菲亚,现状如何?如果失控了咋办?...
- Play Framework
- vscode占内存太大问题
- 我终于刷完了《觉醒年代》,对PMP有了新的思考...
- UI设计和平面设计的区别
热门文章
- mysql数据库配置_mysql数据库怎么配置
- intell idea 使用mave打springboot包的插件
- 老男孩python全栈第9期
- 7-2 海盗分赃 (25 分)(PTA)
- 黑马程序员java学习打卡----程序流程控制
- delphi中增加FastMM4有效管理你的内存使用
- halcon与C#混合编程(转)
- SEO面试题与面试攻略
- 有没有免费的视频剪辑软件?快来看看这些视频裁剪软件
- ConnectionResetError: [WinError 10054] 远程主机强迫关闭了一个现有的连接