原文链接:Cross-Iteration Batch Normalization
代码链接:https://github.com/Howal/Cross-iterationBatchNorm

随着BN的提出,现有的网络基本都在使用。但是在显存有限或者某些任务不允许大batch(比如检测或者分割任务相比分类任务训练时的batchsize一般会小很多)的情况下,BN效果就会差很多,如下图所示,当batch小于16时分类准确率急剧下降。

为了改善小batch情况下网络性能下降的问题,有各种新的normalize方法被提出来了(LN、IN、GN),详情请看文章GN-Group Normalization。上述不同的normalize方式适用于不同的任务,其中GN就是为检测任务设计的,但是它会降低推断时的速度。

本文的思想其实很简单,既然从空间维度不好做,那么就从时间维度进行,通过以前算好的BN参数用来计算更新新的BN参数,从而改善网络性能。从上图可以看出使用新提出的CBN,即使batch size小也可以获得较好的分类准确率。

通过过去的BN参数来计算新的BN参数有个问题,就是过去的BN参数是根据过去的网络参数计算出的feature来估算的,新的BN参数计算时,参数已经更新过了,如果直接使用之前的参数来计算新的BN参数会使得参数估计不准,网络性能下降,如上图中的Naive CBN。为了改进这种缺陷,文章使用泰勒多项式来估算新的BN参数。下面来介绍具体的计算方式。

一、Batch Normalization

先来回顾一下原始的BN计算公式。
BN操作的公式如下
x^t,i(θt)=xt,i(θt)−μt(θt)σt(θt)2+ϵ\hat{x}_{t,i}(\theta_t) = \frac{x_{t,i}(\theta_t)-\mu_t(\theta_t)}{\sqrt{\sigma_t(\theta_t)^2 + \epsilon}}x^t,i​(θt​)=σt​(θt​)2+ϵ​xt,i​(θt​)−μt​(θt​)​
yt,i(θt)=γx^t,i(θt)+βy_{t,i}(\theta_t)=\gamma \hat{x}_{t,i}(\theta_t)+\betayt,i​(θt​)=γx^t,i​(θt​)+β
上式中θt\theta_tθt​表示第t个mini-batch训练时的网络参数,xt,i(θt)x_{t,i}(\theta_t)xt,i​(θt​)表示第t个mini-batch中第i个样本经过网络得到的feature map,x^t,i(θt)\hat{x}_{t,i}(\theta_t)x^t,i​(θt​)表示bn后均值为0和方差为1的特征,μt(θt)\mu_t(\theta_t)μt​(θt​)和σt(θt)\sigma_t(\theta_t)σt​(θt​)表示当前mini-batch计算出来的均值和方差,ϵ\epsilonϵ是防除零的系数,γ\gammaγ和β\betaβ表示BN中需要学习的参数。

μt(θt)\mu_t(\theta_t)μt​(θt​)和σt(θt)\sigma_t(\theta_t)σt​(θt​)计算方式如下
μt(θt)=1m∑i=1mxt,i(θt)\mu_t(\theta_t)=\frac{1}{m}\sum^m_{i=1} x_{t,i}(\theta_t)μt​(θt​)=m1​∑i=1m​xt,i​(θt​)
σt(θt)=1m∑i=1m(xt,i(θt)−μt(θt))2=νt(θt)−μt(θt)2\sigma_t(\theta_t)=\sqrt{\frac{1}{m}\sum^m_{i=1} (x_{t,i}(\theta_t)-\mu_t(\theta_t))^2} = \sqrt{\nu_t(\theta_t)-\mu_t(\theta_t)^2}σt​(θt​)=m1​∑i=1m​(xt,i​(θt​)−μt​(θt​))2​=νt​(θt​)−μt​(θt​)2​
其中νt(θt)=1m∑i=1mxt,i(θt)2\nu_t(\theta_t)=\frac{1}{m}\sum^m_{i=1} x_{t,i}(\theta_t)^2νt​(θt​)=m1​∑i=1m​xt,i​(θt​)2,m表示mini-batch里面有m个样本。

二、估计之前的均值和方差

假设现在是第t次迭代,那t之前的迭代计算时,拿t−τt-\taut−τ次迭代来说,已经计算过了对应的均值和方差了,但是之前的计算是使用之前的网络参数得到的,用符号表示为μt−τ(θt−τ)\mu_{t-\tau}(\theta_{t-\tau})μt−τ​(θt−τ​)和νt−τ(θt−τ)\nu_{t-\tau}(\theta_{t-\tau})νt−τ​(θt−τ​)。现在想要估计的是参数μt−τ(θt)\mu_{t-\tau}(\theta_{t})μt−τ​(θt​)和νt−τ(θt)\nu_{t-\tau}(\theta_{t})νt−τ​(θt​)。文章认为连续几次网络参数的变化是平滑的,所以根据泰勒展开式可以估计出上述的两个参数:
μt−τ(θt)=μt−τ(θt−τ)+∂μt−τ(θt−τ)∂θt−τ(θt−θt−τ)+O(∥θt−θt−τ∥2)\mu_{t-\tau}(\theta_{t})=\mu_{t-\tau}(\theta_{t-\tau}) + \frac{\partial \mu_{t-\tau}(\theta_{t-\tau})}{\partial \theta_{t-\tau}}(\theta_{t}-\theta_{t-\tau})+O(\lVert \theta_{t}-\theta_{t-\tau}\lVert ^2)μt−τ​(θt​)=μt−τ​(θt−τ​)+∂θt−τ​∂μt−τ​(θt−τ​)​(θt​−θt−τ​)+O(∥θt​−θt−τ​∥2)
νt−τ(θt)=νt−τ(θt−τ)+∂νt−τ(θt−τ)∂θt−τ(θt−θt−τ)+O(∥θt−θt−τ∥2)\nu_{t-\tau}(\theta_{t})=\nu_{t-\tau}(\theta_{t-\tau}) + \frac{\partial \nu_{t-\tau}(\theta_{t-\tau})}{\partial \theta_{t-\tau}}(\theta_{t}-\theta_{t-\tau})+O(\lVert \theta_{t}-\theta_{t-\tau}\lVert ^2)νt−τ​(θt​)=νt−τ​(θt−τ​)+∂θt−τ​∂νt−τ​(θt−τ​)​(θt​−θt−τ​)+O(∥θt​−θt−τ​∥2)

其中∂μt−τ(θt−τ)∂θt−τ\frac{\partial\mu_{t-\tau}(\theta_{t-\tau})}{\partial\theta_{t-\tau}}∂θt−τ​∂μt−τ​(θt−τ​)​和∂νt−τ(θt−τ)∂θt−τ\frac{\partial\nu_{t-\tau}(\theta_{t-\tau})}{\partial\theta_{t-\tau}}∂θt−τ​∂νt−τ​(θt−τ​)​表示对网络参数求偏导数,O(∥θt−θt−τ∥2)O(\lVert \theta_{t}-\theta_{t-\tau}\lVert ^2)O(∥θt​−θt−τ​∥2)表示泰勒展开式的高阶项,当(θt−θt−τ)(\theta_{t}-\theta_{t-\tau})(θt​−θt−τ​)较小时,高阶项可以忽略不计。

注意:要精确的获得上式中的∂μt−τ(θt−τ)∂θt−τ\frac{\partial\mu_{t-\tau}(\theta_{t-\tau})}{\partial\theta_{t-\tau}}∂θt−τ​∂μt−τ​(θt−τ​)​和∂νt−τ(θt−τ)∂θt−τ\frac{\partial\nu_{t-\tau}(\theta_{t-\tau})}{\partial\theta_{t-\tau}}∂θt−τ​∂νt−τ​(θt−τ​)​的值计算量会很大,因为网络中的第l层的参数μt−τl(θt−τ){\mu^l_{t-\tau}(\theta_{t-\tau})}μt−τl​(θt−τ​)和νt−τl(θt−τ){\nu^l_{t-\tau}(\theta_{t-\tau})}νt−τl​(θt−τ​)会依赖之前的层,例如依赖∂μt−τl(θt−τ)∂θt−τr≠0\frac{\partial\mu^l_{t-\tau}(\theta_{t-\tau})}{\partial\theta^r_{t-\tau}}\neq0∂θt−τr​∂μt−τl​(θt−τ​)​​=0和∂νt−τl(θt−τ)∂θt−τr≠0\frac{\partial\nu^l_{t-\tau}(\theta_{t-\tau})}{\partial\theta^r_{t-\tau}}\neq 0∂θt−τr​∂νt−τl​(θt−τ​)​​=0,这里r≤lr\leq lr≤l,θt−τr\theta^r_{t-\tau}θt−τr​表示第r层的参数。实际上文章发现当r≤lr\leq lr≤l时,偏导数∂μt−τl(θt)∂θtr\frac{\partial\mu^l_{t-\tau}(\theta_{t})}{\partial\theta^r_{t}}∂θtr​∂μt−τl​(θt​)​和∂νtl(θt)∂θtr\frac{\partial\nu^l_{t}(\theta_{t})}{\partial\theta^r_{t}}∂θtr​∂νtl​(θt​)​减小的非常快

所以上两等式可以近似表示为下两式
μt−τl(θt)≈μt−τl(θt−τ)+∂μt−τl(θt−τ)∂θt−τl(θtl−θt−τl)\mu^l_{t-\tau}(\theta_{t})\approx\mu^l_{t-\tau}(\theta_{t-\tau}) + \frac{\partial \mu^l_{t-\tau}(\theta_{t-\tau})}{\partial \theta^l_{t-\tau}}(\theta^l_{t} - \theta^l_{t-\tau})μt−τl​(θt​)≈μt−τl​(θt−τ​)+∂θt−τl​∂μt−τl​(θt−τ​)​(θtl​−θt−τl​)
νt−τl(θt)≈νt−τl(θt−τ)+∂νt−τl(θt−τ)∂θt−τl(θtl−θt−τl)\nu^l_{t-\tau}(\theta_{t})\approx\nu^l_{t-\tau}(\theta_{t-\tau}) + \frac{\partial \nu^l_{t-\tau}(\theta_{t-\tau})}{\partial \theta^l_{t-\tau}}(\theta^l_{t} - \theta^l_{t-\tau})νt−τl​(θt​)≈νt−τl​(θt−τ​)+∂θt−τl​∂νt−τl​(θt−τ​)​(θtl​−θt−τl​)

三、Cross-Iteration Batch Normalization

CBN的工作方式如下

上一节利用之前的参数估算出来了当前参数下l层在t−τt-\taut−τ次迭代的参数值,通过这些估计值可以计算出l层在当前迭代时的BN参数,计算公式如下:
μˉt,kl(θt)=1k∑τ=0k−1μt−τl(θt)\bar{\mu}^l_{t,k}(\theta_t)=\frac{1}{k}\sum^{k-1}_{\tau=0}\mu^l_{t-\tau}(\theta_t)μˉ​t,kl​(θt​)=k1​∑τ=0k−1​μt−τl​(θt​)
νˉt,kl(θt)=1k∑τ=0k−1max[νt−τl(θt),μt−τl(θt)2]\bar{\nu}^l_{t,k}(\theta_t)=\frac{1}{k}\sum^{k-1}_{\tau=0} max[\nu^l_{t-\tau}(\theta_t) , \mu^l_{t-\tau}(\theta_t)^2]νˉt,kl​(θt​)=k1​∑τ=0k−1​max[νt−τl​(θt​),μt−τl​(θt​)2]
σˉt,k(θt)=νˉt,k(θt)−μˉt,k(θt)2\bar{\sigma}_{t,k}(\theta_t)= \sqrt{\bar{\nu}_{t,k}(\theta_t)-\bar{\mu}_{t,k}(\theta_t)^2}σˉt,k​(θt​)=νˉt,k​(θt​)−μˉ​t,k​(θt​)2​
注意:在有效的统计时νt−τl(θt)≥μt−τl(θt)2\nu^l_{t-\tau}(\theta_t) \geq \mu^l_{t-\tau}(\theta_t)^2νt−τl​(θt​)≥μt−τl​(θt​)2是一直都会满足的,但是利用泰勒展开式估算不一定能满足条件,所以上述使用了max函数来保证这一点。还要说的是,代码实现的时候并没有这样写,实现时k是指所有νt−τl(θt)≥μt−τl(θt)2\nu^l_{t-\tau}(\theta_t) \geq \mu^l_{t-\tau}(\theta_t)^2νt−τl​(θt​)≥μt−τl​(θt​)2的元素,即过滤了不满足条件的值

最后CBN更新特征的方式同BN
x^t,il(θt)=xt,il(θt)−μˉtl(θt)σˉtl(θt)2+ϵ\hat{x}^l_{t,i}(\theta_t) = \frac{x^l_{t,i}(\theta_t)-\bar\mu^l_t(\theta_t)}{\sqrt{\bar\sigma^l_t(\theta_t)^2 + \epsilon}}x^t,il​(θt​)=σˉtl​(θt​)2+ϵ​xt,il​(θt​)−μˉ​tl​(θt​)​

CBN的伪代码如下

四、计算量优化

文章附录还有一个求偏导的优化方式来节省计算量,假设l和l-1层卷积的通道数为ClC^lCl和Cl−1C^{l-1}Cl−1,K表示l层卷积的kernelsize,因此μt−τl\mu^l_{t-\tau}μt−τl​和νt−τl\nu^l_{t-\tau}νt−τl​通道数为ClC^lCl,θt−τl\theta^l_{t-\tau}θt−τl​的维度数为Cl×Cl−1×KC^l \times C^{l-1} \times KCl×Cl−1×K,如果直接计算∂μt−τl(θt−τ)∂θt−τl\frac{\partial\mu^l_{t-\tau}(\theta_{t-\tau})}{\partial\theta^l_{t-\tau}}∂θt−τl​∂μt−τl​(θt−τ​)​和∂νt−τl(θt−τ)∂θt−τl\frac{\partial\nu^l_{t-\tau}(\theta_{t-\tau})}{\partial\theta^l_{t-\tau}}∂θt−τl​∂νt−τl​(θt−τ​)​,时间复杂度为O(Cl×Cl×Cl−1×K)O(C^l \times C^l \times C^{l-1} \times K)O(Cl×Cl×Cl−1×K)。通过下面的推导我们可以知道可以在O(Cl−1×K)O(C^{l-1} \times K)O(Cl−1×K)和O(Cl×Cl−1×K)O(C^l \times C^{l-1} \times K)O(Cl×Cl−1×K)时间内求出μ\muμ和ν\nuν

下面拿μ\muμ来举例说明,为了简化符号,下面用μl\mu^lμl和θl\theta^lθl代替μt−τl(θt−τ)\mu^l_{t-\tau}(\theta_{t-\tau})μt−τl​(θt−τ​)和νt−τl(θt−τ)\nu^l_{t-\tau}(\theta_{t-\tau})νt−τl​(θt−τ​)。
μjl=1m∑i=1mxi,jl\mu^l_j=\frac{1}{m}\sum^m_{i=1} x^l_{i,j}μjl​=m1​∑i=1m​xi,jl​
上式中μjl\mu^l_jμjl​表示μl\mu^lμl中第j个通道的值,xi,jlx^l_{i,j}xi,jl​表示第i个样本的第j个通道值。

xi,jl=∑n=1Cl−1∑k=1Kθj,n,kl⋅yi+offset(k),nl−1x^l_{i,j}=\sum^{C^{l-1}}_{n=1}\sum^K_{k=1}\theta^l_{j,n,k}\cdot y^{l-1}_{i + offset(k),n}xi,jl​=∑n=1Cl−1​∑k=1K​θj,n,kl​⋅yi+offset(k),nl−1​
上式中,n和k分别表示输入feature的通道维度索引和卷积的kernel维度索引,offset表示卷积时的索引值,yl−1y^{l-1}yl−1表示l-1层的输出。

∂μl∂θl∈RCl×Cl×Cl−1×K\frac{\partial{\mu^l}}{\partial{\theta^l}}\in R^{C^l \times C^l \times C^{l-1} \times K}∂θl∂μl​∈RCl×Cl×Cl−1×K计算公式如下
[∂μl∂θl]j,q,p,η=∂μjl∂θq,p,nl=∂1m∑i=1m∑n=1Cl−1∑k=1Kθj,n,kl⋅yi+offset(k),nl−1∂θq,p,nl={1m∑i=1myi+offset(η),pl−1,j=q0,j≠q[\frac{\partial{\mu^l}}{\partial{\theta^l}}]_{j,q,p,\eta}=\frac{\partial \mu^l_j}{ \partial \theta^l_{q,p,n} } \\ =\frac{\partial \frac{1}{m}\sum^m_{i=1}\sum^{C^{l-1}}_{n=1}\sum^{K}_{k=1}\theta^{l}_{j,n,k}\cdot y^{l-1}_{i+offset(k),n}}{\partial \theta^l_{q,p,n}} \\ = \left\{ \begin{array}{lr} \frac{1}{m}\sum^{m}_{i=1} y^{l-1}_{i+offset(\eta) , p} && , j=q \\ 0 && , j\neq q \end{array} \right.[∂θl∂μl​]j,q,p,η​=∂θq,p,nl​∂μjl​​=∂θq,p,nl​∂m1​∑i=1m​∑n=1Cl−1​∑k=1K​θj,n,kl​⋅yi+offset(k),nl−1​​={m1​∑i=1m​yi+offset(η),pl−1​0​​,j=q,j​=q​
从上式可以看出当j=q时才要计算,其它情况不用计算,这样就减少了计算量。

到这里CBN就介绍完了,思想很简单,但是公式较多,可以看看代码实现,代码量不多。最后要说的是,这篇文章提出来时为了解决小batch训练的情况,yolov4提出有一点就是方便用户使用一个GPU训练,所以yolov4借用了该算法的思想。

CBN(Cross-Iteration Batch Normalization)论文详解相关推荐

  1. Batch Normalization函数详解及反向传播中的梯度求导

    摘要 本文给出 Batch Normalization 函数的定义, 并求解其在反向传播中的梯度 相关 配套代码, 请参考文章 : Python和PyTorch对比实现批标准化Batch Normal ...

  2. ShuffleNetv2论文详解

    ShuffleNet v2 论文详解 近期在研究轻量级 backbone 网络,我们所熟悉和工业界能部署的网络有 MobileNet V2.ShuffleNet V2.RepVGG 等,本篇博客是对 ...

  3. Spark 3.2.0 版本新特性 push-based shuffle 论文详解(一)概要和介绍

    前言 本文隶属于专栏<大数据技术体系>,该专栏为笔者原创,引用请注明来源,不足和错误之处请在评论区帮忙指出,谢谢! 本专栏目录结构和参考文献请见大数据技术体系 目录 Spark 3.2.0 ...

  4. YOLO v1论文详解

    YOLO v1:一体化的,实时物体检测 声明:笔者翻译论文仅为学习研究,如有侵权请联系作者删除博文,谢谢! 源论文地址:https://arxiv.org/pdf/1506.02640.pdf 注:文 ...

  5. 智能城市dqn算法交通信号灯调度_博客 | 滴滴 KDD 2018 论文详解:基于强化学习技术的智能派单模型...

    原标题:博客 | 滴滴 KDD 2018 论文详解:基于强化学习技术的智能派单模型 国际数据挖掘领域的顶级会议 KDD 2018 在伦敦举行,今年 KDD 吸引了全球范围内共 1480 篇论文投递,共 ...

  6. Fast R-CNN论文详解

    Fast R-CNN论文详解 作者:ture_dream &创新点 规避R-CNN中冗余的特征提取操作,只对整张图像全区域进行一次特征提取: 用RoI pooling层取代最后一层max po ...

  7. 限时9.9元 | 快速领取数学建模竞赛备战必备技巧与论文详解!

    全世界只有3.14 % 的人关注了 青少年数学之旅 大家晚上好,随着美赛时间的公布以及大大小小的数学建模竞赛的进行,小天经常可以收到来自很多小伙伴们提出的问题,"竞赛中如何去考虑选题?&qu ...

  8. batchnomal_pytorch的batch normalize使用详解

    torch.nn.BatchNorm1d() 1.BatchNorm1d(num_features, eps = 1e-05, momentum=0.1, affine=True) 对于2d或3d输入 ...

  9. transfromer-XL论文详解

    transfromer-XL论文详解 – 潘登同学的NLP笔记 文章目录 transfromer-XL论文详解 -- 潘登同学的NLP笔记 Vanilla Transformer Segment-Le ...

最新文章

  1. Linux数据写操作改进
  2. Linux 下 VNC配置和使用(本机控制本机)
  3. 如何产生高斯带限白噪声数据_车间噪声对我们的身体产生巨大影响,我们该如何解决?...
  4. python降维之时间类型数据的处理_python学习笔记之使用sklearn进行PCA数据降维
  5. 阶乘取模算法java_np问题(大数阶乘取模)
  6. [安卓] 19、一个蓝牙4.0安卓DEMO
  7. NLP将迎来黄金十年,7个案例带你入门(附Python代码)
  8. 集总参数电路的判定——电源波长λ和元件尺寸L的比较
  9. JavaWeb — session+实战项目
  10. 【C++】set和multiset区别
  11. 锅炉正反平衡计算热效率
  12. python语言是不是胶水语言_不会吧,不会吧,不会还有人觉得Python是胶水语言吧?...
  13. 游戏开发 | 基于 EasyX 库开发经典90坦克大战游戏
  14. linux防火墙reject,linux 防火墙配置与REJECT导致没有生效问题(示例代码)
  15. java计算器实训报告_Java实验报告计算器
  16. python数组中查找某个值,Python实现在某个数组中查找一个值的算法示例
  17. 在CAD制图软件中标注数学公式的操作技巧
  18. 【SCSS】1300- 这些 SCSS 使用技巧真好用~
  19. 知道这六种拍摄技巧,让你玩转夕阳拍摄
  20. Altium Designer16 软件汉化步骤

热门文章

  1. 小白都能看明白的VLAN原理解释(超详细)
  2. linebreak_vue-cli构建的项目,eslint一直报CRLF/LF的linebreak错误
  3. Excel透视表如何新增自定义列以及设置值汇总方式和值呈现方式
  4. 如何将Mindjet的宏放到自定义功能区
  5. Unity查看接入的Ironsource和adapter 版本号
  6. JDK国内华为镜像下载地址
  7. CentOS7中服务模块定时检查是否启动(未启动则启动该服务)
  8. 高数_第6章无穷级数__调和级数
  9. 125、新技术之微前端
  10. 今年最火爆的商业模式,九星创客新零售模式