看resnext的结构图我们就知道,ResNeXt与ResNet没有什么太大的区别。比如,ResNeXt50与ResNet50都可以使用bottleneck结构,只是输出通道变一下,中间的3*3卷积改成分组卷积就行。在设计bootleneck的时候,可以传入相应的参数,做一下区分。

实际上在torchvision中,实现ResNet的时候就把这一点考虑进去了。所以ResNet和ResNeXt用的是同一份代码。

代码如下,我把与ResNeXt有关的拿出来了:

import torch
import torch.nn as nndef conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):"""3x3 convolution with padding"""return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,padding=dilation, groups=groups, bias=False, dilation=dilation)def conv1x1(in_planes, out_planes, stride=1):"""1x1 convolution"""return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)class Bottleneck(nn.Module):expansion = 4__constants__ = ['downsample']def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,base_width=64, dilation=1, norm_layer=None):super(Bottleneck, self).__init__()if norm_layer is None:norm_layer = nn.BatchNorm2dwidth = int(planes * (base_width / 64.)) * groups# Both self.conv2 and self.downsample layers downsample the input when stride != 1self.conv1 = conv1x1(inplanes, width)self.bn1 = norm_layer(width)self.conv2 = conv3x3(width, width, stride, groups, dilation)self.bn2 = norm_layer(width)self.conv3 = conv1x1(width, planes * self.expansion)self.bn3 = norm_layer(planes * self.expansion)self.relu = nn.ReLU(inplace=True)self.downsample = downsampleself.stride = stridedef forward(self, x):identity = xout = self.conv1(x)out = self.bn1(out)out = self.relu(out)out = self.conv2(out)out = self.bn2(out)out = self.relu(out)out = self.conv3(out)out = self.bn3(out)if self.downsample is not None:identity = self.downsample(x)out += identityout = self.relu(out)return outclass ResNet(nn.Module):def __init__(self, block, layers, num_classes=1000, zero_init_residual=False,groups=1, width_per_group=64, replace_stride_with_dilation=None,norm_layer=None):super(ResNet, self).__init__()if norm_layer is None:norm_layer = nn.BatchNorm2dself._norm_layer = norm_layerself.inplanes = 64self.dilation = 1if replace_stride_with_dilation is None:# each element in the tuple indicates if we should replace# the 2x2 stride with a dilated convolution insteadreplace_stride_with_dilation = [False, False, False]if len(replace_stride_with_dilation) != 3:raise ValueError("replace_stride_with_dilation should be None ""or a 3-element tuple, got {}".format(replace_stride_with_dilation))self.groups = groupsself.base_width = width_per_groupself.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3,bias=False)self.bn1 = norm_layer(self.inplanes)self.relu = nn.ReLU(inplace=True)self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)self.layer1 = self._make_layer(block, 64, layers[0])self.layer2 = self._make_layer(block, 128, layers[1], stride=2,dilate=replace_stride_with_dilation[0])self.layer3 = self._make_layer(block, 256, layers[2], stride=2,dilate=replace_stride_with_dilation[1])self.layer4 = self._make_layer(block, 512, layers[3], stride=2,dilate=replace_stride_with_dilation[2])self.avgpool = nn.AdaptiveAvgPool2d((1, 1))self.fc = nn.Linear(512 * block.expansion, num_classes)for m in self.modules():if isinstance(m, nn.Conv2d):nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):nn.init.constant_(m.weight, 1)nn.init.constant_(m.bias, 0)# Zero-initialize the last BN in each residual branch,# so that the residual branch starts with zeros, and each residual block behaves like an identity.# This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677if zero_init_residual:for m in self.modules():if isinstance(m, Bottleneck):nn.init.constant_(m.bn3.weight, 0)def _make_layer(self, block, planes, blocks, stride=1, dilate=False):norm_layer = self._norm_layerdownsample = Noneprevious_dilation = self.dilationif dilate:self.dilation *= stridestride = 1if stride != 1 or self.inplanes != planes * block.expansion:downsample = nn.Sequential(conv1x1(self.inplanes, planes * block.expansion, stride),norm_layer(planes * block.expansion),)layers = []layers.append(block(self.inplanes, planes, stride, downsample, self.groups,self.base_width, previous_dilation, norm_layer))self.inplanes = planes * block.expansionfor _ in range(1, blocks):layers.append(block(self.inplanes, planes, groups=self.groups,base_width=self.base_width, dilation=self.dilation,norm_layer=norm_layer))return nn.Sequential(*layers)def _forward_impl(self, x):# See note [TorchScript super()]x = self.conv1(x)x = self.bn1(x)x = self.relu(x)x = self.maxpool(x)x = self.layer1(x)x = self.layer2(x)x = self.layer3(x)x = self.layer4(x)x = self.avgpool(x)x = torch.flatten(x, 1)x = self.fc(x)return xdef forward(self, x):return self._forward_impl(x)def _resnet(arch, block, layers, pretrained, progress, **kwargs):model = ResNet(block, layers, **kwargs)#if pretrained:#    state_dict = load_state_dict_from_url(model_urls[arch], progress=progress)#    model.load_state_dict(state_dict)return modeldef resnext50_32x4d(pretrained=False, progress=True, **kwargs):r"""ResNeXt-50 32x4d model from`"Aggregated Residual Transformation for Deep Neural Networks" <https://arxiv.org/pdf/1611.05431.pdf>`_Args:pretrained (bool): If True, returns a model pre-trained on ImageNetprogress (bool): If True, displays a progress bar of the download to stderr"""kwargs['groups'] = 32kwargs['width_per_group'] = 4return _resnet('resnext50_32x4d', Bottleneck, [3, 4, 6, 3],pretrained, progress, **kwargs)if __name__ == "__main__":data = torch.randn(2, 3, 224, 224)net = resnext50_32x4d()output = net(data)print(output.size())

这其实就是ResNet的实现代码。resnext50_32x4d()调用_resnet()函数,除了传入了bottleneck和[3,4,6,3]外,多传入了"groups"和"width_per_group"两个参数。

_resnet()调用ResNet()函数。ResNet()来构建各层的时候,把"groups"和"width_per_group"传给Bottleneck来实现聚合变换的操作:

当groups=1, base_width=64时,就是没有分组的情况,就是Resnet。当是别的数时,就是ResNeXt了。。

中间的3*3的分组卷积,也是有现成的函数:

只要设置一个groups参数,就行了。

6. torchvision中的ResNeXt实现相关推荐

  1. TorchVision中通过AlexNet网络进行图像分类

    TorchVision中给出了AlexNet的pretrained模型,模型存放位置为https://download.pytorch.org/models/alexnet-owt-4df8aa71. ...

  2. TorchVision中使用FasterRCNN+ResNet50+FPN进行目标检测

    TorchVision中给出了使用ResNet-50-FPN主干(backbone)构建Faster R-CNN的pretrained模型,模型存放位置为https://download.pytorc ...

  3. 使用torchvision 中的roi_pool/roi_align函数时报错

    使用torchvision 中的roi_pool/roi_align函数时报错 Traceback (most recent call last):File "/home/wkj/cj/ro ...

  4. Pytorch基本操作(5)——torchvision中的Dataset以及Dataloader

    简介 在学习李沐在B站发布的<动手学深度学习>PyTorch版本教学视频中发现在操作使用PyTorch方面有许多地方看不懂,往往只是"动手"了,没有动脑.所以打算趁着寒 ...

  5. torchvision中Transform的normalize

    RGB還是GBR? 根據[1][2],應該是: RGB transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) 第一個中括 ...

  6. PyTorch深度学习入门笔记(五)torchvision中DataLoader的使用

    dataloader简介 dataset在程序中起到的作用是告诉程序数据在哪,每个索引所对应的数据是什么.相当于一系列的存储单元,每个单元都存储了数据.这里可以类比成一幅扑克牌,一张扑克牌就是一个数据 ...

  7. 深度学习(PyTorch)——torchvision中的数据集使用方法

    B站UP主"我是土堆"视频内容 torchvision简介 torchvision是pytorch的一个图形库,它服务于PyTorch深度学习框架的,主要用来构建计算机视觉模型.以 ...

  8. (已解决torchvision中CIFAR10下载速度慢)如何下载以及使用torchvision导入

    背景 train_set = torchvision.datasets.CIFAR10(root='./data', train=True,download=True, transform=trans ...

  9. 【Pytorch】加载torchvision中预训练好的模型并修改默认下载路径(使用models.__dict__[model_name]()读取)

    说明 使用torchvision.model加载预训练好的模型时,发现默认下载路径在系统盘下面的用户目录下(这个你执行的时候就会发现),即C:\用户名\.cache\torch\.checkpoint ...

最新文章

  1. R语言与数据分析(2)-R语言简介
  2. python语言包含的错误,Python语言程序中包含的错误,一般分为三种,以下____________不是其中的一种...
  3. python学习笔记之socket(第七天)
  4. mysqldump命令
  5. 洛谷P5173 传球(暴力)
  6. Qt实现基本QMainWindow主窗口程序
  7. 【新东方老师推荐】老师推荐--听说——这是全球最值得听的、最好听的100首英文歌...
  8. 查看电脑重启日志_系统日志看硬盘故障图文教程,电脑日志查看磁盘硬盘坏道问题方法...
  9. 使用cookie保存用户名状态
  10. linux定时任务_linux定时任务cron HelloWorld
  11. mac搭建k8s练习环境
  12. Magento首页不显示产品
  13. WINDOS服务器安全设置
  14. AI今年最大进展就是毫无进展?2019年AutoML、GAN将扛大旗
  15. 计算机网络 方面应用研究,计算机网络技术应用研究
  16. 既然Talk is cheap, 那么就用代码教你如何进行正交设计
  17. Classification metrics can't handle a mix of continuous and multiclass targets
  18. 干了5Android开发还没掌握-binder-机制、驱动核心源码?我劝你早点改行吧
  19. 百度地图线路颜色_山东到底发展成了什么样子,这两张地图不会说谎
  20. 文字溢出省略和用户体验优化

热门文章

  1. MacOS Pycharm 配置 anaconda 环境
  2. getenv java,java System.getenv环境名称以“=”开头
  3. 你需要知道的 Python 3.10 和 Python 3.9 之间的差异
  4. C#界面开发终极UI工具包分享——Krypton
  5. 欧盟能效标签(Erp)
  6. docker_note_3_Docker容器的创建、启动、和停止 、导入、导出、删除容器,docker仓库
  7. 企业短信应用平台的设计
  8. Java SE 基础概述(一)
  9. 我是怎么选搜索引擎的
  10. 7-41 高空坠球(20 分)