如何用Pytorch读取自己的数据集
在训练经典的数据集如cifar10,minsit等,可以用官方自带的数据集格式几行就写出来,如果是自己下载的数据集,那么我们应该如何用pytorch来读取呢?其实是有模板可以直接仿照着写的。
本次案例采用的是pokeman数据集,并用该数据集进行分类。该数据如下所示:
其中文件夹的名字便是标签。数据集大小划分为:皮卡丘 234、超梦239、杰尼龟223、小火龙 238、妙蛙种子234张图。
在深度学习中一般的流程是:加载数据—>构建模型—>训练和测试。
读取数据
在pytorch读取数据,采用3个步骤
- 继承torch中的通用的母类:torch.utils.data.Dataset
from torch.utils.data.Dataset
- __len __:这里需要返回定义数据的数量,返回整型数字
- __getitem __ :这里返回样本、标签等
一个简单的例子
from torch.utils.data import Dataset, DataLoader
class NumberDataset(Dataset): #首先要继承Dataset母类def __init__(self, training=True): #区分训练和测试if training:self.samples = list(range(1, 1001)) #加载数据,一般是存放数据的地址,不然内存爆炸else:self.samples = list(range(1001, 15001))def __len__(self):return len(self.samples) #def __getitem__(self, idx): # idx 是位置标号,在len(self.samples) 内,一个一个的读取该位置数据return self.samples[idx]
小结:1、首先得到所有的数据的地址名字(训练或测试);2、给出数据集长度;3、返回指定位置的数据内容,可以在该数据上进行任何预处理操作。
现在读取本次给的pokeman数据集
python代码框架为:
from torch.utils.data import Dataset, DataLoader #自定义的母类,必须的
class Pokemon(Dataset):def __init__(self): #去读数据路径super(Pokemon, self).__init__()passdef __len__(self): #返回数据长度passdef __getitem__(self, idx): #返回当前位置的数据和标签pass
接下来就是填充每一块函数里面的内容了。
1 将标签转化数字,且数据地址及其标签保存csv文件
首先需要加载数据和标签,因为标签需要转化成0,1,2,3,4,最好保存为csv文件,下次便可以直接加载csv文件。因此我们需要事先写一个函数保存csv文件,不写也可以,最好是写成csv。
下面这个函数可以单独写成一个文件,也可以放在class Pokemon(Dataset)里面。
def load_csv(self, filename):if not os.path.exists(os.path.join(self.root, filename)): #如果没有保存csv文件,那么我们需要写一个csv文件,如果有了直接读取csv文件images = []for name in self.name2label.keys(): # 'pokemon\\mewtwo\\00001.pngimages += glob.glob(os.path.join(self.root, name, '*.png'))images += glob.glob(os.path.join(self.root, name, '*.jpg'))images += glob.glob(os.path.join(self.root, name, '*.jpeg'))# 1167, 'pokemon\\bulbasaur\\00000000.png'print(len(images), images)random.shuffle(images)with open(os.path.join(self.root, filename), mode='w', newline='') as f:writer = csv.writer(f)for img in images: # 'pokemon\\bulbasaur\\00000000.png'name = img.split(os.sep)[-2] #从名字就可以读取标签label = self.name2label[name]# 'pokemon\\bulbasaur\\00000000.png', 0writer.writerow([img, label]) #写进csv文件print('writen into csv file:', filename)# read from csv fileimages, labels = [], []with open(os.path.join(self.root, filename)) as f:reader = csv.reader(f)for row in reader:# 'pokemon\\bulbasaur\\00000000.png', 0img, label = rowlabel = int(label)images.append(img)labels.append(label)assert len(images) == len(labels)return images, labels
2 初始化函数
上面函数可以得到数据地址及其标签,接下来就是初始化,得到数据地址名和标签保存:
def __init__(self, root, resize, mode):super(Pokemon, self).__init__()self.root = rootself.resize = resizeself.name2label = {} # "sq...":0for name in sorted(os.listdir(os.path.join(root))):if not os.path.isdir(os.path.join(root, name)):continueself.name2label[name] = len(self.name2label.keys()) #将英文标签名转化数字0-4# print(self.name2label)# image, labelself.images, self.labels = self.load_csv('images.csv') #csv文件存在 直接读取if mode == 'train': # 60% self.images = self.images[:int(0.6 * len(self.images))]self.labels = self.labels[:int(0.6 * len(self.labels))]elif mode == 'val': # 20% = 60%->80%self.images = self.images[int(0.6 * len(self.images)):int(0.8 * len(self.images))]self.labels = self.labels[int(0.6 * len(self.labels)):int(0.8 * len(self.labels))]else: # 20% = 80%->100%self.images = self.images[int(0.8 * len(self.images)):]self.labels = self.labels[int(0.8 * len(self.labels)):]
3 总体样本数量
def __len__(self):return len(self.images)
4 取出当前位置的数据内容和标签等
def __getitem__(self, idx):# idx~[0~len(images)]# self.images, self.labels# img: 'pokemon\\bulbasaur\\00000000.png'# label: 0img, label = self.images[idx], self.labels[idx]tf = transforms.Compose([ #常用的数据变换器lambda x:Image.open(x).convert('RGB'), # string path= > image data #这里开始读取了数据的内容了transforms.Resize( #数据预处理部分(int(self.resize * 1.25), int(self.resize * 1.25))), transforms.RandomRotation(15), transforms.CenterCrop(self.resize), #防止旋转后边界出现黑框部分transforms.ToTensor(),transforms.Normalize(mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225])])img = tf(img)label = torch.tensor(label) #转化tensorreturn img, label #返回当前的数据内容和标签
5 加载一个bathsize数据
完成上面的步骤,我们只能得到一个一个数据,且需用迭代器表示,即iter:
db = Pokemon('pokemon', 64, 'train')x, y = next(iter(db))print('sample:', x.shape, y.shape, y)
因此还需要DataLoader来加载批量的数据:
loader = DataLoader(db, batch_size=32, shuffle=True, num_workers=8)for x, y in loader: #此时x,y是批量的数据pass
6 可视化数据集
当我们完成数据集读取部分,可视化也是必须的。我们采用的是visdom来可视化。
import visdomimport timefor x, y in loader:viz.images(db.denormalize(x), #因为对原始数据归一化,所以可视化需要返回去,该函数需要自己写下。nrow=8, #每行显示8张图win='batch',opts=dict(title='batch'))viz.text(str(y.numpy()), win='label', opts=dict(title='batch-y'))time.sleep(10)
如果visdom连接超时,那么需要:
>python -m visdom.server
可以在网页上显示:
def denormalize(self, x_hat):mean = [0.485, 0.456, 0.406]std = [0.229, 0.224, 0.225]# x_hat = (x-mean)/std# x = x_hat*std = mean# x: [c, h, w]# mean: [3] => [3, 1, 1]mean = torch.tensor(mean).unsqueeze(1).unsqueeze(1)std = torch.tensor(std).unsqueeze(1).unsqueeze(1)# print(mean.shape, std.shape)x = x_hat * std + meanreturn x
7 简单的文件分级,可以用一行代码搞定
如果文件结构是二级目录,且代码和文件夹在同一个目录:
那么可以用一行代码来写:
tf = transforms.Compose([transforms.Resize((64,64)),transforms.ToTensor(),])db = torchvision.datasets.ImageFolder(root='pokemon', transform=tf) loader = DataLoader(db, batch_size=32, shuffle=True)print(db.class_to_idx)for x,y in loader:viz.images(x, nrow=8, win='batch', opts=dict(title='batch'))viz.text(str(y.numpy()), win='label', opts=dict(title='batch-y'))time.sleep(10)
用ImageFolder即可以写,不过该情况受限,因此不建议。还是用前面的函数自己去定义,方便对数据修改,或者额外引入标签。
接下来就是如何训练了,可参考我写的训练模板:https://blog.csdn.net/lifei1229/article/details/105530012
https://blog.csdn.net/lifei1229/article/details/105527312
如何用Pytorch读取自己的数据集相关推荐
- PyTorch 学习笔记(一):让PyTorch读取你的数据集
本文截取自<PyTorch 模型训练实用教程>,获取全文pdf请点击:https://github.com/tensor-yu/PyTorch_Tutorial 文章目录 Dataset类 ...
- pytorch读取VOC数据集
简单介绍VOC数据集 首先介绍下VOC2007数据集(下图是VOC数据集格式,为了叙述方便,我这里只放了两张图像) Main文件夹内的trainval.txt中的内容如下:存储了图像的名称不加后缀. ...
- pytorch 读取数据集(LiTS-肝肿瘤分割挑战数据集)
pytorch 读取数据集 我的数据集长这样: xx.png和xx_mask.png是对应的待分割图像和ground truth 读取数据集 数据集对象被抽象为Dataset类,实现自定义的数据集需要 ...
- 如何用pytorch做文本摘要生成任务(加载数据集、T5 模型参数、微调、保存和测试模型,以及ROUGE分数计算)
摘要:如何使用 Pytorch(或Pytorchlightning) 和 huggingface Transformers 做文本摘要生成任务,包括数据集的加载.模型的加载.模型的微调.模型的验证.模 ...
- python读数据-如何用Python读取开放数据?
当你开始接触丰富多彩的开放数据集时,CSV.JSON和XML等格式名词就会奔涌而来.如何用Python高效地读取它们,为后续的整理和分析做准备呢?本文为你一步步展示过程,你自己也可以动手实践. 需求 ...
- opencv、matplotlib、pillow和pytorch读取数据的通道顺序
文章目录: 1 opencv读取数据的通道顺序 1.1 opencv读取数据相关说明 1.2 显示opencv读取的数据 1.3 把opencv读取的BGR转换RGB的三种方式 2 matplotli ...
- (pytorch-深度学习系列)pytorch实现对Fashion-MNIST数据集进行图像分类
pytorch实现对Fashion-MNIST数据集进行图像分类 导入所需模块: import torch import torchvision import torchvision.transfor ...
- Pytorch打怪路(三)Pytorch创建自己的数据集2
前面一篇写创建数据集的博文--- Pytorch创建自己的数据集1 是介绍的应用于图像分类任务的数据集,即输入为一个图像和它的类别数字标签,本篇介绍输入的标签label亦为图像的数据集,并包含一些常用 ...
- Pytorch 目标检测和数据集
Pytorch 目标检测和数据集 0. 环境介绍 环境使用 Kaggle 里免费建立的 Notebook 教程使用李沐老师的 动手学深度学习 网站和 视频讲解 小技巧:当遇到函数看不懂的时候可以按 S ...
最新文章
- Java数据结构与算法(25) - ch11哈希(双重哈希)
- angular-oauth2 —— NG 的 OAuth2 认证模块
- c3p0数据源配置抛出Could not load driverClass com.mysql.jdbc.Driver的解决方案
- 当会打王者荣耀的AI学会踢足球,一不小心拿下世界冠军!
- python读音发音器-python3 - 文本读音器
- bogofilter notes
- 透过汇编另眼看世界之多继承下的虚函数函数调用
- python性能分析工具模块_python——关于Python Profilers性能分析器
- 【渝粤教育】国家开放大学2019年春季 455物流实务 参考试题
- 针对口令的暴力破解攻击方式
- 宋宝华Linux培训笔记-Linux系统开发与工具
- 计算机应用技术专业就业方向分析
- python 英文关键词提取_python 利用jieba.analyse进行 关键词提取
- 996工作制该取消吗?
- 离线语音远程遥控车控门制作教程(二)
- android 联系人导入iphone,4种快速将联系人导入iPhone的方法
- 学习笔记_ncl_读取nc文件中的变量_制作nc文件的方法
- 芯片开发必读 | 什么是IP设计?为什么它很重要?
- GitHub个人主页默认模板
- LGD计划扩增OLED TV面板产能
热门文章
- The following signatures were invalid: EXPKEYSIG F42ED6FBAB17C654 的解决方法
- 2021年 IOS的发布流程(企业版那 无法下载,无法安装)
- word页边距调整步骤
- SDH与SONET(整理)
- python程序员面试宝典:12个Python程序员面试必备问题与答案
- 复合函数求导定义证明_复合函数求导法则证明方法的探讨
- wp兼容了android应用程序,WP比安卓流畅 但为什么就不好用呢?
- 计算机教师教学能手演讲,教学能手经验交流发言稿6篇
- 同步和异步修改页面传来的时间类型
- postman传String类型参数时不能加双引号