• 论文题目:

RepVGG:Making VGG-style ConvNets Great Again

  • 论文下载:

httsp://arxiv.org/abs/2101.03697

  • 代码开源:

https://github.com/DingXiaoH/RepVGG

该文是清华大学&旷视科技等提出的一种新颖的CNN设计范式,它将ACNet的思想与VGG架构进行了巧妙的结合,即避免了VGG类方法训练所得精度低的问题,又保持了VGG方案的高效推理优点,它首次将plain模型的精度在ImageNet上提升到了超过80%top1精度。相比ResNet、RegNet、EfficientNet等网络,RepVGG具有更好的精度-速度均衡。

本文提出一种简单而强有力的CNN架构RepVGG,在推理阶段,它具有与VGG类似的架构,而在训练阶段,它则具有多分支架构体系,这种训练-推理解耦的架构设计源自一种称之为“重参数化(re-parameterization)”的技术。

在ImageNet数据集上,RepVGG取得了超过80%的top-1精度,这是plain模型首次达到如此高的精度。在NVIDIA 1080TiGPU上,RepVGG比ResNet50快83%,比ResNet101快101%,同时具有更高的精度;相比EfficientNet与RegNet,RepVGG表现出了更好的精度-速度均衡。

该文的主要贡献包含以下三个方面:

  • 提出了一种简单有强有的CNN架构RepVGG,相比EfficientNet、RegNet等架构,RepVGG具有更佳的精度-速度均衡;

  • 提出采用重参数化技术对plain架构进行训练-推理解耦;

  • 在图像分类、语义分割等任务上验证了RepVGG的有效性。

Simple is Fast, Memory-economical, Flexible

简单的ConvNet具有这样三点优势:

  • Fast:相比VGG,现有的多分支架构理论上具有更低的Flops,但推理速度并未更快。比如VGG16的参数量为EfficientNetB3的8.4倍,但在1080Ti上推理速度反而快1.8倍。这就意味着前者的计算密度是后者的15倍。Flops与推理速度的矛盾主要源自两个关键因素:(1) MAC(memory access cose),比如多分支结构的Add与Cat的计算很小,但MAC很高; (2)并行度,已有研究表明:并行度高的模型要比并行度低的模型推理速度更快。

  • Memory-economical:多分支结构是一种内存低效的架构,这是因为每个分支的结构都需要在Add/Concat之前保存,这会导致更大的峰值内存占用;而plain模型则具有更好的内存高效特征。

  • Flexible:多分支结构会限制CNN的灵活性,比如ResBlock会约束两个分支的tensor具有相同的形状;与此同时,多分支结构对于模型剪枝不够友好。

Training-time Multi-branch Architecture

Palin模型具有多种优势但存在一个重要的弱势:性能差。比如VGG16在ImageNet仅能达到72%的top-1指标。

本文所设计的RepVGG则是受ResNet启发得到,ResNet的ResBlock显示的构建了一个短连接模型信息流 ,当 的维度不匹配时,上述信息流则转变为 。

尽管多分支结构对于推理不友好,但对于训练友好,作者将RepVGG设计为训练时的多分支,推理时单分支结构。作者参考ResNet的identity与 分支,设计了如下形式模块:

其中, 分别对应 卷积。在训练阶段,通过简单的堆叠上述模块构建CNN架构;而在推理阶段,上述模块可以轻易转换为 形式,且 的参数可以通过线性组合方式从已训练好的模型中转换得到。

Re-param for Plain Inference-time Model

接下来,我们将介绍如何将已训练模块转换成单一的 卷积用于推理。下图给出了参数转换示意图。

我们采用 表示输入 ,输出 ,卷积核为3的卷积;采用 表示输入 ,输出 ,卷积核为1的卷积;采用 表示 卷积后的BatchNorm的参数;采用 表示 卷积后的BatchNorm的参数;采用 表示identity分支的BatchNorm的参数。假设 分别表示输入与输出,当 时,

否则,简单的采用无identity分支的模块,也就是说只有前两项。注:bn表示推理时的BN。

首先,我们可以将每个BN与其前接Conv层合并:

注:identity分支可以视作 卷积。通过上述变换,此时上述模块仅仅具有一个 卷积核,两个 卷积核以及三个bias参数。此时,三个bias参数可以通过简单的add方式合并为一个bias;而卷积核则可以将 卷积核参数加到 卷积核的中心点得到。说起来复杂,其实看一下code就非常简单了,见文末code。

Architectural Specification

前面介绍了RepVGG的核心模块设计方式,接下来就要介绍RepVGG的网络结构如何设计了。下表给出了RepVGG的配置信息,包含深度与宽度。

RepVGG是一种类VGG的架构,在推理阶段它仅仅采用 卷积与ReLU,且未采用MaxPool。对于分类任务,采用GAP+全连接层作为输出头。

对于每个阶段的层数按照如下三种简单的规则进行设计:

  • 第一个阶段具有更大的分辨率,故而更为耗时,为降低推理延迟仅仅采用了一个卷积层;

  • 最后一个阶段因为具有更多的通道,为节省参数量,故而仅设计一个卷积层;

  • 在倒数第二个阶段,类似ResNet,RepVGG放置了更多的层。

基于上述考量,RepVGG-A不同阶段的层数分别为1-2-4-14-1;与此同时,作者还构建了一个更深的RepVGG-B,其层数配置为1-4-6-16-1。RepVGG-A用于与轻量型网络和中等计算量网络对标,而RepVGG-B用于与高性能网络对标。

在不同阶段的通道数方面,作者采用了经典的配置64-128-256-512。与此同时,作者采用因子 控制前四个阶段的通道,因子 控制最后一个阶段的通道,通常 (我们期望最后一层具有更丰富的特征)。为避免大尺寸特征的高计算量,对于第一阶段的输出通道做了约束 。基于此得到的不同RepVGG见下表。

为进一步降低计算量与参数量,作者还设计了可选的 组卷积替换标准卷积。具体地说,在RepVGG-A的3-5-7-...-21卷积层采用了组卷积;此外,在RepVGG-B的23-25-27卷积层同样采用了组卷积。

接下来,我们将在不同任务上验证所提方案的有效性,这里主要在ImageNet图像分类任务上进行了实验分析。

上表给出了RepVGG与不同计算量的ResNe及其变种在精度、速度、参数量等方面的对比。可以看到:RepVGG表现出了更好的精度-速度均衡,比如

  • RepVGG-A0比ResNet18精度高1.25%,推理速度快33%;

  • RepVGG-A1比Resnet34精度高0.29%,推理速度快64%;

  • RepVGG-A2比ResNet50精度高0.17%,推理速度快83%;

  • RepVGG-B1g4比ResNet101精度高0.37%,推理速度快101%;

  • RepVGG-B1g2比ResNet152精度相当,推理速度快2.66倍。

另外需要注意的是:RepVGG同样是一种参数高效的方案。比如:相比VGG16,RepVGG-B2b168.com仅需58%参数量,推理快10%,精度高6.57%。

与此同时,还与EfficientNet、RegNet等进行了对比,对比如下:

  • RepVGG-A2比EfficientNet-B0精度高1.37%,推理速度快59%;

  • RepVGG-B1比RegNetX-3.2GF精度高0.39%,推理速度稍快;

此外需要注意:RepVGG仅需200epoch即可取得超过80%的top1精度,见上表对比。这应该是plain模型首次在精度上达到SOTA指标。相比RegNetX-12GF,RepVGG-B3的推理速度快31%,同时具有相当的精度。

尽管RepVGG是一种简单而强有力的ConvNet架构,它在GPU端具有更快的推理速度、更少的参数量和理论FLOPS;但是在低功耗的端侧,MobileNet、ShuffleNet会更受关注。

话说在半个月多之前就听说xiangyu等人把ACNet的思想与Inception相结合设计了一种性能更好的重参数化方案RepVGG,即可取得训练时的性能提升,又可以保持推理高效,使得VGG类网络可以达到ResNet的高性能。

在初次看到RepVGG架构后,笔者就曾尝试将其用于VDSR图像超分方案中,简单一试,确实有了提升,而且不需要进行梯度裁剪等额外的一些操作,赞。

从某种程度上讲,RepVGG应该是ACNet的的一种极致精简,比如上图给出了ACNet的结构示意图,它采用了 三种卷积设计;而RepVGG则是仅仅采用了 三个分支设计。ACNet与RepVGG的另外一点区别在于:ACNet是将上述模块用于替换ResBlock或者Inception中的卷积,而RepVGG则是采用所设计的模块用于替换VGG中的卷积。

最后附上作者提供的RepVGG的核心模块实现code,如下所示。

# code from https://github.com/DingXiaoH/RepVGG
class RepVGGBlock(nn.Module):def __init__(self, in_channels, out_channels, kernel_size,stride=1, padding=0, dilation=1, groups=1, padding_mode='zeros', deploy=False):super(RepVGGBlock, self).__init__()self.deploy = deployself.groups = groupsself.in_channels = in_channelsassert kernel_size == 3assert padding == 1padding_11 = padding - kernel_size // 2self.nonlinearity = nn.ReLU()if deploy:self.rbr_reparam = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride,padding=padding, dilation=dilation, groups=groups, bias=True, padding_mode=padding_mode)else:self.rbr_identity = nn.BatchNorm2d(num_features=in_channels) if out_channels == in_channels and stride == 1 else Noneself.rbr_dense = conv_bn(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride, padding=padding, groups=groups)self.rbr_1x1 = conv_bn(in_channels=in_channels, out_channels=out_channels, kernel_size=1, stride=stride, padding=padding_11, groups=groups)print('RepVGG Block, identity = ', self.rbr_identity)def forward(self, inputs):if hasattr(self, 'rbr_reparam'):return self.nonlinearity(self.rbr_reparam(inputs))if self.rbr_identity is None:id_out = 0else:id_out = self.rbr_identity(inputs)return self.nonlinearity(self.rbr_dense(inputs) + self.rbr_1x1(inputs) + id_out)def _fuse_bn(self, branch):if branch is None:return 0, 0if isinstance(branch, nn.Sequential):kernel = branch.conv.weight.detach().cpu().numpy()running_mean = branch.bn.running_mean.cpu().numpy()running_var = branch.bn.running_var.cpu().numpy()gamma = branch.bn.weight.detach().cpu().numpy()beta = branch.bn.bias.detach().cpu().numpy()eps = branch.bn.epselse:assert isinstance(branch, nn.BatchNorm2d)kernel = np.zeros((self.in_channels, self.in_channels, 3, 3))for i in range(self.in_channels):kernel[i, i, 1, 1] = 1running_mean = branch.running_mean.cpu().numpy()running_var = branch.running_var.cpu().numpy()gamma = branch.weight.detach().cpu().numpy()beta = branch.bias.detach().cpu().numpy()eps = branch.epsstd = np.sqrt(running_var + eps)t = gamma / stdt = np.reshape(t, (-1, 1, 1, 1))t = np.tile(t, (1, kernel.shape[1], kernel.shape[2], kernel.shape[3]))return kernel * t, beta - running_mean * gamma / stddef _pad_1x1_to_3x3(self, kernel1x1):if kernel1x1 is None:return 0kernel = np.zeros((kernel1x1.shape[0], kernel1x1.shape[1], 3, 3))kernel[:, :, 1:2, 1:2] = kernel1x1return kerneldef repvgg_convert(self):kernel3x3, bias3x3 = self._fuse_bn(self.rbr_dense)kernel1x1, bias1x1 = self._fuse_bn(self.rbr_1x1)kernelid, biasid = self._fuse_bn(self.rbr_identity)return kernel3x3 + self._pad_1x1_to_3x3(kernel1x1) + kernelid, bias3x3 + bias1x1 + biasid

欢迎关注旷视研究院,一起聊聊技术那些事儿!

RepVGG | 让你的ConVNet一卷到底,plain网络首次超过80%top1精度相关推荐

  1. 注意力CBMA到底在网络中做了什么事

    注意力CBAM到底在网络中做了什么事 CBAM网络架构 通道注意力 空间注意力: 分析 通道注意力: 1.将特征图进行最大池化和平均池化 ​ SENet也使用了通道注意力, 但SENet只采用了平均池 ...

  2. docker入门,镜像,容器,数据卷,dockerfile,docker网络,springboot微服务打包docker镜像[狂神yyds]

    docker学习大纲 docker概述 docker安装 docker命令 镜像命令 容器命令 操作命令 - docker镜像 容器数据卷 dockerfile docker网络原理 IDEA整合do ...

  3. 【Docker 系列】我们来看看容器数据卷到底是个啥

    什么是容器数据卷 思考一个问题,我们为什么要使用 Docker? 主要是为了可以将应用和环境进行打包成镜像,一键部署. 再思考一个问题,容器之间是相互隔离的,如果我们在容器中部署类似 mysql 这样 ...

  4. 目标检测一卷到底之后,终于有人为它挖了个新坑|CVPR2021 Oral

    点击上方"3D视觉工坊",选择"星标" 干货第一时间送达 作者丨二玖 审稿|邓富城 报道丨极市平台 导读 本文解决了两个挑战:一是在没有明确监督的情况下,将尚未 ...

  5. 【论文解读】目标检测一卷到底之后,终于有人为它挖了个新坑|CVPR2021 Oral

    作者丨二玖 审稿|邓富城 报道丨极市平台 极市导读 本文解决了两个挑战:一是在没有明确监督的情况下,将尚未引入的目标识别为"未知",二是让网络进行N+1式增量学习. 虽然目标检测技 ...

  6. 目标检测一卷到底之后,终于又有人给它挖了个新坑|CVPR2021 Oral

    目标检测技术虽然已经发展得较为成熟,但是如果要说让计算机能够像人眼一样进行识别,有个特征一直没有被解锁--识别现实世界中的所有物体,并且能够逐渐学习认知新的未知物体. 来自澳大利亚国立大学和瑞典林雪平 ...

  7. 前途到底是网络工程还是程序设计

    本人89年年底生的,现在快满21了,大二的时候过的国家网络工程师考试,并不是cisco的网络支持工程师,大三也就是现在,在学校花销太大,想自己赚点钱,于是在学校招聘会上应聘了一家通信公司,从事不是专业 ...

  8. PyTorch笔记 - A ConvNet for the 2020s (ConvNeXt) 网络

    欢迎关注我的CSDN:https://blog.csdn.net/caroline_wendy 本文地址:https://blog.csdn.net/caroline_wendy/article/de ...

  9. 华师大计算机入门模拟卷,华东师范大学网络本科计算机基础考试大纲

    计算机基础考试需要了解微型计算机的基础知识,微型计算机系统的组成和各组成部分的功能等方面的知识,对于报读了华东师范大学本科的网络考生而言,若是有个复习大纲那么复习起来会更高效,为此明德教育为大家整理了 ...

最新文章

  1. hadoop 基准测试与读写测试
  2. php 中curd表达啥,CURD语句的基本语法和PDO中操作数据表的基本步骤实例演示增删改查命令 2019年07月24日 23时10分...
  3. 如何修改snmp的监听端口
  4. 人工智能诗歌写作平台_智能写作VS人工写作,Giiso写作机器人解放你的创造力...
  5. vue 打包上线后字体图标不显示
  6. python查询oracle数据库_python针对Oracle常见查询操作实例分析
  7. 安装一直初始化_3D max 软件安装问题大全
  8. 用pfx证书java双向认证_把CA证书生成的crt的证书和pem的私钥转换成java能够使用的keystore和pcks12的证书,实现https双向认证...
  9. java线程异步传值_Java 多线程传值的四种方法
  10. php按条件修改xml,php 修改、增加xml结点属性的实现代码
  11. Redis的复制(Master/Slave)
  12. 初学RubyOnRails的推荐书籍
  13. python模板是什么意思_python – 这个模板中的正确包含路径是什么?
  14. sel4 手册总结之介绍与内核服务和对象
  15. android 中拦截home键
  16. 全球 26 个主流视频网站高清视频下载全搞定,包括 P 站!
  17. 面向过程(PO)和面向对象(OO)的区别(思维导图)
  18. 4.9. 相等的多项式
  19. 在maven启动时tomcat端口冲突问题 Address already in use: JVM_Bind null:8080
  20. 全球顶级的14位程序员

热门文章

  1. 七个基本量纲_量纲是什么?
  2. iOS15.2 注册相册变化通知未给相册权限导致崩溃 [PHPhotoLibrary.sharedPhotoLibrary registerChangeObserver:self]
  3. GrowingIO服务端开发面试、以及对测试开发、趣头条的一点个人看法。仅供参考
  4. Linux——软件安装
  5. java获取网络图片(比如微信授权后的头像)上传至linux服务器
  6. ROC曲线绘制(Python)
  7. oracle vm virtualBox中配置独立的ip给ubuntu
  8. Node / v8 1gb memory limit?
  9. Java工具类------Math类的详解及使用
  10. Android多种方式实现相机圆形预览 看这一篇就够了,Android开发面试书籍