代码库在:https://github.com/deepinsight/insightface

知识点:

mx.model.load_checkpoint('../recognition/mxnet/new_model', 0)

net_model是json的名字,0是训练的epoch值。

fc1_output,感觉每一层都有output,fc1是最后一层的名字,最后一层是bn层,fc1_output是最后一层bn层的输出。

3.将model上传到insightface/models中。

gpu版替换代码:

gpu_id = 0
ctx = mx.gpu(gpu_id)
model = mx.mod.Module(symbol=sym, context=ctx, label_names=None)

cpu版单张图片提取特征:

# !/usr/bin/env python
# -*- coding: utf-8 -*-import os
import time
import math
import mxnet as mx
import cv2
import numpy as np
from collections import namedtupledef single_input(path):img = cv2.imread(path)# mxnet三通道输入是严格的RGB格式,而cv2.imread的默认是BGR格式,因此需要做一个转换img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)img = cv2.resize(img, (112, 112))img = img.transpose(2, 0, 1)# 添加一个第四维度并构建NDArrayimg = img[np.newaxis, :]array = mx.nd.array(img)return arrayif __name__ == "__main__":time_start = time.time()time0 = time.time()sym, arg_params, aux_params = mx.model.load_checkpoint('../recognition/mxnet/new_model', 0)# print(sym)# 提取中间某层输出帖子特征层作为输出all_layers = sym.get_internals()# print(all_layers)sym = all_layers['fc1_output']# 重建模型model = mx.mod.Module(symbol=sym, label_names=None)model.bind(for_training=False, data_shapes=[('data', (1, 3, 112, 112))])model.set_params(arg_params, aux_params)time1 = time.time()time_load = time1 - time0# print("模型加载和重建时间:{0}".format(time1 - time0))Batch = namedtuple("batch", ['data'])img1_path=r'E:\MNN\project\android-mnn_new\app\src\main\assets\0.jpg'array1 = single_input(img1_path)model.forward(Batch([array1]))vector1 = model.get_outputs()[0].asnumpy()vector1 = np.squeeze(vector1)print(vector1)

vim test.py

import face_model
import argparse
import cv2
import sys
import numpy as np
parser = argparse.ArgumentParser(description='face model test')
# general
parser.add_argument('--image-size', default='112,112', help='')
# 这个比较重要,人脸识别的model,我的经验是要用绝对路径
parser.add_argument('--model', default='/disk1t/insightface/insightface/models/model-r100-ii/model,0', help='path to load model.')
# 这个比较重要,年龄性别的model,我的经验是要用绝对路径
parser.add_argument('--ga-model', default='/disk1t/insightface/insightface/models/gamodel-r50/model,0', help='path to load model.')
parser.add_argument('--gpu', default=0, type=int, help='gpu id')
parser.add_argument('--det', default=0, type=int, help='mtcnn option, 1 means using R+O, 0 means detect from begining')
parser.add_argument('--flip', default=0, type=int, help='whether do lr flip aug')
parser.add_argument('--threshold', default=1.24, type=float, help='ver dist threshold')
args = parser.parse_args()
print(args)
# 加载model
model = face_model.FaceModel(args)
# 读取图片
img = cv2.imread('Tom_Hanks_54745.png')
# 模型加载图片
img = model.get_input(img)
# 获得特征
f1 = model.get_feature(img)
# 输出特征
print(f1)
#print(f1[0:10])
#gender, age = model.get_ga(img)
#print(gender)
#print(age)
#sys.exit(0)
#img = cv2.imread('/raid5data/dplearn/megaface/facescrubr/112x112/Tom_Hanks/Tom_Hanks_54733.png')
#f2 = model.get_feature(img)
#dist = np.sum(np.square(f1-f2))
#print(dist)
#sim = np.dot(f1, f2.T)
#print(sim)
#diff = np.subtract(source_feature, target_feature)
#dist = np.sum(np.square(diff),1)

获取指定层的输出
有些时候我们不需要网络的输出,而是只需要网络某个层的输出来通过网络提取图片的特征,这时候我们就需要指定提取层的名称,这里我们通过提取网络最后一层的全连接层为例

def get_specify_mod(model_str,ctx,data_shpae,layer_name):_vec = model_str.split(",")prefix = _vec[0]epoch = int(_vec[1])sym,arg_params,aux_params = mx.model.load_checkpoint(prefix,epoch)#获取神经网络所有的层all_layers = sym.get_internals()#获取输出层sym = all_layers[layer_name+"_output"]mod = mx.mod.Module(symbol=sym,context=ctx)mod.bind(data_shapes=[("data",data_shpae)])mod.set_params(arg_params,aux_params)return moddef predict_specify(model_str,ctx,data_shape,img_path,label_path):label_names = get_label_names(label_path)#通过输出网络层的名称,输出层全连接层的名称为fc1mod = get_specify_mod(model_str,ctx,data_shape,layer_name="fc1")nd_img = preprocess_img(img_path,data_shape,ctx)#将需要预测的图片封装为Batchdata_batch = mx.io.DataBatch(data=(nd_img,))#计算网络的预测值mod.forward(data_batch,is_train=False)#获取网络的输出值output = mod.get_outputs()[0]#对输出值进行softmax处理proba = mx.nd.softmax(output)#获取前top5的值top_proba = proba.topk(k=5)[0].asnumpy()for index in top_proba:probability = proba[0][int(index)].asscalar()*100pred_label_name = label_names[int(index)]print("label name=%s,probability=%f"%(pred_label_name,probability))

6.到此,人脸特征提取就完成了

人脸识别数据集精度验证:这个是固定阈值,测精度不太合适

# !/usr/bin/env python
# -*- coding: utf-8 -*-import os
import time
import math
import mxnet as mx
import cv2
import numpy as np
from collections import namedtupledef str_expansion(sstring):"""将“数字型”字符扩展"""ssize = len(sstring)if ssize == 1:sstring = "00" + sstringelif ssize == 2:sstring = "0" + sstringreturn sstringdef str_expansion_lfw(sstring):""":param sstring: 将“数字型”字符扩展,针对lfw数据集"""ssize = len(sstring)if ssize == 1:sstring = "000" + sstringelif ssize == 2:sstring = "00" + sstringelif ssize == 3:sstring = "0" + sstringreturn sstringdef cos_similarity(x, y):length = len(x)x_squre = 0y_squre = 0xy_inner_product = 0for i in range(length):x_squre += x[i] * x[i]y_squre += y[i] * y[i]xy_inner_product += x[i] * y[i]print(x_squre)print(y_squre)return xy_inner_product / (math.sqrt(x_squre) * math.sqrt(y_squre))def single_input(path):img = cv2.imread(path)# mxnet三通道输入是严格的RGB格式,而cv2.imread的默认是BGR格式,因此需要做一个转换img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)img = cv2.resize(img, (112, 112))img = img.transpose(2, 0, 1)# 添加一个第四维度并构建NDArrayimg = img[np.newaxis, :]array = mx.nd.array(img)return arrayif __name__ == "__main__":time_start = time.time()# verication_folder = "E:/face_detection/verification/IRimg_verification"verication_folder = r"E:\project\faceid\MobileFaceNet_Tutorial_Pytorch\data_set\LFW\lfw_align_112"file = r"E:\project\faceid\MobileFaceNet_Tutorial_Pytorch\data_set\LFW\pairs.txt"pair_list = []with open(file, 'r') as f:pair_list = f.readlines()time0 = time.time()sym, arg_params, aux_params = mx.model.load_checkpoint("mxnet/zwnwet_model", 0)print(sym)# print(arg_params)# print(aux_params)# 提取中间某层输出帖子特征层作为输出all_layers = sym.get_internals()print(all_layers)sym = all_layers['fc1_output']# 重建模型model = mx.mod.Module(symbol=sym, label_names=None)model.bind(for_training=False, data_shapes=[('data', (1, 3, 112, 112))])model.set_params(arg_params, aux_params)time1 = time.time()time_load = time1 - time0# print("模型加载和重建时间:{0}".format(time1 - time0))Batch = namedtuple("batch", ['data'])threshold = 0.6TP = 0TN = 0NUM_IR = 3280 / 2NUM_lfw = 6000 / 2time_frame = 0for item in pair_list:line = item.rstrip().split("\t")print(line)# 属于同一个人的图片对验证if len(line) == 3:time2 = time.time()folder = line[0]# img1 = str_expansion(str(int(line[1])-1)) + ".jpg"# img2 = str_expansion(str(int(line[2])-1)) + ".jpg"img1 = folder + '_' + str_expansion_lfw(line[1]) + ".jpg"img2 = folder + '_' + str_expansion_lfw(line[2]) + ".jpg"img1_path = os.path.join(verication_folder, folder, img1)img2_path = os.path.join(verication_folder, folder, img2)array1 = single_input(img1_path)array2 = single_input(img2_path)model.forward(Batch([array1]))vector1 = model.get_outputs()[0].asnumpy()vector1 = np.squeeze(vector1)model.forward(Batch([array2]))vector2 = model.get_outputs()[0].asnumpy()vector2 = np.squeeze(vector2)similarity = cos_similarity(vector1, vector2)time3 = time.time()time_frame = time3 - time2 + time_frameprint(similarity, "\n")if similarity >= threshold:TP += 1# 属于不同的人的图片对验证if len(line) == 4:time4 = time.time()folder1 = line[0]# img1 = str_expansion(str(int(line[1])-1)) + ".jpg"img1 = folder1 + "_" + str_expansion_lfw(line[1]) + ".jpg"folder2 = line[2]# img2 = str_expansion(str(int(line[3])-1)) + ".jpg"img2 = folder2 + "_" + str_expansion_lfw(line[3]) + ".jpg"img1_path = os.path.join(verication_folder, folder1, img1)img2_path = os.path.join(verication_folder, folder2, img2)array1 = single_input(img1_path)array2 = single_input(img2_path)model.forward(Batch([array1]))vector1 = model.get_outputs()[0].asnumpy()vector1 = np.squeeze(vector1)model.forward(Batch([array2]))vector2 = model.get_outputs()[0].asnumpy()vector2 = np.squeeze(vector2)similarity = cos_similarity(vector1, vector2)time5 = time.time()time_frame = time5 - time4 + time_frameprint(similarity, "\n")if similarity < threshold:TN += 1print("检真正确率:{0:.4f}".format(TP / NUM_lfw))print("拒假正确率:{0:.4f}".format(TN / NUM_lfw))print("模型加载时间: {0:.3f}s".format(time_load))print("检测一帧平均时间: {0:.3f}s".format(time_frame / (NUM_lfw * 2)))time_end = time.time()print("程序运行时间: {0:.2f}min".format((time_end - time_start) / 60))

mxnet insightface特征提取相关推荐

  1. 解密阿里云大规模深度学习性能优化实践

    云栖号资讯:[点击查看更多行业资讯] 在这里您可以找到不同行业的第一手的上云资讯,还在等什么,快来! 作者 | 阿里云异构计算AI加速负责人 游亮 近日,斯坦福大学公布了最新的 DAWNBench 深 ...

  2. Ali-Perseus(擎天):统一深度学习分布式通信框架 [弹性人工智能]...

    [作者]  驭策(龚志刚) 笋江(林立翔)蜚廉(王志明) 昀龙(游亮) 近些年来,深度学习在图像识别,自然语言处理等领域快速发展.各种网络模型,需要越来越多的计算力来进行训练.以典型的中等规模的图像分 ...

  3. Perseus(擎天):统一深度学习分布式通信框架

    作者  驭策(龚志刚) 笋江(林立翔)蜚廉(王志明) 昀龙(游亮) 近些年来,深度学习在图像识别,自然语言处理等领域快速发展.各种网络模型,需要越来越多的计算力来进行训练.以典型的中等规模的图像分类网 ...

  4. Ali-Perseus(擎天):统一深度学习分布式通信框架 [弹性人工智能]

    [作者]  驭策(龚志刚) 笋江(林立翔)蜚廉(王志明) 昀龙(游亮) 近些年来,深度学习在图像识别,自然语言处理等领域快速发展.各种网络模型,需要越来越多的计算力来进行训练.以典型的中等规模的图像分 ...

  5. 轻松上手UAI-Train,拍拍贷人脸识别算法优化效率提升85.7%

    2019独角兽企业重金招聘Python工程师标准>>> "UAI-Train平台可以让我们方便地在短时内使用大量的GPU资源,用较低的成本训练海量的数据集,提高算法模型迭代 ...

  6. 深度学习框架哪家强?MXNet称霸CNN、RNN和情感分析,TensorFlow仅擅长推断特征提取

    深度学习框架哪家强:TensorFlow?Caffe?MXNet?Keras?PyTorch?对于这几大框架在运行各项深度任务时的性能差异如何,各位读者不免会有所好奇. 微软数据科学家Ilia Kar ...

  7. Insightface项目爬坑指南+使用本地数据集训练流程(MXNET版)

    其实半年多前就已经把insightface训练等一系列环节弄熟了,不得不说IBUG组的这个模型确实是开源界的翘楚,但是还是存在一些问题在某些程度上和商汤云从等大厂存在一点差距,这不妨碍大部分人日常人脸 ...

  8. @property python知乎_使用Mxnet进行图像深度学习训练工具 InsightFace - 使用篇, 如何一键刷分LFW 99.80%, MegaFace 98%....

    开头先把论文和开源项目地址放一下: Additive Angular Margin Loss for Deep Face Recognition​arxiv.org deepinsight/insig ...

  9. insightface mxnet训练 旧版

    11年it研发经验,从一个会计转行为算法工程师,学过C#,c++,java,android,php,go,js,python,CNN神经网络,四千多篇博文,三千多篇原创,只为与你分享,共同成长,一起进 ...

最新文章

  1. Qt实现截屏并保存(转载)
  2. Android之项目推荐使用的第三方库
  3. Python在mysql中进行操作是十分容易和简洁的
  4. 经典的机器学习方面源代码库(非常全,数据挖掘,计算...)
  5. supervisor管理mysql靠谱吗_Supervisor 从入门到放弃
  6. java opc 读取到数据块的数据_MES系统功能数据传输的介绍
  7. 移动应用程序框架Kendo UI Mobile发布R2 2016 SP2
  8. 创建标签等操作DOM的原生js API
  9. ASP.NET MVC Framework体验(1):从一个简单实例开始(转)
  10. 【MATLAB深度学习工具箱】学习笔记--螃蟹公母分类Crab Classification
  11. 时光倒流软件测试简历,时光倒流 28款数据恢复软件大比拼
  12. 江苏警官学院计算机科学与技术专业,江苏警官学院什么专业好就业,哪些专业适合女生...
  13. 在VisualBasic6.0中实现0.5数值修约
  14. Vue知识点总结(一)
  15. 用BitBlt实现透明贴图
  16. 面向对象之唐城NBA选秀大会
  17. 【Python Sympy】将表达式化为关于x的多项式,求出多项式系数
  18. 前端插件库之vue3使用vue-codemirror插件
  19. 一些Java实用技巧(量变转变为质变后会单独整理出来)
  20. OneDrive和OneDrive for Business映射到本地网络驱动器

热门文章

  1. 探测Windows2K/XP/2003本机系统信息
  2. eclipse运行android项目出现The connection to adb is down, and a severe error has occured.的问题
  3. Java的深拷贝和浅拷贝
  4. kernel 3.10代码分析--KVM相关--虚拟机创建\VCPU创建\虚拟机运行
  5. php单例模式的实例,PHP的单例模式的一个实例_php
  6. php twig扩展,如何写一个自定义的 Twig 扩展
  7. object转成实体对象_Object.assign 原理及其实现
  8. netty 文件传输服务器,Netty之二进制文件传输
  9. linux如何取文件列名,Linux ps 指定列名
  10. java生成xsd_java 生成XSD