背景简介

传说中的CV算法工程师抄作业必备手册,既然大佬们烧了那么多的电费为我们总结了这么多能work的trick,那岂有不抄的道理,具体论文的细节本文不做详述,论文地址如下:

https://arxiv.org/abs/2201.03545

以及大佬对此的详细解读

https://mp.weixin.qq.com/s/c6MRbzQE9ErFUWdWKh8PQA

本文仅做对工程应用的整合。

ConvNeXt-YoloV5

仍然以目标检测经典模型yolov5为例,对源代码做如下的修改增加

common.py
# 增加如下代码
#-------------------------------------ConvNeXt------------------------------------------------------
class Block(nn.Module):def __init__(self, dim, drop_path=0., layer_scale_init_value=1e-6):super().__init__()self.dwconv = nn.Conv2d(dim, dim, kernel_size=7, padding=3, groups=dim)  # depthwise convself.norm = LayerNorm(dim, eps=1e-6)self.pwconv1 = nn.Linear(dim, 4 * dim)  # pointwise/1x1 convs, implemented with linear layersself.act = nn.GELU()self.pwconv2 = nn.Linear(4 * dim, dim)self.gamma = nn.Parameter(layer_scale_init_value * torch.ones((dim)),requires_grad=True) if layer_scale_init_value > 0 else Noneself.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()def forward(self, x):input = xx = self.dwconv(x)x = x.permute(0, 2, 3, 1)  # (N, C, H, W) -> (N, H, W, C)x = self.norm(x)x = self.pwconv1(x)x = self.act(x)x = self.pwconv2(x)if self.gamma is not None:x = self.gamma * xx = x.permute(0, 3, 1, 2)  # (N, H, W, C) -> (N, C, H, W)x = input + self.drop_path(x)return xclass LayerNorm(nn.Module):def __init__(self, normalized_shape, eps=1e-6, data_format="channels_last"):super().__init__()self.weight = nn.Parameter(torch.ones(normalized_shape))self.bias = nn.Parameter(torch.zeros(normalized_shape))self.eps = epsself.data_format = data_formatif self.data_format not in ["channels_last", "channels_first"]:raise NotImplementedErrorself.normalized_shape = (normalized_shape,)def forward(self, x):if self.data_format == "channels_last":return F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)elif self.data_format == "channels_first":u = x.mean(1, keepdim=True)s = (x - u).pow(2).mean(1, keepdim=True)x = (x - u) / torch.sqrt(s + self.eps)x = self.weight[:, None, None] * x + self.bias[:, None, None]return xclass ConvNeXt_Block(nn.Module):  # index 0~3def __init__(self, index, in_chans, depths, dims, drop_path_rate=0., layer_scale_init_value=1e-6):super().__init__()self.index = indexself.downsample_layers = nn.ModuleList()  # stem and 3 intermediate downsampling conv layersstem = nn.Sequential(nn.Conv2d(in_chans, dims[0], kernel_size=4, stride=4),LayerNorm(dims[0], eps=1e-6, data_format="channels_first"))self.downsample_layers.append(stem)for i in range(3):downsample_layer = nn.Sequential(LayerNorm(dims[i], eps=1e-6, data_format="channels_first"),nn.Conv2d(dims[i], dims[i + 1], kernel_size=2, stride=2),)self.downsample_layers.append(downsample_layer)self.stages = nn.ModuleList()  # 4 feature resolution stages, each consisting of multiple residual blocksdp_rates = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))]cur = 0for i in range(4):stage = nn.Sequential(*[Block(dim=dims[i], drop_path=dp_rates[cur + j],layer_scale_init_value=layer_scale_init_value) for j in range(depths[i])])self.stages.append(stage)cur += depths[i]self.apply(self._init_weights)def _init_weights(self, m):if isinstance(m, (nn.Conv2d, nn.Linear)):trunc_normal_(m.weight, std=.02)nn.init.constant_(m.bias, 0)def forward(self, x):x = self.downsample_layers[self.index](x)x = self.stages[self.index](x)return x
yolo.py
# 修改parse_model函数
def parse_model(d, ch):  # model_dict, input_channels(3)LOGGER.info('\n%3s%18s%3s%10s  %-40s%-30s' % ('', 'from', 'n', 'params', 'module', 'arguments'))anchors, nc, gd, gw = d['anchors'], d['nc'], d['depth_multiple'], d['width_multiple']na = (len(anchors[0]) // 2) if isinstance(anchors, list) else anchors  # number of anchorsno = na * (nc + 5)  # number of outputs = anchors * (classes + 5)layers, save, c2 = [], [], ch[-1]  # layers, savelist, ch outfor i, (f, n, m, args) in enumerate(d['backbone'] + d['head']):  # from, number, module, argsm = eval(m) if isinstance(m, str) else m  # eval stringsfor j, a in enumerate(args):try:args[j] = eval(a) if isinstance(a, str) else a  # eval stringsexcept:passn = n_ = max(round(n * gd), 1) if n > 1 else n  # depth gainif m in [Conv, GhostConv, Bottleneck, GhostBottleneck, SPP, SPPF, DWConv, MixConv2d, Focus, CrossConv,BottleneckCSP, C3, C3TR, C3SPP, C3Ghost]:c1, c2 = ch[f], args[0]   if c2 != no:  # if not outputc2 = make_divisible(c2 * gw, 8)args = [c1, c2, *args[1:]]if m in [BottleneckCSP, C3, C3TR, C3Ghost]:args.insert(2, n)  # number of repeatsn = 1elif m is nn.BatchNorm2d:args = [ch[f]]elif m is Concat:c2 = sum([ch[x] for x in f])elif m is Detect:args.append([ch[x] for x in f])if isinstance(args[1], int):  # number of anchorsargs[1] = [list(range(args[1] * 2))] * len(f)elif m is Contract:c2 = ch[f] * args[0] ** 2elif m is Expand:c2 = ch[f] // args[0] ** 2
# 添加加该部分代码
#---------------------------------------------            elif m is ConvNeXt_Block:c2 = args[0]args = args[1:]
#----------------------------------------------else:c2 = ch[f]m_ = nn.Sequential(*[m(*args) for _ in range(n)]) if n > 1 else m(*args)  # modulet = str(m)[8:-2].replace('__main__.', '')  # module typenp = sum([x.numel() for x in m_.parameters()])  # number paramsm_.i, m_.f, m_.type, m_.np = i, f, t, np  # attach index, 'from' index, type, number paramsLOGGER.info('%3s%18s%3s%10.0f  %-40s%-30s' % (i, f, n_, np, t, args))  # printsave.extend(x % i for x in ([f] if isinstance(f, int) else f) if x != -1)  # append to savelistlayers.append(m_)if i == 0:ch = []ch.append(c2)return nn.Sequential(*layers), sorted(save)
yolov5_ConvNeXt.yaml

# Parameters
# 以convnext_tiny_1k为例
nc: 80  # number of classes
anchors:- [10,13, 16,30, 33,23]  # P3/8- [30,61, 62,45, 59,119]  # P4/16- [116,90, 156,198, 373,326]  # P5/32# YOLOv5 backbone
backbone:[[-1, 1, ConvNeXt_Block, [96, 0, 3, [3, 3, 9, 3], [96, 192, 384, 768]]],[-1, 1, ConvNeXt_Block, [192, 1, 3, [3, 3, 9, 3], [96, 192, 384, 768]]],[-1, 1, ConvNeXt_Block, [384, 2, 3, [3, 3, 9, 3], [96, 192, 384, 768]]],[-1, 1, ConvNeXt_Block, [768, 3, 3, [3, 3, 9, 3], [96, 192, 384, 768]]],]# YOLOv5 head
head:[[-1, 1, Conv, [768, 1, 1]],[-1, 1, nn.Upsample, [None, 2, 'nearest']],[[-1, 2], 1, Concat, [1]],[-1, 3, C3, [768, False]],[-1, 1, Conv, [384, 1, 1]],[-1, 1, nn.Upsample, [None, 2, 'nearest']],[[-1, 1], 1, Concat, [1]],[-1, 3, C3, [384, False]],[-1, 1, Conv, [384, 3, 2]],[[-1, 8], 1, Concat, [1]],[-1, 3, C3, [768, False]],[-1, 1, Conv, [768, 3, 2]],[[-1, 4], 1, Concat, [1]],  # cat head P5[-1, 3, C3, [768, False]],  # 23 (P5/32-large)[[11, 14, 17], 1, Detect, [nc, anchors]],  # Detect(P3, P4, P5)]
加载骨干网络预训练权重

网络搭建完成后因为缺少预训练权重,而从零开始训练,会需要大量的时间和算力,也浪费了大佬们花了那么多电费为我们提供的骨干网络权重,因此我们需要将骨干网络的预训练权重更新到我们搭建的检测网络中,代码如下(以convnext_tiny_1k为例):

    model_urls = {"convnext_tiny_1k": "https://dl.fbaipublicfiles.com/convnext/convnext_tiny_1k_224_ema.pth","convnext_small_1k": "https://dl.fbaipublicfiles.com/convnext/convnext_small_1k_224_ema.pth","convnext_base_1k": "https://dl.fbaipublicfiles.com/convnext/convnext_base_1k_224_ema.pth","convnext_large_1k": "https://dl.fbaipublicfiles.com/convnext/convnext_large_1k_224_ema.pth","convnext_base_22k": "https://dl.fbaipublicfiles.com/convnext/convnext_base_22k_224.pth","convnext_large_22k": "https://dl.fbaipublicfiles.com/convnext/convnext_large_22k_224.pth","convnext_xlarge_22k": "https://dl.fbaipublicfiles.com/convnext/convnext_xlarge_22k_224.pth",}url = model_urls['convnext_tiny_1k']checkpoint = torch.hub.load_state_dict_from_url(url=url, map_location="cpu", check_hash=True)# 提取模型需要的部分权重init_dict = {}for index in range(4):for k, v in list(checkpoint['model'].items()):if k.startswith('norm') or k.startswith('head'):passelse:init_dict['.'.join(['model', str(index), k])] = v# Create modelmodel = Model(opt.cfg).to(device)model.train()# 更新Backbone部分网络权重model_dict = model.state_dict()model_dict.update(init_dict)model.load_state_dict(model_dict)

完整项目见Github

https://github.com/OutBreak-hui/YoloV5-Flexible-and-Inference

ConvNeXt-Yolo5相关推荐

  1. 【神经网络】(19) ConvNeXt 代码复现,网络解析,附Tensorflow完整代码

    各位同学好,今天和大家分享一下如何使用 Tensorflow 构建 ConvNeXt 卷积神经网络模型. 论文地址:https://arxiv.org/pdf/2201.03545.pdf 完整代码在 ...

  2. CVPR 2022 | ConvNeXt - FAIR再探纯卷积结构的极限(优于Transformer)

    作者| cocoon 编辑| 3D视觉开发者社区 前言 FAIR又出新作了!一篇<2020年代的卷积网络>的横空出世,让国内外CV圈的眼光都聚焦于此,不少大牛都纷纷下场参与讨论.研究团队以 ...

  3. YOLOv5初探(看来这个YOLO5做得还不是很完善,过段时间再试试??)

    文章目录 YOLOv5简介 官方github链接 如何安装以及训练自己的数据集 安装 训练 1.创建Dataset.yaml文件 2.创建存放图片和标注的目录 3.创建标注labels 4.选择预训练 ...

  4. CVPR 2022|从原理和代码详解FAIR的惊艳之作:全新的纯卷积模型ConvNeXt

    本文首发于极市平台,作者科技猛兽,转载请获得授权并标明出处. 本文目录 7 匹敌 Transformer 的2020年代的卷积网络 (来自 FAIR,UCB) 7.1 ConvNeXt 原理分析 7. ...

  5. 【网络设计】ConvNeXt:A ConvNet for the 2020s

    文章目录 一.背景 二.方法 2.1 训练方法 2.2 宏观设计 2.3 ResNeXt-ify:79.5%→80.5% 2.4 Inverted Bottleneck: 80.5%→80.6% 2. ...

  6. 阿里天池街景字符编码YOLO5方案

    前言 最近在做OCR相关的任务,用到了阿里天池一个街景字符识别比赛的数据集,索性就分享一下相关方案,我采用YOLO5模型,最终在平台提交分数也做到了0.924,没有经过任何优化,可以看出YOLO5的效 ...

  7. ConvNext模型复现--CVPR2022

    ConvNext模型复现--CVPR2022 1.Abstract 2.ConvNet现代化:路线图 3.模型设计方案 3.1 Macro Design(宏观设计) 3.2 ResNext-ify 3 ...

  8. Qt+OpenVino部署yolo5模型

    一.openvino简介 OpenVINO是英特尔针对自家硬件平台开发的一套深度学习工具库,包含推断库,模型优化等等一系列与深度学习模型部署相关的功能. OpenVINO™工具包是用于快速开发应用程序 ...

  9. 50、ubuntu18.0420.04+CUDA11.1+cudnn11.3+TensorRT7.2/8.6+Deepsteam5.1+vulkan环境搭建和YOLO5部署

    基本思想:想学习一下TensorRT的使用,随笔记录一下: 链接:https://pan.baidu.com/s/1uFOktdF-bHcDDsufIqmNSA  提取码:k55w  复制这段内容后打 ...

  10. ConvNeXt+YOLOv7改进方案

    ConvNeXt理论请参考论文:[2201.03545] A ConvNet for the 2020s (arxiv.org)

最新文章

  1. web服务的搭建 windows server 2008
  2. WinAPI: PolyBezier - 绘制贝塞尔线
  3. 独舞风雪夜 跟我学mvc系列
  4. 【渝粤题库】广东开放大学社会学概论形成性考核
  5. do与mysql数据类型对照_dophon-db: dophon框架的数据库模块,支持mysql,sqlite数据库,带有orm持久化功能与链式操作实例,贴近逻辑习惯,支持mysql多数据源配置...
  6. 更新FreeBSD Ports的方法
  7. 无法安装 Microsoft Visual Studio 2010 Service Pack 1
  8. maven插件打包exec_Exec Maven插件–从Maven Build运行Java程序
  9. [转] 各种数据类型转换
  10. 简单了解一下电商系统中的SPU、SKU、ID,它们都是什么意思,三者又有什么区别和联系呢?
  11. 微信小程序开发者工具扫码成功但是进不去
  12. flyway 实现 java 自动升级 SQL 脚本
  13. python迷宫地图代码_Python机器人探测迷宫代码求助
  14. 流利阅读day1 Dysmorphia
  15. Android设置应用数字角标
  16. springboot 集成kafka 实现多个customer不同group
  17. 【六更完结!由于字数限制开新文章继续】零基础信号与系统学习笔记:复指数信号、傅里叶级数的系数推导、三角函数正交性、离散傅里叶变换、相位补偿、z变换表、逆变换表、常见序列及其作用
  18. log4j2 概述、配置详情、日志八种输出级别讲解
  19. pfSense book之DNS解析
  20. 简单视频会议软件_简单而免费的视频会议

热门文章

  1. 【Java】GsonUtils(Gson 工具类)
  2. java/php/net/python郎朗球鞋交易系统设计
  3. Log Forging漏洞
  4. input type=checkbox样式设置
  5. 长沙中学计算机编程培训,长沙中学生寒假编程培训
  6. 【Google Code Jam】Millionaire
  7. IOS 第三方开源库汇总
  8. 相寻梦里路,飞雨落花中
  9. 胡图图想学c语言呀3
  10. 店铺定位的关键 装修风格 精准人群 定价 店铺层级