1.ResNet结构简介

ResNet——残差神经网络最重要的一个思想在于对接收的数据进行卷积等一系列操作后(F(x)),再与自身的identity(x)进行相加(即F(x)+x),之后一起经过relu层,这样一个模块叫做残差模块,而残差神经网络就是由多个残差模块以及其他一些层结合构成的。

残差神经网络有18层、34层、50层等等,值得一提的是从第50层开始残差模块与之前有一些区别——多了一些1*1的卷积层,这个目的是解决在最后进行F(x)+x操作时维度不一样的问题,同时可以使我们网络的计算复杂度减少,这点可以从表中的FLOPs的数值看出。

2.ResNet代码详解

a.BasicBlock类

class BasicBlock(nn.Module):expansion: int = 1def __init__(self,inplanes: int,planes: int,stride: int = 1,downsample: Optional[nn.Module] = None,groups: int = 1,base_width: int = 64,dilation: int = 1,norm_layer: Optional[Callable[..., nn.Module]] = None) -> None:super(BasicBlock, self).__init__()if norm_layer is None:norm_layer = nn.BatchNorm2dif groups != 1 or base_width != 64:raise ValueError('BasicBlock only supports groups=1 and base_width=64')if dilation > 1:raise NotImplementedError("Dilation > 1 not supported in BasicBlock")# Both self.conv1 and self.downsample layers downsample the input when stride != 1self.conv1 = conv3x3(inplanes, planes, stride)self.bn1 = norm_layer(planes)self.relu = nn.ReLU(inplace=True)self.conv2 = conv3x3(planes, planes)self.bn2 = norm_layer(planes)self.downsample = downsampleself.stride = stridedef forward(self, x: Tensor) -> Tensor:identity = xout = self.conv1(x)out = self.bn1(out)out = self.relu(out)out = self.conv2(out)out = self.bn2(out)if self.downsample is not None:identity = self.downsample(x)out += identityout = self.relu(out)return out

b.Bottleneck类

class Bottleneck(nn.Module):# Bottleneck in torchvision places the stride for downsampling at 3x3 convolution(self.conv2)# while original implementation places the stride at the first 1x1 convolution(self.conv1)# according to "Deep residual learning for image recognition"https://arxiv.org/abs/1512.03385.# This variant is also known as ResNet V1.5 and improves accuracy according to# https://ngc.nvidia.com/catalog/model-scripts/nvidia:resnet_50_v1_5_for_pytorch.expansion: int = 4def __init__(self,inplanes: int,planes: int,stride: int = 1,downsample: Optional[nn.Module] = None,groups: int = 1,base_width: int = 64,dilation: int = 1,norm_layer: Optional[Callable[..., nn.Module]] = None) -> None:super(Bottleneck, self).__init__()if norm_layer is None:norm_layer = nn.BatchNorm2dwidth = int(planes * (base_width / 64.)) * groups# Both self.conv2 and self.downsample layers downsample the input when stride != 1self.conv1 = conv1x1(inplanes, width)self.bn1 = norm_layer(width)self.conv2 = conv3x3(width, width, stride, groups, dilation)self.bn2 = norm_layer(width)self.conv3 = conv1x1(width, planes * self.expansion)self.bn3 = norm_layer(planes * self.expansion)self.relu = nn.ReLU(inplace=True)self.downsample = downsampleself.stride = stridedef forward(self, x: Tensor) -> Tensor:identity = xout = self.conv1(x)out = self.bn1(out)out = self.relu(out)out = self.conv2(out)out = self.bn2(out)out = self.relu(out)out = self.conv3(out)out = self.bn3(out)if self.downsample is not None:identity = self.downsample(x)out += identityout = self.relu(out)return out

上述a与b类分别是在网络结构层数较低和较高是使用的残差模块的代码,通过forward函数可以与一开始说到的网络结构相对应起来。

c.ResNet类

class ResNet(nn.Module):def __init__(self,block: Type[Union[BasicBlock, Bottleneck]],layers: List[int],num_classes: int = 1000,zero_init_residual: bool = False,groups: int = 1,width_per_group: int = 64,replace_stride_with_dilation: Optional[List[bool]] = None,norm_layer: Optional[Callable[..., nn.Module]] = None) -> None:super(ResNet, self).__init__()if norm_layer is None:norm_layer = nn.BatchNorm2dself._norm_layer = norm_layerself.inplanes = 64self.dilation = 1if replace_stride_with_dilation is None:# each element in the tuple indicates if we should replace# the 2x2 stride with a dilated convolution insteadreplace_stride_with_dilation = [False, False, False]if len(replace_stride_with_dilation) != 3:raise ValueError("replace_stride_with_dilation should be None ""or a 3-element tuple, got {}".format(replace_stride_with_dilation))self.groups = groupsself.base_width = width_per_groupself.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3,bias=False)self.bn1 = norm_layer(self.inplanes)self.relu = nn.ReLU(inplace=True)self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)self.layer1 = self._make_layer(block, 64, layers[0])self.layer2 = self._make_layer(block, 128, layers[1], stride=2,dilate=replace_stride_with_dilation[0])self.layer3 = self._make_layer(block, 256, layers[2], stride=2,dilate=replace_stride_with_dilation[1])self.layer4 = self._make_layer(block, 512, layers[3], stride=2,dilate=replace_stride_with_dilation[2])self.avgpool = nn.AdaptiveAvgPool2d((1, 1))self.fc = nn.Linear(512 * block.expansion, num_classes)for m in self.modules():if isinstance(m, nn.Conv2d):nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):nn.init.constant_(m.weight, 1)nn.init.constant_(m.bias, 0)# Zero-initialize the last BN in each residual branch,# so that the residual branch starts with zeros, and each residual block behaves like an identity.# This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677if zero_init_residual:for m in self.modules():if isinstance(m, Bottleneck):nn.init.constant_(m.bn3.weight, 0)  # type: ignore[arg-type]elif isinstance(m, BasicBlock):nn.init.constant_(m.bn2.weight, 0)  # type: ignore[arg-type]def _make_layer(self, block: Type[Union[BasicBlock, Bottleneck]], planes: int, blocks: int,stride: int = 1, dilate: bool = False) -> nn.Sequential:norm_layer = self._norm_layerdownsample = Noneprevious_dilation = self.dilationif dilate:self.dilation *= stridestride = 1if stride != 1 or self.inplanes != planes * block.expansion:downsample = nn.Sequential(conv1x1(self.inplanes, planes * block.expansion, stride),norm_layer(planes * block.expansion),)layers = []layers.append(block(self.inplanes, planes, stride, downsample, self.groups,self.base_width, previous_dilation, norm_layer))self.inplanes = planes * block.expansionfor _ in range(1, blocks):layers.append(block(self.inplanes, planes, groups=self.groups,base_width=self.base_width, dilation=self.dilation,norm_layer=norm_layer))return nn.Sequential(*layers)def _forward_impl(self, x: Tensor) -> Tensor:# See note [TorchScript super()]x = self.conv1(x)x = self.bn1(x)x = self.relu(x)x = self.maxpool(x)x = self.layer1(x)x = self.layer2(x)x = self.layer3(x)x = self.layer4(x)x = self.avgpool(x)x = torch.flatten(x, 1)x = self.fc(x)return xdef forward(self, x: Tensor) -> Tensor:return self._forward_impl(x)

【Pranet】论文及代码解读(ResNet部分)——jialiang nie相关推荐

  1. Memory-Associated Differential Learning论文及代码解读

    Memory-Associated Differential Learning论文及代码解读 论文来源: 论文PDF: Memory-Associated Differential Learning论 ...

  2. EGNet: Edge Guidance Network for Salient Object Detection 论文及代码解读

    EGNet: Edge Guidance Network for Salient Object Detection 论文及代码解读 注:本文原创作者为Jia-Xing Zhao, Jiang-Jian ...

  3. TCN论文及代码解读总结

    前言:传统的时序处理,普遍采用RNN做为基础网络模型,如其变体LSTM.GRU.BPTT等.但是在处理使用LSTM时时序的卷积神经网络 目录 论文及代码链接 一.论文解读 1. 摘要 2.引言(摘) ...

  4. VGAE(Variational graph auto-encoders)论文及代码解读

    一,论文来源 论文pdf Variational graph auto-encoders 论文代码 github代码 二,论文解读 理论部分参考: Variational Graph Auto-Enc ...

  5. 一文详解单目VINS论文与代码解读目录

    本文旨在对前一阶段学习vins-mono开源框架的总结.结合暑假秋招之前报名的深蓝学院的<从零开始手写VIO>课程,本文从VIO原理以及开源代码分析两部分进行详细介绍.PS:提升代码能力最 ...

  6. LSS-lift splat shoot论文与代码解读

    目录 序言 论文 代码 总结 序言 最近开始学习多摄融合领域了,定义是输入为多个摄像机图像,获得多个视角的相机图像特征,通过相机内外参数进行特征映射到BEV视角,得到360°的视觉感知结果,今天分享的 ...

  7. Exploiting Shared Representations for Personalized Federated Learning 论文笔记+代码解读

    论文地址点这里 一. 介绍 联邦学习中由于各个客户端上数据异构问题,导致全局训练模型无法适应每一个客户端的要求.作者通过利用客户端之间的共同代表来解决这个问题.具体来说,将数据异构的联邦学习问题视为并 ...

  8. 融合transformer和对抗学习的多变量时间序列异常检测算法TranAD论文和代码解读...

    一.前言 今天的文章来自VLDB TranAD: Deep Transformer Networks for Anomaly Detection in Multivariate Time Series ...

  9. Selective Search for Object Recognition(IJCV 2013) 论文及代码解读

    这篇论文已经被很多人解读过啦,以下是传送门: Selective Search for Object Recognition Selective Search for Object Recogniti ...

最新文章

  1. 如何在linux系统下修改mysql密码_如何在linux下修改mysql数据库密码?linux修改数据库密码的方法...
  2. 《Spark大数据分析:核心概念、技术及实践》大数据技术一览
  3. 【c++】iostreeam中的类为何不可以直接定义一个无参对象呢
  4. python分类算法报告_Python机器学习(1)——决策树分类算法
  5. 面向对象回顾(构造函数、覆盖和重载、Query接口的list方法和iterate方法、面向对象的六原则一法则、反射、内部类)
  6. 看技术笔记,提高嵌入式基础
  7. 笑到打鸣~ | 今日趣图
  8. Python基本语法,python入门到精通
  9. gdal数据类型_科学网-gdal数据类型的代码的核心定义文件-林清莹的博文
  10. java文件快速扫描仪_有没有办法从Java中的方法提供自动扫描仪输入?
  11. w10计算机字体怎么设置在哪里设置,win10系统电脑字体设置的操作方法
  12. 职场:如何成为PPT高手【01思维篇】
  13. centos oracle libaio哪下载,linux 安装libaio
  14. 柯尔莫哥洛夫最后的问题
  15. Linux 下查看内存问题
  16. Unity基础(三)--动画系统
  17. 计算机专业为职业环境分析,2021年计算机职业环境分析报告|计算机职业环境分析-得范文网...
  18. Settings sync 配置与使用
  19. lnmp全面优化集合nginx+mysql+php
  20. 安全港到隐私护盾!美欧个人数据跨境流动20年政策变迁

热门文章

  1. C#对接银行接口总结
  2. 获取企业微信授权code
  3. 清华计算机校友郭毅可院士履新,任港科大首席副校长
  4. keil c支持汇编语言吗,keil中用汇编实现hello.c的功能
  5. c语言十七:动态内存申请
  6. ftp、sftp利用bat脚本自动下载以及上传文件
  7. 第二讲:线性结构-Go语言实现
  8. plsr matlab,matlab中的偏最小二乘回归(PLSR)和主成分回归(PCR)
  9. 小明左、右手中分别拿两张纸牌(比如:黑桃10和红桃8,数字10和8可通过键盘录入),要求编写代码交换小明手中的牌
  10. jQuery之浏览器打印插件