以下链接是个人关于FSA-Net(头部姿态估算) 所有见解,如有错误欢迎大家指出,我会第一时间纠正。有兴趣的朋友可以加微信:17575010159 相互讨论技术。若是帮助到了你什么,一定要记得点赞!因为这是对我最大的鼓励。文末附带\color{blue}{文末附带}文末附带公众号−\color{blue}{公众号 -}公众号−海量资源。\color{blue}{ 海量资源}。海量资源。

姿态估计1-00:FSA-Net(头部姿态估算)-目录-史上最新无死角讲解

分析前言

我相信,大家跟到这里来,说明你以及看完论文了,既然如此,我们在来看看training_and_testing/run_fsanet_train.sh文件,这是我们训练的脚本内容如下:

.......

是的,我这里是空的,因为我觉得复制出来太臃肿,显得我的博客不够帅气,所以就不复制出来了,你看自己源码的即可。该文件可以看到如下字样的注释:

# Train on protocal 1
# SSRNET_MT
# FSANET_Capsule
# FSANET_Netvlad
# FSANET_Metric# Train on protocal 2
# SSRNET_MT
# FSANET_Capsule
# FSANET_Netvlad
# FSANET_Metric# Fine-tuned on BIWI with synhead pre-trained model

看完了论文的朋友应该就比较熟悉了,因为在论文中存在如下图示:

也就是说,作者需要执行多次训练脚本,全是为了为了完成实验的对比。那么我们在分析的时候,当然是选择效果最好的哪个进行分析。说得简单,但是现在我也没办法一眼就知道哪个效果最好,不过通过观察可以知道,他们主要的差别就是在于–model_type的参数不一样,其可以选择1到10之间。那么在分析源码的时候,着重分析其处理过程即可。

源码解析

源码注释:

import os
import sys
sys.path.append('..')
import logging
import argparse
import pandas as pd
import numpy as npfrom lib.FSANET_model import *
from lib.SSRNET_model import *import TYY_callbacks
from TYY_generators import *from keras.utils import np_utils
from keras.utils import plot_model
from keras.optimizers import SGD, Adam
from keras.preprocessing.image import ImageDataGenerator
from keras.callbacks import LearningRateScheduler, ModelCheckpointlogging.basicConfig(level=logging.DEBUG)def load_data_npz(npz_path):d = np.load(npz_path)return d["image"], d["pose"]def mk_dir(dir):try:os.mkdir( dir )except OSError:passdef get_args():parser = argparse.ArgumentParser(description="This script trains the CNN model for head pose estimation.",formatter_class=argparse.ArgumentDefaultsHelpFormatter)parser.add_argument("--batch_size", type=int, default=16,help="batch size")parser.add_argument("--nb_epochs", type=int, default=90,help="number of epochs")parser.add_argument("--validation_split", type=float, default=0.2,help="validation split ratio")parser.add_argument("--model_type", type=int, default=3,help="type of model")parser.add_argument("--db_name", type=str, default='300W_LP',help="type of model")args = parser.parse_args()return argsdef main():# 解析并且赋值相关参数args = get_args()db_name = args.db_namebatch_size = args.batch_sizenb_epochs = args.nb_epochsvalidation_split = args.validation_splitmodel_type = args.model_typeimage_size = 64logging.debug("Loading data...")# 如果训练的数据集为300W_LPif db_name == '300W_LP':# 获得对应的npz文件db_list = ['AFW.npz','AFW_Flip.npz','HELEN.npz','HELEN_Flip.npz','IBUG.npz','IBUG_Flip.npz','LFPW.npz','LFPW_Flip.npz']# 用于保存像素image = []# 用于保存姿态pose = []# 循环加入所有的图片像素,以及对应的姿态for i in range(0,len(db_list)):image_temp, pose_temp = load_data_npz('../data/type1/'+db_list[i])image.append(image_temp)pose.append(pose_temp)# 把链表转化为np数组格式.# 加载完数据之后为[122450, 64, 64, 3]image = np.concatenate(image,0)# 加载完数据之后为[122450, 3]pose = np.concatenate(pose,0)# 对于其角度不在[-99,99]之间的数据,全部剔除掉# we only care the angle between [-99,99] and filter other anglesx_data = []y_data = []print(image.shape)print(pose.shape)for i in range(0,pose.shape[0]):temp_pose = pose[i,:]if np.max(temp_pose)<=99.0 and np.min(temp_pose)>=-99.0:x_data.append(image[i,:,:,:])y_data.append(pose[i,:])x_data = np.array(x_data)y_data = np.array(y_data)print(x_data.shape)print(y_data.shape)elif db_name == 'synhead_noBIWI':image, pose = load_data_npz('../data/synhead/media/jinweig/Data2/synhead2_release/synhead_noBIWI.npz')x_data = imagey_data = pose# 如果训练的数据集为BIWIelif db_name == 'BIWI':image, pose = load_data_npz('../data/BIWI_train.npz')x_train = imagey_train = poseimage_test, pose_test = load_data_npz('../data/BIWI_test.npz')x_test = image_testy_test = pose_testelse:print('db_name is wrong!!!')return# 训练到30ep和60ep会进行学习率衰减start_decay_epoch = [30,60]#优化器optMethod = Adam()# 论文中Stage的数目stage_num = [3,3,3]lambda_d = 1# 输出姿态为yaw, pitch, rollnum_classes = 3# 是否使用最好的方法isFine = False#根据model_type参数 进行模型构建if model_type == 0:model = SSR_net_ori_MT(image_size, num_classes, stage_num, lambda_d)()save_name = 'ssrnet_ori_mt'elif model_type == 1:model = SSR_net_MT(image_size, num_classes, stage_num, lambda_d)()save_name = 'ssrnet_mt'elif model_type == 2:num_capsule = 3dim_capsule = 16routings = 2num_primcaps = 7*3m_dim = 5S_set = [num_capsule, dim_capsule, routings, num_primcaps, m_dim]str_S_set = ''.join('_'+str(x) for x in S_set)model = FSA_net_Capsule(image_size, num_classes, stage_num, lambda_d, S_set)()save_name = 'fsanet_capsule'+str_S_setelif model_type == 3:#num_capsule = 3dim_capsule = 16routings = 2num_primcaps = 7*3m_dim = 5S_set = [num_capsule, dim_capsule, routings, num_primcaps, m_dim]str_S_set = ''.join('_'+str(x) for x in S_set)model = FSA_net_Var_Capsule(image_size, num_classes, stage_num, lambda_d, S_set)()save_name = 'fsanet_var_capsule'+str_S_setelif model_type == 4:num_capsule = 3dim_capsule = 16routings = 2num_primcaps = 8*8*3m_dim = 5S_set = [num_capsule, dim_capsule, routings, num_primcaps, m_dim]str_S_set = ''.join('_'+str(x) for x in S_set)model = FSA_net_noS_Capsule(image_size, num_classes, stage_num, lambda_d, S_set)()save_name = 'fsanet_noS_capsule'+str_S_setelif model_type == 5:num_capsule = 3dim_capsule = 16routings = 2num_primcaps = 7*3m_dim = 5S_set = [num_capsule, dim_capsule, routings, num_primcaps, m_dim]str_S_set = ''.join('_'+str(x) for x in S_set)model = FSA_net_NetVLAD(image_size, num_classes, stage_num, lambda_d, S_set)()save_name = 'fsanet_netvlad'+str_S_setelif model_type == 6:num_capsule = 3dim_capsule = 16routings = 2num_primcaps = 7*3m_dim = 5S_set = [num_capsule, dim_capsule, routings, num_primcaps, m_dim]str_S_set = ''.join('_'+str(x) for x in S_set)model = FSA_net_Var_NetVLAD(image_size, num_classes, stage_num, lambda_d, S_set)()save_name = 'fsanet_var_netvlad'+str_S_setelif model_type == 7:num_capsule = 3dim_capsule = 16routings = 2num_primcaps = 8*8*3m_dim = 5S_set = [num_capsule, dim_capsule, routings, num_primcaps, m_dim]str_S_set = ''.join('_'+str(x) for x in S_set)model = FSA_net_noS_NetVLAD(image_size, num_classes, stage_num, lambda_d, S_set)()save_name = 'fsanet_noS_netvlad'+str_S_setelif model_type == 8:# 论文中num_capsule = 3# 论文中的c’=16dim_capsule = 16# 论文中的stream数目routings = 2# 论文中的n'=7,num_primcaps = 7*3# 论文中的m=5m_dim = 5S_set = [num_capsule, dim_capsule, routings, num_primcaps, m_dim]str_S_set = ''.join('_'+str(x) for x in S_set)model = FSA_net_Metric(image_size, num_classes, stage_num, lambda_d, S_set)()save_name = 'fsanet_metric'+str_S_setelif model_type == 9:num_capsule = 3dim_capsule = 16routings = 2num_primcaps = 7*3m_dim = 5S_set = [num_capsule, dim_capsule, routings, num_primcaps, m_dim]str_S_set = ''.join('_'+str(x) for x in S_set)model = FSA_net_Var_Metric(image_size, num_classes, stage_num, lambda_d, S_set)()save_name = 'fsanet_var_metric'+str_S_setelif model_type == 10:num_capsule = 3dim_capsule = 16routings = 2num_primcaps = 8*8*3m_dim = 5S_set = [num_capsule, dim_capsule, routings, num_primcaps, m_dim]str_S_set = ''.join('_'+str(x) for x in S_set)model = FSA_net_noS_Metric(image_size, num_classes, stage_num, lambda_d, S_set)()save_name = 'fsanet_noS_metric'+str_S_set# 指定模型的优化方法,以及loss(均值绝对误差)计算方式,model.compile(optimizer=optMethod, loss=["mae"],loss_weights=[1])logging.debug("Model summary...")# 计算模型参数,打印模型结构model.count_params()model.summary()logging.debug("Saving model...")# 创建必要的目录,如保存模型的路径等等mk_dir(db_name+"_models")mk_dir(db_name+"_models/"+save_name)mk_dir(db_name+"_checkpoints")# 把模型绘画成图,便于分析(总体结构)plot_model(model, to_file=db_name+"_models/"+save_name+"/"+save_name+".png")# 绘画网络模型的细致结构for i_L,layer in enumerate(model.layers):if i_L >0 and i_L< len(model.layers)-1:if 'pred' not in layer.name and 'caps' != layer.name and 'merge' not in layer.name and 'model' in layer.name:plot_model(layer, to_file=db_name+"_models/"+save_name+"/"+layer.name+".png")# 迭代到指定次数,进行学习率衰减decaylearningrate = TYY_callbacks.DecayLearningRate(start_decay_epoch)# 查看指定路径下的模型知否存在,存在则自动加载该目录下的模型callbacks = [ModelCheckpoint(db_name+"_checkpoints/weights.{epoch:02d}-{val_loss:.2f}.hdf5",monitor="val_loss",verbose=1,save_best_only=True,mode="auto"), decaylearningrate]logging.debug("Running training...")# 如果为'BIWI'数据集,则进行测试集和训练集的划分if db_name != 'BIWI':data_num = len(x_data)indexes = np.arange(data_num)np.random.shuffle(indexes)x_data = x_data[indexes]y_data = y_data[indexes]train_num = int(data_num * (1 - validation_split))x_train = x_data[:train_num]x_test = x_data[train_num:]y_train = y_data[:train_num]y_test = y_data[train_num:]elif db_name == 'BIWI':train_num = np.shape(x_train)[0]# 为模型绑定训练数据,测试数据,并且进行训练(真的是个讨厌的框架,用了pytorch之后,其他的框架越看越难受)hist = model.fit_generator(generator=data_generator_pose(X=x_train, Y=y_train, batch_size=batch_size),steps_per_epoch=train_num // batch_size,validation_data=(x_test, y_test),epochs=nb_epochs, verbose=1,callbacks=callbacks)logging.debug("Saving weights...")model.save_weights(os.path.join(db_name+"_models/"+save_name, save_name+'.h5'), overwrite=True)pd.DataFrame(hist.history).to_hdf(os.path.join(db_name+"_models/"+save_name, 'history_'+save_name+'.h5'), "history")if __name__ == '__main__':main()

360行,行行出状元,总得先有钱。三千大道归一大法,无非就是,加载数据,构建模型,训练数据,保存模型。没了,就这么通俗的总结一下。

通过代码的浏览,可以清楚地知道其复杂的地方,是模型的构建过程,具体细节下篇博客进行讲解。

姿态估计1-06:FSA-Net(头部姿态估算)-源码无死角讲解(1)-训练代码总览相关推荐

  1. 姿态估计2-08:PVNet(6D姿态估计)-源码无死角解析(4)-RANSAC投票机制

    以下链接是个人关于PVNet(6D姿态估计) 所有见解,如有错误欢迎大家指出,我会第一时间纠正.有兴趣的朋友可以加微信:17575010159 相互讨论技术.若是帮助到了你什么,一定要记得点赞!因为这 ...

  2. (01)ORB-SLAM2源码无死角解析-(06) 图像金字塔_ORB特征点

    讲解关于slam一系列文章汇总链接:史上最全slam从零开始,针对于本栏目讲解的(01)ORB-SLAM2源码无死角解析链接如下(本文内容来自计算机视觉life ORB-SLAM2 课程课件): (0 ...

  3. 基于确定性最大似然算法 DML 的 DoA 估计,用牛顿法实现(附 MATLAB 源码)

    本文首次在公众号[零妖阁]上发表,为了方便阅读和分享,我们将在其他平台进行自动同步.由于不同平台的排版格式可能存在差异,为了避免影响阅读体验,建议如有排版问题,可前往公众号查看原文.感谢您的阅读和支持 ...

  4. 姿态估计对maskrcnn的优化,姿态估计相比Mask-RCNN提高8.2%,上海交大卢策吾团队开源AlphaPose

    转 2018年02月05日 14:29:24 zchang81 阅读数:3334 查看全文 http://www.taodudu.cc/news/show-5238019.html 相关文章: 上海交 ...

  5. 计算机视觉中头部姿态估计的研究综述--Head Pose Estimation in Computer Vision - A Survey

    计算机视觉中头部姿态估计的研究综述 埃里克.莫非,IEEE的初级会员 默罕 马努拜特里维迪,IEEE高级会员 摘要---让计算机视觉系统作为一个普通人拥有识别另一个人的头部姿势的能力这一想法的提出,对 ...

  6. 姿态估计1-05:FSA-Net(头部姿态估算)-训练测试数据制作-预处理代码讲解

    以下链接是个人关于FSA-Net(头部姿态估算) 所有见解,如有错误欢迎大家指出,我会第一时间纠正.有兴趣的朋友可以加微信:17575010159 相互讨论技术.若是帮助到了你什么,一定要记得点赞!因 ...

  7. 经典论文复现 | PyraNet:基于特征金字塔网络的人体姿态估计

    过去几年发表于各大 AI 顶会论文提出的 400 多种算法中,公开算法代码的仅占 6%,其中三分之一的论文作者分享了测试数据,约 54% 的分享包含"伪代码".这是今年 AAAI ...

  8. 论文阅读笔记--Monocular Human Pose Estimation: A Survey of Deep Learning-based Methods 人体姿态估计综述

    趁着寒假有时间,把之前的论文补完,另外做了一点点笔记,也算是对论文的翻译,尝试探索一条适合自己的论文阅读方法. 这篇笔记基本按照原文的格式来,但是有些地方翻译成中文读起来不太顺,因此添加了一些自己的理 ...

  9. 姿态估计1-02:HR-Net(人体姿态估算)-官方模型训练测试-报错解决

    以下链接是个人关于HR-Net(人体姿态估算)所有见解,如有错误欢迎大家指出,我会第一时间纠正.有兴趣的朋友可以加微信:17575010159 相互讨论技术.若是帮助到了你什么,一定要记得点赞!因为这 ...

  10. 人体姿态估计-评价指标(一)

    人体姿态估计-评价指标(一) 摘要 评价指标 oks(object keypoint similarity) AP(Average Precision)平均准确率 mAP(mean Average P ...

最新文章

  1. LeetCode简单题之字符串中的单词数
  2. cascader 动态加载 回显_Element中的Cascader(级联列表)动态加载省\市\区数据的方法...
  3. 龙格库塔法基本C程序
  4. 分布式系统发展史--转
  5. 多文件目录下makefile文件递归执行编译所有c文件 很不错
  6. 如何用Vue实现简易的富文本编辑器,并支持Markdown语法
  7. 计算几何 -- 旋转坐标系
  8. 智能优化算法应用:基于麻雀搜索算法与双伽马校正的图像自适应增强算法 - 附代码
  9. 现实世界与虚拟世界的差别在哪里
  10. Linux下使用SFTP命令
  11. 交叉火力dsp手机调音软件_DSP调音软件手机版下载-DSP音效处理器app下载 v1.0 安卓版-都去下载...
  12. 华为A1路由器虚拟服务器,华为a1路由器怎么设置
  13. PLC和工控机有什么关系?
  14. PB 数据窗口数据导入Excel, 如果存在则追加,不存在则创建。
  15. jquery获取父级元素、子级元素、兄弟元素的方法
  16. 云计算的未来:看「泛在计算」如何促进数字化生态和计算网络融合
  17. mysql命令查看表内容
  18. 大数据分析和人工智能科普
  19. 青云上NAS服务器挂的操作(他们的文档)
  20. 使用高德API接口查询两个地址之间的距离

热门文章

  1. 微信怎么录屏聊天记录?这两个方法值得收藏!
  2. java批处理查询_java 实现批量查询
  3. 前端别再错过2022的金三银四了。。
  4. php繁体转为简体的函数,繁体中文转换为简体中文的PHP函数_php
  5. pwn unlink
  6. 清明祭曾祖@20130402
  7. Hadoop Web 控制台安全认证
  8. 如何把微信公众号平台做成找券机器人并自动回复优惠券
  9. TestNG单元测试框架详解
  10. 洛谷P2689 东南西北