一、Triplet结构:

triplet loss是一种比较好理解的loss,triplet是指的是三元组:Anchor、Positive、Negative:

整个训练过程是:

  1. 首先从训练集中随机选一个样本,称为Anchor(记为x_a)。
  2. 然后再随机选取一个和Anchor属于同一类的样本,称为Positive (记为x_p)
  3. 最后再随机选取一个和Anchor属于不同类的样本,称为Negative (记为x_n)

由此构成一个(Anchor,Positive,Negative)三元组。

二、Triplet Loss:

在上一篇讲了Center Loss的原理和实现,会发现现在loss的优化的方向是比较清晰好理解的。在基于能够正确分类的同时,我们更希望模型能够:1、把不同类之间分得很开,也就是更不容易混淆;2、同类之间靠得比较紧密,这个对于模型的鲁棒性的提高也是比较有帮助的(基于此想到Hinton的Distillation中给softmax加的一个T就是人为的对训练过程中加上干扰,让distribution变得更加soft从而去把错误信息放大,这样模型能够不光知道什么是正确还知道什么是错误。即:模型可以从仅仅判断这个最可能是7,变为可以知道这个最可能是7、一定不是8、和2比较相近,论文讲解可以参看Hinton-Distillation)。

回归正题,三元组的三个样本最终得到的特征表达计为:

triplet loss的目的就是让Anchor这个样本的feature和positive的feature直接的距离比和negative的小,即:

除了让x_a和x_p特征表达之间的距离尽可能小,而x_a和x_n的特征表达之间的距离尽可能大之外还要让x_a与x_n之间的距离和x_a与x_p之间的距离之间有一个最小的间隔α,于是修改loss为:

于是目标函数为:

距离用欧式距离度量,+表示[  ***  ]内的值大于零的时候,取该值为损失,小于零的时候,损失为零。

故也可以理解为:

    L = max([ ] ,  0)

在code中就是这样实现的,利用marginloss,详见下节。

三、Code实现:

笔者使用pytorch:

from torch import nn
from torch.autograd import Variableclass TripletLoss(object):def __init__(self, margin=None):self.margin = marginif margin is not None:self.ranking_loss = nn.MarginRankingLoss(margin=margin)else:self.ranking_loss = nn.SoftMarginLoss()def __call__(self, dist_ap, dist_an):"""Args:dist_ap: pytorch Variable, distance between anchor and positive sample, shape [N]dist_an: pytorch Variable, distance between anchor and negative sample, shape [N]Returns:loss: pytorch Variable, with shape [1]"""y = Variable(dist_an.data.new().resize_as_(dist_an.data).fill_(1))if self.margin is not None:loss = self.ranking_loss(dist_an, dist_ap, y)else:loss = self.ranking_loss(dist_an - dist_ap, y)return loss

理解起来非常简单,当margin为空时,使用SoftMarginLoss:

当margin不为空时,使用MarginRankingLoss,y中填充的都是1,代表希望dist_an>dist_ap,即anchor到negative样本的距离大于到positive样本的距离,margin为dist_an - dist_ap的值需要大于多少:

与我们要得到的loss类似:当与正例距离+固定distance大于负例距离时为正值,则惩罚,否则不惩罚。

四、github项目介绍

class TripletLoss(nn.Module):"""Triplet loss with hard positive/negative mining.Reference:Hermans et al. In Defense of the Triplet Loss for Person Re-Identification. arXiv:1703.07737.Imported from `<https://github.com/Cysu/open-reid/blob/master/reid/loss/triplet.py>`_.Args:margin (float, optional): margin for triplet. Default is 0.3."""def __init__(self, margin=0.3,global_feat, labels):super(TripletLoss, self).__init__()self.margin = marginself.ranking_loss = nn.MarginRankingLoss(margin=margin)def forward(self, inputs, targets):"""Args:inputs (torch.Tensor): feature matrix with shape (batch_size, feat_dim).targets (torch.LongTensor): ground truth labels with shape (num_classes)."""n = inputs.size(0)    # batch_size# Compute pairwise distance, replace by the official when mergeddist = torch.pow(inputs, 2).sum(dim=1, keepdim=True).expand(n, n)dist = dist + dist.t()dist.addmm_(1, -2, inputs, inputs.t())dist = dist.clamp(min=1e-12).sqrt()  # for numerical stability# For each anchor, find the hardest positive and negativemask = targets.expand(n, n).eq(targets.expand(n, n).t())dist_ap, dist_an = [], []for i in range(n):dist_ap.append(dist[i][mask[i]].max().unsqueeze(0))dist_an.append(dist[i][mask[i] == 0].min().unsqueeze(0))dist_ap = torch.cat(dist_ap)dist_an = torch.cat(dist_an)# Compute ranking hinge lossy = torch.ones_like(dist_an)loss = self.ranking_loss(dist_an, dist_ap, y)return loss

pytorch triple-loss相关推荐

  1. pytorch 区间loss 损失函数

    pytorch 区间loss 损失函数 我们知道sigmoid可以把值转化为0-1之间. tanh函数可以把值转化到[-1,1]之间, 但是在回归时,希望梯度是均匀的,有么有别的方法呢? 答案是肯定的 ...

  2. 对比学习triple loss

    对比学习与多分类学习密切相关,包括对比学习损失和softmax分类损失的相关,最大的差别也就是距离的约束条件不同,也可以说是损失函数的不同.因为我目前做的是二分类,所以先简单的看一下二分类的损失. 1 ...

  3. Pytorch自定义Loss

    Pytorch如何自定义Loss 设置空loss: loss=torch.tensor(0).float().to(outs[0].device) 这个可以试试: regression_losses. ...

  4. pytorch 画loss曲线_Pytorch使用tensorboardX可视化。超详细!!!

    1 引言 我们都知道tensorflow框架可以使用tensorboard这一高级的可视化的工具,为了使用tensorboard这一套完美的可视化工具,未免可以将其应用到Pytorch中,用于Pyto ...

  5. pytorch查看loss曲线_pytorch loss总结与测试

    pytorch loss 参考文献: loss 测试 import torch from torch.autograd import Variable ''' 参考文献: https://blog.c ...

  6. pytorch自定义loss损失函数

    自定义loss的方法有很多,但是在博主查资料的时候发现有挺多写法会有问题,靠谱一点的方法是把loss作为一个pytorch的模块,比如: class CustomLoss(nn.Module): # ...

  7. 【解决方案】pytorch中loss变成了nan | 神经网络输出nan | MSE 梯度爆炸/梯度消失

    loss_func = nn.MSELoss() loss = loss_func(val, target) 最近在跑一个项目,计算loss时用了很普通的MSE,在训练了10到300个batch时,会 ...

  8. Triple loss

    一.Triplet结构: triplet loss是一种比较好理解的loss,triplet是指的是三元组:Anchor.Positive.Negative: 整个训练过程是: 首先从训练集中随机选一 ...

  9. 使用pytorch的loss.backward()时,出现element 0 of tensors does not require grad and does not have a grad_fn

    仅作为记录,大佬请跳过. 用loss.requires_grad_(True) 是因为loss没有设置梯度(所以不能反向传播)(loss的数据类型是tensor) 设置梯度后即可,展示: 参考 pyt ...

  10. pytorch focal loss

    focal loss的改进1: Focal Loss改进版 GFocal Loss_jacke121的专栏-CSDN博客_focal loss改进 focal loss的改进2: GHM loss 讲 ...

最新文章

  1. 对话jQuery之父John Resig:JavaScript的开发之路
  2. 学成在线--13.RabbitMQ工作模式
  3. 【MongoDB for Java】Java操作MongoDB
  4. ORACLE索引重建方法与索引的三种状态
  5. AutoCAD ObjectARX(VC)开发基础与实例教程2014版光盘镜像
  6. java安装_如何在 Mac 上安装 Java | Linux 中国
  7. JS删除数组中某一项或几项的方法汇总
  8. larval中redis的用法
  9. 压缩软件Bandizip
  10. 我国20年农药年施用量增百万吨 生产方式需反思
  11. ps命令查看进程详解
  12. 尝试一下LLJ大佬的理论AC大法
  13. 动态规划算法学习(一)爬楼梯和凑金额
  14. python案例——体脂率项目
  15. 计算机启动后 不显示桌面,电脑开机后不显示桌面怎么办?
  16. 利用appimage工具对开发好的项目进行打包
  17. 无法解析域名“cn.archive.ubuntu.com”。
  18. centos7.5系统动态扩容磁盘及系统挂载未分配硬盘空间
  19. 编译osgEarth2.8遇到gdal_vrt.h找不到的问题
  20. opencv仿射变换:平移,缩放和旋转

热门文章

  1. 如何用六边形网格制作炫酷的地名地址信息热力图?
  2. java内存图片_Java程序缩放图片时,内存占用令我百思不得其解
  3. About Birthday
  4. c语言printf char数组,在C中输出二维char数组的最快方法
  5. typedef函数指针
  6. [计算机毕业设计]MATLAB的人脸识别
  7. 基本面量化:一种多因子选股策略
  8. 七种排序(长路漫漫)
  9. Couldn‘t connect to trainer on port 5004 using API version 1.5.0. Will perform inference inst
  10. oracle的前端是什么,Oracle的那些事情