BN中的滑动平均是怎么做的


训练过程中的每一个batch都会进行滑动平均的计算[1]:

moving_mean = moving_mean * momentum + batch_mean * (1 - momentum)
moving_var = moving_var * momentum + batch_var * (1 - momentum)

式中的 momentum 为动量参数,在 TF/Keras 中,该值为0.99,在 Pytorch 中,这个值为0.9

初始值,moving_mean=0,moving_var=1,相当于标准正态分布,当然,理论上初始化为任意值都可以

在实际的代码中,滑动平均的计算会以一种更高效的方式,但实际上是等价的:

moving_mean -= (moving_mean - batch_mean) * (1 - momentum)
moving_var -= (moving_var - batch_var) * (1 - momentum)

滑动平均中 Momentum 参数的影响

整个训练阶段滑动平均的过程,(moving_mean, moving_var) 参数实际上是从正态分布,向训练集真实分布靠拢的一个过程。

理论上,训练步数越长是会越靠近真实分布的,实际上,因为每个batch并不能代表整个训练集的分布,所以最后的值是在真实分布附近波动。

一个更小的 momentum 值,意味着更大的更新步长,对应着滑动平均值更快的变化,能更快地向真实值靠拢,但也意味着更大的波动性,更大的 momentum 值则相反。

训练阶段使用的是 (batch_mean, batch_var),所以滑动平均并不会影响训练阶段的结果,而是影响预测阶段的效果。关于BN在训练和测试时的差别可参考[2] 。

如果训练步数很短,一个大的 momentum 值可能会导致 (moving_mean, moving_var) 还没有靠拢到真实分布就停止了,这样对预测阶段的影响是很大的,也会是欠拟合的一个状态。如果训练步数足够,一个大的 momentum 值对应小的更新步长,最后的滑动平均的值是会更接近真实值的。

如果batch size 比较小,那单个batch的 (batch_mean, batch_var) 和真实分布会比较大,此时滑动平均单次更新的步长就不应过大,适用一个大的 momentum 值,反之可类比分析。

BN 前向过程代码实现

def batchnorm_forward(x, gamma, beta, bn_param):"""Forward pass for batch normalization.During training the sample mean and (uncorrected) sample variance arecomputed from minibatch statistics and used to normalize the incoming data.During training we also keep an exponentially decaying running mean of themean and variance of each feature, and these averages are used to normalizedata at test-time.At each timestep we update the running averages for mean and variance usingan exponential decay based on the momentum parameter:running_mean = momentum * running_mean + (1 - momentum) * sample_meanrunning_var = momentum * running_var + (1 - momentum) * sample_varInput:- x: Data of shape (N, D)- gamma: Scale parameter of shape (D,)- beta: Shift paremeter of shape (D,)- bn_param: Dictionary with the following keys:- mode: 'train' or 'test'; required- eps: Constant for numeric stability- momentum: Constant for running mean / variance.- running_mean: Array of shape (D,) giving running mean of features- running_var Array of shape (D,) giving running variance of featuresReturns a tuple of:- out: of shape (N, D)- cache: A tuple of values needed in the backward pass"""mode = bn_param['mode']eps = bn_param.get('eps', 1e-5)momentum = bn_param.get('momentum', 0.9)N, D = x.shaperunning_mean = bn_param.get('running_mean', np.zeros(D, dtype=x.dtype))running_var = bn_param.get('running_var', np.ones(D, dtype=x.dtype))if mode == 'train':sample_mean = x.mean(axis=0)sample_var = x.var(axis=0)running_mean = momentum * running_mean + (1 - momentum) * sample_meanrunning_var = momentum * running_var + (1 - momentum) * sample_varstd = np.sqrt(sample_var + eps)x_centered = x - sample_meanx_norm = x_centered / stdout = gamma * x_norm + betacache = (x_norm, x_centered, std, gamma)elif mode == 'test':x_norm = (x - running_mean) / np.sqrt(running_var + eps)out = gamma * x_norm + betaelse:raise ValueError('Invalid forward batchnorm mode "%s"' % mode)# Store the updated running means back into bn_parambn_param['running_mean'] = running_meanbn_param['running_var'] = running_varreturn out, cache

注:代码参考[3],原代码中 running_var 也初始化是 np.zeros,本文做了修改。

running_mean和 running_var 的初始值为正态分布参数值,可参考 Pytorch 代码中的 _NormBase 类

参考:

[1] https://jiafulow.github.io/blog/2021/01/29/moving-average-in-batch-normalization/

[2] https://zhuanlan.zhihu.com/p/61725100

[3] https://towardsdatascience.com/implementing-batch-normalization-in-python-a044b0369567

[4] 题图参考:https://kaixih.github.io/batch-norm/

BN/Batch Norm中的滑动平均/移动平均/Moving Average相关推荐

  1. 理解滑动平均(exponential moving average)

    1. 用滑动平均估计局部均值 滑动平均(exponential moving average),或者叫做指数加权平均(exponentially weighted moving average),可以 ...

  2. 滑动平均(Moving Average Models,MA)模型

    这里我们直接给出MA(q)模型的形式: c0为一个常数项.这里的at,是AR模型t时刻的扰动或者说新息(也就是白噪声误差项),则可以发现,该模型,使用了过去q个时期的随机干扰或预测误差来线性表达当前的 ...

  3. 【matplotlib】 移动平均(Moving Average)

    1.简介 定义见 移动平均(Moving Average,MA),又称移动平均线,简称均线.作为技术分析中一种分析时间序列的常用工具,常被应用于股票价格序列.移动平均可过滤高频噪声,反映出中长期低频趋 ...

  4. 神经网络优化:指数衰减计算平均值(滑动平均)

    Polyak平均会平均优化算法在参数空间访问中的几个点.如果t次迭代梯度下降访问了点,那么Polyak平均算法的输出是. 当应用Polyak平均于非凸问题时,通常会使用指数衰减计算平均值: 1. 用滑 ...

  5. EMA指数滑动平均(Exponential Moving Average)

    指数滑动平均(Exponential Moving Average) 指数滑动平均也叫权重移动平均(Weighted Moving Average),是一种给予近期数据更高权重的平均方法. 假设有nn ...

  6. tensorflow || 滑动平均的理解--tf.train.ExponentialMovingAverage

    1 滑动平均的理解 滑动平均(exponential moving average),或者叫做指数加权平均(exponentially weighted moving average),可以用来估计变 ...

  7. TensorFlow滑动平均模型

    指数加权平均算法的原理 TensorFlow中的滑动平均模型使用的是滑动平均(Moving Average)算法,又称为指数加权移动平均算法(exponenentially weighted aver ...

  8. matlab实现滑动平均滤波(二)

    滑动平均(moving average):在地球物理异常图上,选定某一尺寸的窗口,将窗口内的所有异常值做算术平均,将平均值作为窗口中心点的异常值.按点距或线距移动窗口,重复此平均方法,直到对整幅图完成 ...

  9. YOLOv5的Tricks | 【Trick7】指数移动平均(Exponential Moving Average,EMA)

    如有错误,恳请指出. 文章目录 1. 移动平均法 2. 指数移动平均 3. TensorFlow中的EMA使用 4. Yolov5中的EMA使用 这篇博客主要用于整理网上对EMA(指数移动平均)的介绍 ...

最新文章

  1. 死磕Java并:J.U.C之ConcurrentHashMap红黑树转换分析
  2. python的整数类型有几种进制_(一)Python入门-2编程基本概念:08整数-不同进制-其他类型转换成整数...
  3. AS插件-Android Layout ID Converter
  4. 12306的变态验证码算得了什么?我有Python神器!
  5. 【矩阵】概念的理解 —— span、基
  6. Haproxy+多台MySQL从服务器(Slave) 实现负载均衡
  7. python 计时_Python time clock()方法
  8. python算法题排序_python-数据结构与算法- 面试常考排序算法题-快排-冒泡-堆排-二分-选择等...
  9. 上周热点回顾(10.8-10.14)
  10. CHSBO2018游记
  11. 如何清洗 Git Repo 代码仓库
  12. java Vector.toArray 与强制类型转换
  13. layui中关于重置按钮不起作用的提醒
  14. 去掉迅雷右侧内置浏览器
  15. 图片加文字用什么软件?推荐这三款软件给你
  16. 小米首页产品调研分析和设计方案介绍(详细的倒计时代码介绍)
  17. 【Java进阶营】阿里架构师加持,十分钟入门RocketMQ,就是这么简单
  18. 一套实用性最强的商业方案,让他白手起家做到全国十大财阀之一!
  19. 东北大学应用数理统计知识点总结——历年真题题型
  20. ZYNQ TTC使用方法

热门文章

  1. 介绍一个查看TCP连接的工具TCPView
  2. 准确率和召回率(precisionrecall)
  3. iphone导出视频 无法连接到设备_iPhone 使用技巧:及时关注手机储存容量
  4. 码元,数据传输速率,带宽,信噪比,信道容量
  5. 滤波算法 | 无迹卡尔曼滤波(UKF)算法及其MATLAB实现
  6. 快速分析德邦快递走件信息,并筛选代收的单号
  7. 线程同步互斥机制--互斥锁
  8. python多线程,线程锁
  9. 用c++实现贪吃蛇小游戏,初学者记录一下首次实现的经历,有超详细的思路与语法讲解,新手向
  10. SOCKET链接速度慢