前言

tensorflow官方有个姿态估计项目,这个输入和openpose还有点不一样,这里写个单人情况下的模型输出解析方案。

国际惯例,参考博客:

博客: 使用 TensorFlow.js 在浏览器端上实现实时人体姿势检测

tensorflow中posnet的IOS代码

解析

不要下载官方overview网址下的posenet模型multi_person_mobilenet_v1_075_float.tflite,要去下载IOS端的posenet_mobilenet_v1_100_257x257_multi_kpt_stripped.tflite模型,在github上一搜有一堆,文末放网盘下载地址。

读取模型

先载入必要的工具包:

import numpy as np
import tensorflow as tf
import cv2 as cv
import matplotlib.pyplot as plt
import time

使用tflite载入模型文件

model = tf.lite.Interpreter('posenet_mobilenet_v1_100_257x257_multi_kpt_stripped.tflite')
model.allocate_tensors()
input_details = model.get_input_details()
output_details = model.get_output_details()

看看输入输出分别是什么

print(input_details)
print(output_details)
'''
[{'name': 'sub_2', 'index': 93, 'shape': array([  1, 257, 257,   3], dtype=int32), 'dtype': <class 'numpy.float32'>, 'quantization': (0.0, 0)}]
[{'name': 'MobilenetV1/heatmap_2/BiasAdd', 'index': 87, 'shape': array([ 1,  9,  9, 17], dtype=int32), 'dtype': <class 'numpy.float32'>, 'quantization': (0.0, 0)}, {'name': 'MobilenetV1/offset_2/BiasAdd', 'index': 90, 'shape': array([ 1,  9,  9, 34], dtype=int32), 'dtype': <class 'numpy.float32'>, 'quantization': (0.0, 0)}, {'name': 'MobilenetV1/displacement_fwd_2/BiasAdd', 'index': 84, 'shape': array([ 1,  9,  9, 32], dtype=int32), 'dtype': <class 'numpy.float32'>, 'quantization': (0.0, 0)}, {'name': 'MobilenetV1/displacement_bwd_2/BiasAdd', 'index': 81, 'shape': array([ 1,  9,  9, 32], dtype=int32), 'dtype': <class 'numpy.float32'>, 'quantization': (0.0, 0)}]
'''

很容易看出输入是(257,257)尺寸的彩色图像。

输出就比较麻烦了,有两块:(9,9,17)的称为heatmap的热度图;(9,9,34)的称为offset的偏移图。其实想想也能知道,热度图定位关节的大概位置,用偏移图做进一步的矫正。接下来逐步分析怎么利用这两个输出将关节位置定位的。

输入图像推断

必须将图像resize一下再丢进去,但是tensorflowjs里面说不用resize的方法,我还没试过。

img = cv.imread('../../photo/1.jpeg')
input_img = tf.reshape(tf.image.resize(img, [257,257]), [1,257,257,3])
floating_model = input_details[0]['dtype'] == np.float32
if floating_model:input_img = (np.float32(input_img) - 127.5) / 127.5
model.set_tensor(input_details[0]['index'], input_img)
start = time.time()
model.invoke()
print('time:',time.time()-start)
output_data =  model.get_tensor(output_details[0]['index'])
offset_data = model.get_tensor(output_details[1]['index'])
heatmaps = np.squeeze(output_data)
offsets = np.squeeze(offset_data)
print("output shape: {}".format(output_data.shape))
'''
time: 0.12212681770324707
output shape: (1, 9, 9, 17)
'''

可视化变换后的图

show_img = np.squeeze((input_img.copy()*127.5+127.5)/255.0)[:,:,::-1]
show_img = np.array(show_img*255,np.uint8)
plt.imshow(show_img)
plt.axis('off')

解析输出

一句话概括原理:热度图将图像划分网格,每个网格的得分代表当前关节在此网格点附近的概率;偏移图代表xy两个坐标相对于网格点的偏移情况。

假设提取第2个关节的坐标位置:

  • 先得到最可能的网格点:

    i=1
    joint_heatmap = heatmaps[...,i]
    max_val_pos = np.squeeze(np.argwhere(joint_heatmap==np.max(joint_heatmap)))
    remap_pos = np.array(max_val_pos/8*257,dtype=np.int32)
    
  • offset加上去,前1-17是x坐标偏移,后18-34是y坐标偏移

    refine_pos = np.zeros((2),dtype=int)
    refine_pos[0] = int(remap_pos[0] + offsets[max_val_pos[0],max_val_pos[1],i])
    refine_pos[1] = int(remap_pos[1] + offsets[max_val_pos[0],max_val_pos[1],i+heatmaps.shape[-1]])
    

可视化看看

show_img = np.squeeze((input_img.copy()*127.5+127.5)/255.0)[:,:,::-1]
show_img = np.array(show_img*255,np.uint8)
plt.figure(figsize=(8,8))
plt.imshow(cv.circle(show_img,(refine_pos[1],refine_pos[0]),2,(0,255,0),-1))

映射原图

因为上面是把原图resize乘(257,257)以后的坐标,所以根据原图的缩放系数,重新映射回去

ratio_x = img.shape[0]/257
ratio_y = img.shape[1]/257
refine_pos[0]=refine_pos[0]*ratio_x
refine_pos[1]=refine_pos[1]*ratio_y

可视化

show_img1 = img[:,:,::-1]
plt.figure(figsize=(8,8))
plt.imshow(cv.circle(show_img1.copy(),(refine_pos[1],refine_pos[0]),2,(0,255,0),-1))

封装函数

上面是提取单个关节的,写成函数提取所有关节的坐标就是

def parse_output(heatmap_data,offset_data):joint_num = heatmap_data.shape[-1]pose_kps = np.zeros((joint_num,2),np.uint8)for i in range(heatmap_data.shape[-1]):joint_heatmap = heatmap_data[...,i]max_val_pos = np.squeeze(np.argwhere(joint_heatmap==np.max(joint_heatmap)))remap_pos = np.array(max_val_pos/8*257,dtype=np.int32)pose_kps[i,0] = int(remap_pos[0] + offset_data[max_val_pos[0],max_val_pos[1],i])pose_kps[i,1] = int(remap_pos[1] + offset_data[max_val_pos[0],max_val_pos[1],i+joint_num])return pose_kps

画图的函数也很容易

def draw_kps(show_img,kps):for i in range(kps.shape[0]):cv.circle(show_img,(kps[i,1],kps[i,0]),2,(0,255,0),-1)return show_img

画出来瞅瞅

kps = parse_output(heatmaps,offsets)
plt.figure(figsize=(8,8))
plt.imshow(draw_kps(show_img.copy(),kps))
plt.axis('off')

后记

模型文件:链接:https://pan.baidu.com/s/1heRKFFz28yvpAmvFqDeAXw 密码:5tuw

博客代码:链接:https://pan.baidu.com/s/1Y7WXfQ4WC9QyOGkkN2-kUQ 密码:ono0

本文已经同步到微信公众号中,公众号与本博客将持续同步更新运动捕捉、机器学习、深度学习、计算机视觉算法,敬请关注

tensorflow官方posenet模型解析相关推荐

  1. 解析Tensorflow官方PTB模型的demo

    正文共7138个字,1张图,预计阅读时间18分钟. 01 seq2seq代码案例解读 RNN 模型作为一个可以学习时间序列的模型被认为是深度学习中比较重要的一类模型.在Tensorflow的官方教程中 ...

  2. 3D姿态估计——ThreeDPose项目简单易用的模型解析

    前言 之前写过tensorflow官方的posenet模型解析,用起来比较简单,但是缺点是只有2D关键点,本着易用性的原则,当然要再来个简单易用的3D姿态估计.偶然看见了ThreeDPose的项目,感 ...

  3. 深度学习利器:TensorFlow与NLP模型

    深度学习利器:TensorFlow与NLP模型 享到:微博微信FacebookTwitter有道云笔记邮件分享 稍后阅读 我的阅读清单 前言 自然语言处理(简称NLP),是研究计算机处理人类语言的一门 ...

  4. TensorFlow与PyTorch模型部署性能比较

    TensorFlow与PyTorch模型部署性能比较 前言 2022了,选 PyTorch 还是 TensorFlow?之前有一种说法:TensorFlow 适合业界,PyTorch 适合学界.这种说 ...

  5. TensorFlow官方入门实操课程-一个神经元的网络(线性曲线预测)

    基于如下的课程进行的学习记录 TensorFlow官方入门实操课程 #设置显卡内存使用率,根据使用率占用 import os os.environ["TF_FORCE_GPU_ALLOW_G ...

  6. 【tensorflow速成】Tensorflow图像分类从模型自定义到测试

    文章首发于微信公众号<与有三学AI> [tensorflow速成]Tensorflow图像分类从模型自定义到测试 这是给大家准备的tensorflow速成例子 上一篇介绍了 Caffe , ...

  7. tensorflow笔记:模型的保存与训练过程可视化

    tensorflow笔记系列:  (一) tensorflow笔记:流程,概念和简单代码注释  (二) tensorflow笔记:多层CNN代码分析  (三) tensorflow笔记:多层LSTM代 ...

  8. 【tfcoreml】tensorflow向CoreML模型的转换工具封装

    安装tf向apple coreml模型转换包tfcoreml 基于苹果自己的转换工具coremltools进行封装 tfcoreml 为了将训练的模型转换到apple中使用,需要将模型转换为ios支持 ...

  9. TensorFlow官方发布剪枝优化工具:参数减少80%,精度几乎不变

    晓查 编译自 Medium 量子位 报道 | 公众号 QbitAI 去年TensorFlow官方推出了模型优化工具,最多能将模型尺寸减小4倍,运行速度提高3倍. 最近现又有一款新工具加入模型优化&qu ...

最新文章

  1. 高薪源于专注和极致!
  2. matlab多维数组、结构体数组
  3. 【机器视觉】 fuzzy_measure_pairs算子
  4. 测试私有方法_史上最轻量!阿里开源了新型单元测试Mock工具
  5. [pandas]方法总结
  6. CTime类,CTime 与 CString转换
  7. WEB-移动端图片适配-弹框
  8. 怎样对齐文体框和图像按钮
  9. 译DevExpress v16.1更新说明(WinForms篇)
  10. 软件工程网络15个人阅读作业2(201521123111 陈伟泽)
  11. 新手经常忽略的嵌入式基础知识点,你都掌握了吗?
  12. 【渝粤教育】电大中专营销策划原理与实务答案作业 题库
  13. Unity Shader - 基础光照之漫反射
  14. dataframe分组并求平均
  15. 年度光电领域盛会——CIOE中国光博会开幕在即!小枣君将全程在线直播!
  16. 戴尔 Latitude E5430 non-vPro 笔记本电脑
  17. arXiv每日推荐-5.16:语音/音频每日论文速递
  18. 剑指Offer_入门_JZZ_斐波那契数列
  19. [Erlang] XML处理方案
  20. Flask框架二 Jinja2

热门文章

  1. oracle通过执行计划cost,Oracle 执行计划(5)—cost成本之索引范围扫描-B树索引
  2. Obtain a Permutation(思维)
  3. 对于C++中多态的理解
  4. extjs年月日时分选择控件_UI设计|网站公共控件及交互事件
  5. 玩转GIT系列之【git切换到某个tag之后提示“detached HEAD】
  6. 【Tensorflow】tf.set_random_seed(seed)
  7. Opencv3编程入门学习笔记(四)之split通道分离Debug过程中0xC0000005内存访问冲突问题
  8. 梯度下降的三种形式——BGD、SGD、MBGD
  9. 图像算法中常用的数学概念
  10. 算法竞赛训练指南代码仓库_数据仓库综合指南