见:D:\pythonCodes\深度学习实验\4.1_经典分类网络\inference代码汇总\models\se_resnet.py

一、SE-ResNet的实现方法

读了senet这篇论文之后,可以知道senet并没有提出一个新的网络,而是提出了一个即插即用的模块。这个模块叫做SE Block(在实现的时候,为了防止与SEBasicBlock这个名字混淆,叫做SELayer)。

本文希望实现se_resnet网络,也就是将SE Block嵌入到ResNet中形成的网络。se_resnet与resnet的差别就是,就是在BasicBlock(resnet18/34使用的是BasicBlock堆叠,而resnet50/101/152使用的是Bottleneck进行堆叠,这里就以BasicBlock举例,Bottleneck完全一样)中增加了SE Block这个操作。

比如下图,上面是BasicBlock的结构,下图就是SEBasic的结构,就是多出来了一个小圈圈。

通过读resnet的源码,我们知道是通过Resnet()这个类来组织成整个网络的。比如:

resnet34 = ResNet(BasicBlock, [3,4,6,3])

resnet18 = ResNet(BasicBlock, [2,2,2,2])

resnet50 = ResNet(Bottleneck, [3,4,6,3])

resnet101 = ResNet(Bottleneck, [3,4,23,3])

resnet152 = ResNet(Bottleneck, [3,8,36,3])

ResNet()接收两个参数,一个是block,另一个是堆叠的次数layers。只要传入参数,就能组织成一个网络了。比如传入的是BasicBlock,[3,4,6,3]就能得到resnet34了。这个函数就会自动地用3个BasicBlock组成layer1,用4个BasicBlock组成layer2,用6个BasicBlock组成layer3,用3个BasicBlock组成layer3,然后加上头尾等,组成一个网络。

我们可以利用ResNet()函数来构建我们的se_resent网络。只要给ResNet()传入SEBasicBlock和[3,4,6,3]就可以得到se_resnet34了。。

因此最关键的就是实现SEBasicBlock。而SEBasicBlock代码简直就是照抄BasicBlock代码,只要加上SELayer就行了。

(1) SELayer的实现

就是论文中的SE Block,在实现的时候,为了防止与SEBasicBlock这个名字混淆,叫做SELayer。

就是实现下面这个操作:

代码:

class SELayer(nn.Module):def __init__(self, channel, reduction=16):super(SELayer, self).__init__()self.avg_pool = nn.AdaptiveAvgPool2d(1)        #全局平均池化,输入BCHW -> 输出 B*C*1*1self.fc = nn.Sequential(nn.Linear(channel, channel // reduction, bias=False),   #可以看到channel得被reduction整除,否则可能出问题nn.ReLU(inplace=True),nn.Linear(channel // reduction, channel, bias=False),nn.Sigmoid())def forward(self, x):b, c, _, _ = x.size()y = self.avg_pool(x).view(b, c)     #得到B*C*1*1,然后转成B*C,才能送入到FC层中。y = self.fc(y).view(b, c, 1, 1)     #得到B*C的向量,C个值就表示C个通道的权重。把B*C变为B*C*1*1是为了与四维的x运算。return x * y.expand_as(x)           #先把B*C*1*1变成B*C*H*W大小,其中每个通道上的H*W个值都相等。*表示对应位置相乘。

(2)SEBasicBlock的写法

有了SELayer之后,就可以很容易地写SEBasicBlock了。就是照抄BasicBlock代码,只要加上SELayer就行了。

class SEBasicBlock(nn.Module):expansion = 1def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,base_width=64, dilation=1, norm_layer=None,*, reduction=16):# 参数列表里的 * 星号,标志着位置参数的就此终结,之后的那些参数,都只能以关键字形式来指定。super(SEBasicBlock, self).__init__()self.conv1 = conv3x3(inplanes, planes, stride)self.bn1 = nn.BatchNorm2d(planes)self.relu = nn.ReLU(inplace=True)self.conv2 = conv3x3(planes, planes, 1)self.bn2 = nn.BatchNorm2d(planes)self.se = SELayer(planes, reduction)self.downsample = downsampleself.stride = stridedef forward(self, x):residual = xout = self.conv1(x)out = self.bn1(out)out = self.relu(out)out = self.conv2(out)out = self.bn2(out)out = self.se(out)if self.downsample is not None:residual = self.downsample(x)out += residualout = self.relu(out)return out

就是多了以下两句:

其余的与senet中的basicBlock完全一致。SEBottleneck也是比Bottleneck多了这么一点。

二、完整代码

有了SELayer和SEBasicBlock之后就可以,借用resnet中的ResNet类来搭建SENet网络了。完整的代码如下,包含se_resnet18/34/50/101/152的实现。

# -*- coding: utf-8 -*-
"""
# @file name  : se_resnet.py
# @author     : https://github.com/moskomule/senet.pytorch
# @date       : 2020-08-07
# @brief      : se_resnet 模型搭建
"""
import torch
import torch.nn as nn
from torch.hub import load_state_dict_from_url
from torchvision.models import ResNet# 论文核心 SE Block, 这里称为 SE layer
class SELayer(nn.Module):def __init__(self, channel, reduction=16):super(SELayer, self).__init__()self.avg_pool = nn.AdaptiveAvgPool2d(1)        #全局平均池化,输入BCHW -> 输出 B*C*1*1self.fc = nn.Sequential(nn.Linear(channel, channel // reduction, bias=False), #可以看到channel得被reduction整除,否则可能出问题nn.ReLU(inplace=True),nn.Linear(channel // reduction, channel, bias=False),nn.Sigmoid())def forward(self, x):b, c, _, _ = x.size()y = self.avg_pool(x).view(b, c)   #得到B*C*1*1,然后转成B*C,才能送入到FC层中。y = self.fc(y).view(b, c, 1, 1)   #得到B*C的向量,C个值就表示C个通道的权重。把B*C变为B*C*1*1是为了与四维的x运算。return x * y.expand_as(x)         #先把B*C*1*1变成B*C*H*W大小,其中每个通道上的H*W个值都相等。*表示对应位置相乘。def conv3x3(in_planes, out_planes, stride=1):return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False)class SEBasicBlock(nn.Module):expansion = 1def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,base_width=64, dilation=1, norm_layer=None,*, reduction=16):# 参数列表里的 * 星号,标志着位置参数的就此终结,之后的那些参数,都只能以关键字形式来指定。super(SEBasicBlock, self).__init__()self.conv1 = conv3x3(inplanes, planes, stride)self.bn1 = nn.BatchNorm2d(planes)self.relu = nn.ReLU(inplace=True)self.conv2 = conv3x3(planes, planes, 1)self.bn2 = nn.BatchNorm2d(planes)self.se = SELayer(planes, reduction)self.downsample = downsampleself.stride = stridedef forward(self, x):residual = xout = self.conv1(x)out = self.bn1(out)out = self.relu(out)out = self.conv2(out)out = self.bn2(out)out = self.se(out)if self.downsample is not None:residual = self.downsample(x)out += residualout = self.relu(out)return outclass SEBottleneck(nn.Module):expansion = 4def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,base_width=64, dilation=1, norm_layer=None,*, reduction=16):# 参数列表里的 * 星号,标志着位置参数的就此终结,之后的那些参数,都只能以关键字形式来指定。super(SEBottleneck, self).__init__()self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)self.bn1 = nn.BatchNorm2d(planes)self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride,padding=1, bias=False)self.bn2 = nn.BatchNorm2d(planes)self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False)self.bn3 = nn.BatchNorm2d(planes * 4)self.relu = nn.ReLU(inplace=True)self.se = SELayer(planes * 4, reduction)self.downsample = downsampleself.stride = stridedef forward(self, x):residual = xout = self.conv1(x)out = self.bn1(out)out = self.relu(out)out = self.conv2(out)out = self.bn2(out)out = self.relu(out)out = self.conv3(out)out = self.bn3(out)out = self.se(out)if self.downsample is not None:residual = self.downsample(x)out += residualout = self.relu(out)return outdef se_resnet18(num_classes=1_000):model = ResNet(SEBasicBlock, [2, 2, 2, 2], num_classes=num_classes)model.avgpool = nn.AdaptiveAvgPool2d(1)return modeldef se_resnet34(num_classes=1_000):model = ResNet(SEBasicBlock, [3, 4, 6, 3], num_classes=num_classes)model.avgpool = nn.AdaptiveAvgPool2d(1)return modeldef se_resnet50(num_classes=1_000, pretrained=False):model = ResNet(SEBottleneck, [3, 4, 6, 3], num_classes=num_classes)model.avgpool = nn.AdaptiveAvgPool2d(1)if pretrained:model.load_state_dict(load_state_dict_from_url("https://github.com/moskomule/senet.pytorch/releases/download/archive/seresnet50-60a8950a85b2b.pkl"))return modeldef se_resnet101(num_classes=1_000):model = ResNet(SEBottleneck, [3, 4, 23, 3], num_classes=num_classes)model.avgpool = nn.AdaptiveAvgPool2d(1)return modeldef se_resnet152(num_classes=1_000):model = ResNet(SEBottleneck, [3, 8, 36, 3], num_classes=num_classes)model.avgpool = nn.AdaptiveAvgPool2d(1)return modelif __name__ == "__main__":inputs = torch.randn(2, 3, 224, 224)model = se_resnet50(pretrained=False)outputs = model(inputs)print(outputs.size())

运行结果:

SE-ResNet的实现相关推荐

  1. Image1000优秀网络简介(目-标-分-类)

    文章目录 历年image1000优秀网络汇总 AlexNet VGGNet(2014亚军) GoogleNet(2014冠军) InceptionV1 Inception V2 Inception V ...

  2. AI-DPL, you should know

    CV 分类:LeNet, Alexnet,Vggnet, Googlenet(Inceptionv1234),Resnet, ResnetXt, InceptionResnet, DarkNet, D ...

  3. 谈一谈场景文本图片的超分辨

    引言 文本图像的超分辨任务做的不是很多,有专门针对文本识别的也有针对文本检测的,总而言之,带有文本序列的图像和在imangeNet里的图像是不一样的,那我们来仔细看一看文本图像大家都是怎么做的 Tex ...

  4. 快手如何玩转复杂场景下的说话人识别?| ASRU 2021

    快手是一个短视频社区,短视频和直播中通常混合各种形式的声音,如语音.音乐.特效音和背景噪声等,这些声音很好的提升了短视频和直播的用户消费体验,但同时也为音频内容理解带来极大的困难和挑战.如何在复杂场景 ...

  5. 【SENet】Squeeze-and-Excitation Networks (2017) 全文翻译

    作者 Jie Hu,Li Shen,Samuel Albanie,Gang Sun,Enhua Wu 摘要 卷积神经网络(CNNs)的核心组成部分是卷积算子,它使网络能够通过融合每层局部感受野中的空间 ...

  6. Squeeze-and-Excitation Networks(译)

    摘要 卷积神经网络建立在卷积运算的基础上,通过融合局部感受野内的空间信息和通道信息来提取信息特征.为了提高网络的表示能力,许多现有的工作已经显示出增强空间编码的好处.在这项工作中,我们专注于通道,并提 ...

  7. ResNet家族:ResNet、ResNeXt、SE Net、SE ResNeXt

    目录 ResNet DenseNet ResNeXt SE-ResNet, SE-ResNeXt (2018 Apr) 涉及到的其他知识: Global average pooling (GAP) 梯 ...

  8. 从ResNet、DenseNet、ResNeXt、SE Net、SE ResNeXt 演进学习总结

    本文主要总结一下最近学习ResNet.DenseNet.ResNeXt.SE Net.SE ResNeXt 的演进,归纳了一下整个特点,话不多说先上图: 1.ResNet 1.1 结构特点 1.sho ...

  9. 【论文分享】不平衡流量分类方法 DeepFE:ResNet+SE+non-local:Let Imbalance Have Nowhere to Hide

    论文:Let Imbalance Have Nowhere to Hide Class-Sensitive Feature Extraction for Imbalanced Traffic Clas ...

  10. 如何使ResNet优于EfficientNet?

    视学算法报道 转载自:机器之心 编辑:魔王.维度 架构变化.训练方法和扩展策略是影响模型性能的不可或缺的重要因素,而当前的研究只侧重架构的变化.谷歌大脑和 UC 伯克利的一项最新研究重新审视了 Res ...

最新文章

  1. 自动调整速率的Actor设计模式
  2. 使用 vue.js 的一些操作记录
  3. 吴恩达:我们说人工智能时,实际在说些什么?
  4. 卷积的物理意义(经典)
  5. eureka对比Zookeeper:
  6. Docker平台的基本使用方法
  7. (备忘)卸载微软自带输入法
  8. powershell快捷键_关于powershell的知识你知道多少呢
  9. tomcat start 无法启动_解密Springboot内嵌Tomcat
  10. 接口自动化测试框架搭建(9、自动化测试case的编写)--python+HTMLTestRunnerCN+request+unittest+mock+db
  11. Atitit 流水线子线程异常处理 1.1. 大概原理是 FutureTask排除异常 FutureTask.get can throw ExecutionException,can catc
  12. linux 卸载jdk和安装
  13. 基于javaweb,springboot银行管理系统
  14. qq空间留言板删除 php,怎么批量删除QQ空间的说说
  15. Kubernetes 学习总结(27)—— Kubernetes 安装 Redis 集群的两个方案
  16. 【C语言蓝桥杯每日一题】——跑步锻炼
  17. ESP8266-01 MQTT固件烧录并连接阿里云服务器
  18. nginx+uwsgi+django1.9+mysql+python2.7部署到CentOS6.5
  19. CGAL 凹包(alpha-Shape)
  20. 2020年全球便携式储能行业发展现状、竞争格局及未来发展趋势分析,市场规模呈现高速增长,行业潜力巨大「图」

热门文章

  1. C++类和对象的使用之对象指针
  2. 低成本2.4G SOC(NYA054E)灯控遥控芯片方案-CI2454/CI2451
  3. SCI-15种投稿状态
  4. 大英百科挂了,维基百科赢了
  5. 如何同步化本地svn库到googlecode
  6. 【git与github交互之主分支和次分支切换、合并等】
  7. 数字孪生技术海上风电场解决方案
  8. 海龟如何保留米帝手机号
  9. 小程序页面之间跳转的方式
  10. 游戏设计的艺术:一本透镜的书——第十五章 其中一种体验是故事