深度学习笔记(44) Triplet 损失

  • 1. 三元组损失函数
  • 2. 损失函数公式
  • 3. 训练集

1. 三元组损失函数

已经了解了Siamese网络架构,并且知道想要网络输出什么,即什么是好的编码
但是如何定义实际的目标函数
能够让神经网络学习并做到 深度学习笔记(43) Siamese网络 讨论的内容呢?

要想通过学习神经网络的参数来得到优质的人脸图片编码
方法之一就是:定义三元组损失函数然后应用梯度下降


为了应用三元组损失函数,需要比较成对的图像
比如这个图片,为了学习网络的参数,需要同时看几幅图片
比如这对图片(编号1和编号2),想要它们的编码相似,因为这是同一个人
然而假如是这对图片(编号3和编号4),会想要它们的编码差异大一些,因为这是不同的人

用三元组损失的术语来说,要做的通常是看一个 Anchor 图片
想让 Anchor图片 和 Positive图片(Positive意味着是同一个人)的距离很接近
当 Anchor图片 与 Negative图片(Negative意味着是非同一个人)对比时
会想让他们的距离离得更远一点


它代表你通常会同时看三张图片,需要看 Anchor图片、Postive图片,还有Negative图片
要把 Anchor图片、Positive图片和Negative图片简写成 APN


2. 损失函数公式

把这些写成公式的话,想要的是网络的参数或者编码能够满足以下特性:
想要 ||f( A )-f( P )||2,希望这个数值很小
准确地说,想让它小于等 f( A ) 和 f( N ) 之间的距离
或者说是它们的范数的平方,即:||f( A ) - f( P )||2 ≤ ||f( A ) - f( N )||2
||f( A ) - f( P )||2 ,这就是 d(A,P)d(A,P)d(A,P)
||f( A ) - f( N )||2 ,这就是 d(A,N)d(A,N)d(A,N)
可以把 ddd 看作是距离(distance)函数,这也是为什么把它命名为 ddd

现在如果把方程右边项移到左边,最终就得到:
||f( A ) - f( P )||2 - ||f( A ) - f( N )||2 ≤ 0

现在要对这个表达式做一些小的改变
有一种情况满足这个表达式,但是没有用处,就是把所有的东西都学成0
如果 f 总是输出0,即0 - 0 ≤ 0,这就是0减去0还等于0
如果所有图像的 f 都是一个零向量,那么总能满足这个方程

所以为了确保网络对于所有的编码不会总是输出0
也为了确保它不会把所有的编码都设成互相相等的
另一种方法能让网络得到这种没用的输出
就是如果每个图片的编码和其他图片一样
这种情况,还是得到0 - 0

为了阻止网络出现这种情况,需要修改这个目标
也就是这个不能是刚好小于等于0,应该是比0还要小
所以这个应该小于一个 -a 值(即 ||f( A ) - f( P )||2 - ||f( A ) - f( N )||2 ≤ -a)
这里的 a 是另一个超参数,这个就可以阻止网络输出无用的结果
按照惯例,习惯写 +a(即 ||f( A ) - f( P )||2 - ||f( A ) - f( N )||2 + a ≤ 0)
而不是把 -a 写在后面,它也叫做 间隔(margin)
这个术语你会很熟悉,如果看过关于支持向量机 (SVM) 的文献
可以把上面这个方程(||f( A ) - f( P )||2 - ||f( A ) - f( N )||2)也修改一下,加上这个间隔参数

举个例子,假如间隔设置成0.2
如果在这个例子中,如果 Anchor和 Negative图片的d(A,P)d(A,P)d(A,P) = 0.5,
d(A,N)d(A,N)d(A,N)只大一点,比如说0.51,条件就不能满足
虽然0.51也是大于0.5的,但还是不够好

想要 d(A,N)d(A,N)d(A,N)d(A,P)d(A,P)d(A,P) 大很多,会想让 d(A,N)d(A,N)d(A,N) 至少是0.7或者更高
或者为了使这个间隔,或者间距至少达到0.2,可以把这项调大或者这个调小
这样这个间隔a,超参数a 至少是0.2
d(A,P)d(A,P)d(A,P)d(A,N)d(A,N)d(A,N) 之间至少相差0.2,这就是间隔参数a的作用
它拉大了Anchor和Positive 图片对和Anchor与Negative 图片对之间的差距

取下面的这个方框圈起来的方程式,更公式化表示,然后定义三元组损失函数

其中positive图片和anchor图片是同一个人,但是negative图片和anchor不是同一个人

接下来定义损失函数,这个例子的损失函数,它的定义基于三元图片组


所以为了定义这个损失函数
L(A,P,N)L(A,P,N)L(A,P,N) =maxmaxmax(||f( A ) - f( P )||2 - ||f( A ) - f( N )||2+a,0)

这个max函数的作用就是,只要这个||f( A ) - f( P )||2 - ||f( A ) - f( N )||2 + a ≤ 0
那么损失函数就是0
只要能使画绿色下划线部分小于等于0,只要能达到这个目标
那么这个例子的损失就是0

另一方面如果这个||f( A ) - f( P )||2 - ||f( A ) - f( N )||2 + a ≤ 0
然后取它们的最大值,最终会得到绿色下划线部分,即||f( A ) - f( P )||2 - ||f( A ) - f( N )||2 + a是最大值
这样会得到一个正的损失值
通过最小化这个损失函数达到的效果:使这部分 ||f( A ) - f( P )||2 - ||f( A ) - f( N )||2 + a 小于或者等于0
只要这个损失函数小于等于0,网络不会关心它负值有多大

整个网络的代价函数应该是 训练集中这些单个三元组损失的总和
假如有一个10000个图片的训练集,里面是1000个不同的人的照片(每个人十张图片)
要做的就是取这10000个图片,然后生成这样的三元组,然后训练学习算法
对这种代价函数用梯度下降,这个代价函数就是定义在数据集里的这样的三元组图片上

注意,为了定义三元组的数据集需要成对的A和P
即同一个人的成对的图片,为了训练系统确实需要一个数据集,里面有同一个人的多个照片
这样在1000个不同的人的10000张照片中,也许是这1000个人平均每个人10张照片
如果只有每个人一张照片,那么根本没法训练这个系统

当然,训练完这个系统之后,可以应用到一次学习问题上
对于人脸识别系统,可能只有想要识别的某个人的一张照片
但对于训练集,需要确保有同一个人的多个图片,至少是训练集里的一部分人
这样就有成对的Anchor和Positive图片了


3. 训练集

现在来看,如何选择这些三元组来形成训练集
一个问题是如果从训练集中,随机地选择A、P和N
遵守A和P是同一个人,而A和N是不同的人这一原则

有个问题就是,如果随机的选择它们,那么这个约束条件(d(A,P)d(A,P)d(A,P) + a ≤ d(A,N)d(A,N)d(A,N))很容易达到
因为随机选择的图片,A和N比A和P差别很大的概率很大
如果A和N是随机选择的不同的人,有很大的可能性 ||f( A ) - f( N )||2会比左边这项 ||f( A ) - f( P )||2
而且差距远大于a,这样网络并不能从中学到什么


所以为了构建一个数据集,要做的就是尽可能选择难训练的三元组A、P和N

具体而言,想要所有的三元组都满足这个条件(d(A,P)d(A,P)d(A,P) + a ≤ d(A,N)d(A,N)d(A,N)
难训练的三元组就是,A、P和N的选择使得d(A,P)d(A,P)d(A,P)很接近d(A,N)d(A,N)d(A,N),即 d(A,P)≈d(A,N)d(A,P)≈d(A,N)d(A,P)d(A,N)
这样学习算法会竭尽全力使右边这个式子变大(d(A,N)d(A,N)d(A,N)),或者使左边这个式子(d(A,P)d(A,P)d(A,P))变小
这样左右两边至少有一个a的间隔
并且选择这样的三元组还可以增加学习算法的计算效率

如果随机的选择这些三元组,其中有太多会很简单,梯度算法不会有什么效果
因为网络总是很轻松就能得到正确的结果
只有选择难的三元组梯度下降法才能发挥作用,使得这两边离得尽可能远

如果对此感兴趣的话,这篇论文中有更多细节
作者是Florian Schroff, Dmitry Kalenichenko, James Philbin
他们建立了这个叫做 FaceNet 的系统,博客的许多观点都是来自于他们的工作
FaceNet: A Unified Embedding for Face Recognition and Clustering

总结一下,训练这个三元组损失需要取训练集,然后把它做成很多三元组,这就是一个三元组(编号1)


有一个Anchor图片和Positive图片,这两个图片是同一个人,还有一张另一个人的Negative图片
这是另一组(编号2),其中Anchor和Positive图片是同一个人,但这两个图片不是同一个人,等等


定义了这些包括A、P和N图片的数据集之后
还需要做的就是用梯度下降最小化之前定义的代价函数 JJJ
这样做的效果就是反向传播到网络中的所有参数来学习到一种编码,使得如果两个图片是同一个人
那么它们的 ddd 就会很小,如果两个图片不是同一个人,它们的 ddd 就会很大

这就是三元组损失,并且如何用它来训练网络输出一个好的编码用于人脸识别
现在的人脸识别系统,尤其是大规模的商业人脸识别系统都是在很大的数据集上训练
超过百万图片的数据集并不罕见,一些公司用千万级的图片,还有一些用上亿的图片来训练这些系统
这些是很大的数据集,即使按照现在的标准,这些数据集并不容易获得

幸运的是,一些公司已经训练了这些大型的网络并且上传了模型参数
所以相比于从头训练这些网络,在这一领域,由于这些数据集太大
这一领域的一个实用操作就是下载别人的预训练模型,而不是一切都要从头开始
但是即使下载了别人的预训练模型,了解怎么训练这些算法也是有用的
以防针对一些应用需要从头实现这些想法


参考:

《神经网络和深度学习》视频课程


相关推荐:

深度学习笔记(43) Siamese网络
深度学习笔记(42) 人脸识别
深度学习笔记(41) 候选区域
深度学习笔记(40) YOLO
深度学习笔记(39) Anchor Boxes


谢谢!

深度学习笔记(44) Triplet 损失相关推荐

  1. 深度学习笔记(49) 风格代价函数

    深度学习笔记(49) 风格代价函数 1. 风格相关系数 2. 风格矩阵 3. 风格代价函数 1. 风格相关系数 比如有这样一张图片,可能已经对这个计算很熟悉了,它能算出这里是否含有不同隐藏层 现在选择 ...

  2. 深度学习笔记(48) 内容代价函数

    深度学习笔记(48) 内容代价函数 1. 激活函数值 2. 内容代价函数 1. 激活函数值 风格迁移网络的代价函数有一个内容代价部分,还有一个风格代价部分 先定义内容代价部分,不要忘了这就是整个风格迁 ...

  3. 深度学习笔记(47) 神经风格迁移

    深度学习笔记(47) 神经风格迁移 1. 神经风格迁移 2. 代价函数 1. 神经风格迁移 近,卷积神经网络最有趣的应用是神经风格迁移 来看几个例子,比如这张照片,照片是在斯坦福大学拍摄的 如果想利用 ...

  4. 深度学习笔记(46) 深度卷积网络学习

    深度学习笔记(46) 深度卷积网络学习 1. 学习内容 2. 第一层 3. 第二层 4. 第三层 5. 第四层 6. 第五层 1. 学习内容 深度卷积网络到底在学什么? 来看一个例子,假如训练了一个卷 ...

  5. 吴恩达深度学习笔记(四)

    吴恩达深度学习笔记(四) 卷积神经网络CNN-第二版 卷积神经网络 深度卷积网络:实例探究 目标检测 特殊应用:人脸识别和神经风格转换 卷积神经网络编程作业 卷积神经网络CNN-第二版 卷积神经网络 ...

  6. 深度学习笔记(45) 人脸验证与二分类

    深度学习笔记(45) 人脸验证与二分类 1. 二分类问题 2. 逻辑回归单元的处理 3. 计算技巧 1. 二分类问题 深度学习笔记(44) Triplet 损失 的Triplet loss是一个学习人 ...

  7. 一文弄懂元学习 (Meta Learing)(附代码实战)《繁凡的深度学习笔记》第 15 章 元学习详解 (上)万字中文综述

    <繁凡的深度学习笔记>第 15 章 元学习详解 (上)万字中文综述(DL笔记整理系列) 3043331995@qq.com https://fanfansann.blog.csdn.net ...

  8. 下载量过百万的吴恩达机器学习和深度学习笔记更新了!(附PDF下载)

    今天,我把吴恩达机器学习和深度学习课程笔记都更新了,并提供下载,这两本笔记非常适合机器学习和深度学习入门.(作者:黄海广) 0.导语 我和同学将吴恩达老师机器学习和深度学习课程笔记做成了打印版,放在g ...

  9. 33万字!深度学习笔记在线版发布!

    吴恩达老师的深度学习课程(deeplearning.ai),可以说是深度学习入门的最热门课程,我和志愿者编写了这门课的笔记,并在 github 开源,为满足手机阅读的需要,我将笔记做成了在线版,可以在 ...

最新文章

  1. 推荐系统-03-简单基于用户的推荐
  2. 底板芯片组与内存映射(Motherboard Chipsets and the Memory Map) 【转】
  3. dcba oracle,【转】dcba的文章:Oracle的SET UNUSED COLUMN操作到底做了什么?
  4. java心电图心率计算_java如何画心电图?
  5. SpringBoot中使用thymeleaf的switch来实现if-else if -else的效果
  6. 网页设计界面 电脑版设计
  7. SAP Analytics Cloud里避免类型为个数的measure出现小数点
  8. 如何添加JWT生成的token在请求头中
  9. Qt-VS开发:解决VS中使用带有信号槽的导出对象库时,信号槽不工作的问题
  10. 后通用芯片时代: 专用芯片兴起背后的经济学
  11. ArcSDE数据库连接(直连、服务连)与GT_Geometry存储配置图解
  12. springboot读取src下文件_springboot获取src/main/resource下的文件
  13. JS学习--取整方法整理
  14. ROP_return to dl-resolve学习笔记
  15. EasyTouch的使用
  16. 【mysql的编程专题①】流程控制与其他语法
  17. QT绘制同心扇形(Paintevent实现)
  18. [简单dp]toj1179
  19. 2500个常用中文字符 + 130常用中英文字符
  20. LeetCode 刷题记录模板

热门文章

  1. Composition-based Multi-Relational Graph Convolutional Networks 多关系图神经网络 ICLR 2020
  2. mysql 字符串特殊字符_转:MySQL数据入库时特殊字符处理
  3. anaconda进出某个环境
  4. Redis总结集群方式之主从复制
  5. 电子商务c语言实训报告,中南民族大学电子商务C语言实验报告.doc
  6. vscode linux新建c语言,Ubuntu16.04下配置VScode的C/C++开发环境
  7. 根目录_Ubuntu的根目录下的var/log/apt突然爆满,电脑卡死
  8. linux学习第四周作业练习
  9. 从零起步到Linux运维经理,你必须管好的23个细节
  10. 浅谈iOS中的蓝牙技术(一) GameKit framework