目录

1.简介

2.ResNeXt网络结构

2.1 ResNeXt block

2.2 实验对比

2.3 ResNeXt整体结构

3.ResNeXt模型代码


1.简介

ResNeXt实际上就是对Resnet网络的结构做出了一些调整,性能有些许提升,主要的改变就是将之前得普通卷积改成了组卷积(如果不知道组卷积的建议先去了解下),减少了一定的参数量。

上图为论文中给出的ResNeXt的性能,可以发现ResNeXt-101对比ResNet-101来说,错误率明显下降,在输入尺寸为320x320和299x299时的错误率也比同时期的其他网络要低。

2.ResNeXt网络结构

2.1 ResNeXt block

其实ResNeXt block就是将ResNet block中的普通卷积替换成了组卷积,并且将通道数扩大一倍,也就是卷积核个数是之前得两倍,其他都一样,没什么变化。

上图左就是非常熟悉的ResNet block,上图右就是该文章中提出的ResNeXt block,就是将之前Resnet block中的3x3普通卷积替换成了group=32的组卷积,并且前两层卷积的out_channel(卷积核个数)都是之前得两倍,如64—>128。如果上图看不懂可以看下图,以下三张图在数学层面是一样的,可以直接看第三张图,理解起来比较容易,稍微推理一下就能得到前两张图了。

2.2 实验对比

现在你应该就比较清楚了ResNeXt block的结构了,将Resnet中的block替换一下就是ResNeXt网络了。可能你有一个疑问,为什么group一定是32呢?作者也是通过一系列的实验得到的,如下图

上图中就是作者调整group数得到的实验结果,其中setting栏的第一个数字代表着group数,第二个xx d代表着每组的卷积核个数。可以看到ResNet的group=1,卷积核个数为64。作者对比了group为1、2、4、8、32时的top1错误率,发现group=32时错误率最低,所以选择了group为32,这时每组卷积核个数为4,总共的out_channel就为resnet的两倍了。

PS:经过本人实验,DW卷积的效果更好,精度更高,收敛更快,不知道作者为什么不继续扩大group的值,或者这篇文章的主要目的就是凸显组卷积的好处?

上图为网络参数量为一倍和两倍时与ResNeXt网络的错误率的对比,可以发现,在两个部分中ResNeXt网络的错误率都是最低的。

2.3 ResNeXt整体结构

3.ResNeXt模型代码

代码很简单,我就不做什么过多的注释了,不会的评论或私信我。

import torch.nn as nn
import torchclass BasicBlock(nn.Module):expansion = 1def __init__(self, in_channel, out_channel, stride=1, downsample=None, **kwargs):super(BasicBlock, self).__init__()self.conv1 = nn.Conv2d(in_channels=in_channel, out_channels=out_channel,kernel_size=3, stride=stride, padding=1, bias=False)self.bn1 = nn.BatchNorm2d(out_channel)self.relu = nn.ReLU()self.conv2 = nn.Conv2d(in_channels=out_channel, out_channels=out_channel,kernel_size=3, stride=1, padding=1, bias=False)self.bn2 = nn.BatchNorm2d(out_channel)self.downsample = downsampledef forward(self, x):identity = xif self.downsample is not None:identity = self.downsample(x)out = self.conv1(x)out = self.bn1(out)out = self.relu(out)out = self.conv2(out)out = self.bn2(out)out += identityout = self.relu(out)return outclass Bottleneck(nn.Module):"""注意:原论文中,在虚线残差结构的主分支上,第一个1x1卷积层的步距是2,第二个3x3卷积层步距是1。但在pytorch官方实现过程中是第一个1x1卷积层的步距是1,第二个3x3卷积层步距是2,这么做的好处是能够在top1上提升大概0.5%的准确率。可参考Resnet v1.5 https://ngc.nvidia.com/catalog/model-scripts/nvidia:resnet_50_v1_5_for_pytorch"""expansion = 4def __init__(self, in_channel, out_channel, stride=1, downsample=None,groups=1, width_per_group=64):super(Bottleneck, self).__init__()width = int(out_channel * (width_per_group / 64.)) * groupsself.conv1 = nn.Conv2d(in_channels=in_channel, out_channels=width,kernel_size=1, stride=1, bias=False)  # squeeze channelsself.bn1 = nn.BatchNorm2d(width)# -----------------------------------------self.conv2 = nn.Conv2d(in_channels=width, out_channels=width, groups=groups,kernel_size=3, stride=stride, bias=False, padding=1)self.bn2 = nn.BatchNorm2d(width)# -----------------------------------------self.conv3 = nn.Conv2d(in_channels=width, out_channels=out_channel*self.expansion,kernel_size=1, stride=1, bias=False)  # unsqueeze channelsself.bn3 = nn.BatchNorm2d(out_channel*self.expansion)self.relu = nn.ReLU(inplace=True)self.downsample = downsampledef forward(self, x):identity = xif self.downsample is not None:identity = self.downsample(x)out = self.conv1(x)out = self.bn1(out)out = self.relu(out)out = self.conv2(out)out = self.bn2(out)out = self.relu(out)out = self.conv3(out)out = self.bn3(out)out += identityout = self.relu(out)return outclass ResNet(nn.Module):def __init__(self,block,blocks_num,num_classes=1000,include_top=True,groups=1,width_per_group=64):super(ResNet, self).__init__()self.include_top = include_topself.in_channel = 64self.groups = groupsself.width_per_group = width_per_groupself.conv1 = nn.Conv2d(3, self.in_channel, kernel_size=7, stride=2,padding=3, bias=False)self.bn1 = nn.BatchNorm2d(self.in_channel)self.relu = nn.ReLU(inplace=True)self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)self.layer1 = self._make_layer(block, 64, blocks_num[0])self.layer2 = self._make_layer(block, 128, blocks_num[1], stride=2)self.layer3 = self._make_layer(block, 256, blocks_num[2], stride=2)self.layer4 = self._make_layer(block, 512, blocks_num[3], stride=2)if self.include_top:self.avgpool = nn.AdaptiveAvgPool2d((1, 1))  # output size = (1, 1)self.fc = nn.Linear(512 * block.expansion, num_classes)for m in self.modules():if isinstance(m, nn.Conv2d):nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')def _make_layer(self, block, channel, block_num, stride=1):downsample = Noneif stride != 1 or self.in_channel != channel * block.expansion:downsample = nn.Sequential(nn.Conv2d(self.in_channel, channel * block.expansion, kernel_size=1, stride=stride, bias=False),nn.BatchNorm2d(channel * block.expansion))layers = []layers.append(block(self.in_channel,channel,downsample=downsample,stride=stride,groups=self.groups,width_per_group=self.width_per_group))self.in_channel = channel * block.expansionfor _ in range(1, block_num):layers.append(block(self.in_channel,channel,groups=self.groups,width_per_group=self.width_per_group))return nn.Sequential(*layers)def forward(self, x):x = self.conv1(x)x = self.bn1(x)x = self.relu(x)x = self.maxpool(x)x = self.layer1(x)x = self.layer2(x)x = self.layer3(x)x = self.layer4(x)if self.include_top:x = self.avgpool(x)x = torch.flatten(x, 1)x = self.fc(x)return xdef resnet34(num_classes=1000, include_top=True):# https://download.pytorch.org/models/resnet34-333f7ec4.pthreturn ResNet(BasicBlock, [3, 4, 6, 3], num_classes=num_classes, include_top=include_top)def resnet50(num_classes=1000, include_top=True):# https://download.pytorch.org/models/resnet50-19c8e357.pthreturn ResNet(Bottleneck, [3, 4, 6, 3], num_classes=num_classes, include_top=include_top)def resnet101(num_classes=1000, include_top=True):# https://download.pytorch.org/models/resnet101-5d3b4d8f.pthreturn ResNet(Bottleneck, [3, 4, 23, 3], num_classes=num_classes, include_top=include_top)def resnext50_32x4d(num_classes=1000, include_top=True):# https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pthgroups = 32width_per_group = 4return ResNet(Bottleneck, [3, 4, 6, 3],num_classes=num_classes,include_top=include_top,groups=groups,width_per_group=width_per_group)def resnext101_32x8d(num_classes=1000, include_top=True):# https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pthgroups = 32width_per_group = 8return ResNet(Bottleneck, [3, 4, 23, 3],num_classes=num_classes,include_top=include_top,groups=groups,width_per_group=width_per_group)

快速理解ResNeXt(结合代码)相关推荐

  1. 如何快速理解递归——蓝桥杯 试题 基础练习 FJ的字符串(递归与非递归解法)——10行代码AC

    励志用少的代码做高效的表达. 注意点: 1.规律 2.非递归解法:string重载了+=运算符,因此用string会方便很多.并且string动态扩充,防浪费,更高效. 3.递归解法:官方的标签就是递 ...

  2. 一篇文章带你快速理解JVM运行时数据区 、程序计数器详解 (手画详图)值得收藏!!!

    受多种情况的影响,又开始看JVM 方面的知识. 1.Java 实在过于内卷,没法不往深了学. 2.面试题问的多,被迫学习. 3.纯粹的好奇. 很喜欢一句话:"八小时内谋生活,八小时外谋发展. ...

  3. 网络编程懒人入门(一):快速理解网络通信协议(上篇)

    1.写在前面 论坛和群里常会有技术同行打算自已开发IM或者消息推送系统,很多时候连基本的网络编程理论(如网络协议等)都不了解,就贸然定方案.写代码,显得非常盲目且充满技术风险. 即时通讯网论坛里精心整 ...

  4. 两个相邻盒子的边框怎么只显示一个_一篇文章带你快速理解盒子模型「经典案例」...

    今天带大家快速理解盒子模型,直接上代码: css盒子 我的css盒子测试模型 上面代码没有任何难度,只是写了一个div标签,大家已经知道,div标签是块级元素,所以会占满一行: 但是我们也注意到了图片 ...

  5. 快速理解Spark Dataset

    1. 前言 RDD.DataFrame.Dataset是Spark三个最重要的概念,RDD和DataFrame两个概念出现的比较早,Dataset相对出现的较晚(1.6版本开始出现),有些开发人员对此 ...

  6. 快速理解ASP.NET Core的认证与授权

    ASP.NET Core的认证与授权已经不是什么新鲜事了,微软官方的文档对于如何在ASP.NET Core中实现认证与授权有着非常详细深入的介绍.但有时候在开发过程中,我们也往往会感觉无从下手,或者由 ...

  7. 快速傅里叶变换及python代码实现

    文章来源:https://www.cnblogs.com/LXP-Never/p/11558302.html 快速傅里叶变换及python代码实现 目录 一.前言   傅里叶变换相关函数   基于傅里 ...

  8. 安卓逆向_7 --- 六种快速定位关键 Smali 代码的方法 ( 去掉 RE 广告 )

    哔哩哔哩:https://www.bilibili.com/video/BV1UE411A7rW?p=34 具体用法,看视频教程( 去掉 RE 的 结束广告 ) 6 种定位关键代码的方法,当然还有其他 ...

  9. 三分钟快速理解javascript内存管理

    javascript中具有垃圾自动回收机制(Garbage Collection),也就是执行环境会负责管理代码执行过程中使用的内存,在开发过程中就可以不考虑内存的分配,以及无用内存释放的问题.但是触 ...

最新文章

  1. 化学人学python有前途吗-课堂上老师不讲的有趣物理知识,才是孩子最感兴趣的!...
  2. MacOs中Docker与宿主机网络互通问题解决
  3. 为了方便在微博上看小黄图,我写了一段JS
  4. Mac屏幕录制GIF动图
  5. 局域网弱口令扫描工具_“菜鸟黑客”必用兵器之“扫描篇”
  6. 批量将一个 PDF 文件按固定页数拆分成多个小的 PDF 文件
  7. JDBC 操作数据库步骤
  8. 极客时间课程笔记:业务安全
  9. java语音播报天气_语音播报实时天气
  10. proxychains替代品polipo
  11. 计算几何——向量的叉乘、点乘、夹角
  12. 苹果手机如何找回id密码_iPhone手机ID总是忘记密码,轻松一招帮你找回,原来这么简单...
  13. 从战争到外包软件开发:如何赢得最后胜利
  14. 模拟赛20200228(yyq)【右链+dfs序,子树管辖,聚集水流问题】
  15. python读写csv常用方法
  16. win10 无法连接打印机 报0x00000520错误解决办法!
  17. 置信度和置信度区间理解
  18. 计算机毕业设计Node.js+Vue办公用品管理系统(程序+源码+LW+部署)
  19. linux pppoe拨号上网
  20. PowerPoint VBA批量格式转换:pptx转pdf、ppt以及反向转换

热门文章

  1. Huggingface的from pretrained的下载代理服务器方法设置
  2. execve()函数的研究
  3. visio中直线交叉处消除跨线的方法
  4. 公司合伙人股权的进入和退出机制
  5. 长尾理论读书笔记:第一章 长尾市场
  6. 【Python自动化Excel】pandas处理Excel的拆分、合并
  7. 啊哈添柴挑战Java1581. 填数游戏(入门版)
  8. 【如何使用idea合并当前分支的代码到主分支】
  9. SQL数据库权限禁止授予deny
  10. java分布式服务框架Dubbo的介绍与使用