[Pytorch]将自己的数据集载入dataloader
一、概述
初始化DataLoader类时必须注入一个参数dataset,而dataset为自己定义。DataSet类可以继承,但是必须重载__len__()和__getitem__
使用Pytoch封装的DataLoader有以下好处:
①可以自动实现多进程加载
②自动惰性加载,不会占用过多内存
③封装有数据预处理和数据增强等操作,避免重复造轮子
二、自定义DataSet
以Faster R-CNN为例,一般建议至少传入以下参数,方便后续使用:
class FRCNNDataset(Dataset):def __init__(self, annotation_lines, input_shape = [600, 600], train = True):self.annotation_lines = annotation_lines #数据集列表self.length = len(annotation_lines) #数据集大小self.input_shape = input_shape #输出尺寸self.train = train #是否训练
然后重载__len__()和__getitem__
def __len__(self):return self.length #直接返回长度
def __getitem__(self, index):index = index % self.length#训练时候对数据进行随机增强,但验证时不进行image, y = self.get_random_data(self.annotation_lines[index], self.input_shape[0:2], random = self.train)#将图片转换成矩阵image = np.transpose(preprocess_input(np.array(image, dtype=np.float32)), (2, 0, 1))#编码先验框box_data = np.zeros((len(y), 5))if len(y) > 0:box_data[:len(y)] = ybox = box_data[:, :4]label = box_data[:, -1]return image, box, label
关于数据增强函数get_random_data(),其中还包含了对图片的无变形缩放功能
def get_random_data(self, annotation_line, input_shape, jitter=.3, hue=.1, sat=0.7, val=0.4, random=True):# 数据经过处理后格式为:地址——(空格)——预测框,使用split函数即可切割出地址和先验框line = annotation_line.split()# 读取图像并转换为RGB格式image = Image.open(line[0])image = cvtColor(image)# 获得图像的高宽与目标高宽iw, ih = image.sizeh, w = input_shape# 读取先验框box = np.array([np.array(list(map(int,box.split(',')))) for box in line[1:]])
仅缩放的无变形缩放功(非训练模式)
# 在不进行随机数据增强的情况下(非训练模式),直接变形后输出
if not random:#获取变形比例scale = min(w/iw, h/ih)nw = int(iw*scale)nh = int(ih*scale)dx = (w-nw)//2dy = (h-nh)//2# 将图像多余的部分加上灰条image = image.resize((nw,nh), Image.BICUBIC)new_image = Image.new('RGB', (w,h), (128,128,128))new_image.paste(image, (dx, dy))image_data = np.array(new_image, np.float32)# 对真实框进行调整if len(box)>0:np.random.shuffle(box)box[:, [0,2]] = box[:, [0,2]]*nw/iw + dxbox[:, [1,3]] = box[:, [1,3]]*nh/ih + dybox[:, 0:2][box[:, 0:2]<0] = 0box[:, 2][box[:, 2]>w] = wbox[:, 3][box[:, 3]>h] = hbox_w = box[:, 2] - box[:, 0]box_h = box[:, 3] - box[:, 1]box = box[np.logical_and(box_w>1, box_h>1)] # discard invalid box#返回图片和先验框return image_data, box
带数据增强的无变形缩放(训练模式)
# 对图像进行缩放并且进行长和宽的扭曲new_ar = iw/ih * self.rand(1-jitter,1+jitter) / self.rand(1-jitter,1+jitter)scale = self.rand(.25, 2)if new_ar < 1:nh = int(scale*h)nw = int(nh*new_ar)else:nw = int(scale*w)nh = int(nw/new_ar)image = image.resize((nw,nh), Image.BICUBIC)# 将图像多余的部分加上灰条dx = int(self.rand(0, w-nw))dy = int(self.rand(0, h-nh))new_image = Image.new('RGB', (w,h), (128,128,128))new_image.paste(image, (dx, dy))image = new_image# 翻转图像flip = self.rand()<.5if flip: image = image.transpose(Image.FLIP_LEFT_RIGHT)image_data = np.array(image, np.uint8)# 对图像进行色域变换# 计算色域变换的参数r = np.random.uniform(-1, 1, 3) * [hue, sat, val] + 1# 将图像转到HSV上hue, sat, val = cv2.split(cv2.cvtColor(image_data, cv2.COLOR_RGB2HSV))dtype = image_data.dtype# 应用变换x = np.arange(0, 256, dtype=r.dtype)lut_hue = ((x * r[0]) % 180).astype(dtype)lut_sat = np.clip(x * r[1], 0, 255).astype(dtype)lut_val = np.clip(x * r[2], 0, 255).astype(dtype)image_data = cv2.merge((cv2.LUT(hue, lut_hue), cv2.LUT(sat, lut_sat), cv2.LUT(val, lut_val)))image_data = cv2.cvtColor(image_data, cv2.COLOR_HSV2RGB)# 对真实框进行调整if len(box)>0:np.random.shuffle(box)box[:, [0,2]] = box[:, [0,2]]*nw/iw + dxbox[:, [1,3]] = box[:, [1,3]]*nh/ih + dyif flip: box[:, [0,2]] = w - box[:, [2,0]]box[:, 0:2][box[:, 0:2]<0] = 0box[:, 2][box[:, 2]>w] = wbox[:, 3][box[:, 3]>h] = hbox_w = box[:, 2] - box[:, 0]box_h = box[:, 3] - box[:, 1]box = box[np.logical_and(box_w>1, box_h>1)] return image_data, box
关于collate_fn参数
__getitem__一般返回(image,label)样本对,而DataLoader需要一个batch_size用于处理batch样本,以便于批量训练。
默认的default_collate(batch)函数仅能对尺寸一致且batch_size相同的image进行整理,如将(img0,lbl0),(img1,lbl1),(img2,lbl2)整合为([img0,img1,img2],[lbl0,lbl1,lbl2]),如图像中含有box等参数则需要自定义处理
def frcnn_dataset_collate(batch):images = []bboxes = []labels = []for img, box, label in batch:images.append(img)bboxes.append(box)labels.append(label)images = torch.from_numpy(np.array(images))return images, bboxes, labels
三、语义分割与目标检测DataSet的区别
①在__getitem__中不需要获取box值,转而获取标志图png。
def __getitem__(self, index):annotation_line = self.annotation_lines[index]name = annotation_line.split()[0]# 从文件中读取图像jpg = Image.open(os.path.join(os.path.join(self.dataset_path, "VOC2007/JPEGImages"), name + ".jpg"))png = Image.open(os.path.join(os.path.join(self.dataset_path, "VOC2007/SegmentationClass"), name + ".png"))# 数据增强jpg, png = self.get_random_data(jpg, png, self.input_shape, random = self.train)jpg = np.transpose(preprocess_input(np.array(jpg, np.float64)), [2,0,1])png = np.array(png)png[png >= self.num_classes] = self.num_classes# 转化成one_hot的形式# 在这里需要+1是因为voc数据集有些标签具有白边部分seg_labels = np.eye(self.num_classes + 1)[png.reshape([-1])]seg_labels = seg_labels.reshape((int(self.input_shape[0]), int(self.input_shape[1]), self.num_classes + 1))return jpg, png, seg_labels
②get_random_data变形时需要对两张图做同样的变换
if not random:iw, ih = image.sizescale = min(w/iw, h/ih)nw = int(iw*scale)nh = int(ih*scale)image = image.resize((nw,nh), Image.BICUBIC)new_image = Image.new('RGB', [w, h], (128,128,128))new_image.paste(image, ((w-nw)//2, (h-nh)//2))label = label.resize((nw,nh), Image.NEAREST)new_label = Image.new('L', [w, h], (0))new_label.paste(label, ((w-nw)//2, (h-nh)//2))return new_image, new_label
③collate_fn需要进行修改
def deeplab_dataset_collate(batch):images = []pngs = []seg_labels = []for img, png, labels in batch:images.append(img)pngs.append(png)seg_labels.append(labels)images = torch.from_numpy(np.array(images)).type(torch.FloatTensor)pngs = torch.from_numpy(np.array(pngs)).long()seg_labels = torch.from_numpy(np.array(seg_labels)).type(torch.FloatTensor)return images, pngs, seg_labels
四、在训练过程中的调用
①读取文件集(经处理的txt文件)
with open(train_annotation_path, encoding='utf-8') as f:train_lines = f.readlines()
with open(val_annotation_path, encoding='utf-8') as f:val_lines = f.readlines()
#获取数据集长度
num_train = len(train_lines)
num_val = len(val_lines)
②检查数据集是否符合要求
这里一般检查数据集是否足够大,也可不检查
③将数据集装入DataSet中
train_dataset = MyDataset(train_lines, input_shape, anchors, batch_size, num_classes, train = True)
val_dataset = MyDataset(val_lines, input_shape, anchors, batch_size, num_classes, train = False)
④将DataSet放入DataLoader中
关于dataloader:一般有以下5个参数:
1.dataset:数据集对象,dataset型
2.batch_size:批大小,int型
3.shuffe:每一轮epoch是否重新洗牌,bool型
4.num_workers:多进程读取
5.drop_last:当样本不能被batch_size取整时,是否丢弃最后一批数据,bool型
gen = DataLoader(train_dataset, shuffle = shuffle, batch_size = batch_size, num_workers = num_workers, pin_memory=True,drop_last=True, collate_fn=ssd_dataset_collate, sampler=train_sampler)
gen_val = DataLoader(val_dataset , shuffle = shuffle, batch_size = batch_size, num_workers = num_workers, pin_memory=True, drop_last=True, collate_fn=ssd_dataset_collate, sampler=val_sampler)
[Pytorch]将自己的数据集载入dataloader相关推荐
- TypeError: 'module' object is not callable (pytorch在进行MNIST数据集预览时出现的错误)
在使用pytorch在对MNIST数据集进行预览时,出现了TypeError: 'module' object is not callable的错误: 上报错信息图如下: 从图中可以看出,报错位置为第 ...
- 基于pytorch的双模态数据载入
基于pytorch的双模态数据载入 双模态数据融合 torch.utils.data.dataloader 双模态数据载入 双模态数据融合 无论是双模态,还是多模态融合,数据载入都是其重要的一环.如将 ...
- pytroch 数据集 datasets DataLoader示例
pytroch 数据集 datasets DataLoader示例 # 安装依赖包 ! pip install torchvision Looking in indexes: https://pypi ...
- Pytorch(七) --加载数据集
主要用到了Pytorch中的Dataset和DataLoadder这两个方法,其中Dataset是抽象类,不能实例化对象,只能继承用于构造数据集,DataLoader是帮助加载数据的,可以做shuff ...
- (pytorch-深度学习系列)pytorch实现对Fashion-MNIST数据集进行图像分类
pytorch实现对Fashion-MNIST数据集进行图像分类 导入所需模块: import torch import torchvision import torchvision.transfor ...
- 速成pytorch学习——6天Dataset和DataLoader
Pytorch通常使用Dataset和DataLoader这两个工具类来构建数据管道. Dataset定义了数据集的内容,它相当于一个类似列表的数据结构,具有确定的长度,能够用索引获取数据集中的元素. ...
- Pytorch打怪路(三)Pytorch创建自己的数据集2
前面一篇写创建数据集的博文--- Pytorch创建自己的数据集1 是介绍的应用于图像分类任务的数据集,即输入为一个图像和它的类别数字标签,本篇介绍输入的标签label亦为图像的数据集,并包含一些常用 ...
- Pytorch 目标检测和数据集
Pytorch 目标检测和数据集 0. 环境介绍 环境使用 Kaggle 里免费建立的 Notebook 教程使用李沐老师的 动手学深度学习 网站和 视频讲解 小技巧:当遇到函数看不懂的时候可以按 S ...
- 编写transformers的自定义pytorch训练循环(Dataset和DataLoader解析和实例代码)
文章目录 一.Dataset和DataLoader加载数据集 1.torch.utils.data 2. 加载数据流程 3. Dataset 4. dataloader类及其参数 5. dataloa ...
最新文章
- 洛谷1216 数字三角形
- 水稻微生物组时间序列分析
- Linux在线求助 man page
- Xamarin.Android 使用Timer 并更改UI
- HUD 1043 Eight 八数码问题 A*算法 1667 The Rotation Game IDA*算法
- viewmodel+livedata+binding 实现listview+adapter
- android 车辆轨迹,Android自定义view实现车载可调整轨迹线
- addRoutes爬坑记
- 【Gamma】Scrum Meeting 6
- 【爬虫】微博数据采集
- 计算机网络说课教案,认识计算机网络说课稿PPT课件.ppt
- 2009年全国数模比赛,江苏三等奖名单
- 实用常识 | 写论文时如何引用插入脚注 / 如何自定义脚注符号 / 如何将多个脚注合并在一起
- 家用无线路由器选购指南。
- MATLAB中同一路径下同文件的末尾继续写入数据
- win7 64位 SEC S3C2410X Test B/D安装
- 夹角余弦 python_python 根据余弦定理计算两边的夹角
- flex是什么及flex布局语法
- 计算机的语言是美式英语,为什么电脑的语言栏一直有两国语言“CH中文(中国)”和“EH英语(美国)”...
- mysql中phpmyadmin安装教程_怎么安装phpMyAdmin?
热门文章
- sap crm行业解决方案_培训机构行业crm系统解决方案
- Linux 重新加载 nginx 配置命令
- 批量删除html网页,ie浏览器收藏夹网页批量删除方法
- 利用json实现vivo x20手机评论的爬取
- 5款主流智能音箱入门款测评:苹果小米华为天猫小度,谁的表现更胜一筹?
- 亲戚(relative)
- 大环境之下软件测试行业趋势能否上升?
- 生产环境RedisCPU飙高怎么办
- pycharm连接远程服务器以及踩的坑
- ultraiso刻录linux系统盘,使用UltraISO在Windows 10下刻录Ubuntu 18.04.2 U盘的方法