最近因需要粗浅的学习了一下ArcFace损失函数,由于在学习中遇到了很多问题,特将问题的思考分享出来,权当分享个人愚见,希望可以有人看到后进行讨论进步。

  • ArcFace的引入

    人脸识别分为四个过程:人脸检测、人脸对齐、特征提取、特征匹配。其中,特征提取作为人脸识别最关键的步骤,提取到的特征更偏向于该人脸“独有”的特征,对于特征匹配起到举足轻重的作用,而我们的网络和模型承担着提取特征的重任,优秀的网络和训练策略使得模型更加健壮。

    但在Resnet网络表现力十分优秀的情况下,要提高人脸识别模型的性能,除了优化网络结构,修改损失函数是另一种选择,优化损失函数可以使模型从现有数据中学习到更多有价值的信息。

    而在我们以往接触的分类问题有很大一部分使用了Softmax loss来作为网络的损失层,实验表明Softmax loss考虑到样本是否能正确分类,而在扩大异类样本间的类间距离和缩小同类样本间的类内距离的问题上有很大的优化空间,因而作者在Arcface文章中讨论了Softmax到Arcface的变化过程,同时作者还指出了数据清洗的重要性,改善了Resnet网络结构使其“更适合”学习人脸的特征。

  • 对特征提取和分类的个人理解
    首先我们思考在分类层前全连接层的意义是什么,全连接层可以视为一个权重矩阵W和网络模型提取到的特征X(我们可以理解为通过全连接层之前的网络结构并且已经进行过flatten的特征)相乘的过程。即为一个W[^T]*X的过程。那么这个相乘操作的物理意义是什么呢,此时我们可以回忆向量的点乘,向量的点乘即为两个向量的模(常数)的乘积再乘上他们之间夹角的cos值,它的物理意义是两个向量之间的相似度大小。
    我们来看一个例子:

    这个分类任务的目的是为了区分输入图像为笔记本电脑还是平板电脑。我们假设通过网络模型提取到了
    第I个样本的特征featurei,一共有五个特征。全连接层的操作就是将[2,5]的权重矩阵,乘上这个[5,1]的特征矩阵,得到[2,1]的分类结果矩阵。(此时为logits因为还没有经过Softmax层)。我们可以理解全连接层的权重就是这个样本的“标准特征向量“而提取到的特征向量与”标准特征向量“进行点乘其实是计算出了,第i个样本提取特征和分类Ci的标准特征向量的相似度,所以我们取Ci相似度最大的结果作为最后的分类结果。
    我们通常的操作就是将全连接层和提取特征向量的乘积结果送入全连接层,得到一个sum为1的概率向量,取向量中概率最大的index作为分类结果。
    但是这样的分类,我们只能得到类似下图的分类结果:

    这种结果只能让不同类别(用颜色表示)简单分开,并不能拉大类别之间的距离,减小类别内样本之间的距离。
    这样的简单分类不适合做人脸识别的任务(我们可以思考一下,如果仅仅使用softmax完成如图所示的分类效果,如果存在双胞胎这种两个人长的很像的类型,类别之间距离不够,便很难将其分开)于是ArcFace出现了。

  • ArcFace的推导与遇到问题的个人见解
    首先我们回顾Softmax Loss:
    这是我们传统的Softmax公式,其中 代表我们的全连接层输出,我们在使损失L_S下降的过程中,则必须提高我们的{W^T_{y_i}x_i+b_{y_i}}所占有的比重,从而使得该类的样本更多地落入到该类的决策边界之内。
    首先我们看ArcFace的整体公式:

    下面是ArcFace是数学推导,推导前我们要注意一个问题就是L2归一化,即为在推导过程中将W和X化为1的计算过程,L2归一化是将向量内的每个元素除以向量的L2范数的过程。

  • 代码实现(基于pytorch)

class ArcMarginModel(nn.Module):def __init__(self, m=0.5,s=64,easy_margin=False,emb_size=512):super(ArcMarginModel, self).__init__()self.weight = Parameter(torch.FloatTensor(num_classes, emb_size))# num_classes 训练集中总的人脸分类数# emb_size 特征向量长度nn.init.xavier_uniform_(self.weight)# 使用均匀分布来初始化weightself.easy_margin = easy_marginself.m = m# 夹角差值 0.5 公式中的mself.s = s# 半径 64 公式中的s# 二者大小都是论文中推荐值self.cos_m = math.cos(self.m)self.sin_m = math.sin(self.m)# 差值的cos和sinself.th = math.cos(math.pi - self.m)# 阈值,避免theta + m >= piself.mm = math.sin(math.pi - self.m) * self.mdef forward(self, input, label):x = F.normalize(input)W = F.normalize(self.weight)# 正则化cosine = F.linear(x, W)# cos值sine = torch.sqrt(1.0 - torch.pow(cosine, 2))# sinphi = cosine * self.cos_m - sine * self.sin_m# cos(theta + m) 余弦公式if self.easy_margin:phi = torch.where(cosine > 0, phi, cosine)# 如果使用easy_marginelse:phi = torch.where(cosine > self.th, phi, cosine - self.mm)one_hot = torch.zeros(cosine.size(), device=device)one_hot.scatter_(1, label.view(-1, 1).long(), 1)# 将样本的标签映射为one hot形式 例如N个标签,映射为(N,num_classes)output = (one_hot * phi) + ((1.0 - one_hot) * cosine)# 对于正确类别(1*phi)即公式中的cos(theta + m),对于错误的类别(1*cosine)即公式中的cos(theta)# 这样对于每一个样本,比如[0,0,0,1,0,0]属于第四类,则最终结果为[cosine, cosine, cosine, phi, cosine, cosine]# 再乘以半径,经过交叉熵,正好是ArcFace的公式output *= self.s# 乘以半径return output
  • 个人遇到的主要问题以及查找和思考
    1.参数s和m具体代表什么:
    通过ArcFace,分类结果可以”进化“为

    这种样子,我们把分类的可视化结果视为一个圆,s就代表这个圆的半径,m则可以调整类别之间的夹角距离(?)
    2.代码中这easy_margin部分的意义(为什么需要这两行代码):

    首先没有查到easy_margin相关的资料,希望有人可以指点下作者这部分相关。
    这个部分的代码主要意义是为了保持Cos的单调性,那么我们首先思考为什么要保持Cos这个函数的单调性。因为在ArcFace中,我们将特征向量和“类别标准向量”的相似度衡量标准从点乘结果转变为了仅仅看两者之间Cos“夹角”(此处的夹角表达意思不完全准确,仅供此部分的理解所需)的值。根据余弦函数的特点,当角度超过Pi时,余弦函数会丢失单调性特征。但是我们在衡量相似度时,所用的夹角是建立在余弦函数的单调性之上的,比如夹角(>0)时,夹角越大,余弦值越小,因此我们就可以说余弦值小的两个向量,相似度较小。但是一旦丢失单调性,这种理论基础便不复存在了。因此我们需要cosine-self.mm的这个操作在Cos(theta+m) > Pi 的时候进行代替,强制使其小于Pi

  • 未解决的问题:
    1.为什么减去的值是

以上内容均为本人个人理解,不代表准确立场,希望大家在评论区指出错误,一起讨论问题。

ArcFace的原理以及代码的理解相关推荐

  1. 对LOAM算法原理和代码的理解

    LOAM概况 V-LOAM和LOAM在KITTI上的排名一直是前两名. LOAM在KITTI上的排名 原始LOAM代码(带中文注释)的地址:https://github.com/cuitaixiang ...

  2. 深入理解BatchNorm的原理、代码实现以及BN在CNN中的应用

    深入理解BatchNorm的原理.代码实现以及BN在CNN中的应用 BatchNorm是算法岗面试中几乎必考题,本文将带你理解BatchNorm的原理和代码实现,以及详细介绍BatchNorm在CNN ...

  3. DeepLearning tutorial(1)Softmax回归原理简介+代码详解

    FROM: http://blog.csdn.net/u012162613/article/details/43157801 DeepLearning tutorial(1)Softmax回归原理简介 ...

  4. DeepLearning tutorial(3)MLP多层感知机原理简介+代码详解

    FROM:http://blog.csdn.net/u012162613/article/details/43221829 @author:wepon @blog:http://blog.csdn.n ...

  5. DeepLearning tutorial(4)CNN卷积神经网络原理简介+代码详解

    FROM: http://blog.csdn.net/u012162613/article/details/43225445 DeepLearning tutorial(4)CNN卷积神经网络原理简介 ...

  6. php 策略模式实现原理,php 策略模式原理与应用深入理解

    php 策略模式原理与应用深入理解,策略,可以用,接口,简单,算法 php 策略模式原理与应用深入理解 易采站长站,站长之家为您整理了php 策略模式原理与应用深入理解的相关内容. 本文实例讲述了ph ...

  7. 【深度学习】搞懂 Vision Transformer 原理和代码,看这篇技术综述就够了

    作者丨科技猛兽 编辑丨极市平台 导读 本文对Vision Transformer的原理和代码进行了非常全面详细的解读,一切从Self-attention开始.Transformer的实现和代码以及Tr ...

  8. python原理及代码_原理+代码|详解层次聚类及Python实现

    前言 聚类分析是研究分类问题的分析方法,是洞察用户偏好和做用户画像的利器之一.聚类分析的方法非常多,能够理解经典又最基础的聚类方法 -- 层次聚类法(系统聚类) 的基本原理并将代码用于实际的业务案例是 ...

  9. TCP 三次握手原理,你真的理解吗

    转载自  TCP 三次握手原理,你真的理解吗 最近,阿里中间件小哥哥蛰剑碰到一个问题--client端连接服务器总是抛异常.在反复定位分析.并查阅各种资料文章搞懂后,他发现没有文章把这两个队列以及怎么 ...

  10. MLP多层感知机(人工神经网络)原理及代码实现

    一.多层感知机(MLP)原理简介 多层感知机(MLP,Multilayer Perceptron)也叫人工神经网络(ANN,Artificial Neural Network),除了输入输出层,它中间 ...

最新文章

  1. c函数scanf(),printf()等常用格式字符串
  2. 浅谈CC攻击原理与防范
  3. MySQL中的datetime与timestamp比较
  4. mysql用户变量递归_MYSQL递归树查询的实现
  5. 华为Mate 40 Pro屏幕贴膜曝光:双孔曲面屏实锤?
  6. Java字符串基本认识
  7. Web架构演变过程以及出现的问题
  8. 图像分割与GAN网络
  9. 按键精灵-5-按键精灵控制脚本流程2
  10. java 生成pdf文件_Java 生成PDF文档的示例代码
  11. 【数据库内核】数据库核心技术演进之路
  12. 微信免卸载降级安装方法
  13. windows10怎么用cmd编译C语言,win10怎么样使用cmd来运行程序
  14. uva 11021 数学概率 麻球
  15. java毕业设计成品源码网站javaweb企业财务|记账|账单管理系统
  16. echars省份地图(安徽地图地图加散点图)亮点展示
  17. 2k2实用球员_nba2kol2实用球员
  18. led trigger
  19. ubuntu英文版变成中文版
  20. 傅里叶变换的虚数部分

热门文章

  1. python安装pyltp_windows 安装pyltp详细教程
  2. flex布局(弹性布局)
  3. 2021年上半年软件设计师下午真题及答案解析
  4. Drozer的安装与使用 | Android逆向工具
  5. 安装Petalinux
  6. 怎么解除计算机管理员的身份,怎么取消管理员权限(怎么取消管理员取得所有权)...
  7. 微信聊天图片视频怎么防撤回?自动备份/保存微信的聊天图片和视频(天有不撤图片视频)
  8. 企业微信的聊天记录保存在了哪里?
  9. 全网首发stm8s的硬件I2C读取bme280(bmp280)的C源程序
  10. 单词毕业设计,微信小程序毕设,小程序毕设源码,单词天天斗 (毕业设计/实战小程序学习/微信小程序完整项目)