来源:Coursera吴恩达深度学习课程

在人脸识别中,我们希望学习“输入两张人脸图片,然后输出相似度”的函数d,然后Siamese 网络(Siamese network)实现了这个功能。这篇文章将探讨如何定义实际的目标函数(define an objective function),能够让神经网络学习并训练Siamese网络架构呢?

要想通过学习神经网络的参数来得到优质的人脸图片编码,方法之一就是定义三元组损失函数(triplet loss function),然后应用梯度下降(gradient descent )

为了应用三元组损失函数,你需要比较成对的图像。用三元组损失的术语来说,你要做的通常是看一个 Anchor 图片,你想让Anchor图片和Positive图片(Positive意味着是同一个人)的距离很接近。然而,当Anchor图片与Negative图片(Negative意味着是非同一个人)对比时,你会想让他们的距离离得更远一点。这就是为什么叫做三元组损失,它代表你通常会同时看三张图片(Anchor图片、Postive图片和Negative图片),通常把它们简写成A、P、N

把这些写成公式的话,你想要的是网络的参数或者编码能够满足以下特性

公式两边分别对应如下距离:

这里的d可以看成距离(distance)函数。现在如果我把方程右边项移到左边,最终就得到:

值得注意的是,上面的表达式有一种情况:把所有东西都学成0,如果f总是输出0,0-0=0,也满足这个方程,但是没有用处。所以为了确保网络对于所有的编码不会总是输出0,也为了确保它不会把所有的编码都设成互相相等的。我们做一些改变,应该是比0还要小,小于一个-alpha值(另一个超参数),即:

按照惯例,一般写成这样:

这个参数也叫做间隔(margin)参数。间隔参数的作用是它拉大了Anchor和Positive 图片对和Anchor与Negative 图片对之间的差距。取下面的这个方框圈起来的方程式,然后定义三元组损失函数。

三元组损失函数的定义基于三张图片,假如三张图片A、P、N,即anchor样本、positive样本和negative样本,其中positive图片和anchor图片是同一个人,但是negative图片和anchor不是同一个人。如上图,为了定义这个损失函数,我们取这个和0的最大值:

这个max函数的作用是只要你能使画绿色下划线部分小于等于0,只要你能达到这个目标,那么这个例子的损失就是0。另一方面,如果绿色下划线部分大于0,则它是最大值,得到一个正的损失值,通过最小化损失函数使得这部分小于或等于0。

这是一个三元组定义的损失,整个网络的成本函数J应该是训练集中这些单个三元组损失的总和(sum)。假如你有一个10000个图片的训练集,里面是1000个不同的人的照片,你要做的就是取这10000个图片,然后生成这样的三元组,然后训练你的学习算法,对这种代价函数用梯度下降,这个代价函数就是定义在你数据集里的这样的三元组图片上。

注意,为了定义三元组的数据集需要成对的A和P,即同一个人的成对的图片,为了训练你的系统你确实需要一个数据集,里面有同一个人的多个照片。如果你只有每个人一张照片,那么根本没法训练这个系统。当然,训练完这个系统之后,你可以应用到你的一次学习问题上,对于你的人脸识别系统,可能你只有想要识别的某个人的一张照片。但对于训练集,你需要确保有同一个人的多个图片,至少是你训练集里的一部分人,这样就有成对(pairs)的Anchor和Positive图片了。

下面看如何选择这些三元组来形成训练集(training set)。

一个问题是如果你从训练集中,随机地选择A、P和N,遵守A和P是同一个人,而A和N是不同的人这一原则。随机选择的话约束条件很容易满足,因为A和N比A和P差别很大的概率很大,这样网络并不能从中学到什么。所以要尽可能的选择难(hard)训练的A、P和N,使得d(A,P)和d(A,N)很接近,即:

这样你的学习算法会竭尽全力使右边这个式子变大,或者使左边这个式子变小,这样左右两边至少有一个间隔(margin),并且选择这样的三元组还可以增加你的学习算法的计算效率。因此,只有选择难的三元组梯度下降法才能发挥作用,使得这两边离得尽可能远。

如果感兴趣的话,这篇论文中有更多细节,作者是Florian Schroff, Dmitry Kalenichenko, James Philbin,他们建立了这个叫做FaceNet的系统,Florian Schroff, Dmitry Kalenichenko, James Philbin (2015). FaceNet: A Unified Embedding for Face Recognition and Clustering

总结一下,如上图,训练这个三元组损失你需要取你的训练集,然后把它做成很多三元组。定义了这些包括A、P和N的数据集之后,还需要最小化成本函数J。这样做的效果就是反向传播到网络中的所有参数来学习到一种编码,使得如果两个图片是同一个人,那么它们的d就会很小,如果两个图片不是同一个人,它们的d就会很大。

这就是三元组损失,并且如何用它来训练网络输出一个好的编码用于人脸识别。现在的人脸识别系统,尤其是大规模的商业人脸识别系统(large-scale commercial face recognition)都是在很大的数据集上训练。幸运的是,一些公司已经训练了这些大型的网络并且上传了模型参数。这一领域的一个实用操作就是下载别人的预训练模型(pretrained model),而不是一切都要从头开始(do everything from scratch yourself)。但是即使你下载了别人的预训练模型,Andrew认为了解怎么训练这些算法也是有用的,以防针对一些应用你需要从头实现这些想法。

说明:记录学习笔记,如果错误欢迎指正!转载请联系我。

三元组损失(Triplet loss)相关推荐

  1. 三元组损失 Triplet Loss及其梯度

    Triplet Loss及其梯度 Triplet Loss及其梯度_jcjx0315的博客-CSDN博客 Triplet Loss简介 我这里将Triplet Loss翻译为三元组损失,其中的三元也就 ...

  2. 人脸检测、人脸定位、人脸对齐、MTCNN、人脸识别(衡量人脸的相似或不同:softmax、三元组损失Triplet Loss、中心损失Center Loss、ArcFace)

    日萌社 人工智能AI:Keras PyTorch MXNet TensorFlow PaddlePaddle 深度学习实战(不定时更新) # ArcFace class ArcMarginProduc ...

  3. 一文理解Ranking Loss/Margin Loss/Triplet Loss

    点击蓝字  关注我们 作者丨土豆@知乎 来源丨https://zhuanlan.zhihu.com/p/158853633 本文已获授权,未经作者许可,不得二次转载. 前言 Ranking loss在 ...

  4. Parameter Sharing Exploration and Hetero-center Triplet Loss

    参数共享探索与异心三元祖损失 摘要 探讨了双流网络应该共享多少参数,通过拆分 ResNet50 模型构建特定模态特征提取网络和模态共享特征嵌入网络,证明了双流网络参数共享对 VT Re-ID 的效果. ...

  5. 一文理解Ranking Loss/Contrastive Loss/Margin Loss/Triplet Loss/Hinge Loss

    一文理解Ranking Loss/Contrastive Loss/Margin Loss/Triplet Loss/Hinge Loss 翻译自FesianXu, 2020/1/13, 原文链接 h ...

  6. Triplet loss 源码解析

    2021最后一天,赶紧学习一手吧.由于 Triplet loss 很重要,而代码不复习又很容易忘,这里记录一下. 代码在这里:https://github.com/VisualComputingIns ...

  7. 度量学习(Metric learning、损失函数、triplet、三元组损失、fastreid)

    定义 Metric learning 是学习一个度量相似度的距离函数:相似的目标离得近,不相似的离得远. 一般来说,DML包含三个部分,如下图.. 1)特征提取网络:map embedding 2)采 ...

  8. 三元损失“In Defense of the Triplet Loss for Person Re-Identification”

    更全面的阅读记录可以参考这篇博客:https://blog.csdn.net/xuluohongshang/article/details/78965580 背景描述 提出了一个三元损失的变形用于行人 ...

  9. 机器学习笔记:triplet loss

    1 Triplet loss Triplet Loss,即三元组损失,其中的三元是Anchor.Negative.Positive. 通过Triplet Loss的学习后使得Positive元和Anc ...

最新文章

  1. python数据类型-Python语言基本数据类型
  2. JVM详解之:类的加载链接和初始化
  3. oracle中drop、delete和truncate的区别
  4. 时间处理总结(二)oracle
  5. Kaggle新上比赛预测地震-总奖池5万美金
  6. 【C语言】fgets函数返回值
  7. 移植RTT使用cubeMx配置后出现 cannot open source input file stm32f1xx_hal_exti.h: No such file or directory
  8. 全国大学生数学建模竞赛中,哈工大被禁用 MATLAB!
  9. 呼吁成立中国FreeType联盟
  10. VS2008 调试windows服务项目
  11. 阿帕奇apache服务器和webDav服务器快速配置。
  12. tbopen链接生成工具_影视、航空行业都在用的短链接工具,免费短链接生成器?...
  13. swift转场动画_Swift游乐场的演变
  14. GameBuilder开发游戏应用系列之60行代码实现FlappyBird
  15. Uniapp关于 Android原生插件开发案例
  16. 如何有效阅读《C++ Primer》那么厚的书
  17. Aspen中物性方法选择
  18. word中不能设置首字下沉的一个原因
  19. asp sql ip地址排序_SQL语言基础
  20. 关于数据治理的读书笔记 - 数据治理能力成熟度评估

热门文章

  1. 唯品会的服务化[转]
  2. windows10中定时备份文件和清除文件
  3. EDIUS中的图片该如何进行剪裁
  4. 从第三方应用跳回uniapp开发的app
  5. 用python中的turtle库绘制一些有趣的图
  6. 用canvas画“哆啦A梦”时钟
  7. 喾哲~ (八月最佳)
  8. 药明海德在苏州打造疫苗CDMO服务中国基地;现代汽车将在印尼新首都启用“空中出租车” | 美通企业日报...
  9. Web 攻防之业务安全:账号安全案例总结.
  10. SRM 683 div1 hard