文章目录

  • 1. 准备数据集
    • 1.1 数据集存放格式
    • 1.2 config配置文件
  • 2. 训练
    • 2.1 代码中调整了的部分
    • 2.2 训练命令
  • 3. 评估
  • 4. 推理
    • 4.1 推理脚本
    • 4.2 推理命令
    • 4.3 推理结果

源代码地址:Swin-Transformer
本机为Ubuntu系统,为了训练自己的数据集,在原代码的基础上做了一点小调整:

  • 原代码中每个epoch保存一个模型,调整为只保存表现最佳的模型最后一个epoch的模型
  • 原代码训练的ImageNet数据集,数据类别比较多,输出了两个评估指标:Top-1 Acc和Top-5 Acc,但我自己数据集只有3个类别,调整为输出Top-1 Acc和Top-2 Acc(其实Top-2 Acc没啥用,不输出也可以的)
  • 原代码未细化每个类别的Acc,简单补充了下该信息在终端的输出
  • 原代码没有推理脚本,简单补充了一个

1. 准备数据集

1.1 数据集存放格式

── imagenet
├── train
│   ├── class1
│   │   ├── cat0001.jpg
│   │   ├── cat0002.jpg
│   │   └── ...
│   ├── class2
│   │   ├── dog0001.jpg
│   │   ├── dog0002.jpg
│   │   └── ...
│   └── class3
│       ├── bird0001.jpg
│       ├── bird0002.jpg
│       └── ...
└── val├── class1├── class2└── class3

1.2 config配置文件

swinv2_base_patch4_window12_192_22k.yaml为例

DATA:# 为了配合上方的数据集存放格式,DATASET的value需设置为imagenetDATASET: imagenetIMG_SIZE: 384# NAME_CLASSES是自己增加的,在推理阶段可视化时使用NAME_CLASSES: ["cat", "dog", "bird"]
MODEL:TYPE: swinv2NAME: swinv2_base_patch4_window12_192_22kDROP_PATH_RATE: 0.2# NUM_CLASSES是增加进来的默认是1000NUM_CLASSES: 3SWINV2:EMBED_DIM: 128DEPTHS: [ 2, 2, 18, 2 ]NUM_HEADS: [ 4, 8, 16, 32 ]WINDOW_SIZE: 12
TRAIN:EPOCHS: 90WARMUP_EPOCHS: 5WEIGHT_DECAY: 0.1BASE_LR: 1.25e-4 # 4096 batch-sizeWARMUP_LR: 1.25e-7MIN_LR: 1.25e-6

针对上方的调整相应地需要修改config.py文件

_C.DATA = CN()
# 增加NAME_CLASSES字段的默认值
_C.DATA.NAME_CLASSES = []

2. 训练

2.1 代码中调整了的部分

  • main.py
if __name__ == '__main__':args, config = parse_option()# 训练环境为本地单机单卡,手动写入环境变量中一些字段os.environ['WORLD_SIZE'] = '1'os.environ['RANK'] = '0'os.environ['MASTER_ADDR'] = 'localhost'os.environ['MASTER_PORT'] = '12345'# ...
if config.TRAIN.AUTO_RESUME:resume_file = auto_resume_helper(config.OUTPUT, get_best=True)
# 原代码中计算acc时输出的是top-1 acc和top-5 acc,但我自己的数据集只有3个类别
# 所以调整为输出top-1 acc和top-2 acc
# 增加了每个类别的acc的输出
def validate(config, data_loader, model):criterion = torch.nn.CrossEntropyLoss()model.eval()batch_time = AverageMeter()loss_meter = AverageMeter()acc1_meter = AverageMeter()acc2_meter = AverageMeter()cla_num_meter = np.zeros(config.MODEL.NUM_CLASSES)pre_num_meter = np.zeros(config.MODEL.NUM_CLASSES)end = time.time()for idx, (images, target) in enumerate(data_loader):images = images.cuda(non_blocking=True)target = target.cuda(non_blocking=True)# compute outputwith torch.cuda.amp.autocast(enabled=config.AMP_ENABLE):output = model(images)# measure accuracy and record lossloss = criterion(output, target)acc1, acc2 = accuracy(output, target, topk=(1, 2))cla_num, pre_num = cla_accuracy(output, target, config.MODEL.NUM_CLASSES)cla_num_meter += cla_numpre_num_meter += pre_numacc1 = reduce_tensor(acc1)acc2 = reduce_tensor(acc2)loss = reduce_tensor(loss)loss_meter.update(loss.item(), target.size(0))acc1_meter.update(acc1.item(), target.size(0))acc2_meter.update(acc2.item(), target.size(0))# measure elapsed timebatch_time.update(time.time() - end)end = time.time()if idx % config.PRINT_FREQ == 0:memory_used = torch.cuda.max_memory_allocated() / (1024.0 * 1024.0)logger.info(f'Test: [{idx}/{len(data_loader)}]\t'f'Time{batch_time.val:.3f}({batch_time.avg:.3f})\t'f'Loss{loss_meter.val:.4f}({loss_meter.avg:.4f})\t'f'Acc@1{acc1_meter.val:.3f}({acc1_meter.avg:.3f})\t'f'Acc@2{acc2_meter.val:.3f}({acc2_meter.avg:.3f})\t'f'Mem{memory_used:.0f}MB')logger.info(f' * Acc@1{acc1_meter.avg:.3f}Acc@2{acc2_meter.avg:.3f}')ans = ''acc_each_class = [pre_num_meter[i] / cla_num_meter[i] for i in range(config.MODEL.NUM_CLASSES)]for i in range(config.MODEL.NUM_CLASSES):ans += f'Acc of{config.DATA.NAME_CLASSES[i]}:{acc_each_class[i]}\t'logger.info(ans)return acc1_meter.avg, acc2_meter.avg, loss_meter.avgdef cla_accuracy(output, target, num_class):# 计算每个类别的实际数目和识别正确数目_, pred = output.topk(1, 1, True, True)pred = pred.t()[0]sam_nums = np.zeros(num_class)pre_cor_nums = np.zeros(num_class)for i in range(len(target)):sam_nums[int(target[i])] += 1if int(target[i]) == int(pred[i]):pre_cor_nums[int(target[i])] += 1return sam_nums, pre_cor_nums
# 原代码每个epoch保存一个模型,调整为只保存best_ckpt.pth和last_epoch_ckpt.pth
for epoch in range(config.TRAIN.START_EPOCH, config.TRAIN.EPOCHS):data_loader_train.sampler.set_epoch(epoch)train_one_epoch(config, model, criterion, data_loader_train, optimizer, epoch, mixup_fn, lr_scheduler,loss_scaler)acc1, acc2, loss = validate(config, data_loader_val, model)if dist.get_rank() == 0 and (epoch % config.SAVE_FREQ == 0 or epoch == (config.TRAIN.EPOCHS - 1)):if acc1 > max_accuracy:ckpt_name = "best_ckpt"else:ckpt_name = "last_epoch_ckpt"save_checkpoint(config, epoch, model_without_ddp, max_accuracy, optimizer, lr_scheduler, loss_scaler,logger, ckpt_name)logger.info(f"Accuracy of the network on the{len(dataset_val)}test images:{acc1:.1f}%")max_accuracy = max(max_accuracy, acc1)logger.info(f'Max accuracy:{max_accuracy:.2f}%')
  • data/build.py
def build_loader(config):config.defrost()# 原代码为dataset_train, config.MODEL.NUM_CLASSES =# 我们在config文件中已经指明了数据集类别数dataset_train, _ = build_dataset(is_train=True, config=config)
  • utils.py
# 修改代码resume时调用的是best_ckpt.pth
def auto_resume_helper(output_dir, get_best=False):checkpoints = os.listdir(output_dir)checkpoints = [ckpt for ckpt in checkpoints if ckpt.endswith('pth')]print(f"All checkpoints founded in{output_dir}:{checkpoints}")#  原本的代码是采用时间最近的模型,调整为读取best_ckpt.pthif len(checkpoints) > 0 and not get_best:latest_checkpoint = max([os.path.join(output_dir, d) for d in checkpoints], key=os.path.getmtime)print(f"The latest checkpoint founded:{latest_checkpoint}")resume_file = latest_checkpointelif get_best and "best_ckpt.pth" in checkpoints:print(f"The best checkpoint founded:{os.path.join(output_dir, 'best_ckpt.pth')}")resume_file = os.path.join(output_dir, 'best_ckpt.pth')else:resume_file = Nonereturn resume_file
def save_checkpoint(config, epoch, model, max_accuracy, optimizer, lr_scheduler, loss_scaler, logger, ckpt_name):save_state = {'model': model.state_dict(),'optimizer': optimizer.state_dict(),'lr_scheduler': lr_scheduler.state_dict(),'max_accuracy': max_accuracy,'scaler': loss_scaler.state_dict(),'epoch': epoch,'config': config}save_path = os.path.join(config.OUTPUT, f'{ckpt_name}.pth')logger.info(f"{save_path}saving......")torch.save(save_state, save_path)logger.info(f"{save_path}saved !!!")

2.2 训练命令

python main.py --cfg configs/swinv2/swinv2_base_patch4_window12_192_22k.yaml --batch-size 4 --data-path imagenet --pretrained swinv2_base_patch4_window12_192_22k.pth --local_rank 0

3. 评估

python main.py --eval --cfg configs/swinv2/swinv2_base_patch4_window12_192_22k.yaml --resume output/swinv2_base_patch4_window12_192_22k/default/best_ckpt.pth --data-path imagenet --local_rank 0

评估阶段的终端输出:

4. 推理

4.1 推理脚本

原作者没有提供inference代码,根据evaluate流程写一个简单的推理脚本。

import os
import argparse
from torch.autograd import Variable
import cv2import torch
from torchvision import transformsfrom config import get_config
from models import build_model
from PIL import Imagefrom timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STDtry:from torchvision.transforms import InterpolationModedef _pil_interp(method):if method == 'bicubic':return InterpolationMode.BICUBICelif method == 'lanczos':return InterpolationMode.LANCZOSelif method == 'hamming':return InterpolationMode.HAMMINGelse:# default bilinear, do we want to allow nearest?return InterpolationMode.BILINEARimport timm.data.transforms as timm_transformstimm_transforms._pil_interp = _pil_interp
except:from timm.data.transforms import _pil_interpdef parse_option():parser = argparse.ArgumentParser('Swin Transformer inference script', add_help=False)parser.add_argument('--cfg', type=str, required=True, metavar="FILE", help='path to config file', )parser.add_argument("--opts",help="Modify config options by adding 'KEY VALUE' pairs. ",default=None,nargs='+',)# easy config modificationparser.add_argument('--batch-size', type=int, help="batch size for single GPU")parser.add_argument('--data-path', type=str, help='path to dataset')parser.add_argument('--zip', action='store_true', help='use zipped dataset instead of folder dataset')parser.add_argument('--cache-mode', type=str, default='part', choices=['no', 'full', 'part'],help='no: no cache, ''full: cache all data, ''part: sharding the dataset into nonoverlapping pieces and only cache one piece')parser.add_argument('--pretrained',help='pretrained weight from checkpoint, could be imagenet22k pretrained weight')parser.add_argument('--resume', help='resume from checkpoint')parser.add_argument('--accumulation-steps', type=int, help="gradient accumulation steps")parser.add_argument('--use-checkpoint', action='store_true',help="whether to use gradient checkpointing to save memory")parser.add_argument('--disable_amp', action='store_true', help='Disable pytorch amp')parser.add_argument('--amp-opt-level', type=str, choices=['O0', 'O1', 'O2'],help='mixed precision opt level, if O0, no amp is used (deprecated!)')parser.add_argument('--output', default='output', type=str, metavar='PATH',help='root of output folder, the full path is <output>/<model_name>/<tag> (default: output)')parser.add_argument('--tag', help='tag of experiment')parser.add_argument('--eval', action='store_true', help='Perform evaluation only')parser.add_argument('--throughput', action='store_true', help='Test throughput only')# distributed trainingparser.add_argument("--local_rank", type=int, required=True, help='local rank for DistributedDataParallel')# for accelerationparser.add_argument('--fused_window_process', action='store_true',help='Fused window shift & window partition, similar for reversed part.')parser.add_argument('--fused_layernorm', action='store_true', help='Use fused layernorm.')## overwrite optimizer in config (*.yaml) if specified, e.g., fused_adam/fused_lambparser.add_argument('--optim', type=str,help='overwrite optimizer if provided, can be adamw/sgd/fused_adam/fused_lamb.')args, unparsed = parser.parse_known_args()config = get_config(args)return args, configif __name__ == '__main__':args, config = parse_option()transform_test = transforms.Compose([transforms.Resize((config.DATA.IMG_SIZE, config.DATA.IMG_SIZE), interpolation=_pil_interp(config.DATA.INTERPOLATION)),transforms.ToTensor(),transforms.Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD)])classes = config.DATA.NAME_CLASSESDEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")model = build_model(config)checkpoint = torch.load(config.MODEL.PRETRAINED, map_location='cpu')model.load_state_dict(checkpoint['model'], strict=False)model.eval()model.to(DEVICE)path = config.DATA.DATA_PATHtestList = os.listdir(path)for file in testList:img = Image.open(os.path.join(path + file))img = transform_test(img)img.unsqueeze_(0)img = Variable(img).to(DEVICE)out = model(img)_,pred = torch.max(out.data, 1)ori_img = cv2.imread(os.path.join(path + file))text = 'ImageName:{}, predict:{}'.format(file, classes[pred.data.item()])font = cv2.FONT_HERSHEY_SIMPLEXtxt_size = cv2.getTextSize(text, font, 0.7, 1)[0]x0 = int(ori_img.shape[1] / 2.0)cv2.putText(ori_img, text, (x0 - int(txt_size[0] / 2.0), int(0 + txt_size[1])), font, 0.7, (0, 0, 255), thickness=1)cv2.imshow(os.path.join(path, file), ori_img)cv2.waitKey(0)cv2.destroyAllWindows()

4.2 推理命令

python inference.py --cfg configs/swinv2/swinv2_base_patch4_window12_192_22k.yaml --data-path images/ --pretrained output/swinv2_base_patch4_window12_192_22k/default/best_ckpt.pth --local_rank 0

4.3 推理结果

Swin-Transformer图像分类相关推荐

  1. Swin Transformer实战:使用 Swin Transformer实现图像分类。

    Swin Transformer简介 目标检测刷到58.7 AP! 实例分割刷到51.1 Mask AP! 语义分割在ADE20K上刷到53.5 mIoU! 今年,微软亚洲研究院的Swin Trans ...

  2. Swin Transformer v2实战:使用Swin Transformer v2实现图像分类(一)

    Swin Transformer V2实战 摘要 安装包 安装timm 数据增强Cutout和Mixup EMA 项目结构 计算mean和std 生成数据集 摘要 Swin Transformer v ...

  3. Swin Transformer升级版来了!30亿参数,刷榜多项视觉任务,微软亚研原班人马打造...

    视学算法报道 编辑:杜伟.陈萍 微软亚洲研究院升级了 Swin Transformer,新版本具有 30 亿个参数,可以训练分辨率高达 1,536×1,536 的图像,并在四个具有代表性的基准上刷新纪 ...

  4. 《预训练周刊》第29期:Swin Transformer V2:扩大容量和分辨率、SimMIM:用于遮蔽图像建模的简单框架...

    No.29 智源社区 预训练组 预 训 练 研究 观点 资源 活动 关于周刊 本期周刊,我们选择了10篇预训练相关的论文,涉及图像处理.图像屏蔽编码.推荐系统.语言模型解释.多模态表征.多语言建模.推 ...

  5. ICCV 2021 Best Paper | Swin Transformer何以屠榜各大CV任务!

    作者:陀飞轮@知乎(已授权) 来源:https://zhuanlan.zhihu.com/p/360513527 编辑:智源社区 近日,Swin Transformer拿到2021 ICCV Best ...

  6. 当Swin Transformer遇上DCN,效果惊人!

    来源:机器之心 Transformer 近来在各种视觉任务上表现出卓越的性能,感受野赋予 Transformer 比 CNN 更强的表征能力.然而,简单地扩大感受野会引起一些问题.一方面,使用密集注意 ...

  7. 【知乎热议】如何看待swin transformer成为ICCV2021的 best paper?

    编辑:深度学习技术前沿 转载请注明来源,谢谢! [导读]今年ICCV2021, 在所有被接收的论文中,来自中国的论文数量占比最高,达到了 43.2%,约为第二位美国(23.6%)的两倍.中国学者凭借S ...

  8. 超越Swin Transformer!谷歌提出了收敛更快、鲁棒性更强、性能更强的NesT

    [导读]谷歌&罗格斯大学的研究员对ViT领域的分层结构设计进行了反思与探索,提出了一种简单的结构NesT,方法凭借68M参数取得了超越Swin Transformer的性能. 文章链接:htt ...

  9. 专访 Swin Transformer 作者胡瀚:面向计算机视觉中的「开放问题」 原创

    文 | 刘冰一.Echo 编辑 | 极市平台 本文原创首发于极市平台,转载请获得授权并标明出处. 胡瀚,湖北潜江人,本博均毕业于清华大学自动化系,曾就职于百度研究院深度学习实验室,目前任职于微软亚洲研 ...

  10. 霸榜各大CV任务榜单,Swin Transformer横空出世!

    1. ImageNet-1K的图像分类 Swin Transformer: Hierarchical Vision Transformer using Shifted Windows 一元@炼丹笔记 ...

最新文章

  1. vue 函数 路由跳转_vue中通过路由跳转的三种方式
  2. SQLServer:用户自定义数据类型用法
  3. 我发现我的Java重拍了!
  4. protected访问权限_权限修饰符 /重写
  5. 2020 年开发者生态报告:Python超越Java,Go、Kotlin强势崛起
  6. Weblogic常用监控指标
  7. Apache 配置多端口网站
  8. ubuntu linux u盘安装教程,U盘安装ubuntu的详细教程
  9. tomcat8安装及配置详细步骤(win10)
  10. 差分进化算法python_L单目标差分进化算法
  11. 【转】 Pro Android学习笔记(五八):Preferences(2):CheckBoxPreference
  12. 基于三维激光雷达的二维占据栅格地图构建-简介
  13. flash 与动画 轮盘旋转
  14. 几个鲜为人知但很有用的 HTML 属性
  15. Flink 清理 Checkpoint的原理和机制
  16. Graphics2D图片合成
  17. pin limiting the speed
  18. linux 的一些脑洞操作
  19. cae计算机仿真分析技术,cae分析.doc
  20. 移动通讯中的2G和2.5G以及3G概念

热门文章

  1. 回首系列01: 假如我的人生就像是在炒股
  2. 【数学】一些数学概念
  3. 【netcore基础】wwwroot下静态资源文件访问权限控制
  4. Meta半年亏损57.7亿美元也要搞元宇宙,听听扎克伯格自己是怎么说的
  5. 进程的休眠与唤醒(等待队列)
  6. python使用turtle绘制简单五角星图案
  7. FLY--互联网经典语录
  8. 求N个数的最大公因数(算法)
  9. ANSYS Fluent 电子产品散热计算案例
  10. 如何在Github快速找到资源(资源快速检索)