点击我爱计算机视觉标星,更快获取CVML新技术


本文来自知乎,作者费敬敬,现为同济大学计算机科学与技术硕士。

https://zhuanlan.zhihu.com/p/54289848

温故而知新,理论结合实际,值得收藏的好文!

0、前言

何恺明等人在2015年提出的ResNet,在ImageNet比赛classification任务上获得第一名,获评CVPR2016最佳论文。因为它“简单与实用”并存,之后许多目标检测、图像分类任务都是建立在ResNet的基础上完成的,成为计算机视觉领域重要的基石结构。

  • 本文对ResNet的论文进行简单梳理,并对其网络结构进行分析,然后对Torchvision版的ResNet代码进行解读,最后对ResNet训练自有网络进行简单介绍

  • 论文连接:

    https://arxiv.org/abs/1512.03385

  • 代码链接:

    https://github.com/pytorch/vision/blob/master/torchvision/models/resnet.py

  • 本文所有代码解读均基于PyTroch 1.0,Python3;

  • 本文为原创文章,完成时间2019.01,相关参考链接附在文末;

1、ResNet要解决什么问题?

自从深度神经网络在ImageNet大放异彩之后,后来问世的深度神经网络就朝着网络层数越来越深的方向发展。直觉上我们不难得出结论:增加网络深度后,网络可以进行更加复杂的特征提取,因此更深的模型可以取得更好的结果。

但事实并非如此,人们发现随着网络深度的增加,模型精度并不总是提升,并且这个问题显然不是由过拟合(overfitting)造成的,因为网络加深后不仅测试误差变高了,它的训练误差竟然也变高了。作者提出,这可能是因为更深的网络会伴随梯度消失/爆炸问题,从而阻碍网络的收敛。作者将这种加深网络深度但网络性能却下降的现象称为退化问题(degradation problem)。

Is learning better networks as easy as stacking more layers? An obstacle to answering this question was the notorious problem of vanishing/exploding gradients [1, 9], which hamper convergence from the beginning.

Unexpectedly, such degradation is not caused by overfitting, and adding more layers to a suitably deep model leads to higher training error.

文中给出的实验结果进一步描述了这种退化问题:当传统神经网络的层数从20增加为56时,网络的训练误差和测试误差均出现了明显的增长,也就是说,网络的性能随着深度的增加出现了明显的退化。ResNet就是为了解决这种退化问题而诞生的。

图1 20层与56层传统神经网络在CIFAR上的训练误差和测试误差

2、ResNe怎么解决网络退化问题

随着网络层数的增加,梯度爆炸和梯度消失问题严重制约了神经网络的性能,研究人员通过提出包括Batch normalization在内的方法,已经一定程度上缓解了这个问题,但依然不足以满足需求。

This problem,however, has been largely addressed by normalized initialization [23, 9, 37, 13] and intermediate normalization layers[16], which enable networks with tens of layers to start converging for stochastic gradient descent (SGD) with backpropagation [22].

作者想到了构建恒等映射(Identity mapping)来解决这个问题,问题解决的标志是:增加网络层数,但训练误差不增加。为什么是恒等映射呢,我是这样子想的:20层的网络是56层网络的一个子集,如果我们将56层网络的最后36层全部短接,这些层进来是什么出来也是什么(也就是做一个恒等映射),那这个56层网络不就等效于20层网络了吗,至少效果不会相比原先的20层网络差吧。那不引入恒等映射的56层网络为什么不行呢?因为梯度消失现象使得网络难以训练,虽然网络的深度加深了,但是实际上无法有效训练网络,训练不充分的网络不但无法提升性能,甚至降低了性能。

There exists a solution by construction to the deeper model: the added layers are identity mapping, and the other layers are copied from the learned shallower model. The existence of this constructed solution indicates that a deeper model should produce no higher training error than its shallower counterpart.

那怎么构建恒等映射呢?简单地说,原先的网络输入x,希望输出H(x)。现在我们改一改,我们令H(x)=F(x)+x,那么我们的网络就只需要学习输出一个残差F(x)=H(x)-x。作者提出,学习残差F(x)=H(x)-x会比直接学习原始特征H(x)简单的多。

图2 残差学习基本单元

3、ResNet网络结构与代码实现

ResNet主要有五种变形:Res18,Res34,Res50,Res101,Res152。

如下图所示,每个网络都包括三个主要部分:输入部分、输出部分和中间卷积部分(中间卷积部分包括如图所示的Stage1到Stage4共计四个stage)。尽管ResNet的变种形式丰富,但是都遵循上述的结构特点,网络之间的不同主要在于中间卷积部分的block参数和个数存在差异。下面我们以ResNet18为例,看一下整个网络的实现代码是怎样的。

图3.1 ResNet结构总览
  • 网络整体结构

我们通过调用resnet18( )函数来生成一个具体的model,而resnet18函数则是借助ResNet类来构建网络的。

class ResNet(nn.Module):def forward(self, x):# 输入x = self.conv1(x)x = self.bn1(x)x = self.relu(x)x = self.maxpool(x)# 中间卷积x = self.layer1(x)x = self.layer2(x)x = self.layer3(x)x = self.layer4(x)# 输出x = self.avgpool(x)x = x.view(x.size(0), -1)x = self.fc(x)return x# 生成一个res18网络
def resnet18(pretrained=False, **kwargs):model = ResNet(BasicBlock, [2, 2, 2, 2], **kwargs)if pretrained:model.load_state_dict(model_zoo.load_url(model_urls['resnet18']))return model

在ResNet类中的forward( )函数规定了网络数据的流向:

(1)数据进入网络后先经过输入部分(conv1, bn1, relu, maxpool);

(2)然后进入中间卷积部分(layer1, layer2, layer3, layer4,这里的layer对应我们之前所说的stage);

(3)最后数据经过一个平均池化和全连接层(avgpool, fc)输出得到结果;

具体来说,resnet18和其他res系列网络的差异主要在于layer1~layer4,其他的部件都是相似的。

  • 网络输入部分

所有的ResNet网络输入部分是一个size=7x7, stride=2的大卷积核,以及一个size=3x3, stride=2的最大池化组成,通过这一步,一个224x224的输入图像就会变56x56大小的特征图,极大减少了存储所需大小。

        self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False)self.bn1 = nn.BatchNorm2d(64)self.relu = nn.ReLU(inplace=True)self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)

输入层特征图数据变化
  • 网络中间卷积部分

中间卷积部分主要是下图中的蓝框部分,通过3*3卷积的堆叠来实现信息的提取。红框中的[2, 2, 2, 2]和[3, 4, 6, 3]等则代表了bolck的重复堆叠次数。

ResNet结构细节

刚刚我们调用的resnet18( )函数中有一句 ResNet(BasicBlock, [2, 2, 2, 2], **kwargs),这里的[2, 2, 2, 2]与图中红框是一致的,如果你将这行代码改为 ResNet(BasicBlock, [3, 4, 6, 3], **kwargs), 那你就会得到一个res34网络。

  • 残差块实现

下面我们来具体看一下一个残差块是怎么实现的,如下图所示的basic-block,输入数据分成两条路,一条路经过两个3*3卷积,另一条路直接短接,二者相加经过relu输出,十分简单。

basic_block
class BasicBlock(nn.Module):expansion = 1def __init__(self, inplanes, planes, stride=1, downsample=None):super(BasicBlock, self).__init__()self.conv1 = conv3x3(inplanes, planes, stride)self.bn1 = nn.BatchNorm2d(planes)self.relu = nn.ReLU(inplace=True)self.conv2 = conv3x3(planes, planes)self.bn2 = nn.BatchNorm2d(planes)self.downsample = downsampleself.stride = stridedef forward(self, x):identity = xout = self.conv1(x)out = self.bn1(out)out = self.relu(out)out = self.conv2(out)out = self.bn2(out)if self.downsample is not None:identity = self.downsample(x)out += identityout = self.relu(out)return out

bascic_block 数据走向

代码比较清晰,不做分析了,主要提一个点:downsample,它的作用是对输入特征图大小进行减半处理,每个stage都有且只有一个downsample。后面我们再详细介绍。

  • 网络输出部分

网络输出部分很简单,通过全局自适应平滑池化,把所有的特征图拉成1*1,对于res18来说,就是1x512x7x7 的输入数据拉成 1x512x1x1,然后接全连接层输出,输出节点个数与预测类别个数一致。

        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))self.fc = nn.Linear(512 * block.expansion, num_classes)

至此,整体网络结构代码分析结束,更多细节,请看torchvision源码。

4、Bottleneck结构和1*1卷积

ResNet50起,就采用Bottleneck结构,主要是引入1x1卷积。我们来看一下这里的1x1卷积有什么作用:

  • 对通道数进行升维和降维(跨通道信息整合),实现了多个特征图的线性组合,同时保持了原有的特征图大小;

  • 相比于其他尺寸的卷积核,可以极大地降低运算复杂度;

  • 如果使用两个3x3卷积堆叠,只有一个relu,但使用1x1卷积就会有两个relu,引入了更多的非线性映射;

Basicblock和Bottleneck结构

我们来计算一下1*1卷积的计算量优势:首先看上图右边的bottleneck结构,对于256维的输入特征,参数数目:1x1x256x64+3x3x64x64+1x1x64x256=69632,如果同样的输入输出维度但不使用1x1卷积,而使用两个3x3卷积的话,参数数目为(3x3x256x256)x2=1179648。简单计算下就知道了,使用了1x1卷积的bottleneck将计算量简化为原有的5.9%,收益超高。

5、ResNet的网络设计规律

整个ResNet不使用dropout,全部使用BN。此外,回到最初的这张细节图,我们不难发现一些规律和特点:

  • 受VGG的启发,卷积层主要是3×3卷积;

  • 对于相同的输出特征图大小的层,即同一stage,具有相同数量的3x3滤波器;

  • 如果特征地图大小减半,滤波器的数量加倍以保持每层的时间复度。

  • 每个stage通过步长为2的卷积层执行下采样,而却这个下采样只会在每一个satge的第一个卷积完成,有且仅有一次。

  • 网络以全局平均池化层和softmax的1000路全连接层结束。

6、如何改造得到自己的ResNet?

我举一个简单的例子

from torchvision.models.resnet import *def get_net():model = resnet18(pretrained=True)model.avgpool = nn.AdaptiveAvgPool2d((1, 1))model.fc = nn.Sequential(nn.BatchNorm1d(512*1),nn.Linear(512*1, 你的分类类别数),)return model

代码简单解读一下:

  • 首先,通过torchvision导入相关的函数

  • 通过resnet18( )实例化一个模型,并使用imagenet预训练权重

  • 将平均池化修改为自适应全局平均池化,避免输入特征尺寸不匹配

  • 修改全连接层,主要是修改分类类别墅,并加入BN1d

这样子,不仅可以根据自己的需求改造网络,还能最大限度的使用现成的预训练权重。需要注意的是,这里的nn.BatchNorm1d(512*1)是很必要的,初学者可以尝试删除这个部件感受一下区别。在我曾经的实验里面,loss会直接爆炸。

7、ResNet的常见改进

  • 改进一:改进downsample部分,减少信息流失。前面说过了,每个stage的第一个conv都有下采样的步骤,我们看左边第一张图左侧的通路,input数据进入后在会经历一个stride=2的1*1卷积,将特征图尺寸减小为原先的一半,请注意1x1卷积和stride=2会导致输入特征图3/4的信息不被利用,因此ResNet-B的改进就是就是将下采样移到后面的3x3卷积里面去做,避免了信息的大量流失。ResNet-C则是另一种思路。ResNet-D则是在ResNet-B的基础上将identity部分的下采样交给avgpool去做,避免出现1x1卷积和stride同时出现造成信息流失。

ResNet的三种改进
  • 改进二:ResNet V2。这是由ResNet原班人马打造的,主要是对ResNet部分组件的顺序进行了调整。各种魔改中常见的预激活ResNet就是出自这里。

ResNet V2

原始的resnet是上图中的a的模式,我们可以看到相加后需要进入ReLU做一个非线性激活,这里一个改进就是砍掉了这个非线性激活,不难理解,如果将ReLU放在原先的位置,那么残差块输出永远是非负的,这制约了模型的表达能力,因此我们需要做一些调整,我们将这个ReLU移入了残差块内部,也就是图e的模式。这里的细节比较多,建议直接阅读原文:Identity Mappings in Deep Residual Networks (https://arxiv.org/abs/1603.05027),就先介绍这么多。

8.1、从模型集成角度理解ResNet的有效性

ResNet 中其实是存在着很多路径的集合,整个ResNet类似于多个网络的集成学习,证据是删除部分ResNet的网络结点,不影响整个网络的性能,但是在VGG上做同样的事请网络立刻崩溃,由此可见相比其他网络ResNet对于部分路径的缺失不敏感。更多细节具体可参见NIPS论文: Residual Networks Behave Like Ensembles of Relatively Shallow Networks  (http://papers.nips.cc/paper/6556-residual-networks-behave-like-ensembles-of-relatively-shallow-networks);

模型集成假说

破坏性实验

8.2、从梯度反向传播角度理解ResNet的有效性

残差结构使得梯度反向传播时,更不易出现梯度消失等问题,由于Skip Connection的存在,梯度能畅通无阻地通过各个Res blocks,下面我们来推导一下 ResNet v2 的反向传播过程。

原始的残差公式是这样子的,函数F表示一个残差函数,函数f表示激活函数,:

ResNet v2 使用恒等映射,且相加后不使用激活函数,因此可得到:

递归得到第L层的表达式:

反向传播求第l层梯度:

我们从这个表达式可以看出来:第l层的梯度里,包含了第L层的梯度,通俗的说就是第L层的梯度直接传递给了第l层。因为梯度消失问题主要是发生在浅层,这种将深层梯度直接传递给浅层的做法,有效缓解了深度神经网络梯度消失的问题。

9、总结

ResNet是当前计算机视觉领域的基石结构,是初学者无法绕开的网络模型,仔细阅读论文和源码并进行实验是极有必要的。

参考资料

1.你必须要知道CNN模型:ResNet

https://zhuanlan.zhihu.com/p/31852747

2.一文读懂卷积神经网络中的1x1卷积核

https://zhuanlan.zhihu.com/p/40050371

3.为什么ResNet和DenseNet可以这么深?一文详解残差块为何有助于解决梯度弥散问题

https://zhuanlan.zhihu.com/p/28124810

4.论文笔记:Residual Network内部结构剖析

https://zhuanlan.zhihu.com/p/37820282

5.[论文阅读]Identity Mappings in Deep Residual Networks

https://zhuanlan.zhihu.com/p/47766814

加群交流

关注计算机视觉与机器学习技术,欢迎加入52CV群,扫码添加52CV君拉你入群,

(请务必注明:加群)

喜欢在QQ交流的童鞋,可以加52CV官方QQ群:928997753。

(不会时时在线,如果没能及时通过验证还请见谅)


长按关注我爱计算机视觉

麻烦给我一个好看

ResNet及其变种的结构梳理、有效性分析与代码解读(PyTorch)相关推荐

  1. 编译原理语义分析代码_Pix2Pix原理分析与代码解读

    原理分析: 图像.视觉中很多问题都涉及到将一副图像转换为另一幅图像(Image-to-Image Translation Problem),这些问题通常都使用特定的方法来解决,不存在一个通用的方法.但 ...

  2. ResNet及其变体结构梳理与总结

    点击上方"小白学视觉",选择加"星标"或"置顶" 重磅干货,第一时间送达 [导读]2020年,在各大CV顶会上又出现了许多基于ResNet改 ...

  3. html中并列式的应用,并列式结构梳理

    本篇文章中公事业单位(www.zgsydw.com)提供言语理解知识:<并列式结构梳理>. 一.并列式结构的概述 并列式结构的文段没有中心句,而是并列枚举观点或事例,考生需要归纳这些观点或 ...

  4. 管理类综合-论证有效性分析思路总结

    阅读--①找错--(②分解组合/筛选/评估)--③析错--④成文 原来是要求找 3 点漏洞,现在是确保至少 4 点 3个错别字扣1分,重复的不计,至多扣2分,书写酌情扣1-2分 标点符号要占一格,两个 ...

  5. 需求结构化与分析约束影响

    第4章需求结构化与分析约束影响 心念不同,判断力自然不同. --严定暹,<格局决定结局> 全面认识需求,是生产出高质量软件所必须的"第一项修炼". --温昱,<软 ...

  6. 英文Essay写作标准结构及风格分析

    在英文Essay写作中,通常有以下标准的结构: 标题页:包含Essay题目.作者信息.机构或学校名称和日期等. 摘要和关键词:摘要是Essay的简短概括,包含研究目的.方法.结果和结论等内容,关键词是 ...

  7. 面向过程(或者叫结构化)分析方法与面向对象分析方法到底区别在哪里?

    AutoSAR入门到精通系列讲解 将从2019年开始更新关于AutoSAR的知识,从入门到精通,博主xyfx和大家一起进步 雪云飞星 ¥29.90 去订阅 简单地说结构化分析方法主要用来分析系统的功能 ...

  8. 【原创】【专栏】《Linux设备驱动程序》--- LDD3源码目录结构和源码分析经典链接

    http://blog.csdn.net/geng823/article/details/37567557 [原创][专栏]<Linux设备驱动程序>--- LDD3源码目录结构和源码分析 ...

  9. 【深度学习】Swin Transformer结构和应用分析

    [深度学习]Swin Transformer结构和应用分析 文章目录 1 引言 2 Swin Transformer结构 3 分析3.1 Hierarchical Feature Representa ...

最新文章

  1. centos7.4安装mysql5.7_CentOS7.4手动安装MySQL5.7的方法
  2. 硬回车与软回车[转]
  3. 【Android应用开发】Android 蓝牙低功耗 (BLE) ( 第一篇 . 概述 . 蓝牙低功耗文档 翻译)
  4. 《深入理解C++11:C++ 11新特性解析与应用》——3.2 委派构造函数
  5. MyEclipse使用总结——在MyEclipse中设置jsp页面为默认utf-8编码
  6. hadoop centos 安装
  7. [云炬创业管理笔记]第五章打磨最有效的商业模式测试4
  8. mysql中sex设置男女_MYSQL常用命令(3)
  9. Linux删除重复内容命令uniq笔记
  10. 一个C++程序执行main函数前和执行完main函数后会发生什么。
  11. 爬虫-视频资源的爬取
  12. 数组中的两个常见异常
  13. 写在2012的最后一天
  14. MyBatis Review——多对多映射
  15. 【Prometheus】Prometheus联邦的一次优化记录[续]
  16. C语言中 1%3,算术什么意思啊 算数什么意思
  17. 嵩天老师python爬虫笔记整理week3
  18. mysql级联删除_近百道MySQL面试题和答案(2020收藏版)(完结篇)
  19. 游戏领域的“抄袭”与“借鉴”之分,无耻与致敬仅有一步之遥
  20. DTMF三种模式(SIPINFO,RFC2833,INBAND)

热门文章

  1. 122 - Trees on the level(模拟内存池解法)
  2. java 虚拟机初始堆_了解java虚拟机—堆相关参数设置(3)
  3. ad17 pcb扇孔_PCB设计中为什么需要先进行扇孔
  4. c语言练习书,谁有C语言入门的练习题?
  5. mysql ha 安装 配置文件_Linux下环境安装配置Rose HA全攻略(图)
  6. mysql proxy ro-pooling.lua_MySQL读写分离
  7. 快速锁屏电脑快捷键_电脑小技巧
  8. nginx 判断手机端跳转_Nginx系列:配置跳转的常用方式
  9. python物理模拟_在Python游戏中模拟重力【Programming(Python)】
  10. python类和对象的定义_python类与对象基本语法