阿联酋起源人工智能研究院(IIAI)科学家提出了一种新颖的人脸关键点检测方法PIPNet,通过融合坐标回归和热力图回归的优势,并结合半监督学习充分利用大量无标注数据提升跨域的泛化性能,最终得到一个又快又准又稳的人脸关键点检测器。相关论文已被IJCV 2021接收。

论文:https://arxiv.org/abs/2003.03771

代码:https://github.com/jhb86253817/PIPNet

预训练地址:

https://drive.google.com/drive/folders/17OwDgJUfuc5_ymQ3QruD8pUnh5zHreP2

gpu上测试30ms左右。

严重侧脸有时也比较飘。

多人脸时,速度比较慢,一个人脸30多ms,4个人脸1百多ms。

为了得到一个适用于真实应用的人脸关键点模型,本文基于上述挑战提出嵌套网络(PIPNet)模型。该模型主要包含三个重要模块。

首先是一个新颖的检测头,称作嵌套回归(PIP regression)。该方法将关键点定位任务分解成了基于低分辨率特征图的热力图回归和局部特征图上的坐标回归,使模型在不依赖高分辨率特征图的情况下依然具有较高的精度,从而节省了计算量。

此外,我们在检测头上额外设计了近邻回归模块,通过训练每个关键点根据自己的位置定位它的近邻关键点,使得在预测时能得到局部区域的形状约束,从而提升模型的鲁棒性。

最后,我们提出一种基于自训练的半监督学习方法来充分利用大量无标注的不同场景下样本。该方法在对无标注样本估计伪标签时,首先从简单的任务开始,然后在后续迭代中逐渐增加任务的难度,直到变成标准的自训练任务,有效缓解了标准自训练方法在伪标签中引入的噪声问题。

半监督学习:表4展示了STC与基线方法的比较。其中,300W的训练集带有标注,无标注数据集来自COFW和WFLW或CelebA。可以看到,STC无论是与直接跨领域测试,还是与经典的UDA方法DANN以及标准自训练法比较,均取得了更好的结果。

表5展示了与现有方法在跨领域泛化性能上的比较。之前的方法基本遵循在300W上训练,然后直接在测试集上测试(即GSL)。同样遵循这一模式,PIPNet在COFW-68上仅落后于AVS。

而当充分利用CelebA中的无标注数据后,模型在COFW-68上的跨领域性能大幅提高,并超越了AVS,这既表明了STC的有效性,也显示了GSSL范式在实际应用中的可行性。

表4. STC与基线方法的比较

表5. STC与已有方法在同领域及跨领域测试集上的结果比较

速度:为了说明PIPNet在推断速度上的优势,我们与之前的方法比较模型在精度和速度上的平衡。如图1所示,PIPNet在CPU和GPU上均取得了最优的平衡(越靠近右下角越好),尤其是CPU上的优势更为明显。因此,PIPNet很适合计算资源受限的场景。

笔者自己改了一版视频测试脚本:

import cv2, os
import sysfrom FaceBoxesV2.faceboxes_detector import FaceBoxesDetectorsys.path.insert(0, 'FaceBoxesV2')
sys.path.insert(0, '..')
import numpy as np
import importlib
from math import floor
# from faceboxes_detector import *
import timeimport torch
import torch.nn.parallel
import torch.utils.data
import torch.nn.functional as F
import torchvision.transforms as transforms
import torchvision.datasets as datasets
import torchvision.models as modelsfrom networks import *
import data_utils
from functions import *experiment_name = 'pip_32_16_60_r18_l2_l1_10_1_nb10.py '
data_name = 'WFLW'
config_path = '.experiments.{}.{}'.format(data_name, experiment_name)
video_file = '../videos/002.avi'
video_file = 'camera'# my_config = importlib.import_module(config_path, package='PIPNet')
# Config = getattr(my_config, 'Config')
class Config():def __init__(self):self.det_head = 'pip'self.net_stride = 32self.batch_size = 16self.init_lr = 0.0001self.num_epochs = 60self.decay_steps = [30, 50]self.input_size = 256self.backbone = 'resnet18'self.pretrained = Trueself.criterion_cls = 'l2'self.criterion_reg = 'l1'self.cls_loss_weight = 10self.reg_loss_weight = 1self.num_lms = 98self.save_interval = self.num_epochsself.num_nb = 10self.use_gpu = Trueself.gpu_id = 2cfg = Config()
cfg.experiment_name = experiment_name
cfg.data_name = data_namesave_dir = os.path.join('./snapshots', cfg.data_name, cfg.experiment_name)meanface_indices, reverse_index1, reverse_index2, max_len = get_meanface(os.path.join('../data', cfg.data_name, 'meanface.txt'), cfg.num_nb)if cfg.backbone == 'resnet18':resnet18 = models.resnet18(pretrained=cfg.pretrained)net = Pip_resnet18(resnet18, cfg.num_nb, num_lms=cfg.num_lms, input_size=cfg.input_size, net_stride=cfg.net_stride)
device = torch.device("cpu")
if cfg.use_gpu:device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")net = net.to(device)# weight_file = os.path.join(save_dir, 'epoch%d.pth' % (cfg.num_epochs-1))
weight_file = 'epoch59.pth'
state_dict = torch.load(weight_file, map_location=device)
net.load_state_dict(state_dict)normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225])
preprocess = transforms.Compose([transforms.Resize((cfg.input_size, cfg.input_size)), transforms.ToTensor(), normalize])def demo_video(video_file, net, preprocess, input_size, net_stride, num_nb, use_gpu, device):detector = FaceBoxesDetector('FaceBoxes', '../FaceBoxesV2/weights/FaceBoxesV2.pth', use_gpu, device)my_thresh = 0.9det_box_scale = 1.2net.eval()if video_file == 'camera':cap = cv2.VideoCapture(0)else:cap = cv2.VideoCapture(video_file)if (cap.isOpened()== False): print("Error opening video stream or file")frame_width = int(cap.get(3))frame_height = int(cap.get(4))count = 0while(cap.isOpened()):ret, frame = cap.read()if ret == True:start=time.time()detections, _ = detector.detect(frame, my_thresh, 1)print('detect time',time.time()-start)start = time.time()for i in range(len(detections)):det_xmin = detections[i][2]det_ymin = detections[i][3]det_width = detections[i][4]det_height = detections[i][5]det_xmax = det_xmin + det_width - 1det_ymax = det_ymin + det_height - 1det_xmin -= int(det_width * (det_box_scale-1)/2)# remove a part of top area for alignment, see paper for detailsdet_ymin += int(det_height * (det_box_scale-1)/2)det_xmax += int(det_width * (det_box_scale-1)/2)det_ymax += int(det_height * (det_box_scale-1)/2)det_xmin = max(det_xmin, 0)det_ymin = max(det_ymin, 0)det_xmax = min(det_xmax, frame_width-1)det_ymax = min(det_ymax, frame_height-1)det_width = det_xmax - det_xmin + 1det_height = det_ymax - det_ymin + 1cv2.rectangle(frame, (det_xmin, det_ymin), (det_xmax, det_ymax), (0, 0, 255), 2)det_crop = frame[det_ymin:det_ymax, det_xmin:det_xmax, :]det_crop = cv2.resize(det_crop, (input_size, input_size))inputs = Image.fromarray(det_crop[:,:,::-1].astype('uint8'), 'RGB')inputs = preprocess(inputs).unsqueeze(0)inputs = inputs.to(device)lms_pred_x, lms_pred_y, lms_pred_nb_x, lms_pred_nb_y, outputs_cls, max_cls = forward_pip(net, inputs, preprocess, input_size, net_stride, num_nb)lms_pred = torch.cat((lms_pred_x, lms_pred_y), dim=1).flatten()tmp_nb_x = lms_pred_nb_x[reverse_index1, reverse_index2].view(cfg.num_lms, max_len)tmp_nb_y = lms_pred_nb_y[reverse_index1, reverse_index2].view(cfg.num_lms, max_len)tmp_x = torch.mean(torch.cat((lms_pred_x, tmp_nb_x), dim=1), dim=1).view(-1,1)tmp_y = torch.mean(torch.cat((lms_pred_y, tmp_nb_y), dim=1), dim=1).view(-1,1)lms_pred_merge = torch.cat((tmp_x, tmp_y), dim=1).flatten()lms_pred = lms_pred.cpu().numpy()lms_pred_merge = lms_pred_merge.cpu().numpy()for i in range(cfg.num_lms):x_pred = lms_pred_merge[i*2] * det_widthy_pred = lms_pred_merge[i*2+1] * det_heightcv2.circle(frame, (int(x_pred)+det_xmin, int(y_pred)+det_ymin), 1, (0, 0, 255), 2)print('keypoint time', time.time() - start)count += 1#cv2.imwrite('video_out2/'+str(count)+'.jpg', frame)cv2.imshow('1', frame)if cv2.waitKey(1) & 0xFF == ord('q'):breakelse:breakcap.release()cv2.destroyAllWindows()demo_video(video_file, net, preprocess, cfg.input_size, cfg.net_stride, cfg.num_nb, cfg.use_gpu, device)

IJCV2021 人脸关键点检测器PIPNet相关推荐

  1. 【IJCV2021】 实用人脸关键点检测器PIPNet:快!准!稳!

    关注公众号,发现CV技术之美 导语 阿联酋起源人工智能研究院(IIAI)科学家提出了一种新颖的人脸关键点检测方法PIPNet,通过融合坐标回归和热力图回归的优势,并结合半监督学习充分利用大量无标注数据 ...

  2. 【论文解读】PFLD:高精度实时人脸关键点检测算法

    这篇文章作者分别来自天津大学.武汉大学.腾讯AI实验室.美国天普大学.该算法对在高通ARM 845处理器可达140fps:另外模型大小较小,仅2.1MB:此外在许多关键点检测的benchmark中也取 ...

  3. 创意赛第二季又来了,PaddleHub人脸关键点检测实现猫脸人嘴特效

    前段时间,下班后闲来无事,参加了百度PaddleHub的AI人像抠图创意赛,凭借着大家的阅读量,获得了一个第三名,得了一个小度音响,真香啊! 对,说的是我 小奖品 PaddleHub创意赛第二期又出来 ...

  4. 由6,14以及68点人脸关键点计算头部姿态

    前言 关于头部姿态估计理论部分的内容,网络上包括我所列的参考文献中都有大量概述,我不再重复.这里直入主题,如何通过图像中2D人脸关键点计算出头部姿态角,具体就是计算出俯仰角(pitch),偏航角(ya ...

  5. 基于人脸关键点修复人脸,腾讯等提出优于SOTA的LaFIn生成网络

    作者 | Yang Yang.Xiaojie Guo.Jiayi Ma.Lin Ma.Haibin Ling 译者 | 刘畅 编辑 | Jane 出品 | AI科技大本营(ID:rgznai100) ...

  6. CV:利用cv2+dlib库自带frontal_face_detector(人脸征检测器)实现人脸检测与人脸标记之《极限男人帮》和《NBA全明星球员》

    CV:利用cv2+dlib库自带frontal_face_detector(人脸征检测器)实现人脸检测与人脸标记之<极限男人帮>和<NBA全明星球员> 目录 输出结果 设计思路 ...

  7. caffe 人脸关键点检测_全套 | 人脸检测 人脸关键点检测 人脸卡通化

    点击上方"AI算法与图像处理",选择加"星标"或"置顶" 重磅干货,第一时间送达 来源:CVPy 人脸检测历险记 可能跟我一样,人脸检测是很 ...

  8. 【dlib库】进行人脸检测+人脸关键点检测+人脸对齐

    原图像: 1. 人脸检测 import cv2 import dlib import matplotlib.pyplot as plt # 获取图片 my_img = cv2.imread('my_i ...

  9. 级联MobileNet-V2实现CelebA人脸关键点检测(附训练源码)

    文章目录 一 .引言 1.1为什么是级联? 1.2为什么是MobileNet-V2? 二. 级联MobileNet-V2之人脸关键点检测 2.0 修改caffe 2.1 整体框架及思路 2.2 原始数 ...

最新文章

  1. vert.x 结合JAX-RS
  2. 2020全国大学生数学建模竞赛【论文格式、时间节点及作品提交要求、竞赛题目下载、评分要点】【微信公众号:校苑数模】
  3. 随心篇第九期:我不愿一无所有
  4. 传奇谢幕,回顾霍金76载传奇人生
  5. 新建子窗体 1124
  6. java cache教程_Java 中常用缓存Cache机制的实现
  7. 栈和队列基本概念,顺序栈的表示和实现
  8. 中调用view_在 View 上使用挂起函数
  9. 传说中的ACM大牛们
  10. php网站后台密码忘记,phpweb忘记后台密码
  11. 一个投标经理的标书检查笔记,拿来就用!
  12. Cadence PCB仿真 使用Allegro PCB SI为BRD文件创建通用型IBIS模型的方法图文教程
  13. Uderstanding and using Pointers 读书笔记
  14. STM32F411核心板固件库开发(一) GPIO基本配置
  15. XRAY项目--电荷积分放大器AD8488介绍
  16. c 工厂模式与mysql链接_工厂模式连接数据库
  17. HANA 与 Oracle 12c哪一个更快
  18. VAE_MNIST数字图片识别及生成
  19. 工业镜头参数及选型参考
  20. 树莓派TF卡磁盘扩容 分区扩容

热门文章

  1. Android开发--FileInputStream/OutStream/Sdcard写入
  2. linux内核cfs浅析
  3. 随机生成100万个数,排序后保存在文件中
  4. Android SDK目录结构
  5. The Elements of Statistical Learning的笔记
  6. handler post r 同一个线程的疑惑
  7. 问题集锦(26-29)
  8. 程振波 算法设计与分析_算法分析与设计之动态规划
  9. c语言实现将字符串首尾*删除,C语言实现Trim()函数:删除字符串首尾空格。...
  10. java 模式匹配算法_用Java匹配模式