最近打算认真看下SwinTransformer算法,这里记录下今天看的build_loader部分的代码,步骤均加了汉语注释,希望可以帮到一些正在学习的朋友。

前言

build_loader是Swin Transformer代码中main.py的第一句,所以这里主要记录下Swin Transformer加载数据的过程。


在swin tranformer 中,数据预处理部分是在下面的文件夹中

build.py

try:from torchvision.transforms import InterpolationModedef _pil_interp(method):if method == 'bicubic':return InterpolationMode.BICUBICelif method == 'lanczos':return InterpolationMode.LANCZOSelif method == 'hamming':return InterpolationMode.HAMMINGelse:# default bilinear, do we want to allow nearest?return InterpolationMode.BILINEARimport timm.data.transforms as timm_transformstimm_transforms._pil_interp = _pil_interp
except:from timm.data.transforms import _pil_interp#建立数据加载器
def build_loader(config):#配置文件解冻 这一步是方便后续对配置文件config.MODEL.NUM_CLASSES 的更改config.defrost()#加载训练集dataset_train, config.MODEL.NUM_CLASSES = build_dataset(is_train=True, config=config)config.freeze()#配置文件锁print(f"local rank {config.LOCAL_RANK} / global rank {dist.get_rank()} successfully build train dataset")dataset_val, _ = build_dataset(is_train=False, config=config)#验证数据集print(f"local rank {config.LOCAL_RANK} / global rank {dist.get_rank()} successfully build val dataset")#分布式训练所需参数num_tasks = dist.get_world_size()global_rank = dist.get_rank()# 缓存模式为 part 即只取部分数据的情况if config.DATA.ZIP_MODE and config.DATA.CACHE_MODE == 'part':indices = np.arange(dist.get_rank(), len(dataset_train), dist.get_world_size())sampler_train = SubsetRandomSampler(indices)else:#将数据集均匀分散在不同的GPU上# 比如 GPU0:[0,2,4,6,8] GPU1:[1,3,5,7..]sampler_train = torch.utils.data.DistributedSampler(dataset_train, num_replicas=num_tasks, rank=global_rank, shuffle=True)#测试时候是否为顺序采样,即每次采样均为 0,1,2...顺序if config.TEST.SEQUENTIAL:sampler_val = torch.utils.data.SequentialSampler(dataset_val)else:#分布式采样sampler_val = torch.utils.data.distributed.DistributedSampler(dataset_val, shuffle=config.TEST.SHUFFLE)#训练集数据加载器data_loader_train = torch.utils.data.DataLoader(dataset_train, sampler=sampler_train,batch_size=config.DATA.BATCH_SIZE,num_workers=config.DATA.NUM_WORKERS,pin_memory=config.DATA.PIN_MEMORY,drop_last=True,)#验证集数据加载器data_loader_val = torch.utils.data.DataLoader(dataset_val, sampler=sampler_val,batch_size=config.DATA.BATCH_SIZE,shuffle=False,num_workers=config.DATA.NUM_WORKERS,pin_memory=config.DATA.PIN_MEMORY,drop_last=False)# setup mixup / cutmixmixup_fn = None# truemixup_active = config.AUG.MIXUP > 0 or config.AUG.CUTMIX > 0. or config.AUG.CUTMIX_MINMAX is not Noneif mixup_active: mixup_fn = Mixup(mixup_alpha=config.AUG.MIXUP, cutmix_alpha=config.AUG.CUTMIX, cutmix_minmax=config.AUG.CUTMIX_MINMAX,prob=config.AUG.MIXUP_PROB, switch_prob=config.AUG.MIXUP_SWITCH_PROB, mode=config.AUG.MIXUP_MODE,label_smoothing=config.MODEL.LABEL_SMOOTHING, num_classes=config.MODEL.NUM_CLASSES)return dataset_train, dataset_val, data_loader_train, data_loader_val, mixup_fn#建立数据集
#is_train : true:train dataset false:val dataset
def build_dataset(is_train, config):#进行数据变换transform = build_transform(is_train, config)#依据不同的数据格式进行划分数据if config.DATA.DATASET == 'imagenet':prefix = 'train' if is_train else 'val'#如果使用的是zip格式的数据集 具体可以参看 get_start.md的数据集例子if config.DATA.ZIP_MODE:ann_file = prefix + "_map.txt"prefix = prefix + ".zip@/"dataset = CachedImageFolder(config.DATA.DATA_PATH, ann_file, prefix, transform,cache_mode=config.DATA.CACHE_MODE if is_train else 'part')else:root = os.path.join(config.DATA.DATA_PATH, prefix)dataset = datasets.ImageFolder(root, transform=transform)nb_classes = 1000#数据的类别elif config.DATA.DATASET == 'imagenet22K':prefix = 'ILSVRC2011fall_whole'if is_train:ann_file = prefix + "_map_train.txt"else:ann_file = prefix + "_map_val.txt"dataset = IN22KDATASET(config.DATA.DATA_PATH, ann_file, transform)nb_classes = 21841else:raise NotImplementedError("We only support ImageNet Now.")return dataset, nb_classes#数据变换 预处理
def build_transform(is_train, config):#判断输入图像尺寸大小是否大于32 相关配置在 config.py 中resize_im = config.DATA.IMG_SIZE > 32if is_train:# this should always dispatch to transforms_imagenet_train#相关参数可以在 config.py 找到transform = create_transform(input_size=config.DATA.IMG_SIZE,is_training=True,color_jitter=config.AUG.COLOR_JITTER if config.AUG.COLOR_JITTER > 0 else None,auto_augment=config.AUG.AUTO_AUGMENT if config.AUG.AUTO_AUGMENT != 'none' else None,re_prob=config.AUG.REPROB,re_mode=config.AUG.REMODE,re_count=config.AUG.RECOUNT,interpolation=config.DATA.INTERPOLATION,)#如果输入图像大小过小  则采用随即裁剪的方式if not resize_im:# replace RandomResizedCropAndInterpolation with# RandomCroptransform.transforms[0] = transforms.RandomCrop(config.DATA.IMG_SIZE, padding=4)return transform# 非训练数据集的扩增方式t = []if resize_im:#如果测试过程中使用中心裁剪 先放大输入图像的尺寸 再中心裁剪#否则直接使用对应的插值策略进行设置输入图像大小if config.TEST.CROP:size = int((256 / 224) * config.DATA.IMG_SIZE)t.append(transforms.Resize(size, interpolation=_pil_interp(config.DATA.INTERPOLATION)),# to maintain same ratio w.r.t. 224 images)t.append(transforms.CenterCrop(config.DATA.IMG_SIZE))else:t.append(transforms.Resize((config.DATA.IMG_SIZE, config.DATA.IMG_SIZE),interpolation=_pil_interp(config.DATA.INTERPOLATION)))t.append(transforms.ToTensor())#将数据转为 tensor类型 (H*W*C)->(C*H*W)#进行归一化t.append(transforms.Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD))return transforms.Compose(t)#把多个数据处理步骤整合到一起

cached_image_folder.py


import io
import os
import time
import torch.distributed as dist
import torch.utils.data as data
from PIL import Imagefrom .zipreader import is_zip_path, ZipReader#检验文件是否符合所要求的文件后缀格式
def has_file_allowed_extension(filename, extensions):"""Checks if a file is an allowed extension.Args:filename (string): path to a fileReturns:bool: True if the filename ends with a known image extension"""filename_lower = filename.lower()return any(filename_lower.endswith(ext) for ext in extensions)#获得数据集中的种类
def find_classes(dir):#通过数据集文件夹下的子文件夹名称来获取类名classes = [d for d in os.listdir(dir) if os.path.isdir(os.path.join(dir, d))]classes.sort()#进行类名排序#将类名进行映射为从0开始的类别索引 {类名:类索引}class_to_idx = {classes[i]: i for i in range(len(classes))}#返回类别列表和类别索引字典return classes, class_to_idx#获取数据集中的所有图像数据
def make_dataset(dir, class_to_idx, extensions):images = []#如果是"~"开始的则进入系统根目录,如果不为"~"开头的则保持原路径dir = os.path.expanduser(dir)#获取文件列表for target in sorted(os.listdir(dir)):d = os.path.join(dir, target)#如果不为文件夹则跳过if not os.path.isdir(d):continue#便利文件子文件夹for root, _, fnames in sorted(os.walk(d)):for fname in sorted(fnames):#遍历所有文件#判断文件后缀是否正确if has_file_allowed_extension(fname, extensions):#图片的绝对路径path = os.path.join(root, fname)#构建图片路径和类别索引元组item = (path, class_to_idx[target])images.append(item)#返回images列表 元素为图片路径和类别索引元组#[("./img1.jpg",0),("./img2.jpg",0),...]return images#生成带有标注文件的图像列表
def make_dataset_with_ann(ann_file, img_prefix, extensions):images = []#读取标注文件with open(ann_file, "r") as f:contents = f.readlines()#读取所有的行#遍历内容列表for line_str in contents:# 遍历每一行的内容("\t")划分path_contents = [c for c in line_str.split('\t')]im_file_name = path_contents[0]#图片名字class_index = int(path_contents[1])#图片所属类别#判断文件后缀是否正确assert str.lower(os.path.splitext(im_file_name)[-1]) in extensions#构建图片路径和类别索引元组item = (os.path.join(img_prefix, im_file_name), class_index)images.append(item)#返回images列表 元素为图片路径和类别索引元组#[("./img1.jpg",0),("./img2.jpg",0),...]return imagesclass DatasetFolder(data.Dataset):"""A generic data loader where the samples are arranged in this way: ::root/class_x/xxx.extroot/class_x/xxy.extroot/class_x/xxz.extroot/class_y/123.extroot/class_y/nsdf3.extroot/class_y/asd932_.extArgs:root (string): Root directory path.loader (callable): A function to load a sample given its path.extensions (list[string]): A list of allowed extensions.transform (callable, optional): A function/transform that takes ina sample and returns a transformed version.E.g, ``transforms.RandomCrop`` for images.target_transform (callable, optional): A function/transform that takesin the target and transforms it.Attributes:samples (list): List of (sample path, class_index) tuples"""def __init__(self, root, loader, extensions, ann_file='', img_prefix='', transform=None, target_transform=None,cache_mode="no"):#根据不同数据集类型读取数据# image folder modeif ann_file == '':_, class_to_idx = find_classes(root)samples = make_dataset(root, class_to_idx, extensions)# zip modeelse:samples = make_dataset_with_ann(os.path.join(root, ann_file),os.path.join(root, img_prefix),extensions)if len(samples) == 0:raise (RuntimeError("Found 0 files in subfolders of: " + root + "\n" +"Supported extensions are: " + ",".join(extensions)))self.root = rootself.loader = loader#声明加载器self.extensions = extensions#samples为images列表 元素为图片路径和类别索引元组#[("./img1.jpg",0),("./img2.jpg",0),...]self.samples = samplesself.labels = [y_1k for _, y_1k in samples]#得到样本的类别索引self.classes = list(set(self.labels))#去重并构建类别索引列表self.transform = transform#样本的数据变换策略self.target_transform = target_transform#类标数据的变换self.cache_mode = cache_modeif self.cache_mode != "no":self.init_cache()def init_cache(self):assert self.cache_mode in ["part", "full"]n_sample = len(self.samples)#获取图像的数量#分布式采样global_rank = dist.get_rank()world_size = dist.get_world_size()#声明一个与样本数量同样大小的数据列表 默认为Nonesamples_bytes = [None for _ in range(n_sample)]start_time = time.time()#记录开始时间for index in range(n_sample):# 每n_sample // 10个样本为一个数据缓存块if index % (n_sample // 10) == 0:t = time.time() - start_timeprint(f'global_rank {dist.get_rank()} cached {index}/{n_sample} takes {t:.2f}s per block')start_time = time.time()path, target = self.samples[index]#path路径和类别索引if self.cache_mode == "full":samples_bytes[index] = (ZipReader.read(path), target)elif self.cache_mode == "part" and index % world_size == global_rank:samples_bytes[index] = (ZipReader.read(path), target)else:samples_bytes[index] = (path, target)self.samples = samples_bytes#[(path1, 0),(path2, 1),...]#返回变换过的图像数据和类标def __getitem__(self, index):"""Args:index (int): IndexReturns:tuple: (sample, target) where target is class_index of the target class."""path, target = self.samples[index]sample = self.loader(path)if self.transform is not None:sample = self.transform(sample)if self.target_transform is not None:target = self.target_transform(target)return sample, targetdef __len__(self):return len(self.samples)def __repr__(self):fmt_str = 'Dataset ' + self.__class__.__name__ + '\n'fmt_str += '    Number of datapoints: {}\n'.format(self.__len__())fmt_str += '    Root Location: {}\n'.format(self.root)tmp = '    Transforms (if any): 'fmt_str += '{0}{1}\n'.format(tmp, self.transform.__repr__().replace('\n', '\n' + ' ' * len(tmp)))tmp = '    Target Transforms (if any): 'fmt_str += '{0}{1}'.format(tmp, self.target_transform.__repr__().replace('\n', '\n' + ' ' * len(tmp)))return fmt_strIMG_EXTENSIONS = ['.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm', '.tif']#根据路径打开图片 并转为RGB
def pil_loader(path):# open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835)if isinstance(path, bytes):img = Image.open(io.BytesIO(path))elif is_zip_path(path):data = ZipReader.read(path)img = Image.open(io.BytesIO(data))else:with open(path, 'rb') as f:img = Image.open(f)return img.convert('RGB')return img.convert('RGB')#使用accimage 进行加载图像数据(是比PIL Image更快的一种库)
def accimage_loader(path):import accimagetry:return accimage.Image(path)except IOError:# Potentially a decoding problem, fall back to PIL.Imagereturn pil_loader(path)def default_img_loader(path):from torchvision import get_image_backendif get_image_backend() == 'accimage':return accimage_loader(path)else:return pil_loader(path)#DatasetFolder的继承类
class CachedImageFolder(DatasetFolder):"""A generic data loader where the images are arranged in this way: ::root/dog/xxx.pngroot/dog/xxy.pngroot/dog/xxz.pngroot/cat/123.pngroot/cat/nsdf3.pngroot/cat/asd932_.pngArgs:root (string): Root directory path.transform (callable, optional): A function/transform that  takes in an PIL imageand returns a transformed version. E.g, ``transforms.RandomCrop``target_transform (callable, optional): A function/transform that takes in thetarget and transforms it.loader (callable, optional): A function to load an image given its path.Attributes:imgs (list): List of (image path, class_index) tuples"""def __init__(self, root, ann_file='', img_prefix='', transform=None, target_transform=None,loader=default_img_loader, cache_mode="no"):super(CachedImageFolder, self).__init__(root, loader, IMG_EXTENSIONS,ann_file=ann_file, img_prefix=img_prefix,transform=transform, target_transform=target_transform,cache_mode=cache_mode)self.imgs = self.samplesdef __getitem__(self, index):"""Args:index (int): IndexReturns:tuple: (image, target) where target is class_index of the target class."""path, target = self.samples[index]image = self.loader(path)if self.transform is not None:img = self.transform(image)else:img = imageif self.target_transform is not None:target = self.target_transform(target)return img, target

【SwinTransformer源码阅读一】build_loader部分代码相关推荐

  1. Soul 网关源码阅读(二)代码初步运行

    Soul 源码阅读(二)代码初步运行 简介     基于上篇:Soul 源码阅读(一) 概览,这部分跑一下Soul网关的示例 过程记录     现在我们可以根据地图,稍微探索一下周边,摸一摸      ...

  2. Linux源码阅读——PCI总线驱动代码(一)整体框架

    目录 一.前言 二.概述 三.整体流程 四.PCI相关入口函数 4.1 pcibus_class_init 4.2 pci_driver_init 4.3 pci_arch_init 4.4 pci_ ...

  3. 【SwinTransformer源码阅读二】Window Attention和Shifted Window Attention部分

    先放一下SwinTransformer的整体结构,图片源于原论文,可以发现,在Transformer的Block中 W-MSA(Window based multi-head self attenti ...

  4. Linux源码阅读——PCI总线驱动代码(三)PCI设备枚举过程

    目录 前言 1.枚举过程 1.1 acpi_pci_root_add 1.2 pci_acpi_scan_root(枚举开始) 1.3 acpi_pci_root_create 1.4 pci_sca ...

  5. surefire 拉起 junit 单元测试类 源码阅读(一)

    根据surefire 拉起Junit单元测试类 输出的报错日志 跟踪执行过程: 日志1: java.lang.reflect.InvocationTargetExceptionat sun.refle ...

  6. Soul网关源码阅读(十)自定义简单插件编写

    Soul网关源码阅读(十)自定义简单插件编写 简介     综合前面所分析的插件处理流程相关知识,此次我们来编写自定义的插件:统计请求在插件链中的经历时长 编写准备     首先我们先探究一下,一个P ...

  7. Soul网关源码阅读(九)插件配置加载初探

    Soul网关源码阅读(九)插件配置加载初探 简介     今日来探索一下插件的初始化,及相关的配置的加载 源码Debug 插件初始化     首先来到我们非常熟悉的插件链调用的类: SoulWebHa ...

  8. Soul网关源码阅读(八)路由匹配初探

    Soul网关源码阅读(八)路由匹配初探 简介      今日看看路由的匹配相关代码,查看HTTP的DividePlugin匹配 示例运行      使用HTTP的示例,运行Soul-Admin,Sou ...

  9. Soul网关源码阅读(七)限流插件初探

    Soul网关源码阅读(七)限流插件初探 简介     前面的文章中对处理流程探索的差不多了,今天来探索下限流插件:resilience4j 示例运行 环境配置     启动下MySQL和redis d ...

最新文章

  1. linux入门(三)常见Linux指令及其用法
  2. API编程基本控件使用
  3. pycharm 调试程序时如何监控、监视变量?
  4. python之路-day18-反射
  5. 在package-lock.json中指定node-mass版本+独立编译flink中的flink-runtime-web模块
  6. 实验:PIO外部中断
  7. 怎么运行c语言_C语言 原来是这样调用硬件的
  8. unity3d collider自动调整大小_自动网格组合建模工具Unity游戏素材资源
  9. IT女性必备——5个方法变身小腰精
  10. bash 将二进制转换为十进制_用‘栈的思想编写一个十进制转换二进制、八进制或十六进制的程序...
  11. 对Moss 2007中访问群体的设置和使用补充
  12. 大前端页面布局插件收藏
  13. mysql界面导出数据库有乱码_导出的MYSQL数据库是乱码还可以变回中文吗
  14. 曼妙音色要靠煲 多媒体音箱煲机大法
  15. (译)《科学美国人》:多样的人际网络导致繁荣的本地经济
  16. 【毕设记录日记】深度学习|铝型材表面缺陷视觉检测算法:YOLOv5环境搭建、基础知识、问题解决、优化方法
  17. 在win7的iis下部署asp网站
  18. 0基础看-最大似然函数,原理,基本概念,例子
  19. 这款Python视频剪辑神器,牛逼!
  20. Q 2:真的是格局不够吗?

热门文章

  1. 达梦数据库通过dmp文件导入数据
  2. 【☠️️社死现场の老板来了☠️️】小伙,搞C语言嵌入式开发这么久了,还不知道u8、u16、u32、s8、s16、s32是什么意思啊?
  3. css 下划线_(06)CSS 给文本加样式:② 文本属性 | CSS
  4. 做一个派发工单的微信小程序
  5. 黑龙江省计算机三本学校,黑龙江省的公办三表大学有哪些?
  6. android锁屏(三)
  7. 如何用python的turtle画五角星_Python turtle 绘制五角星
  8. 算法思想记录:给定一个整数数组 nums 和一个目标值 target
  9. 深度神经网络(训练集,验证集,测试集), 提升模型效果,交叉验证
  10. Endnote引用文献时期刊名称不缩写问题-论文投稿经验总结-第1期