resnet50网络结构_学习笔记(一):分析resnet源码理解resnet网络结构
最近在跑实验的过程中一直在使用resnet50和resnet34,为了弄清楚网络的结构和原理的实现,打开resnet的源码进行了学习。
残差网络学习的原理
针对神经网络过深而导致的学习准确率饱和甚至是退化现象,resnet通过将若干个卷积层前的输入x直接与经过卷积层卷积学习过的特征进行叠加,假如经过卷积层学习到的特征为H(x),那么经过若干卷积层后得到的特征F(x)=H(x) + x,那该网络需要学习的仅为H(x)。
Resnet源码分析
Resnet网络主要是若干个网络的堆叠,Resnet内部实现了两种网络,一个是Basicblock,一个是Bottleneck。
class BasicBlock(nn.Module):expansion = 1def __init__(self, inplanes, planes, stride=1, downsample=None):super(BasicBlock, self).__init__()self.conv1 = conv3x3(inplanes, planes, stride)self.bn1 = nn.BatchNorm2d(planes)self.relu = nn.ReLU(inplace=True)self.conv2 = conv3x3(planes, planes)self.bn2 = nn.BatchNorm2d(planes)self.downsample = downsampleself.stride = stridedef forward(self, x):identity = xout = self.conv1(x)out = self.bn1(out)out = self.relu(out)out = self.conv2(out)out = self.bn2(out)if self.downsample is not None:identity = self.downsample(x)out += identityout = self.relu(out)return out
可以看到在BasicBlock的__init__方法里主要是两个3*3的卷积层以及downsample,downsample起到一个下采样的作用,在out和原始输入x的通道数相同时,downsample为None,out与原始输入x可以直接叠加,而当out和原始输入x的通道数不同时,通过downsample下采样增加通道数然后再与out叠加。如下图所示
class Bottleneck(nn.Module):expansion = 4def __init__(self, inplanes, planes, stride=1, downsample=None):super(Bottleneck, self).__init__()self.conv1 = conv1x1(inplanes, planes)self.bn1 = nn.BatchNorm2d(planes)self.conv2 = conv3x3(planes, planes, stride)self.bn2 = nn.BatchNorm2d(planes)self.conv3 = conv1x1(planes, planes * self.expansion)self.bn3 = nn.BatchNorm2d(planes * self.expansion)self.relu = nn.ReLU(inplace=True)self.downsample = downsampleself.stride = stridedef forward(self, x):identity = xout = self.conv1(x)out = self.bn1(out)out = self.relu(out)out = self.conv2(out)out = self.bn2(out)out = self.relu(out)out = self.conv3(out)out = self.bn3(out)if self.downsample is not None:identity = self.downsample(x)out += identityout = self.relu(out)return out
Bottleneck里面的网络为1*1卷积,3*3卷积和1*1卷积。如下图所示
下面分别分析一下resnet34与resnet50的结构
Resnet34
model = ResNet(BasicBlock, [3, 4, 6, 3], **kwargs)
第一个参数是选用的基础网络,resnet34使用BasicBlock,第二个参数以列表的形式传入每一层内部包含的层数。
class ResNet(nn.Module):def __init__(self, block, layers, num_classes=1000, zero_init_residual=False):super(ResNet, self).__init__()self.inplanes = 64self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3,bias=False)self.bn1 = nn.BatchNorm2d(64)self.relu = nn.ReLU(inplace=True)self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)self.layer1 = self._make_layer(block, 64, layers[0])self.layer2 = self._make_layer(block, 128, layers[1], stride=2)self.layer3 = self._make_layer(block, 256, layers[2], stride=2)self.layer4 = self._make_layer(block, 512, layers[3], stride=2)self.avgpool = nn.AdaptiveAvgPool2d((1, 1))self.fc = nn.Linear(512 * block.expansion, num_classes)
可以看到,resnet网络的基本结构是第一层由普通的卷积,BN,Relu和最大池化构成的卷积层,接下来是4个layer,每一个layer里面都由n个block构成,n由之前传入的列表中的值决定。然后使用global average pool来代替全连接层进行展平,最后是一个全连接层进行分类。
下面看一下中间的4个layer是如何实现的
def _make_layer(self, block, planes, blocks, stride=1):downsample = Noneif stride != 1 or self.inplanes != planes * block.expansion:downsample = nn.Sequential(conv1x1(self.inplanes, planes * block.expansion, stride),nn.BatchNorm2d(planes * block.expansion),)layers = []layers.append(block(self.inplanes, planes, stride, downsample))self.inplanes = planes * block.expansionfor _ in range(1, blocks):layers.append(block(self.inplanes, planes))return nn.Sequential(*layers)
对于layer1来说,通过判断的条件可以知道layer1中的downsample为None,之前传入的list第一个值为3,也就是说在layer1中有3个BasicBlock,并且在卷积过后输入和输出的维度一致。
对于layers2来说,downsample由1*1的卷积完成。因为每经过一个layer,通道数就会增加一倍,在layer1中的输入为64通道,输出也为64通道,但在layer2中输入为64通道,输出为128通道,所以需要通过下采样来使通道数翻倍。而在layers2中有4个block,只在第一个block中用到了downsample,使得在layer2中的通道数改变为128,其余与layers1中的block一致。对于layer3和layer4也是分别在第一个block之中将通道数进行了翻倍。layer3由6个block组成,layer4由3个block组成。
在4个layer之后就是一个简单的全局平均池化和全连接层分类,默认分1000类。
Resnet50
resnet50和resnet34的区别就在于使用的block块不一致,resnet50使用的是Bottleneck,其余地方没有什么大区别。
model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs)
总结
1、写的文章当做个人学习笔记来看,忘记的时候翻看一下
2、由于第一次写文章,有写的不规范的地方希望指正。
3、对于resnet这些都是我自己的理解,有错误的地方希望批评指正。
参考资料
小小将:你必须要知道CNN模型:ResNetzhuanlan.zhihu.com
resnet50网络结构_学习笔记(一):分析resnet源码理解resnet网络结构相关推荐
- Android学习笔记-常用的一些源码,防止忘记了
Android学习笔记-常用的一些源码,防止忘记了... 设置拨打电话 StringdialUri="tell:"+m_currentTelNumble; IntentcallIn ...
- Netty学习笔记(一)Netty客户端源码分析
最近在学些BIO,NIO相关的知识,也学习了下Netty和它的源码,做个记录,方便以后继续学习,如果有错误的地方欢迎指正 如果不了解BIO,NIO这些基础知识,可以看下我的如下博客 IO中的阻塞.非阻 ...
- Nginx学习笔记(五) 源码分析内存模块内存对齐
Nginx源码分析&内存模块 今天总结了下C语言的内存分配问题,那么就看看Nginx的内存分配相关模型的具体实现.还有内存对齐的内容~~不懂的可以看看~~ src/os/unix/Ngx_al ...
- Shiro学习笔记(三)源码解析
Shiro作为轻量级的权限框架,Shiro的认证流程是怎样的一个过程. 如果没有对Shiro进行了解的话,建议先对Shiro学习笔记(一)学习一下Shiro基本的组 成. 1,几大重要组件解析 1.1 ...
- 【Redis学习笔记】2018-05-30 Redis源码学习之Ziplist、Server
作者:施洪宝 顺风车运营研发团队 一. 压缩列表 压缩列表是Redis的关键数据结构之一.目前已经有大量的相关资料,下面几个链接都已经对Ziplist进行了详细的介绍. http://origin.r ...
- Laravel 学习笔记之 Query Builder 源码解析(下)
说明:本文主要学习下Query Builder编译Fluent Api为SQL的细节和执行SQL的过程.实际上,上一篇聊到了\Illuminate\Database\Query\Builder这个非常 ...
- PyQt5学习笔记03----Qt Designer生成源码
下面来分析一下Qt Designer生成的源码. Qt Designer制作的图形界面为 生成的代码如下 [python] view plaincopy from PyQt5 import QtCor ...
- [Java Path Finder][JPF学习笔记][4]将JPF源码导入Eclipse
这篇日志很简单,考虑到有些师弟在学习JPF,这里总结些经验. 在Eclipse中新建"Java Project",在新建的Project的src图标上点击右键--"Imp ...
- 【SLAM学习笔记】12-ORB_SLAM3关键源码分析⑩ Optimizer(七)地图融合优化
2021SC@SDUSC 目录 1.前言 2.代码分析 1.前言 这一部分代码量巨大,查阅了很多资料结合来看的代码,将分为以下部分进行分析 单帧优化 局部地图优化 全局优化 尺度与重力优化 sim3优 ...
最新文章
- mysql5.6下主主复制的配置实现
- 转:Jquery AJAX POST与GET之间的区别
- 使用mybatis自动生成指定规则的编号
- 34、Power Query-中国式排名
- HTML文本下划线效果,聊聊CSS中文本下划线_CSS, SVG, masking, clip-path, 会员专栏, text-decoration 教程_W3cplus...
- [POI2009]SLO
- 一台交换机不同vlan如何通信
- Rancher 2.0集群与工作负载告警
- 小米网技术架构变迁实践
- (翻译)Importing models-FBX Importer - Animations Tab
- [转]McAfee 病毒库最新离线升级包下载 VirusScan SuperDAT
- fdfs-文件上传信息返回详情
- JS保留小数 去尾法 进一法 四舍五入法
- 2016二级java题型分数_2016年英语六级考试题型、试卷结构及分值比例
- unity3D AR涂涂乐制作浅谈
- android ibinder类接口编辑
- 深度学习制作自己的样本
- 关键词(快排)刷词原理和方法
- Web渗透攻击之vega
- 磁盘阵列RAID技术超详细解读