码字不易,转载请注明出处!

前面我们制作好了训练所需要的文件:train.rec,property,以及验证时所需要的val.bin,那么接下来就是该探索如何进行数据的训练。

这部分内容相对来说比较简单,毕竟框架和代码都是作者已经写好的,可供更改的内容还是有限的,所以也没有太多技巧的内容,更多就是按部就班的来。

模型训练

训练文件在"src"=>"train_softmax.py"文件内:


打开train_softmax.py文件后,主要关注的是一些训练参数,这里的内容还是挺多的,需要花点时间看下每项都在做什么。


我们用到的内容主要有以下几个:

  1. 87行中,我们需要指定我们制作的训练样本train.rec文件所在的文件夹
  2. 88行中,我们需要指定将来模型训练好之后保存到哪个位置
  3. 89行中,表示我们所使用的作者提供的预训练模型路径和名称
  4. 91行中,表示我们所希望使用的loss种类,这里作者提供了5中loss可供使用,分别是1)原始的softmax、2)SphereFace;3)cosineface;4)arcface;5)各种loss结合版本。这部分内容在作者的github主页上面有介绍:


5. 92行中,表示每隔多少个iteration做一次验证并保存模型
6. 95行中,network表示使用何种网络模型,在路径"src"=>"symbols"文件夹下有不同的模型,并且在代码中也用get_symbol()函数中定义了不同的模型,可以根据自己的需求使用,这里作者默认为resnet50
7. 97行中,是否使用se网路结构,这部分内容可以查看Squeeze and excitations networks论文,作者在论文中是使用的,模型的表现也不错,虽然默认是0,但是建议还是用上。
8. 112行中,设置batchsize的大小
9. 113-114行设置特征归一化后的大小和margin的大小,这部分内容很多论文都有提到
10. 128行中,target表示验证数据。作者原始使用了lfw、cfp、agedb三种验证集,但是我们这里因为要使用我们自己的,所以这里可以通过参数更改,或者直接在默认值中添加我们的数据集。并且,一定要注意的是,这里的target一定得有,不然网络训练中是不会保存模型的。即使我们在92行中设定了验证的间隔,但是如果我们没有提供验证数据,就不会有验证的过程,而且也不会保存模型。我在刚刚开始训练的时候,没有添加验证集,结果导致了模型训练了好久,训练准确度虽然提升了,但是由于没有验证集,所以模型不做验证也不保存数据。所以这里一定要加上。
别的内容自己了解下,根据自己的需要进行下尝试。

实际在使用的时候,依然是将训练参数写入一个.sh文件中,这样就可以直接通过更改.sh文件的内容来达到控制网络运行的目的:

#!/usr/bin/env bash
CUDA_VISIBLE_DEVICES='0' python -u  train_softmax_my.py --prefix ../models/model-r50-am-lfw --loss-type 0  --data-dir ../dataset/train --per-batch-size 32 \
--version-se 1 --verbose 1000 --target val --margin-s 64 --emb-size 512

上面的参数可以解释为,使用0号gpu,训练我的train_softmax_my.py文件,预训练模型保存路径为上级目录中的models文件夹内的"model-r50-am-lfw" ,使用原始的softmax损失,数据存放路径为上级目录中的dataset文件夹中的train文件夹内,batchsize为32, 使用se网络,每隔1000次进行一下验证,并保存模型,验证集的名称为val,因为它与train文件是放在同一个目录下的,所以模型自己会去train的路径下去找,归一化后特征尺度为64,最后的输出特征层的维度为512维。

那么在命令行敲入sh 你的sh文件名称.sh或者 ./你的sh文件名称.sh文件就可以跑起来了。

模型验证

同样是在"src"=>"train_softmax.py"文件中的第403行:


我们可以看到,定义了一个ver_test()函数来进行验证,验证使用的是verification.test函数,这部分的内容是在"src"=>“eval”=>"verification.py"文件中定义的。

而调用ver_test()函数,则是通过系统的一个回调函数来进行调用:

从上图中可以看到,418行中初始化了一个global_step,训练的时候会在432行中将这个变量加1,并将值赋给mbatch变量,当这个mbatch变量满足444行中的判定条件,那么就调用ver_test函数,从而进入验证阶段。

而验证阶段的目的只有一个:采用n-fold交叉验证的方式,确定一个阈值来使模型对验证集中相同的图相对判定为1,不相同的图相对判定为0。
具体而言,在"verification.py"文件中,找到test函数:


可以看到,该函数加载了数据和模型之后,通过将训练数据data_set划分为多个batch之后,得到_data,然后将其送入网络中进行前向推理(232行),获得输出的特征。也就是通过这部分内容,net_out变量保存的内容就是网络在前向推理得到的不同图像的特征向量。
然后对特征向量进行后续的操作:


从上图中可以看到,这里首先对特征向量进行了标准化(278行),然后将标准化后的特征向量送入281行中的evaluate函数进行评估。而evaluate函数如下:


可以看到,这里设定的阈值为从0到4,变化幅度为0.01,一共400个阈值,算法将从这400个阈值中选取最合适的阈值。注意,这个函数输入的参数embeddings其实在上一步中已经对输入的图相对进行了重新组合,还是拿上面的NBA例子作为示例,原来我们送入的图像对是这样的:


但是这里已经embeddings中保存的图像为如下形式:


唯一的区别就是,这里的embeddings分别对应着每个人的特征向量,而不是原始图像。embeddings1和embeddings2分别保存着embeddings中特征向量的奇数行和偶数行,这样做可以非常方便的进行比较操作,比如embeddings1[0] (韦德1)和embeddings2[0](韦德2)就表示相同,actual_issame[0]对应为1,embeddings1[3](科比3)和embeddings2[3](科比4)也表示相同, actual_issame[1]为1,但是embeddings1[8](韦德1) 和embeddings2[8](科比2)就表示不相同,actual_issame[8]为0。当然也可以不这样做,只不过操作稍微麻烦一点。
在177行,所有的内容都被送入calculate_roc中进行计算,求取tpr和fpr。


该函数中,首先根据参数将送入的embeddings和标签actual_issame划分为k-fold(这里为10折),然后对两个特征向量求取欧氏距离(76、77行),以便下步计算两者之间的相似度。
而且,在79行,将特征向量划分为训练集和验证集(划分比例为9:1),然后在100行的循环中将训练集送入calculate_accuracy函数中,计算训练集在不同阈值情况下的准确率。得到最高的准确率所对应的阈值之后,再在104行的循环中和106行,利用该阈值在测试集上进行试验,获得对应的准确率指标。这样总共经过10次循环,就可以计算出在10次划分中,不同划分下的最优阈值,这个阈值将是我们将不同人脸区分开的重要阈值。


这里,我们可以看到,函数是对不同阈值情况下的不同划分,计算器混淆矩阵,然后求解tpr和fpr和acc。这里假设我们将相同人的图像对认为是正样本,不同人的图像对认为是负样本。
tpr表示的就是:所有的正样本中,模型正确判断其为正样本所占的比例。
fpr表示的就是:所有的负样本中,模型误判某些负样本为正样本所占的比例。
acc表示的就是:所有样本中,正样本和负样本都划分对的部分占所有样本的比例。

因为对于不同图像对来讲,我们的希望模型能够尽量把相同的人和不同的人都划分正确。但实际上是不可能的。通常我们希望tpr越高越好,代表模型能够充分将正样本判定为正样本;同时,也希望模型fpr越低越好,表示模型不会对负样本发生误判。而实际上这两者之间是互相制约的。想要tpr高,则模型应该不要太敏感,但是如果模型不敏感,那么fpr就会升高,反之亦然。所以实际上在使用中是根据需要尽量取到一个平衡点。
总之,通过了10折交叉验证,我们获得了在不同的阈值情况下的一个准确率。而使得准确率最高的阈值就是我们想要的。

通过上面的过程,就完成了模型的验证过程。并且得到了不同人脸之间区分的阈值,那么这个阈值就是我们将来进行人脸识别重要指标。如果两个人的特征向量之间的距离小于该阈值,那么我们判定其为同一个人,如果大于该阈值,则判定其为不同的人。

通过训练和验证过程,我们最终获得了一个训练好的模型和一个计算好的阈值。其实到这步我们就已经具备了1:1的识别能力了。但是想要实现1:N,我们还需要构建一个人脸特征库。这部分内容在下节进行介绍。

人脸识别之insightface开源代码使用:训练、验证、测试(3)相关推荐

  1. 人脸识别之insightface开源代码使用:训练、验证、测试(1)

    码字不易,转载请注明出处! 代码放于github:https://github.com/vincentwei0919/insightface_for_face_recognition 训练权重: 链接 ...

  2. 人脸识别之insightface开源代码使用:训练、验证、测试(4)

    通过前面的几个小节,我们已经实现了模型的训练以及阈值的选取.此时利用我们已经训练好的模型和手上的阈值,我们已经能够做到1:1这样的验证了.所要做的就是拿两张图片,相同人或者不同人,然后送入网络中,网络 ...

  3. 人脸识别之insightface开源代码使用:训练、验证、测试(2)

    码字不易,转载请注明出处! 有了前面的准备工作之后,我们就开始动手了. Let The Hunt Begin! 数据集规模 数据当然是越多越好,然而实际我们可能没有那么多数据,那么多大的量就可以了呢? ...

  4. 人脸识别之insightface开源代码使用——自定义数据集制作

    人脸识别简介 简单来讲,人脸识别这个问题,就是给定两个人脸,然后判定他们是不是同一个人,这是它最原始的定义.它有很多应用场景,比如银行柜台.海关.手机解锁.酒店入住.网吧认证,会查身份证跟你是不是同一 ...

  5. 离线识别率高达99%的Python人脸识别系统,开源~

    来源:https://zhuanlan.zhihu.com/p/46931078 大家好,我是辰哥 以往的人脸识别主要是包括人脸图像采集.人脸识别预处理.身份确认.身份查找等技术和系统.现在人脸识别已 ...

  6. python人脸识别系统早已开源,离线识别率高达99%以上!

    以往的人脸识别主要是包括人脸图像采集.人脸识别预处理.身份确认.身份查找等技术和系统.现在人脸识别已经慢慢延伸到了ADAS中的驾驶员检测.行人跟踪.甚至到了动态物体的跟踪. 由此可以看出,人脸识别系统 ...

  7. DeepLearning tutorial(5)CNN卷积神经网络应用于人脸识别(详细流程+代码实现)

    DeepLearning tutorial(5)CNN卷积神经网络应用于人脸识别(详细流程+代码实现) @author:wepon @blog:http://blog.csdn.net/u012162 ...

  8. 人脸识别SeetaFace2原理与代码详解

    人脸识别SeetaFace2原理与代码详解 前言 一.人脸识别步骤 二.SeetaFace2基本介绍 三.seetaFace2人脸注册.识别代码详解 3.1 人脸注册 3.1.1 人脸检测 3.1.2 ...

  9. GitHub上YOLOv5开源代码的训练数据定义

    GitHub上YOLOv5开源代码的训练数据定义 代码地址:https://github.com/ultralytics/YOLOv5 训练数据定义地址:https://github.com/ultr ...

最新文章

  1. 通过js让页面中的元素上下居中的写法
  2. JAVA入门级教学之(switch语句)
  3. python中常见的运行时错误_python--17个新手常见Python运行时错误
  4. 爬虫实例7 爬取豆瓣电影数据 (json+ajax)
  5. 敏感词库快速添加到mysql数据库,并在页面使用方法过滤敏感词
  6. android高帧率模式,《和平精英》等五款游戏已经适配小米10系列高帧率模式
  7. K线形态分析交易系统
  8. [车联网安全自学篇] Android安全之Android so文件分析「详细版」
  9. 线上发版如何做到分批发的?详解蓝绿部署,滚动升级,A/B 测试,灰度发布/金丝雀发布
  10. 一款消消乐游戏的自动解法
  11. 用决策树预测获胜球队
  12. flutter图片聊天泡泡_Flutter极致的业务封装——各类聊天气泡(一)
  13. c++ IO流---输入输出流 格式控制字符
  14. 目标检测实战篇1——数据集介绍(PASCAL VOC,MS COCO)
  15. 【小萝莉说Crash】第二期:Unrecognized selector xxx 之 ForwardInvocation
  16. 运用javascript的成员访问特性来实现通用版的兼容所有浏览器的打开对话框功能...
  17. 实验吧-MD5之守株待兔
  18. fetchMetadata: sill resolveWithNewModule raw-loader@0.5.1 checking installable status
  19. 算法日记(十三)之动态规划
  20. 2018计算机专业研究院教育部评估

热门文章

  1. 外贸SOHO怎么开发新客户
  2. AI 2021 条形码插件
  3. jCO--http://www.cnblogs.com/zfswff/p/5671148.html
  4. 电脑C盘必须要删除的四个文件夹
  5. 单片机实验4 外部中断EX0 EX1
  6. python之某公司不同年份不同财务指标比较
  7. 工业智能网关BL110应用之八十: 实现西门子S7-400 PLC 接入华为云平台
  8. 炒股做短线好还是中长线好?区别对比分析
  9. 海思联咏安霸视觉AI SOC横向对比,你心中的王者有没有动摇过。
  10. H5旅游webapp(顺便游)|移动端旅行项目