目录

0、目标:

1、数据的预处理

2、修改数据处理部分的代码

2.1 复制对数据集进行处理的文件

2.2 对kitti_lidar_dataset.py进行修改

2.2.1 头文件修改

2.2.2 数据集对象名称修改

2.2.3 get_info函数修改

2.2.4 .yaml文件修改

2.2.5 运行

3、修改数据集加载

3.1 去掉测试

3.2 修改__getitem__函数

3.3 前后连起来

3.4 .yaml文件修改

3.5运行


·本文还存在错误,对点云并未进行坐标转化,选择性阅读

0、目标:

本文立足于pointpillars算法的训练,这里通过处理kitti数据集展示对自定义数据集的训练方法。

在源代码中对pointpillars的训练需要很多的数据(不晓得咋直接训练可以进入这篇博客OpenPCDet 在KITTI 训练PointPillar_辉e的博客-CSDN博客_openpcdet训练kitti)这里尤其是calib,我们对点云进行目标检测的训练,不需要啥坐标转换的,所以我这里想去除这个文件夹,只依靠velodyne和label来进行训练

1、数据的预处理

这里写了一个代码进行数据的预处理,其目的主要是对label的第12、13、14位进行处理,因为kitti数据集中这个标注的意思是在相机坐标系下其标注框的位置(x ,y ,z),而我们在使用过程中需要获得雷达坐标系下的标注,所以在这里进行预先的转化。

1、该代码写在tools文件夹中,kitti数据集在data文件夹中

2、运行下面的py文件会建立一个文件夹data/kitti/training/new_label_2,并将处理过然后产生的txt文件放入其中

3、运行完代码后将new_label_2名字改为label_2(原谅我是懒蛋,如果不改这个地方,会有很多其他地方要改)

import numpy as np
from pathlib import Path
import osdef get_calib_from_file(calib_file):with open(calib_file) as f:lines = f.readlines()obj = lines[2].strip().split(' ')[1:]P2 = np.array(obj, dtype=np.float32)obj = lines[3].strip().split(' ')[1:]P3 = np.array(obj, dtype=np.float32)obj = lines[4].strip().split(' ')[1:]R0 = np.array(obj, dtype=np.float32)obj = lines[5].strip().split(' ')[1:]Tr_velo_to_cam = np.array(obj, dtype=np.float32)return {'P2': P2.reshape(3, 4),'P3': P3.reshape(3, 4),'R0': R0.reshape(3, 3),'Tr_velo2cam': Tr_velo_to_cam.reshape(3, 4)}class Calibration(object):def __init__(self, calib_file):if not isinstance(calib_file, dict):calib = get_calib_from_file(calib_file)else:calib = calib_fileself.P2 = calib['P2']  # 3 x 4self.R0 = calib['R0']  # 3 x 3self.V2C = calib['Tr_velo2cam']  # 3 x 4# Camera intrinsics and extrinsicsself.cu = self.P2[0, 2]self.cv = self.P2[1, 2]self.fu = self.P2[0, 0]self.fv = self.P2[1, 1]self.tx = self.P2[0, 3] / (-self.fu)self.ty = self.P2[1, 3] / (-self.fv)def cart_to_hom(self, pts):""":param pts: (N, 3 or 2):return pts_hom: (N, 4 or 3)"""pts_hom = np.hstack((pts, np.ones((pts.shape[0], 1), dtype=np.float32)))return pts_hom#对R0_rect进行拓展,然后与Tr_velo_to_cam进行相乘求相反数后再求逆 R0_rect * Tr_velo_to_cam * y=x(y是雷达,x是照相机)def rect_to_lidar(self, pts_rect):""":param pts_lidar: (N, 3):return pts_rect: (N, 3)"""pts_rect_hom = self.cart_to_hom(pts_rect)  # (N, 4)R0_ext = np.hstack((self.R0, np.zeros((3, 1), dtype=np.float32)))  # (3, 4)R0_ext = np.vstack((R0_ext, np.zeros((1, 4), dtype=np.float32)))  # (4, 4)R0_ext[3, 3] = 1V2C_ext = np.vstack((self.V2C, np.zeros((1, 4), dtype=np.float32)))  # (4, 4)V2C_ext[3, 3] = 1pts_lidar = np.dot(pts_rect_hom, np.linalg.inv(np.dot(R0_ext, V2C_ext).T))return pts_lidar[:, 0:3]class Object3d(object):def __init__(self, line):label = line.strip().split(' ')self.top=np.array([])for i in range(0,11):self.top=np.append(self.top,label[i])self.loc = np.array((float(label[11]), float(label[12]), float(label[13])), dtype=np.float32)self.last=np.array([label[14]])def get_calib(root_split_path, idx):calib_file = root_split_path / 'calib' / ('%s.txt' % idx)assert calib_file.exists()return Calibration(calib_file)def get_objects_from_label(label_file):with open(label_file, 'r') as f:lines = f.readlines()objects = [Object3d(line) for line in lines]return objectsdef get_label(root_split_path, idx):label_file = root_split_path / 'label_2' / ('%s.txt' % idx)assert label_file.exists()return get_objects_from_label(label_file)def write_new_libel(root_split_path, idx, save_num):new_libel_file=root_split_path / 'new_label_2' / ('%s.txt' % idx)with open(new_libel_file, "a")as f:f.write(str(save_num[0]))for i in range(1,save_num.shape[0]):f.write(' '+str(save_num[i]))f.write('\r\n')#去掉文件最后的换行符
def del_n(root_split_path,idx):new_libel_file=root_split_path / 'new_label_2' / ('%s.txt' % idx)file_object = open(new_libel_file, "rb+")file_object.seek(-2,2)file_object.truncate()file_object.close()def get_allfile(path):  # 获取所有文件all_file = []files =sorted(os.listdir(path))for f in files :  #listdir返回文件中所有目录#f_name = os.path.join(path, f)#f_name=os.path.basename(f_name)#去掉路径f=os.path.splitext(f)[0]#去掉文件名后缀all_file.append(f)return all_filedef clean_file(root_split_path,idx):new_libel_file=root_split_path / 'new_label_2' / ('%s.txt' % idx)file_object = open(new_libel_file, "w")file_object.close()def mkdir_new_label_2(root_split_path):new_libel_2=root_split_path / 'new_label_2'if os.path.exists(new_libel_2) is False:print("-------mkdir%s-------"%new_libel_2) os.mkdir(new_libel_2)root_split_path=Path('../data/kitti/training')mkdir_new_label_2(root_split_path)
all_file=get_allfile(root_split_path/'label_2')  #tickets要获取文件夹名
print("-------All name loaded-------")
#print(all_file)for file_idx in all_file:clean_file(root_split_path,file_idx)print("This is the %s.txt"%file_idx)calib=get_calib(root_split_path,file_idx)obj_list=get_label(root_split_path,file_idx)annotations = {}for obj in obj_list:annotations['location'] = np.concatenate([obj.loc.reshape(1, 3)], axis=0)#print(annotations['location'])loc_lidar = calib.rect_to_lidar(annotations['location'])loc_lidar=loc_lidar.reshape(-1)#print("top",obj.top[0])temp=np.concatenate([obj.top,loc_lidar,obj.last],axis=0)#print("concatenate",temp)write_new_libel(root_split_path, file_idx, temp)#del_n(root_split_path, file_idx)

2、修改数据处理部分的代码

OpenPCDet中首先对数据进行了一波预处理,我们仿照着写一下,这一步主要是对pcdet/datasets这个文件夹进行处理

2.1 复制对数据集进行处理的文件

把pcdet/datasets/kitti文件夹复制并改名为pcdet/datasets/kitti_lidar,然后把pcdet/utils/object3d_kitti.py复制为pcdet/utils/object3d_kitti_lidar.py

2.2 对kitti_lidar_dataset.py进行修改

pcdet/datasets/kitti_lidar/kitti_lidar_dataset.py

2.2.1 头文件修改

这一行修改最后的object3d_kitti为object3d_kitti_lidar

from ...utils import box_utils, calibration_kitti, common_utils, object3d_kitti_lidar

2.2.2 数据集对象名称修改

头文件下面一行修改为(原类名为KittiDataset)

class KittiLidarDataset(DatasetTemplate):

2.2.3 get_info函数修改

这里其他的地方不要改,直接到这个函数,然后替换为下面的代码

    def get_infos(self, num_workers=4, has_label=True, count_inside_pts=True, sample_id_list=None):import concurrent.futures as futuresdef process_single_scene(sample_idx):print('%s sample_idx: %s' % (self.split, sample_idx))info = {}pc_info = {'num_features': 4, 'lidar_idx': sample_idx}info['point_cloud'] = pc_infoif has_label:obj_list = self.get_label(sample_idx)annotations = {}annotations['name'] = np.array([obj.cls_type for obj in obj_list])annotations['truncated'] = np.array([obj.truncation for obj in obj_list])annotations['occluded'] = np.array([obj.occlusion for obj in obj_list])annotations['alpha'] = np.array([obj.alpha for obj in obj_list])annotations['bbox'] = np.concatenate([obj.box2d.reshape(1, 4) for obj in obj_list], axis=0)annotations['dimensions'] = np.array([[obj.l, obj.h, obj.w] for obj in obj_list])  # lhw(camera) formatannotations['location'] = np.concatenate([obj.loc.reshape(1, 3) for obj in obj_list], axis=0)annotations['rotation_y'] = np.array([obj.ry for obj in obj_list])annotations['score'] = np.array([obj.score for obj in obj_list])annotations['difficulty'] = np.array([obj.level for obj in obj_list], np.int32)num_objects = len([obj.cls_type for obj in obj_list if obj.cls_type != 'DontCare'])num_gt = len(annotations['name'])index = list(range(num_objects)) + [-1] * (num_gt - num_objects)annotations['index'] = np.array(index, dtype=np.int32)loc = annotations['location'][:num_objects]dims = annotations['dimensions'][:num_objects]rots = annotations['rotation_y'][:num_objects]#loc_lidar = calib.rect_to_lidar(loc)#获得一个变换矩阵loc_lidar=locl, h, w = dims[:, 0:1], dims[:, 1:2], dims[:, 2:3]loc_lidar[:, 2] += h[:, 0] / 2gt_boxes_lidar = np.concatenate([loc_lidar, l, w, h, -(np.pi / 2 + rots[..., np.newaxis])], axis=1)annotations['gt_boxes_lidar'] = gt_boxes_lidarinfo['annos'] = annotationsreturn infosample_id_list = sample_id_list if sample_id_list is not None else self.sample_id_listwith futures.ThreadPoolExecutor(num_workers) as executor:infos = executor.map(process_single_scene, sample_id_list)return list(infos)

2.2.4 .yaml文件修改

老规矩,先cv,将tools/cfgs/dataset_configs/kitti_dataset.yaml复制为tools/cfgs/dataset_configs/kitti_lidar.yaml。然后修改一下第一行,修改为

DATASET: 'KittiLidarDataset'

2.2.5 运行

终端输入

python -m pcdet.datasets.kitti_lidar.kitti_lidar_dataset create_kitti_infos tools/cfgs/dataset_configs/kitti_lidar.yaml

结果展示:

然后我们的pkl文件就存放在data/kitti里面啦

3、修改数据集加载

3.0 复制数据集加载的文件

把pcdet/datasets/kitti_lidar文件夹复制并改名为pcdet/datasets/kitti_lidar,里面的文件相应改名

3.1 去掉测试

tools/train.py这个文件夹内部,去掉测试的代码,我们少修改一点

"""logger.info('**********************Start evaluation %s/%s(%s)**********************' %(cfg.EXP_GROUP_PATH, cfg.TAG, args.extra_tag))test_set, test_loader, sampler = build_dataloader(dataset_cfg=cfg.DATA_CONFIG,class_names=cfg.CLASS_NAMES,batch_size=args.batch_size,dist=dist_train, workers=args.workers, logger=logger, training=False)eval_output_dir = output_dir / 'eval' / 'eval_with_train'eval_output_dir.mkdir(parents=True, exist_ok=True)args.start_epoch = max(args.epochs - args.num_epochs_to_eval, 0)  # Only evaluate the last args.num_epochs_to_eval epochsrepeat_eval_ckpt(model.module if dist_train else model,test_loader, args, eval_output_dir, logger, ckpt_dir,dist_test=dist_train)logger.info('**********************End evaluation %s/%s(%s)**********************' %(cfg.EXP_GROUP_PATH, cfg.TAG, args.extra_tag))
"""

3.2 修改__getitem__函数

pcdet/datasets/kitti_lidar/kitti_lidar_dataset.py,修改数据加载的文件,这里主要把图像和calib的加载去掉,然后把我们新的数据集文件(label_2)导入

    def __getitem__(self, index):# index = 4if self._merge_all_iters_to_one_epoch:index = index % len(self.kitti_infos)info = copy.deepcopy(self.kitti_infos[index])sample_idx = info['point_cloud']['lidar_idx']#img_shape = info['image']['image_shape']#calib = self.get_calib(sample_idx)get_item_list = self.dataset_cfg.get('GET_ITEM_LIST', ['points'])input_dict = {'frame_id': sample_idx,#'calib': calib,}if 'annos' in info:annos = info['annos']annos = common_utils.drop_info_with_name(annos, name='DontCare')loc, dims, rots = annos['location'], annos['dimensions'], annos['rotation_y']gt_names = annos['name']#gt_boxes_camera = np.concatenate([loc, dims, rots[..., np.newaxis]], axis=1).astype(np.float32)gt_boxes_lidar = annos['gt_boxes_lidar']input_dict.update({'gt_names': gt_names,'gt_boxes': gt_boxes_lidar})#if "gt_boxes2d" in get_item_list:#    input_dict['gt_boxes2d'] = annos["bbox"]road_plane = self.get_road_plane(sample_idx)if road_plane is not None:input_dict['road_plane'] = road_planeif "points" in get_item_list:points = self.get_lidar(sample_idx)input_dict['points'] = pointsdata_dict = self.prepare_data(data_dict=input_dict)#data_dict['image_shape'] = img_shapereturn data_dict

3.3 前后连起来

pcdet/datasets/__init__.py,将前面的部分和后面的部分连起来

#头文件中加入,我们2.2.2
from .kitti_lidar.kitti_lidar_dataset import KittiLidarDataset__all__ = {'DatasetTemplate': DatasetTemplate,'KittiDataset': KittiDataset,'KittiLidarDataset':KittiLidarDataset,#相应的这里也加入'NuScenesDataset': NuScenesDataset,'WaymoDataset': WaymoDataset,'PandasetDataset': PandasetDataset,'LyftDataset': LyftDataset
}

3.4 .yaml文件修改

将tools/cfgs/kitti_models/pointpillar.yaml复制到tools/cfgs/kitti_lidar_models/pointpillar.yaml,kitti_lidar_models这个文件夹自己建立

其中修改_BASE_CONFIG_

DATA_CONFIG: _BASE_CONFIG_: cfgs/dataset_configs/kitti_lidar.yaml

3.5运行

cd tools
python train.py --cfg_file=cfgs/kitti_lidar_models/pointpillar.yaml --batch_size=3 --epochs=100

运行结果:

OpenPCDet 自定义数据集训练相关推荐

  1. YOLOv5自定义数据集训练

    YOLOv5自定义数据集训练 简介 本文介绍如何在自己的VOC格式数据集上训练YOLO5目标检测模型. VOC数据集格式 首先,先来了解一下Pascal VOC数据集的格式,该数据集油5个部分组成,文 ...

  2. 行人属性识别二:添加新网络训练和自定义数据集训练

    序言 上一篇记录了训练过程,但是项目中提供的模型网络都是偏大的,如果想要在边缘设备上部署,还是比较吃力的,所以本文记录如何加入新的网络模型进行训练,以repvgg为例,加入mobilenet.shuf ...

  3. 在集群服务器进行自定义数据集训练记录过程 TensorBoard logging requires TensorBoard with Python summary writer installed.

    先记录解决办法: TensorBoard logging requires TensorBoard with Python summary writer installed. This should ...

  4. 【yolo】yolov3的pytorch版本保存自定义数据集训练好的权重,并载入自己的模型

    多次试验终于测出来了!!很高兴,结果截图: 数据集是来自网上的,代码原型是github一个大概五千多star的pytorch-yolov3,但原代码并没有载入自己的模型进行训测试阶段,然后parser ...

  5. MMrotate自定义数据集训练与验证格式转换脚本

    数据集准备 数据集格式 文件夹格式:Data/ #保存Dota数据集的目录 Train #存放images和labelTxt的文件夹 Images#存放所有训练集图片的文件夹 labelTxt #存放 ...

  6. Scaled-YOLOv4 快速开始,训练自定义数据集

    代码: https://github.com/ikuokuo/start-scaled-yolov4 Scaled-YOLOv4 代码: https://github.com/WongKinYiu/S ...

  7. 利用PyTorch自定义数据集实现猫狗分类

    看了许多关于PyTorch的入门文章,大抵是从torchvision.datasets中自带的数据集进行训练,导致很难把PyTorch运用于自己的数据集上,真正地灵活运用PyTorch. 这里我采用从 ...

  8. ML之catboost:基于自定义数据集利用catboost 算法实现回归预测(训练采用CPU和GPU两种方式)

    ML之catboost:基于自定义数据集利用catboost 算法实现回归预测(训练采用CPU和GPU两种方式) 目录 基于自定义数据集利用catboost 算法实现回归预测(训练采用CPU和GPU两 ...

  9. DL之CNN:利用自定义DeepConvNet【7+1】算法对mnist数据集训练实现手写数字识别、模型评估(99.4%)

    DL之CNN:利用自定义DeepConvNet[7+1]算法对mnist数据集训练实现手写数字识别.模型评估(99.4%) 目录 输出结果 设计思路 核心代码 输出结果 设计思路 核心代码 netwo ...

最新文章

  1. 2022-2028年中国防臭袜行业投资分析及前景预测报告
  2. python 获取闭包函数的参数
  3. DC使用教程系列1-.synopsys.dc.setup的建立
  4. 一种集各种优点于一身的技术面试方式--转
  5. 10行Python代码自动清理电脑内重复文件
  6. leetcode 41. First Missing Positive 1
  7. 滴滴滴,测试工程师简历模板分享一波
  8. IIS6.0 + openssl执行版 + Windows2003--配置篇
  9. vs2019安装python库_vs2019安装和使用详细图文教程
  10. sql判断为0_SQL简单语义分析概述
  11. 金旭亮:我是一只IT小小鸟(新书推荐 序)
  12. PKG安装包的管理与文件格式分析
  13. 区块链研究实验室-首次提出如何确保区块链分片技术的数据完整性
  14. leetcode剑指offe刷题-第一题-用两个栈实现队列
  15. 看未来的企业是如何解决潜规则的
  16. CAD图块无法分解怎么办?CAD块分解教程
  17. 股票 - - 常用指标【中】
  18. linux中perl的环境变量,在Perl脚本中使用Bash环境变量?
  19. c语言二进制数以字符形式输出,如何用C语言输出二进制数据
  20. matlab典型相关函数,典型相关分析(Matlab实现函数)

热门文章

  1. B.FRIENDit壁虎忍者悬浮式黑轴机械键盘 台式电脑笔记本外接有线游戏键盘IGK2-ST白色套装
  2. 前端基础学习——带你全面掌握HTML语言
  3. cos和acos--余弦和反余弦函数
  4. java+mysql 基于jsp828婚庆用品销售购物网站
  5. 【2018的第一封邀请函】开源大数据引擎用户大会
  6. Python + Tkinter 图形化界面设计1 —— 第一个图形化界面
  7. 三星 I8150 详解手机电池的使用与保养
  8. [附源码]java毕业设计网上花店系统
  9. AppleMac重置管理员账户
  10. python列表02