Batchsize不够大,如何发挥BN性能?探讨神经网络在小Batch下的训练方法
由于算力的限制,有时我们无法使用足够大的batchsize,此时该如何使用BN呢?本文将介绍两种在小batchsize也可以发挥BN性能的方法。
本文首发自极市平台,作者 @皮特潘,转载需获授权。
前言
BN(Batch Normalization)几乎是目前神经网络的必选组件,但是使用BN有两个前提要求:
- batchsize不能太小;
- 每一个minibatch和整体数据集同分布。
不然的话,非但不能发挥BN的优势,甚至会适得其反。但是由于算力的限制,有时我们无法使用足够大的batchsize,此时该如何使用BN呢?本文介绍两篇在小batchsize也可以发挥BN性能的方法。解决思路为:既然batchsize太小的情况下,无法保证当前minibatch收集到的数据和整体数据同分布。那么能否多收集几个batch的数据进行统计呢?这两篇工作分别分别是:
- BRN:Batch Renormalization: Towards Reducing Minibatch Dependence in Batch-Normalized Models
- CBN:Cross-Iteration Batch Normalization
另外,本文也会给出代码解析,帮助大家理解。
batchsize过小的场景
通常情况下,大家对CNN任务的研究一般为公开的数据集指标负责。分类任务为ImageNet数据集负责,其尺度为224X224。检测任务为coco数据集负责,其尺度为640X640左右。分割任务一般为coco或PASCAL VOC数据集负责,后者的尺度大概在500X500左右。再加上例如resize的前处理操作,真正送入网络的图片的分辨率都不算太大。一般性能的GPU也很容易实现大的batchsize(例如大于32)的支持。
但是实际的项目中,经常遇到需要处理的图片尺度过大的场景,例如我们使用500w像素甚至2000w像素的工业相机进行数据采集,500w的相机采集的图片尺度就是2500X2000左右。而对于微小的缺陷检测、高精度的关键点检测或小物体的目标检测等任务,我们一般不太想粗暴降低输入图片的分辨率,这样违背了我们使用高分辨率相机的初衷,也可能导致丢失有用特征。在算力有限的情况下,我们的batchsize就无法设置太大,甚至只能为1或2。小的batchsize会带来很多训练上的问题,其中BN问题就是最突出的。虽然大batchsize训练是一个共识,但是现实中可能无法具有充足的资源,因此我们需要一些处理手段。
BN回顾
首先Batch Normalization 中的Normalization被称为标准化,通过将数据进行平和缩放拉到一个特定的分布。BN就是在batch维度上进行数据的标准化。BN的引入是用来解决 internal covariate shift 问题,即训练迭代中网络激活的分布的变化对网络训练带来的破坏。BN通过在每次训练迭代的时候,利用minibatch计算出的当前batch的均值和方差,进行标准化来缓解这个问题。虽然How Does Batch Normalization Help Optimization 这篇文章探究了BN其实和Internal Covariate Shift (ICS)问题关系不大,本文不深入讨论,这个会在以后的文章中细说。
一般来说,BN有两个优点:
- 降低对初始化、学习率等超参的敏感程度,因为每层的输入被BN拉成相对稳定的分布,也能加速收敛过程。
- 应对梯度饱和和梯度弥散,主要是对于使用sigmoid和tanh的激活函数的网络。
当然,BN的使用也有两个前提:
- minibatch和全部数据同分布。因为训练过程每个minibatch从整体数据中均匀采样,不同分布的话minibatch的均值和方差和训练样本整体的均值和方差是会存在较大差异的,在测试的时候会严重影响精度。
- batchsize不能太小,否则效果会较差,论文给的一般性下限是32。
再来回顾一下BN的具体做法:
- 训练的时候:使用当前batch统计的均值和方差对数据进行标准化,同时优化优化gamma和beta两个参数。另外利用指数滑动平均收集全局的均值和方差。
- 测试的时候:使用训练时收集全局均值和方差以及优化好的gamma和beta进行推理。
可以看出,要想BN真正work,就要保证训练时当前batch的均值和方差逼近全部数据的均值和方差。
BRN
论文题目:Batch Renormalization: Towards Reducing Minibatch Dependence in Batch-Normalized Models
论文地址: https://arxiv.org/pdf/1702.03275.pdf
代码地址: https://github.com/ludvb/batchrenorm
核心解析:
本文的核心思想就是:训练过程中,由于batchsize较小,当前minibatch统计到的均值和方差与全部数据有差异,那么就对当前的均值和方差进行修正。修正的方法主要是利用到通过滑动平均收集到的全局均值和标准差。看公式:
xi−μσ=xi−μBσB⋅r+d,where r=σBσ,d=μB−μσ\frac{x_{i}-\mu}{\sigma}=\frac{x_{i}-\mu_{\mathcal{B}}}{\sigma_{\mathcal{B}}} \cdot r+d, \quad \text { where } r=\frac{\sigma_{\mathcal{B}}}{\sigma}, \quad d=\frac{\mu_{\mathcal{B}}-\mu}{\sigma} σxi−μ=σBxi−μB⋅r+d, where r=σσB,d=σμB−μ
上面公式中,i表示网络的第i层。μ和σ表示网络推理时的均值和标准差,也就是训练过程通过滑动平均收集的到均值和方差。μB和σb表示当前训练迭代过程中的实际统计到的均值和标准差。BN在小batch不work的根本原因就是这两组参数存在较大的差异。通过r和d对训练过程中数据进行线性变换,在该变化下,上公式左右两端就严格相等了。其实标准的BN就是r=1,d=0的一种情况。对于某一个特定的minibatch,其中r和d可以看成是固定的,是直接计算出来的,不需要梯度优化的。
具体流程:
统计当前batch数据的均值和标注差,和标准BN做法一致。
根据当前batch的均值和标准差结合全局的均值和标准差利用上面的公式计算r和d;注意该运算是不参与梯度反向传播的。另外,r和d需要增加一个限制,直接clip操作就好。
利用当前的均值和标准差对当前数据执行Normalization操作,利用上面计算得到的r和d对当前batch进行线性变换。
滑动平均收集全局均值和标注差。
测试过程和标准BN一样。其实本质上,就是训练的过程中使用全局的信息进行更新当前batch的数据。间接利用了全局的信息,而非当前这一个batch的信息。
实验效果:
在较大的batchsize(32)的时候,与标准BN相比,不会丢失效果,训练过程一如既往稳定高效。如下:
在小的batchsize(4)下, 本文做法依然接近batchsize为32的时候,可见在小batchsize下是work的。
代码解析:
def forward(self, x):if x.dim() > 2:x = x.transpose(1, -1)if self.training: # 训练过程dims = [i for i in range(x.dim() - 1)batch_mean = x.mean(dims) # 计算均值batch_std = x.std(dims, unbiased=False) + self.eps # 计算标准差# 按照公式计算r和dr = (batch_std.detach() / self.running_std.view_as(batch_std)).clamp_(1 / self.rmax, self.rmax)d = ((batch_mean.detach() - self.running_mean.view_as(batch_mean))/ self.running_std.view_as(batch_std)).clamp_(-self.dmax, self.dmax)# 对当前数据进行标准化和线性变换x = (x - batch_mean) / batch_std * r + d# 滑动平均收集全局均值和标注差self.running_mean += self.momentum * (batch_mean.detach() - self.running_mean)self.running_std += self.momentum * (batch_std.detach() - self.running_std)self.num_batches_tracked += 1else: # 测试过程x = (x - self.running_mean) / self.running_stdreturn x
CBN
论文题目:Cross-Iteration Batch Normalization
论文地址:https://arxiv.org/abs/2002.05712
代码地址:https://github.com/Howal/Cross-iterationBatchNorm
本文认为BRN的问题在于它使用的全局均值和标准差不是当前网络权重下获取的,因此不是exactly正确的,所以batchsize再小一点,例如为1或2时就不太work了。本文使用泰勒多项式逼近原理来修正当前的均值和标准差,同样也是间接利用了全局的均值和方差信息。简述就是:当前batch的均值和方差来自之前的K次迭代均值和方差的平均,由于网络权重一直在更新,所以不能直接粗暴求平均。本文而是利用泰勒公式估计前面的迭代在当前权重下的数值。
泰勒公式:
泰勒公式是一个用函数在某点的信息描述其附近取值的公式。如果函数满足一定的条件,泰勒公式可以用函数在某一点的各阶导数值做系数构建一个多项式来近似表达这个函数。教科书介绍如下:
核心解析:
本文做法,由于网络一般使用SGD更新权重,因此网络权重的变化是平滑的,所以适用泰勒公式。如下,t为训练过程中当前迭代时刻,t-τ为t时刻向前τ时刻。θ为网络权重,权重下标代表该权重的时刻。μ为当前minibatch均值,v为当强minibatch平方的均值,是为了计算标准差。因此直接套用泰勒公式得到:
μt−τ(θt)=μt−τ(θt−τ)+∂μt−τ(θt−τ)∂θt−τ(θt−θt−τ)+O(∥θt−θt−τ∥2)(5)\begin{aligned} \mu_{t-\tau}\left(\theta_{t}\right)=& \mu_{t-\tau}\left(\theta_{t-\tau}\right)+\frac{\partial \mu_{t-\tau}\left(\theta_{t-\tau}\right)}{\partial \theta_{t-\tau}}\left(\theta_{t}-\theta_{t-\tau}\right) \\ &+\mathbf{O}\left(\left\|\theta_{t}-\theta_{t-\tau}\right\|^{2}\right) \end{aligned}\tag{5} μt−τ(θt)=μt−τ(θt−τ)+∂θt−τ∂μt−τ(θt−τ)(θt−θt−τ)+O(∥θt−θt−τ∥2)(5)
νt−τ(θt)=νt−τ(θt−τ)+∂νt−τ(θt−τ)∂θt−τ(θt−θt−τ)+O(∥θt−θt−τ∥2)(6)\begin{aligned} \nu_{t-\tau}\left(\theta_{t}\right)=& \nu_{t-\tau}\left(\theta_{t-\tau}\right)+\frac{\partial \nu_{t-\tau}\left(\theta_{t-\tau}\right)}{\partial \theta_{t-\tau}}\left(\theta_{t}-\theta_{t-\tau}\right) \\ &+\mathbf{O}\left(\left\|\theta_{t}-\theta_{t-\tau}\right\|^{2}\right) \end{aligned}\tag{6} νt−τ(θt)=νt−τ(θt−τ)+∂θt−τ∂νt−τ(θt−τ)(θt−θt−τ)+O(∥θt−θt−τ∥2)(6)
上面这两个公式就是为了估计在t-τ时刻,t时刻的权重下的均值和方差的参数估计。BRN可以看作没有进行该方法估计,使用的依然是t-τ时刻权重的参数估计。其中O为高阶项,因为该式主要由一阶项控制,因此高阶项目可以忽略。上面的公式还要进一步简化,主要是偏导项的求法。假设当前层为l,实际上∂μ/ ∂θ 和 ∂ν/∂θ依赖与所有l层之前层的权重,求导计算量极大。不过经验观察到,l层之前层的偏数下降很快,因此可以忽略掉,仅仅计算当前层的权重偏导。
因此化简为如下,可以看出,求偏导的部分,只考虑对当前层的偏导数,注意上标l表示网络层的意思。至此,之前时刻在当前权重下的均值和方差已经估计出来了。
μt−τl(θt)≈μt−τl(θt−τ)+∂μt−τl(θt−τ)∂θt−τl(θtl−θt−τl)(7)\mu_{t-\tau}^{l}\left(\theta_{t}\right) \approx \mu_{t-\tau}^{l}\left(\theta_{t-\tau}\right)+\frac{\partial \mu_{t-\tau}^{l}\left(\theta_{t-\tau}\right)}{\partial \theta_{t-\tau}^{l}}\left(\theta_{t}^{l}-\theta_{t-\tau}^{l}\right)\tag{7} μt−τl(θt)≈μt−τl(θt−τ)+∂θt−τl∂μt−τl(θt−τ)(θtl−θt−τl)(7)
νt−τl(θt)≈νt−τl(θt−τ)+∂νt−τl(θt−τ)∂θt−τl(θtl−θt−τl)(8)\nu_{t-\tau}^{l}\left(\theta_{t}\right) \approx \nu_{t-\tau}^{l}\left(\theta_{t-\tau}\right)+\frac{\partial \nu_{t-\tau}^{l}\left(\theta_{t-\tau}\right)}{\partial \theta_{t-\tau}^{l}}\left(\theta_{t}^{l}-\theta_{t-\tau}^{l}\right)\tag{8} νt−τl(θt)≈νt−τl(θt−τ)+∂θt−τl∂νt−τl(θt−τ)(θtl−θt−τl)(8)
下面穿插代码解析整个计算过程。
首先是统计计算当前batch的数据,和标准BN没有差别。代码为:
cur_mu = y.mean(dim=1) # 当前层的均值
cur_meanx2 = torch.pow(y, 2).mean(dim=1) # 当前值平方的均值,计算标准差使用
cur_sigma2 = y.var(dim=1) # 当前值的方差
对当前网络层求偏导,直接使用torch的内置函数。代码:
# 注意 grad_outputs = self.ones : 不同值的梯度对结果影响程度不同,类似torch.sum()的作用。
dmudw = torch.autograd.grad(cur_mu, weight, self.ones, retain_graph=True)[0]
dmeanx2dw = torch.autograd.grad(cur_meanx2, weight, self.ones, retain_graph=True)[0]
使用公式(7)和(8)继续下面的计算,也就是向前累计K次估计数值,更新到当前batch的均值和方差的计算上,这里引入了一个超参就是k的大小,它表示当前的迭代向后回溯到多长的步长的迭代。实验探究k=8是一个比较折中的选择。k=1的时候,RBN退化成了原始的BN:
μˉt,kl(θt)=1k∑τ=0k−1μt−τl(θt)(9)\bar{\mu}_{t, k}^{l}\left(\theta_{t}\right)=\frac{1}{k} \sum_{\tau=0}^{k-1} \mu_{t-\tau}^{l}\left(\theta_{t}\right)\tag{9} μˉt,kl(θt)=k1τ=0∑k−1μt−τl(θt)(9)
νˉt,kl(θt)=1k∑τ=0k−1max[νt−τl(θt),μt−τl(θt)2](10)\bar{\nu}_{t, k}^{l}\left(\theta_{t}\right)=\frac{1}{k} \sum_{\tau=0}^{k-1} \max \left[\nu_{t-\tau}^{l}\left(\theta_{t}\right), \mu_{t-\tau}^{l}\left(\theta_{t}\right)^{2}\right]\tag{10} νˉt,kl(θt)=k1τ=0∑k−1max[νt−τl(θt),μt−τl(θt)2](10)
σˉt,kl(θt)=νˉt,kl(θt)−μˉt,kl(θt)2(11)\bar{\sigma}_{t, k}^{l}\left(\theta_{t}\right)=\sqrt{\bar{\nu}_{t, k}^{l}\left(\theta_{t}\right)-\bar{\mu}_{t, k}^{l}\left(\theta_{t}\right)^{2}}\tag{11} σˉt,kl(θt)=νˉt,kl(θt)−μˉt,kl(θt)2(11)
代码如下,其中这里的self.pre_mu, self.pre_dmudw, self.pre_weight是前面每次迭代收集到了窗口k大小的数值,分别代表均值、均值对权重的偏导、权重。self.pre_meanx2, self.pre_dmeanx2dw, self.pre_weight同理,是对应平方均值的。
# 利用泰勒公式估计
mu_all = torch.stack \([cur_mu, ] + [tmp_mu + (self.rho * tmp_d * (weight.data - tmp_w)).sum(1).sum(1).sum(1) for tmp_mu, tmp_d, tmp_w in zip(self.pre_mu, self.pre_dmudw, self.pre_weight)])meanx2_all = torch.stack \([cur_meanx2, ] + [tmp_meanx2 + (self.rho * tmp_d * (weight.data - tmp_w)).sum(1).sum(1).sum(1) for tmp_meanx2, tmp_d, tmp_w in zip(self.pre_meanx2, self.pre_dmeanx2dw, self.pre_weight)])
上面所说的变量收集迭代过程如下:
# 动态维护buffer_num长度的均值、均值平方、偏导、权重
self.pre_mu = [cur_mu.detach(), ] + self.pre_mu[:(self.buffer_num - 1)]
self.pre_meanx2 = [cur_meanx2.detach(), ] + self.pre_meanx2[:(self.buffer_num - 1)]
self.pre_dmudw = [dmudw.detach(), ] + self.pre_dmudw[:(self.buffer_num - 1)]
self.pre_dmeanx2dw = [dmeanx2dw.detach(), ] + self.pre_dmeanx2dw[:(self.buffer_num - 1)]
tmp_weight = torch.zeros_like(weight.data)
tmp_weight.copy_(weight.data)
self.pre_weight = [tmp_weight.detach(), ] + self.pre_weight[:(self.buffer_num - 1)]
计算获取当前batch的均值和方差,取修正后的K次迭代数据的平均即可。
# 利用收集到的一定窗口长度的均值和平方均值,计算当前均值和方差
sigma2_all = meanx2_all - torch.pow(mu_all, 2)
re_mu_all = mu_all.clone()
re_meanx2_all = meanx2_all.clone()
re_mu_all[sigma2_all < 0] = 0
re_meanx2_all[sigma2_all < 0] = 0
count = (sigma2_all >= 0).sum(dim=0).float()
mu = re_mu_all.sum(dim=0) / count # 平均操作
sigma2 = re_meanx2_all.sum(dim=0) / count - torch.pow(mu, 2)
均值和方差使用过程,和标准BN没有区别。
# 标准化过程,和原始BN没有区别
y = y - mu.view(-1, 1)
if self.out_p: # 仅仅控制开平方的位置y = y / (sigma2.view(-1, 1) + self.eps) ** .5
else:y = y / (sigma2.view(-1, 1) ** .5 + self.eps)
最后再理解一下:mu_0是当前batch统计获取的均值,mu_1是上一batch统计获取的均值。 当前batch计算BN的时候也想利用到mu_1,但是统计mu_1的时候利用到网络的权重也是上一次的,直接使用肯定有问题,所以本文使用泰勒公式估计出mu_1在当前权重下应该是什么样子。方差估计同理。
实验效果:
这里的Naive CBN 是上一篇论文BRN的做法,可以认为是CBN不使用泰勒估计的一种特例。在batchsize下降的过程中,CBN指标依然坚挺,甚至超过了GN(不过也侧面反应了GN确实厉害)。而原始BN和其改进版BRN在batchsize更小的时候都不太work了。
Batchsize不够大,如何发挥BN性能?探讨神经网络在小Batch下的训练方法相关推荐
- linux发挥显卡性能,Linux Kernel 2.6.30下Intel显卡性能有大幅提升!
我是Intel GMA950集成显卡的用户,而且一直也是Compiz的用户的. 早在Beryl时代,GMA950就给3D桌面提供了相当的动力.可惜在最新发布的Ubuntu 9.04中,Compiz的性 ...
- 计算机电源影响吗,电脑主机电源瓦数太低会不会影响显卡和处理器发挥最高性能?...
电脑电源相当于人体的血液,为每一个硬件提供血液(供电),因此可见电源的重要性,它决定了电脑稳定性.不过装机用户在搭配电源的时候,就怕将电源的功率配小了.那么电脑主机电源瓦数太低会不会影响显卡和处理器发 ...
- Java HashMap 遍历方式性能探讨
转载自 Java HashMap 遍历方式性能探讨 关于HashMap的实现这里就不展开了,具体可以参考JDK7与JDK8中HashMap的实现 JDK8之前,可以使用keySet或者entrySet ...
- 大泥王怎么调时区_卡西欧大泥王手表性能怎么样 卡西欧手表为什么叫泥王
卡西欧的手表有很多不同的款式.一般的价格都在几百元左右,但是卡西欧大泥王腕表却高4千元左右.当然外表上非常的帅气好看.卡西欧大泥王手表性能怎么样 卡西欧手表为什么叫泥王.八宝网带来相关介绍. 卡西欧大 ...
- 好用的综合revit软件丨Revit专业模型太大如何提高电脑性能
好用的综合revit软件丨Revit专业模型太大如何提高电脑性能 在实际项目中(以暖通专业为例),目前的专业模型太大,如下图1所示,会导致计算机资源不足.如何减少计算机资源的消耗,提高工作效率? 以上 ...
- java mx150显卡够了吗_MX150性能究竟怎么样,与GTX1050相比差别有多大,白话评测性能...
MX150性能究竟怎么样,与GTX1050相比差别有多大,白话评测性能 MX150自从出世后,被各大厂商宣传坏了,"高性能独立显卡","2GB独立显卡",&qu ...
- 区块链与大数据共生共长 帮助大数据发挥出更大的价值
自2015年以来,区块链技术迅猛发展,其应用场景日益广泛.与此同时,大数据的发展却越来越受到数据孤岛.数据质量.数据安全等问题的制约.区块链技术会替代大数据技术吗?二者将此消彼长吗?本文将讨论这一问题 ...
- Intel官宣两大全新CPU 一性能猛增8.8倍!
一年一度的台北电脑展马上开启,各大巨头都在摩拳擦掌.Intel今天官方宣布,公司高级副总裁兼客户端计算事业部总经理Gregory Bryant,将于5月28日发表2019年台北国际电脑展开幕主题演讲, ...
- HTML标签以及各大浏览器份额、性能
总目录 HTML标签 声明 基础标签 格式相关 表格 列表 多媒体 交互 内容交互: 菜单交互: 状态交互: 表单 超链接 各大浏览器 浏览器内核 性能方面: HTML标签 声明 <!DOCTY ...
最新文章
- 【微信Java开发 --番外篇】错误解析
- python操作系统-python 操作系统
- Java语言编码规范(1)
- [bzoj5405]platform
- mino文件服务器删除文件,Spring-minio
- Python3 中的 asyncio async await 概念(实例)(ValueError: too many file descriptors in select())
- 作家如何利用Git更好地完成工作
- Python学习笔记:初探NumPy世界
- java 摸拟qq消息提示_java 仿qq消息提示框
- MySQL数据库 --基础
- Qt4_十六进制微调框
- [Guava源码日报](8)ImmutableCollection
- rosdep init 和rosdep update的解决方法,亲测有效
- Ubuntu下使用Atom将Markdown文件转换为PDF的一个异常
- java音乐网站论文_基于Java web的音乐网站的设计与实现论文(含源文件).doc
- ZOJ 3527 Shinryaku! Kero Musume 【树形DP[带简单环]】
- 斯特林(Stirling)数
- python函数查询工具_布同:Python函数帮助查询小工具[v1和v2]
- java无法下载jnlp_java-JNLP下载期间FileNotFoundException
- ATFX:道琼斯指数的反弹,11月能否突破35000关口?