作为一个2年多的不资深keraser和tfer,被boss要求全员换成pytorch。不得不说,pytorch还是真香的。之前用keras,总会发现多GPU使用的情况下不太好,对计算资源的利用率不太高。把模型改成pytorch以后,发现资源利用率非常可观。

非常看好pytorch的前途,到时候能制衡一下tf就好了。闲话不多扯,我来讲讲初入pytorch最重要的东西:dataset

网上有很多介绍pytorch dataset类的文章,不过大多数都是讲解某一类任务的数据集模型建立。不太具有泛化性,本文将提出一个通用的数据集接口解决技巧,供大家参考

实验环境:

python==3.7.3

ubuntu==16.04

pytorch==1.1.0


dataset类

为什么木盏会说dataset是初入pytorch最重要的东西?因为我们复现项目的时候,最需要改的就是数据集。其他调调参改改模型问题都不大。

如果弄明白了pytorch中dataset类,你可以创建适应任意模型的数据集接口

所谓数据集,无非就是一组{x:y}的集合吗,你只需要在这个类里说明“有一组{x:y}的集合”就可以了。

对于图像分类任务,图像+分类

对于目标检测任务,图像+bbox、分类

对于超分辨率任务,低分辨率图像+超分辨率图像

对于文本分类任务,文本+分类

...

你只需定义好这个项目的x和y是什么。好了,上面都是扯闲篇,我们直接看dataset代码:

class Dataset(object):"""An abstract class representing a Dataset.All other datasets should subclass it. All subclasses should override``__len__``, that provides the size of the dataset, and ``__getitem__``,supporting integer indexing in range from 0 to len(self) exclusive."""def __getitem__(self, index):raise NotImplementedErrordef __len__(self):raise NotImplementedErrordef __add__(self, other):return ConcatDataset([self, other])

上面的代码是pytorch给出的官方代码,其中__getitem__和__len__是子类必须继承的。

很好解释,pytorch给出的官方代码限制了标准,你要按照它的标准进行数据集建立。首先,__getitem__就是获取样本对,模型直接通过这一函数获得一对样本对{x:y}。__len__是指数据集长度。

自己建立一个dataset试试:

class MyDataSet(Dataset):def __init__(self):self.sample_list = ...def __getitem__(self, index):x= ...y= ...return x, ydef __len__(self):return len(self.sample_list)

上面这个模板是本人定义好的,史称“木盏模板”。咱只需按照需求把模板填完就Ok了,那么为什么说这个模板使用于各种任务的数据集建造呢?还得依靠一个trick:通过txt文件映射

举个实例,假设我要给一个分类器训练喂数据,我的数据是images+number的组合,比如{img:3},这代表这个图像应该分在“3”类。我怎么写代码呢?

from torch.utils.data import Datasetclass MyDataSet(Dataset):def __init__(self, dataset_type, transform=None, update_dataset=False):"""dataset_type: ['train', 'test']"""dataset_path = '/home/muzhan/projects/dataset/'if update_dataset:make_txt_file(dataset_path)  # update datalistself.transform = transformself.sample_list = list()self.dataset_type = dataset_typef = open(dataset_path + self.dataset_type + '/datalist.txt')lines = f.readlines()for line in lines:self.sample_list.append(line.strip())f.close()def __getitem__(self, index):item = self.sample_list[index]# img = cv2.imread(item.split(' _')[0])img = Image.open(item.split(' _')[0])if self.transform is not None:img = self.transform(img)label = int(item.split(' _')[-1])return img, labeldef __len__(self):return len(self.sample_list)

上面有个transform参数,用于对数据集进行预处理的,可以根据项目选择使用。

上面有一个make_txt_file的函数需要说明一下,这个函数可以在数据集目录下创建一个txt文件,代表x和y的映射关系。这个函数大家可以自己写,一个简单脚本而已,我就不共享代码了 。(如有需要,留言告知)

我给大家看一下我的datalist.txt中的几行:

/home/muzhan/projects/dataset/test/250_04.png _0
/home/muzhan/projects/dataset/test/250_05.png _7
/home/muzhan/projects/dataset/test/250_06.png _3
/home/muzhan/projects/dataset/test/250_07.png _2
/home/muzhan/projects/dataset/test/250_08.png _2
/home/muzhan/projects/dataset/test/250_09.png _3
/home/muzhan/projects/dataset/test/250_10.png _4
/home/muzhan/projects/dataset/test/250_11.png _0
/home/muzhan/projects/dataset/test/250_12.png _9

这样就可以理解我在__getitem__函数中解析x和y的方法吧,在文本中用字符串' _'隔开,当然你可以用其他字符,能够保证剪切字符串不出错即可。

我们需要测试这个dataset类是否成功:

if __name__ == '__main__':ds = MyDataSet()print(ds.__len__())img, gt = ds.__getitem__(34) # get the 34th sampleprint(type(img))print(gt)

上面有输出,并且和你数据集一致,那证明这个dataset类是成功的。

有了这个,用DataLoader函数就可以加载我们的数据集了。

Pytorch中的dataset类——创建适应任意模型的数据集接口相关推荐

  1. PyTorch中nn.Module类中__call__方法介绍

    在PyTorch源码的torch/nn/modules/module.py文件中,有一条__call__语句和一条forward语句,如下: __call__ : Callable[-, Any] = ...

  2. 在pytorch中自定义dataset读取数据2021-1-8学习笔记

    在pytorch中自定义dataset读取数据 utils import os import json import pickle import randomimport matplotlib.pyp ...

  3. Java中创建线程需要使用的类_如何通过使用Java中的匿名类创建线程?

    甲线程是可以同时与该程序的其他部分被执行的功能.所有Java程序都有至少一个称为主线程的线程,该线程由Java虚拟机(JVM)在程序启动时由主线程调用main()方法创建. 在Java中,我们可以通过 ...

  4. 利用 AssemblyAI 在 PyTorch 中建立端到端的语音识别模型

    作者 | Comet 译者 | 天道酬勤,责编 | Carol 出品 | AI 科技大本营(ID:rgznai100) 这篇文章是由AssemblyAI的机器学习研究工程师Michael Nguyen ...

  5. gpu处理信号_在PyTorch中使用DistributedDataParallel进行多GPU分布式模型训练

    先进的深度学习模型参数正以指数级速度增长:去年的GPT-2有大约7.5亿个参数,今年的GPT-3有1750亿个参数.虽然GPT是一个比较极端的例子但是各种SOTA模型正在推动越来越大的模型进入生产应用 ...

  6. 通过‘PyQt6‘中的QWidget类创建一个含有按钮的窗口 1

    1.首先搭建一个基本的窗口 ,代码如下: import sys from PyQt6.QtWidgets import QApplication,QWidgetclass Add_func(QWidg ...

  7. PyTorch中nn.Module类简介

    torch.nn.Module类是所有神经网络模块(modules)的基类,它的实现在torch/nn/modules/module.py中.你的模型也应该继承这个类,主要重载__init__.for ...

  8. 在PyTorch中使用Seq2Seq构建的神经机器翻译模型

    在这篇文章中,我们将构建一个基于LSTM的Seq2Seq模型,使用编码器-解码器架构进行机器翻译. 本篇文章内容: 介绍 数据准备和预处理 长短期记忆(LSTM) - 背景知识 编码器模型架构(Seq ...

  9. pytorch-构建自己的dataset类

    如今,与keras .tf相比,pytorch高效的资源利用率,越来越多的Aier应用pytorch.我来讲讲初入pytorch最重要的东西:dataset 网上有很多介绍pytorch datase ...

最新文章

  1. OFRecord 数据格式
  2. python类中self是什么
  3. 北京区域赛I题,Uva7676,A Boring Problem,前缀和差分
  4. centos7.3安装mysql5.7 解决 Access denied for user 'root'@'localhost' (using password: NO)
  5. [Python] L1-008. 求整数段和-PAT团体程序设计天梯赛GPLT
  6. libevent源码深度剖析-张亮
  7. ​从 Spark Streaming 到 Apache Flink:bilibili 实时平台的架构与实践
  8. 深度学习——常用数据标注工具总结
  9. SSM医院挂号就诊预约系统 毕业设计-附源码250853
  10. Win7 没有声音的解决方法
  11. linux中设置中英文语言
  12. MySQL万字总结(含测试代码)
  13. 逆矩阵与矩阵的特征值的关系
  14. vosk实时语音识别
  15. python就业前景不好_担心学习Python就业情况不好?来看看Python发展前景
  16. 手工纸盒子_折纸盒子大全_10多种折纸盒子制作图解教程|怎么折纸盒子 - 聚巧网...
  17. 关于对数函数的引入理解
  18. JQuery EasyUI Datagrid 清空排序状态(箭头)代码
  19. 离线密码破解之John the Ripper
  20. 捷俊通地磅称重软件在垃圾处理厂中的应用

热门文章

  1. 物联网变身黑暗森林:僵尸网络、守护者、毁灭者层出不穷
  2. 国内杀毒软件的发展史
  3. html 全场开场动画,HTML5 星际大战电影开场字幕动画
  4. ctf流量分析练习二
  5. 运用Ntop监控网络流量(视频Demo)
  6. Centos7.X修改hostname立刻生效-修改/etc/hostname后立刻生效-Centos7.x修改hostname永久生效
  7. mysql 历史数据迁移,MySQL 历史数据表迁移方法
  8. 塞拉利昂首次秘密进行基于区块链的总统选举
  9. 项目经理的岗位职责和要求
  10. uboot分析之Loopback接口