目录在这里

  • 前言
  • 简单介绍
  • Show you the Code
  • 路径哦
  • 整份代码
  • 备注
  • 参考

前言

纯粹想学习一下torch的数据集类,可能后面会用到吧。

简单介绍

我们在训练过程中除了写模型、训练等程序外,还会用到数据加载,而官方就几个数据加载的类,不足以满足个人的需求

我想要将个人的npy文件数据集加载进来,还有每一个npy文件对应了一个target,于是便想动手写个Dataset

参考pytorch文档可以看到,所有其他数据集都会进行经子类化,且所有的子类都应该override__len____getitem__

  • __len__ 提供了数据集的大小
  • __getitem__ 支持整数索引,范围从0到len(self)

其实这里我在马代码过程中感觉到有点不对劲,如果__len__提供的是真的个数(比如存在6个,而真是就是6个,而不是6-1个),那么__getitem__的0到len(self)(因为从0开始的话,最后一个索引应该是5,而不是6,即len(self))就会超出索引,所以在我的代码里面为了不出现错误,所幸将数据集进行了掐头去尾

Show you the Code

class OFDataset(Dataset):def __init__(self, path_dataset, seq):self.path_OF = r'{}flow_data/'.format(path_dataset)self.path_pose = r'{}gt_poses_frames/'.format(path_dataset)self.seq = seqself.get_len_of_each_seq()passdef __getitem__(self, item):'''支持整数索引,范围从0到len(self)'''path_index = self.decode(item)sample, target = self.loader(path_index)return sample, targetdef __len__(self):'''提供了数据集的大小'''return self.index_range[-1] - 1def get_len_of_each_seq(self):'''获取每一个seq的数据长度'''self.len_seq = [] # [4540, 1100, ...]for i in self.seq:pose_name = r'{}{}.txt'.format(self.path_pose, i)with open(pose_name, 'r') as f:len_i = len(f.readlines())self.len_seq.append(len_i)self.index_range = [0] # [0, 4540, 5640, ...]for j in range(len(self.len_seq)):if j == 0:max_range_j = self.index_range[0] + self.len_seq[j]else:max_range_j = self.index_range[j] + self.len_seq[j]self.index_range.append(max_range_j)def decode(self, item):'''对item进行解码,并获取相应的path和index'''path_index = {}for i in range(len(self.index_range)):if item >= self.index_range[i] and item < self.index_range[i+1]:path_index['path'] = self.seq[i]path_index['index'] = item - self.index_range[i]return path_indexelse:passdef loader(self, path_index):'''通过索引返回tansor'''sample_path = r'{}{}/{}-{}.npy'.format(self.path_OF,path_index['path'],path_index['index'],path_index['index']+1)sample = np.load(sample_path)sample = torch.from_numpy(sample)target_path = r'{}{}.txt'.format(self.path_pose,path_index['path'])with open(target_path, 'r') as f:readline = f.readlines()[path_index['index']]readline = readline.split(',')readline = list(map(float, readline))readline = np.array(readline)target = readline[:6]target = torch.from_numpy(target)return sample, target

其实整个代码中,我们只需要在意的就是下面三个函数:

  • __init__(self, path_dataset, seq)
    初始化函数,我这里传入的是数据集的地址和需要读取的序列
  • __getitem__(self, item)
    这个函数是最重要的之一,还有一个是下面那个,这两个是必须要要有的,而且必须返回相应的数据,像这个函数就需要返回的是sample和target(样本和标签,有些子类标签不是必须的),然回的数据必须是tensor类型的数据,这样子才能直接被输入网络里面进行训练
  • __len__(self)
    这个说的很清楚,就是所有数据的总量,是一个int类型的

我的代码里还有其他的函数,其实都是为上面的两个重要的函数作准备的,里面有注释,我也就不介绍了。

路径哦

npy数据保存在flow_data文件夹下的00到10文件夹里

而gt_poses_frames文件夹里面的txt文件就是对应每一个[00,01,02,…,10]文件夹里面的npy数据的标签,以行为单位

然后00.txt文件夹第一行就对应data_flow/00/文件夹下第一个npy数据,如图

整份代码

'''
@Author: Astrophil (luo19902567292@163.com)
@Date: 2022-03-18
@LastEditTime: 2022-03-23
@LastEditors: Astrophil
@Description:
'''import os
import matplotlib.pyplot as pltimport numpy as np
import torch
from torch import nn
import torch.optim as optim
import torchvision
from torchvision import transforms, models, datasets
import imageio
import time
import warnings
import random
import sys
import copy
import json
from PIL import Image
from torch.utils.data import Dataset, DataLoader, TensorDataset
import torch.nn.functional as Fclass OFDataset(Dataset):def __init__(self, path_dataset, seq):self.path_OF = r'{}flow_data/'.format(path_dataset)self.path_pose = r'{}gt_poses_frames/'.format(path_dataset)self.seq = seqself.get_len_of_each_seq()passdef __getitem__(self, item):'''支持整数索引,范围从0到len(self)'''path_index = self.decode(item)sample, target = self.loader(path_index)return sample, targetdef __len__(self):'''提供了数据集的大小'''return self.index_range[-1] - 1def get_len_of_each_seq(self):'''获取每一个seq的数据长度'''self.len_seq = [] # [4540, 1100, ...]for i in self.seq:pose_name = r'{}{}.txt'.format(self.path_pose, i)with open(pose_name, 'r') as f:len_i = len(f.readlines())self.len_seq.append(len_i)self.index_range = [0] # [0, 4540, 5640, ...]for j in range(len(self.len_seq)):if j == 0:max_range_j = self.index_range[0] + self.len_seq[j]else:max_range_j = self.index_range[j] + self.len_seq[j]self.index_range.append(max_range_j)def decode(self, item):'''对item进行解码,并获取相应的path和index'''path_index = {}for i in range(len(self.index_range)):if item >= self.index_range[i] and item < self.index_range[i+1]:path_index['path'] = self.seq[i]path_index['index'] = item - self.index_range[i]return path_indexelse:passdef loader(self, path_index):'''通过索引返回tansor'''sample_path = r'{}{}/{}-{}.npy'.format(self.path_OF,path_index['path'],path_index['index'],path_index['index']+1)sample = np.load(sample_path)sample = torch.from_numpy(sample)target_path = r'{}{}.txt'.format(self.path_pose,path_index['path'])with open(target_path, 'r') as f:readline = f.readlines()[path_index['index']]readline = readline.split(',')readline = list(map(float, readline))readline = np.array(readline)target = readline[:6]target = torch.from_numpy(target)return sample, target# 一般卷积层,relu层,池化层可以写成一个套餐
# 注意卷积最后结果还是一个特征图,需要把图转换成向量才能做分类或者回归任务
class CNN(nn.Module):def __init__(self):super(CNN, self).__init__()self.conv1 = nn.Sequential(     # 输入大小(2, 192, 640)nn.Conv2d(in_channels=2,          # ofout_channels=16,        # 要得到多少个特征图kernel_size=5,          # 卷积核大小stride=1,               # 步长padding=2,# 如果希望卷积后大小跟原来一样,需要配置padding=(kernal_size-1)/2 if stride=1),nn.ReLU(),                  # relunn.MaxPool2d(kernel_size=2),# 进行池化操作(2x2区域),输出结果为(16, 81, 320))self.conv2 = nn.Sequential(     # [16, 96, 320]nn.Conv2d(in_channels=16,out_channels=32,kernel_size=5,stride=1,padding=2,),nn.ReLU(),nn.MaxPool2d(kernel_size=2),# [32, 48, 160])self.conv3 = nn.Sequential(     # [32, 48, 160]nn.Conv2d(in_channels=32,out_channels=64,kernel_size=5,stride=1,padding=2,),nn.ReLU(),nn.MaxPool2d(kernel_size=2),# [64, 24, 80])self.fc1 = nn.Linear(in_features=64*24*80, out_features=1024)self.fc2 = nn.Linear(in_features=1024, out_features=6)def forward(self, x):x = self.conv1(x)x = self.conv2(x)x = self.conv3(x)#x = x.view(x.size(0), -1)       # flatten操作,结果为:(batch_size, 31*7*7)x = x.view(-1, 64*24*80)fc1_out = F.relu(self.fc1(x))fc2_out = F.relu(self.fc2(fc1_out))return fc2_out'''dataset'''
path_dataset = 'dataset/kitti_odom/'
seq_train = ['00', '01', '03', '10']
seq_test = ['02', '04']
dataset_train = OFDataset(path_dataset, seq_train)
data_loader_train = DataLoader(dataset_train, batch_size=8)
dataset_test = OFDataset(path_dataset, seq_test)
data_loader_test = DataLoader(dataset_test, batch_size=8)train_on_gpu = torch.cuda.is_available()
if not train_on_gpu:print('CUDA is not available.   Training on CPU...')
else:print('CUDA is available!   Training on GPU...')
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")'''准确率作为评估标准'''
def accuracy(predictions, labels):pred = torch.max(predictions.data, 1)[1]rights = pred.eq(labels.data.view_as(pred)).sum()return rights, len(labels)'''训练网络模型'''
# 实例化
net = CNN()
net = net.to(device)
# 损失函数
# criterion = nn.CrossEntropyLoss()
criterion = nn.MSELoss()
# 优化器
optimizer = optim.Adam(net.parameters(), lr=0.001)num_epochs = 50
for epoch in range(num_epochs):loss_list = []print('begin {}th batch'.format(epoch))for batch_idx, (data, target) in enumerate(data_loader_train):data = data.to(device)target = target.to(device)net.train()output = net(data)output = output.to(torch.float32)target = target.to(torch.float32)loss = criterion(output, target)optimizer.zero_grad()loss.backward()optimizer.step()if batch_idx % 20 == 0:print(epoch, batch_idx)print(loss)loss_list.append(loss.data)print('loss: {}'.format(loss_list[-1]))

里面的CNN类代码是我随便乱写的,只是为了测试以下Dataset类是否能够进行加载

里面的损失很奇怪,我估计是因为label中的6个回归真值没有进行标准化

备注

个人纯深度学习小白,可能存在不对的地方,还望大佬们指正
哦哦,可能后续会上传到git吧哈哈

参考

Pytorch中文文档

pytorch自定义Dataset,torch加载自己的numpy数据集,torch-cnn训练numpy进行回归相关推荐

  1. 深度学习-Pytorch:项目标准流程【构建、保存、加载神经网络模型;数据集构建器Dataset、数据加载器DataLoader(线性回归案例、手写数字识别案例)】

    1.拿到文本,分词,清晰数据(去掉停用词语): 2.建立word2index.index2word表 3.准备好预训练好的word embedding 4.做好DataSet / Dataloader ...

  2. Pytorch中的数据加载

    Pytorch中的数据加载 1. 模型中使用数据加载器的目的 在前面的线性回归模型中,使用的数据很少,所以直接把全部数据放到模型中去使用. 但是在深度学习中,数据量通常是都非常多,非常大的,如此大量的 ...

  3. 【学习系列7】Pytorch中的数据加载

    目录 1. 模型中使用数据加载器的目的 2. 数据集类 3. 迭代数据集 1. 模型中使用数据加载器的目的 在前面的线性回归横型中,我们使用的数据很少,所以直接把全部数据放到锁型中去使用. 但是在深度 ...

  4. 使用pytorch自定义DataSet,以加载图像数据集为例,实现一些骚操作

    使用pytorch自定义DataSet,以加载图像数据集为例,实现一些骚操作 总共分为四步 构造一个my_dataset类,继承自torch.utils.data.Dataset 重写__getite ...

  5. pytorch中的数据加载(dataset基类,以及pytorch自带数据集)

    目录 pytorch中的数据加载 模型中使用数据加载器的目的 数据集类 Dataset基类介绍 数据加载案例 数据加载器类 pytorch自带的数据集 torchvision.datasets MIN ...

  6. PyTorch训练中Dataset多线程加载数据,比Dataloader里设置多个workers还要快

    PyTorch训练中Dataset多线程加载数据,而不是在DataLoader 背景与需求 现在做深度学习的越来越多人都有用PyTorch,他容易上手,而且API相对TF友好的不要太多.今天就给大家带 ...

  7. 【PyTorch训练中Dataset多线程加载数据,比Dataloader里设置多个workers还要快】

    文章目录 一.引言 二.背景与需求 三.方法的实现 四.代码与数据测试 五.测试结果 5.1.Max elapse 5.2.Multi Load Max elapse 5.3.Min elapse 5 ...

  8. Pytorch自定义Dataset和DataLoader去除不存在和空的数据

    Pytorch自定义Dataset和DataLoader去除不存在和空的数据 [源码GitHub地址]:https://github.com/PanJinquan/pytorch-learning-t ...

  9. 十九、Pytorch中的数据加载

    1. Pytorch中DataSet的使用方法 1.1 DataSet加载数据的方法 DataSet是Pytorch中用来表示数据集的一个抽象类,在torch中提供了数据集的基类torch.utils ...

  10. PyTorch模型保存与加载

    torch.save:保存序列化的对象到磁盘,使用了Python的pickle进行序列化,模型.张量.所有对象的字典. torch.load:使用了pickle的unpacking将pickled的对 ...

最新文章

  1. 关于Linux的inode和dentry的一组文章
  2. 深度优先搜索(dfs),城堡问题
  3. pandas 知识点补充:绘图plot
  4. 【杂文】Do A Slash
  5. 三菱fx5u编程手册_实用分享 | 三菱FX 5U特点是什么?
  6. 基于jQuery带图标的多级下拉菜单
  7. OpenCASCADE绘制测试线束:拓扑命令之拓扑和曲面创建
  8. java计算器简单吗,java简单计算器
  9. MySQL的权限分配
  10. python开发效率高吗_从运行效率与开发效率比较Python和C++
  11. 流程型企业SCM、ERP、MES、PCS如何集成?
  12. linux 内存泄露 工具,Linux Kernel模块内存泄露分析
  13. Flash 安全策略配置(1)
  14. 嵌入式工作笔记0007---对讲机嵌入式开发记录---认识对讲机的功能--随时更新
  15. solr相关配置(搜索novel案例)
  16. Mybatis使用技巧
  17. Fedora 14 下成功驱动BCM4312的步骤
  18. java gui即时聊天工具
  19. 按键精灵手机助手如何连接安卓版按键精灵如何连接手机助手
  20. 软件测试面试经验之如何测试刷抖音

热门文章

  1. Java 汉字 转 拼音/首字母
  2. 中环混改尚存变数 高调的TCL要上演“资本魔输”?
  3. Java进阶(一) Java高效读取大文件,占内存少
  4. golang 文件命名规则
  5. 手机怎么查看连接过的wifi密码
  6. 破解封杀ADSL路由器解决办法全面剖析
  7. WIN2K XP 2K3 下红警不能联机的完美解决方案(转)
  8. 对比度 css_更好的颜色和对比度可访问性CSS技巧
  9. 熊出没全集光头强的机器人_熊出没:其实光头强早就不想当伐木工了,这些细节足以说明一切...
  10. 再生龙clonezilla备份系统全过程