人脸识别中的损失函数ArcFace及其实现过程(pytorch)

  • 简述ArcFace的原理
  • ArcFace代码部分
  • 后续使用:

简述ArcFace的原理

人脸识别的步骤分为人脸目标检测->特征提取->特征对比
在训练特征提取器的时候,我们要获得一个比较好的特征提取器,要求特征间分离得比较开,这样就不容易认错人了。

所以我们特别需要一个好的损失函数来完成大类间距的任务。
ArcFace其实就是从softmax loss衍生而来的,所以先要明白softmax loss是怎么一回事。
softmax和softmax loss虽然差不多,但这是不同的概念

个人认为比较讲的通俗易懂的softmax损失的链接,点击链接

为了使得特征之间分的更加开,ArcFace选择减少类内距,增加类间距的方式(角度)。
先看一下ArcFace loss的表达式

在softmax中,以e为底的指数
wx=∥w∥∥x∥cos⁡θ=s×cos⁡θwx=\lVert w \rVert \lVert x \rVert \cos \theta =s\times \cos \theta wx=∥w∥∥x∥cosθ=s×cosθ
现在让θ\thetaθ变成了θ+m\theta+mθ+m

使得向量x与中心线向量w角度变大,这样设计的目的是为了在损失中增加角度的贡献量,从而使得优化过程中,角度收更小(也就是说w作为一个类别的中心线,w与x的角度变小),那么一个类别的xix_ixi​与其他类别的xjx_jxj​之间的角度就增大了,从而实现了,减小类内距,增大类间距的方式。

看了下图应该就明白了,这是最终我们想要达成的目的。w1w_1w1​与w2w_2w2​对应两个中心线(也就是两个类),最后它们的类间距是比较大的。

其中中心线w是一个可学习的参数,可以理解为一堆数据的中心线(类似二维平面中的一些点的聚类中心点)

ArcFace代码部分

下面是ArcFace实现的过程,为了方便理解这个损失,并不写过多累赘代码,并用比较小的特征数代替。实际人脸检测中所需要的特征向量的维度还是比较大的,以及Arc函数的还需要完善。

import torch
from torch import nn
import torch.nn.functional as Fclass Arc(nn.Module):def __init__(self,feature_dim=2,cls_dim=10):super(Arc, self).__init__()#x是(N,V)结构,那么W是(V,C结构),V是特征的维度,C是代表类别数self.W = nn.Parameter(torch.randn(feature_dim,cls_dim))def forward(self,feature,m=1,s=10):x = F.normalize(feature,dim=1) w = F.normalize(self.W,dim=0)cos = torch.matmul(x,w)/10 #(N,C)a = torch.acos(cos) #(N,C)top = torch.exp(s*torch.cos(a+m))  #(N,C)down = torch.sum(torch.exp(s*torch.cos(a)),dim=1,keepdim=True)-torch.exp(s*torch.cos(a))  #第一项(N,1)  keepdim=True保持形状不变.这是我们原有的softmax的分布。第二项(N,C),最后结果是(N,C)out = torch.log(top/(top+down))  #(N,C)return out

cosθcos\thetacosθ用两个归一化后的x与w的乘积可得,因为cosθcos\thetacosθ为x与w的内积并除以它们的模。
x在第一维度归一化,w在第零维度归一化,因为后续作了矩阵相乘,torch.matmul(x,w),x的行乘以w的列。
对cos的结果还要除10,是因为torch.matmul(x,w)的范围不确定,可能会超过1,这样就超过arccos的定义域范围了,就会产生NaN的结果。当然后续也不需要乘回来,因为w是一个可学习参数,它会自己去改变。
至于s(w与x的模的乘积),乘在cosθcos\thetacosθ前就相当于一个超参数,和m一样。可以通过改变m(加减)和改变s(缩放),来调节你对最终想要特征间距的结果。

后续使用:

平时使用的交叉熵CrossEntropyLoss()是log+softmax+nn.NLLloss()
ArcFace就是将log+softmax替换成了Arc(),在角度上加了一个值,使得特征间的角度更加小(有点类似于正则化前的系数)。

现在需要一个特征提取器:比如desnet、resnet、mobileNetV2等等都行,它们的输出形状为(N,feature_dim)
将特征输入ArcFace层,得到输出形状(N,cls)
下面的net就是特征提取器+ArcFace层

提前定义一下损失

loss_fn = nn.NLLLoss()

训练过程中把损失这样写就行了

cls = net(xs)
loss = loss_fn(cls, ys)

数据集载入,写一个特征提取器或者使用预训练模型,训练过程,这些常规流程就不再赘述了。

人脸识别中的损失函数ArcFace及其实现过程代码(pytorch)--理解softmax损失函数及Arcface相关推荐

  1. CV之FRec之ME/LF:人脸识别中常用的模型评估指标/损失函数(Triplet Loss、Center Loss)简介、使用方法之详细攻略

    CV之FRec之ME/LF:人脸识别中常用的模型评估指标/损失函数(Triplet Loss.Center Loss)简介.使用方法之详细攻略 目录 T1.Triplet Loss 1.英文原文解释 ...

  2. 【CVPR 2018】腾讯AI lab提出深度人脸识别中的大间隔余弦损失

    论文导读] 深度卷积神经网络(DCNN)在人脸识别中已经取得了巨大的进展,通常的人脸识别的核心任务都包括人脸验证与人脸识别,涉及到特征判别.很多模型都是使用Softmax损失函数去监督模型的训练,但是 ...

  3. 人脸识别中Softmax-based Loss的演化史

    点击我爱计算机视觉标星,更快获取CVML新技术 近期,人脸识别研究领域的主要进展之一集中在了 Softmax Loss 的改进之上:在本文中,旷视研究院(上海)(MEGVII Research Sha ...

  4. 人脸识别中常用的几种分类器

    人脸识别中常用的几种分类器 在人脸识别中有几种常用的分类器,一是最邻近分类器:二是线性分类器 (1)最邻近分类器 最近邻分类器是模式识别领域中最常用的分类方法之一,其直观简单,在通常的应用环境中非常有 ...

  5. 人脸识别中的深度学习

    深度学习在人脸识别中的应用 人脸识别的过程包括: 人脸检测 人脸对齐 特征提取(在数学上,实质上是:空间变换) 特征度量 其中,特征提取与度量,是人脸识别问题中的关键问题,也是相关研究的难点之一. 传 ...

  6. yii2 模型中set_Day184:人脸识别中open-set与close-set

    人脸识别 可以简单的分为如下两类: face verification:人脸验证时将人脸分类到某个ID,比如给定两张人脸,判断是否是同一个人(ID) face identification:给定一张人 ...

  7. 人脸识别中的rank-n

    人脸识别中的rank-n 代表的意思 原创这个昵称唯一 最后发布于2017-09-02 11:05:13 阅读数 2247  收藏 展开 Rank-1 看一些论文总是在结果中看到rank-1,等等,但 ...

  8. 计算机视觉子方向,计算机视觉方向简介 | 人脸识别中的活体检测算法综述

    原标题:计算机视觉方向简介 | 人脸识别中的活体检测算法综述 本文转载自"SIGAI人工智能学习与实践平台"(ID:SIGAICN) 导言 1. 什么是活体检测? 判断捕捉到的人脸 ...

  9. 人脸识别中的阈值应该如何设置?

    人脸识别中的阈值应该如何设置? 标签: 人脸识别 分类: 人脸识别 人脸识别中的阈值应该如何设置? 随着人脸识别技术使用范围越来越广,大部分使用者可能对人脸识别中的某一方面不是很懂,咨询的问题也五花八 ...

  10. 人脸识别中的全脸/半脸/中脸

    人脸识别中的一个常用概念是全脸/中脸/半脸 下面讲一下区别 目前DeepFaceLab拥有三种不同类型的脸部模式,H64和H128是半脸(half face)模型,DF LIAEF128 Quick9 ...

最新文章

  1. 如何支持亿级用户分流实验?AB实验平台在爱奇艺的实践
  2. 比特币现金比BTC节能40%以上
  3. android double比较大小吗,java – Android – 比较方法违反了它的一般...
  4. 项目: 最简单的飞机大战游戏
  5. Java异常处理——try-with-resource 语法糖
  6. 关于你,关于我. 你好 5G
  7. Linux常用的分区方案及目录结构
  8. columns列:Rows 工作表上所有的行
  9. 你好,欢迎来到我的空间,
  10. java模拟浏览器http请求_java使用HttpClient模拟浏览器请求
  11. postgress无法远程连接问题解决方案
  12. 实高斯随机向量与复高斯随机向量
  13. [读书笔记]《Windows游戏编程之从零开始》(零)
  14. 【Unity游戏开发基础】如何做可以调整音量的UI滚动条组件
  15. c语言题目详解——实现四舍五入
  16. [黑客光盘2009年更新]100张黑客光盘 数百G黑客资源~~~~~~~~~~~~~
  17. Unity制作卡牌游戏
  18. 深度 | 为什么通用AMM模型不适用于期权交易定价?
  19. iostat命令参数详解
  20. git 避免提交_新秀Git错误避免

热门文章

  1. 怎么让电脑微信安装到别的盘路径
  2. VS创建和使用C++动态链接库教程
  3. 极点五笔状态栏菜单不显示,如何再让它显示?
  4. vs2010下配置HPSocket教程,详细截图附入门demo源码
  5. Hp-socket高性能网络库三--tcp组件pack接收模型
  6. 51单片机步进电机c语言程序,51单片机的步进电机c语言驱动程序
  7. somachine3.1安装包和安装方法
  8. java调用js模板引擎_JavaScript模板引擎用法实例
  9. xjoi 1524 枚举集合
  10. SCARA机器人matlab仿真