使用关键点进行小目标检测
点击上方“小白学视觉”,选择加"星标"或“置顶”
重磅干货,第一时间送达
本文转自|OpenCV学堂
【导语】本文搞了一个小的库,主要是用于定位红外小目标。由于其具有尺度很小的特点,所以可以尝试用点的方式代表其位置。本文主要采用了回归和heatmap两种方式来回归关键点,是一个很简单基础的项目,代码量很小,可供新手学习。
1. 数据来源
数据集:数据来源自小武,经过小武的授权使用,但不会公开。本项目只用了其中很少一部分共108张图片。
标注工具:https://github.com/pprp/landmark_annotation
标注工具也可以在公众号后台回复“landmark”关键字获取
上图是数据集中的两张图片,红圈代表对应的目标,标注的时候只需要在其中心点一下即可得到该点对应的横纵坐标。
该数据集有一个特点,每张图只有一个目标(不然没法用简单的方法回归),多余一个目标的图片被剔除了。
1
0.42 0.596
以上是一个标注文件的例子,1.jpg对应1.txt
2. 回归确定关键点
回归确定关键点比较简单,网络部分采用手工构建的一个两层的小网络,训练采用的是MSELoss。
这部分代码在:https://github.com/pprp/SimpleCVReproduction/tree/master/simple_keypoint/regression
2.1 数据加载
数据的组织比较简单,按照以下格式组织:
- data- images- 1.jpg- 2.jpg- ...- labels- 1.txt- 2.txt- ...
重写一下Dataset类,用于加载数据集。
class KeyPointDatasets(Dataset):def __init__(self, root_dir="./data", transforms=None):super(KeyPointDatasets, self).__init__()self.img_path = os.path.join(root_dir, "images")# self.txt_path = os.path.join(root_dir, "labels")self.img_list = glob.glob(os.path.join(self.img_path, "*.jpg"))self.txt_list = [item.replace(".jpg", ".txt").replace("images", "labels") for item in self.img_list]if transforms is not None:self.transforms = transformsdef __getitem__(self, index):img = self.img_list[index]txt = self.txt_list[index]img = cv2.imread(img)if self.transforms:img = self.transforms(img)label = []with open(txt, "r") as f:for i, line in enumerate(f):if i == 0:# 第一行num_point = int(line.strip())else:x1, y1 = [(t.strip()) for t in line.split()]# range from 0 to 1x1, y1 = float(x1), float(y1)tmp_label = (x1, y1)label.append(tmp_label)return img, torch.tensor(label[0])def __len__(self):return len(self.img_list)@staticmethoddef collect_fn(batch):imgs, labels = zip(*batch)return torch.stack(imgs, 0), torch.stack(labels, 0)
返回的结果是图片和对应坐标位置。
2.2 网络模型
import torch
import torch.nn as nnclass KeyPointModel(nn.Module):def __init__(self):super(KeyPointModel, self).__init__()self.conv1 = nn.Conv2d(3, 6, 3, 1, 1)self.bn1 = nn.BatchNorm2d(6)self.relu1 = nn.ReLU(True)self.maxpool1 = nn.MaxPool2d((2, 2))self.conv2 = nn.Conv2d(6, 12, 3, 1, 1)self.bn2 = nn.BatchNorm2d(12)self.relu2 = nn.ReLU(True)self.maxpool2 = nn.MaxPool2d((2, 2))self.gap = nn.AdaptiveMaxPool2d(1)self.classifier = nn.Sequential(nn.Linear(12, 2),nn.Sigmoid())def forward(self, x):x = self.conv1(x)x = self.bn1(x)x = self.relu1(x)x = self.maxpool1(x)x = self.conv2(x)x = self.bn2(x)x = self.relu2(x)x = self.maxpool2(x)x = self.gap(x)x = x.view(x.shape[0], -1)return self.classifier(x)
其结构就是卷积+pooling+卷积+pooling+global average pooling+Linear,返回长度为2的tensor。
2.3 训练
def train(model, epoch, dataloader, optimizer, criterion):model.train()for itr, (image, label) in enumerate(dataloader):bs = image.shape[0]output = model(image)loss = criterion(output, label)optimizer.zero_grad()loss.backward()optimizer.step()if itr % 4 == 0:print("epoch:%2d|step:%04d|loss:%.6f" % (epoch, itr, loss.item()/bs))vis.plot_many_stack({"train_loss": loss.item()*100/bs})total_epoch = 300
bs = 10
########################################
transforms_all = transforms.Compose([transforms.ToPILImage(),transforms.Resize((360,480)),transforms.ToTensor(),transforms.Normalize(mean=[0.4372, 0.4372, 0.4373],std=[0.2479, 0.2475, 0.2485])
])datasets = KeyPointDatasets(root_dir="./data", transforms=transforms_all)data_loader = DataLoader(datasets, shuffle=True,batch_size=bs, collate_fn=datasets.collect_fn)model = KeyPointModel()optimizer = torch.optim.Adam(model.parameters(), lr=3e-4)
# criterion = torch.nn.SmoothL1Loss()
criterion = torch.nn.MSELoss()
scheduler = torch.optim.lr_scheduler.StepLR(optimizer,step_size=30,gamma=0.1)for epoch in range(total_epoch):train(model, epoch, data_loader, optimizer, criterion)loss = test(model, epoch, data_loader, criterion)if epoch % 10 == 0:torch.save(model.state_dict(),"weights/epoch_%d_%.3f.pt" % (epoch, loss*1000))
loss部分使用Smooth L1 loss或者MSE loss均可。
MSE Loss:
Smooth L1 Loss:
2.4 测试结果
3. heatmap确定关键点
这部分代码很多参考了CenterNet,不过曾经尝试CenterNet中的loss在这个问题上收敛效果不好,所以参考了kaggle人脸关键点定位的解决方法,发现使用简单的MSELoss效果就很好。
3.1 数据加载
这部分和CenterNet构建heatmap的过程类似,不过半径的确定是人工的。因为数据集中的目标都比较小,半径的范围最大不超过半径为30个像素的圆。
class KeyPointDatasets(Dataset):def __init__(self, root_dir="./data", transforms=None):super(KeyPointDatasets, self).__init__()self.down_ratio = 1self.img_w = 480 // self.down_ratioself.img_h = 360 // self.down_ratioself.img_path = os.path.join(root_dir, "images")self.img_list = glob.glob(os.path.join(self.img_path, "*.jpg"))self.txt_list = [item.replace(".jpg", ".txt").replace("images", "labels") for item in self.img_list]if transforms is not None:self.transforms = transformsdef __getitem__(self, index):img = self.img_list[index]txt = self.txt_list[index]img = cv2.imread(img)if self.transforms:img = self.transforms(img)label = []with open(txt, "r") as f:for i, line in enumerate(f):if i == 0:# 第一行num_point = int(line.strip())else:x1, y1 = [(t.strip()) for t in line.split()]# range from 0 to 1x1, y1 = float(x1), float(y1)cx, cy = x1 * self.img_w, y1 * self.img_hheatmap = np.zeros((self.img_h, self.img_w))draw_umich_gaussian(heatmap, (cx, cy), 30)return img, torch.tensor(heatmap).unsqueeze(0)def __len__(self):return len(self.img_list)@staticmethoddef collect_fn(batch):imgs, labels = zip(*batch)return torch.stack(imgs, 0), torch.stack(labels, 0)
核心函数是draw_umich_gaussian,具体如下:
def gaussian2D(shape, sigma=1):m, n = [(ss - 1.) / 2. for ss in shape]y, x = np.ogrid[-m:m + 1, -n:n + 1]h = np.exp(-(x * x + y * y) / (2 * sigma * sigma))h[h < np.finfo(h.dtype).eps * h.max()] = 0# 限制最小的值return hdef draw_umich_gaussian(heatmap, center, radius, k=1):diameter = 2 * radius + 1gaussian = gaussian2D((diameter, diameter), sigma=diameter / 6)# 一个圆对应内切正方形的高斯分布x, y = int(center[0]), int(center[1])width, height = heatmap.shapeleft, right = min(x, radius), min(width - x, radius + 1)top, bottom = min(y, radius), min(height - y, radius + 1)masked_heatmap = heatmap[y - top:y + bottom, x - left:x + right]masked_gaussian = gaussian[radius - top:radius +bottom, radius - left:radius + right]if min(masked_gaussian.shape) > 0 and min(masked_heatmap.shape) > 0: # TODO debugnp.maximum(masked_heatmap, masked_gaussian * k, out=masked_heatmap)# 将高斯分布覆盖到heatmap上,取最大,而不是叠加return heatmap
sigma参数直接沿用了CenterNet中的设置,没有调节这个超参数。
3.2 网络结构
网络结构参考了知乎上一个复现YOLOv3中提到的模块,Sematic Embbed Block(SEB)用于上采样部分,将来自低分辨率的特征图进行上采样,然后使用3x3卷积和1x1卷积统一通道个数,最后将低分辨率特征图和高分辨率特征图相乘得到融合结果。
class SematicEmbbedBlock(nn.Module):def __init__(self, high_in_plane, low_in_plane, out_plane):super(SematicEmbbedBlock, self).__init__()self.conv3x3 = nn.Conv2d(high_in_plane, out_plane, 3, 1, 1)self.upsample = nn.UpsamplingBilinear2d(scale_factor=2)self.conv1x1 = nn.Conv2d(low_in_plane, out_plane, 1)def forward(self, high_x, low_x):high_x = self.upsample(self.conv3x3(high_x))low_x = self.conv1x1(low_x)return high_x * low_xclass KeyPointModel(nn.Module):"""downsample ratio=2"""def __init__(self):super(KeyPointModel, self).__init__()self.conv1 = nn.Conv2d(3, 6, 3, 1, 1)self.bn1 = nn.BatchNorm2d(6)self.relu1 = nn.ReLU(True)self.maxpool1 = nn.MaxPool2d((2, 2))self.conv2 = nn.Conv2d(6, 12, 3, 1, 1)self.bn2 = nn.BatchNorm2d(12)self.relu2 = nn.ReLU(True)self.maxpool2 = nn.MaxPool2d((2, 2))self.conv3 = nn.Conv2d(12, 20, 3, 1, 1)self.bn3 = nn.BatchNorm2d(20)self.relu3 = nn.ReLU(True)self.maxpool3 = nn.MaxPool2d((2, 2))self.conv4 = nn.Conv2d(20, 40, 3, 1, 1)self.bn4 = nn.BatchNorm2d(40)self.relu4 = nn.ReLU(True)self.seb1 = SematicEmbbedBlock(40, 20, 20)self.seb2 = SematicEmbbedBlock(20, 12, 12)self.seb3 = SematicEmbbedBlock(12, 6, 6)self.heatmap = nn.Conv2d(6, 1, 1)def forward(self, x):x1 = self.conv1(x)x1 = self.bn1(x1)x1 = self.relu1(x1)m1 = self.maxpool1(x1)x2 = self.conv2(m1)x2 = self.bn2(x2)x2 = self.relu2(x2)m2 = self.maxpool2(x2)x3 = self.conv3(m2)x3 = self.bn3(x3)x3 = self.relu3(x3)m3 = self.maxpool3(x3)x4 = self.conv4(m3)x4 = self.bn4(x4)x4 = self.relu4(x4)up1 = self.seb1(x4, x3)up2 = self.seb2(up1, x2)up3 = self.seb3(up2, x1)out = self.heatmap(up3)return out
网络模型也是自己写的小网络,用了四个卷积层,三个池化层,然后进行了三次上采样。最终输出分辨率和输入分辨率相同。
3.3 训练过程
训练过程和基于回归的方法几乎一样,代码如下:
datasets = KeyPointDatasets(root_dir="./data", transforms=transforms_all)data_loader = DataLoader(datasets, shuffle=True,batch_size=bs, collate_fn=datasets.collect_fn)model = KeyPointModel()if torch.cuda.is_available():model = model.cuda()optimizer = torch.optim.Adam(model.parameters(), lr=3e-3)
criterion = torch.nn.MSELoss() # compute_loss
scheduler = torch.optim.lr_scheduler.StepLR(optimizer,step_size=30,gamma=0.1)for epoch in range(total_epoch):train(model, epoch, data_loader, optimizer, criterion, scheduler)loss = test(model, epoch, data_loader, criterion)if epoch % 5 == 0:torch.save(model.state_dict(),"weights/epoch_%d_%.3f.pt" % (epoch, loss*10000))
用的是MSELoss进行监督,训练曲线如下:
3.4 测试过程
测试过程和CenterNet的推理过程一致,也用到了3x3的maxpooling来筛选极大值点
for iter, (image, label) in enumerate(dataloader):# print(image.shape)bs = image.shape[0]hm = model(image)hm = _nms(hm)hm = hm.detach().numpy()for i in range(bs):hm = hm[i]hm = np.maximum(hm, 0)hm = hm/np.max(hm)hm = normalization(hm)hm = np.uint8(255 * hm)hm = hm[0]# heatmap = torch.sigmoid(heatmap)# hm = cv2.cvtColor(hm, cv2.COLOR_RGB2BGR)hm = cv2.applyColorMap(hm, cv2.COLORMAP_JET)cv2.imwrite("./test_output/output_%d_%d.jpg" % (iter, i), hm)cv2.waitKey(0)
以上的nms和topk代码都在CenterNet系列最后一篇讲过了。这里直接对模型输出结果使用nms,然后进行可视化,结果如下:
上图中白色的点就是目标位置,为了更形象的查看结果,detect.py部分负责可视化。
3.5 可视化
可视化的问题经常遇见,比如CAM、Grad CAM等可视化特征图的时候就会碰到。以下是可视化的一个简单的方法(参考了CSDN的一位博主的方案,具体链接因太过久远找不到了)。
具体实现代码如下:
def normalization(data):_range = np.max(data) - np.min(data)return (data - np.min(data)) / _rangeheatmap = model(img_tensor_list)
heatmap = heatmap.squeeze().cpu()for i in range(bs):img_path = img_list[i]img = cv2.imread(img_path)img = cv2.resize(img, (480, 360))single_map = heatmap[i]hm = single_map.detach().numpy()hm = np.maximum(hm, 0)hm = hm/np.max(hm)hm = normalization(hm)hm = np.uint8(255 * hm)hm = cv2.applyColorMap(hm, cv2.COLORMAP_JET)hm = cv2.resize(hm, (480, 360))superimposed_img = hm * 0.2 + imgcoord_x, coord_y = landmark_coord[i]cv2.circle(superimposed_img, (int(coord_x), int(coord_y)), 2, (0, 0, 0), thickness=-1)cv2.imwrite("./output2/%s_out.jpg" % (img_name_list[i]), superimposed_img)
注意通过处理以后的hm和原图叠加的时候0.2只是一个参考值,这个值既不会影响原图显示又能将heatmap中重点关注的位置可视化出来。
结果如下:
可以看到,定位结果要比回归更准一些,图中黑色点是获取到最终坐标的位置,几乎和目标是重叠的状态,效果比较理想。
4. 总结
笔者做这个小项目初心是想搞清楚如何用关键点进行定位的,关键点被用在很多领域比如人脸关键点定位、车牌定位、人体姿态检测、目标检测等等领域。当时用小武的数据的时候,发现这个数据集的特点就是目标很小,比较适合用关键点来做。之后又开始陆陆续续看CenterNet源码,借鉴了其中很多代码,这才完成了这个小项目。
下载1:OpenCV-Contrib扩展模块中文版教程
在「小白学视觉」公众号后台回复:扩展模块中文教程,即可下载全网第一份OpenCV扩展模块教程中文版,涵盖扩展模块安装、SFM算法、立体视觉、目标跟踪、生物视觉、超分辨率处理等二十多章内容。
下载2:Python视觉实战项目52讲
在「小白学视觉」公众号后台回复:Python视觉实战项目,即可下载包括图像分割、口罩检测、车道线检测、车辆计数、添加眼线、车牌识别、字符识别、情绪检测、文本内容提取、面部识别等31个视觉实战项目,助力快速学校计算机视觉。
下载3:OpenCV实战项目20讲
在「小白学视觉」公众号后台回复:OpenCV实战项目20讲,即可下载含有20个基于OpenCV实现20个实战项目,实现OpenCV学习进阶。
交流群
欢迎加入公众号读者群一起和同行交流,目前有SLAM、三维视觉、传感器、自动驾驶、计算摄影、检测、分割、识别、医学影像、GAN、算法竞赛等微信群(以后会逐渐细分),请扫描下面微信号加群,备注:”昵称+学校/公司+研究方向“,例如:”张三 + 上海交大 + 视觉SLAM“。请按照格式备注,否则不予通过。添加成功后会根据研究方向邀请进入相关微信群。请勿在群内发送广告,否则会请出群,谢谢理解~
使用关键点进行小目标检测相关推荐
- PPDet:减少Anchor-free目标检测中的标签噪声,小目标检测提升明显
本文转载自AI算法修炼营. 这篇文章收录于BMVC2020,主要的思想是减少anchor-free目标检测中的label噪声,在COCO小目标检测上表现SOTA!性能优于FreeAnchor.Cent ...
- 2021年小目标检测最新研究综述 很全面值得收藏
摘要 小目标检测长期以来是计算机视觉中的一个难点和研究热点.在深度学习的驱动下,小目标检测已取得了重大突破,并成功应用于国防安全.智能交通和工业自动化等领域.为了进一步促进小目标检测的发展,本文对小目 ...
- 【CV】小目标检测问题中“小目标”如何定义?其主要技术难点在哪?
前言: 目标检测是计算机视觉领域中的一个重要研究方向,同时也是解决分割.场景理解.目标跟踪.图像描述和事件检测等更高层次视觉任务的基础.在现实场景中,由于小目标是的大量存在,因此小目标检测具有广泛的应 ...
- 暴力改进SSD | 小目标检测的福音
作者 | 小书童 编辑 | 集智书童 点击下方卡片,关注"自动驾驶之心"公众号 ADAS巨卷干货,即可获取 点击进入→自动驾驶之心技术交流群 小目标检测是一个具有挑战性的问题.在 ...
- 小目标检测的增强算法
小目标检测的增强算法 Augmentation for small object detection 摘要 近年来,目标检测取得了令人瞩目的进展.尽管有了这些改进,但在检测小目标和大目标之间的性能仍有 ...
- YOLOV5 的小目标检测网络结构优化方法汇总(附代码)
点击上方"3D视觉工坊",选择"星标" 干货第一时间送达 作者丨南山 来源丨 AI约读社 YOLOv5是一种非常受欢迎的单阶段目标检测,以其性能和速度著称,其结 ...
- YOLO-Z | 记录修改YOLOv5以适应小目标检测的实验过程
点击上方"3D视觉工坊",选择"星标" 干货第一时间送达 作者丨ChaucerG 来源丨集智书童 随着自动驾驶汽车和自动赛车越来越受欢迎,对更快.更准确的检测器 ...
- 如何改进YOLOv3使其更好应用到小目标检测(比YOLO V4高出4%)
点击上方"3D视觉工坊",选择"星标" 干货第一时间送达 作者丨ChaucerG 来源丨集智书童 编辑丨极市平台 导读 针对微小目标的特征分散和层间语义差异的问 ...
- 快速小目标检测--Feature-Fused SSD: Fast Detection for Small Objects
Feature-Fused SSD: Fast Detection for Small Objects 本文针对小目标检测问题,对 SSD 模型进行了一个小的改进,将 contextual infor ...
最新文章
- elasticsearch原理_ElasticSearch读写底层原理及性能调优
- 单片机c语言NTC温度查表程序,STM32查表法读NTC值并显示温度
- 中央空调如何调节温度html,中央空调怎么调温度
- JS设计模式(2)策略模式
- 动态生成lookup字段
- .NPT 扩展名格式文件类型及打开方式分析:首次渗入 XR 内容领域
- 浪潮云海OS C位出道,融合开放基础设施呼之欲出
- JDK 9 发布仅数月,为何在生产环境中却频遭嫌弃?
- 机器人能力再进化,组装宜家椅子只需20分钟! | Science Robotics论文
- Matlab中的滤波器
- DPDK Release 21.05
- H5和原生开发的区别
- 数学英语计算机拼音,幼儿英语拼音数学早教机
- 公有云NAT 网关比较
- android(微博 微信 qq) 分享和第三分认证登录的封装
- 机器学习模型度量方法,分类及回归模型评估
- 联想小新520新品实测,对比当贝投影D3X竟无还手之力
- 牛客网小白二(2018.4.21)
- 路漫漫其修远兮···VB 来15个数尝尝咸淡
- python - 文件操作函数练习
热门文章
- 南大和中大“合体”拯救手残党:基于GAN的PI-REC重构网络,“老婆”画作有救了 | 技术头条...
- 年后跳槽BAT必看:10种数据结构、算法和编程课助你面试通关
- AI一分钟 | 知乎融资2.7亿美元;腾讯投资特斯拉大赚特赚
- 盘点|最实用的机器学习算法优缺点分析,没有比这篇说得更好了
- 手工模拟实现 Docker 容器网络!
- 面试官:给我一个避免消息重复消费的解决方案?
- 美团二面:Redis与MySQL双写一致性如何保证?
- Grafana 7.0 发布:改进的界面、新的插件平台和可视化等
- 聊一聊 Java 服务端中的乱象
- 面试题:如何理解 Linux 的零拷贝技术?