论文提出NF-ResNet,根据网络的实际信号传递进行分析,模拟BatchNorm在均值和方差传递上的表现,进而代替BatchNorm。论文实验和分析十分足,出来的效果也很不错。一些初始化方法的理论效果是对的,但实际使用会有偏差,论文通过实践分析发现了这一点进行补充,贯彻了实践出真知的道理

来源:晓飞的算法工程笔记 公众号

论文: Characterizing signal propagation to close the performance gap in unnormalized ResNets

  • 论文地址:https://arxiv.org/abs/2101.08692

Introduction


  BatchNorm是深度学习中核心计算组件,大部分的SOTA图像模型都使用它,主要有以下几个优点:

  • 平滑损失曲线,可使用更大的学习率进行学习。
  • 根据minibatch计算的统计信息相当于为当前的batch引入噪声,有正则化作用,防止过拟合。
  • 在初始阶段,约束残差分支的权值,保证深度残差网络有很好的信息传递,可训练超深的网络。

  然而,尽管BatchNorm很好,但还是有以下缺点:

  • 性能受batch size影响大,batch size小时表现很差。
  • 带来训练和推理时用法不一致的问题。
  • 增加内存消耗。
  • 实现模型时常见的错误来源,特别是分布式训练。
  • 由于精度问题,难以在不同的硬件上复现训练结果。

  目前,很多研究开始寻找替代BatchNorm的归一化层,但这些替代层要么表现不行,要么会带来新的问题,比如增加推理的计算消耗。而另外一些研究则尝试去掉归一化层,比如初始化残差分支的权值,使其输出为零,保证训练初期大部分的信息通过skip path进行传递。虽然能够训练很深的网络,但使用简单的初始化方法的网络的准确率较差,而且这样的初始化很难用于更复杂的网络中。
  因此,论文希望找出一种有效地训练不含BatchNorm的深度残差网络的方法,而且测试集性能能够媲美当前的SOTA,论文主要贡献如下:

  • 提出信号传播图(Signal Propagation Plots, SPPs),可辅助观察初始阶段的推理信号传播情况,确定如何设计无BatchNorm的ResNet来达到类似的信号传播效果。
  • 验证发现无BatchNorm的ResNet效果不好的关键在于非线性激活(ReLU)的使用,经过非线性激活的输出的均值总是正数,导致权值的均值随着网络深度的增加而急剧增加。于是提出Scaled Weight Standardization,能够阻止信号均值的增长,大幅提升性能。
  • 对ResNet进行normalization-free改造以及添加Scaled Weight Standardization训练,在ImageNet上与原版的ResNet有相当的性能,层数达到288层。
  • 对RegNet进行normalization-free改造,结合EfficientNet的混合缩放,构造了NF-RegNet系列,在不同的计算量上都达到与EfficientNet相当的性能。

Signal Propagation Plots


  许多研究从理论上分析ResNet的信号传播,却很少会在设计或魔改网络的时候实地验证不同层数的特征缩放情况。实际上,用任意输入进行前向推理,然后记录网络不同位置特征的统计信息,可以很直观地了解信息传播状况并尽快发现隐藏的问题,不用经历漫长的失败训练。于是,论文提出了信号传播图(Signal Propagation Plots,SPPs),输入随机高斯输入或真实训练样本,然后分别统计每个残差block输出的以下信息:

  • Average Channel Squared Mean,在NHW维计算均值的平方(平衡正负均值),然后在C维计算平均值,越接近零是越好的。
  • Average Channel Variance,在NHW维计算方差,然后在C维计算平均值,用于衡量信号的幅度,可以看到信号是爆炸抑或是衰减。
  • Residual Average Channel Variance,仅计算残差分支输出,用于评估分支是否被正确初始化。

  论文对常见的BN-ReLU-Conv结构和不常见的ReLU-BN-Conv结构进行了实验统计,实验的网络为600层ResNet,采用He初始化,定义residual block为xl+1=fl(xl)+xlx_{l+1}=f_{l}(x_{l}) + x_{l}xl+1​=fl​(xl​)+xl​,从SPPs可以发现了以下现象:

  • Average Channel Variance随着网络深度线性增长,然后在transition block处重置为较低值。这是由于在训练初始阶段,residual block的输出的方差为Var(xl+1)=Var(fl(xl))+Var(xl)Var(x_{l+1})=Var(f_{l}(x_{l})) + Var(x_{l})Var(xl+1​)=Var(fl​(xl​))+Var(xl​),不断累积residual branch和skip path的方差。而在transition block处,skip path的输入被BatchNorm处理过,所以block的输出的方差直接被重置了。

  • BN-ReLU-Conv的Average Squared Channel Means也是随着网络深度不断增加,虽然BatchNorm的输出是零均值的,但经过ReLU之后就变成了正均值,再与skip path相加就不断地增加直到transition block的出现,这种现象可称为mean-shift。

  • BN-ReLU的Residual Average Channel Variance大约为0.68,ReLU-BN的则大约为1。BN-ReLU的方差变小主要由于ReLU,后面会分析到,但理论应该是0.34左右,而且这里每个transition block的残差分支输出却为1,有点奇怪,如果知道的读者麻烦评论或私信一下。

  假如直接去掉BatchNorm,Average Squared Channel Means和Average Channel Variance将会不断地增加,这也是深层网络难以训练的原因。所以要去掉BatchNorm,必须设法模拟BatchNorm的信号传递效果。

Normalizer-Free ResNets(NF-ResNets)


  根据前面的SPPs,论文设计了新的redsidual blockxl+1=xl+αfl(xl/βl)x_{l+1}=x_l+\alpha f_l(x_l/\beta_l)xl+1​=xl​+αfl​(xl​/βl​),主要模拟BatchNorm在均值和方差上的表现,具体如下:

  • f(⋅)f(\cdot)f(⋅)为residual branch的计算函数,该函数需要特殊初始化,保证初期具有保持方差的功能,即Var(fl(z))=Var(z)Var(f_l(z))=Var(z)Var(fl​(z))=Var(z),这样的约束能够帮助更好地解释和分析网络的信号增长。
  • βl=Var(xl)\beta_l=\sqrt{Var(x_l)}βl​=Var(xl​)​为固定标量,值为输入特征的标准差,保证fl(⋅)f_l(\cdot)fl​(⋅)为单位方差。
  • α\alphaα为超参数,用于控制block间的方差增长速度。

  根据上面的设计,给定Var(x0)=1Var(x_0)=1Var(x0​)=1和βl=Var(xl)\beta_l=\sqrt{Var(x_l)}βl​=Var(xl​)​,可根据Var(xl)=Var(xl−1)+α2Var(x_l)=Var(x_{l-1})+\alpha^2Var(xl​)=Var(xl−1​)+α2直接计算第lll个residual block的输出的方差。为了模拟ResNet中的累积方差在transition block处被重置,需要将transition block的skip path的输入缩小为xl/βlx_l/\beta_lxl​/βl​,保证每个stage开头的transition block输出方差满足Var(xl+1)=1+α2Var(x_{l+1})=1+\alpha^2Var(xl+1​)=1+α2。将上述简单缩放策略应用到残差网络并去掉BatchNorm层,就得到了Normalizer-Free ResNets(NF-ResNets)。

ReLU Activations Induce Mean Shifts

  论文对使用He初始化的NF-ResNet进行SPPs分析,结果如图2,发现了两个比较意外的现象:

  • Average Channel Squared Mean随着网络变深不断增加,值大到超过了方差,有mean-shift现象。
  • 跟BN-ReLU-Conv类似,残差分支输出的方差始终小于1。

  为了验证上述现象,论文将网络的ReLU去掉再进行SPPs分析。如图7所示,当去掉ReLU后,Average Channel Squared Mean接近于0,而且残差分支输出的接近1,这表明是ReLU导致了mean-shift现象。
  论文也从理论的角度分析了这一现象,首先定义转化z=Wg(x)z=Wg(x)z=Wg(x),WWW为任意且固定的矩阵,g(⋅)g(\cdot)g(⋅)为作用于独立同分布输入xxx上的elememt-wise激活函数,所以g(x)g(x)g(x)也是独立同分布的。假设每个维度iii都有E(g(xi))=μg\mathbb{E}(g(x_i))=\mu_gE(g(xi​))=μg​以及Var(g(xi))=σg2Var(g(x_i))=\sigma^2_gVar(g(xi​))=σg2​,则输出zi=∑jNWi,jg(xj)z_i=\sum^N_jW_{i,j}g(x_j)zi​=∑jN​Wi,j​g(xj​)的均值和方差为:

  其中,μwi,.\mu w_{i,.}μwi,.​和σwi,.\sigma w_{i,.}σwi,.​为WWW的iii行(fan-in)的均值和方差:

  当g(⋅)g(\cdot)g(⋅)为ReLU激活函数时,则g(x)≥0g(x)\ge 0g(x)≥0,意味着后续的线性层的输入都为正均值。如果xi∼N(0,1)x_i\sim\mathcal{N}(0,1)xi​∼N(0,1),则μg=1/2π\mu_g=1/\sqrt{2\pi}μg​=1/2π​。由于μg>0\mu_g>0μg​>0,如果μwi\mu w_iμwi​也是非零,则ziz_izi​同样有非零均值。需要注意的是,即使WWW从均值为零的分布中采样而来,其实际的矩阵均值肯定不会为零,所以残差分支的任意维度的输出也不会为零,随着网络深度的增加,越来越难训练。

Scaled Weight Standardization

  为了消除mean-shift现象以及保证残差分支fl(⋅)f_l(\cdot)fl​(⋅)具有方差不变的特性,论文借鉴了Weight Standardization和Centered Weight Standardization,提出Scaled Weight Standardization(Scaled WS)方法,该方法对卷积层的权值重新进行如下的初始化:

  μ\muμ和σ\sigmaσ为卷积核的fan-in的均值和方差,权值WWW初始为高斯权值,γ\gammaγ为固定常量。代入公式1可以得出,对于z=W^g(x)z=\hat{W}g(x)z=W^g(x),有E(zi)=0\mathbb{E}(z_i)=0E(zi​)=0,去除了mean-shift现象。另外,方差变为Var(zi)=γ2σg2Var(z_i)=\gamma^2\sigma^2_gVar(zi​)=γ2σg2​,γ\gammaγ值由使用的激活函数决定,可保持方差不变。
  Scaled WS训练时增加的开销很少,而且与batch数据无关,在推理的时候更是无额外开销的。另外,训练和测试时的计算逻辑保持一致,对分布式训练也很友好。从图2的SPPs曲线可以看出,加入Scaled WS的NF-ResNet-600的表现跟ReLU-BN-Conv十分相似。

Determining Nonlinerity-Specific Constants

  最后的因素是γ\gammaγ值的确定,保证残差分支输出的方差在初始阶段接近1。γ\gammaγ值由网络使用的非线性激活类型决定,假设非线性的输入x∼N(0,1)x\sim\mathcal{N}(0,1)x∼N(0,1),则ReLU输出g(x)=max(x,0)g(x)=max(x,0)g(x)=max(x,0)相当于从方差为σg2=(1/2)(1−(1/π))\sigma^2_g=(1/2)(1-(1/\pi))σg2​=(1/2)(1−(1/π))的高斯分布采样而来。由于Var(W^g(x))=γ2σg2Var(\hat{W}g(x))=\gamma^2\sigma^2_gVar(W^g(x))=γ2σg2​,可设置γ=1/σg=21−1π\gamma=1/\sigma_g=\frac{\sqrt{2}}{\sqrt{1-\frac{1}{\pi}}}γ=1/σg​=1−π1​​2​​来保证Var(W^g(x))=1Var(\hat{W}g(x))=1Var(W^g(x))=1。虽然真实的输入不是完全符合x∼N(0,1)x\sim \mathcal{N}(0,1)x∼N(0,1),在实践中上述的γ\gammaγ设定依然有不错的表现。
  对于其他复杂的非线性激活,如SiLU和Swish,公式推导会涉及复杂的积分,甚至推出不出来。在这种情况下,可使用数值近似的方法。先从高斯分布中采样多个NNN维向量xxx,计算每个向量的激活输出的实际方差Var(g(x))Var(g(x))Var(g(x)),再取实际方差均值的平方根即可。

Other Building Block and Relaxed Constraints

  本文的核心在于保持正确的信息传递,所以许多常见的网络结构都要进行修改。如同选择γ\gammaγ值一样,可通过分析或实践判断必要的修改。比如SE模块y=sigmoid(MLP(pool(h)))∗hy=sigmoid(MLP(pool(h)))*hy=sigmoid(MLP(pool(h)))∗h,输出需要与[0,1][0,1][0,1]的权值进行相乘,导致信息传递减弱,网络变得不稳定。使用上面提到的数值近似进行单独分析,发现期望方差为0.5,这意味着输出需要乘以2来恢复正确的信息传递。
  实际上,有时相对简单的网络结构修改就可以保持很好的信息传递,而有时候即便网络结构不修改,网络本身也能够对网络结构导致的信息衰减有很好的鲁棒性。因此,论文也尝试在维持稳定训练的前提下,测试Scaled WS层的约束的最大放松程度。比如,为Scaled WS层恢复一些卷积的表达能力,加入可学习的缩放因子和偏置,分别用于权值相乘和非线性输出相加。当这些可学习参数没有任何约束时,训练的稳定性没有受到很大的影响,反而对大于150层的网络训练有一定的帮助。所以,NF-ResNet直接放松了约束,加入两个可学习参数。
  论文的附录有详细的网络实现细节,有兴趣的可以去看看。

Summary

  总结一下,Normalizer-Free ResNet的核心有以下几点:

  • 计算前向传播的期望方差βl2\beta^2_lβl2​,每经过一个残差block稳定增加α2\alpha^2α2,残差分支的输入需要缩小βl\beta_lβl​倍。
  • 将transition block中skip path的卷积输入缩小βl\beta_lβl​倍,并在transition block后将方差重置为βl+1=1+α2\beta_{l+1}=1+\alpha^2βl+1​=1+α2。
  • 对所有的卷积层使用Scaled Weight Standardization初始化,基于x∼N(0,1)x\sim\mathcal{N}(0,1)x∼N(0,1)计算激活函数g(x)g(x)g(x)对应的γ\gammaγ值,为激活函数输出的期望标准差的倒数1Var(g(x))\frac{1}{\sqrt{Var(g(x))}}Var(g(x))​1​。

Experiments


  对比RegNet的Normalizer-Free变种与其他方法的对比,相对于EfficientNet还是差点,但已经十分接近了。

Conclusion


  论文提出NF-ResNet,根据网络的实际信号传递进行分析,模拟BatchNorm在均值和方差传递上的表现,进而代替BatchNorm。论文实验和分析十分足,出来的效果也很不错。一些初始化方法的理论效果是对的,但实际使用会有偏差,论文通过实践分析发现了这一点进行补充,贯彻了实践出真知的道理。



如果本文对你有帮助,麻烦点个赞或在看呗~
更多内容请关注 微信公众号【晓飞的算法工程笔记】

NF-ResNet:去掉BN归一化,值得细读的网络信号分析 | ICLR 2021相关推荐

  1. 【卷积神经网络结构专题】ResNet及其变体的结构梳理、有效性分析

    点击上方,选择星标或置顶,不定期资源大放送! 阅读大概需要20分钟 Follow小博主,每天更新前沿干货 [导读]2020年,在各大CV顶会上又出现了许多基于ResNet改进的工作,比如:Res2Ne ...

  2. ICLR 2021投稿中值得一读的NLP相关论文

    我们从 ICLR 2021开放投稿的3000篇论文中,粗略筛选了近100篇与自然语言处理领域中也许值得一读的论文,供大家查阅. 理论.模型与经验性分析:38篇 问答与阅读理解:4篇 知识图谱:4篇 文 ...

  3. ResNet改进版来了!可训练网络超过3000层!相同深度精度更高

    来自阿联酋起源人工智能研究院(IIAI)的研究人员公布了一篇论文Improved Residual Networks for Image and Video Recognition,深入研究了残差网络 ...

  4. 142页ICML会议强化学习笔记整理,值得细读

    作者 | David Abel 编辑 | DeepRL 来源 | 深度强化学习实验室(ID: Deep-RL) ICML 是 International Conference on Machine L ...

  5. qt中实现左右分割线_Qt项目中,实现屏幕截图并生成gif的详细示例(值得细读)...

    总第50篇 平时我们在工作和学习的过程中,有时需要将桌面的某些动作截图生成gif动图,以更生动地呈现出来.目前有很多这样的软件,并且方便易使用,比如我经常使用的GifCam,软件小巧,生成的图片文件也 ...

  6. 深度学习论文导航 | 08 ResNet:用于图像识别的深度残差网络

    写在前面:大家好!我是[AI 菌],一枚爱弹吉他的程序员.我热爱AI.热爱分享.热爱开源! 这博客是我对学习的一点总结与记录.如果您也对 深度学习.机器视觉.算法.Python.C++ 感兴趣,可以关 ...

  7. 怎么把整个网站的代码中的一个词去掉_【杭州南牛网络】网站优化的最新优化方法...

    [杭州南牛网络]如果你是一名企业主,你有建立企业官方网站的经验,在2-3年的运营过程中,我相信你至少对网站做了一次修改,甚至对SEO战略进行了重大调整. 原因很简单:当我们刚开始建一家公司时,很多时间 ...

  8. ICLR 2021 有什么值得关注的投稿?

    链接:https://www.zhihu.com/question/423975807 编辑:深度学习与计算机视觉 声明:仅做学术分享,侵删 作者:简单名 https://www.zhihu.com/ ...

  9. 信号归一化功率_UE低发射功率余量分析

    1.功率余量基础原理 1.1 功率余量PH 云南 (图1) 1.2 功率余量报告PHR PHR,全称是Power Headroom Report,中文为功率余量报告,即UE向网侧报告功率余量的过程.这 ...

最新文章

  1. 【MySQL】 如何在“海啸”下保命
  2. Memcache的分布式应用
  3. 左右mysql事务提交
  4. nlp-tutorial代码注释3-3,双向RNN简介
  5. 模拟ARP报文发送,通过改变拓扑结构,观察报文发送方法以及途径
  6. $.ajax data怎么处理_不用jsp怎么实现前后端交互?给萌新后端的ajax教程(2)
  7. android pdf阅读器推荐,四款好用的PDF阅读器推荐,建议收藏!
  8. linux安装Telnet工具
  9. 计算机输入输出设计原则,交互设计精髓4中的104条设计原则
  10. 基于Material Studio软件使用第一性原理预测AlAs的晶格参数
  11. Android播放音乐的代码,android源代码(完整的音乐播放器)
  12. Android开发随手记1
  13. SQL注入之floor报错注入
  14. 罗克韦尔自动化帮助简化工业生产力分析
  15. PCB布局布线中地的设计(地与地使用跨接)。
  16. 揭秘“短视频创业”:一年亏50万,一个人就是一支团队
  17. 鸿蒙——通用设计基础(未完待续)
  18. XP停止服务:不必难过 千里相送终有一别
  19. Vue的全局事件总线实现任意组件间通信
  20. 单片机 c语言 d,单片机89C51与A/D转换器MAX - 控制/MCU - 电子发烧友网

热门文章

  1. R语言分组画条形图——qplot
  2. 微商分销系统哪家好,要怎么做?
  3. 令人不寒而栗的黄蓉(转)
  4. php怎样规定密码混合,PHP产生随机字串,可用来自动生成密码 默认长度6位 字母和数字混合...
  5. a href=javascript作用
  6. Interlocked.Increment 方法 和Interlocked.Decrement 方法作用
  7. VMbox 如何显示控制菜单,不显示控制菜单了
  8. 互联网春招和秋招的区别
  9. MyBatis高效同步百万级数据
  10. android模拟器装包,逍遥模拟器如何安装本地应用包apk?