一、概述

初始化DataLoader类时必须注入一个参数dataset,而dataset为自己定义。DataSet类可以继承,但是必须重载__len__()和__getitem__

使用Pytoch封装的DataLoader有以下好处:

①可以自动实现多进程加载

②自动惰性加载,不会占用过多内存

③封装有数据预处理和数据增强等操作,避免重复造轮子

二、自定义DataSet

以Faster R-CNN为例,一般建议至少传入以下参数,方便后续使用:

class FRCNNDataset(Dataset):def __init__(self, annotation_lines, input_shape = [600, 600], train = True):self.annotation_lines   = annotation_lines        #数据集列表self.length             = len(annotation_lines)   #数据集大小self.input_shape        = input_shape             #输出尺寸self.train              = train                   #是否训练

然后重载__len__()和__getitem__

def __len__(self):return self.length    #直接返回长度
def __getitem__(self, index):index = index % self.length#训练时候对数据进行随机增强,但验证时不进行image, y = self.get_random_data(self.annotation_lines[index], self.input_shape[0:2], random = self.train)#将图片转换成矩阵image = np.transpose(preprocess_input(np.array(image, dtype=np.float32)), (2, 0, 1))#编码先验框box_data = np.zeros((len(y), 5))if len(y) > 0:box_data[:len(y)] = ybox = box_data[:, :4]label = box_data[:, -1]return image, box, label

关于数据增强函数get_random_data(),其中还包含了对图片的无变形缩放功能

def get_random_data(self, annotation_line, input_shape, jitter=.3, hue=.1, sat=0.7, val=0.4, random=True):# 数据经过处理后格式为:地址——(空格)——预测框,使用split函数即可切割出地址和先验框line = annotation_line.split()# 读取图像并转换为RGB格式image = Image.open(line[0])image = cvtColor(image)# 获得图像的高宽与目标高宽iw, ih = image.sizeh, w = input_shape# 读取先验框box = np.array([np.array(list(map(int,box.split(',')))) for box in line[1:]])

仅缩放的无变形缩放功(非训练模式)

# 在不进行随机数据增强的情况下(非训练模式),直接变形后输出
if not random:#获取变形比例scale = min(w/iw, h/ih)nw = int(iw*scale)nh = int(ih*scale)dx = (w-nw)//2dy = (h-nh)//2#   将图像多余的部分加上灰条image       = image.resize((nw,nh), Image.BICUBIC)new_image   = Image.new('RGB', (w,h), (128,128,128))new_image.paste(image, (dx, dy))image_data  = np.array(new_image, np.float32)#   对真实框进行调整if len(box)>0:np.random.shuffle(box)box[:, [0,2]] = box[:, [0,2]]*nw/iw + dxbox[:, [1,3]] = box[:, [1,3]]*nh/ih + dybox[:, 0:2][box[:, 0:2]<0] = 0box[:, 2][box[:, 2]>w] = wbox[:, 3][box[:, 3]>h] = hbox_w = box[:, 2] - box[:, 0]box_h = box[:, 3] - box[:, 1]box = box[np.logical_and(box_w>1, box_h>1)] # discard invalid box#返回图片和先验框return image_data, box

带数据增强的无变形缩放(训练模式)

        #   对图像进行缩放并且进行长和宽的扭曲new_ar = iw/ih * self.rand(1-jitter,1+jitter) / self.rand(1-jitter,1+jitter)scale = self.rand(.25, 2)if new_ar < 1:nh = int(scale*h)nw = int(nh*new_ar)else:nw = int(scale*w)nh = int(nw/new_ar)image = image.resize((nw,nh), Image.BICUBIC)#   将图像多余的部分加上灰条dx = int(self.rand(0, w-nw))dy = int(self.rand(0, h-nh))new_image = Image.new('RGB', (w,h), (128,128,128))new_image.paste(image, (dx, dy))image = new_image#   翻转图像flip = self.rand()<.5if flip: image = image.transpose(Image.FLIP_LEFT_RIGHT)image_data      = np.array(image, np.uint8)#   对图像进行色域变换#   计算色域变换的参数r               = np.random.uniform(-1, 1, 3) * [hue, sat, val] + 1#   将图像转到HSV上hue, sat, val   = cv2.split(cv2.cvtColor(image_data, cv2.COLOR_RGB2HSV))dtype           = image_data.dtype#   应用变换x       = np.arange(0, 256, dtype=r.dtype)lut_hue = ((x * r[0]) % 180).astype(dtype)lut_sat = np.clip(x * r[1], 0, 255).astype(dtype)lut_val = np.clip(x * r[2], 0, 255).astype(dtype)image_data = cv2.merge((cv2.LUT(hue, lut_hue), cv2.LUT(sat, lut_sat), cv2.LUT(val, lut_val)))image_data = cv2.cvtColor(image_data, cv2.COLOR_HSV2RGB)#   对真实框进行调整if len(box)>0:np.random.shuffle(box)box[:, [0,2]] = box[:, [0,2]]*nw/iw + dxbox[:, [1,3]] = box[:, [1,3]]*nh/ih + dyif flip: box[:, [0,2]] = w - box[:, [2,0]]box[:, 0:2][box[:, 0:2]<0] = 0box[:, 2][box[:, 2]>w] = wbox[:, 3][box[:, 3]>h] = hbox_w = box[:, 2] - box[:, 0]box_h = box[:, 3] - box[:, 1]box = box[np.logical_and(box_w>1, box_h>1)] return image_data, box

关于collate_fn参数

__getitem__一般返回(image,label)样本对,而DataLoader需要一个batch_size用于处理batch样本,以便于批量训练。

默认的default_collate(batch)函数仅能对尺寸一致且batch_size相同的image进行整理,如将(img0,lbl0),(img1,lbl1),(img2,lbl2)整合为([img0,img1,img2],[lbl0,lbl1,lbl2]),如图像中含有box等参数则需要自定义处理

def frcnn_dataset_collate(batch):images = []bboxes = []labels = []for img, box, label in batch:images.append(img)bboxes.append(box)labels.append(label)images = torch.from_numpy(np.array(images))return images, bboxes, labels

三、语义分割与目标检测DataSet的区别

①在__getitem__中不需要获取box值,转而获取标志图png

    def __getitem__(self, index):annotation_line = self.annotation_lines[index]name            = annotation_line.split()[0]#   从文件中读取图像jpg         = Image.open(os.path.join(os.path.join(self.dataset_path, "VOC2007/JPEGImages"), name + ".jpg"))png         = Image.open(os.path.join(os.path.join(self.dataset_path, "VOC2007/SegmentationClass"), name + ".png"))#   数据增强jpg, png    = self.get_random_data(jpg, png, self.input_shape, random = self.train)jpg         = np.transpose(preprocess_input(np.array(jpg, np.float64)), [2,0,1])png         = np.array(png)png[png >= self.num_classes] = self.num_classes#   转化成one_hot的形式#   在这里需要+1是因为voc数据集有些标签具有白边部分seg_labels  = np.eye(self.num_classes + 1)[png.reshape([-1])]seg_labels  = seg_labels.reshape((int(self.input_shape[0]), int(self.input_shape[1]), self.num_classes + 1))return jpg, png, seg_labels

get_random_data变形时需要对两张图做同样的变换

        if not random:iw, ih  = image.sizescale   = min(w/iw, h/ih)nw      = int(iw*scale)nh      = int(ih*scale)image       = image.resize((nw,nh), Image.BICUBIC)new_image   = Image.new('RGB', [w, h], (128,128,128))new_image.paste(image, ((w-nw)//2, (h-nh)//2))label       = label.resize((nw,nh), Image.NEAREST)new_label   = Image.new('L', [w, h], (0))new_label.paste(label, ((w-nw)//2, (h-nh)//2))return new_image, new_label

collate_fn需要进行修改

def deeplab_dataset_collate(batch):images      = []pngs        = []seg_labels  = []for img, png, labels in batch:images.append(img)pngs.append(png)seg_labels.append(labels)images      = torch.from_numpy(np.array(images)).type(torch.FloatTensor)pngs        = torch.from_numpy(np.array(pngs)).long()seg_labels  = torch.from_numpy(np.array(seg_labels)).type(torch.FloatTensor)return images, pngs, seg_labels

四、在训练过程中的调用

①读取文件集(经处理的txt文件)

with open(train_annotation_path, encoding='utf-8') as f:train_lines = f.readlines()
with open(val_annotation_path, encoding='utf-8') as f:val_lines   = f.readlines()
#获取数据集长度
num_train   = len(train_lines)
num_val     = len(val_lines)  

②检查数据集是否符合要求

这里一般检查数据集是否足够大,也可不检查

③将数据集装入DataSet中

train_dataset   = MyDataset(train_lines, input_shape, anchors, batch_size, num_classes, train = True)
val_dataset     = MyDataset(val_lines, input_shape, anchors, batch_size, num_classes, train = False)

④将DataSet放入DataLoader中

关于dataloader:一般有以下5个参数:

1.dataset:数据集对象,dataset型

2.batch_size:批大小,int型

3.shuffe:每一轮epoch是否重新洗牌,bool型

4.num_workers:多进程读取

5.drop_last:当样本不能被batch_size取整时,是否丢弃最后一批数据,bool型

gen = DataLoader(train_dataset, shuffle = shuffle, batch_size = batch_size, num_workers = num_workers, pin_memory=True,drop_last=True, collate_fn=ssd_dataset_collate, sampler=train_sampler)
gen_val = DataLoader(val_dataset  , shuffle = shuffle, batch_size = batch_size, num_workers = num_workers, pin_memory=True, drop_last=True, collate_fn=ssd_dataset_collate, sampler=val_sampler)

[Pytorch]将自己的数据集载入dataloader相关推荐

  1. TypeError: 'module' object is not callable (pytorch在进行MNIST数据集预览时出现的错误)

    在使用pytorch在对MNIST数据集进行预览时,出现了TypeError: 'module' object is not callable的错误: 上报错信息图如下: 从图中可以看出,报错位置为第 ...

  2. 基于pytorch的双模态数据载入

    基于pytorch的双模态数据载入 双模态数据融合 torch.utils.data.dataloader 双模态数据载入 双模态数据融合 无论是双模态,还是多模态融合,数据载入都是其重要的一环.如将 ...

  3. pytroch 数据集 datasets DataLoader示例

    pytroch 数据集 datasets DataLoader示例 # 安装依赖包 ! pip install torchvision Looking in indexes: https://pypi ...

  4. Pytorch(七) --加载数据集

    主要用到了Pytorch中的Dataset和DataLoadder这两个方法,其中Dataset是抽象类,不能实例化对象,只能继承用于构造数据集,DataLoader是帮助加载数据的,可以做shuff ...

  5. (pytorch-深度学习系列)pytorch实现对Fashion-MNIST数据集进行图像分类

    pytorch实现对Fashion-MNIST数据集进行图像分类 导入所需模块: import torch import torchvision import torchvision.transfor ...

  6. 速成pytorch学习——6天Dataset和DataLoader

    Pytorch通常使用Dataset和DataLoader这两个工具类来构建数据管道. Dataset定义了数据集的内容,它相当于一个类似列表的数据结构,具有确定的长度,能够用索引获取数据集中的元素. ...

  7. Pytorch打怪路(三)Pytorch创建自己的数据集2

    前面一篇写创建数据集的博文--- Pytorch创建自己的数据集1 是介绍的应用于图像分类任务的数据集,即输入为一个图像和它的类别数字标签,本篇介绍输入的标签label亦为图像的数据集,并包含一些常用 ...

  8. Pytorch 目标检测和数据集

    Pytorch 目标检测和数据集 0. 环境介绍 环境使用 Kaggle 里免费建立的 Notebook 教程使用李沐老师的 动手学深度学习 网站和 视频讲解 小技巧:当遇到函数看不懂的时候可以按 S ...

  9. 编写transformers的自定义pytorch训练循环(Dataset和DataLoader解析和实例代码)

    文章目录 一.Dataset和DataLoader加载数据集 1.torch.utils.data 2. 加载数据流程 3. Dataset 4. dataloader类及其参数 5. dataloa ...

最新文章

  1. 洛谷1216 数字三角形
  2. 水稻微生物组时间序列分析
  3. Linux在线求助 man page
  4. Xamarin.Android 使用Timer 并更改UI
  5. HUD 1043 Eight 八数码问题 A*算法 1667 The Rotation Game IDA*算法
  6. viewmodel+livedata+binding 实现listview+adapter
  7. android 车辆轨迹,Android自定义view实现车载可调整轨迹线
  8. addRoutes爬坑记
  9. 【Gamma】Scrum Meeting 6
  10. 【爬虫】微博数据采集
  11. 计算机网络说课教案,认识计算机网络说课稿PPT课件.ppt
  12. 2009年全国数模比赛,江苏三等奖名单
  13. 实用常识 | 写论文时如何引用插入脚注 / 如何自定义脚注符号 / 如何将多个脚注合并在一起
  14. 家用无线路由器选购指南。
  15. MATLAB中同一路径下同文件的末尾继续写入数据
  16. win7 64位 SEC S3C2410X Test B/D安装
  17. 夹角余弦 python_python 根据余弦定理计算两边的夹角
  18. flex是什么及flex布局语法
  19. 计算机的语言是美式英语,为什么电脑的语言栏一直有两国语言“CH中文(中国)”和“EH英语(美国)”...
  20. mysql中phpmyadmin安装教程_怎么安装phpMyAdmin?

热门文章

  1. sap crm行业解决方案_培训机构行业crm系统解决方案
  2. Linux 重新加载 nginx 配置命令
  3. 批量删除html网页,ie浏览器收藏夹网页批量删除方法
  4. 利用json实现vivo x20手机评论的爬取
  5. 5款主流智能音箱入门款测评:苹果小米华为天猫小度,谁的表现更胜一筹?
  6. 亲戚(relative)
  7. 大环境之下软件测试行业趋势能否上升?
  8. 生产环境RedisCPU飙高怎么办
  9. pycharm连接远程服务器以及踩的坑
  10. ultraiso刻录linux系统盘,使用UltraISO在Windows 10下刻录Ubuntu 18.04.2 U盘的方法