概述

网上搜了一圈,关于CenterNet 训练关键点数据的资料非常少,而且讲得都很模糊,没法解决实际问题,也未说明细节和要素。在踏坑许久之后,才跑通CenterNet的关键点训练,于是记录一下踏坑历程,以备后忘

环境

cuda11.0
torch1.7.1
torchvision0.8.2
numpy 1.19.2
这是我的环境版本,不是非得这个版本

数据集准备

参考我的另一篇文章
COCO KeyPoints关键点数据集准备

CenterNet 代码修改

训练代码修改

我的数据集 的类别是1类, 关键点是3个

新的数据集代码创建

CenterNet/src/lib/datasets/datasets 目录, coco_hp.py是原来的coco keypoints官方数据集的数据集代码, 我们从这文件copy一份, 命名为handKeyPoints.py

修改其中内容, 先看下对比图

修改代码

num_classes = 1             #类别1num_joints = 3             #关键点个数3default_resolution = [512, 512]mean = np.array([0.636623, 0.642096, 0.649946],dtype=np.float32).reshape(1, 1, 3)                  #数据集计算出的meanstd  = np.array([0.318729, 0.316616, 0.297199],                #数据集计算出的stddtype=np.float32).reshape(1, 1, 3)flip_idx = [[1, 2]]                          #图像翻转,这个我也不懂, 就照着官方的写了一个,跟实际的关键点数对应def __init__(self, opt, split):super(HandKeyPoints, self).__init__()self.edges = [[0, 1], [1, 2]]self.acc_idxs = [1, 2, 3]#数据集文件夹 目录 data/HandCupKeyPoints ,  这个文件夹里面是annotations、test2017 、train2017三个文件夹self.data_dir = os.path.join(opt.data_dir, 'HandCupKeyPoints')                   if split == 'val':                                                                                      #这里我们的是test而不是val,所以改一下split = 'test'self.img_dir = os.path.join(self.data_dir, '{}2017'.format(split))if split == 'test':self.annot_path = os.path.join(self.data_dir, 'annotations', 'test.json')                                                       #直接指定文件名else:self.annot_path = os.path.join(self.data_dir, 'annotations', 'train.json')                                                    #直接指定文件名

这个文件就修改完毕, 其他不用动

修改文件 CenterNet/src/lib/datasets/dataset_factory.py

对比如下:

我这里大小写不一致,不用管, dataset_factory 字段 key 就是 刚刚创建的handKeyPoints.py 的前缀 , value就是 文件里 class 的类名

修改文件 CenterNet/src/lib/opts.py

第15行 修改默认数据集为 handKeyPoints 数据集

第323行

   opt.flip_idx = False#dataset.flip_idxopt.heads = {'hm': opt.num_classes, 'wh': 2, 'hps': 34}          #17个点的x、y 共 34个值

修改为:

    # opt.flip_idx = Falseopt.flip_idx = dataset.flip_idxopt.heads = {'hm': opt.num_classes, 'wh': 2, 'hps': 6}     #3个点的x、y 共 6个值

第345行

 'multi_pose': {'default_resolution': [512, 512], 'num_classes': 1, 'mean': [0.408, 0.447, 0.470], 'std': [0.289, 0.274, 0.278],               #数据集的 mean std'dataset': 'coco_hp', 'num_joints': 17,                                             #关键点个数 17个'flip_idx': [[1, 2], [3, 4], [5, 6], [7, 8], [9, 10], [11, 12], [13, 14], [15, 16]]},                                               #flip_idx

修改为 自己数据集的 mean 和std
flip_idx 修改为相应的

     'multi_pose': {'default_resolution': [512, 512], 'num_classes': 1, 'mean': [0.614, 0.612, 0.622], 'std': [0.348, 0.347, 0.329],                   #数据集的 mean std'dataset': 'handKeyPoints', 'num_joints': 3,                                            #关键点个数 3个'flip_idx': [[1, 2]]},                                                                                   #flip_idx 看情况写

到此,训练部分的代码就修改完了, 可以开始训练了

编写脚本train.sh, 内容

python main.py multi_pose  --arch dla_34 --dataset handKeyPoints --lr 0.25e-4  --batch_size 16 --gpus 0  --load_model ../models/multi_pose_dla_3x.pth

学习率自定义设置,
–batch_size 依据GPU显存大小, 如果CUDA out of memery 就改小点
–load_model …/models/multi_pose_dla_3x.pth 这个是官方训练好的模型,

开始训练

报一堆警告,不用管, 如下:

输出训练过程日志

看损失值, 在验证损失不再下降的时候停止训练

训练的结果

训练的结果保存在exp目录下

模型文件如下:

这样我们就得到了 训练好的模型,
我的数据集较少,不知道是由于标注质量的问题还是学习率的问题,最终验证损失loss=2.7左右时不再下降。

测试

训练代码修改

修改文件 CenterNet/src/lib/utils/debugger.py

增加handKeyPoints 数据集类别判断

第45行增加

    elif num_classes == 1 or dataset == 'handKeyPoints':               #增加自定义的handKeyPoints类别self.names = handKeyPoints_class                                         #类别名称self.names = ['p']                                                                      #类别名称self.num_class = 1                                                                    #类别数self.num_joints = 3                                                                    #关键点数量self.edges = [[0, 1], [1, 2]]                                                        #关键点连接关系self.ec = [(255, 0, 0), (0, 0, 255), (255, 0, 0)]                          #颜色self.colors_hp = [(255, 0, 255), (255, 0, 0), (0, 0, 255)]

第467行增加

handKeyPoints_class = ['hand']

修改文件 CenterNet/src/lib/detectors/multi_pose.py

第85行

 dets[:, :, :4] *= self.opt.down_ratio                                  #bbox 的4个值  dets[:, :, 5:39] *= self.opt.down_ratioemmina.mo               #5~39 是17个关键点的坐标34个值

这里修改为:

 dets[:, :, :4] *= self.opt.down_ratio                          #bbox 的4个值  dets[:, :, 5:11] *= self.opt.down_ratioemmina.mo       #5~11 是3个关键点的坐标6个值

第101行

     debugger.add_coco_bbox(bbox[:4], 0, bbox[4], img_id='multi_pose')            #bbox 的4个值  debugger.add_coco_hp(bbox[5:39], img_id='multi_pose')                                #5~39 是17个关键点的坐标34个值

修改为

     debugger.add_coco_bbox(bbox[:4], 0, bbox[4], img_id='multi_pose')            #bbox 的4个值  debugger.add_coco_hp(bbox[5:11], img_id='multi_pose')                                #5~11 是3个关键点的坐标6个值

到这里测试代码就修改好了

编写测试脚本test_keyPoints.sh, 内容如下:

python demo.py multi_pose --demo ../images/16.jpg --load_model ../models/model_best_keypoints.pth

模型用的就是上述训练好的模型

测试结果


识别出了手臂, 置信度为0.9

CenterNet KeyPoints 关键点训练自己的数据相关推荐

  1. COCO KeyPoints关键点数据集准备

    COCO KeyPoints关键点数据集准备 概述 网上搜了一圈,coco关键点数据集准备的内容比较少,这里写一篇完成的标注流程到数据集准备的文章,以备后忘 标注工具 coco官方标注工具: coco ...

  2. PANet训练自己的数据(VIA标注)

    当前最好的实例分割网络非PANet莫属,可是由于模型太新,网上的资料太少,最近的项目需要 实例分割,只能自己踩踩坑了,目前我还没看到一篇关于PANet训练的博客,只有几篇讲论文的. 环境:ubuntu ...

  3. YOLO-v5训练自己的数据+TensorRT推理部署(2)

    YOLO-v5训练自己的数据+TensorRT推理部署(2) 代码下载地址:下载地址 YOLO v5转TensorRT模型并调用 0.pt模型转wts模型 python3 gen_wts.py # 注 ...

  4. YOLO-v5训练自己的数据+TensorRT推理部署(1)

    YOLO-v5训练自己的数据+TensorRT推理部署(1) 代码下载地址:下载地址 YOLO v5在医疗领域中消化内镜目标检测的应用 YOLO v5训练自己数据集详细教程

  5. YOLOv3: 训练自己的数据(绝对经典版本1)

    为什么80%的码农都做不了架构师?>>>    windows版本:请参考:https://github.com/AlexeyAB/darknet linux       版本:请参 ...

  6. DL之LSTM之MvP:基于TF利用LSTM基于DIY时间训练csv文件数据预测后100个数据(多值预测)状态

    DL之LSTM之MvP:基于TF利用LSTM基于DIY时间训练csv文件数据预测后100个数据(多值预测)状态 目录 数据集csv文件内容 输出结果 设计思路 训练记录全过程 数据集csv文件内容 输 ...

  7. DL之LSTM之UvP:基于TF利用LSTM基于DIY时间训练1200个数据预测后200个数据状态

    DL之LSTM之UvP:基于TF利用LSTM基于DIY时间训练1200个数据预测后200个数据状态 目录 输出结果 设计思路 训练记录全过程 输出结果 设计思路 训练记录全过程 INFO:tensor ...

  8. Dataset之图片数据增强:基于TF实现图片数据增强(原始的训练图片reshaped_image→数据增强→distorted_image(训练时直接使用))

    Dataset之图片数据增强:基于TF实现图片数据增强(原始的训练图片reshaped_image→数据增强→distorted_image(训练时直接使用)) 目录 数据增强步骤 数据增强实现代码 ...

  9. 5招训练你的数据敏感度,数据高手都在用

    真正的数据分析大神是怎样的?有人说能轻松玩转各种分析工具,有人说能从海量数据中找到关联,有人说能一眼识别出报告中的数据异常,还有人说能够撰写一份经典的数据分析报告. 其实对于一个数据大神,这些都是必备 ...

最新文章

  1. 实战教程 | 车道线检测项目实战,霍夫变换 新方法 Spatial CNN
  2. 网络时间协议 --- 网络对时程序
  3. MySQL唯一约束(UNIQUE KEY)
  4. mysql开启慢查询方法(转)
  5. 检测机安装mysql_centos安装mysql的正确方法
  6. python indexerror怎么办_Python IndexError:使用列表作为可迭代对象时...
  7. ffmpeg 静态库使用,undefined reference错误
  8. html改为php报错,**PHP, 这段嵌入html的php代码为何第15、16、17行报错?**
  9. Jsp基本page指令、注释、方法声明,书写规范及注意事项
  10. Hive数据导出入门
  11. 计算机总线相关知识,计算机包括哪几种总线?
  12. sqlite3返回码
  13. 人脸识别技术开发解决方案,人脸识别智慧工地应用开发
  14. 黑客语言——Ruby
  15. 小程序“成语猜题”部分答案
  16. 计算机cpu占用率高,CPU占用率高的原因及解决方法
  17. 2021-10-12 Java 中 Filed.modifiers 之 java.lang.reflect.Modifier
  18. 使用EKL(Elasticsearch、Kibana、Logstash)进行服务器日志的汇聚与监控
  19. 如何快速自动填充空白单元格上一行的内容
  20. 【Chrome 浏览器自带谷歌翻译用不了】

热门文章

  1. 递归删除目录下的所有文件
  2. JavaSE replaceAll 方法
  3. ASP.NET 4.0: 请求验证模式变化导致ValidateRequest=false失效
  4. [推荐]C#快速开发3d游戏工具--Unity3d
  5. linux 统计日志数量总,shell统计日志中时间段内匹配的数量的方法
  6. java struts2值栈ognl_Struts2 (三) — OGNL与值栈
  7. linux 安装redis4.0.6,Redis(4.0.6)在Linux(CentOS7)下的安装
  8. pythonl_Python3 os.lchown() 方法
  9. oracle重启一个节点集群,帮忙分析一例数据库两节点集群每隔几个月节点重启
  10. linux 6.7 nfs安装yum,centos7下NFS使用与配置