ConvNeXt-T 结构图

ConvNeXt Block模块搭建

class Block(nn.Module): # ConvNeXt Block模块def __init__(self, dim, drop_rate=0., layer_scale_init_value=1e-6): # 初始化函数super().__init__()self.dwconv = nn.Conv2d(dim, dim, kernel_size=7, padding=3, groups=dim)  # 构建卷积depthwise convself.norm = LayerNorm(dim, eps=1e-6, data_format="channels_last")self.pwconv1 = nn.Linear(dim, 4 * dim)  # 1x1的卷积层和全连接层的作用是一样的 pointwise/1x1 convs, implemented with linear layersself.act = nn.GELU() # GELU激活函数self.pwconv2 = nn.Linear(4 * dim, dim) # 注意pwconv1和pwconv2的输入输出channel是不同的self.gamma = nn.Parameter(layer_scale_init_value * torch.ones((dim,)), # layer_scale层requires_grad=True) if layer_scale_init_value > 0 else Noneself.drop_path = DropPath(drop_rate) if drop_rate > 0. else nn.Identity() # 构建DropPath层def forward(self, x: torch.Tensor) -> torch.Tensor: # 正向传播过程shortcut = xx = self.dwconv(x) # 通过DW卷积x = x.permute(0, 2, 3, 1)  # 通过permute方法调整通道顺序 [N, C, H, W] -> [N, H, W, C]x = self.norm(x) # LayerNorm层x = self.pwconv1(x) # 1x1的卷积层x = self.act(x) # GELU激活函数x = self.pwconv2(x) # 1x1的卷积层if self.gamma is not None:x = self.gamma * x   # 对每个通道的数据进行缩放x = x.permute(0, 3, 1, 2)  # 还原通道顺序 [N, H, W, C] -> [N, C, H, W]x = shortcut + self.drop_path(x) # 通过drop_path层并融合shortcutreturn x

ConvNeXt整体网络结构搭建

class ConvNeXt(nn.Module):def __init__(self, in_chans: int = 3, num_classes: int = 1000, depths: list = None,dims: list = None, drop_path_rate: float = 0., layer_scale_init_value: float = 1e-6,head_init_scale: float = 1.):super().__init__()self.downsample_layers = nn.ModuleList()  # stem and 3 intermediate downsampling conv layersstem = nn.Sequential(nn.Conv2d(in_chans, dims[0], kernel_size=4, stride=4),LayerNorm(dims[0], eps=1e-6, data_format="channels_first")) # 构建卷积层和LayerNorm层self.downsample_layers.append(stem) # 添加到downsample_layers中# 对应stage2-stage4前的3个downsamplefor i in range(3):downsample_layer = nn.Sequential(LayerNorm(dims[i], eps=1e-6, data_format="channels_first"),nn.Conv2d(dims[i], dims[i+1], kernel_size=2, stride=2))self.downsample_layers.append(downsample_layer) # 添加到downsample_layers中self.stages = nn.ModuleList()  # 存储每一个stage所构建的block 4 feature resolution stages, each consisting of multiple blocksdp_rates = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))]cur = 0# 构建每个stage中堆叠的blockfor i in range(4):stage = nn.Sequential(*[Block(dim=dims[i], drop_rate=dp_rates[cur + j], layer_scale_init_value=layer_scale_init_value)for j in range(depths[i])])self.stages.append(stage)cur += depths[i]self.norm = nn.LayerNorm(dims[-1], eps=1e-6)  # 最后一个LayerNorm层 final norm layerself.head = nn.Linear(dims[-1], num_classes)self.apply(self._init_weights) # 传入初始化权重self.head.weight.data.mul_(head_init_scale)self.head.bias.data.mul_(head_init_scale)def _init_weights(self, m): # 初始化权重if isinstance(m, (nn.Conv2d, nn.Linear)):nn.init.trunc_normal_(m.weight, std=0.2)nn.init.constant_(m.bias, 0)# 顺序为downsample_layers[0]->stages[0]->downsample_layers[1]->stages[1]->downsample_layers[2]->stages[2]->downsample_layers[3]->stages[3]def forward_features(self, x: torch.Tensor) -> torch.Tensor:for i in range(4):x = self.downsample_layers[i](x)x = self.stages[i](x)# 再通过全局平均池化以及LayerNorm层return self.norm(x.mean([-2, -1]))  # global average pooling, (N, C, H, W) -> (N, C)def forward(self, x: torch.Tensor) -> torch.Tensor:x = self.forward_features(x)x = self.head(x) # 再通过最后的全连接层得到最终输出return x

对于ConvNeXt网络,作者提出了T/S/B/L/XL五个版本

其中C代表4个stage中输入的通道数,B代表每个stage重复堆叠block的次数

def convnext_tiny(num_classes: int):# https://dl.fbaipublicfiles.com/convnext/convnext_tiny_1k_224_ema.pthmodel = ConvNeXt(depths=[3, 3, 9, 3],dims=[96, 192, 384, 768],num_classes=num_classes)return modeldef convnext_small(num_classes: int):# https://dl.fbaipublicfiles.com/convnext/convnext_small_1k_224_ema.pthmodel = ConvNeXt(depths=[3, 3, 27, 3],dims=[96, 192, 384, 768],num_classes=num_classes)return modeldef convnext_base(num_classes: int):# https://dl.fbaipublicfiles.com/convnext/convnext_base_1k_224_ema.pth# https://dl.fbaipublicfiles.com/convnext/convnext_base_22k_224.pthmodel = ConvNeXt(depths=[3, 3, 27, 3],dims=[128, 256, 512, 1024],num_classes=num_classes)return modeldef convnext_large(num_classes: int):# https://dl.fbaipublicfiles.com/convnext/convnext_large_1k_224_ema.pth# https://dl.fbaipublicfiles.com/convnext/convnext_large_22k_224.pthmodel = ConvNeXt(depths=[3, 3, 27, 3],dims=[192, 384, 768, 1536],num_classes=num_classes)return modeldef convnext_xlarge(num_classes: int):# https://dl.fbaipublicfiles.com/convnext/convnext_xlarge_22k_224.pthmodel = ConvNeXt(depths=[3, 3, 27, 3],dims=[256, 512, 1024, 2048],num_classes=num_classes)return model

reference

ConvNeXt网络详解_太阳花的小绿豆的博客-CSDN博客

13.1 ConvNeXt网络讲解_哔哩哔哩_bilibili

ConvNeXt网络结构搭建相关推荐

  1. YOLOV3网络结构搭建

    YOLOV3 一.定义一个残差结构 # 残差结构 # 利用一个1x1卷积下降通道数,然后利用一个3x3卷积提取特征并且上升通道数 # 最后接上一个残差边 #---------------------- ...

  2. ResNet50 网络结构搭建(PyTorch)

    ResNet50是一个经典的特征提取网络结构,虽然Pytorch已有官方实现,但为了加深对网络结构的理解,还是自己动手敲敲代码搭建一下.需要特别说明的是,笔者是以熟悉网络各层输出维度变化为目的的,只对 ...

  3. ConvNext模型复现--CVPR2022

    ConvNext模型复现--CVPR2022 1.Abstract 2.ConvNet现代化:路线图 3.模型设计方案 3.1 Macro Design(宏观设计) 3.2 ResNext-ify 3 ...

  4. PyTorch笔记 - A ConvNet for the 2020s (ConvNeXt) 网络

    欢迎关注我的CSDN:https://blog.csdn.net/caroline_wendy 本文地址:https://blog.csdn.net/caroline_wendy/article/de ...

  5. 人体姿态估计HRNet网络模型搭建代码详解

    HRNet-v1模型详解 源码参考:https://github.com/HRNet/HRNet-Human-Pose-Estimation 内容参考:点击跳转 仅作为个人的学习笔记,欢迎交流学习. ...

  6. 【小白学PyTorch】4.构建模型三要素与权重初始化

    点击上方"小白学视觉",选择加"星标"或"置顶" 重磅干货,第一时间送达 文章目录: 1 模型三要素 2 参数初始化 3 完整运行代码 4 ...

  7. 【项目实践】中英文文字检测与识别项目(CTPN+CRNN+CTC Loss原理讲解)

    点击上方"小白学视觉",选择加"星标"或"置顶" 重磅干货,第一时间送达 本文转自:opencv学堂 OCR--简介 文字识别也是图像领域一 ...

  8. MobileNetV1/V2/V3简述 | 轻量级网络

    MobileNet系列很重要的轻量级网络家族,出自谷歌,MobileNetV1使用深度可分离卷积来构建轻量级网络,MobileNetV2提出创新的inverted residual with line ...

  9. 神经网络:卷积神经网络CNN

    一.前言 这篇卷积神经网络是前面介绍的多层神经网络的进一步深入,它将深度学习的思想引入到了神经网络当中,通过卷积运算来由浅入深的提取图像的不同层次的特征,而利用神经网络的训练过程让整个网络自动调节卷积 ...

最新文章

  1. ST表 (模板) 洛谷3865
  2. 常用计算机类型包括个人计算机,网络教育统考《计算机应用基础》多媒体技术模拟题(二)...
  3. iPhone 12系列接连出新问题:无法收短信等信息通知
  4. 纯 js 让浏览器不缓存 ajax 请求
  5. paroot忘记root密码
  6. 欢迎加入互联网架构师群
  7. python实现简单计算器
  8. Windows anaconda下载安装
  9. 注意力机制(Attention Mechanism)
  10. Java爬取糗百段子
  11. 云锁linux宝塔安装,【最新版】宝塔面板下为Nginx自编译云锁Web防护模块教程
  12. php压缩bcd码,什么是BCD码,什么是压缩的BCD码?二者有什么区别?
  13. 为阿里云ECS设置共享上网、端口映射
  14. PHP通过CURL上传图片(微信公众号上传素材)
  15. 程序猿修仙之路--数据结构之设计高性能访客记录系统
  16. Android WebView 加载失败(net::ERR_CLEARTEXT_NOT_PERMITTED)
  17. PCA(主成分分析法)的Python代码实现(numpy,sklearn)
  18. Linux系统mmap函数映射物理地址
  19. 220g格斗机器人建造日志(1) —— 结构和硬件设计
  20. cisco(思科) 问题库

热门文章

  1. 【测速】使用不同的Docker镜像加速器下载速度 对比测试
  2. 邮件服务器类型的区别
  3. RAII技术与智能指针(超详细)
  4. web前端全栈0基础到精通(祺)vue 04
  5. unity AR3D物体识别
  6. Hive数据仓库工具基本架构和入门部署详解
  7. 浅析IP地址、子网掩码、MAC地址
  8. php的意义,php的含义是什么意思
  9. 基于Springboot+vue的电影购票系统(源代码+数据库)057
  10. 重温STL, STD