一:insighteface tripletloss实现中的dataiter部分解读

前言:

mxnet的dataiter一般为继承io.DataIter类,实现其中主要的几个函数。

需要实现的主要函数为:__init__(),reset(),next()。以及几个在fit函数中需要用到的几个属性:provide_label,provide_data等。

本文主要阅读了insightface代码中实现的triplet_image_iter.py。主要学习和解读其源代码。

代码解读:

1.首先需要实现的dataiter的reset()。源代码的reset()调用的主要函数为pick_triplets,select_triplets,pairwise_dists。

  • 选择三元组 pick_triplets

函数:根据facenet的选取三元组的规则,尽量选择离anchor近的负样本。

返回:返回所有的满足要求的三元组。

python知识点:np.logical_and()  np.where()

具体注释:

## 根据facenet的选取要求选取三元组,返回三元组列表def pick_triplets(self, embeddings, nrof_images_per_class):emb_start_idx = 0triplets = []people_per_batch = len(nrof_images_per_class)#类别数#self.time_reset()pdists = self.pairwise_dists(embeddings)#由提取的特征列表计算样本间距离#self.times[3] += self.time_elapsed()for i in xrange(people_per_batch):nrof_images = int(nrof_images_per_class[i])#每一类的图片数for j in xrange(1,nrof_images):#self.time_reset()a_idx = emb_start_idx + j - 1#neg_dists_sqr = np.sum(np.square(embeddings[a_idx] - embeddings), 1)neg_dists_sqr = pdists[a_idx]#self.times[3] += self.time_elapsed()for pair in xrange(j, nrof_images): # For every possible positive pair.p_idx = emb_start_idx + pair#self.time_reset()pos_dist_sqr = np.sum(np.square(embeddings[a_idx]-embeddings[p_idx]))#anchor 和 positive 之间的距离#self.times[4] += self.time_elapsed()#self.time_reset()neg_dists_sqr[emb_start_idx:emb_start_idx+nrof_images] = np.NaN#将距离列表中正样本对距离置为无穷大,方便之后的选取负样本if self.triplet_max_ap>0.0:if pos_dist_sqr>self.triplet_max_ap:continue#np.where:返回满足条天的数组坐标,多维数组的时候返回多个列表#np.logical_and:逻辑与,同时满足facenet选择条件(负样本和anchor的距离与正样本对距离之间差<alpha,正样本对距离小于负样本距离)all_neg = np.where(np.logical_and(neg_dists_sqr-pos_dist_sqr<self.triplet_alpha, pos_dist_sqr<neg_dists_sqr))[0]  # FaceNet selection#self.times[5] += self.time_elapsed()#self.time_reset()#all_neg = np.where(neg_dists_sqr-pos_dist_sqr<alpha)[0] # VGG Face selecctionnrof_random_negs = all_neg.shape[0]if nrof_random_negs>0:#随机选取一个满足条件的负样本rnd_idx = np.random.randint(nrof_random_negs)n_idx = all_neg[rnd_idx]triplets.append( (a_idx, p_idx, n_idx) )emb_start_idx += nrof_images#打乱顺序np.random.shuffle(triplets)return triplets
  • select_triplets :获取anchor ,p,n的主函数

函数调用pick_triplets。

输出:self.seq,其中内容是anchor_batch,p_batch,n_batch。

主要步骤:根据triplet_seq系列初始化数组,将数组送入模型进行前向计算,根据计算结果重新选择三元组,最终将三元组的结果保存。

    def select_triplets(self):self.seq = []while len(self.seq)<self.seq_min_size:self.time_reset()embeddings = Nonebag_size = self.triplet_bag_sizebatch_size = self.batch_size#data = np.zeros( (bag_size,)+self.data_shape )#label = np.zeros( (bag_size,) )tag = []#idx = np.zeros( (bag_size,) )print('eval %d images..'%bag_size, self.triplet_cur)print('triplet time stat', self.times)if self.triplet_cur+bag_size>len(self.triplet_seq):self.triplet_reset()#bag_size = min(bag_size, len(self.triplet_seq))print('eval %d images..'%bag_size, self.triplet_cur)self.times[0] += self.time_elapsed()self.time_reset()#print(data.shape)data = nd.zeros( self.provide_data[0][1] )label = Noneif self.provide_label is not None:label = nd.zeros( self.provide_label[0][1] )ba = 0#从0-bag_sizewhile True:bb = min(ba+batch_size, bag_size)if ba>=bb:break_count = bb-ba#data = nd.zeros( (_count,)+self.data_shape )#_batch = self.data_iter.next()#_data = _batch.data[0].asnumpy()#print(_data.shape)#_label = _batch.label[0].asnumpy()#data[ba:bb,:,:,:] = _data#label[ba:bb] = _labelfor i in xrange(ba, bb):#print(ba, bb, self.triplet_cur, i, len(self.triplet_seq))_idx = self.triplet_seq[i+self.triplet_cur]#triplet_reset中初始化s = self.imgrec.read_idx(_idx)header, img = recordio.unpack(s)img = self.imdecode(img)data[i-ba][:] = self.postprocess_data(img)_label = header.labelif not isinstance(_label, numbers.Number):_label = _label[0]if label is not None:label[i-ba][:] = _labeltag.append( ( int(_label), _idx) )#idx[i] = _idxdb = mx.io.DataBatch(data=(data,))##前向计算当前batchself.mx_model.forward(db, is_train=False)net_out = self.mx_model.get_outputs()#获取前向的结果#print('eval for selecting triplets',ba,bb)#print(net_out)#print(len(net_out))#print(net_out[0].asnumpy())net_out = net_out[0].asnumpy()#print(net_out)#print('net_out', net_out.shape)if embeddings is None:embeddings = np.zeros( (bag_size, net_out.shape[1]))embeddings[ba:bb,:] = net_outba = bbassert len(tag)==bag_sizeself.triplet_cur+=bag_sizeembeddings = sklearn.preprocessing.normalize(embeddings)self.times[1] += self.time_elapsed()self.time_reset()## 获取类别数和每个类别样本数nrof_images_per_class = [1]for i in xrange(1, bag_size):if tag[i][0]==tag[i-1][0]:#lablenrof_images_per_class[-1]+=1else:nrof_images_per_class.append(1)## 选择三元组triplets = self.pick_triplets(embeddings, nrof_images_per_class) # shape=(T,3)print('found triplets', len(triplets))ba = 0while True:bb = ba+self.per_batch_size//3if bb>len(triplets):break_triplets = triplets[ba:bb]for i in xrange(3):for triplet in _triplets:_pos = triplet[i]_idx = tag[_pos][1]#idxself.seq.append(_idx)# a_batch p_batch n_batch ba = bbself.times[2] += self.time_elapsed()

2.next()

函数返回:databatch,datalabel。

def next(self):if not self.is_init:self.reset()self.is_init = True"""Returns the next batch of data."""#print('in next', self.cur, self.labelcur)self.nbatch+=1batch_size = self.batch_sizec, h, w = self.data_shapebatch_data = nd.empty((batch_size, c, h, w))#batch_dataif self.provide_label is not None:batch_label = nd.empty(self.provide_label[0][1])i = 0try:while i < batch_size:label, s, bbox, landmark = self.next_sample()_data = self.imdecode(s)#编码s为一个ndarrayif self.rand_mirror:#随机镜像_rd = random.randint(0,1)if _rd==1:_data = mx.ndarray.flip(data=_data, axis=1)if self.cutoff>0:#截断centerh = random.randint(0, _data.shape[0]-1)centerw = random.randint(0, _data.shape[1]-1)half = self.cutoff//2starth = max(0, centerh-half)endh = min(_data.shape[0], centerh+half)startw = max(0, centerw-half)endw = min(_data.shape[1], centerw+half)_data = _data.astype('float32')#print(starth, endh, startw, endw, _data.shape)_data[starth:endh, startw:endw, :] = 127.5#_npdata = _data.asnumpy()#if landmark is not None:#  _npdata = face_preprocess.preprocess(_npdata, bbox = bbox, landmark=landmark, image_size=self.image_size)#if self.rand_mirror:#  _npdata = self.mirror_aug(_npdata)#if self.mean is not None:#  _npdata = _npdata.astype(np.float32)#  _npdata -= self.mean#  _npdata *= 0.0078125#nimg = np.zeros(_npdata.shape, dtype=np.float32)#nimg[self.patch[1]:self.patch[3],self.patch[0]:self.patch[2],:] = _npdata[self.patch[1]:self.patch[3], self.patch[0]:self.patch[2], :]#_data = mx.nd.array(nimg)data = [_data]try:self.check_valid_image(data)except RuntimeError as e:logging.debug('Invalid image, skipping:  %s', str(e))continue#print('aa',data[0].shape)#data = self.augmentation_transform(data)#print('bb',data[0].shape)for datum in data:assert i < batch_size, 'Batch size must be multiples of augmenter output length'#print(datum.shape)batch_data[i][:] = self.postprocess_data(datum)if self.provide_label is not None:batch_label[i][:] = labeli += 1except StopIteration:if i<batch_size:raise StopIteration#print('next end', batch_size, i)_label = Noneif self.provide_label is not None:_label = [batch_label]return io.DataBatch([batch_data], _label, batch_size - i)

                                            二:insighteface tripletloss实现中的loss部分

源码:train_triplet.py

nembedding = mx.symbol.L2Normalization(embedding, mode='instance', name='fc1n')#在第一个维度上切分三分,分别是anchor,positive,negativeanchor = mx.symbol.slice_axis(nembedding, axis=0, begin=0, end=args.per_batch_size//3)positive = mx.symbol.slice_axis(nembedding, axis=0, begin=args.per_batch_size//3, end=2*args.per_batch_size//3)negative = mx.symbol.slice_axis(nembedding, axis=0, begin=2*args.per_batch_size//3, end=args.per_batch_size)#tripletloss 实现ap = anchor - positivean = anchor - negativeap = ap*apan = an*anap = mx.symbol.sum(ap, axis=1, keepdims=1) #(T,1)an = mx.symbol.sum(an, axis=1, keepdims=1) #(T,1)triplet_loss = mx.symbol.Activation(data = (ap-an+args.triplet_alpha), act_type='relu')triplet_loss = mx.symbol.mean(triplet_loss)#triplet_loss = mx.symbol.sum(triplet_loss)/(args.per_batch_size//3)triplet_loss = mx.symbol.MakeLoss(triplet_loss)#预测值和loss值的合并,将预测值堵塞反向传播out_list = [mx.symbol.BlockGrad(embedding)]out_list.append(mx.sym.BlockGrad(gt_label))out_list.append(triplet_loss)out = mx.symbol.Group(out_list)

insightface tripletloss源码阅读相关推荐

  1. 应用监控CAT之cat-client源码阅读(一)

    CAT 由大众点评开发的,基于 Java 的实时应用监控平台,包括实时应用监控,业务监控.对于及时发现线上问题非常有用.(不知道大家有没有在用) 应用自然是最初级的,用完之后,还想了解下其背后的原理, ...

  2. centos下将vim配置为强大的源码阅读器

    每日杂事缠身,让自己在不断得烦扰之后终于有了自己的清静时光来熟悉一下我的工具,每次熟悉源码都需要先在windows端改好,拖到linux端,再编译.出现问题,还得重新回到windows端,这个过程太耗 ...

  3. 源码阅读:AFNetworking(十六)——UIWebView+AFNetworking

    该文章阅读的AFNetworking的版本为3.2.0. 这个分类提供了对请求周期进行控制的方法,包括进度监控.成功和失败的回调. 1.接口文件 1.1.属性 /**网络会话管理者对象*/ @prop ...

  4. 源码阅读:SDWebImage(六)——SDWebImageCoderHelper

    该文章阅读的SDWebImage的版本为4.3.3. 这个类提供了四个方法,这四个方法可分为两类,一类是动图处理,一类是图像方向处理. 1.私有函数 先来看一下这个类里的两个函数 /**这个函数是计算 ...

  5. mybatis源码阅读

    说下mybatis执行一个sql语句的流程 执行语句,事务等SqlSession都交给了excutor,excutor又委托给statementHandler SimpleExecutor:每执行一次 ...

  6. 24 UsageEnvironment使用环境抽象基类——Live555源码阅读(三)UsageEnvironment

    24 UsageEnvironment使用环境抽象基类--Live555源码阅读(三)UsageEnvironment 24 UsageEnvironment使用环境抽象基类--Live555源码阅读 ...

  7. Transformers包tokenizer.encode()方法源码阅读笔记

    Transformers包tokenizer.encode()方法源码阅读笔记_天才小呵呵的博客-CSDN博客_tokenizer.encode

  8. 源码阅读笔记 BiLSTM+CRF做NER任务 流程图

    源码阅读笔记 BiLSTM+CRF做NER任务(二) 源码地址:https://github.com/ZhixiuYe/NER-pytorch 本篇正式进入源码的阅读,按照流程顺序,一一解剖. 一.流 ...

  9. 源码阅读:AFNetworking(八)——AFAutoPurgingImageCache

    该文章阅读的AFNetworking的版本为3.2.0. AFAutoPurgingImageCache该类是用来管理内存中图片的缓存. 1.接口文件 1.1.AFImageCache协议 这个协议定 ...

最新文章

  1. Java项目:基于Jsp实现网上定餐系统
  2. Scala(三):类
  3. Anaconda :利用Anaconda Prompt (Anaconda3)建立、设计不同python版本及对应库函数环境之详细攻略
  4. 11-Reliability, Availability, and Serviceability (RAS) Extensions
  5. Cpp 对象模型探索 / operator new、operator delete、operator new[] 和 operator delete [] 重载
  6. 《荣耀战魂》的环境设计制作经验
  7. 协议簇:TCP 解析: 基础
  8. CDQZ_Training 2012-05-24 聪明的打字员
  9. Android Studio系列教程一--下载与安装
  10. 《Finite-Element Neural Networks for Solving Differential Equations》梳理
  11. java.sql.SQLException: Io 异常: The Network Adapter could not establish the connection 解决
  12. Firefox 2015 最受国人欢迎的十大扩展
  13. 数学实验基于matlab软件,数学实验:基于MATLAB软件
  14. 一文读懂社交网络分析(附应用、前沿、学习资源)
  15. 机器视觉硬件(焦距和景深的计算)
  16. macOS Monterey 12.2 (21D49) 正式版 ISO、IPSW、PKG 下载
  17. STM32F103标准库开发---目录
  18. strong标签和b标签,dfn标签,abbr和acronym标签,em和i标签完整介绍
  19. 做一个很出色的程序员
  20. HashMap的实现原理(简要概述)

热门文章

  1. VS-c++控制台打印彩色字
  2. AES - JAVA
  3. T1049晶晶赴约会 (信息学一本通C++)
  4. Essential COM
  5. cat EOF追加与覆盖
  6. debian服务器系统安装,ikoula独立服务器dd安装debian8系统
  7. 你能帮忙数清天上有几颗星星吗?
  8. UPC 2020年春混合个人训练第十四场
  9. 前端js使用又拍云绕过服务器直接上传图片到云端
  10. Menu菜单用法全面讲解