为什么用ArcFace

前一篇文章已经提到了问什么不能直接用softmax loss做损失,是因为类与类之间交界处不容易分开,而center loss能把他分开是因为缩小了类内距,就是给每一个中心点,让每个类中的特征点无限向中心点靠拢。缩小类内距的同时,间接缩小了类间距。而ArcFace是直接缩小了类间距。
下面是我用mnist数字十分类做的直接用softmax loss和arcFace做的效果图:


第一个图是直接用softmax loss做的,很明显在交接处没有分开,第二个图是arcface做的效果,每一个类都清晰可见。

arcFace推导过程。

因为arcFace是对softmax loss的改进,先看softmax loss。
softmax loss:

N是样本的数量,i代表第i个样本,j代表第j个类别,fyi代表着第i个样本所属的类别的分数
fyi是全连接层的输出,代表着每一个类别的分数,
每一个分数即为权重W和特征向量X的内积

每个样本的softmax值即为:

由于w是通过损失反向传播不断更新的,x是随着前面的w变化而变化的,所以要改进softmax需要只能改cos(θ)或者θ,而论文作者实验证明改θ比改cos(θ)效果更好,所以有了Arcface。
Arcface公式:

arcface限制条件:

mnist数据集实现arcface(Pytorch):

import matplotlib.pyplot as plt
import numpy as np
import os
import torch.nn.functional as Fclass Arcsoftmax(nn.Module):def __init__(self, feature_num, cls_num):super(Arcsoftmax, self).__init__()self.w = nn.Parameter(torch.randn(feature_num, cls_num).cuda())self.func = nn.Softmax()def forward(self, x, s, m):x_norm = F.normalize(x, dim=1)w_norm = F.normalize(self.w, dim=0)cosa = torch.matmul(x_norm, w_norm)/10a = torch.acos(cosa)arcsoftmax = torch.exp(s * torch.cos(a + m) * 10) / (torch.sum(torch.exp(s * cosa * 10), dim=1, keepdim=True) - torch.exp(s * cosa * 10) + torch.exp(s * torch.cos(a + m) * 10))# arcsoftmax = torch.exp(s*torch.cos(a+m)*10) / (torch.sum(torch.exp(s*cosa*10# ), dim=1, keepdim=True) - torch.exp(s*cosa*10) + torch.exp(s*torch.cos(a+m) * 10))return arcsoftmaxclass ClsNet(nn.Module):def __init__(self):super().__init__()self.conv_layer = nn.Sequential(nn.Conv2d(1, 32, 3), nn.BatchNorm2d(32), nn.PReLU(),nn.Conv2d(32, 64, 3), nn.BatchNorm2d(64), nn.PReLU(),nn.MaxPool2d(3, 2))self.feature_layer = nn.Sequential(nn.Linear(11 * 11 * 64, 256), nn.BatchNorm1d(256), nn.PReLU(),nn.Linear(256, 128), nn.BatchNorm1d(128), nn.PReLU(),nn.Linear(128, 2), nn.PReLU())self.arcsoftmax = Arcsoftmax(2, 10)self.loss_fn = nn.NLLLoss()def forward(self, x, s, m):conv = self.conv_layer(x)conv = conv.reshape(x.size(0), -1)feature = self.feature_layer(conv)out = self.arcsoftmax(feature, s, m)out = torch.log(out)print(out.shape)return feature, outdef get_loss(self, out, ys):return self.loss_fn(out, ys)if __name__ == '__main__':train_data = datasets.MNIST(root='mnist',train=True,transform=torchvision.transforms.ToTensor(),download=True)test_data = torchvision.datasets.MNIST(root='mnist',train=False,transform = torchvision.transforms.ToTensor(),download=False)train = DataLoader(dataset=train_data, batch_size=1024, shuffle=True, drop_last= True)test = DataLoader(dataset=test_data, batch_size=1024, shuffle=True)# transform = transforms.Compose([#     transforms.Resize(28, 28),#     transforms.ToTensor(),#     transforms.Normalize((0.5,), (0.5,)),net = ClsNet().cuda()# net = net.to(device)path = r'params/weightnet2.pt'if os.path.exists(path):net.load_state_dict(torch.load(path))net.eval()print('load susseful')else:print('load fail')# epoch = 1024# optimism = optim.SGD(net.parameters(), lr=1e-3)optimism = optim.Adam(net.parameters(), lr=0.0005)# scheduler = lr_scheduler.StepLR(optimism, 10, gamma=0.8)# optimizer = optim.SGD(net.parameters(), weight_decay=0.0005, lr=0.001, momentum=0.9)# scheduler = lr_scheduler.StepLR(optimizer, 20, gamma=0.8)# optimizercenter = optim.SGD(Centerloss.parameters(), lr=0.5)losses = []# In[]c = ['#ff0000', '#ffff00', '#00ff00', '#00ffff', '#0000ff','#ff00ff', '#990000', '#999900', '#009900', '#009999']epoch = 10000d = 0# fig, ax = plt.subplots()for i in range(epoch):# scheduler.step()print('epoch: {}'.format(i))print(len(train))tar = []out = []for j, (input, target) in enumerate(train):input = input.cuda()target = target.cuda()feature, output = net(input, 1, 0.01)loss = net.get_loss(output, target)# label = torch.argmax(output, dim=1)  # 选出最大值的索引作为标签# 清空梯度 反向传播 更新梯度optimism.zero_grad()loss.backward()optimism.step()feature = feature.cpu().detach().numpy()# print(output)target = target.cpu().detach()# print(target)out.extend(feature)tar.extend(target)print('[epochs - {} - {} / {}] loss: {} '.format(i, j, len(train), loss.float()))outstack = np.stack(out)tarstack = torch.stack(tar)# plt.cla()plt.ion()if j == 3:d += 1for m in range(10):index = torch.tensor(torch.nonzero(tarstack == m))# print(index)plt.scatter(outstack[:, 0][index[:, 0]], outstack[:, 1][index[:, 0]], c=c[m], marker='.')plt.show()plt.pause(1)plt.savefig('picture1.2/{0}.jpg'.format(d))print('save sussece')# plt.ioff()# plt.clf()plt.close()torch.save(net.state_dict(), r'params/weightnet2.pt')

人脸识别 ArcFace 实现相关推荐

  1. 使用Delphi接入虹软人脸识别ArcFace,开发人脸库服务器

    利用虹软 SDK 开发局域网人脸库服务器 一.选择开发平台 以前做单位食堂人脸识别就餐时,会用到在线人脸识别,终端设备必须并入互联网,单位对人脸信息比较敏感,客户会要求提供内部网人脸库使用. 利用人脸 ...

  2. 人脸识别-arcface损失函数

    参考博客: L-margin softmax loss:https://blog.csdn.net/u014380165/article/details/76864572 A-softmax loss ...

  3. 虹软-人脸识别SDK的使用

    1.登录虹软开发者平台,进行注册. 官网:https://ai.arcsoft.com.cn/ucenter/resource/build/index.html#/login 注册成功之后,选择视觉开 ...

  4. 基于深度学习的人脸识别闸机开发(基于飞桨PaddlePaddle)

    目录 一.概述 1.1 人脸识别背景 1.2 实现 1.2.1 算法说明 1.2.2 环境设置 1.2.3 实现思路 二.示例脚本 2.1 安装PaddlePaddle和PLSC 2.2 下载人脸检测 ...

  5. ArcFace - 人脸识别

    点击上方"小白学视觉",选择加"星标"或"置顶" 重磅干货,第一时间送达 来源:知乎 作者:科密中的科蜜 链接:https://zhuanl ...

  6. 【java】人脸识别 虹软ArcFace 2.0-抽取人脸特征并做比对

    虹软产品地址:http://ai.arcsoft.com.cn/product/arcface.html 虹软ArcFace功能简介 人脸检测 人脸跟踪 人脸属性检测(性别.年龄) 人脸三维角度检测 ...

  7. 【人脸识别】arcface详解

    论文题目:<ArcFace Additive Angular Margin Loss for Deep Face Recognition > 论文地址:https://arxiv.org/ ...

  8. 人脸识别系列(十七):ArcFace/Insight Face

    论文链接:ArcFace: Additive Angular Margin Loss for Deep Face Recognition 作者开源代码:https://github.com/deepi ...

  9. 深度篇——人脸识别(一)  ArcFace 论文 翻译

    返回主目录 返回 人脸识别 目录 下一章:深度篇--人脸识别(二) 人脸识别代码 insight_face_pro 项目讲解 目录内容: 深度篇--人脸识别(一) ArcFace 论文 翻译 深度篇- ...

最新文章

  1. Android Service
  2. 第十六讲 傅里叶级数拓展
  3. 前端学习01-04格式标签
  4. liferay 指定默认首页
  5. NVIDIA Parallel Nsight
  6. eclipse-注释
  7. Windows 使用浮动键盘语言栏
  8. linux之ClamAV杀毒软件安装配置
  9. 【2021杭电多校赛】2021“MINIEYE杯”中国大学生算法设计超级联赛(2)签到题5题
  10. 数据结构与算法 | Leetcode 19. Remove Nth Node From End of List
  11. Kotlin — 竞技程序设计(类似天梯训练)
  12. 公众号淘宝客自营商城外卖返利小程序淘宝客小程序流量主返利app
  13. delphi 发送html邮件,Delphi下html编辑器,像foxmail或者Outlook的邮件编辑器一样 能够保存为单一文件如 mht,eml (200分)...
  14. 魔兽世界怀旧服正式服风铃键盘鼠标同步器TBC70级燃烧远征
  15. 译:25个面试中最常问的问题和答案
  16. Redis源码分析之双索引机制
  17. Maya布料解算入门
  18. PLC实验:LED 数码显示控制
  19. 做量化交易需要了解的国外在线量化平台有哪些?
  20. 视频文件损坏怎么修复?简单的修复办法分享

热门文章

  1. 2021-10-27 - 开发人员将大多数时间花到了探究系统本身上
  2. 虚拟机(VMware Workstation或Hyper-V)装ghost版系统提示“ntldr is missing Press Ctrl+Alt+del to Resta
  3. 细数饿了么开源的前端项目及实践
  4. HRBUST - 1646
  5. 增量式PID控制算法及仿真
  6. 雨滴网易云播放器html代码,求大佬帮忙看下 雨滴音乐插件怎么改代码关联网易云?...
  7. JS手机触摸屏的事件用法详解
  8. php 短网址 算法,php生成短网址的思路以及实现方法
  9. 例一---骨骼肌肉模型简介
  10. C# 经常忘 该记记