文章目录

  • 摘要
  • 安装包
    • 安装timm
  • 数据增强Cutout和Mixup
  • EMA
  • 导入g_ghost_regnet.py文件
  • 项目结构
  • 计算mean和std
  • 生成数据集

摘要

论文地址:https://arxiv.org/abs/2201.03297
代码地址:https://github.com/huawei-noah/CV-Backbones
讲解视频:https://www.zhihu.com/zvideo/1584651719241363457
上篇实战介绍了华为的GhostNet,上面的论文中,将GhostNet成为C-GhostNet,C-GhostNet中为实现轻量化,使用了一些低运算密度的操作。低运算密度使得GPU的并行计算能力无法被充分利用,从而导致C-GhostNet在GPU等设备上糟糕的延迟,因此需要设计一种适用于GPU设备的Ghost模块。

作者等人发现,现有大多数CNN架构中,一个阶段通常包括几个卷积层/块,同时在每个阶段中的不同层/块,特征图的尺寸大小相同,因此一种猜想是:特征的相似性和冗余性不仅存在于一个层内,也存在于该阶段的多个层之间。下图的可视化结果验证了这种想法(如右边第三行第二列和第七行第三列的特征图存在一定相似性)。

作者等人利用观察到的阶段性特征冗余,设计G-Ghost模块并应用于GPU等设备,实现了一个在GPU上具有SOTA性能的轻量级CNN。G-Ghost中g_ghost_regnetx_160模型在ImageNet上取的了79.9%的成绩。
我这篇文章主要讲解如何使用G-Ghost完成图像分类任务,接下来我们一起完成项目的实战。经过测试,G-Ghost在植物幼苗数据集上实现了97+%的准确率。

通过这篇文章能让你学到:

  1. 如何使用数据增强,包括transforms的增强、CutOut、MixUp、CutMix等增强手段?
  2. 如何实现G-Ghost模型实现训练?
  3. 如何使用pytorch自带混合精度?
  4. 如何使用梯度裁剪防止梯度爆炸?
  5. 如何使用DP多显卡训练?
  6. 如何绘制loss和acc曲线?
  7. 如何生成val的测评报告?
  8. 如何编写测试脚本测试测试集?
  9. 如何使用余弦退火策略调整学习率?
  10. 如何使用AverageMeter类统计ACC和loss等自定义变量?
  11. 如何理解和统计ACC1和ACC5?
  12. 如何使用EMA?

安装包

安装timm

使用pip就行,命令:

pip install timm

数据增强Cutout和Mixup

为了提高成绩我在代码中加入Cutout和Mixup这两种增强方式。实现这两种增强需要安装torchtoolbox。安装命令:

pip install torchtoolbox

Cutout实现,在transforms中。

from torchtoolbox.transform import Cutout
# 数据预处理
transform = transforms.Compose([transforms.Resize((224, 224)),Cutout(),transforms.ToTensor(),transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])])

需要导入包:from timm.data.mixup import Mixup,

定义Mixup,和SoftTargetCrossEntropy

  mixup_fn = Mixup(mixup_alpha=0.8, cutmix_alpha=1.0, cutmix_minmax=None,prob=0.1, switch_prob=0.5, mode='batch',label_smoothing=0.1, num_classes=12)criterion_train = SoftTargetCrossEntropy()

参数详解:

mixup_alpha (float): mixup alpha 值,如果 > 0,则 mixup 处于活动状态。

cutmix_alpha (float):cutmix alpha 值,如果 > 0,cutmix 处于活动状态。

cutmix_minmax (List[float]):cutmix 最小/最大图像比率,cutmix 处于活动状态,如果不是 None,则使用这个 vs alpha。

如果设置了 cutmix_minmax 则cutmix_alpha 默认为1.0

prob (float): 每批次或元素应用 mixup 或 cutmix 的概率。

switch_prob (float): 当两者都处于活动状态时切换cutmix 和mixup 的概率 。

mode (str): 如何应用 mixup/cutmix 参数(每个’batch’,‘pair’(元素对),‘elem’(元素)。

correct_lam (bool): 当 cutmix bbox 被图像边框剪裁时应用。 lambda 校正

label_smoothing (float):将标签平滑应用于混合目标张量。

num_classes (int): 目标的类数。

EMA

EMA(Exponential Moving Average)是指数移动平均值。在深度学习中的做法是保存历史的一份参数,在一定训练阶段后,拿历史的参数给目前学习的参数做一次平滑。具体实现如下:

""" Exponential Moving Average (EMA) of model updatesHacked together by / Copyright 2020 Ross Wightman
"""
import logging
from collections import OrderedDict
from copy import deepcopy
import torch
import torch.nn as nn_logger = logging.getLogger(__name__)class ModelEma:def __init__(self, model, decay=0.9999, device='', resume=''):# make a copy of the model for accumulating moving average of weightsself.ema = deepcopy(model)self.ema.eval()self.decay = decayself.device = device  # perform ema on different device from model if setif device:self.ema.to(device=device)self.ema_has_module = hasattr(self.ema, 'module')if resume:self._load_checkpoint(resume)for p in self.ema.parameters():p.requires_grad_(False)def _load_checkpoint(self, checkpoint_path):checkpoint = torch.load(checkpoint_path, map_location='cpu')assert isinstance(checkpoint, dict)if 'state_dict_ema' in checkpoint:new_state_dict = OrderedDict()for k, v in checkpoint['state_dict_ema'].items():# ema model may have been wrapped by DataParallel, and need module prefixif self.ema_has_module:name = 'module.' + k if not k.startswith('module') else kelse:name = knew_state_dict[name] = vself.ema.load_state_dict(new_state_dict)_logger.info("Loaded state_dict_ema")else:_logger.warning("Failed to find state_dict_ema, starting from loaded model weights")def update(self, model):# correct a mismatch in state dict keysneeds_module = hasattr(model, 'module') and not self.ema_has_modulewith torch.no_grad():msd = model.state_dict()for k, ema_v in self.ema.state_dict().items():if needs_module:k = 'module.' + kmodel_v = msd[k].detach()if self.device:model_v = model_v.to(device=self.device)ema_v.copy_(ema_v * self.decay + (1. - self.decay) * model_v)

加入到模型中。

#初始化
if use_ema:model_ema = ModelEma(model_ft,decay=model_ema_decay,device='cpu',resume=resume)# 训练过程中,更新完参数后,同步update shadow weights
def train():optimizer.step()if model_ema is not None:model_ema.update(model)# 将model_ema传入验证函数中
val(model_ema.ema, DEVICE, test_loader)

导入g_ghost_regnet.py文件

文件的路径:https://github.com/huawei-noah/Efficient-AI-Backbones/tree/master/g_ghost_pytorch
将其导入到项目的根目录,然后,对其做修改:
由于我们使用g_ghost_regnetx_160,增加g_ghost_regnetx_160预训练模型配置字典。

default_cfgs = {'g_ghost_regnetx_160':'https://github.com/huawei-noah/Efficient-AI-Backbones/releases/tag/g_ghost_regnet/g_ghost_regnet_16.0g_79.9.pth',
}

然后对g_ghost_regnetx_160函数做修改,增加预训练模型参数的加载。由于预训练模型比g_ghost_regnetx_160多了module.这个参数,所以要将这个参数去掉,否则无法正确加载。

def g_ghost_regnetx_160(pretrained=False,**kwargs):model=GGhostRegNet(Bottleneck, [2, 6, 13, 1], [256, 512, 896, 2048], group_width=128, **kwargs)if pretrained:url = default_cfgs['g_ghost_regnetx_160']checkpoint = torch.hub.load_state_dict_from_url(url=url, map_location="cpu")out_dict = collections.OrderedDict()for k, v in checkpoint.items():k=k.replace('module.','')out_dict[k] = vprint(out_dict.keys())model.load_state_dict(out_dict)return model

项目结构

G_Ghost_demo
├─data1
│  ├─Black-grass
│  ├─Charlock
│  ├─Cleavers
│  ├─Common Chickweed
│  ├─Common wheat
│  ├─Fat Hen
│  ├─Loose Silky-bent
│  ├─Maize
│  ├─Scentless Mayweed
│  ├─Shepherds Purse
│  ├─Small-flowered Cranesbill
│  └─Sugar beet
├─mean_std.py
├─makedata.py
├─g_ghost_regnet.py
├─train.py
└─test.py

mean_std.py:计算mean和std的值。
makedata.py:生成数据集。

为了能在DP方式中使用混合精度,还需要在模型的forward函数前增加@autocast()。

计算mean和std

为了使模型更加快速的收敛,我们需要计算出mean和std的值,新建mean_std.py,插入代码:

from torchvision.datasets import ImageFolder
import torch
from torchvision import transformsdef get_mean_and_std(train_data):train_loader = torch.utils.data.DataLoader(train_data, batch_size=1, shuffle=False, num_workers=0,pin_memory=True)mean = torch.zeros(3)std = torch.zeros(3)for X, _ in train_loader:for d in range(3):mean[d] += X[:, d, :, :].mean()std[d] += X[:, d, :, :].std()mean.div_(len(train_data))std.div_(len(train_data))return list(mean.numpy()), list(std.numpy())if __name__ == '__main__':train_dataset = ImageFolder(root=r'data1', transform=transforms.ToTensor())print(get_mean_and_std(train_dataset))

数据集结构:

运行结果:

([0.3281186, 0.28937867, 0.20702125], [0.09407319, 0.09732835, 0.106712654])

把这个结果记录下来,后面要用!

生成数据集

我们整理还的图像分类的数据集结构是这样的

data
├─Black-grass
├─Charlock
├─Cleavers
├─Common Chickweed
├─Common wheat
├─Fat Hen
├─Loose Silky-bent
├─Maize
├─Scentless Mayweed
├─Shepherds Purse
├─Small-flowered Cranesbill
└─Sugar beet

pytorch和keras默认加载方式是ImageNet数据集格式,格式是

├─data
│  ├─val
│  │   ├─Black-grass
│  │   ├─Charlock
│  │   ├─Cleavers
│  │   ├─Common Chickweed
│  │   ├─Common wheat
│  │   ├─Fat Hen
│  │   ├─Loose Silky-bent
│  │   ├─Maize
│  │   ├─Scentless Mayweed
│  │   ├─Shepherds Purse
│  │   ├─Small-flowered Cranesbill
│  │   └─Sugar beet
│  └─train
│      ├─Black-grass
│      ├─Charlock
│      ├─Cleavers
│      ├─Common Chickweed
│      ├─Common wheat
│      ├─Fat Hen
│      ├─Loose Silky-bent
│      ├─Maize
│      ├─Scentless Mayweed
│      ├─Shepherds Purse
│      ├─Small-flowered Cranesbill
│      └─Sugar beet

新增格式转化脚本makedata.py,插入代码:

import glob
import os
import shutilimage_list=glob.glob('data1/*/*.png')
print(image_list)
file_dir='data'
if os.path.exists(file_dir):print('true')#os.rmdir(file_dir)shutil.rmtree(file_dir)#删除再建立os.makedirs(file_dir)
else:os.makedirs(file_dir)from sklearn.model_selection import train_test_split
trainval_files, val_files = train_test_split(image_list, test_size=0.3, random_state=42)
train_dir='train'
val_dir='val'
train_root=os.path.join(file_dir,train_dir)
val_root=os.path.join(file_dir,val_dir)
for file in trainval_files:file_class=file.replace("\\","/").split('/')[-2]file_name=file.replace("\\","/").split('/')[-1]file_class=os.path.join(train_root,file_class)if not os.path.isdir(file_class):os.makedirs(file_class)shutil.copy(file, file_class + '/' + file_name)for file in val_files:file_class=file.replace("\\","/").split('/')[-2]file_name=file.replace("\\","/").split('/')[-1]file_class=os.path.join(val_root,file_class)if not os.path.isdir(file_class):os.makedirs(file_class)shutil.copy(file, file_class + '/' + file_name)

完成上面的内容就可以开启训练和测试了。

G-Ghost-RegNet实战:使用G-Ghost-RegNet实现图像分类任务(一)相关推荐

  1. 实战解决使用ghost安装系统出现的各种问题

    昨天使用ghost给人安装系统时,把另一个分区的数据都搞没了,安装完也只剩下一个分区,相信了解的同志们知道是什么原因. 今天下决心研究了一下ghost. 首先使用了一张深度的xp盘,设置为光驱优先引导 ...

  2. 联想ghost重装系统_一键ghost,详细教您使用一键ghost怎么重装win7系统

    讲起这个重装系统的方法跟操作,相信广大的用户听的最多的,用的最多的,看的最多的,就是U盘安装系统,硬盘安装系统以及渐渐退隐江湖的光盘重装系统,这几样了,那么不知道你们有没有听过,或者使用过一键重装系统 ...

  3. java中输出a个b_下面代码输出什么 ( ) var a=0,b=0; for(;a10,b7;a++,b++){ g=a+b; } console.log(g);_学小易找答案...

    [单选题]Java Script 函数说法正确的是 () [单选题]阅读下面的 Javascript 代码 , 输出结果正确的是( ) var i=0; for(i=0;i<=5;i++){ i ...

  4. g标签 怎么设置svg_svg g标签的运用

    在svg中提供了如g元素这样的将多个元素组织在一起的元素. 由g元素编组在一起的可以设置相同的颜色,可以进行坐标变换 下面是运用了snap.svg.js的实例 //  创建一个svg 对象 var   ...

  5. (离散)设函数 f:A→B,g:B→C,证明:若g °f是满射,则g是满射.

    设函数f:A→B,g:B→C,证明:若g °f是满射,则g是满射. 解析:要证满射 g:B→C,对 ∀ c ∈C,∃ b∈B,c = g(b),则g满射

  6. 6阶群的非平凡子群_当|G|=8时,群lt;G,*gt;只能有?阶非平凡子群,不能有?阶子群,平凡子群为?...

    匿名用户 1级 2013-01-02 回答 试卷二十四试题与答案 一.填空题每空1分本大题共15分 1设432aA143aB请在下列每对集合中填入适当的符号. 1a B 2 34a A. 2设10AN ...

  7. Ghost v11 for U3 and Ghost Explorer v11 for U3

    Symantec Ghost v11 for U3 and Ghost Explorer v11 for U3

  8. 联想ghost重装系统_手动ghost安装系统详细操作步骤

    ghost是什么意思?ghost在英文的意思则是魔鬼,通常我们会使用ghost来恢复gho系统镜像文件,该怎么使用ghost来恢复win7 64位系统呢?如何手动操作ghost来重装Win7 64位系 ...

  9. 服务器系统ghost蓝屏,win7系统ghost安装一半蓝屏的解决方法

    我们在操作win7系统电脑的时候,常常会遇到win7系统ghost安装一半蓝屏的困惑吧,一些朋友看过网上零散的win7系统ghost安装一半蓝屏的处理方法,并没有完完全全明白win7系统ghost安装 ...

  10. linux系统硬盘ghost吗,将linux硬盘ghost到另一颗去

    将linux硬盘ghost到另一颗去 发布时间:2005-09-29 00:04:16来源:红联作者:cha 要将linux完完全全的备份到另一颗抽取式硬盘,如果linux挂了,可以立刻升上来备援!! ...

最新文章

  1. Web 开发最有用的 jQuery 插件集锦
  2. DHCP服务器的搭建
  3. 第十三周项目一-分数类中的运算符重载
  4. 人工智能70年商业变现艰难,新基建能否催生规模化落地?
  5. 第十七部分-Python文档和测试
  6. 19 | 案例篇:为什么系统的Swap变高了(上)
  7. java cpu过高排查_涨薪秘籍:JAVA项目排查cpu负载过高
  8. linux的基础知识——shell语法
  9. SpringBoot怎么直接访问templates下的html页面
  10. 安卓--selector简单使用
  11. 小叮当的2021年年终总结
  12. 查询ES(ElasticSearch)版本
  13. 三运放差分放大电路分析_三运放组成的差分放大器电路图及特点
  14. 八皇后时间复杂度_【精神分裂症】首次发病未治疗精神分裂症患者大脑皮质复杂度改变...
  15. 信奥要学哪些数学知识 学信奥要不要先学python
  16. 计算机音乐来自天堂的魔鬼,来自天堂的魔鬼
  17. Matlab GUI编程技巧(十五):scroll滚动到组件内的位置及ScrollBar动画演示
  18. elementui 打包后icon图标加载偶尔会乱码
  19. 小说里的编程 【连载之二十三】元宇宙里月亮弯弯
  20. Go语言sqlx库操作PostgreSQL数据库增删改查

热门文章

  1. 北京博奥智源,浅谈图书馆的馆情展示系统细则
  2. Java 代码实现阿姆斯特朗炮的原理
  3. Adobe Reader 下载大全
  4. 《唐探3》口碑急转直下?看看影迷们到底都说了些啥
  5. CVPR2018 语义分割
  6. CVPR2018总结
  7. 【牛客网华为机试】HJ42 学英语
  8. java之Reader类与Writer类
  9. 25个故事性网页设计,轻松讲述网页独有的故事!!!
  10. 电脑开机自动开启程序的方法