pytorch版本RetinaFace人脸检测模型推理加速_胖胖大海的博客-CSDN博客

pytorch版本RetinaFace人脸检测模型推理加速,去掉FPN第一层,不检测特别小的人脸框_胖胖大海的博客-CSDN博客

代码地址:https://github.com/xxcheng0708/Pytorch_Retinaface_Accelerate


本文介绍的方法是提升pytorch版本RetinaFace代码在数据预处理阶段的速度,使用纯pytorch框架进行模型推理,并不涉及模型的onnx、tensorrt部署等方法。本文介绍的方法适用于从磁盘加载分辨率相同的一批图像使用RetinaFace进行人脸检测,能够带来30%的性能提升。关于pytorch_retinaface使用tensorrt部署请参考https://github.com/wang-xinyu/tensorrtx/tree/master/retinaface。

先上优化前后处理性能的结论:

优化前

优化后

提升效果

分辨率

fps

总耗时(s)

平均耗时(ms)

fps

总耗时(s)

平均耗时(ms)

1920 x 1080

5.92

134

168

8.84

90

113

32.7%

1280 x 720

13.08

256

76

19.49

172

51

32.8%

模型推理耗时主要来自于三个方面:
1、数据预处理:数据预处理阶段通常包括数据的读取、格式转化、归一化、维度扩充等。
2、模型预测:模型的forward过程
3、后处理:数据后处理如目标检测算法中的NMS等操作。
    
https://github.com/biubug6/Pytorch_Retinaface/blob/master/detect.py开源代码中,在数据预处理阶段,主要包括以下几个步骤:
1、使用opencv读取图片数据
2、将读取到的图片数据类型从uint8转换为float32
3、图像数据归一化,各通道减去一个数值,这里是主要耗时部分
4、图像矩阵轴对换,转换成[C, H, W]的形式,然后转换为tensor,并进行维度扩充到[1, C, H, W]的形式
5、将tensor放到GPU上,进行模型推理预测

这一系列操作都是在CPU上进行的,处理速度就会比较慢。

image_path = "./curve/test.jpg"
img_raw = cv2.imread(image_path, cv2.IMREAD_COLOR)
img = np.float32(img_raw)
img -= (104, 117, 123)
img = img.transpose(2, 0, 1)
img = torch.from_numpy(img).unsqueeze(0)
img = img.to(device)

在之前的torchvision.transforms GPU加速,提升预测阶段数据预处理速度内容中,我们介绍过torchvision 0.8.0版本之后提供了read_image函数,可以将图片直接读取为tensor然后放到GPU上做数据预处理操作。在torchvision 0.8.0版本之后,torchvision.io.read_image可以直接将图片读取为[C, H, W]形状的tensor,然后就可以将归一化、维度扩充等操作都放在GPU上进行,速度自然就会提升。经过转化后的数据预处理主要包括以下几个步骤:
1、使用read_image读取图片数据为[C, H, W]的tensor,并放到GPU上
2、使用torchvision.transform.Normalize转换算子在GPU上进行数据归一化
3、扩充数据维度为[1, C, H, W],然后进行模型推理

# read_image读取的是RGB通道顺序,RetinaFace输入的是BGR通道顺序,所以使用[[2, 1, 0], :, :]转换通道顺序
img = read_image(image_path, torchvision.io.ImageReadMode.RGB).to(device)[[2, 1, 0], :, :].float()
img = torchvision.transforms.Normalize(mean=[104.0, 117.0, 123.0], std=[1.0, 1.0, 1.0])
img = img.unsqueeze(0)

完整推理检测代码demo.py如下:

# coding:utf-8import os
import cv2
import torch
from torch import nn
import torch.backends.cudnn as cudnn
import numpy as np
from data import cfg_mnet, cfg_re50
from layers.functions.prior_box import PriorBox
from utils.nms.py_cpu_nms import py_cpu_nms
from models.retinaface import RetinaFace
from utils.box_utils import decode, decode_landm
from imutils.video import FPS
import torchvision
from torchvision.io import read_image
# from utils.timer import print_execute_infoclass RetinaFaceDetector(object):def __init__(self, trained_model, network, use_cpu=False, confidence_threshold=0.02, top_k=5000,nms_threshold=0.4, keep_top_k=750, vis_thres=0.6, im_height=720, im_width=1280):super(RetinaFaceDetector, self).__init__()self.trained_model = trained_modelself.network = networkself.use_cpu = use_cpuself.confidence_threshold = confidence_thresholdself.top_k = top_kself.nms_threshold = nms_thresholdself.keep_top_k = keep_top_kself.vis_thres = vis_thresself.im_height = im_heightself.im_width = im_widthself.device = torch.device("cpu" if self.use_cpu else "cuda")self.norm = torchvision.transforms.Normalize(mean=[104.0, 117.0, 123.0], std=[1.0, 1.0, 1.0])torch.set_grad_enabled(False)self.cfg = Noneif self.network == "mobile0.25":self.cfg = cfg_mnetelif self.network == "resnet50":self.cfg = cfg_re50self.net = RetinaFace(cfg=self.cfg, phase="test")self.load_model(self.trained_model, self.use_cpu)self.net.eval()print(self.net)cudnn.benchmark = Trueself.net = self.net.to(self.device)self.resize = 1self.scale = torch.Tensor([self.im_width, self.im_height, self.im_width, self.im_height])self.scale = self.scale.to(self.device)self.scale1 = torch.Tensor([self.im_width, self.im_height,self.im_width, self.im_height,self.im_width, self.im_height,self.im_width, self.im_height,self.im_width, self.im_height])self.scale1 = self.scale1.to(self.device)self.priorbox = PriorBox(self.cfg, image_size=(self.im_height, self.im_width))self.priors = self.priorbox.forward()self.priors = self.priors.to(self.device)self.prior_data = self.priors.datadef check_keys(self, model, pretrained_state_dict):ckpt_keys = set(pretrained_state_dict.keys())model_keys = set(model.state_dict().keys())used_pretrained_keys = model_keys & ckpt_keysunused_pretrained_keys = ckpt_keys - model_keysmissing_keys = model_keys - ckpt_keysprint('Missing keys:{}'.format(len(missing_keys)))print('Unused checkpoint keys:{}'.format(len(unused_pretrained_keys)))print('Used keys:{}'.format(len(used_pretrained_keys)))assert len(used_pretrained_keys) > 0, 'load NONE from pretrained checkpoint'return Truedef remove_prefix(self, state_dict, prefix):''' Old style model is stored with all names of parameters sharing common prefix 'module.' '''print('remove prefix \'{}\''.format(prefix))f = lambda x: x.split(prefix, 1)[-1] if x.startswith(prefix) else xreturn {f(key): value for key, value in state_dict.items()}def load_model(self, pretrained_path, load_to_cpu):print('Loading pretrained model from {}'.format(pretrained_path))if load_to_cpu:pretrained_dict = torch.load(pretrained_path, map_location=lambda storage, loc: storage)else:pretrained_dict = torch.load(pretrained_path, map_location=lambda storage, loc: storage.cuda(self.device))if "state_dict" in pretrained_dict.keys():pretrained_dict = self.remove_prefix(pretrained_dict['state_dict'], 'module.')else:pretrained_dict = self.remove_prefix(pretrained_dict, 'module.')self.check_keys(self.net, pretrained_dict)self.net.load_state_dict(pretrained_dict, strict=False)# @print_execute_infodef detect(self, img):_, im_height, im_width = img.shapeif im_height != self.im_height or im_width != self.im_width:self.im_height = im_heightself.im_width = im_widthself.scale = torch.Tensor([self.im_width, self.im_height, self.im_width, self.im_height])self.scale = self.scale.to(self.device)self.scale1 = torch.Tensor([self.im_width, self.im_height,self.im_width, self.im_height,self.im_width, self.im_height,self.im_width, self.im_height,self.im_width, self.im_height])self.scale1 = self.scale1.to(self.device)self.priorbox = PriorBox(self.cfg, image_size=(self.im_height, self.im_width))self.priors = self.priorbox.forward()self.priors = self.priors.to(self.device)self.prior_data = self.priors.dataimg = img.to(self.device)[[2, 1, 0], :, :].float()img = self.norm(img)img = img.unsqueeze(0)loc, conf, landms = self.net(img)boxes = decode(loc.data.squeeze(0), self.prior_data, self.cfg['variance'])boxes = boxes * self.scale / self.resizeboxes = boxes.cpu().numpy()scores = conf.squeeze(0).data.cpu().numpy()[:, 1]landms = decode_landm(landms.data.squeeze(0), self.prior_data, self.cfg['variance'])landms = landms * self.scale1 / self.resizelandms = landms.cpu().numpy()# ignore low scoresinds = np.where(scores > self.confidence_threshold)[0]boxes = boxes[inds]landms = landms[inds]scores = scores[inds]# keep top-K before NMSorder = scores.argsort()[::-1][:self..top_k]boxes = boxes[order]landms = landms[order]scores = scores[order]# do NMSdets = np.hstack((boxes, scores[:, np.newaxis])).astype(np.float32, copy=False)keep = py_cpu_nms(dets, self..nms_threshold)# keep = nms(dets, args.nms_threshold,force_cpu=args.cpu)dets = dets[keep, :]landms = landms[keep]# keep top-K faster NMSdets = dets[:self.keep_top_k, :]landms = landms[:self.keep_top_k, :]dets = np.concatenate((dets, landms), axis=1)return detsif __name__ == '__main__':import shutildetector = RetinaFaceDetector(trained_model="./weights/Resnet50_Final.pth", network="resnet50",im_height=720, im_width=1280)fps = FPS()fps.start()data_path = "./images"output_path = "./outputs"if os.path.exists(output_path) is False:shutil.rmtree(output_path)os.makedirs(output_path)for image_name in os.listdir(data_path):image_path = os.path.join(data_path, image_name)img = read_image(image_path, mode=torchvision.io.ImageReadMode.RGB)results = detector.detect(img)fps.update()# save resultsif False:img_raw = cv2.imread(image_path)for b in results:if b[4] < detector.vis_thres:continuetext = "{:.4f}".format(b[4])b = list(map(int, b))cv2.rectangle(img_raw, (b[0], b[1]), (b[2], b[3]), (0, 0, 255), 2)cx = b[0]cy = b[1] + 12cv2.putText(img_raw, text, (cx, cy), cv2.FONT_HERSHEY_DUPLEX, 0.5, (255, 255, 255))# landmscv2.circle(img_raw, (b[5], b[6]), 1, (0, 0, 255), 4)cv2.circle(img_raw, (b[7], b[8]), 1, (0, 255, 255), 4)cv2.circle(img_raw, (b[9], b[10]), 1, (255, 0, 255), 4)cv2.circle(img_raw, (b[11], b[12]), 1, (0, 255, 0), 4)cv2.circle(img_raw, (b[13], b[14]), 1, (255, 0, 0), 4)# save imagecv2.imwrite(os.path.join(output_path, image_name), img_raw)fps.stop()print("duration time: {} s, fps: {}".format(fps.elapsed(), fps.fps()))

pytorch版本RetinaFace人脸检测模型推理加速相关推荐

  1. Pytorch搭建Retinaface人脸检测与关键点定位平台

    学习前言 一起来看看Retinaface的Pytorch实现吧. 在这里插入图片描述 什么是Retinaface人脸检测算法 Retinaface是来自insightFace的又一力作,基于one-s ...

  2. 【项目实战课】基于Pytorch的RetinaFace人脸与关键点检测实战

    欢迎大家来到我们的项目实战课,本期内容是<基于Pytorch的RetinaFace人脸与关键点检测实战>. 所谓项目实战课,就是以简单的原理回顾+详细的项目实战的模式,针对具体的某一个主题 ...

  3. 大小仅1MB,超轻量级通用人脸检测模型登上GitHub趋势榜

    机器之心报道 项目作者:Linzaer 近日,用户 Linzaer 在 Github 上推出了一款适用于边缘计算设备.移动端设备以及 PC 的超轻量级通用人脸检测模型,该模型文件大小仅 1MB,320 ...

  4. 3模型大小_Github推荐一个国内牛人开发的超轻量级通用人脸检测模型

    Ultra-Light-Fast-Generic-Face-Detector-1MB 1MB轻量级通用人脸检测模型 作者表示该模型设计是为了边缘计算设备以及低功耗设备(如arm)设计的实时超轻量级通用 ...

  5. java rfb,github上开源的超轻量级人脸检测模型及github地址。

    该模型设计是针对边缘计算设备或低算力设备(如用ARM推理)设计的实时超轻量级通用人脸检测模型,可以在低算力设备中如用ARM进行实时的通用场景的人脸检测推理,同样适用于移动端.PC.在模型大小上,默认F ...

  6. 超轻量级通用人脸检测模型

    项目地址:github.com/Linzaer/Ult- 以下是作者对此项目的介绍: 该模型设计是针对边缘计算设备或低算力设备 (如用 ARM 推理) 设计的一款实时超轻量级通用人脸检测模型,旨在能在 ...

  7. 模型仅1MB,更轻量的人脸检测模型开源,效果不弱于主流算法

    乾明 编辑整理  量子位 报道 | 公众号 QbitAI AI模型越来越小,需要的算力也也来越弱,但精度依旧有保障. 最新代表,是一个刚在GitHub上开源的中文项目:一款超轻量级通用人脸检测模型. ...

  8. Ubuntu 下使用 FDDB 测试人脸检测模型并生成 ROC 曲线,详细步骤

    原 Ubuntu 下使用 FDDB 测试人脸检测模型并生成 ROC 曲线 2018年08月01日 20:18:44 Xing_yb 阅读数:101 标签: FDDB 人脸检测 模型测试 ROC 曲线 ...

  9. 【模型推理加速系列】05: 推理加速格式TorchScript简介及其应用

    简介 本文紧接前文:模型推理加速系列|04:BERT模型推理加速 TorchScript vs. ONNX 实验结果:在动态文本长度且大batch size的场景下,TorchScript 格式的in ...

最新文章

  1. 005_JavaScript使用
  2. Servlet映射路径中的通配符
  3. (Java常用类)Object类
  4. java接口开发规范,干货满满
  5. AS查看Android系统源码
  6. Linux常用的配置文件
  7. 转:Linux 僵尸进程详解
  8. c语言伪常量const理解
  9. MFC学习(实时更新)
  10. PMP-36项目风险管理
  11. TapTap实习三个月总结
  12. AM调制解调matlab实验报告,MATLAB仿真AM调制解调 无线通信实验报告.doc
  13. 广告代码(弹窗和富媒体)
  14. 五个真实的数据挖掘故事
  15. 工作流与BPM的区别
  16. STATA面板数据模型进行Hausman检验
  17. python中词云图是用来描述_Python如何实现中国地图词云图
  18. 读“王东升 新时空 硅碳融合的产业革命”拙见
  19. Eclipse中properties配置文件的中文乱码
  20. vc文件拖曳(控件)

热门文章

  1. Dell戴尔笔记本电脑灵越Inspiron 5580原装出厂Windows10系统恢复原厂oem系统
  2. 万能实用工具箱微信小程序
  3. 微信小程序开发入门实例
  4. 最新在线客服系统源码软件代码+自动回复+管理后台
  5. Android各国语言对照表
  6. 工业机器人与计算机控制,不懂工业机器人控制技术?那你一定是没看过这篇文章...
  7. 前端工程师必备的 10款开发工具
  8. Python网络爬虫实战1:百度新闻数据爬取
  9. 数百面试问题、覆盖AI核心主题,401页的深度学习面经免费下载了
  10. ESD静电二极管|静电保护器件