torch中Dataset的构造与解读
torch中Dataset的构造与解读
Dataset的构造
要自定义自己的数据集,首先需要继承Dataset(torch.utils.data.Dataset)
类.
继承Dataset
类之后,必须重写三个方法:__init__(), __getitem__(), __len__()
class ModelNet40(Dataset):def __init__(self, xxx):...def __getitem__(self, item):...def __len()__(self):...
解读
单看上面的构造结构与三个需要重写的方法可能会一头雾水。我们详细分析其作用:
__init__的作用
__init__的作用与所有构造函数都一样,初始化一个类的实例。定义类的实际属性,如点云数据集中的unseen, guassian_noise
等,是True
还是False
, 取出所有数据存储为成员变量等等。__getitem__的作用
__getitem__的作用是,根据item的值取出数据。 item实际上就是索引值,会由Dataloader自动从0一直递增到__len__中取出的值。__len__的作用
__len__的作用是,相当于返回整体数据data的shape[0], 即给item的递增指定一个范围。
例子
class ModelNet40(Dataset):def __init__(self, num_points, partition='train', gaussian_noise=False, unseen=False, factor=4):self.data, self.label = load_data(partition)self.num_points = num_pointsself.partition = partitionself.gaussian_noise = gaussian_noiseself.unseen = unseenself.label = self.label.squeeze()self.factor = factorif self.unseen:######## simulate testing on first 20 categories while training on last 20 categoriesif self.partition == 'test':self.data = self.data[self.label>=20]self.label = self.label[self.label>=20]elif self.partition == 'train':self.data = self.data[self.label<20]self.label = self.label[self.label<20]def __getitem__(self, item):pointcloud = self.data[item][:self.num_points] # 核心代码,就是用item取出的数据if self.gaussian_noise:pointcloud = jitter_pointcloud(pointcloud)if self.partition != 'train':np.random.seed(item)anglex = np.random.uniform() * np.pi / self.factorangley = np.random.uniform() * np.pi / self.factoranglez = np.random.uniform() * np.pi / self.factorcosx = np.cos(anglex)cosy = np.cos(angley)cosz = np.cos(anglez)sinx = np.sin(anglex)siny = np.sin(angley)sinz = np.sin(anglez)Rx = np.array([[1, 0, 0],[0, cosx, -sinx],[0, sinx, cosx]])Ry = np.array([[cosy, 0, siny],[0, 1, 0],[-siny, 0, cosy]])Rz = np.array([[cosz, -sinz, 0],[sinz, cosz, 0],[0, 0, 1]])R_ab = Rx.dot(Ry).dot(Rz)R_ba = R_ab.Ttranslation_ab = np.array([np.random.uniform(-0.5, 0.5), np.random.uniform(-0.5, 0.5),np.random.uniform(-0.5, 0.5)])translation_ba = -R_ba.dot(translation_ab)pointcloud1 = pointcloud.Trotation_ab = Rotation.from_euler('zyx', [anglez, angley, anglex])pointcloud2 = rotation_ab.apply(pointcloud1.T).T + np.expand_dims(translation_ab, axis=1)euler_ab = np.asarray([anglez, angley, anglex])euler_ba = -euler_ab[::-1]pointcloud1 = np.random.permutation(pointcloud1.T).Tpointcloud2 = np.random.permutation(pointcloud2.T).Tprint(item)print(pointcloud1.shape)return pointcloud1.astype('float32'), pointcloud2.astype('float32'), R_ab.astype('float32'), \translation_ab.astype('float32'), R_ba.astype('float32'), translation_ba.astype('float32'), \euler_ab.astype('float32'), euler_ba.astype('float32')def __len__(self):return self.data.shape[0] # 给item一个范围
进一步理解其执行逻辑
if __name__ == '__main__':dataset1 = ModelNet40(num_points=1024, partition='train', gaussian_noise=True)dataloader = DataLoader(dataset1, batch_size=64, shuffle=False)count = 0for src_pointcloud, tgt_pointcloud, Rotation, translation, _, _, _, _ in dataloader:print(src_pointcloud.shape)count += 1print(count)
首先需要说明的是,在ModelNet40中,getitem中会打印item的当前值。
如果执行这段代码,在shuffle=False
的情况下,其结果为:
item从0一直增加到__len__
返回的那个值-1, 也就是data的第一维(姑且称为batch维)。
在getitem中取出的pointcloud的shape为(3, 1024),只有两个axis.
而最后输出的count,也就是main函数中整个for循环执行的次数,会是__len__() / batch_size
.
比如len是9480,即self.data的shape为(9480, 2048, 3),那么item就会从0一直增加到9479. 在batch_size为64的情况下,for循环一共执行(即count为) 9480/64 = 148.125, 那么最终会执行149次。 也就是说,每次for循环实质上调用了getitem方法64次,最后在第一维上stack,使之shape变为(64, 3, 1024).
torch中Dataset的构造与解读相关推荐
- .Net 中DataSet和DataTable的 区别与联系
1.简要说明二者关系 在我们编写代码的时候从数据库里取出数据,填充到dataset里,再根据表的名字,实例化到 datatable 中.其实使用 dataset 相当于所使用数据库中数据的副本,保存在 ...
- PyTorch训练中Dataset多线程加载数据,比Dataloader里设置多个workers还要快
PyTorch训练中Dataset多线程加载数据,而不是在DataLoader 背景与需求 现在做深度学习的越来越多人都有用PyTorch,他容易上手,而且API相对TF友好的不要太多.今天就给大家带 ...
- 【PyTorch训练中Dataset多线程加载数据,比Dataloader里设置多个workers还要快】
文章目录 一.引言 二.背景与需求 三.方法的实现 四.代码与数据测试 五.测试结果 5.1.Max elapse 5.2.Multi Load Max elapse 5.3.Min elapse 5 ...
- java中的静态、动态代理模式以及Spring中的CgLib动态代理解读(面试必问)
java中的静态.动态代理模式以及Spring中的CgLib动态代理解读(面试必问) 静态代理 动态代理 CgLib动态代理 基础知: 反射知识 代理(Proxy)是一种设计模式,提供了对目标 ...
- Torch中的矩阵相乘分类
矩阵相乘在torch中的几种情况 1.矩阵逐元素(Element-wise)乘法 torch.mul(mat1, other) mat和other可以是标量也可以是任意维度的矩阵,只要满足最终相乘是可 ...
- 【增强学习】Torch中的增强学习层
要想在Torch框架下解决计算机视觉中的增强学习问题(例如Visual Attention),可以使用Nicholas Leonard提供的dpnn包.这个包对Torch中原有nn包进行了强大的扩展, ...
- torch中的copy()和clone()
torch中的copy()和clone() 1.torch中的copy()和clone() y = torch.Tensor(2,2):copy(x) --- 1 修改y并不改变原来的x y = x: ...
- C++中的二阶构造模式
1 C++中的二阶构造模式 1.1 半成品对象 首先回顾下构造函数: 类的构造函数用于对象的初始化. 构造函数与类同名并且没有返回值. 构造函数在对象定义时自动被调用. 思考如下几个问题: 如何判断构 ...
- LPS25HB 气压计 参考手册中关于FIFO功能的解读
文章目录 LPS25HB 气压计 参考手册中关于FIFO功能的解读 FIFO 普通模式 FIFO Stream 模式 Stream-to-FIFO 模式 Bypass-to-Stream 模式 FIF ...
最新文章
- Matlab绘制包含双Y轴的图
- Java基础-方法(2)和数组
- hdu4011(水贪心)
- 陈序猿,你敢创业吗?怎么才算成功?
- python描述器 有限状态机_笨办法学 Python · 续 练习 30:有限状态机
- concatenate python_python中numpy.concatenate()函数的使用
- 如何设置xampp的phpmyadmin外网访问?
- 一文学会如何使用Java的交互式编程环境 JShell
- 推荐5个JAVA前后端分离项目
- html页面加声音,HTML5 肿么给网页加屏幕点击声音。
- 什么叫定向广告?定向传播有哪些好处
- 解决win10安装失败原因和方法
- ValueError: Sample larger than population or is negative...
- java求矩阵的逆矩阵_Java逆矩阵计算
- celery英语_蔬菜介绍:芹菜 Celery
- 30个精美的简单网站
- 智能化引领中国铁路发展
- 2022年全球与中国LED嵌入式照明行业发展趋势及投资战略分析报告
- RTK差分共享猫共享后中海达不能固定解决办法
- c语言2010软件下载,Access2010官方下载免费完整版|Access2010官方下载-太平洋下载中心...