PyTorch学习—6.PyTorch数据读取机制Dataloader与Dataset
文章目录
- 一、PyTorch数据读取机制Dataloader
一、PyTorch数据读取机制Dataloader
PyTorch数据读取在Dataloader模块下,Dataloader又可以分为DataSet与Sampler。Sampler模块的功能是生成索引(样本序号);DataSet是依据索引读取Img、Lable。我们主要学习Dataloader与Dataset。
torch.utils.data.DataLoader()
DataLoader(dataset,batch_size=1,shuffle=False,sampler=None,batch_sampler=None,num_workers=0,collate_fn=None,pin_memory=False,drop_last=False,timeout=0,worker_init_fn=None,multiprocessing_context=None)
功能:构建可迭代的数据装载器
- dataset: Dataset类,决定数据从哪读取及如何读取
- batch_size :批大小
- num_works:是否多进程读取数据
- shuffle:每个epoch是否乱序
- drop_last :当样本数不能被batchsize整除时,是否舍弃最后一批数据
Epoch:所有训练样本都已输入到模型中,称为一个Epoch
Iteration:一批样本输入到模型中,称之为一个lteration
Batchsize:批大小,决定一个Epoch有多少个lteration
样本总数:80,Batchsize : 8
1 Epoch = 10 lteration
样本总数:87, Batchsize: 8
1 Epoch = 10 lteration ? drop_last = True
1 Epoch = 11 lteration drop_last = False
torch.utils.data.Dataset()
class Dataset(object):def __getitem__(self,index):raise NotImplementedErrordef __add__(self, other) :return ConcatDataset([self, other])
功能: Dataset抽象类,所有自定义的Dataset需要继承它,并且复写__getitem__()
getitem:接收一个索引,返回一个样本
数据读取流程如下:
for i, data in enumerate(train_loader):==># 判断是单进程还是多进程def __iter__(self):# 单进程if self.num_workers == 0:return _SingleProcessDataLoaderIter(self)# 多进程else:return _MultiProcessingDataLoaderIter(self)==>
# 以单进程为例
class _SingleProcessDataLoaderIter(_BaseDataLoaderIter):def __init__(self, loader):super(_SingleProcessDataLoaderIter, self).__init__(loader)assert self.timeout == 0assert self.num_workers == 0self.dataset_fetcher = _DatasetKind.create_fetcher(self.dataset_kind, self.dataset, self.auto_collation, self.collate_fn, self.drop_last)# 这个函数告诉我们每个iteration中读哪些数据def __next__(self):# index = self._next_index() # may raise StopIterationdata = self.dataset_fetcher.fetch(index) # may raise StopIterationif self.pin_memory:data = _utils.pin_memory.pin_memory(data)return datanext = __next__ # Python 2 compatibility==>def _next_index(self):return next(self.sampler_iter) # may raise StopIteration==># 利用sampler输出的index来进行采样def __iter__(self):batch = []# for idx in self.sampler:batch.append(idx)if len(batch) == self.batch_size:yield batchbatch = []if len(batch) > 0 and not self.drop_last:yield batch==>class _MapDatasetFetcher(_BaseDatasetFetcher):def __init__(self, dataset, auto_collation, collate_fn, drop_last):super(_MapDatasetFetcher, self).__init__(dataset, auto_collation, collate_fn, drop_last)def fetch(self, possibly_batched_index):if self.auto_collation:# 这一步实现了正式的数据读取data = [self.dataset[idx] for idx in possibly_batched_index]else:data = self.dataset[possibly_batched_index]return self.collate_fn(data)==>class RMBDataset(Dataset):def __init__(self, data_dir, transform=None):"""rmb面额分类任务的Dataset:param data_dir: str, 数据集所在路径:param transform: torch.transform,数据预处理"""self.label_name = {"1": 0, "100": 1}self.data_info = self.get_img_info(data_dir) # data_info存储所有图片路径和标签,在DataLoader中通过index读取样本self.transform = transformdef __getitem__(self, index):# 根据索引index获得数据与标签path_img, label = self.data_info[index]img = Image.open(path_img).convert('RGB') # 0~255if self.transform is not None:img = self.transform(img) # 在这里做transform,转为tensor等等return img, labeldef __len__(self):return len(self.data_info)@staticmethoddef get_img_info(data_dir):data_info = list()# 遍历一个目录内,各个子目录与子文件for root, dirs, _ in os.walk(data_dir):# 遍历类别for sub_dir in dirs:img_names = os.listdir(os.path.join(root, sub_dir))img_names = list(filter(lambda x: x.endswith('.jpg'), img_names))# 遍历图片for i in range(len(img_names)):img_name = img_names[i]path_img = os.path.join(root, sub_dir, img_name)label = rmb_label[sub_dir]data_info.append((path_img, int(label)))return data_info==>class _MapDatasetFetcher(_BaseDatasetFetcher):def __init__(self, dataset, auto_collation, collate_fn, drop_last):super(_MapDatasetFetcher, self).__init__(dataset, auto_collation, collate_fn, drop_last)def fetch(self, possibly_batched_index):if self.auto_collation:data = [self.dataset[idx] for idx in possibly_batched_index]else:data = self.dataset[possibly_batched_index]# 数据的整理器,将读取到的数据整理成batch的形式return self.collate_fn(data)==>for i, data in enumerate(train_loader):# forward# data由两个Tensor组成inputs, labels = data
数据整理器将数据由下面的形式:
转化为batch形式:
如果对您有帮助,麻烦点赞关注,这真的对我很重要!!!如果需要互关,请评论或者私信!
PyTorch学习—6.PyTorch数据读取机制Dataloader与Dataset相关推荐
- 数据读取机制Dataloader和Dataset和Transforms
人民币二分类模型 数据-模型-损失函数-优化器-迭代训练 数据收集 img label 数据划分 train valid test 数据读取 Dataloader [sampler-生成索引 data ...
- 系统学习Pytorch笔记三:Pytorch数据读取机制(DataLoader)与图像预处理模块(transforms)
Pytorch官方英文文档:https://pytorch.org/docs/stable/torch.html? Pytorch中文文档:https://pytorch-cn.readthedocs ...
- PyTorch框架学习八——PyTorch数据读取机制(简述)
PyTorch框架学习八--PyTorch数据读取机制(简述) 一.数据 二.DataLoader与Dataset 1.torch.utils.data.DataLoader 2.torch.util ...
- Pytorch Note52 灵活的数据读取介绍
Pytorch Note52 灵活的数据读取介绍 文章目录 Pytorch Note52 灵活的数据读取介绍 灵活的数据读取 读入数据 传入数据预处理方式 Dataset DataLoader 例子 ...
- tensorflow 1.0 学习:十图详解tensorflow数据读取机制
本文转自:https://zhuanlan.zhihu.com/p/27238630 在学习tensorflow的过程中,有很多小伙伴反映读取数据这一块很难理解.确实这一块官方的教程比较简略,网上也找 ...
- linux 读取大量图片 内存,10 张图帮你搞定 TensorFlow 数据读取机制
导读 在学习tensorflow的过程中,有很多小伙伴反映读取数据这一块很难理解.确实这一块官方的教程比较简略,网上也找不到什么合适的学习材料.今天这篇文章就以图片的形式,用最简单的语言,为大家详细解 ...
- Pytorch学习 - Task5 PyTorch卷积层原理和使用
Pytorch学习 - Task5 PyTorch卷积层原理和使用 1. 卷积层 (1)介绍 (torch.nn下的) 1) class torch.nn.Conv1d() 一维卷积层 2) clas ...
- 十图详解TensorFlow数据读取机制(附代码)
在学习TensorFlow的过程中,有很多小伙伴反映读取数据这一块很难理解.确实这一块官方的教程比较简略,网上也找不到什么合适的学习材料.今天这篇文章就以图片的形式,用最简单的语言,为大家详细解释一下 ...
- PyTorch学习记录——PyTorch进阶训练技巧
PyTorch学习记录--PyTorch进阶训练技巧 1.自定义损失函数 1.1 以函数的方式定义损失函数 1.2 以类的方式定义损失函数 1.3 比较与思考 2.动态调整学习率 2.1 官方提供的s ...
- Pytorch学习 - Task6 PyTorch常见的损失函数和优化器使用
Pytorch学习 - Task6 PyTorch常见的损失函数和优化器使用 官方参考链接 1. 损失函数 (1)BCELoss 二分类 计算公式 小例子: (2) BCEWithLogitsLoss ...
最新文章
- 快速学会MySQL常用操作方法
- 大多数人对AI的理解,都是错的
- .net core发布 正在发现数据上下文_使用EF Core实现数据库读写分离
- Nginx 安装配置【必须把文件到放到机器上】
- python用http协议传数据_python基础 -- 简单实现HTTP协议
- Oracle PCTfree assm,Oracle 12C LMT ASSM 完美测试
- Kendo UI开发教程(9): Kendo UI Validator 概述
- SQL必知必会-存储过程
- [人工智能]隔墙有眼,吓屎了
- python读写磁盘扇区数据_[Win32] 直接读写磁盘扇区(磁盘绝对读写)
- 微软彻底告别移动操作系统!
- Qt常用类——Qpoint
- 22. Magento 创建新闻模块(3)
- js数据结构hashMap -----hashMap
- java调用python机器学习模型的坑
- 【AI视野·今日CV 计算机视觉论文速览 第181期】Tue, 7 Apr 2020
- 【虚拟仿真】Unity3D中如何实现让3D模型显示在UI前面
- 像素三国志在线html5小游戏,像素三国志
- 大一ACM比赛观摩感悟(比赛)
- 【第三课】UAV倾斜摄影测量三维建模软件