Pointnet(part_seg)train.py,test.py代码随记
train.py
我将代码全部简化,将关键步骤全部列出
hdf5_data_dir = 数据集路径 #读取数据集的路径创建os.mkdir(train_result) #创建train_result文件夹color_map_file = part_color_mapping.json #读取颜色json文件路径,一共50类color_map = json.load() #读取.json文件内容读取overallid_to_catid_partid.json 列表形式 #读取物体零件编号training_file_list = train_hdf5 路径
testing_file_list = test_hdf5 路径model_storage_path = 'trained_models' #在train_result下创建了一个trained_models文件夹,#用于存放训练好的模型
创建logs日志文件夹
创建summaries文件夹,可视化def train():pointclouds_ph = (32,2048,3)input_label_ph = (32,16)label_ph = (32)seg_ph = (32,2048) """以上是导入输入数据的占位符"""batch = 初始化变量0learning_rate = 指数衰减学习率bn_decay = 批标准化衰减率labels_pred , seg_pred, end_points = model.get_model( (32,2048,3),(32,16),...)#模型训练loss = get_loss()#计算损失train_variables = tf.trainable_variables() #可训练参数trainer = tf.train.AdamOptimizer(learning_rate) #优化器优化train_op = trainer.minimize(loss, var_list=train_variables, global_step=batch)#梯度优化,更新var_list最大程度减少损失saver = tf.train.Saver() #保存和加载模型init = tf.global_variables_initializer() #全局变量初始化sess.run(init) #图结构创建好了,开始会话 for epoch in range(training_epoches): #训练次数eval_one_epoch(epoch)train_file_idx = np.arange(0,6)打乱顺序train_one_epoch(train_one_idx , epoch)if(epoch+1) %10 == 0:cp_filename = saver.save(sess , 保存路径) #保存训练模型def eval_one_epoch(epoch_num):total_label_acc_per_cat = np.zeros[16] #每一类物体分类标签的正确数total_seg_acc_per_cat = np.zeros[16] #每一类分割正确数total_seen_per_cat = np.zeros[16] #每类个数for i in range(num_test_file):cur_data = (2048,2048,3) #测试集的点云数据cur_labels =(2048,16) #点云数据物体对应的16类cur_seg = (2048,2048) #每个点对应的50类其中之一cur_labels_one_hot = convert_label_to_one_hot(cur_labels)"""将label都换为one_hot形式"""for j in range(num_batch): #按批次运行beginidx-----endidx #开始到结束的索引loss = sess.run()per_instance_part_acc = np.mean(pred_seg_res == cur_seg[begidx:endidx,...],axis=1 )"""求每个物体的零件正确率"""average_part_acc = np.mean(per_instance_part_acc)"""求这32个物体的平均零件正确率"""per_instance_label_pred = np.argmax(label_pred_val, axis=1)""" 求出这32个物体对类别预测的标签 """total_label_acc += np.mean(np.float32(per_instance_label_pred == cur_labels[begidx: endidx, ...]))""" 算出预测标签的正确率并求平均进行累加"""total_seg_acc += average_part_acc """将平均零件分割正确率累加"""for shape_idx in range(begidx, endidx):total_seen_per_cat[cur_labels[shape_idx]] += 1"""test过的每一类的个数"""total_label_acc_per_cat[cur_labels[shape_idx]]+=np.int32(per_instance_label_pred[shape_idx-begidx] == cur_labels[shape_idx])"""每一类标签判断正确的个数:预测标签与正确标签对比,如果正确就在相应位置+1"""total_seg_acc_per_cat[cur_labels[shape_idx]] += per_instance_part_acc[shape_idx - begidx]"""将每个物体分割的正确率累加"""total_loss = total_loss * 1.0 / total_seentotal_label_loss = total_label_loss * 1.0 / total_seentotal_seg_loss = total_seg_loss * 1.0 / total_seentotal_label_acc = total_label_acc * 1.0 / total_seentotal_seg_acc = total_seg_acc * 1.0 / total_seen
train_one_proch比eval_one_peoch 多一个优化器过程
test.py
自行定义命令参数获取 model_path ,保存的训练模型pretrained_model_path = FLAGS.model_path #获取保存好的模型
hdf5_data_dir = './hdf5_data' # 获取h5数据集
ply_data_dir = './PartAnnotation' # 导入测试数据集test_file_list = os.path.join(BASE_DIR, 'testing_ply_file_list.txt')
""" testing_ply_file_list.txt为从PartAnnotation数据集中采样出的2874个数据,分别包括点云数据 / 分割数据 / 实例类别编号"""oid2cpid = 'overallid_to_catid_partid.json'
"""oid2cpid读取物体零件编号[["02691156", 1], ["02691156", 2],....]""" object2setofoid = {} #oid对象集
for idx in range(len(oid2cpid)):objid, pid = oid2cpid[idx] #objid对象标识符 pid编号if not objid in object2setofoid.keys():object2setofoid[objid] = []object2setofoid[objid].append(idx)
"""创建一个字典,将每个物体编号按顺序0~49索引排序{'02691156':[0,1,2,3],'02773838':[4,5],.....}"""all_obj_cat_file = 'all_object_categories.txt' 获取16类物体和编号的文件,并分别划分到两个列表中
objcats = split()[0]
"""['02691156','02773838',......]"""
objnames = split()[1]
"""['Airplane','Bag',......]"""color_map = json.load('part_color_mapping.json') 获取颜色cpid2oid = 'catid_partid_to_overallid.json'
"""cpid2oid为对物体零件进行分类1~50类对应{"03642806_2": 29, "03642806_1": 28,...."""------------------------------------数据集的前期处理全部完成-----------------------------------
def predict():pointclouds_ph = (1,3000,3)input_label_ph = (1,16)pred , seg_pred , end_points = get_model(pointclouds_ph, input_label_ph,...)"""模型占位符"""saver = tf.train.Saver()"""添加操作用来保存和重现所有变量"""with tf.Seesion(config=config) as sess:saver.restore(sess, pretrained_model_path)"""导入训练好的模型"""batch_data = np.zeros[1,3000,3]total_per_cat_acc = np.zeros(16)"""每一类正确的个数"""total_per_cat_iou = np.zeros(16)""" 每一类的IOU"""total_per_cat_seen = np.zeros(16)""" 每一类测试的总个数"""获取测试用的数据集test_file_list,并进行预处理,将其划分为3类列表pts_files = split()[0] """获取的点云文件路径"""seg_files = split()[1]"""获取seg文件路径"""labels = split()[2]""" 获取物体类别编号""""""开始逐个对测试数据集中的数据进行操作,测试数据有2874个"""for shape_idx in range(len_pts_files):cur_gt_label = on2oid[labels[shape_idx]]""" on2oid为物体编号对应索引,总共有16个,获取当前数据集的编号对应索引"""将其转换为独热编码pts_file_to_load = os.path.join(ply_data_dir, pts_files[shape_idx])seg_file_to_load = os.path.join(ply_data_dir, seg_files[shape_idx])"""根据shape_idx将pts文件和seg文件读取出来"""pts, seg = load_pts_seg_files(pts_file_to_load, seg_file_to_load, objcats[cur_gt_label])"""将各物体编号都统一到1~50类当中,这个操作非常关键!!!!! """def load_pts_seg_files(pts_file, seg_file, catid):with open(pts_file, 'r') as f:pts_str = [item.rstrip() for item in f.readlines()]pts = np.array([np.float32(s.split()) for s in pts_str], dtype=np.float32)a = len(pts) with open(seg_file, 'r') as f:part_ids = np.array([int(item.rstrip()) for item in f.readlines()], dtype=np.uint8)"""在单独一个物体中以1,2,3将不同零件进行分类,得出的零件索引[2 2 2 1 1 1 1 1 ....]"""seg = np.array([cpid2oid[catid+'_'+str(x)] for x in part_ids])"""cpid2oid为每个物体零件对应的0~50类编号,将单个物体零件的分类通过cpid2oid转换为总的50类别"""label_pred_val , seg_pred_res = sess.run()""" 预测出的label 和 seg"""label_pred_val = np.argmax(label_pred_val[0, :]) """将预测出的label得出"""seg_pred_res = seg_pred_res[0,....] #进行降维处理c = seg_pred_res.shaoe #(3000,50)iou_oids = object2setofoid[objcats[cur_gt_label]]""" 将该物体的零件索引提取出来objacts:['02691156','02773838',......]object2setofoid:{'02691156':[0,1,2,3],'02773838':[4,5],.....}[12,13,14,15]"""non_cat_labels = list(set(np.arange(NUM_PART_CATS)).difference(set(iou_oids))) """创建一个0~49的数组,剔除12,13,14,15"""mini = np.min(seg_pred_res) #获取预测中的最小值seg_pred_res[:, non_cat_labels] = mini - 1000 #将除12,13,14,15的其他标签都减小seg_pred_val = np.argmax(seg_pred_res, axis=1)[:ori_point_num]"""比较12,13,14,15这个位置的数,取最大判断为该类"""seg_acc = np.mean(seg_pred_val == seg)"""预测的类与正确实际的类做比较,得出seg的正确率"""total_acc += seg_acc"""将分割的正确率进行累加"""total_seen += 1""" 测试总的个数"""total_per_cat_seen[cur_gt_label] += 1total_per_cat_acc[cur_gt_label] += seg_accmask = np.int32(seg_pred_val == seg)"""预测类与正确的比较,相等为1,不等为0""" 计算IOU = n_intersect/(n_pred + n_seg - n_intersect)n_pred = 预测的12标签的个数n_seg = 实际的12标签的个数n_intersect = 判断正确的12标签的个数"""对预测结果,保存在obj文件"""if output_verbose:output_color_point_cloud(pts, seg, os.path.join(output_dir, str(shape_idx)+'_gt.obj'))output_color_point_cloud(pts, seg_pred_val, os.path.join(output_dir, str(shape_idx)+'_pred.obj'))output_color_point_cloud_red_blue(pts, np.int32(seg == seg_pred_val), os.path.join(output_dir, str(shape_idx)+'_diff.obj'))
Pointnet(part_seg)train.py,test.py代码随记相关推荐
- python打不开py文件查看代码,用python打开py文件
.py文件无法用python打开 刚刚把python更新到python3.7.2 但是发现之前写的.py的文件双击没有任何我去,你的情况和我一模一样,我也是环境变量和注册表按照网上的方法设置了,却还是 ...
- libsvm中tools(easy.py,subset.py,grid.py,checkdata.py)的使用
这几天在用libsvm(2.8.6)中的一些工具,总结一下. libsvm的一些工具还是非常有用的,1.可以调用subset.py将你的样本集合按你所想要的比例进行抽样出两个子样本集合.2.还可以调用 ...
- PointNet训练与测试github开源代码(PointNet实现第5步骤pytorch版)
PointNet第5步--PointNet训练与测试github开源代码 在运行github上的代码时,经常版本不匹配会出现大量的不同,或者报错,这篇主要记录我解决相关报错的方法. 本次测试的是git ...
- Tensorflow2.0 实现 YOLOv3(二):网络结构(common.py + backbone.py)
文章目录 文章说明 总体结构 common.py Convolutional 结构 Residual 残差模块 Upsample 结构 backbone.py Darknet53 结构 yolov3. ...
- yolov5 test.py val.py detec.py 区别在哪里呢?
yolov5 test.py val.py detec.py 区别在哪里呢? 用户在训练数据的时候必须使用 train.py 来进行 数据训练和验证,但我很难理解 detect.py 和 test.p ...
- python endswith py pyw_表达式 'test.py'.endswith(('.py', '.pyw')) 的值为 __________ 。_学小易找答案...
[简答题]实验十二:服务.doc [填空题]表达式 'Hello world'.lower() 的值为 _____________ . [填空题]表达式 ':'.join('hello world.' ...
- MATLAB代码:记及电转气协同的含碳捕集与垃圾焚烧虚拟电厂优化调度
MATLAB代码:记及电转气协同的含碳捕集与垃圾焚烧虚拟电厂优化调度 为了促进多能源互补及能源低碳化,本文提出了计及电转气协同的含碳捕集与垃圾焚烧虚拟电厂优化调度模型. 完美复现 ID:7934668 ...
- 1 PyTorch版YOLOv3 代码中文注释 之 训练 train.py test.py detect.py
文章目录 PyTorch版YOLOv3 代码中文注释 1. 相关链接: 2. 代码结构: 3. train.py 3.1. train.py 中包含的主要功能 4. test.py 4.1. test ...
- MMDet——用单卡train.py debug分布式代码
增加如下代码即可 前边的输入参数添加DEBUG: parser.add_argument('--DEBUG', type=bool, default=True, help='debug mode') ...
最新文章
- 盘点路由协议之RIP协议及IGRP协议
- Spring boot配置log4j输出日志
- oracle 10g real application clusters introduction (RAC原理)
- HttpServletRequest类用途
- SpringAOP描述及实现_AspectJ详解_基于注解的AOP实现_SpringJdbcTemplate详解
- 最近在群里┏━━━━━━━━━飞鸽传书━━━━━━━━━━┓
- textbox内容转为字符串_【公告】整改文章内容
- Bootstrap_导航
- 安全运维 - Linux系统维护
- File的创建删除复制等功能实现
- 个人博客网站搭建详细视频教程和源码
- 微信小程序内无法播放第三方服务器上的视频资源
- VMware虚拟机安装macOS黑苹果教程,亲测流程,全过程问题解决方案记录
- 计算机显示应用程序错误窗口,电脑出现应用程序错误窗口怎么办
- 如何理解总体标准差、样本标准差与标准误
- 关于杂质过滤的一点研究
- 简述python程序的书写规范_简明的 Python 编程规范
- 计算机网络——IP数据报分片
- 关于MFC模态对话框dlg.DoModal()返回-1的可能原因
- Symfony 入门教程