BN、CBN、CmBN 的对比与总结

最近看到了关于 Yolo 系列 trick 的总结文章 【Make YOLO Great Again】YOLOv1-v7全系列大解析(Tricks篇),其中提到了 YoloV4 中使用了 CmBN,这是对 CBN 的改进,可以较好的适应小 batch 的情形。论文中给出了一个简要的对比图:

这里结合此图对 BN 和其两种改进策略进行说明。所以需要注意的是,这里存在两个 batch 相关的概念:

  • batch:指代与 BN 层的统计量 实际想要相对应的数据池,也就是图片样本数。
  • mini-batch:由于整个 batch 独立计算时,受到资源限制可能不现实,于是需要将 batch 拆分成数个 mini-batch,每个 mini-batch 单独计算后汇总得到整个 batch 的统计量。从而归一化特征。

我们日常在分割或者检测中使用 BN 时,此时如果不使用特殊的设定,那么 batch 与 mini-batch 是一样的。CBN 和 CmBN 所做的就是如何使用多个独立的 mini-batch 的数据获得一个近似于更大 batch 的统计量以提升学习效果。

CBN 与 CmBN

CmBN(Cross mini-Batch Normalization)是 CBN(Cross-Iteration Batch Normalization)的修改版。

CBN 主要用来解决在 Batch-Size 较小时,BN 的效果不佳问题。CBN 连续利用多个迭代的数据来变相扩大 batch size 从而改进模型的效果。这种用前几个 iteration 计算好的统计量来计算当前迭代的 BN 统计量的方法会有一个问题:过去的 BN 参数是由过去的网络参数计算出来的特征而得到的,而本轮迭代中计算 BN 时,它们的模型参数其实已经过时了

假定 batch=4*mini batch,CBN 在 ttt 次迭代:

  • 模型基于之前的梯度被更新。此时的 BN 的仿射参数也是最新的。
  • 除了本次迭代的统计量,也会使用通过补偿后的前 3 次迭代得到的统计量。这 4 次的统计量会被一起用来得到近似于整个窗口的近似 batch 的 BN 的统计量。
  • 使用得到的近似统计量归一化特征。
  • 使用当前版本的仿射参数放缩和偏移。

CmBN 是基于 CBN 改进的,按照论文的图示的意思,主要的差异在于从滑动窗口变为固定窗口。每个 batch 中的统计不会使用 batch 之前的迭代的信息,仅会累积该窗口内的 4 次迭代以用于最后一次迭代的更新。这一策略基本与梯度累积策略仍有不同,梯度累加仅仅累加了梯度,但是前面的图中明显可以看到 BN 的统计量实际上也累积了起来,而图 4 中的展现的 BN 似乎更像是梯度累积。

CBN 的实现

# https://github.com/Howal/Cross-iterationBatchNorm/blob/f6d35301789c96e52699a9cbc8d2de8681547770/mmdet/models/utils/CBN.py#L74
def forward(self, input, weight):# deal with wight and grad of self.pre_dxdw!self._check_input_dim(input)y = input.transpose(0, 1)return_shape = y.shapey = y.contiguous().view(input.size(1), -1)# burninif self.training and self.burnin > 0:self.iter_count += 1self._update_buffer_num()if self.buffer_num > 0 and self.training and input.requires_grad:  # some layers are frozen!# cal current batch mu and sigmacur_mu = y.mean(dim=1)cur_meanx2 = torch.pow(y, 2).mean(dim=1)cur_sigma2 = y.var(dim=1)# cal dmu/dw dsigma2/dwdmudw = torch.autograd.grad(cur_mu, weight, self.ones, retain_graph=True)[0]dmeanx2dw = torch.autograd.grad(cur_meanx2, weight, self.ones, retain_graph=True)[0]# update cur_mu and cur_sigma2 with presmu_all = torch.stack([cur_mu, ] + [tmp_mu + (self.rho * tmp_d * (weight.data - tmp_w)).sum(1).sum(1).sum(1) for tmp_mu, tmp_d, tmp_w in zip(self.pre_mu, self.pre_dmudw, self.pre_weight)])meanx2_all = torch.stack([cur_meanx2, ] + [tmp_meanx2 + (self.rho * tmp_d * (weight.data - tmp_w)).sum(1).sum(1).sum(1) for tmp_meanx2, tmp_d, tmp_w in zip(self.pre_meanx2, self.pre_dmeanx2dw, self.pre_weight)])sigma2_all = meanx2_all - torch.pow(mu_all, 2)# with considering countre_mu_all = mu_all.clone()re_meanx2_all = meanx2_all.clone()re_mu_all[sigma2_all < 0] = 0re_meanx2_all[sigma2_all < 0] = 0count = (sigma2_all >= 0).sum(dim=0).float()mu = re_mu_all.sum(dim=0) / countsigma2 = re_meanx2_all.sum(dim=0) / count - torch.pow(mu, 2)self.pre_mu = [cur_mu.detach(), ] + self.pre_mu[:(self.buffer_num - 1)]self.pre_meanx2 = [cur_meanx2.detach(), ] + self.pre_meanx2[:(self.buffer_num - 1)]self.pre_dmudw = [dmudw.detach(), ] + self.pre_dmudw[:(self.buffer_num - 1)]self.pre_dmeanx2dw = [dmeanx2dw.detach(), ] + self.pre_dmeanx2dw[:(self.buffer_num - 1)]tmp_weight = torch.zeros_like(weight.data)tmp_weight.copy_(weight.data)self.pre_weight = [tmp_weight.detach(), ] + self.pre_weight[:(self.buffer_num - 1)]else:x = ymu = x.mean(dim=1)cur_mu = musigma2 = x.var(dim=1)cur_sigma2 = sigma2if not self.training or self.FROZEN:y = y - self.running_mean.view(-1, 1)# TODO: outside **0.5?if self.out_p:y = y / (self.running_var.view(-1, 1) + self.eps)**.5else:y = y / (self.running_var.view(-1, 1)**.5 + self.eps)else:if self.track_running_stats is True:with torch.no_grad():self.running_mean = (1 - self.momentum) * self.running_mean + self.momentum * cur_muself.running_var = (1 - self.momentum) * self.running_var + self.momentum * cur_sigma2y = y - mu.view(-1, 1)# TODO: outside **0.5?if self.out_p:y = y / (sigma2.view(-1, 1) + self.eps)**.5else:y = y / (sigma2.view(-1, 1)**.5 + self.eps)y = self.weight.view(-1, 1) * y + self.bias.view(-1, 1)return y.view(return_shape).transpose(0, 1)

BN、CBN、CmBN 的对比与总结相关推荐

  1. 【深度学习】深度学习的归一化方法的演变(局部响应LRN,BN,LN, IN, GN, FRN, WN, BRN, CBN, CmBN)

    [深度学习]深度学习的归一化方法的演变(局部响应LRN,BN,LN, IN, GN, FRN, WN, BRN, CBN, CmBN) 文章目录 [深度学习]深度学习的归一化方法的演变(局部响应LRN ...

  2. YOLOv4重磅发布,五大改进,二十多项技巧实验,堪称最强目标检测万花筒

    今年2月22日,知名的 DarkNet 和 YOLO 系列作者 Joseph Redmon 宣布退出 CV 界面,这也就意味着 YOLOv3 不会再有官方更新了.但是,CV 领域进步的浪潮仍在滚滚向前 ...

  3. 归一化方法总结 | 又名“BN和它的后浪们“

    前言: 归一化相关技术已经经过了几年的发展,目前针对不同的应用场合有相应的方法,在本文将这些方法做了一个总结,介绍了它们的思路,方法,应用场景.主要涉及到:LRN,BN,LN, IN, GN, FRN ...

  4. Pytorch中BN层入门思想及实现

    批归一化层-BN层(Batch Normalization) 作用及影响: 直接作用:对输入BN层的张量进行数值归一化,使其成为均值为零,方差为一的张量. 带来影响: 1.使得网络更加稳定,结果不容易 ...

  5. 对象检测目标小用什么模型好_[目标检测] YOLO4论文中文版

    点击上方 蓝字 关注我呀! [目标检测] YOLO4论文中文版 文章目录 YOLO4论文中文版 摘要 1.介绍 2.相关工作 2.1.目标检测模型 2.2.Bag of freebies 2.3.Ba ...

  6. YOLOv4论文(中文版)

    摘要 据说有大量的特征可以提高卷积神经网络(CNN)的准确性.需要在大数据集上对这些特征的组合进行实际测试,并对结果进行理论验证.有些特征专门针对某些模型和某些问题,或者只针对小规模数据集;而一些特性 ...

  7. Paper:《YOLOv4: Optimal Speed and Accuracy of Object Detection》的翻译与解读

    Paper:<YOLOv4: Optimal Speed and Accuracy of Object Detection>的翻译与解读 目录 YOLOv4的评价 1.四个改进和一个创新 ...

  8. 【读点论文】YOLOv4: Optimal Speed and Accuracy of Object Detection,讲明目标检测结构,分析先进的涨点tricks,实现一种精度与速度的平衡

    YOLOv4: Optimal Speed and Accuracy of Object Detection Abstract 据说有大量的特征可以提高卷积神经网络(CNN)的准确性.需要在大型数据集 ...

  9. 深度学习阅读导航 | 15 YOLOv4:最佳速度与精确度的目标检测器

    写在前面:大家好!我是[AI 菌],一枚爱弹吉他的程序员.我热爱AI.热爱分享.热爱开源! 这博客是我对学习的一点总结与记录.如果您也对 深度学习.机器视觉.算法.Python.C++ 感兴趣,可以关 ...

最新文章

  1. Leetcode 134. 加油站 解题思路及C++实现
  2. 数据链路层---使用集线器的星型拓扑_传统以太网传输介质的改变_总线型--->双绞线为介质的以太网采用星型拓扑_集线器的特点_集线器之间的远程连接
  3. AI应用开发实战系列之四 - 定制化视觉服务的使用
  4. (DFS+DP)滑雪(poj1088)
  5. log4j的使用 20220228
  6. 解压.solitairetheme8文件
  7. tensorflow使用object detection实现目标检测超详细全流程(视频+图像集检测)
  8. GDAL更新至1.8.1后,通过属性查询矢量出错问题的解决方式
  9. 多摩川读写EEPROM以及并口实现
  10. 2021年10月数学一及第十三届大数赛部分复习
  11. Python实现最近邻nearest、双线性bilinear、双三次bicubic插值
  12. 带你玩转软件项目测试管理——项目研发管理模式(一)
  13. 解决Jenkins一直用户名或密码错误
  14. 云服务器连接手机本地文件在哪里,云服务器如何连接本地文件
  15. Windows 10配置CUDA 9.2
  16. 如何利用免费工具轻松实现个人号裂变?
  17. 设置 XShell 的默认全局配色方案
  18. 假如不小心因病去世,怎么给家人留下足够的财富呢?
  19. GOJS入门三-如何设置节点间的连线
  20. 工作室多wifi软路由指南

热门文章

  1. 有没有python搜题_python搜题公众号
  2. 关于SBUF读两次的问题
  3. 是否有标准函数来检查 JavaScript 中的 null、未定义或空白变量?
  4. UVA 11134 - Fabled Rooks(经典贪心)
  5. micropython 串口 wifi_MicroPython实现wifi干扰与抓包
  6. 拨开字符编码的迷雾--编译器如何处理文件编码
  7. 基于Redis解决业务场景中延迟队列的应用实践
  8. selenium的安装和下载谷歌浏览器镜像驱动
  9. 图像分类竞赛——添翼杯人工智能应用创新大赛——rank4解决方案
  10. opencv各lib库的功能