简介
SeNet是一种通道注意力卷积,通俗的讲,就是对每层特征图求平均值,得到(N,C,1,1),然后经过全连接得到对应系数(N,C,1,1),再乘到原始的特征图上(N,C,H,W);
原理上是挑选出更需要关注的特征,特征图是由卷积核卷积得到,有的卷积核关注横向的梯度,有的关注纵向的梯度,每个卷积核提取的特征不同,如果一个卷积核的系数值比较小,特征图的值也就比较小,也就意味着这个卷积核没有那么重要,那么SENet就希望缩小不重要的特征的系数,增大重要特征的系数;
SeNet 模块源码介绍:

from torch import nn
class SELayer(nn.Module):def __init__(self, channel, reduction=16):super(SELayer, self).__init__()self.avg_pool = nn.AdaptiveAvgPool2d(1)self.fc = nn.Sequential(nn.Linear(channel, channel // reduction, bias=False),nn.ReLU(inplace=True),nn.Linear(channel // reduction, channel, bias=False),nn.Sigmoid())def forward(self, x):b, c, _, _ = x.size()y = self.avg_pool(x).view(b, c)#先pooling获取通道维度,也就是把H,w=>1y = self.fc(y).view(b, c, 1, 1)#全连接层对通道系数进行调整#tensor_1.expand_as(tensor_2) :把tensor_1扩展成和tensor_2一样的形状return x * y.expand_as(x)#原值乘以对应系数

网络流程图介绍
SeNet是放在卷积之后,对不同通道以不同系数相乘的操作,可以方便的移植叠加到通用模块后面,比如Resnet或者inception net,既然是加法操作,就会一定程度上增加计算量;

可以将SeNet用在对Inception的输出上

Inception v3本身的网络结构如下图所示:

Inception是将特征图经过不同大小的卷积核卷积后再concat,Inceptionv1是将特征图先降维再经过不同大小卷积核卷积,Inceptionv2是用多个3X3的卷积核去替换5X5、7X7的卷积核;Inceptionv3使用1Xn,nX1的卷积去替换nXn的卷积;因为这样可以减少参数量的同时,加深网络的深度,增加非线性,增强网络的表达能力;Inception系列讲的比较好的博文有:https://www.jianshu.com/p/6d66fa4ca9d7
Inception的输出特征图更多,也就需要接在不同的输出后面进行Se,具体实现源码如下图:

from senet.se_module import SELayer
from torch import nn
from torchvision.models.inception import Inception3class SEInception3(nn.Module):def __init__(self, num_classes, aux_logits=True, transform_input=False):super(SEInception3, self).__init__()model = Inception3(num_classes=num_classes, aux_logits=aux_logits,transform_input=transform_input)model.Mixed_5b.add_module("SELayer", SELayer(192))model.Mixed_5c.add_module("SELayer", SELayer(256))model.Mixed_5d.add_module("SELayer", SELayer(288))model.Mixed_6a.add_module("SELayer", SELayer(288))model.Mixed_6b.add_module("SELayer", SELayer(768))model.Mixed_6c.add_module("SELayer", SELayer(768))model.Mixed_6d.add_module("SELayer", SELayer(768))model.Mixed_6e.add_module("SELayer", SELayer(768))if aux_logits:model.AuxLogits.add_module("SELayer", SELayer(768))model.Mixed_7a.add_module("SELayer", SELayer(768))model.Mixed_7b.add_module("SELayer", SELayer(1280))model.Mixed_7c.add_module("SELayer", SELayer(2048))self.model = modeldef forward(self, x):_, _, h, w = x.size()if (h, w) != (299, 299):raise ValueError("input size must be (299, 299)")return self.model(x)def se_inception_v3(**kwargs):return SEInception3(**kwargs)

也可以将SeNet用在对Resnet的残差叠加之前(原因细想一下就可以理解,因为Se是对卷积核的压制);

resnet+se的实现:

import torch.nn as nn
from torch.hub import load_state_dict_from_url
from torchvision.models import ResNet
from senet.se_module import SELayerdef conv3x3(in_planes, out_planes, stride=1):return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False)class SEBasicBlock(nn.Module):expansion = 1def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,base_width=64, dilation=1, norm_layer=None,*, reduction=16):super(SEBasicBlock, self).__init__()self.conv1 = conv3x3(inplanes, planes, stride)self.bn1 = nn.BatchNorm2d(planes)self.relu = nn.ReLU(inplace=True)self.conv2 = conv3x3(planes, planes, 1)self.bn2 = nn.BatchNorm2d(planes)self.se = SELayer(planes, reduction)self.downsample = downsampleself.stride = stridedef forward(self, x):residual = xout = self.conv1(x)out = self.bn1(out)out = self.relu(out)out = self.conv2(out)out = self.bn2(out)#在残差叠加之前进行通道注意力out = self.se(out)if self.downsample is not None:residual = self.downsample(x)out += residualout = self.relu(out)return outclass SEBottleneck(nn.Module):expansion = 4def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,base_width=64, dilation=1, norm_layer=None,*, reduction=16):super(SEBottleneck, self).__init__()self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)self.bn1 = nn.BatchNorm2d(planes)self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride,padding=1, bias=False)self.bn2 = nn.BatchNorm2d(planes)self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False)self.bn3 = nn.BatchNorm2d(planes * 4)self.relu = nn.ReLU(inplace=True)self.se = SELayer(planes * 4, reduction)self.downsample = downsampleself.stride = stridedef forward(self, x):residual = xout = self.conv1(x)out = self.bn1(out)out = self.relu(out)out = self.conv2(out)out = self.bn2(out)out = self.relu(out)out = self.conv3(out)out = self.bn3(out)out = self.se(out)if self.downsample is not None:residual = self.downsample(x)out += residualout = self.relu(out)return outdef se_resnet18(num_classes=1_000):"""Constructs a ResNet-18 model.Args:pretrained (bool): If True, returns a model pre-trained on ImageNet"""model = ResNet(SEBasicBlock, [2, 2, 2, 2], num_classes=num_classes)model.avgpool = nn.AdaptiveAvgPool2d(1)return modeldef se_resnet34(num_classes=1_000):"""Constructs a ResNet-34 model.Args:pretrained (bool): If True, returns a model pre-trained on ImageNet"""model = ResNet(SEBasicBlock, [3, 4, 6, 3], num_classes=num_classes)model.avgpool = nn.AdaptiveAvgPool2d(1)return modeldef se_resnet50(num_classes=1_000, pretrained=False):"""Constructs a ResNet-50 model.Args:pretrained (bool): If True, returns a model pre-trained on ImageNet"""model = ResNet(SEBottleneck, [3, 4, 6, 3], num_classes=num_classes)model.avgpool = nn.AdaptiveAvgPool2d(1)if pretrained:model.load_state_dict(load_state_dict_from_url("https://github.com/moskomule/senet.pytorch/releases/download/archive/seresnet50-60a8950a85b2b.pkl"))return modeldef se_resnet101(num_classes=1_000):"""Constructs a ResNet-101 model.Args:pretrained (bool): If True, returns a model pre-trained on ImageNet"""model = ResNet(SEBottleneck, [3, 4, 23, 3], num_classes=num_classes)model.avgpool = nn.AdaptiveAvgPool2d(1)return modeldef se_resnet152(num_classes=1_000):"""Constructs a ResNet-152 model.Args:pretrained (bool): If True, returns a model pre-trained on ImageNet"""model = ResNet(SEBottleneck, [3, 8, 36, 3], num_classes=num_classes)model.avgpool = nn.AdaptiveAvgPool2d(1)return model

SeNet--通道注意力卷积相关推荐

  1. SEnet 通道注意力模块

    SEnet 通道注意力模块 开篇一张图: 变量和图片解释: 三个正方体:特征向量,比如说图像的特征,H表示图片高度.W表示图片宽.C表示通道(黑白1通道.彩色3通道) 字母: X表示输入特征: Ftr ...

  2. 《Squeeze-and-Excitation Networks》SE-Net通道注意力机制

    前言 在阅读了一系列Activation Funtion论文之后,其中Dynamic Relu的论文提到了基于注意力机制的网络,因此先来看看经典的SE-Net的原理 Introduction 对于CN ...

  3. SEnet 通道注意力机制

    SENet 在于通过网络根据loss去学习特征权重,获取到每个feature map的重要程度,然后用这个重要程度去给每一个特征通道赋予一个权重值,从而让神经网络去重点关注某些feature map ...

  4. 【深度学习】(8) CNN中的通道注意力机制(SEnet、ECAnet),附Tensorflow完整代码

    各位同学好,今天和大家分享一下attention注意力机制在CNN卷积神经网络中的应用,重点介绍三种注意力机制,及其代码复现. 在我之前的神经网络专栏的文章中也使用到过注意力机制,比如在MobileN ...

  5. ECA-Net:深度卷积神经网络的高效通道注意力

    ECA-Net:深度卷积神经网络的高效通道注意力 1.什么是注意力机制? 2.简介 3.ECANet注意力模块 3.1 回顾SENet模块 3.2 ECANet模块 3.3 ECANet代码复现 4. ...

  6. 基于改进通道注意力和多尺度卷积模块的蛋白质二级结构预测

    一.背景: 传统的蛋白质三维结构预测可以通过一些传统方法预测,但是此类方法过于昂贵和耗费时间. 蛋白质二级结构是三维结构和序列的桥梁,其由多肽链中氢键的作用决定.许多研究表明,我们可以通过蛋白质的二级 ...

  7. Squeeze-and-Excitation Networks(SENet:GAP+2个FC)通道注意力模型公式解析

    文章目录 Squeeze-and-Excitation Networks 代码 论文 SENet:通道注意力机制 SENet基线网络图 SE-inception Module SE-ResNet Mo ...

  8. 深度学习卷积神经网络重要结构之通道注意力和空间注意力模块

    #主要原理 提出CBAM的作者主要对分类网络和目标检测网络进行了实验,证明了CBAM模块确实是有效的. 以ResNet为例,论文中提供了改造的示意图,如下图所示: #CMAB模块实现,依据上面原理 # ...

  9. 最强通道注意力来啦!金字塔分割注意力模块,即插即用,效果显著,已开源!...

    导读 本文是通道注意力机制的又一重大改进,主要是在通道注意力的基础上,引入多尺度思想,本文主要提出了金字塔分割注意力模块,即PSA module.进一步,基于PSA ,我们将PSA注意力模块替换Res ...

最新文章

  1. UBUNTU adb连接android设备
  2. Jexus部署.Net Core项目
  3. Unity3D 之NGUI各种脚本及应用
  4. Android --- allowBackup 属性的含义和危险性实例讲解
  5. [js开源组件开发]图片放大镜
  6. python math库函数源码_11. math库函数
  7. eclipse 插件扩展新建java页面_java-Eclipse插件:创建动态菜单和相应的处理...
  8. matlab中if语句中的结果返回,matlab中if 语句后面的判别式不能是算术表达式?或者说变量?...
  9. 如何使用清理垃圾软件优化苹果电脑
  10. 从零开始学习makefile(7) makefile的filter的作用
  11. CF probabilities 自制题单
  12. 复杂材料棱柱体单站RCS
  13. STM32LL库系列教程(一)—— LL库概览及资料
  14. 一些牛逼哄哄的javascript面试题
  15. win32 API函数大全
  16. python基础(1)---python简介
  17. 【学习笔记】H5性能测试
  18. 空闲时间不要接私活,要提升自己
  19. 简单理解float和double、单精度和双精度
  20. EPB电子驻车制动系统Simulink模型 模型包括:有刷直流电机+执行器模型,电机参数m文件,SSM模块,PBC模块,数据处理模块,与Carsim联防进行过验证

热门文章

  1. 前端vue+element ie兼容性问题
  2. 优化总结:有哪些APP启动提速方法?
  3. 在手机上安装Ubuntu(Termux)
  4. 关于工作与生活zz —— 转载
  5. 直播app源代码,手机屏幕截取并保存到手机相册
  6. 管理学定律五:二八定律与木桶理论
  7. qq三国挂机云服务器,云服务器挂机QQ三国游戏的流程和实际操作概况记录
  8. mysql数据库怎么导出导入表
  9. 在C语言中使用else if判断数字是正数还是负数或是零。
  10. CCF-20170902-公共钥匙盒(30分)