在训练经典的数据集如cifar10,minsit等,可以用官方自带的数据集格式几行就写出来,如果是自己下载的数据集,那么我们应该如何用pytorch来读取呢?其实是有模板可以直接仿照着写的。

本次案例采用的是pokeman数据集,并用该数据集进行分类。该数据如下所示:


其中文件夹的名字便是标签。数据集大小划分为:皮卡丘 234、超梦239、杰尼龟223、小火龙 238、妙蛙种子234张图。

在深度学习中一般的流程是:加载数据—>构建模型—>训练和测试。

读取数据

在pytorch读取数据,采用3个步骤

  1. 继承torch中的通用的母类:torch.utils.data.Dataset
from torch.utils.data.Dataset
  1. __len __:这里需要返回定义数据的数量,返回整型数字
  2. __getitem __ :这里返回样本、标签等
一个简单的例子
from torch.utils.data import Dataset, DataLoader
class NumberDataset(Dataset):   #首先要继承Dataset母类def __init__(self, training=True):  #区分训练和测试if training:self.samples = list(range(1, 1001))   #加载数据,一般是存放数据的地址,不然内存爆炸else:self.samples = list(range(1001, 15001))def __len__(self):return len(self.samples)    #def __getitem__(self, idx):  # idx 是位置标号,在len(self.samples) 内,一个一个的读取该位置数据return self.samples[idx]

小结:1、首先得到所有的数据的地址名字(训练或测试);2、给出数据集长度;3、返回指定位置的数据内容,可以在该数据上进行任何预处理操作。

现在读取本次给的pokeman数据集

python代码框架为:

from torch.utils.data import Dataset, DataLoader  #自定义的母类,必须的
class Pokemon(Dataset):def __init__(self):        #去读数据路径super(Pokemon, self).__init__()passdef __len__(self):  #返回数据长度passdef __getitem__(self, idx):  #返回当前位置的数据和标签pass

接下来就是填充每一块函数里面的内容了。

1 将标签转化数字,且数据地址及其标签保存csv文件

首先需要加载数据和标签,因为标签需要转化成0,1,2,3,4,最好保存为csv文件,下次便可以直接加载csv文件。因此我们需要事先写一个函数保存csv文件,不写也可以,最好是写成csv。

下面这个函数可以单独写成一个文件,也可以放在class Pokemon(Dataset)里面。

 def load_csv(self, filename):if not os.path.exists(os.path.join(self.root, filename)): #如果没有保存csv文件,那么我们需要写一个csv文件,如果有了直接读取csv文件images = []for name in self.name2label.keys():   # 'pokemon\\mewtwo\\00001.pngimages += glob.glob(os.path.join(self.root, name, '*.png'))images += glob.glob(os.path.join(self.root, name, '*.jpg'))images += glob.glob(os.path.join(self.root, name, '*.jpeg'))# 1167, 'pokemon\\bulbasaur\\00000000.png'print(len(images), images)random.shuffle(images)with open(os.path.join(self.root, filename), mode='w', newline='') as f:writer = csv.writer(f)for img in images:  # 'pokemon\\bulbasaur\\00000000.png'name = img.split(os.sep)[-2]        #从名字就可以读取标签label = self.name2label[name]# 'pokemon\\bulbasaur\\00000000.png', 0writer.writerow([img, label])  #写进csv文件print('writen into csv file:', filename)# read from csv fileimages, labels = [], []with open(os.path.join(self.root, filename)) as f:reader = csv.reader(f)for row in reader:# 'pokemon\\bulbasaur\\00000000.png', 0img, label = rowlabel = int(label)images.append(img)labels.append(label)assert len(images) == len(labels)return images, labels
2 初始化函数

上面函数可以得到数据地址及其标签,接下来就是初始化,得到数据地址名和标签保存

    def __init__(self, root, resize, mode):super(Pokemon, self).__init__()self.root = rootself.resize = resizeself.name2label = {}  # "sq...":0for name in sorted(os.listdir(os.path.join(root))):if not os.path.isdir(os.path.join(root, name)):continueself.name2label[name] = len(self.name2label.keys()) #将英文标签名转化数字0-4# print(self.name2label)# image, labelself.images, self.labels = self.load_csv('images.csv')  #csv文件存在 直接读取if mode == 'train':  # 60%                   self.images = self.images[:int(0.6 * len(self.images))]self.labels = self.labels[:int(0.6 * len(self.labels))]elif mode == 'val':  # 20% = 60%->80%self.images = self.images[int(0.6 * len(self.images)):int(0.8 * len(self.images))]self.labels = self.labels[int(0.6 * len(self.labels)):int(0.8 * len(self.labels))]else:  # 20% = 80%->100%self.images = self.images[int(0.8 * len(self.images)):]self.labels = self.labels[int(0.8 * len(self.labels)):]
3 总体样本数量
    def __len__(self):return len(self.images)
4 取出当前位置的数据内容和标签等
    def __getitem__(self, idx):# idx~[0~len(images)]# self.images, self.labels# img: 'pokemon\\bulbasaur\\00000000.png'# label: 0img, label = self.images[idx], self.labels[idx]tf = transforms.Compose([   #常用的数据变换器lambda x:Image.open(x).convert('RGB'),  # string path= > image data #这里开始读取了数据的内容了transforms.Resize(   #数据预处理部分(int(self.resize * 1.25), int(self.resize * 1.25))), transforms.RandomRotation(15), transforms.CenterCrop(self.resize), #防止旋转后边界出现黑框部分transforms.ToTensor(),transforms.Normalize(mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225])])img = tf(img)label = torch.tensor(label)  #转化tensorreturn img, label       #返回当前的数据内容和标签
5 加载一个bathsize数据

完成上面的步骤,我们只能得到一个一个数据,且需用迭代器表示,即iter:

    db = Pokemon('pokemon', 64, 'train')x, y = next(iter(db))print('sample:', x.shape, y.shape, y)

因此还需要DataLoader来加载批量的数据:

  loader = DataLoader(db, batch_size=32, shuffle=True, num_workers=8)for x, y in loader: #此时x,y是批量的数据pass
6 可视化数据集

当我们完成数据集读取部分,可视化也是必须的。我们采用的是visdom来可视化。

    import visdomimport timefor x, y in loader:viz.images(db.denormalize(x), #因为对原始数据归一化,所以可视化需要返回去,该函数需要自己写下。nrow=8,  #每行显示8张图win='batch',opts=dict(title='batch'))viz.text(str(y.numpy()), win='label', opts=dict(title='batch-y'))time.sleep(10)

如果visdom连接超时,那么需要:

>python -m visdom.server

可以在网页上显示:

    def denormalize(self, x_hat):mean = [0.485, 0.456, 0.406]std = [0.229, 0.224, 0.225]# x_hat = (x-mean)/std# x = x_hat*std = mean# x: [c, h, w]# mean: [3] => [3, 1, 1]mean = torch.tensor(mean).unsqueeze(1).unsqueeze(1)std = torch.tensor(std).unsqueeze(1).unsqueeze(1)# print(mean.shape, std.shape)x = x_hat * std + meanreturn x
7 简单的文件分级,可以用一行代码搞定

如果文件结构是二级目录,且代码和文件夹在同一个目录:

那么可以用一行代码来写:

    tf = transforms.Compose([transforms.Resize((64,64)),transforms.ToTensor(),])db = torchvision.datasets.ImageFolder(root='pokemon', transform=tf) loader = DataLoader(db, batch_size=32, shuffle=True)print(db.class_to_idx)for x,y in loader:viz.images(x, nrow=8, win='batch', opts=dict(title='batch'))viz.text(str(y.numpy()), win='label', opts=dict(title='batch-y'))time.sleep(10)

用ImageFolder即可以写,不过该情况受限,因此不建议。还是用前面的函数自己去定义,方便对数据修改,或者额外引入标签。

接下来就是如何训练了,可参考我写的训练模板:https://blog.csdn.net/lifei1229/article/details/105530012
https://blog.csdn.net/lifei1229/article/details/105527312

如何用Pytorch读取自己的数据集相关推荐

  1. PyTorch 学习笔记(一):让PyTorch读取你的数据集

    本文截取自<PyTorch 模型训练实用教程>,获取全文pdf请点击:https://github.com/tensor-yu/PyTorch_Tutorial 文章目录 Dataset类 ...

  2. pytorch读取VOC数据集

    简单介绍VOC数据集 首先介绍下VOC2007数据集(下图是VOC数据集格式,为了叙述方便,我这里只放了两张图像) Main文件夹内的trainval.txt中的内容如下:存储了图像的名称不加后缀. ...

  3. pytorch 读取数据集(LiTS-肝肿瘤分割挑战数据集)

    pytorch 读取数据集 我的数据集长这样: xx.png和xx_mask.png是对应的待分割图像和ground truth 读取数据集 数据集对象被抽象为Dataset类,实现自定义的数据集需要 ...

  4. 如何用pytorch做文本摘要生成任务(加载数据集、T5 模型参数、微调、保存和测试模型,以及ROUGE分数计算)

    摘要:如何使用 Pytorch(或Pytorchlightning) 和 huggingface Transformers 做文本摘要生成任务,包括数据集的加载.模型的加载.模型的微调.模型的验证.模 ...

  5. python读数据-如何用Python读取开放数据?

    当你开始接触丰富多彩的开放数据集时,CSV.JSON和XML等格式名词就会奔涌而来.如何用Python高效地读取它们,为后续的整理和分析做准备呢?本文为你一步步展示过程,你自己也可以动手实践. 需求 ...

  6. opencv、matplotlib、pillow和pytorch读取数据的通道顺序

    文章目录: 1 opencv读取数据的通道顺序 1.1 opencv读取数据相关说明 1.2 显示opencv读取的数据 1.3 把opencv读取的BGR转换RGB的三种方式 2 matplotli ...

  7. (pytorch-深度学习系列)pytorch实现对Fashion-MNIST数据集进行图像分类

    pytorch实现对Fashion-MNIST数据集进行图像分类 导入所需模块: import torch import torchvision import torchvision.transfor ...

  8. Pytorch打怪路(三)Pytorch创建自己的数据集2

    前面一篇写创建数据集的博文--- Pytorch创建自己的数据集1 是介绍的应用于图像分类任务的数据集,即输入为一个图像和它的类别数字标签,本篇介绍输入的标签label亦为图像的数据集,并包含一些常用 ...

  9. Pytorch 目标检测和数据集

    Pytorch 目标检测和数据集 0. 环境介绍 环境使用 Kaggle 里免费建立的 Notebook 教程使用李沐老师的 动手学深度学习 网站和 视频讲解 小技巧:当遇到函数看不懂的时候可以按 S ...

最新文章

  1. Java数据结构与算法(25) - ch11哈希(双重哈希)
  2. angular-oauth2 —— NG 的 OAuth2 认证模块
  3. c3p0数据源配置抛出Could not load driverClass com.mysql.jdbc.Driver的解决方案
  4. 当会打王者荣耀的AI学会踢足球,一不小心拿下世界冠军!
  5. python读音发音器-python3 - 文本读音器
  6. bogofilter notes
  7. 透过汇编另眼看世界之多继承下的虚函数函数调用
  8. python性能分析工具模块_python——关于Python Profilers性能分析器
  9. 【渝粤教育】国家开放大学2019年春季 455物流实务 参考试题
  10. 针对口令的暴力破解攻击方式
  11. 宋宝华Linux培训笔记-Linux系统开发与工具
  12. 计算机应用技术专业就业方向分析
  13. python 英文关键词提取_python 利用jieba.analyse进行 关键词提取
  14. 996工作制该取消吗?
  15. 离线语音远程遥控车控门制作教程(二)
  16. android 联系人导入iphone,4种快速将联系人导入iPhone的方法
  17. 学习笔记_ncl_读取nc文件中的变量_制作nc文件的方法
  18. 芯片开发必读 | 什么是IP设计?为什么它很重要?
  19. GitHub个人主页默认模板
  20. LGD计划扩增OLED TV面板产能

热门文章

  1. The following signatures were invalid: EXPKEYSIG F42ED6FBAB17C654 的解决方法
  2. 2021年 IOS的发布流程(企业版那 无法下载,无法安装)
  3. word页边距调整步骤
  4. SDH与SONET(整理)
  5. python程序员面试宝典:12个Python程序员面试必备问题与答案
  6. 复合函数求导定义证明_复合函数求导法则证明方法的探讨
  7. wp兼容了android应用程序,WP比安卓流畅 但为什么就不好用呢?
  8. 计算机教师教学能手演讲,教学能手经验交流发言稿6篇
  9. 同步和异步修改页面传来的时间类型
  10. postman传String类型参数时不能加双引号