torch中Dataset的构造与解读

Dataset的构造

要自定义自己的数据集,首先需要继承Dataset(torch.utils.data.Dataset)类.

继承Dataset类之后,必须重写三个方法:__init__(), __getitem__(), __len__()

class ModelNet40(Dataset):def __init__(self, xxx):...def __getitem__(self, item):...def __len()__(self):...

解读

单看上面的构造结构与三个需要重写的方法可能会一头雾水。我们详细分析其作用:

  1. __init__的作用
    __init__的作用与所有构造函数都一样,初始化一个类的实例。定义类的实际属性,如点云数据集中的unseen, guassian_noise等,是True还是False, 取出所有数据存储为成员变量等等。

  2. __getitem__的作用
    __getitem__的作用是,根据item的值取出数据。 item实际上就是索引值,会由Dataloader自动从0一直递增到__len__中取出的值。

  3. __len__的作用
    __len__的作用是,相当于返回整体数据data的shape[0], 即给item的递增指定一个范围。

例子

class ModelNet40(Dataset):def __init__(self, num_points, partition='train', gaussian_noise=False, unseen=False, factor=4):self.data, self.label = load_data(partition)self.num_points = num_pointsself.partition = partitionself.gaussian_noise = gaussian_noiseself.unseen = unseenself.label = self.label.squeeze()self.factor = factorif self.unseen:######## simulate testing on first 20 categories while training on last 20 categoriesif self.partition == 'test':self.data = self.data[self.label>=20]self.label = self.label[self.label>=20]elif self.partition == 'train':self.data = self.data[self.label<20]self.label = self.label[self.label<20]def __getitem__(self, item):pointcloud = self.data[item][:self.num_points]          # 核心代码,就是用item取出的数据if self.gaussian_noise:pointcloud = jitter_pointcloud(pointcloud)if self.partition != 'train':np.random.seed(item)anglex = np.random.uniform() * np.pi / self.factorangley = np.random.uniform() * np.pi / self.factoranglez = np.random.uniform() * np.pi / self.factorcosx = np.cos(anglex)cosy = np.cos(angley)cosz = np.cos(anglez)sinx = np.sin(anglex)siny = np.sin(angley)sinz = np.sin(anglez)Rx = np.array([[1, 0, 0],[0, cosx, -sinx],[0, sinx, cosx]])Ry = np.array([[cosy, 0, siny],[0, 1, 0],[-siny, 0, cosy]])Rz = np.array([[cosz, -sinz, 0],[sinz, cosz, 0],[0, 0, 1]])R_ab = Rx.dot(Ry).dot(Rz)R_ba = R_ab.Ttranslation_ab = np.array([np.random.uniform(-0.5, 0.5), np.random.uniform(-0.5, 0.5),np.random.uniform(-0.5, 0.5)])translation_ba = -R_ba.dot(translation_ab)pointcloud1 = pointcloud.Trotation_ab = Rotation.from_euler('zyx', [anglez, angley, anglex])pointcloud2 = rotation_ab.apply(pointcloud1.T).T + np.expand_dims(translation_ab, axis=1)euler_ab = np.asarray([anglez, angley, anglex])euler_ba = -euler_ab[::-1]pointcloud1 = np.random.permutation(pointcloud1.T).Tpointcloud2 = np.random.permutation(pointcloud2.T).Tprint(item)print(pointcloud1.shape)return pointcloud1.astype('float32'), pointcloud2.astype('float32'), R_ab.astype('float32'), \translation_ab.astype('float32'), R_ba.astype('float32'), translation_ba.astype('float32'), \euler_ab.astype('float32'), euler_ba.astype('float32')def __len__(self):return self.data.shape[0]       # 给item一个范围

进一步理解其执行逻辑

if __name__ == '__main__':dataset1 = ModelNet40(num_points=1024, partition='train', gaussian_noise=True)dataloader = DataLoader(dataset1, batch_size=64, shuffle=False)count = 0for src_pointcloud, tgt_pointcloud, Rotation, translation, _, _, _, _ in dataloader:print(src_pointcloud.shape)count += 1print(count)

首先需要说明的是,在ModelNet40中,getitem中会打印item的当前值。

如果执行这段代码,在shuffle=False的情况下,其结果为:

item从0一直增加到__len__返回的那个值-1, 也就是data的第一维(姑且称为batch维)。

在getitem中取出的pointcloud的shape为(3, 1024),只有两个axis.

而最后输出的count,也就是main函数中整个for循环执行的次数,会是__len__() / batch_size.

比如len是9480,即self.data的shape为(9480, 2048, 3),那么item就会从0一直增加到9479. 在batch_size为64的情况下,for循环一共执行(即count为) 9480/64 = 148.125, 那么最终会执行149次。 也就是说,每次for循环实质上调用了getitem方法64次,最后在第一维上stack,使之shape变为(64, 3, 1024).

torch中Dataset的构造与解读相关推荐

  1. .Net 中DataSet和DataTable的 区别与联系

    1.简要说明二者关系 在我们编写代码的时候从数据库里取出数据,填充到dataset里,再根据表的名字,实例化到 datatable 中.其实使用 dataset 相当于所使用数据库中数据的副本,保存在 ...

  2. PyTorch训练中Dataset多线程加载数据,比Dataloader里设置多个workers还要快

    PyTorch训练中Dataset多线程加载数据,而不是在DataLoader 背景与需求 现在做深度学习的越来越多人都有用PyTorch,他容易上手,而且API相对TF友好的不要太多.今天就给大家带 ...

  3. 【PyTorch训练中Dataset多线程加载数据,比Dataloader里设置多个workers还要快】

    文章目录 一.引言 二.背景与需求 三.方法的实现 四.代码与数据测试 五.测试结果 5.1.Max elapse 5.2.Multi Load Max elapse 5.3.Min elapse 5 ...

  4. java中的静态、动态代理模式以及Spring中的CgLib动态代理解读(面试必问)

    java中的静态.动态代理模式以及Spring中的CgLib动态代理解读(面试必问) 静态代理 动态代理 CgLib动态代理     基础知: 反射知识 代理(Proxy)是一种设计模式,提供了对目标 ...

  5. Torch中的矩阵相乘分类

    矩阵相乘在torch中的几种情况 1.矩阵逐元素(Element-wise)乘法 torch.mul(mat1, other) mat和other可以是标量也可以是任意维度的矩阵,只要满足最终相乘是可 ...

  6. 【增强学习】Torch中的增强学习层

    要想在Torch框架下解决计算机视觉中的增强学习问题(例如Visual Attention),可以使用Nicholas Leonard提供的dpnn包.这个包对Torch中原有nn包进行了强大的扩展, ...

  7. torch中的copy()和clone()

    torch中的copy()和clone() 1.torch中的copy()和clone() y = torch.Tensor(2,2):copy(x) --- 1 修改y并不改变原来的x y = x: ...

  8. C++中的二阶构造模式

    1 C++中的二阶构造模式 1.1 半成品对象 首先回顾下构造函数: 类的构造函数用于对象的初始化. 构造函数与类同名并且没有返回值. 构造函数在对象定义时自动被调用. 思考如下几个问题: 如何判断构 ...

  9. LPS25HB 气压计 参考手册中关于FIFO功能的解读

    文章目录 LPS25HB 气压计 参考手册中关于FIFO功能的解读 FIFO 普通模式 FIFO Stream 模式 Stream-to-FIFO 模式 Bypass-to-Stream 模式 FIF ...

最新文章

  1. Matlab绘制包含双Y轴的图
  2. Java基础-方法(2)和数组
  3. hdu4011(水贪心)
  4. 陈序猿,你敢创业吗?怎么才算成功?
  5. python描述器 有限状态机_笨办法学 Python · 续 练习 30:有限状态机
  6. concatenate python_python中numpy.concatenate()函数的使用
  7. 如何设置xampp的phpmyadmin外网访问?
  8. 一文学会如何使用Java的交互式编程环境 JShell
  9. 推荐5个JAVA前后端分离项目
  10. html页面加声音,HTML5 肿么给网页加屏幕点击声音。
  11. 什么叫定向广告?定向传播有哪些好处
  12. 解决win10安装失败原因和方法
  13. ValueError: Sample larger than population or is negative...
  14. java求矩阵的逆矩阵_Java逆矩阵计算
  15. celery英语_蔬菜介绍:芹菜 Celery
  16. 30个精美的简单网站
  17. 智能化引领中国铁路发展
  18. 2022年全球与中国LED嵌入式照明行业发展趋势及投资战略分析报告
  19. RTK差分共享猫共享后中海达不能固定解决办法
  20. c语言2010软件下载,Access2010官方下载免费完整版|Access2010官方下载-太平洋下载中心...

热门文章

  1. Oracle系列:Oracle RAC集群体系结构
  2. 给控件做数字签名之三:进行数字签名
  3. 方法 注释_在IDEA中配置类和方法的文档注释
  4. 【Python】政府工作报告词云
  5. 关于在Visual Studio 2019预览版中的用户体验和界面的变化
  6. windows系统托盘tray
  7. 使用Redux-Saga进行异步操作
  8. Java中的equals和==的差别 以及Java中等价性和同一性的讨论
  9. 数字签名加密过程举例
  10. matlab设置背景颜色