小白记录,大神勿扰

小白入门的时候,发现,现有的基本都是直接类似这样的:

trainset = datasets.MNIST('../MNIST', download=True,train=True, transform=transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.1307,), (0.3081,))]))
train_loader = DataLoader(trainset, batch_size=128, shuffle=True, num_workers=4)

这个download=True直接解决了一切问题,却不理解发生肾么事了。

而且由于网不好等原因,常常无法自动下载。

这个网上有一些方法,提前自己把数据下载好,放在download的那个目录下。

或者改源代码的下载链接为本地目录。

例如:https://zhuanlan.zhihu.com/p/129081723

有时候,大多数时候想用自己数据集,如下这样类似的写法:

class MyDataset(Dataset):def __init__(self, image_path, label_path, setid_path, train=True, transform=None):setid = scipy.io.loadmat(setid_path)labels = scipy.io.loadmat(label_path)['labels'][0]if train:trnid = setid['tstid'][0]self.labels = [labels[i - 1] - 1 for i in trnid]self.images = ['%s/image_%05d.jpg' % (image_path, i) for i in trnid]else:tstid = np.append(setid['valid'][0], setid['trnid'][0])self.labels = [labels[i - 1] - 1 for i in tstid]self.images = ['%s/image_%05d.jpg' % (image_path, i) for i in tstid]self.transform = transformdef __getitem__(self, index):label = self.labels[index]image = self.images[index]if self.transform is not None:image = self.transform(Image.open(image))return image, labeldef __len__(self):return len(self.labels)

init初始化,一般就包括加载数据啊,然后整体数据的一些基本处理之类的。数据可以来自自己定义放好的本地文件夹,也可以是自己在code之前就完成加载的numpy格式或者其他格式的数据(这时候init中就不需要加本地路径了)。

getitem,每次调用数据,其实就是调用它,后面index不要丢。内部一般就写 init之后,数据被加载之前 还需要进行的一些处理。这里,比如你要加载不一样的图像,这里return不同的就可以了。

len,就返回一个数据长度即可。

然后调用自己定义的数据集,MyDataset,再放到loader中,再从loader中直接拿数据就ok了,这时候拿到的数据就是一个batch一个batch的。

train_dataset = MyDataset(image_path, label_path, setid_path,train=True, transform=transforms.Compose([transforms.RandomHorizontalFlip(),transforms.RandomRotation(30),transforms.ToTensor(),transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]))train_loader = DataLoader(train_dataset, batch_size=BATCH, shuffle=True).......for batch_idx, (image, label) in enumerate(train_loader):...

好,然后怎么重载官网数据集,比如说,你载loader中,希望每次加载这样的数据,image1,image2,label

又是基于现有数据集,比如minist,那么就可以重写这个官网的数据集。本质上和完全自己定义是一回事。

示例代码如下:

class CIFAR10_(datasets.CIFAR10):"""CIFAR10 Dataset."""def __getitem__(self, index):img, target = self.data[index], self.targets[index]img = Image.fromarray(img)if self.target_transform is not None:target = self.target_transform(target)if self.transform is not None:img1 = self.transform(img)if self.train:img2 = self.transform(img)if self.train:return img1, img2, target, index

然后改怎么使用,就怎么使用,可以自己下载好:

trainset = datasets.CIFAR10_(root='./data', train=True, download=True, transform=transform_train)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=args.batch_size, shuffle=True, num_workers=4, drop_last=True)
testset = datasets.CIFAR10_(root='./data', train=True, download=False, transform=transform_test)
testloader = torch.utils.data.DataLoader(testset, batch_size=100, shuffle=False, num_workers=4, drop_last=True)

这时候train loader中出来的 就是 这个样子的:

for batch_idx, (inputs1, inputs2, target, indexes) in enumerate(trainloader):...

ok

Pytorch 怎么构建自己的数据集。怎么重写官方数据集。相关推荐

  1. 【深度学习】在PyTorch中构建高效的自定义数据集

    文章来源于磐创AI,作者磐创AI 学习Dataset类的来龙去脉,使用干净的代码结构,同时最大限度地减少在训练期间管理大量数据的麻烦. 神经网络训练在数据管理上可能很难做到"大规模" ...

  2. 【Pytorch】构建VOC2012数据集代码详解

    目录 数据集 图片读入 预处理 crop 标签和像素点颜色 随机翻转 噪声 标准化 torch.utils.data.Dataset()和torch.utils.data.DataLoader() t ...

  3. PyTorch如何构建和实验神经网络

    点击上方"视学算法",马上关注 真爱,请设置"星标"或点个"在看" 作者 | Tirthajyoti Sarkar 来源 | Medium ...

  4. PyTorch:采用sklearn 工具生成这样的合成数据集+利用PyTorch实现简单合成数据集上的线性回归进行数据分析

    PyTorch:采用sklearn 工具生成这样的合成数据集+利用PyTorch实现简单合成数据集上的线性回归进行数据分析 目录 输出结果 核心代码 输出结果 核心代码 #PyTorch:采用skle ...

  5. Pytorch 实现全连接神经网络/卷积神经网络训练MNIST数据集,并将训练好的模型在制作自己的手写图片数据集上测试

    使用教程 代码下载地址:点我下载 模型在训练过程中会自动显示训练进度,如果您的pytorch是CPU版本的,代码会自动选择CPU训练,如果有cuda,则会选择GPU训练. 项目目录说明: CNN文件夹 ...

  6. PyTorch框架:(2)使用PyTorch框架构建神经网络模型---气温预测

    目录 第一步:数据导入 第二步:将时间转换成标准格式(比如datatime格式) 第三步: 展示数据:(画了4个子图) 第四步:做独热编码 第五步:指定输入与输出 第六步:对数据做一个标准化 第七步: ...

  7. PyTorch版YOLOv4更新了,不仅适用于自定义数据集,还集成了注意力和MobileNet

    机器之心报道 作者:陈萍 距离 YOLO v4 的推出,已经过去 5 个多月.YOLO 框架采用 C 语言作为底层代码,这对于惯用 Python 的研究者来说,实在是有点不友好.因此网上出现了很多基于 ...

  8. 使用PyTorch从零开始构建Elman循环神经网络

    摘要: 循环神经网络是如何工作的?如何构建一个Elman循环神经网络?在这里,教你手把手创建一个Elman循环神经网络进行简单的序列预测. 本文以最简单的RNNs模型为例:Elman循环神经网络,讲述 ...

  9. 交互系统的构建之(一)重写Makefile编译TLD系统

    交互系统的构建之(一)重写Makefile编译TLD系统 zouxy09@qq.com http://blog.csdn.net/zouxy09 为了对TLD系统做一些功能的填充,例如添加语音合成来提 ...

最新文章

  1. 《Java大学教程》—第5章 数组
  2. linux文件需求管理,CaliberRM 需求管理系统
  3. 【译】Making Sense of Ethereum’s Layer 2 Scaling Solutions: State Channels, Plasma, and Truebit
  4. ASA 9.21 in Vmware Workstation 10
  5. SwitchHosts提示切换hosts失败!没有修改'C:\WINDOWS\system32\drivers\etc\hosts'的权限问题
  6. opencv22-直方图均衡化
  7. 使用Flask-SocketIO完成服务端和客户端的双向通信
  8. C# 把ABCD转换成数字
  9. 使用JSON实现分页
  10. python基础===open()文件处理使用介绍
  11. JSONP原理及使用
  12. Tab栏切换效果的制作
  13. python是自由开放源代码软件吗_附录:免费/自由和开放源码软件
  14. JAVA 配置文件 路径_Java配置文件读取和路径设置
  15. Springboot内置Tomcat原理
  16. 微信登录界面安卓代码_安卓Activity劫持与反劫持
  17. AWS、Azure等国外云计算如何迁移到国内阿里云上?
  18. wamp php 如何安装,WAMP的详细安装过程分享
  19. 计算机和红楼梦,电脑计算机与红楼梦的故事
  20. 差动保护类毕业论文文献包含哪些?

热门文章

  1. CentOS创建快捷按钮并设置文件图标
  2. 使用迭代器从map或vector中删除元素
  3. C# string.Format谨慎使用
  4. 潜入java内存结构
  5. b. Suffix Zeroes
  6. 动态规划(一)简单例子
  7. Hadoop权威指南pdf
  8. Android开发人员不得不收集的代码(不断更新中...)
  9. CSMA/CD协议——学习笔记
  10. poj 1106 Transmitters (枚举+叉积运用)