文章目录

  • 一、PyTorch数据读取机制Dataloader

一、PyTorch数据读取机制Dataloader

  PyTorch数据读取在Dataloader模块下,Dataloader又可以分为DataSet与Sampler。Sampler模块的功能是生成索引(样本序号);DataSet是依据索引读取Img、Lable。我们主要学习Dataloader与Dataset。
  torch.utils.data.DataLoader()

DataLoader(dataset,batch_size=1,shuffle=False,sampler=None,batch_sampler=None,num_workers=0,collate_fn=None,pin_memory=False,drop_last=False,timeout=0,worker_init_fn=None,multiprocessing_context=None)

功能:构建可迭代的数据装载器

  • dataset: Dataset类,决定数据从哪读取及如何读取
  • batch_size :批大小
  • num_works:是否多进程读取数据
  • shuffle:每个epoch是否乱序
  • drop_last :当样本数不能被batchsize整除时,是否舍弃最后一批数据

Epoch:所有训练样本都已输入到模型中,称为一个Epoch
Iteration:一批样本输入到模型中,称之为一个lteration
Batchsize:批大小,决定一个Epoch有多少个lteration
样本总数:80,Batchsize : 8
1 Epoch = 10 lteration
样本总数:87, Batchsize: 8
1 Epoch = 10 lteration ? drop_last = True
1 Epoch = 11 lteration drop_last = False

  torch.utils.data.Dataset()

class Dataset(object):def __getitem__(self,index):raise NotImplementedErrordef __add__(self, other) :return ConcatDataset([self, other])

功能: Dataset抽象类,所有自定义的Dataset需要继承它,并且复写__getitem__()
getitem:接收一个索引,返回一个样本
数据读取流程如下:

for i, data in enumerate(train_loader):==># 判断是单进程还是多进程def __iter__(self):# 单进程if self.num_workers == 0:return _SingleProcessDataLoaderIter(self)# 多进程else:return _MultiProcessingDataLoaderIter(self)==>
# 以单进程为例
class _SingleProcessDataLoaderIter(_BaseDataLoaderIter):def __init__(self, loader):super(_SingleProcessDataLoaderIter, self).__init__(loader)assert self.timeout == 0assert self.num_workers == 0self.dataset_fetcher = _DatasetKind.create_fetcher(self.dataset_kind, self.dataset, self.auto_collation, self.collate_fn, self.drop_last)# 这个函数告诉我们每个iteration中读哪些数据def __next__(self):# index = self._next_index()  # may raise StopIterationdata = self.dataset_fetcher.fetch(index)  # may raise StopIterationif self.pin_memory:data = _utils.pin_memory.pin_memory(data)return datanext = __next__  # Python 2 compatibility==>def _next_index(self):return next(self.sampler_iter)  # may raise StopIteration==># 利用sampler输出的index来进行采样def __iter__(self):batch = []# for idx in self.sampler:batch.append(idx)if len(batch) == self.batch_size:yield batchbatch = []if len(batch) > 0 and not self.drop_last:yield batch==>class _MapDatasetFetcher(_BaseDatasetFetcher):def __init__(self, dataset, auto_collation, collate_fn, drop_last):super(_MapDatasetFetcher, self).__init__(dataset, auto_collation, collate_fn, drop_last)def fetch(self, possibly_batched_index):if self.auto_collation:# 这一步实现了正式的数据读取data = [self.dataset[idx] for idx in possibly_batched_index]else:data = self.dataset[possibly_batched_index]return self.collate_fn(data)==>class RMBDataset(Dataset):def __init__(self, data_dir, transform=None):"""rmb面额分类任务的Dataset:param data_dir: str, 数据集所在路径:param transform: torch.transform,数据预处理"""self.label_name = {"1": 0, "100": 1}self.data_info = self.get_img_info(data_dir)  # data_info存储所有图片路径和标签,在DataLoader中通过index读取样本self.transform = transformdef __getitem__(self, index):# 根据索引index获得数据与标签path_img, label = self.data_info[index]img = Image.open(path_img).convert('RGB')     # 0~255if self.transform is not None:img = self.transform(img)   # 在这里做transform,转为tensor等等return img, labeldef __len__(self):return len(self.data_info)@staticmethoddef get_img_info(data_dir):data_info = list()# 遍历一个目录内,各个子目录与子文件for root, dirs, _ in os.walk(data_dir):# 遍历类别for sub_dir in dirs:img_names = os.listdir(os.path.join(root, sub_dir))img_names = list(filter(lambda x: x.endswith('.jpg'), img_names))# 遍历图片for i in range(len(img_names)):img_name = img_names[i]path_img = os.path.join(root, sub_dir, img_name)label = rmb_label[sub_dir]data_info.append((path_img, int(label)))return data_info==>class _MapDatasetFetcher(_BaseDatasetFetcher):def __init__(self, dataset, auto_collation, collate_fn, drop_last):super(_MapDatasetFetcher, self).__init__(dataset, auto_collation, collate_fn, drop_last)def fetch(self, possibly_batched_index):if self.auto_collation:data = [self.dataset[idx] for idx in possibly_batched_index]else:data = self.dataset[possibly_batched_index]# 数据的整理器,将读取到的数据整理成batch的形式return self.collate_fn(data)==>for i, data in enumerate(train_loader):# forward# data由两个Tensor组成inputs, labels = data

数据整理器将数据由下面的形式:
转化为batch形式:


如果对您有帮助,麻烦点赞关注,这真的对我很重要!!!如果需要互关,请评论或者私信!


PyTorch学习—6.PyTorch数据读取机制Dataloader与Dataset相关推荐

  1. 数据读取机制Dataloader和Dataset和Transforms

    人民币二分类模型 数据-模型-损失函数-优化器-迭代训练 数据收集 img label 数据划分 train valid test 数据读取 Dataloader [sampler-生成索引 data ...

  2. 系统学习Pytorch笔记三:Pytorch数据读取机制(DataLoader)与图像预处理模块(transforms)

    Pytorch官方英文文档:https://pytorch.org/docs/stable/torch.html? Pytorch中文文档:https://pytorch-cn.readthedocs ...

  3. PyTorch框架学习八——PyTorch数据读取机制(简述)

    PyTorch框架学习八--PyTorch数据读取机制(简述) 一.数据 二.DataLoader与Dataset 1.torch.utils.data.DataLoader 2.torch.util ...

  4. Pytorch Note52 灵活的数据读取介绍

    Pytorch Note52 灵活的数据读取介绍 文章目录 Pytorch Note52 灵活的数据读取介绍 灵活的数据读取 读入数据 传入数据预处理方式 Dataset DataLoader 例子 ...

  5. tensorflow 1.0 学习:十图详解tensorflow数据读取机制

    本文转自:https://zhuanlan.zhihu.com/p/27238630 在学习tensorflow的过程中,有很多小伙伴反映读取数据这一块很难理解.确实这一块官方的教程比较简略,网上也找 ...

  6. linux 读取大量图片 内存,10 张图帮你搞定 TensorFlow 数据读取机制

    导读 在学习tensorflow的过程中,有很多小伙伴反映读取数据这一块很难理解.确实这一块官方的教程比较简略,网上也找不到什么合适的学习材料.今天这篇文章就以图片的形式,用最简单的语言,为大家详细解 ...

  7. Pytorch学习 - Task5 PyTorch卷积层原理和使用

    Pytorch学习 - Task5 PyTorch卷积层原理和使用 1. 卷积层 (1)介绍 (torch.nn下的) 1) class torch.nn.Conv1d() 一维卷积层 2) clas ...

  8. 十图详解TensorFlow数据读取机制(附代码)

    在学习TensorFlow的过程中,有很多小伙伴反映读取数据这一块很难理解.确实这一块官方的教程比较简略,网上也找不到什么合适的学习材料.今天这篇文章就以图片的形式,用最简单的语言,为大家详细解释一下 ...

  9. PyTorch学习记录——PyTorch进阶训练技巧

    PyTorch学习记录--PyTorch进阶训练技巧 1.自定义损失函数 1.1 以函数的方式定义损失函数 1.2 以类的方式定义损失函数 1.3 比较与思考 2.动态调整学习率 2.1 官方提供的s ...

  10. Pytorch学习 - Task6 PyTorch常见的损失函数和优化器使用

    Pytorch学习 - Task6 PyTorch常见的损失函数和优化器使用 官方参考链接 1. 损失函数 (1)BCELoss 二分类 计算公式 小例子: (2) BCEWithLogitsLoss ...

最新文章

  1. 快速学会MySQL常用操作方法
  2. 大多数人对AI的理解,都是错的
  3. .net core发布 正在发现数据上下文_使用EF Core实现数据库读写分离
  4. Nginx 安装配置【必须把文件到放到机器上】
  5. python用http协议传数据_python基础 -- 简单实现HTTP协议
  6. Oracle PCTfree assm,Oracle 12C LMT ASSM 完美测试
  7. Kendo UI开发教程(9): Kendo UI Validator 概述
  8. SQL必知必会-存储过程
  9. [人工智能]隔墙有眼,吓屎了
  10. python读写磁盘扇区数据_[Win32] 直接读写磁盘扇区(磁盘绝对读写)
  11. 微软彻底告别移动操作系统!
  12. Qt常用类——Qpoint
  13. 22. Magento 创建新闻模块(3)
  14. js数据结构hashMap -----hashMap
  15. java调用python机器学习模型的坑
  16. 【AI视野·今日CV 计算机视觉论文速览 第181期】Tue, 7 Apr 2020
  17. 【虚拟仿真】Unity3D中如何实现让3D模型显示在UI前面
  18. 像素三国志在线html5小游戏,像素三国志
  19. 大一ACM比赛观摩感悟(比赛)
  20. 【第三课】UAV倾斜摄影测量三维建模软件

热门文章

  1. PHP (20140510)深入浅出 JavaScript 变量、作用域和内存 v 0.5
  2. FastSpring.NET V2.05 final 发布[集成Spring.net NHibernate Ajax]
  3. CSS hover 改变另外一个元素状态
  4. HDU-5441-离线化并查集
  5. 容器系列之虚拟化网络
  6. pycharm 激活相关
  7. 自学编程的12个网站
  8. C# typeof()实例详解
  9. 关于unity如何制作mmo
  10. 二级缓存:EHCache的使用