实战CenterNet,训练猫脸关键点检测数据集并测试

  • 一、本机配置
  • 二、环境搭建
  • 三、数据准备
  • 四、开始训练
  • 五、测试
  • 六、references

这里主要记录一下,调试CenterNet用来训练猫脸 关键点检测的过程。因为网上现在大多都是利用CenterNet进行目标检测,但是我觉得目标检测用CenterNet显然不是最好的,关于目标检测可以看mmdetection这个框架,可以参考我这篇博客:
mmdetection实战,训练扑克牌数据集(VOC格式)并测试计算mAP
争取按照本篇博客,都能跑成功!
论文地址: Objects as Points【 国内镜像】
代码地址: https://github.com/xingyizhou/CenterNet
博客参考: 扔掉anchor!真正的CenterNet——Objects as Points论文解读

一、本机配置

ubuntu18.04.3 + cuda10.0 + cudnn7.4.2 + PyTorch1.2 + torchvision0.4 + python3.6

二、环境搭建

  1. 其实官方文档INSTALL.md里面已经说的很详细了,虽然作者使用的是pytorch0.4,但是在pytorch1.x上是完全可以运行的。
  2. 在issues/7里面也非常详细了,我基本是安装都是参照这个问题来的,装好环境可以参照里面的test部分测试一下demo。
  3. 只是有一个注意的地方,就是编译DCNv2的时候得去这个tree下载下来,然后注意不是复制到CenterNet/src/lib/models/networks去替换,而是需要删除整个DCNv2文件夹rm -rf DCNv2 ,然后把下载好的复制进去,否则编译会出错。
  4. 以下修改和用到的代码文件,官方预训练模型都打包在这里,无需下载比赛的优胜代码。
  5. 有问题欢迎评论区,这里不过多介绍。

三、数据准备

其实这是一个AI研习社的比赛,介绍和数据集可以里面下,这里用的也是优胜代码,但是他没有写任何的介绍和给出训练好的模型,甚至log日志也没有(只有一点他运行时留下的自己复制的内容,和json格式文件,害我硬着头皮看了好几天,这里就记录一下,只要按照我下面来就行):

  1. 数据集解压出来只有train,test,train.csv,需要把train.csv放到刚刚创建环境中CenterNet/data下,运行generate_train_val_txt.py(也放在data下)生成train.txt,test.txt,这里放一下代码:generate_train_val_txt.py
import csv
import randomwith open('train.csv','r') as csvfile:reader = csv.reader(csvfile)column = [row[0] for row in reader]total_file = column[1:]train_percent = 0.85
num = len(total_file)
print("total images numbers:", num)  # 10548list = range(num)
tr = int(num * train_percent)
train = random.sample(list, tr)
print("train size:", tr)ftrain = open('train.txt', 'w')
fval = open('val.txt', 'w')for i in list:name = total_file[i] + '\n'if i in train:ftrain.write(name)else:fval.write(name)ftrain.close()
fval.close()
print("write finished!")
  1. 然后运行generate_coco_json.py把训练数据集生成COCO格式的文件train.json,test.json,代码参考了在目标检测和关键点检测任务中如何将自己的数据集转为coco格式,想了解COCO格式的可以看一下COCO数据集的标注格式,这里也放一下代码:generate_coco_json.py
# *_* : coding: utf-8 *_*'''
datasets process for object detection project.
for convert customer dataset format to coco data format,
'''import traceback
import argparse
import json
import cv2
import csv
import os__CLASS__ = ['__background__', 'CatFace']   # class dictionary, background must be in first index.def argparser():parser = argparse.ArgumentParser("define argument parser for pycococreator!")parser.add_argument("-r", "--image_root", default='D:\\catface\\train\\', help="path of root directory")parser.add_argument("-p", "--phase_folder", default=["train", "val"], help="datasets split")return parser.parse_args()def MainProcessing(args):'''main process source code.'''annotations = {}    # annotations dictionary, which will dump to json format file.image_root = args.image_rootphase_folder = args.phase_folderwith open('train.csv', 'r') as f:reader = csv.reader(f)result = list(reader)result = result[1:]# coco annotations info.annotations["info"] = {"description": "customer dataset format convert to COCO format","url": "http://cocodataset.org","version": "1.0","year": 2020,"contributor": "ezra","date_created": "2020/03/15"}# coco annotations licenses.annotations["licenses"] = [{"url": "https://www.apache.org/licenses/LICENSE-2.0.html","id": 1,"name": "Apache License 2.0"}]# coco annotations categories.annotations["categories"] = []for cls, clsname in enumerate(__CLASS__):if clsname == '__background__':continueannotations["categories"].append({"id": cls,"name": clsname,"supercategory": "Cat",})for catdict in annotations["categories"]:if "CatFace" == catdict["name"]:catdict["keypoints"] = [0, 1, 2, 3, 4, 5, 6, 7, 8]catdict["skeleton"] = [[0,1],[1,2],[0,2],[3,4],[4,5],[5,6],[6,7],[7,8],[8,3]]for phase in phase_folder:annotations["images"] = []annotations["annotations"] = []fphase = open(phase + '.txt', 'r')step = 0for id, line in enumerate(fphase.readlines()):line = line.strip("\n")file_name = line + '.jpg'images_id = int(line)height, width, _ = cv2.imread(image_root + file_name).shapev = [2, 2, 2, 2, 2, 2, 2, 2, 2]point_str = result[images_id][1:]point = [int(k) for k in point_str]for j in range(9):if min(point[2*j:2*j+2]) < 0:v[j] = 1keypoint = [point[0],  point[1],  v[0], point[2],  point[3],  v[1],point[4],  point[5],  v[2], point[6],  point[7],  v[3],point[8],  point[9],  v[4], point[10], point[11], v[5],point[12], point[13], v[6], point[14], point[15], v[7],point[16], point[17], v[8]]bw = max(point[0::2]) - min(point[0::2]) + 10bh = max(point[1::2]) - min(point[1::2]) + 10if (min(point[0::2]) - 5) < 0:x1 = 0else:x1 = (min(point[0::2]) - 5)if (min(point[1::2]) - 5) < 0:y1 = 0else:y1 = (min(point[1::2]) - 5)annotations["images"].append({"file_name": file_name,"height": height,"width": width,"id": images_id})# coco annotations annotations.annotations["annotations"].append({"id": id + 1,"num_keypoints": 9,"keypoints": keypoint,"area": bw * bh,"iscrowd": 0,"image_id": images_id,"bbox": [x1, y1, bw, bh],"category_id": 1,"segmentation": [],})step += 1if step % 100 == 0:print("processing {} ...".format(step))json_path = phase+".json"with open(json_path, "w") as f:json.dump(annotations, f)if __name__ == "__main__":print("begining to convert customer format to coco format!")args = argparser()try:MainProcessing(args)except Exception as e:traceback.print_exc()print("successful to convert customer format to coco format")

注意一点,不要问我这里的bboxarea怎么会这样计算出来的,我是硬是看他的数据得到了规律,准确的计算方法是要用cocoapi的,因为下面的任务对这两个属性不要求,所以没关系。
3. 然后在data下创建coco文件夹,在其中创建两个文件夹:annotationsimages,把生成的train.jsonval.json放在annotations下,之前解压出来的数据文件夹traintest都放在images下,这样做好后的文件目录长这样:

├── CenterNet/data
│   ├── coco
│   │   ├── annotations
│   │   │   ├── train.json
│   │   │   ├── val.json
│   │   ├── images
│   │   │   ├── train
│   │   │   ├── test

四、开始训练

这里只要把打包文件里面的coco_hp.pyopts.py替换掉对应的CenterNet/src/lib/datasets/dataset/coco_hp.pyCenterNet/src/lib/opts.py,然后就可以运行下面的命令了【windows可能还需要设置–num_workers 0,batch_size看自己硬件】:

python main.py multi_pose --exp_id dla_1x_catface --dataset coco_hp --lr 5e-4 --lr_step '17,27' --num_epochs 37 --batch_size 16 --gpus 0  --load_model ../models/multi_pose_dla_3x.pth

五、测试

  1. 这里只要把打包文件里面的post_process.pymulti_pose.pydemo.py替换掉对应的CenterNet/src/lib/utils/post_process.pyCenterNet/src/lib/detectors/multi_pose.pyCenterNet/src/demo.py
  2. coco_hp.py31行改为:
self.img_dir = os.path.join(self.data_dir, 'images/test')
  1. 运行下面的命令:
python demo.py multi_pose --demo '/home/lsm/文档/CenterNet/data/coco/images/test' --exp_id dla_1x_catfacetest --dataset coco_hp --load_model /home/lsm/文档/CenterNet/exp/multi_pose/dla_1x_catface/model_best.pth

就会在/home/lsm/文档/CenterNet/exp/multi_pose/dla_1x_catfacetest下生成一个result.csv文件,在练习赛里面提交后可以看到成绩:

这个成绩比之前的奖金赛的第一名,也是这份代码的提供者还高了一点,说明我还train得好一点:

4. 在CenterNet/src/lib/utils/debugger.py里面改动下面两个属性:

self.num_joints = 9
self.edges = [[0, 1], [1, 2], [0, 2], [3, 4],[4, 5], [5, 6], [6, 7],[7, 8], [8, 3]]

然后运行下面的命令:

python demo.py multi_pose --demo /home/lsm/文档/CenterNet/data/coco/images/test/54.jpg --load_model /home/lsm/文档/CenterNet/exp/multi_pose/dla_1x_catface/model_best.pth --debug 4

就会得到下面的图:

六、references

做完的感觉就是我真的佩服自己硬着头皮看代码,之前没有调试过CenterNet代码,就是根据追风筝的小伙伴的代码凑出来的!
自己论文看的差不多,代码只是会跑通,不是特别深入,感谢下面博客的指引,也可以去他们里面看看一些loss曲线的绘制等其他功能:
CenterNet训练自己的数据集
(最新版本)如何在CenterNet上训练自己的数据集?
详解Centernet训练自己的数据集(win10 + cuda10 + pytorch1.0.1)

实战CenterNet,训练猫脸关键点检测数据集并测试相关推荐

  1. 猫脸关键点检测大赛:三种方法,轻松实现猫脸识别!

    导语:挑战猫脸,就差你了! 今天这个比赛,得从一个做程序猿的铲屎官开始说起...... 话说,有一天「铲屎猿」早起之后,发现猫主子竟然没了身影:他找啊找啊,找了好久,可仍然到处都没找到猫主子.这时,客 ...

  2. 猫脸关键点检测Baseline【阿水】

    关键点检测是许多计算机视觉任务的基础,例如表情分析.异常行为检测.大家接触最多的可能是人脸关键点检测,广泛应用于人脸识别.美颜.换妆等. 本次AI研习社举办猫脸关键点检测,训练集有10468张,测试集 ...

  3. deep-hight-relolution-net.pytorch训练自己的关键点检测数据步骤

    数据集准备 标注转为coco格式 部署源码(https://github.com/HuangJunJie2017/UDP-Pose/tree/master/deep-high-resolution-n ...

  4. 基于OpenCV训练口罩检测数据集并测试

    以下内容是利用opencv自带的训练器opencv_traincascade.exe与opencv_createsamples.exe,来对口罩数据集进行训练.内容是自己操作过程中的笔记,可能会有些杂 ...

  5. 只讲关键点之兼容100+种关键点检测数据增强方法

    点击上方"3D视觉工坊",选择"星标" 干货第一时间送达 作者丨DefTruth 编辑丨极市平台 本文介绍了已有的几种关键点检测数据增强的方法,将其的优缺点进行 ...

  6. 计算机视觉人体骨骼点动作识别-1.训练自己的关键点检测模型

    人体关键点检测算法 关键点并不特指人体骨骼关键点,还有人脸关键点,物体的关键点.其中人体的关键点,也叫作pose Estimation,是最热门,也是最有难度,应用最广的. 应用可以包括:行为识别,人 ...

  7. python使用opencv对猫脸进行检测,并且框出猫脸

    首先导入需要的cv2库,如果没有的可以在Terminal中使用pip install opencv-python导入opencv的主要库包. import cv2 使用filepath赋值照片的路径并 ...

  8. PyTorch深度学习实战 | 基于ResNet的人脸关键点检测

    人脸关键点检测指的是用于标定人脸五官和轮廓位置的一系列特征点的检测,是对于人脸形状的稀疏表示.关键点的精确定位可以为后续应用提供十分丰富的信息.因此,人脸关键点检测是人脸分析领域的基础技术之一.许多应 ...

  9. 一分钟教会您使用Yolov5训练自己的数据集并测试

    1. 下载YOLO项目代码 点击这里下载并解压YOLO的官方代码:https://github.com/ultralytics/yolov5/tree/v5.0 2. 环境安装 cd进入到下载的YOL ...

最新文章

  1. 第15章 进程间通行 15.6 XSI IPC 15.7 消息队列
  2. asp.net面试的题目
  3. table列宽控制,word-break等
  4. php函数库快速记忆法_PHP速成大法
  5. 【渝粤题库】陕西师范大学201701 高等数学(二)作业 (高起本、专升本)
  6. c语言报错spawning 插1,C语言错误····error spawning c1.exe
  7. vue中 v-show和v-if的区别?
  8. Netty工作笔记0009---Channel基本介绍
  9. 算法复习——割点(洛谷3388)
  10. 系统在此应用程序中检测到基于堆栈的缓冲区_Linux 中的零拷贝技术
  11. 使用Java FXGL构建太空游侠游戏
  12. magento邮件使用php,用Magento的Email模板机制发邮件
  13. MySQL docker yml 3_分享一个mysql的docker-compose.yml
  14. Structs2文件上传以及预览
  15. 关于ztree的使用
  16. java宠物商店_Java如何实现宠物商店管理 Java实现宠物商店管理代码示例
  17. 计算机组成原理--数制与编码(校验码,CRC,汉明码详解)
  18. Navicat premium 导入大数据的Excel文件失败的方法
  19. 【资源】公开的电子书 合集 (计算机相关、多高清、pdf)
  20. electron与jquery起冲突,使用jquery报错解决方法

热门文章

  1. MacOS 中使用 killall kill
  2. 提交App Store被拒
  3. postgresql json数据筛选 字符串转json
  4. 百度网盘安装,双击一闪而过没有反应
  5. 7-1 圆形体体积计算器 (20分) C++
  6. springboot使用监听器实现用户在线离线状态的监控
  7. C# async / await 用法
  8. 提升App用户活跃度的5个小技巧
  9. Linux网卡新增虚拟ip
  10. 前端接modelmap的list_SpringMVC - 数据怎么从后端到前端?Model, ModelMap, ModelAndView