一、introduction

Res2Net由南开大学程明明组2019年提出,主要贡献是对ResNet模型中的block模块进行了改进,计算负载不增加,特征提取能力更强大。
论文地址:Res2Net: A New Multi-scale Backbone Architecture

二、网络结构

回顾ResNet网络结构:https://blog.csdn.net/qq_40356092/article/details/109024375

左图是ResNet网络中的block模块,右图是论文中新提出来的Res2Net模块。简单来说,Res2Net就是将3×3卷积层的输入分成了四个部分,网络内部又以残差式的风格进行连接。

计算公式如下所示:

y1 = x1;

y2 = x2*(3x3)= K2;

y3 =(K2 + x3)*(3x3)= K3 ;

y4 =(K3 + x4)*(3x3)= K4

三、设计思路

1、首先我们来讲讲ResNet残差网络的设计思路。早期的神经网络结构设计非常简单,通过卷积层、池化层和全连接层的线性堆叠来提取图片中的特征;我们可以将神经网络看作是一个大型的数学公式,输入是一个矩阵X,输出是Y,对X进行的操作主要有线性变换和非线性变换,将其统一看作F,那么就有Y=F(X)

为了使得训练之后的神经网络能够具有更好地识别效果,在输入X之前我们往往会对图片进行预处理、归一化等操作,将像素值归一到0 ~ 1(或者-1 ~ 1)之间,使得输入在同一个量级上。

同样地,我们也可以在网络之中引入类似的处理方法,也就是ResNet中提出的残差连接,Y=F(X)+X,这样网络在训练时就只需要学习到关于X的一个偏差就可以(F(X)=Y-X),极大增加了网络的可训练性。
这种思想在深度学习领域的应用极其广泛,并且都取得了不错的结果。比如yolo v3中只预测box框的偏差而不直接预测box框的坐标;在卷积层之后加上BN层;

2、Res2Net的贡献
Res2Net提出了一个新的概念:尺度(scale)
CNN网络中除了深度,宽度和基数等现有维度之外,尺度也是一个必不可少的因素。将输入X拆分成四个部分,每个部分通过不同的卷积层之后再融合到一起,得到的输出会获得更大的感受野,而且一些额外的计算开销可以忽略。(在实际实验中,运行速度会慢20%左右)
Res2Net模块可以很好地与现有模型进行融合

比如3×3卷积层的数量可以任意调整,在1×1网络的最后可以加上SE block(关于SE block后续会进行讲解)。

四、代码实现

import torch
from torch import nn#需要分类的类别数
classes=5#SE模块
class SEModule(nn.Module):def __init__(self, channels, reduction=16):super(SEModule, self).__init__()self.avg_pool = nn.AdaptiveAvgPool2d(1)self.fc1 = nn.Conv2d(channels, channels // reduction, kernel_size=1, padding=0)self.relu = nn.ReLU(inplace=True)self.fc2 = nn.Conv2d(channels // reduction, channels, kernel_size=1, padding=0)self.sigmoid = nn.Sigmoid()def forward(self, input):x = self.avg_pool(input)x = self.fc1(x)x = self.relu(x)x = self.fc2(x)x = self.sigmoid(x)return input * xclass Res2NetBottleneck(nn.Module):expansion = 4  #残差块的输出通道数=输入通道数*expansiondef __init__(self, inplanes, planes, downsample=None, stride=1, scales=4, groups=1, se=True,  norm_layer=True):#scales为残差块中使用分层的特征组数,groups表示其中3*3卷积层数量,SE模块和BN层super(Res2NetBottleneck, self).__init__()if planes % scales != 0: #输出通道数为4的倍数raise ValueError('Planes must be divisible by scales')if norm_layer:  #BN层norm_layer = nn.BatchNorm2dbottleneck_planes = groups * planesself.scales = scalesself.stride = strideself.downsample = downsample#1*1的卷积层,在第二个layer时缩小图片尺寸self.conv1 = nn.Conv2d(inplanes, bottleneck_planes, kernel_size=1, stride=stride)self.bn1 = norm_layer(bottleneck_planes)#3*3的卷积层,一共有3个卷积层和3个BN层self.conv2 = nn.ModuleList([nn.Conv2d(bottleneck_planes // scales, bottleneck_planes // scales,kernel_size=3, stride=1, padding=1, groups=groups) for _ in range(scales-1)])self.bn2 = nn.ModuleList([norm_layer(bottleneck_planes // scales) for _ in range(scales-1)])#1*1的卷积层,经过这个卷积层之后输出的通道数变成self.conv3 = nn.Conv2d(bottleneck_planes, planes * self.expansion, kernel_size=1, stride=1)self.bn3 = norm_layer(planes * self.expansion)self.relu = nn.ReLU(inplace=True)#SE模块self.se = SEModule(planes * self.expansion) if se else Nonedef forward(self, x):identity = x#1*1的卷积层out = self.conv1(x)out = self.bn1(out)out = self.relu(out)#scales个(3x3)的残差分层架构xs = torch.chunk(out, self.scales, 1) #将x分割成scales块ys = []for s in range(self.scales):if s == 0:ys.append(xs[s])elif s == 1:ys.append(self.relu(self.bn2[s-1](self.conv2[s-1](xs[s]))))else:ys.append(self.relu(self.bn2[s-1](self.conv2[s-1](xs[s] + ys[-1]))))out = torch.cat(ys, 1)#1*1的卷积层out = self.conv3(out)out = self.bn3(out)#加入SE模块if self.se is not None:out = self.se(out)#下采样if self.downsample:identity = self.downsample(identity)out += identityout = self.relu(out)return outclass Res2Net(nn.Module):def __init__(self, layers, num_classes, width=16, scales=4, groups=1,zero_init_residual=True, se=True, norm_layer=True):super(Res2Net, self).__init__()if norm_layer:  #BN层norm_layer = nn.BatchNorm2d#通道数分别为64,128,256,512planes = [int(width * scales * 2 ** i) for i in range(4)]self.inplanes = planes[0]#7*7的卷积层,3*3的最大池化层self.conv1 = nn.Conv2d(3, planes[0], kernel_size=7, stride=2, padding=3,bias=False)self.bn1 = norm_layer(planes[0])self.relu = nn.ReLU(inplace=True)self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)#四个残差块self.layer1 = self._make_layer(Res2NetBottleneck, planes[0], layers[0], stride=1, scales=scales, groups=groups, se=se, norm_layer=norm_layer)self.layer2 = self._make_layer(Res2NetBottleneck, planes[1], layers[1], stride=2, scales=scales, groups=groups, se=se, norm_layer=norm_layer)self.layer3 = self._make_layer(Res2NetBottleneck, planes[2], layers[2], stride=2, scales=scales, groups=groups, se=se, norm_layer=norm_layer)self.layer4 = self._make_layer(Res2NetBottleneck, planes[3], layers[3], stride=2, scales=scales, groups=groups, se=se, norm_layer=norm_layer)#自适应平均池化,全连接层self.avgpool = nn.AdaptiveAvgPool2d((1, 1))self.fc = nn.Linear(planes[3] * Res2NetBottleneck.expansion, num_classes)#初始化for m in self.modules():if isinstance(m, nn.Conv2d):nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):nn.init.constant_(m.weight, 1)nn.init.constant_(m.bias, 0)#零初始化每个剩余分支中的最后一个BN,以便剩余分支从零开始,并且每个剩余块的行为类似于一个恒等式if zero_init_residual:for m in self.modules():if isinstance(m, Res2NetBottleneck):nn.init.constant_(m.bn3.weight, 0)def _make_layer(self, block, planes, blocks, stride=1, scales=4, groups=1, se=True, norm_layer=True):if norm_layer:norm_layer = nn.BatchNorm2ddownsample = None  #下采样,可缩小图片尺寸if stride != 1 or self.inplanes != planes * block.expansion:downsample = nn.Sequential(nn.Conv2d(self.inplanes, planes * block.expansion, kernel_size=1, stride=stride),norm_layer(planes * block.expansion),)layers = []layers.append(block(self.inplanes, planes, downsample, stride=stride, scales=scales, groups=groups, se=se, norm_layer=norm_layer))self.inplanes = planes * block.expansionfor _ in range(1, blocks):layers.append(block(self.inplanes, planes, scales=scales, groups=groups, se=se, norm_layer=norm_layer))return nn.Sequential(*layers)def forward(self, x):x = self.conv1(x)x = self.bn1(x)x = self.relu(x)x = self.maxpool(x)x = self.layer1(x)x = self.layer2(x)x = self.layer3(x)x = self.layer4(x)x = self.avgpool(x)x = x.view(x.size(0), -1)logits = self.fc(x)probas = nn.functional.softmax(logits, dim=1)return probas

使用pytorch搭建自己的网络之Res2Net相关推荐

  1. 基于PyTorch的生成对抗网络入门(3)——利用PyTorch搭建生成对抗网络(GAN)生成彩色图像超详解

    目录 一.案例描述 二.代码详解 2.1 获取数据 2.2 数据集类 2.3 构建判别器 2.3.1 构造函数 2.3.2 测试判别器 2.4 构建生成器 2.4.1 构造函数 2.4.2 测试生成器 ...

  2. 【PyTorch】PyTorch搭建基础VGG16网络

    vgg16网络结构: 源码: import torch import torch.nn as nn from torch.autograd import Variablecfg = {'vgg16': ...

  3. pytorch 搭建cnn resnet50网络进行图片分类 代码详解

    数据样式: 直接上代码: import pathlib import tensorflow as tf import matplotlib.pyplot as plt import os, PIL, ...

  4. 实战:使用Pytorch搭建分类网络(肺结节假阳性剔除)

    实战:使用Pytorch搭建分类网络(肺结节假阳性剔除) 阅前可看: 实战:使用yolov3完成肺结节检测(Luna16数据集)及肺实质分割 其中的脚本资源getMat.py文件是对肺结节进行切割. ...

  5. 使用Pytorch搭建U-Net网络并基于DRIVE数据集训练(语义分割)学习笔记

    使用Pytorch搭建U-Net网络并基于DRIVE数据集训练(语义分割)学习笔记 https://www.bilibili.com/video/BV1rq4y1w7xM?spm_id_from=33 ...

  6. 使用 PyTorch 搭建网络 - predict_py篇

    predict_py篇 python中采用驼峰书写法且首字母大写的变量符号一般表示类名. 学习网络步骤:看原论文+看别人对原论文的理解,学习网络结构,看损失函数计算,看数据集,看别人写的代码,复现代码 ...

  7. 利用Pytorch搭建简单的图像分类模型(之二)---搭建网络

    Pytorch搭建网络模型-ResNet 一.ResNet的两个结构 首先来看一下ResNet和一般卷积网络结构上的差异: 图中上面一部分就是ResNet34的网络结构图,下面可以理解为一个含有34层 ...

  8. 使用pytorch搭建MLP多层感知器分类网络判断LOL比赛胜负

    使用pytorch搭建MLP多层感知器分类网络判断LOL比赛胜负 1. 数据集 百度网盘链接,提取码:q79p 数据集文件格式为CSV.数据集包含了大约5万场英雄联盟钻石排位赛前15分钟的数据集合,总 ...

  9. ResNet网络详解并使用pytorch搭建模型、并基于迁移学习训练

    1.ResNet网络详解 网络中的创新点: (1)超深的网络结构(突破1000层) (2)提出residual模块 (3)使用Batch Normalization加速训练(丢弃dropout) (1 ...

最新文章

  1. 谢文睿:西瓜书 + 南瓜书 吃瓜系列 3. 对数几率回归
  2. 2021-03-19Tomcat源码学习--WebAppClassLoader类加载机制
  3. 文本数据分析——主题提取+词向量化
  4. 86. Leetcode 264. 丑数 II (动态规划-基础题)
  5. 从Tronbull引狂欢,看APENFT与波场带来的新可能
  6. 帆软日期控件变灰_FineReport-JS脚本常见日期使用整理
  7. Android复合控件创建与使用Demo
  8. D - Maximum Value Problem FZU - 2037
  9. 【CHM】.chm文件无法正常显示的解决方案
  10. Linux下启动启动tomcat 服务器报错 The file is absent or does not have execute permission
  11. c语言statistics函数,Logistic回归中C-Statistics计算方法
  12. Linux命令之diff
  13. linux nginx php 启动命令,linux nginx启动,重启,关闭命令
  14. 滴滴自动驾驶首轮融资超5亿美元 加大研发投入 助力“新基建”
  15. 【腾讯内部干货分享】分析Dalvik字节码进行减包优化
  16. kepware mysql_Kepware EX6与MySQL数据库通讯(上篇)
  17. EMP电磁脉冲射频发射器制作教程
  18. eclipse常用快捷键和设置
  19. HTML中 常见的浏览器内核有哪些,主流浏览器的内核以及内核前缀是什么?
  20. (搬砖)Epic/Feature/Story/Task/Bug到底是什么

热门文章

  1. 看不见的共享电单车战争
  2. 苹果备忘录分享不了微信提示无法连接服务器,微信分享提示universal link 校验不通过...
  3. 职工线上健步走活动小程序方案,通过微信小程序实现功能,getElementById(“demo“).innerHTML=x
  4. MFC 右键菜单呼出
  5. 计算机教育专业的专业任选课,什么叫自由选修课 又什么叫全校任选课
  6. Kitty猫基因编码
  7. 没毕业就3次跳槽的经历,走不平凡的路,让人跌破眼镜。
  8. 【弘成基】运用资料整理
  9. MFC在对话框中绘制图像
  10. # 计算圆周长和面积