点击上方“小白学视觉”,选择加"星标"或“置顶

重磅干货,第一时间送达

本文转自|OpenCV学堂

【导语】本文搞了一个小的库,主要是用于定位红外小目标。由于其具有尺度很小的特点,所以可以尝试用点的方式代表其位置。本文主要采用了回归和heatmap两种方式来回归关键点,是一个很简单基础的项目,代码量很小,可供新手学习。

1. 数据来源

数据集:数据来源自小武,经过小武的授权使用,但不会公开。本项目只用了其中很少一部分共108张图片。

标注工具:https://github.com/pprp/landmark_annotation

标注工具也可以在公众号后台回复“landmark”关键字获取

部分样例展示

上图是数据集中的两张图片,红圈代表对应的目标,标注的时候只需要在其中心点一下即可得到该点对应的横纵坐标。

该数据集有一个特点,每张图只有一个目标(不然没法用简单的方法回归),多余一个目标的图片被剔除了。

1
0.42 0.596

以上是一个标注文件的例子,1.jpg对应1.txt

2. 回归确定关键点

回归确定关键点比较简单,网络部分采用手工构建的一个两层的小网络,训练采用的是MSELoss。

这部分代码在:https://github.com/pprp/SimpleCVReproduction/tree/master/simple_keypoint/regression

2.1 数据加载

数据的组织比较简单,按照以下格式组织:

- data- images- 1.jpg- 2.jpg- ...- labels- 1.txt- 2.txt- ...

重写一下Dataset类,用于加载数据集。

class KeyPointDatasets(Dataset):def __init__(self, root_dir="./data", transforms=None):super(KeyPointDatasets, self).__init__()self.img_path = os.path.join(root_dir, "images")# self.txt_path = os.path.join(root_dir, "labels")self.img_list = glob.glob(os.path.join(self.img_path, "*.jpg"))self.txt_list = [item.replace(".jpg", ".txt").replace("images", "labels") for item in self.img_list]if transforms is not None:self.transforms = transformsdef __getitem__(self, index):img = self.img_list[index]txt = self.txt_list[index]img = cv2.imread(img)if self.transforms:img = self.transforms(img)label = []with open(txt, "r") as f:for i, line in enumerate(f):if i == 0:# 第一行num_point = int(line.strip())else:x1, y1 = [(t.strip()) for t in line.split()]# range from 0 to 1x1, y1 = float(x1), float(y1)tmp_label = (x1, y1)label.append(tmp_label)return img, torch.tensor(label[0])def __len__(self):return len(self.img_list)@staticmethoddef collect_fn(batch):imgs, labels = zip(*batch)return torch.stack(imgs, 0), torch.stack(labels, 0)

返回的结果是图片和对应坐标位置。

2.2 网络模型

import torch
import torch.nn as nnclass KeyPointModel(nn.Module):def __init__(self):super(KeyPointModel, self).__init__()self.conv1 = nn.Conv2d(3, 6, 3, 1, 1)self.bn1 = nn.BatchNorm2d(6)self.relu1 = nn.ReLU(True)self.maxpool1 = nn.MaxPool2d((2, 2))self.conv2 = nn.Conv2d(6, 12, 3, 1, 1)self.bn2 = nn.BatchNorm2d(12)self.relu2 = nn.ReLU(True)self.maxpool2 = nn.MaxPool2d((2, 2))self.gap = nn.AdaptiveMaxPool2d(1)self.classifier = nn.Sequential(nn.Linear(12, 2),nn.Sigmoid())def forward(self, x):x = self.conv1(x)x = self.bn1(x)x = self.relu1(x)x = self.maxpool1(x)x = self.conv2(x)x = self.bn2(x)x = self.relu2(x)x = self.maxpool2(x)x = self.gap(x)x = x.view(x.shape[0], -1)return self.classifier(x)

其结构就是卷积+pooling+卷积+pooling+global average pooling+Linear,返回长度为2的tensor。

2.3 训练

def train(model, epoch, dataloader, optimizer, criterion):model.train()for itr, (image, label) in enumerate(dataloader):bs = image.shape[0]output = model(image)loss = criterion(output, label)optimizer.zero_grad()loss.backward()optimizer.step()if itr % 4 == 0:print("epoch:%2d|step:%04d|loss:%.6f" % (epoch, itr, loss.item()/bs))vis.plot_many_stack({"train_loss": loss.item()*100/bs})total_epoch = 300
bs = 10
########################################
transforms_all = transforms.Compose([transforms.ToPILImage(),transforms.Resize((360,480)),transforms.ToTensor(),transforms.Normalize(mean=[0.4372, 0.4372, 0.4373],std=[0.2479, 0.2475, 0.2485])
])datasets = KeyPointDatasets(root_dir="./data", transforms=transforms_all)data_loader = DataLoader(datasets, shuffle=True,batch_size=bs, collate_fn=datasets.collect_fn)model = KeyPointModel()optimizer = torch.optim.Adam(model.parameters(), lr=3e-4)
# criterion = torch.nn.SmoothL1Loss()
criterion = torch.nn.MSELoss()
scheduler = torch.optim.lr_scheduler.StepLR(optimizer,step_size=30,gamma=0.1)for epoch in range(total_epoch):train(model, epoch, data_loader, optimizer, criterion)loss = test(model, epoch, data_loader, criterion)if epoch % 10 == 0:torch.save(model.state_dict(),"weights/epoch_%d_%.3f.pt" % (epoch, loss*1000))

loss部分使用Smooth L1 loss或者MSE loss均可。

MSE Loss:

Smooth L1 Loss:

2.4 测试结果

3. heatmap确定关键点

这部分代码很多参考了CenterNet,不过曾经尝试CenterNet中的loss在这个问题上收敛效果不好,所以参考了kaggle人脸关键点定位的解决方法,发现使用简单的MSELoss效果就很好。

3.1 数据加载

这部分和CenterNet构建heatmap的过程类似,不过半径的确定是人工的。因为数据集中的目标都比较小,半径的范围最大不超过半径为30个像素的圆。

class KeyPointDatasets(Dataset):def __init__(self, root_dir="./data", transforms=None):super(KeyPointDatasets, self).__init__()self.down_ratio = 1self.img_w = 480 // self.down_ratioself.img_h = 360 // self.down_ratioself.img_path = os.path.join(root_dir, "images")self.img_list = glob.glob(os.path.join(self.img_path, "*.jpg"))self.txt_list = [item.replace(".jpg", ".txt").replace("images", "labels") for item in self.img_list]if transforms is not None:self.transforms = transformsdef __getitem__(self, index):img = self.img_list[index]txt = self.txt_list[index]img = cv2.imread(img)if self.transforms:img = self.transforms(img)label = []with open(txt, "r") as f:for i, line in enumerate(f):if i == 0:# 第一行num_point = int(line.strip())else:x1, y1 = [(t.strip()) for t in line.split()]# range from 0 to 1x1, y1 = float(x1), float(y1)cx, cy = x1 * self.img_w, y1 * self.img_hheatmap = np.zeros((self.img_h, self.img_w))draw_umich_gaussian(heatmap, (cx, cy), 30)return img, torch.tensor(heatmap).unsqueeze(0)def __len__(self):return len(self.img_list)@staticmethoddef collect_fn(batch):imgs, labels = zip(*batch)return torch.stack(imgs, 0), torch.stack(labels, 0)

核心函数是draw_umich_gaussian,具体如下:

def gaussian2D(shape, sigma=1):m, n = [(ss - 1.) / 2. for ss in shape]y, x = np.ogrid[-m:m + 1, -n:n + 1]h = np.exp(-(x * x + y * y) / (2 * sigma * sigma))h[h < np.finfo(h.dtype).eps * h.max()] = 0# 限制最小的值return hdef draw_umich_gaussian(heatmap, center, radius, k=1):diameter = 2 * radius + 1gaussian = gaussian2D((diameter, diameter), sigma=diameter / 6)# 一个圆对应内切正方形的高斯分布x, y = int(center[0]), int(center[1])width, height = heatmap.shapeleft, right = min(x, radius), min(width - x, radius + 1)top, bottom = min(y, radius), min(height - y, radius + 1)masked_heatmap = heatmap[y - top:y + bottom, x - left:x + right]masked_gaussian = gaussian[radius - top:radius +bottom, radius - left:radius + right]if min(masked_gaussian.shape) > 0 and min(masked_heatmap.shape) > 0:  # TODO debugnp.maximum(masked_heatmap, masked_gaussian * k, out=masked_heatmap)# 将高斯分布覆盖到heatmap上,取最大,而不是叠加return heatmap

sigma参数直接沿用了CenterNet中的设置,没有调节这个超参数。

3.2 网络结构

网络结构参考了知乎上一个复现YOLOv3中提到的模块,Sematic Embbed Block(SEB)用于上采样部分,将来自低分辨率的特征图进行上采样,然后使用3x3卷积和1x1卷积统一通道个数,最后将低分辨率特征图和高分辨率特征图相乘得到融合结果。

class SematicEmbbedBlock(nn.Module):def __init__(self, high_in_plane, low_in_plane, out_plane):super(SematicEmbbedBlock, self).__init__()self.conv3x3 = nn.Conv2d(high_in_plane, out_plane, 3, 1, 1)self.upsample = nn.UpsamplingBilinear2d(scale_factor=2)self.conv1x1 = nn.Conv2d(low_in_plane, out_plane, 1)def forward(self, high_x, low_x):high_x = self.upsample(self.conv3x3(high_x))low_x = self.conv1x1(low_x)return high_x * low_xclass KeyPointModel(nn.Module):"""downsample ratio=2"""def __init__(self):super(KeyPointModel, self).__init__()self.conv1 = nn.Conv2d(3, 6, 3, 1, 1)self.bn1 = nn.BatchNorm2d(6)self.relu1 = nn.ReLU(True)self.maxpool1 = nn.MaxPool2d((2, 2))self.conv2 = nn.Conv2d(6, 12, 3, 1, 1)self.bn2 = nn.BatchNorm2d(12)self.relu2 = nn.ReLU(True)self.maxpool2 = nn.MaxPool2d((2, 2))self.conv3 = nn.Conv2d(12, 20, 3, 1, 1)self.bn3 = nn.BatchNorm2d(20)self.relu3 = nn.ReLU(True)self.maxpool3 = nn.MaxPool2d((2, 2))self.conv4 = nn.Conv2d(20, 40, 3, 1, 1)self.bn4 = nn.BatchNorm2d(40)self.relu4 = nn.ReLU(True)self.seb1 = SematicEmbbedBlock(40, 20, 20)self.seb2 = SematicEmbbedBlock(20, 12, 12)self.seb3 = SematicEmbbedBlock(12, 6, 6)self.heatmap = nn.Conv2d(6, 1, 1)def forward(self, x):x1 = self.conv1(x)x1 = self.bn1(x1)x1 = self.relu1(x1)m1 = self.maxpool1(x1)x2 = self.conv2(m1)x2 = self.bn2(x2)x2 = self.relu2(x2)m2 = self.maxpool2(x2)x3 = self.conv3(m2)x3 = self.bn3(x3)x3 = self.relu3(x3)m3 = self.maxpool3(x3)x4 = self.conv4(m3)x4 = self.bn4(x4)x4 = self.relu4(x4)up1 = self.seb1(x4, x3)up2 = self.seb2(up1, x2)up3 = self.seb3(up2, x1)out = self.heatmap(up3)return out

网络模型也是自己写的小网络,用了四个卷积层,三个池化层,然后进行了三次上采样。最终输出分辨率和输入分辨率相同。

3.3 训练过程

训练过程和基于回归的方法几乎一样,代码如下:

datasets = KeyPointDatasets(root_dir="./data", transforms=transforms_all)data_loader = DataLoader(datasets, shuffle=True,batch_size=bs, collate_fn=datasets.collect_fn)model = KeyPointModel()if torch.cuda.is_available():model = model.cuda()optimizer = torch.optim.Adam(model.parameters(), lr=3e-3)
criterion = torch.nn.MSELoss()  # compute_loss
scheduler = torch.optim.lr_scheduler.StepLR(optimizer,step_size=30,gamma=0.1)for epoch in range(total_epoch):train(model, epoch, data_loader, optimizer, criterion, scheduler)loss = test(model, epoch, data_loader, criterion)if epoch % 5 == 0:torch.save(model.state_dict(),"weights/epoch_%d_%.3f.pt" % (epoch, loss*10000))

用的是MSELoss进行监督,训练曲线如下:

训练过程中的loss曲线

3.4 测试过程

测试过程和CenterNet的推理过程一致,也用到了3x3的maxpooling来筛选极大值点

for iter, (image, label) in enumerate(dataloader):# print(image.shape)bs = image.shape[0]hm = model(image)hm = _nms(hm)hm = hm.detach().numpy()for i in range(bs):hm = hm[i]hm = np.maximum(hm, 0)hm = hm/np.max(hm)hm = normalization(hm)hm = np.uint8(255 * hm)hm = hm[0]# heatmap = torch.sigmoid(heatmap)# hm = cv2.cvtColor(hm, cv2.COLOR_RGB2BGR)hm = cv2.applyColorMap(hm, cv2.COLORMAP_JET)cv2.imwrite("./test_output/output_%d_%d.jpg" % (iter, i), hm)cv2.waitKey(0)

以上的nms和topk代码都在CenterNet系列最后一篇讲过了。这里直接对模型输出结果使用nms,然后进行可视化,结果如下:

放大结果

上图中白色的点就是目标位置,为了更形象的查看结果,detect.py部分负责可视化。

3.5 可视化

可视化的问题经常遇见,比如CAM、Grad CAM等可视化特征图的时候就会碰到。以下是可视化的一个简单的方法(参考了CSDN的一位博主的方案,具体链接因太过久远找不到了)。

可视化流程

具体实现代码如下:

def normalization(data):_range = np.max(data) - np.min(data)return (data - np.min(data)) / _rangeheatmap = model(img_tensor_list)
heatmap = heatmap.squeeze().cpu()for i in range(bs):img_path = img_list[i]img = cv2.imread(img_path)img = cv2.resize(img, (480, 360))single_map = heatmap[i]hm = single_map.detach().numpy()hm = np.maximum(hm, 0)hm = hm/np.max(hm)hm = normalization(hm)hm = np.uint8(255 * hm)hm = cv2.applyColorMap(hm, cv2.COLORMAP_JET)hm = cv2.resize(hm, (480, 360))superimposed_img = hm * 0.2 + imgcoord_x, coord_y = landmark_coord[i]cv2.circle(superimposed_img, (int(coord_x), int(coord_y)), 2, (0, 0, 0), thickness=-1)cv2.imwrite("./output2/%s_out.jpg" % (img_name_list[i]), superimposed_img)

注意通过处理以后的hm和原图叠加的时候0.2只是一个参考值,这个值既不会影响原图显示又能将heatmap中重点关注的位置可视化出来。

结果如下:

可视化结果

可以看到,定位结果要比回归更准一些,图中黑色点是获取到最终坐标的位置,几乎和目标是重叠的状态,效果比较理想。

4. 总结

笔者做这个小项目初心是想搞清楚如何用关键点进行定位的,关键点被用在很多领域比如人脸关键点定位、车牌定位、人体姿态检测、目标检测等等领域。当时用小武的数据的时候,发现这个数据集的特点就是目标很小,比较适合用关键点来做。之后又开始陆陆续续看CenterNet源码,借鉴了其中很多代码,这才完成了这个小项目。

下载1:OpenCV-Contrib扩展模块中文版教程

在「小白学视觉」公众号后台回复:扩展模块中文教程即可下载全网第一份OpenCV扩展模块教程中文版,涵盖扩展模块安装、SFM算法、立体视觉、目标跟踪、生物视觉、超分辨率处理等二十多章内容。

下载2:Python视觉实战项目52讲

在「小白学视觉」公众号后台回复:Python视觉实战项目即可下载包括图像分割、口罩检测、车道线检测、车辆计数、添加眼线、车牌识别、字符识别、情绪检测、文本内容提取、面部识别等31个视觉实战项目,助力快速学校计算机视觉。

下载3:OpenCV实战项目20讲

在「小白学视觉」公众号后台回复:OpenCV实战项目20讲即可下载含有20个基于OpenCV实现20个实战项目,实现OpenCV学习进阶。

交流群

欢迎加入公众号读者群一起和同行交流,目前有SLAM、三维视觉、传感器、自动驾驶、计算摄影、检测、分割、识别、医学影像、GAN、算法竞赛等微信群(以后会逐渐细分),请扫描下面微信号加群,备注:”昵称+学校/公司+研究方向“,例如:”张三 + 上海交大 + 视觉SLAM“。请按照格式备注,否则不予通过。添加成功后会根据研究方向邀请进入相关微信群。请勿在群内发送广告,否则会请出群,谢谢理解~

使用关键点进行小目标检测相关推荐

  1. PPDet:减少Anchor-free目标检测中的标签噪声,小目标检测提升明显

    本文转载自AI算法修炼营. 这篇文章收录于BMVC2020,主要的思想是减少anchor-free目标检测中的label噪声,在COCO小目标检测上表现SOTA!性能优于FreeAnchor.Cent ...

  2. 2021年小目标检测最新研究综述 很全面值得收藏

    摘要 小目标检测长期以来是计算机视觉中的一个难点和研究热点.在深度学习的驱动下,小目标检测已取得了重大突破,并成功应用于国防安全.智能交通和工业自动化等领域.为了进一步促进小目标检测的发展,本文对小目 ...

  3. 【CV】小目标检测问题中“小目标”如何定义?其主要技术难点在哪?

    前言: 目标检测是计算机视觉领域中的一个重要研究方向,同时也是解决分割.场景理解.目标跟踪.图像描述和事件检测等更高层次视觉任务的基础.在现实场景中,由于小目标是的大量存在,因此小目标检测具有广泛的应 ...

  4. 暴力改进SSD | 小目标检测的福音

    作者 | 小书童  编辑 | 集智书童 点击下方卡片,关注"自动驾驶之心"公众号 ADAS巨卷干货,即可获取 点击进入→自动驾驶之心技术交流群 小目标检测是一个具有挑战性的问题.在 ...

  5. 小目标检测的增强算法

    小目标检测的增强算法 Augmentation for small object detection 摘要 近年来,目标检测取得了令人瞩目的进展.尽管有了这些改进,但在检测小目标和大目标之间的性能仍有 ...

  6. YOLOV5 的小目标检测网络结构优化方法汇总(附代码)

    点击上方"3D视觉工坊",选择"星标" 干货第一时间送达 作者丨南山 来源丨 AI约读社 YOLOv5是一种非常受欢迎的单阶段目标检测,以其性能和速度著称,其结 ...

  7. YOLO-Z | 记录修改YOLOv5以适应小目标检测的实验过程

    点击上方"3D视觉工坊",选择"星标" 干货第一时间送达 作者丨ChaucerG 来源丨集智书童 随着自动驾驶汽车和自动赛车越来越受欢迎,对更快.更准确的检测器 ...

  8. 如何改进YOLOv3使其更好应用到小目标检测(比YOLO V4高出4%)

    点击上方"3D视觉工坊",选择"星标" 干货第一时间送达 作者丨ChaucerG 来源丨集智书童 编辑丨极市平台 导读 针对微小目标的特征分散和层间语义差异的问 ...

  9. 快速小目标检测--Feature-Fused SSD: Fast Detection for Small Objects

    Feature-Fused SSD: Fast Detection for Small Objects 本文针对小目标检测问题,对 SSD 模型进行了一个小的改进,将 contextual infor ...

最新文章

  1. elasticsearch原理_ElasticSearch读写底层原理及性能调优
  2. 单片机c语言NTC温度查表程序,STM32查表法读NTC值并显示温度
  3. 中央空调如何调节温度html,中央空调怎么调温度
  4. JS设计模式(2)策略模式
  5. 动态生成lookup字段
  6. .NPT 扩展名格式文件类型及打开方式分析:首次渗入 XR 内容领域
  7. 浪潮云海OS C位出道,融合开放基础设施呼之欲出
  8. JDK 9 发布仅数月,为何在生产环境中却频遭嫌弃?
  9. 机器人能力再进化,组装宜家椅子只需20分钟! | Science Robotics论文
  10. Matlab中的滤波器
  11. DPDK Release 21.05
  12. H5和原生开发的区别
  13. 数学英语计算机拼音,幼儿英语拼音数学早教机
  14. 公有云NAT 网关比较
  15. android(微博 微信 qq) 分享和第三分认证登录的封装
  16. 机器学习模型度量方法,分类及回归模型评估
  17. 联想小新520新品实测,对比当贝投影D3X竟无还手之力
  18. 牛客网小白二(2018.4.21)
  19. 路漫漫其修远兮···VB 来15个数尝尝咸淡
  20. python - 文件操作函数练习

热门文章

  1. 南大和中大“合体”拯救手残党:基于GAN的PI-REC重构网络,“老婆”画作有救了 | 技术头条...
  2. 年后跳槽BAT必看:10种数据结构、算法和编程课助你面试通关
  3. AI一分钟 | 知乎融资2.7亿美元;腾讯投资特斯拉大赚特赚
  4. 盘点|最实用的机器学习算法优缺点分析,没有比这篇说得更好了
  5. 手工模拟实现 Docker 容器网络!
  6. 面试官:给我一个避免消息重复消费的解决方案?
  7. 美团二面:Redis与MySQL双写一致性如何保证?
  8. Grafana 7.0 发布:改进的界面、新的插件平台和可视化等
  9. 聊一聊 Java 服务端中的乱象
  10. 面试题:如何理解 Linux 的零拷贝技术?