本文截取自《PyTorch 模型训练实用教程》,获取全文pdf请点击:https://github.com/tensor-yu/PyTorch_Tutorial

文章目录

  • Dataset类
  • 构建Dataset子类

想让PyTorch能读取我们自己的数据,首先要了解pytroch读取图片的机制和流程,然后按流程编写代码。

Dataset类

PyTorch读取图片,主要是通过Dataset类,所以先简单了解一下Dataset类。Dataset类作为所有的datasets的基类存在,所有的datasets都需要继承它,类似于C++中的虚基类。

源码如下:

class Dataset(object):
"""An abstract class representing a Dataset.
All other datasets should subclass it. All subclasses should override
``__len__``, that provides the size of the dataset, and ``__getitem__``,
supporting integer indexing in range from 0 to len(self) exclusive.
"""
def __getitem__(self, index):raise NotImplementedError
def __len__(self):raise NotImplementedError
def __add__(self, other):return ConcatDataset([self, other])

这里重点看 getitem函数,getitem接收一个index,然后返回图片数据和标签,这个index通常指的是一个list的index,这个list的每个元素就包含了图片数据的路径和标签信息。

然而,如何制作这个list呢,通常的方法是将图片的路径和标签信息存储在一个txt中,然后从该txt中读取。
那么读取自己数据的基本流程就是:

  1. 制作存储了图片的路径和标签信息的txt
  2. 将这些信息转化为list,该list每一个元素对应一个样本
  3. 通过getitem函数,读取数据和标签,并返回数据和标签

在训练代码里是感觉不到这些操作的,只会看到通过DataLoader就可以获取一个batch的数据,其实触发去读取图片这些操作的是DataLoader里的__iter__(self),后面会详细讲解读取过程。在本小节,主要讲Dataset子类。
因此,要让PyTorch能读取自己的数据集,只需要两步:

  1. 制作图片数据的索引

  2. 构建Dataset子类

  3. 制作图片数据的索引
    这个比较简单,就是读取图片路径,标签,保存到txt文件中,这里注意格式就好
    特别注意的是,txt中的路径,是以训练时的那个py文件所在的目录为工作目录,所以这里需要提前算好相对路径!
    运行代码 Code/1_data_prepare/1_3_generate_txt.py,即会在/Data/文件夹下面看到 train.txt valid.txt
    txt中是这样的:

构建Dataset子类

下面是本实验构建的Dataset子类——MyDataset类:

# coding: utf-8
from PIL import Image
from torch.utils.data import Dataset
class MyDataset(Dataset):
def __init__(self, txt_path, transform = None, target_transform = None):fh = open(txt_path, 'r')imgs = []for line in fh:line = line.rstrip()words = line.split()imgs.append((words[0], int(words[1])))self.imgs = imgs self.transform = transformself.target_transform = target_transform
def __getitem__(self, index):fn, label = self.imgs[index]img = Image.open(fn).convert('RGB') if self.transform is not None:img = self.transform(img) return img, label
def __len__(self):return len(self.imgs)

首先看看初始化,初始化中从我们准备好的txt里获取图片的路径和标签,并且存储在self.imgs,self.imgs就是上面提到的list,其一个元素对应一个样本的路径和标签,其实就是txt中的一行。

初始化中还会初始化transform,transform是一个Compose类型,里边有一个list,list中就会定义了各种对图像进行处理的操作,可以设置减均值,除标准差,随机裁剪,旋转,翻转,仿射变换等操作。

在这里我们可以知道,一张图片读取进来之后,会经过数据处理(数据增强),最终变成输入模型的数据。这里就有一点需要注意,PyTorch的数据增强是将原始图片进行了处理,并不会生成新的一份图片,而是“覆盖”原图,当采用randomcrop之类的随机操作时,每个epoch输入进来的图片几乎不会是一模一样的,这达到了样本多样性的功能。

然后看看核心的 getitem函数:

第一行:self.imgs 是一个list,也就是一开始提到的list,self.imgs的一个元素是一个元组,包含图片路径,图片标签,这些信息是从txt文件中读取

第二行:利用Image.open对图片进行读取,img类型为 Image ,mode=‘RGB’

第三行与第四行: 对图片进行处理,这个transform里边可以实现 减均值,除标准差,随机裁剪,旋转,翻转,放射变换,等等操作,这个放在后面会详细讲解。

当Mydataset构建好,剩下的操作就交给DataLoder,在DataLoder中,会触发Mydataset中的getiterm函数读取一张图片的数据和标签,并拼接成一个batch返回,作为模型真正的输入。下一小节将会通过一个小例子,介绍DataLoder是如何获取一个batch,以及一张图片是如何被PyTorch读取,最终变为模型的输入的。

转载请注明出处:https://blog.csdn.net/u011995719/article/details/85102770

PyTorch 学习笔记(一):让PyTorch读取你的数据集相关推荐

  1. 【Pytorch学习笔记2】Pytorch的主要组成模块

    个人笔记,仅用于个人学习与总结 感谢DataWhale开源组织提供的优秀的开源Pytorch学习文档:原文档链接 本文目录 1. Pytorch的主要组成模块 1.1 完成深度学习的必要部分 1.2 ...

  2. PyTorch学习笔记(19) ——NIPS2019 PyTorch: An Imperative Style, High-Performance Deep Learning Library

    0. 前言 波兰小哥Adam Paszke从15年的Torch开始,到现在发表了关于PyTorch的Neurips2019论文(令我惊讶的是只中了Poster?而不是Spotlight?).中间经历了 ...

  3. 【Pytorch学习笔记三】Pytorch神经网络包nn和优化器optm(一个简单的卷积神经网络模型的搭建)

    文章目录 一, 神经网络包nn 1.1定义一个网络 1.2 损失函数 二.优化器 nn构建于 Autograd之上,可用来定义和运行神经网络, PyTorch Autograd 让我们定义计算图和计算 ...

  4. 【PyTorch学习笔记_04】--- PyTorch(开始动手操作_案例1:手写字体识别)

    手写字体识别的流程 定义超参数(自己定义的参数) 构建transforms, 主要是对图像做变换 下载,加载数据集MNIST 构建网络模型(重要,自己定义) 定义训练方法 定义测试方法 开始训练模型, ...

  5. 深度学习入门之PyTorch学习笔记:多层全连接网络

    深度学习入门之PyTorch学习笔记 绪论 1 深度学习介绍 2 深度学习框架 3 多层全连接网络 3.1 PyTorch基础 3.2 线性模型 3.2.1 问题介绍 3.2.2 一维线性回归 3.2 ...

  6. pytorch学习笔记(2):在MNIST上实现一个CNN

    参考文档:https://mp.weixin.qq.com/s/1TtPWYqVkj2Gaa-3QrEG1A 这篇文章是在一个大家经常见到的数据集 MNIST 上实现一个简单的 CNN.我们会基于上一 ...

  7. Pytorch学习笔记总结

    往期Pytorch学习笔记总结: 1.Pytorch实战笔记_GoAI的博客-CSDN博客 2.Pytorch入门教程_GoAI的博客-CSDN博客 Pytorch系列目录: PyTorch学习笔记( ...

  8. PyTorch学习笔记(五):模型定义、修改、保存

    往期学习资料推荐: 1.Pytorch实战笔记_GoAI的博客-CSDN博客 2.Pytorch入门教程_GoAI的博客-CSDN博客 本系列目录: PyTorch学习笔记(一):PyTorch环境安 ...

  9. PyTorch学习笔记(四):PyTorch基础实战

    PyTorch实战:以FashionMNIST时装分类为例: 往期学习资料推荐: 1.Pytorch实战笔记_GoAI的博客-CSDN博客 2.Pytorch入门教程_GoAI的博客-CSDN博客 本 ...

  10. PyTorch学习笔记(三):PyTorch主要组成模块

    往期学习资料推荐: 1.Pytorch实战笔记_GoAI的博客-CSDN博客 2.Pytorch入门教程_GoAI的博客-CSDN博客 本系列目录: PyTorch学习笔记(一):PyTorch环境安 ...

最新文章

  1. crontab 备份mysql数据库_crontab定时备份mySQL数据库
  2. linux环境切换python3版本
  3. web前端三大主流框架_小猿圈web前端之前端的主流框架都有哪些?
  4. 并不是所有SAP产品的UX,都得遵循Fiori UX风格
  5. 【CodeForces - 520B】Two Buttons (bfs或dp或时光倒流,trick)
  6. maven如何合并两个war到一个war项目中
  7. 玩转大数据可视化,推荐几个必学的工具
  8. Trace obtained enqueue information by set event 10704
  9. python实现键盘自动输入
  10. C# winform cefsharp 截取网页元素图片
  11. Serialization assertion safeVersionRead == safeSerializationVersion failed.
  12. 深度模型压缩技术在智能座舱方案的探索与实践
  13. WPF 方块按钮 仿照360
  14. 腾讯云服务器nginx安装配置
  15. Wiredtiger 存储引擎概述
  16. 国防创新小组(DIU)选择Immervision InnovationLab为Blue UAS Framework项目开发计算机视觉广角摄相机
  17. python找不到第三方安装库
  18. a人工智能b大数据c云计算_ABC时代生产工具的是()。A、人工智能B、大数据C、云计算D、物联网...
  19. Losses Can Be Blessings: Routing Self-Supervised Speech Representations Towards Efficient Multilingu
  20. [转]Half Life 2 Source 引擎介绍

热门文章

  1. 用curl发起https请求
  2. vue中使用ts后,父组件获取执行子组件方法报错问题
  3. python第6天作业
  4. https://127.0.0.1:8080/test?param={%22..报错
  5. Struts2入门到放弃
  6. 【转】 C#学习笔记14——Trace、Debug和TraceSource的使用以及日志设计
  7. S03_CH03_AXI_DMA_OV7725摄像头采集系统
  8. 防暴力破解一些安全机制
  9. Box2DWeb_03之Shape
  10. ExtJS使用总结和参考