【GiantPandaCV导语】

前段时间debug LayerNorm的时候,看见Pytorch LayerNorm计算方差的方式与我们并不一样。它使用了一种在线更新算法,速度更快,数值稳定性更好,这篇笔记就当一篇总结。

1回顾常见的方差计算方法

Two-pass方法

这种方法就是方差的定义式了:

简单来说就是样本减去均值,取平方,然后再累加起来除以样本数量(这里就不再具体分总体方差和样本方差了)。

那为什么他叫Two-pass方法呢?因为他需要循环两遍原始数据:

  • 第一遍统计,计算均值

  • 第二遍再将样本值和均值计算,得到方差 当数据比较大的时候,两遍循环耗时也比较多

Naive方法

我们还知道方差和均值的一个关系式子

相比Two-pass方法,这种方法仅仅只需要遍历一遍数据。我们只需要在外面统计两个变量,sumsum_square

最后再分别计算两者的均值,通过上述关系式子得到结果

根据维基百科的介绍,前面这两种方法的一个共同缺点是,其结果依赖于数据的排序,存在累加的舍入误差,对于大数据集效果较差

Welford算法

此前大部分深度学习框架都采用的是Naive的计算方法,后续Pytorch转用了这套算法。

首先给出结果,我们再来进行一步步的推导:

其中表示前n个元素的均值

推导

首先我们推导均值的计算:

当为n+1的情况下:

方差的推导稍微有点复杂,做好心理准备!

首先我们回到Naive公式

我们看下n+1时候的情况

我们把n+1乘到左边,并把n+1的平方项单独拆出来

而根据前面计算我们可以把替换掉

而我们前面推导均值的时候推导过,此时替换进来

左右两遍,同时乘上N+1,并进行化简,可以得到:

把挪到右边就可以得到

而根据平方公式的特性有

我们将其中一项用前面推导得到的均值来进行转换

然后替换到前面的公式进行化简就可以得到最终结果

额外拓展:

这样子更新方差,每一次都可能会加一个较小的数字,也会导致舍入误差,因此又做了个变换:

每次统计:

最后再得到方差:

这个转换是一个等价转换,感兴趣的读者可以从头一项一项的推导。

2实现代码

简单用python写了个脚本

import numpy as npdef welford_update(count, mean, M2, currValue):count += 1delta = currValue - meanmean += delta / countdelta2 = currValue - meanM2 += delta * delta2return (count, mean, M2)def naive_update(sum, sum_square, currValue):sum = sum + currValuesum_square = sum_square + currValue * currValuereturn (sum, sum_square)x_arr = np.random.randn(100000).astype(np.float32)welford_mean = 0
welford_m2 = 0
welford_count = 0
for i in range(len(x_arr)):new_val = x_arr[i]welford_count, welford_mean, welford_m2 = welford_update(welford_count, welford_mean, welford_m2, new_val)
print("Welford mean: ", welford_mean)
print("Welford var: ", welford_m2 / welford_count)naive_sum = 0
naive_sum_square = 0
for i in range(len(x_arr)):new_val = x_arr[i]naive_sum, naive_sum_square = naive_update(naive_sum, naive_sum_square, new_val)
naive_mean = naive_sum / len(x_arr)
naive_var = naive_sum_square/ len(x_arr) - naive_mean*naive_mean
print("Naive mean: ", naive_mean)
print("Naive var: ", naive_var)

更多的代码可以参考pytorch和apex实现:

pytorch moments实现:https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/SharedReduceOps.h#L95-L113

apex实现:https://github.com/NVIDIA/apex/blob/master/csrc/layer_norm_cuda_kernel.cu#L11-L24

3参考资料

  • wiki:https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Welford's_online_algorithm

  • https://changyaochen.github.io/welford/

笔者主要是根据上面这两个材料进行学习,第二个博客写的十分详细,还有配套的jupyter notebook代码跑,十分推荐。

欢迎加入GiantPandaCV交流群

用Welford算法实现LN的方差更新相关推荐

  1. 机器学习算法工程师面试集锦(更新中)

    机器学习算法工程师面试集锦(更新中) 面试问题汇总 常用的损失函数 介绍交叉验证 特征选择方法 机器学习项目的一般步骤 经验风险最小化与结构风险最小化 训练决策树时的参数是什么 在决策树的节点处分割标 ...

  2. leetcode贪心算法题集锦(持续更新中)

    leetcode贪心算法题集锦 leetcode贪心算法题集锦(持续更新中).python 和C++编写. 文章目录 leetcode贪心算法题集锦 一.贪心算法 1.盛最多水的容器 2.买股票的最佳 ...

  3. 【强化学习入门】梯度赌博机算法中,偏好函数更新:梯度上升公式是精确梯度上升的随机近似的证明

    本文证明强化学习入门问题:K摇臂赌博机的梯度赌博机算法中,偏好函数更新公式:Ht+1(At)=Ht(At)+α(Rt−Rt‾)(1−πt(At))H_{t+1}(A_t) = H_t(A_t) + \ ...

  4. BP算法误差逆传播参数更新公式推导

    BP算法误差逆传播参数更新公式推导

  5. BP算法,用梯度下降法更新权值W与偏置项b

    Bp算法实际是输出的误差函数对每一个参数求导,输出层可以直接求出,非输出层则有链式法则求导.这里以上图5层神经网络为例进行说明. 一   符号说明: 1)这里使用激活函数为sigmoid函数:     ...

  6. 捷联惯导算法(四)姿态更新算法

    前言 本文是对姿态更新算法的理解. 一.姿态更新算法 地心惯性坐标系(i系)绝对不动的惯性参考坐标系,与时间无关. 个人理解: 地心惯性坐标系是绝对不动的,因此机体系到导航系的姿态矩阵可以先由机体系到 ...

  7. 捷联惯导算法(二)位置更新算法的理解

    前言 文中算法公式摘自<捷联惯导算法与组合导航原理>(严恭敏.翁浚 编著).<惯性导航>(秦永元 编著),其他理解仅代表个人观点.本文是对位置更新算法,按照自己学习的思路整理得 ...

  8. 算法与数据结构模版(AcWing算法基础课笔记,持续更新中)

    AcWing算法基础课笔记 文章目录 AcWing算法基础课笔记 第一章 基础算法 1. 排序 快速排序: 归并排序: 2. 二分 整数二分 浮点数二分 3. 高精度 高精度加法 高精度减法 高精度乘 ...

  9. 【自适应波束形成算法】 ---- 线性约束最小方差准则(公式推导)

    波束形成是阵列信号处理中的一个重要领域.常规的波束形成,可以通过FFT是实现,其权矢量一般由期望方向的导向矢量加窗后得到. 假设有一个由N个阵元组成的线阵,有一来自方向的来波信号入射到阵元上,其导向矢 ...

  10. 五种最短路径算法的总结(待更新)

    最短路径算法: 1:Dijkstra     2:Floyd     3:Bellman-Ford     4:SPFA     5:A* 这五种最短路径算法初学的时候非常容易混淆,因为他们的松弛方法 ...

最新文章

  1. 计算机四级分数怎么查,计算机三四级成绩查询正确打开方式
  2. C++ Primer 5th笔记(chap 16 模板和泛型编程)类模板部分特例化
  3. 真恶心,用安卓模拟器开微信不能找附近的人
  4. android 虚拟键盘改变单个按键颜色_这款机械键盘很特别!一亿次按键寿命还有高颜值...
  5. [14-01] 闭包
  6. 树莓派开机后画面一闪之后黑屏的解决
  7. Beijing54坐标系——Y坐标(6位数和8位数)区别
  8. 计算机论文中期考核报告,(硕士学位论文中期考核报告范文.doc
  9. 好好说话 -简单概括
  10. windows删除注册表(通用方法)
  11. 谷歌商店上架APP被拒绝
  12. APP分享微信小程序
  13. Python 取模运算(取余)%误区及详解
  14. DB2 SQLSTATE 讯息(二)
  15. React Native带你一步步实现热更新(CodePush-Android)
  16. 02. Java环境搭建
  17. MySQL基础知识面试选择题40
  18. Word 2010 自定义首行缩进的快捷键
  19. 渗透分支写脚本_给小白的黑盒渗透测试作业——漏洞分析测试到安全加固建议...
  20. Go语言_通神路之灵胎篇(4)

热门文章

  1. 前端面试题2016--CSS
  2. Container类型元素累加
  3. 八.创建型设计模式——Singleton Pattern(单例模式)
  4. 常用排序算法之插入排序 ( 直接插入排序、希尔排序 )
  5. 列出选定月份的时间序列
  6. socket编程-阻塞和非阻塞
  7. Python的几个相关实例
  8. ZipArchive是一个开源的zip开发包工具。
  9. centos6.5安装sublime text 2
  10. 转:百度又开始踢新浪屁股了