写东西最好中午写,因为早晚不想写

一、论文回顾

论文获取:http://AlignedReID: Surpassing Human-Level Performance in Person Re-Identification

AlignedReID很有意思,提出了动态规划下计算局部最短路径的方法。这条最短路径中的一条边就对应了一对局部特征的匹配,它给出了一种人体对齐的方式,在保证身体个部分相对顺序正确的情况下,这种对齐方式的总距离是最短的。在训练的时候,最短路径的长度被加入到局部损失函数,辅助学习行人的整体特征。

要看懂这张图,我们分几个小问题来一步步分析。

1. 7x7的矩阵是怎么得到的?——用F和G分别代表两张图,其特征分别为,在文中H=7,分别计算各部分的相似度(根据公式一),得出7x7=49种相似度,即一个7x7的矩阵。

2. 拐点的含义是什么?——图中拐点表示两张图像切片相对应的位置,比如(2,4)表示Image A的第二块切片和Image B的第四块切片是Aligned

3. 如何理解最短路径?——分黑线和黑箭头两部分理解。首先,第一行黑线横跨矩阵中1-4行,表示Image A的第一块切片与Image B的1-4块切片是corresponding alignment;其次,第一行黑箭头同样横跨矩阵中1-4行,表示Image A的第一块切片与整个Image B的最大相似度(或最短距离),换句话说就是计算机认为A的第一块切片只于B的前四块切片相似,与后面三块完全不相似,所以剩下的三块不属于对应对齐区域,也就无关最短路径的计算。

4. 黑线的最短路径计算公式是什么?(你会说是公式2,没有错,那你能自己推导吗)

Image A的第一块切片与Image B的距离:

Image A的第二块切片与Image B的距离:

Image A的第三块切片与Image B的距离:

Image A的第四块切片与Image B的距离:

Image A的第五块切片与Image B的距离:

Image A的第六块切片与Image B的距离:

Image A的第七块切片与Image B的距离:

二、代码解析

AlignedReID的代码与Triplet loss很相似,由于之前已经详细解析过Triplet loss源码了https://blog.csdn.net/m0_57541899/article/details/122243847?spm=1001.2014.3001.5501https://blog.csdn.net/m0_57541899/article/details/122243847?spm=1001.2014.3001.5501这里直接在代码上解析。先把几个值得注意的函数写在前面:

1. torch.mean(object,dim,keepdim):对指定维度求平均,将指定的那维全变成1。如一个大小为(2,3)的tensor,其中2代表0维,3代表1维,对0维求平均,则tensor大小变为(1,3)

2.  permute(dims):将tensor的维度换位。如Imag的size是(28,28,3),就可以利用img.permute(2,0,1)得到一个size为(3,28,28)的张量

3. squzze(a,axis=None):a为输入数组,axis用于删除指定维度

4. clamp(float min-number,float max-number,float parameter):用法为fmin<fp<fmax,则返回fp;fp>fmax,则返回fmax;fp<fmin,则返回fmin

from __future__ import print_function
import torchdef normalize(x, axis=-1):"""Normalizing to unit length along the specified dimension.Args:x: pytorch VariableReturns:x: pytorch Variable, same shape as input      """x = 1. * x / (torch.norm(x, 2, axis, keepdim=True).expand_as(x) + 1e-12)  # torch.norm(input,p,dim):calculate the norm in the specified dimension(p-dim)return xdef euclidean_dist(x, y):"""Args:x: pytorch Variable, with shape [m, d]y: pytorch Variable, with shape [n, d]Returns:dist: pytorch Variable, with shape [m, n]"""m, n = x.size(0), y.size(0)  # n:128  m:128xx = torch.pow(x, 2).sum(1, keepdim=True).expand(m, n)  # 1:means axis=1yy = torch.pow(y, 2).sum(1, keepdim=True).expand(n, m).t()dist = xx + yydist.addmm_(1, -2, x, y.t())  # dist=1*dist-2*(x@y^T);  # Note:dist.addmm_ & dist.addmmdist = dist.clamp(min=1e-12).sqrt()  # gain the distance matrix between samplesreturn distdef batch_euclidean_dist(x, y):  # x,y={Tensor:(128,8,128)}={Tensor:(N,m,d)}"""Args:x(local_feat): pytorch Variable, with shape [N, m, d]=[128,8,128]y(local_feat[p_inds]): pytorch Variable, with shape [N, n, d]=[128,8,128]Returns:dist: pytorch Variable, with shape [N, m, n]"""assert len(x.size()) == 3assert len(y.size()) == 3assert x.size(0) == y.size(0)assert x.size(-1) == y.size(-1)N, m, d = x.size()  # N:128 m:8 d=128N, n, d = y.size()  # n=8# shape [N, m, n]xx = torch.pow(x, 2).sum(-1, keepdim=True).expand(N, m, n)  # xx={Tensor:(128,8,8)}yy = torch.pow(y, 2).sum(-1, keepdim=True).expand(N, n, m).permute(0, 2, 1)  # yy={Tensor:(128,8,8)}dist = xx + yydist.baddbmm_(1, -2, x, y.permute(0, 2, 1))  # dist=1*dist-2*(x@y.permute);  # Note:dist.baddmm_ & dist.baddmmdist = dist.clamp(min=1e-12).sqrt()  # for numerical stabilityreturn distdef shortest_dist(dist_mat):"""Parallel version.Args:dist_mat: pytorch Variable, available shape:1) [m, n]2) [m, n, N], N is batch size3) [m, n, *], * can be arbitrary additional dimensionsReturns:dist: three cases corresponding to `dist_mat`:1) scalar2) pytorch Variable, with shape [N]3) pytorch Variable, with shape [*]"""m, n = dist_mat.size()[:2]# Just offering some reference for accessing intermediate distance.dist = [[0 for _ in range(n)] for _ in range(m)]  # initializationfor i in range(m):for j in range(n):if (i == 0) and (j == 0):dist[i][j] = dist_mat[i, j]elif (i == 0) and (j > 0):dist[i][j] = dist[i][j - 1] + dist_mat[i, j]elif (i > 0) and (j == 0):dist[i][j] = dist[i - 1][j] + dist_mat[i, j]else:dist[i][j] = torch.min(dist[i - 1][j], dist[i][j - 1]) + dist_mat[i, j]dist = dist[-1][-1]return distdef local_dist(x, y):"""Args:x: pytorch Variable, with shape [M, m, d]y: pytorch Variable, with shape [N, n, d]Returns:dist: pytorch Variable, with shape [M, N]"""M, m, d = x.size()N, n, d = y.size()x = x.contiguous().view(M * m, d)y = y.contiguous().view(N * n, d)# shape [M * m, N * n]dist_mat = euclidean_dist(x, y)dist_mat = (torch.exp(dist_mat) - 1.) / (torch.exp(dist_mat) + 1.)# shape [M * m, N * n] -> [M, m, N, n] -> [m, n, M, N]dist_mat = dist_mat.contiguous().view(M, m, N, n).permute(1, 3, 0, 2)# shape [M, N]dist_mat = shortest_dist(dist_mat)return dist_matdef batch_local_dist(x, y):"""Args:x(local_feat): pytorch Variable, with shape [N, m, d]=[128,8,128]y(local_feat[p_inds]): pytorch Variable, with shape [N, n, d]=[128,8,128]Returns:dist: pytorch Variable, with shape [N]"""assert len(x.size()) == 3  # judge the 'x' matrix whether is 3-dim,if not,report errorassert len(y.size()) == 3assert x.size(0) == y.size(0)assert x.size(-1) == y.size(-1)# shape [N, m, n]dist_mat = batch_euclidean_dist(x, y)  # dist_mat={Tensor:(128,8,8)}dist_mat = (torch.exp(dist_mat) - 1.) / (torch.exp(dist_mat) + 1.)  # normalization# shape [N]dist = shortest_dist(dist_mat.permute(1, 2, 0))  # (128,8,8)-->(8,8,128),then calculate the shortest distance under dynamic planningreturn distdef hard_example_mining(dist_mat, labels, return_inds=False):"""For each anchor, find the hardest positive and negative sample.Args:dist_mat: pytorch Variable, pair wise distance between samples, shape [N, N]labels: pytorch LongTensor, with shape [N]return_inds: whether to return the indices. Save time if `False`(?)Returns:dist_ap: pytorch Variable, distance(anchor, positive); shape [N]dist_an: pytorch Variable, distance(anchor, negative); shape [N]p_inds: pytorch LongTensor, with shape [N]; indices of selected hard positive samples; 0 <= p_inds[i] <= N - 1n_inds: pytorch LongTensor, with shape [N];indices of selected hard negative samples; 0 <= n_inds[i] <= N - 1NOTE: Only consider the case in which all labels have same num of samples, thus we can cope with all anchors in parallel."""assert len(dist_mat.size()) == 2  # judge the 'dist_mat' matrix whether is 2-dim,if not,report errorassert dist_mat.size(0) == dist_mat.size(1)  # judge the 'dist_mat' matrix whether is Square matrix,if not,report errorN = dist_mat.size(0)  # gain the 'dist_mat' matrix length i.e.N(N: 128)# shape [N, N]is_pos = labels.expand(N, N).eq(labels.expand(N, N).t())  # gain the positive sampleis_neg = labels.expand(N, N).ne(labels.expand(N, N).t())  # gain the negative sample# `dist_ap` means distance(anchor, positive)# both `dist_ap` and `relative_p_inds` with shape [N, 1]dist_ap, relative_p_inds = torch.max(dist_mat[is_pos].contiguous().view(N, -1), 1, keepdim=True)  # calculate the min similarity(i.e.max distance) between anchor and positive,and return the corresponding serial number# `dist_an` means distance(anchor, negative)# both `dist_an` and `relative_n_inds` with shape [N, 1]dist_an, relative_n_inds = torch.min(dist_mat[is_neg].contiguous().view(N, -1), 1, keepdim=True)  # calculate the max similarity(i.e.min distance) between anchor and negative,and return the corresponding serial number# shape [N]dist_ap = dist_ap.squeeze(1)  # compression dimensiondist_an = dist_an.squeeze(1)if return_inds:# shape [N, N]ind = (labels.new().resize_as_(labels).copy_(torch.arange(0, N).long()).unsqueeze( 0).expand(N, N))# shape [N, 1]p_inds = torch.gather(ind[is_pos].contiguous().view(N, -1), 1, relative_p_inds.data)n_inds = torch.gather(ind[is_neg].contiguous().view(N, -1), 1, relative_n_inds.data)# shape [N]p_inds = p_inds.squeeze(1)n_inds = n_inds.squeeze(1)return dist_ap, dist_an, p_inds, n_indsreturn dist_ap, dist_andef global_loss(tri_loss, global_feat, labels, normalize_feature=True):"""Args:tri_loss: a `TripletLoss` objectglobal_feat: pytorch Variable, shape [N, C]labels: pytorch LongTensor, with shape [N]normalize_feature: whether to normalize feature to unit length along the Channel dimensionReturns:loss: pytorch Variable, with shape [1]p_inds: pytorch LongTensor, with shape [N]; indices of selected hard positive samples; 0 <= p_inds[i] <= N - 1n_inds: pytorch LongTensor, with shape [N];indices of selected hard negative samples; 0 <= n_inds[i] <= N - 1=============For Debugging=============dist_ap: pytorch Variable, distance(anchor, positive); shape [N]dist_an: pytorch Variable, distance(anchor, negative); shape [N]===================For Mutual Learning===================dist_mat: pytorch Variable, pairwise euclidean distance; shape [N, N]"""if normalize_feature:global_feat = normalize(global_feat, axis=-1)# shape [N, N]dist_mat = euclidean_dist(global_feat, global_feat)dist_ap, dist_an, p_inds, n_inds = hard_example_mining(dist_mat, labels, return_inds=True)loss = tri_loss(dist_ap, dist_an)return loss, p_inds, n_inds, dist_ap, dist_an, dist_matdef local_loss(tri_loss,local_feat,p_inds=None,n_inds=None,labels=None,normalize_feature=True):"""Args:tri_loss: a `TripletLoss` objectlocal_feat: pytorch Variable, shape [N, H, c] (NOTE THE SHAPE!)p_inds: pytorch LongTensor, with shape [N]; indices of selected hard positive samples; 0 <= p_inds[i] <= N - 1n_inds: pytorch LongTensor, with shape [N];indices of selected hard negative samples; 0 <= n_inds[i] <= N - 1labels: pytorch LongTensor, with shape [N]normalize_feature: whether to normalize feature to unit length along the Channel dimensionIf hard samples are specified by `p_inds` and `n_inds`, then `labels` is not used. Otherwise, local distance finds its own hard samples independent of global distance.Returns:loss: pytorch Variable,with shape [1]=============For Debugging=============dist_ap: pytorch Variable, distance(anchor, positive); shape [N]dist_an: pytorch Variable, distance(anchor, negative); shape [N]===================For Mutual Learning===================dist_mat: pytorch Variable, pairwise local distance; shape [N, N]"""if normalize_feature:local_feat = normalize(local_feat, axis=-1)  # local_feat={Tensor:(128,8,128)}if p_inds is None or n_inds is None:dist_mat = local_dist(local_feat, local_feat)dist_ap, dist_an = hard_example_mining(dist_mat, labels, return_inds=False)loss = tri_loss(dist_ap, dist_an)return loss, dist_ap, dist_an, dist_matelse:dist_ap = batch_local_dist(local_feat, local_feat[p_inds])  # dist_ap:local_dist_ap;  local_feat[p_inds]:positive_local_featdist_an = batch_local_dist(local_feat, local_feat[n_inds])loss = tri_loss(dist_ap, dist_an)  # loss:local_lossreturn loss, dist_ap, dist_an

AlignedReID 源码解析相关推荐

  1. 谷歌BERT预训练源码解析(二):模型构建

    版权声明:本文为博主原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接和本声明. 本文链接:https://blog.csdn.net/weixin_39470744/arti ...

  2. 谷歌BERT预训练源码解析(三):训练过程

    目录 前言 源码解析 主函数 自定义模型 遮蔽词预测 下一句预测 规范化数据集 前言 本部分介绍BERT训练过程,BERT模型训练过程是在自己的TPU上进行的,这部分我没做过研究所以不做深入探讨.BE ...

  3. 谷歌BERT预训练源码解析(一):训练数据生成

    目录 预训练源码结构简介 输入输出 源码解析 参数 主函数 创建训练实例 下一句预测&实例生成 随机遮蔽 输出 结果一览 预训练源码结构简介 关于BERT,简单来说,它是一个基于Transfo ...

  4. Gin源码解析和例子——中间件(middleware)

    在<Gin源码解析和例子--路由>一文中,我们已经初识中间件.本文将继续探讨这个技术.(转载请指明出于breaksoftware的csdn博客) Gin的中间件,本质是一个匿名回调函数.这 ...

  5. Colly源码解析——结合例子分析底层实现

    通过<Colly源码解析--框架>分析,我们可以知道Colly执行的主要流程.本文将结合http://go-colly.org上的例子分析一些高级设置的底层实现.(转载请指明出于break ...

  6. libev源码解析——定时器监视器和组织形式

    我们先看下定时器监视器的数据结构.(转载请指明出于breaksoftware的csdn博客) /* invoked after a specific time, repeatable (based o ...

  7. libev源码解析——定时器原理

    本文将回答<libev源码解析--I/O模型>中抛出的两个问题.(转载请指明出于breaksoftware的csdn博客) 对于问题1:为什么backend_poll函数需要指定超时?我们 ...

  8. libev源码解析——I/O模型

    在<libev源码解析--总览>一文中,我们介绍过,libev是一个基于事件的循环库.本文将介绍其和事件及循环之间的关系.(转载请指明出于breaksoftware的csdn博客) 目前i ...

  9. libev源码解析——调度策略

    在<libev源码解析--监视器(watcher)结构和组织形式>中介绍过,监视器分为[2,-2]区间5个等级的优先级.等级为2的监视器最高优,然后依次递减.不区分监视器类型和关联的文件描 ...

最新文章

  1. Java面试题目,java关键字final
  2. 牛逼,Java中表达式引擎工具就用它!建议收藏,一定用的到!!
  3. Lingo解决最优化问题
  4. 开课吧:数据分析的价值体现在哪些方面?
  5. python实现谷歌翻译
  6. RedHat免费订阅账号注册方式
  7. 针对《评人工智能如何走向新阶段》一文,继续发布国内外的跟贴留言466-476条如下:
  8. freecodecamp小练习——Falsy Bouncer过滤数组假值
  9. 苹果手机登录华为账号无法连接服务器,华为手机登录华为帐号时,显示“没法连接到服务器”是什么情况?...
  10. wordpress pdf_9个适用于WordPress的最佳PDF插件
  11. 《Python编程:从入门到实践》读书笔记
  12. (Animator详解一)mixamo动画导入Unity的一些配置
  13. Elasticsearch 索引别名应用
  14. Oracle数据库 —— DDL
  15. OpenDrive格式地图数据解析
  16. C++的高精乘+高精加
  17. 逻辑运算符(logical operator)
  18. echarts散点图中大小_ECharts如何实现散点图
  19. H5测试||锦上添花的辅助工具
  20. 第3章第6节:如何在幻灯片中使用gif动画 [PowerPoint精美幻灯片实战教程]

热门文章

  1. CBA联赛32轮 山西男篮92:86战胜江苏同曦
  2. python画图怎样写文字_python画图系列之个性化显示x轴区段文字的实例
  3. 城市轨道交通自动售检票系统
  4. (二)OpenCV | 阿尔法混合
  5. 使用FrameWork发布IQD文件
  6. 相关词挖掘-下拉词挖掘免费工具-用户都在相关搜索的关键词挖掘
  7. 【机器学习之模型融合】Voting投票法简单实践
  8. HDFS心跳机制--判断DN失联部分的源码解析
  9. Topit专辑原图批量下载JavaScript脚本
  10. C语言使用strcmp()函数对两个汉字字符进行比较