ConvMAE实战

  • 摘要
  • 安装包
    • 1、安装timm
  • 数据增强Cutout和Mixup
  • 项目结构
  • 计算mean和std
  • 生成数据集

摘要

本文通过对植物幼苗分类的实际例子来感受一下ConvMAE模型的效果。模型来自官方,我自己写了train和test部分。从得分情况来看,这个模型不是很好,我训练了1000个epoch,最终得分是93.X%,而且收敛很慢。

所以就不详细介绍这个模型,带领大家从实战的角度体验一下。

这篇文章能让你学到:

  1. 如何使用数据增强,包括transforms的增强、CutOut、MixUp、CutMix等增强手段?
  2. 如何调用自定义的模型?
  3. 如何使用pytorch自带混合精度?
  4. 如何使用梯度裁剪防止梯度爆炸?
  5. 如何使用DP多显卡训练?
  6. 如何绘制loss和acc曲线?
  7. 如何生成val的测评报告?
  8. 如何编写测试脚本测试测试集?
  9. 如何使用余弦退火策略调整学习率?
  10. 如何使用AverageMeter类统计ACC和loss等自定义变量?
  11. 如何理解和统计ACC1和ACC5?

安装包

1、安装timm

使用pip就行,命令:

pip install timm

数据增强Cutout和Mixup

为了提高成绩我在代码中加入Cutout和Mixup这两种增强方式。实现这两种增强需要安装torchtoolbox。安装命令:

pip install torchtoolbox

Cutout实现,在transforms中。

from torchtoolbox.transform import Cutout
# 数据预处理
transform = transforms.Compose([transforms.Resize((224, 224)),Cutout(),transforms.ToTensor(),transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])])

需要导入包:from timm.data.mixup import Mixup,

定义Mixup,和SoftTargetCrossEntropy

  mixup_fn = Mixup(mixup_alpha=0.8, cutmix_alpha=1.0, cutmix_minmax=None,prob=0.1, switch_prob=0.5, mode='batch',label_smoothing=0.1, num_classes=12)criterion_train = SoftTargetCrossEntropy()

参数详解:

mixup_alpha (float): mixup alpha 值,如果 > 0,则 mixup 处于活动状态。

cutmix_alpha (float):cutmix alpha 值,如果 > 0,cutmix 处于活动状态。

cutmix_minmax (List[float]):cutmix 最小/最大图像比率,cutmix 处于活动状态,如果不是 None,则使用这个 vs alpha。

如果设置了 cutmix_minmax 则cutmix_alpha 默认为1.0

prob (float): 每批次或元素应用 mixup 或 cutmix 的概率。

switch_prob (float): 当两者都处于活动状态时切换cutmix 和mixup 的概率 。

mode (str): 如何应用 mixup/cutmix 参数(每个’batch’,‘pair’(元素对),‘elem’(元素)。

correct_lam (bool): 当 cutmix bbox 被图像边框剪裁时应用。 lambda 校正

label_smoothing (float):将标签平滑应用于混合目标张量。

num_classes (int): 目标的类数。

项目结构

ConvMAE_demo
├─data
│  ├─Black-grass
│  ├─Charlock
│  ├─Cleavers
│  ├─Common Chickweed
│  ├─Common wheat
│  ├─Fat Hen
│  ├─Loose Silky-bent
│  ├─Maize
│  ├─Scentless Mayweed
│  ├─Shepherds Purse
│  ├─Small-flowered Cranesbill
│  └─Sugar beet
├─models
│  ├─__init__.py
│  ├─models_convvit.py
│  └─vision_transformer.py
├─checkpoint.pth
├─mean_std.py
├─makedata.py
├─train.py
└─test.py

mean_std.py:计算mean和std的值。

makedata.py:生成数据集。

models_convvit.py和vision_transformer:来自官方的pytorch版本的代码。

checkpoint.pth:预训练权重。作者把预训练权重放到谷歌网盘上,不方便下载。我把它传到了CSDN上。
下载链接:https://download.csdn.net/download/hhhhhhhhhhwwwwwwwwww/85467960?spm=1001.2014.3001.5503
为了能在DP方式中使用混合精度,还需要在模型的forward函数前增加@autocast().

计算mean和std

为了使模型更加快速的收敛,我们需要计算出mean和std的值,新建mean_std.py,插入代码:

from torchvision.datasets import ImageFolder
import torch
from torchvision import transformsdef get_mean_and_std(train_data):train_loader = torch.utils.data.DataLoader(train_data, batch_size=1, shuffle=False, num_workers=0,pin_memory=True)mean = torch.zeros(3)std = torch.zeros(3)for X, _ in train_loader:for d in range(3):mean[d] += X[:, d, :, :].mean()std[d] += X[:, d, :, :].std()mean.div_(len(train_data))std.div_(len(train_data))return list(mean.numpy()), list(std.numpy())if __name__ == '__main__':train_dataset = ImageFolder(root=r'data1', transform=transforms.ToTensor())print(get_mean_and_std(train_dataset))

数据集结构:

运行结果:

([0.3281186, 0.28937867, 0.20702125], [0.09407319, 0.09732835, 0.106712654])

把这个结果记录下来,后面要用!

生成数据集

我们整理还的图像分类的数据集结构是这样的

data
├─Black-grass
├─Charlock
├─Cleavers
├─Common Chickweed
├─Common wheat
├─Fat Hen
├─Loose Silky-bent
├─Maize
├─Scentless Mayweed
├─Shepherds Purse
├─Small-flowered Cranesbill
└─Sugar beet

pytorch和keras默认加载方式是ImageNet数据集格式,格式是

├─data
│  ├─val
│  │   ├─Black-grass
│  │   ├─Charlock
│  │   ├─Cleavers
│  │   ├─Common Chickweed
│  │   ├─Common wheat
│  │   ├─Fat Hen
│  │   ├─Loose Silky-bent
│  │   ├─Maize
│  │   ├─Scentless Mayweed
│  │   ├─Shepherds Purse
│  │   ├─Small-flowered Cranesbill
│  │   └─Sugar beet
│  └─train
│      ├─Black-grass
│      ├─Charlock
│      ├─Cleavers
│      ├─Common Chickweed
│      ├─Common wheat
│      ├─Fat Hen
│      ├─Loose Silky-bent
│      ├─Maize
│      ├─Scentless Mayweed
│      ├─Shepherds Purse
│      ├─Small-flowered Cranesbill
│      └─Sugar beet

新增格式转化脚本makedata.py,插入代码:

import glob
import os
import shutilimage_list=glob.glob('data1/*/*.png')
print(image_list)
file_dir='data'
if os.path.exists(file_dir):print('true')#os.rmdir(file_dir)shutil.rmtree(file_dir)#删除再建立os.makedirs(file_dir)
else:os.makedirs(file_dir)from sklearn.model_selection import train_test_split
trainval_files, val_files = train_test_split(image_list, test_size=0.3, random_state=42)
train_dir='train'
val_dir='val'
train_root=os.path.join(file_dir,train_dir)
val_root=os.path.join(file_dir,val_dir)
for file in trainval_files:file_class=file.replace("\\","/").split('/')[-2]file_name=file.replace("\\","/").split('/')[-1]file_class=os.path.join(train_root,file_class)if not os.path.isdir(file_class):os.makedirs(file_class)shutil.copy(file, file_class + '/' + file_name)for file in val_files:file_class=file.replace("\\","/").split('/')[-2]file_name=file.replace("\\","/").split('/')[-1]file_class=os.path.join(val_root,file_class)if not os.path.isdir(file_class):os.makedirs(file_class)shutil.copy(file, file_class + '/' + file_name)

完成上面的内容就可以开启训练和测试了,详见下面的链接:

ConvMAE实战:使用ConvMAE实现对植物幼苗的分类(非官方)(一)相关推荐

  1. MobileNetV1实战:使用MobileNetV1实现植物幼苗分类

    文章目录 摘要 数据增强Cutout和Mixup 项目结构 导入项目使用的库 设置全局参数 图像预处理与增强 读取数据 设置模型 定义训练和验证函数 测试 摘要 本例提取了植物幼苗数据集中的部分数据做 ...

  2. SENet实战详解:使用SE-ReSNet50实现对植物幼苗的分类

    摘要 1.SENet概述 ​ Squeeze-and-Excitation Networks(简称 SENet)是 Momenta 胡杰团队(WMW)提出的新的网络结构,利用SENet,一举取得最后一 ...

  3. RepLKNet实战:使用RepLKNet实现对植物幼苗的分类(非官方)(二)

    训练 完成上面的步骤后,就开始train脚本的编写,新建train.py. 导入项目使用的库 import json import os import shutil import matplotlib ...

  4. MobileNetV3 实战:植物幼苗分类(pytorch)

    文章目录 摘要 mobilenetv3简介 数据增强Cutout和Mixup 项目结构 导入项目使用的库 设置全局参数 图像预处理与增强 读取数据 设置模型 定义训练和验证函数 测试 摘要 本例提取了 ...

  5. Kaggle图像识别竞赛 Plant Seedlings Classification(植物幼苗分类)具体实现

    目录 0. 前言 1. 总体设计 2. import部分 3. 具体实现步骤 一.数据预处理 (一)均衡化 (二)提取图片中叶子(绿色)的部分 二.提取特征 (一)SIFT提取关键点 (二)BOW(B ...

  6. matlab幼苗识别,基于MATLAB的植物幼苗识别

    基于MATLAB的植物幼苗识别(论文11000字,外文翻译) 摘要:杂草种类繁多,严重影响了农作物的生产与产量,使用图像处理技术识别区分杂草和作物幼苗已成为一种最科学最有效的方法.通过提取植物图像的有 ...

  7. 深度学习图像分类:植物幼苗图像分类入门(Plant Seedlings Classification)

    前言:深度学习考试期末的题目,植物幼苗分类,可以帮助农业领域的进步. 题目介绍:kaggle原题:可以下载数据集,查看一些参与者的思路等. 易用的深度学习框架Keras简介及使用 部分图片如下: 思路 ...

  8. 自然语言处理入门实战2:基于深度学习的文本分类

    自然语言处理入门实战2:基于深度学习的文本分类 数据集 数据预处理 模型 模型训练 模型测试 参考 本文参考复旦大学自然语言处理入门练习,主要是实现基于深度学习的文本分类. 环境:python3.7 ...

  9. 深度学习实战——利用卷积神经网络对手写数字二值图像分类(附代码)

    系列文章目录 深度学习实战--利用卷积神经网络对手写数字二值图像分类(附代码) 目录 系列文章目录 前言 一.案例需求 二.MATLAB算法实现 三.MATLAB源代码 参考文献 前言 本案例利用MA ...

最新文章

  1. 大一计算机在线考试,Word 大一计算机考试操作题
  2. 工艺路线和工序有差别吗_ERP-工序与工艺路线
  3. 自建MySQL5.6数据库查询优化
  4. python培训中心-想学python,上海Python培训中心哪个好?
  5. python环境管理命令_conda管理Python环境
  6. java类加载过程_面试官:java类的加载过程
  7. MySQL 中 AUTO_INCREMENT 的“坑”--id不连续
  8. matlab中关于程序运行的快捷键
  9. dbforge studio for oracle,dbForge Studio for Oracle(数据库管理软件)官方版
  10. [论文阅读] Stereoscopically Attentive Multi-scale Network for Lightweight Salient Object Detection
  11. Tomcat 系统架构与设计模式之设计模式篇
  12. DBSCAN密度聚类算法
  13. Ignite SQL网格
  14. 计算机程序员的英文简历,电脑程序员个人英文简历范文
  15. 真正了解gets() fgets() getc() fgetc()的区别
  16. MySql必知必会学习
  17. 安卓端录像并将视频分享给微信好友
  18. 网络 - VXLAN
  19. JQuery快速入门之插件
  20. Mybatis入门(二)

热门文章

  1. 【leetcode 简单】 第一百五十题 两个列表的最小索引总和
  2. h5游戏接入googleplay时遇到的问题总结
  3. SE2431L-R高性能 完全集成的RF前端模块 ZigBee 低功耗 蓝牙1.0
  4. centos7开启网卡命令_CentOS7 开启网卡,设置开机启用网卡
  5. 访问我在BLOGBUS的博客吧
  6. 【华人学者风采】聂礼强 山东大学
  7. Docker多主机管理Docker Machine
  8. Http的get和post请求
  9. 图像增强——伽马变换
  10. Matlab笔记(二):Matlab实现高斯函数的三维显示