最近在跑实验的过程中一直在使用resnet50和resnet34,为了弄清楚网络的结构和原理的实现,打开resnet的源码进行了学习。

残差网络学习的原理

针对神经网络过深而导致的学习准确率饱和甚至是退化现象,resnet通过将若干个卷积层前的输入x直接与经过卷积层卷积学习过的特征进行叠加,假如经过卷积层学习到的特征为H(x),那么经过若干卷积层后得到的特征F(x)=H(x) + x,那该网络需要学习的仅为H(x)。

图1 resnet的工作原理

Resnet源码分析

Resnet网络主要是若干个网络的堆叠,Resnet内部实现了两种网络,一个是Basicblock,一个是Bottleneck。

class BasicBlock(nn.Module):expansion = 1def __init__(self, inplanes, planes, stride=1, downsample=None):super(BasicBlock, self).__init__()self.conv1 = conv3x3(inplanes, planes, stride)self.bn1 = nn.BatchNorm2d(planes)self.relu = nn.ReLU(inplace=True)self.conv2 = conv3x3(planes, planes)self.bn2 = nn.BatchNorm2d(planes)self.downsample = downsampleself.stride = stridedef forward(self, x):identity = xout = self.conv1(x)out = self.bn1(out)out = self.relu(out)out = self.conv2(out)out = self.bn2(out)if self.downsample is not None:identity = self.downsample(x)out += identityout = self.relu(out)return out

可以看到在BasicBlock的__init__方法里主要是两个3*3的卷积层以及downsample,downsample起到一个下采样的作用,在out和原始输入x的通道数相同时,downsample为None,out与原始输入x可以直接叠加,而当out和原始输入x的通道数不同时,通过downsample下采样增加通道数然后再与out叠加。如下图所示

图2 BasicBlock块
class Bottleneck(nn.Module):expansion = 4def __init__(self, inplanes, planes, stride=1, downsample=None):super(Bottleneck, self).__init__()self.conv1 = conv1x1(inplanes, planes)self.bn1 = nn.BatchNorm2d(planes)self.conv2 = conv3x3(planes, planes, stride)self.bn2 = nn.BatchNorm2d(planes)self.conv3 = conv1x1(planes, planes * self.expansion)self.bn3 = nn.BatchNorm2d(planes * self.expansion)self.relu = nn.ReLU(inplace=True)self.downsample = downsampleself.stride = stridedef forward(self, x):identity = xout = self.conv1(x)out = self.bn1(out)out = self.relu(out)out = self.conv2(out)out = self.bn2(out)out = self.relu(out)out = self.conv3(out)out = self.bn3(out)if self.downsample is not None:identity = self.downsample(x)out += identityout = self.relu(out)return out

Bottleneck里面的网络为1*1卷积,3*3卷积和1*1卷积。如下图所示

图3 Bottleneck块

下面分别分析一下resnet34与resnet50的结构

Resnet34

model = ResNet(BasicBlock, [3, 4, 6, 3], **kwargs)

第一个参数是选用的基础网络,resnet34使用BasicBlock,第二个参数以列表的形式传入每一层内部包含的层数。

class ResNet(nn.Module):def __init__(self, block, layers, num_classes=1000, zero_init_residual=False):super(ResNet, self).__init__()self.inplanes = 64self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3,bias=False)self.bn1 = nn.BatchNorm2d(64)self.relu = nn.ReLU(inplace=True)self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)self.layer1 = self._make_layer(block, 64, layers[0])self.layer2 = self._make_layer(block, 128, layers[1], stride=2)self.layer3 = self._make_layer(block, 256, layers[2], stride=2)self.layer4 = self._make_layer(block, 512, layers[3], stride=2)self.avgpool = nn.AdaptiveAvgPool2d((1, 1))self.fc = nn.Linear(512 * block.expansion, num_classes)

可以看到,resnet网络的基本结构是第一层由普通的卷积,BN,Relu和最大池化构成的卷积层,接下来是4个layer,每一个layer里面都由n个block构成,n由之前传入的列表中的值决定。然后使用global average pool来代替全连接层进行展平,最后是一个全连接层进行分类。

下面看一下中间的4个layer是如何实现的

    def _make_layer(self, block, planes, blocks, stride=1):downsample = Noneif stride != 1 or self.inplanes != planes * block.expansion:downsample = nn.Sequential(conv1x1(self.inplanes, planes * block.expansion, stride),nn.BatchNorm2d(planes * block.expansion),)layers = []layers.append(block(self.inplanes, planes, stride, downsample))self.inplanes = planes * block.expansionfor _ in range(1, blocks):layers.append(block(self.inplanes, planes))return nn.Sequential(*layers)

对于layer1来说,通过判断的条件可以知道layer1中的downsample为None,之前传入的list第一个值为3,也就是说在layer1中有3个BasicBlock,并且在卷积过后输入和输出的维度一致。

对于layers2来说,downsample由1*1的卷积完成。因为每经过一个layer,通道数就会增加一倍,在layer1中的输入为64通道,输出也为64通道,但在layer2中输入为64通道,输出为128通道,所以需要通过下采样来使通道数翻倍。而在layers2中有4个block,只在第一个block中用到了downsample,使得在layer2中的通道数改变为128,其余与layers1中的block一致。对于layer3和layer4也是分别在第一个block之中将通道数进行了翻倍。layer3由6个block组成,layer4由3个block组成。

图4 layer1的第一个block网络结构
图5 layer2的第一个block网络结构

在4个layer之后就是一个简单的全局平均池化和全连接层分类,默认分1000类。

Resnet50

resnet50和resnet34的区别就在于使用的block块不一致,resnet50使用的是Bottleneck,其余地方没有什么大区别。

model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs)

总结

1、写的文章当做个人学习笔记来看,忘记的时候翻看一下

2、由于第一次写文章,有写的不规范的地方希望指正。

3、对于resnet这些都是我自己的理解,有错误的地方希望批评指正。

参考资料

小小将:你必须要知道CNN模型:ResNet​zhuanlan.zhihu.com

resnet50网络结构_学习笔记(一):分析resnet源码理解resnet网络结构相关推荐

  1. Android学习笔记-常用的一些源码,防止忘记了

    Android学习笔记-常用的一些源码,防止忘记了... 设置拨打电话 StringdialUri="tell:"+m_currentTelNumble; IntentcallIn ...

  2. Netty学习笔记(一)Netty客户端源码分析

    最近在学些BIO,NIO相关的知识,也学习了下Netty和它的源码,做个记录,方便以后继续学习,如果有错误的地方欢迎指正 如果不了解BIO,NIO这些基础知识,可以看下我的如下博客 IO中的阻塞.非阻 ...

  3. Nginx学习笔记(五) 源码分析内存模块内存对齐

    Nginx源码分析&内存模块 今天总结了下C语言的内存分配问题,那么就看看Nginx的内存分配相关模型的具体实现.还有内存对齐的内容~~不懂的可以看看~~ src/os/unix/Ngx_al ...

  4. Shiro学习笔记(三)源码解析

    Shiro作为轻量级的权限框架,Shiro的认证流程是怎样的一个过程. 如果没有对Shiro进行了解的话,建议先对Shiro学习笔记(一)学习一下Shiro基本的组 成. 1,几大重要组件解析 1.1 ...

  5. 【Redis学习笔记】2018-05-30 Redis源码学习之Ziplist、Server

    作者:施洪宝 顺风车运营研发团队 一. 压缩列表 压缩列表是Redis的关键数据结构之一.目前已经有大量的相关资料,下面几个链接都已经对Ziplist进行了详细的介绍. http://origin.r ...

  6. Laravel 学习笔记之 Query Builder 源码解析(下)

    说明:本文主要学习下Query Builder编译Fluent Api为SQL的细节和执行SQL的过程.实际上,上一篇聊到了\Illuminate\Database\Query\Builder这个非常 ...

  7. PyQt5学习笔记03----Qt Designer生成源码

    下面来分析一下Qt Designer生成的源码. Qt Designer制作的图形界面为 生成的代码如下 [python] view plaincopy from PyQt5 import QtCor ...

  8. [Java Path Finder][JPF学习笔记][4]将JPF源码导入Eclipse

    这篇日志很简单,考虑到有些师弟在学习JPF,这里总结些经验. 在Eclipse中新建"Java Project",在新建的Project的src图标上点击右键--"Imp ...

  9. 【SLAM学习笔记】12-ORB_SLAM3关键源码分析⑩ Optimizer(七)地图融合优化

    2021SC@SDUSC 目录 1.前言 2.代码分析 1.前言 这一部分代码量巨大,查阅了很多资料结合来看的代码,将分为以下部分进行分析 单帧优化 局部地图优化 全局优化 尺度与重力优化 sim3优 ...

最新文章

  1. mysql5.6下主主复制的配置实现
  2. 转:Jquery AJAX POST与GET之间的区别
  3. 使用mybatis自动生成指定规则的编号
  4. 34、Power Query-中国式排名
  5. HTML文本下划线效果,聊聊CSS中文本下划线_CSS, SVG, masking, clip-path, 会员专栏, text-decoration 教程_W3cplus...
  6. [POI2009]SLO
  7. 一台交换机不同vlan如何通信
  8. Rancher 2.0集群与工作负载告警
  9. 小米网技术架构变迁实践
  10. (翻译)Importing models-FBX Importer - Animations Tab
  11. [转]McAfee 病毒库最新离线升级包下载 VirusScan SuperDAT
  12. fdfs-文件上传信息返回详情
  13. JS保留小数 去尾法 进一法 四舍五入法
  14. 2016二级java题型分数_2016年英语六级考试题型、试卷结构及分值比例
  15. unity3D AR涂涂乐制作浅谈
  16. android ibinder类接口编辑
  17. 深度学习制作自己的样本
  18. 关键词(快排)刷词原理和方法
  19. Web渗透攻击之vega
  20. 磁盘阵列RAID技术超详细解读

热门文章

  1. 相变仿真之固液相变问题结合comsol案例
  2. 亚信科技两方案入围工信部“数字技术融合创新解决方案”评选
  3. 恶意软件与反病毒网关
  4. 【远程办公】vmware horizon client 安装失败
  5. AD9中怎么建立多个部分的组成的单个器件
  6. 【财务危机】--2018.9债务
  7. 中国文化垃圾论(zt)--作为镜子仅供反省
  8. ubuntu16.04 ROS安装
  9. 多普勒频率的推导(纯公式版)
  10. python的sql注入