深度余弦度量学习cosine-metric-learning在VeRi数据集调试

  • 训练部分
  • 数据读取部分

训练部分

在VeRi数据集上调试深度余弦度量学习时,出现了很多bug,其中有一个是由于输入图像的维度和placehold的维度不一致报错,导致我花费了很多时间和精力去找问题,后来通过更改了源代码,才得以调通。
在train_app.py上
原来代码:

filename_var = tf.placeholder(tf.string, (None, ))
image_var = tf.map_fn(lambda x:tf.image.decode_jpeg(tf.read_file(x), channels=num_channels),filename_var, back_prop=False, dtype=tf.float32)
image_var = tf.image.resize_images(image_var, image_shape[:2])

更改后的代码:

filename_var = tf.placeholder(tf.string, (None, ))
image_var = tf.map_fn(lambda x: tf.image.resize_images(tf.image.decode_jpeg(tf.read_file(x), channels=num_channels), image_shape[:2]),filename_var, back_prop=False, dtype=tf.float32)

**

数据读取部分

**
VeRi.py

# vim: expandtab:ts=4:sw=4
import os
import numpy as np
import cv2
import scipy.io as sio# The maximum person ID in the dataset.
MAX_LABEL = 769   # VeRi max label IMAGE_SHAPE = 128, 64, 3  # 此处为了简单不改变模型结构,保持与Market1501一致def _parse_filename(filename):filename_base, ext = os.path.splitext(filename)if '.' in filename_base:# Some images have double filename extensions.filename_base, ext = os.path.splitext(filename_base)if ext != ".jpg":return Noneperson_id, cam_seq, frame_idx, detection_idx = filename_base.split('_')return int(person_id), int(cam_seq[1]), filename_base, extdef read_train_split_to_str(dataset_dir):filenames, ids, camera_indices = [], [], []image_dir = os.path.join(dataset_dir, "image_train")for filename in sorted(os.listdir(image_dir)):meta_data = _parse_filename(filename)if meta_data is None:# This is not a valid filename (e.g., Thumbs.db).continuefilenames.append(os.path.join(image_dir, filename))ids.append(meta_data[0])camera_indices.append(meta_data[1])return filenames, ids, camera_indicesdef read_train_split_to_image(dataset_dir):filenames, ids, camera_indices = read_train_split_to_str(dataset_dir)images = np.zeros((len(filenames), 128, 64, 3), np.uint8)for i, filename in enumerate(filenames):# 将图片resize为(128, 64, 3) 保持与之前结构一致images[i] = cv2.resize(cv2.imread(filename, cv2.IMREAD_COLOR), (64, 128))ids = np.asarray(ids, np.int64)camera_indices = np.asarray(camera_indices, np.int64)return images, ids, camera_indicesdef read_test_split_to_str(dataset_dir):# Read gallery.gallery_filenames, gallery_ids = [], []image_dir = os.path.join(dataset_dir, "bounding_box_test")for filename in sorted(os.listdir(image_dir)):meta_data = _parse_filename(filename)if meta_data is None:# This is not a valid filename (e.g., Thumbs.db).continuegallery_filenames.append(os.path.join(image_dir, filename))gallery_ids.append(meta_data[0])# Read queries.query_filenames, query_ids, query_junk_indices = [], [], []image_dir = os.path.join(dataset_dir, "query")for filename in sorted(os.listdir(image_dir)):meta_data = _parse_filename(filename)if meta_data is None:# This is not a valid filename (e.g., Thumbs.db).continuefilename_base = meta_data[2]junk_matfile = filename_base + "_junk.mat"mat = sio.loadmat(os.path.join(dataset_dir, "gt_query", junk_matfile))if np.any(mat["junk_index"] < 1):indices = []else:# MATLAB to Python index.indices = list(mat["junk_index"].astype(np.int64).ravel() - 1)query_junk_indices.append(indices)query_filenames.append(os.path.join(image_dir, filename))query_ids.append(meta_data[0])# The following matrix maps from query (row) to gallery image (column) such# that element (i, j) evaluates to 0 if query i and gallery image j should# be excluded from computation of the evaluation metrics and 1 otherwise.good_mask = np.ones((len(query_filenames), len(gallery_filenames)), np.float32)for i, junk_indices in enumerate(query_junk_indices):good_mask[i, junk_indices] = 0.return gallery_filenames, gallery_ids, query_filenames, query_ids, good_maskdef read_test_split_to_image(dataset_dir):gallery_filenames, gallery_ids, query_filenames, query_ids, good_mask = (read_test_split_to_str(dataset_dir))gallery_images = np.zeros((len(gallery_filenames), 128, 64, 3), np.uint8)for i, filename in enumerate(gallery_filenames):# 将图片resize为(128, 64, 3) 保持与之前结构一致gallery_images[i] = cv2.resize(cv2.imread(filename, cv2.IMREAD_COLOR), (64, 128))query_images = np.zeros((len(query_filenames), 128, 64, 3), np.uint8)for i, filename in enumerate(query_filenames):query_images[i] = cv2.imread(filename, cv2.IMREAD_COLOR)gallery_ids = np.asarray(gallery_ids, np.int64)query_ids = np.asarray(query_ids, np.int64)return gallery_images, gallery_ids, query_images, query_ids, good_mask

train_VeRi_dataset.py

# vim: expandtab:ts=4:sw=4
import functools
import os
import numpy as np
import scipy.io as sio
import train_app
from datasets import market1501
from datasets import util
import nets.deep_sort.network_definition as netclass VeRi_dataset(object):def __init__(self, dataset_dir, num_validation_y=0.1, seed=1):# 切分训练集的10%为验证集self._dataset_dir = dataset_dirself._num_validation_y = num_validation_yself._seed = seeddef read_train(self):filenames, ids, camera_indices = VeRi.read_train_split_to_str(self._dataset_dir)train_indices, _ = util.create_validation_split(np.asarray(ids, np.int64), self._num_validation_y, self._seed)filenames = [filenames[i] for i in train_indices]ids = [ids[i] for i in train_indices]camera_indices = [camera_indices[i] for i in train_indices]return filenames, ids, camera_indicesdef read_validation(self):filenames, ids, camera_indices = VeRi.read_train_split_to_str(self._dataset_dir)_, valid_indices = util.create_validation_split(np.asarray(ids, np.int64), self._num_validation_y, self._seed)filenames = [filenames[i] for i in valid_indices]ids = [ids[i] for i in valid_indices]camera_indices = [camera_indices[i] for i in valid_indices]return filenames, ids, camera_indicesdef read_test(self):return VeRi.read_test_split_to_str(self._dataset_dir)def main():arg_parser = train_app.create_default_argument_parser("VeRi")arg_parser.add_argument("--dataset_dir", help="Path to Market1501 dataset directory.",default="data/VeRi")args = arg_parser.parse_args()dataset = VeRi(args.dataset_dir, num_validation_y=0.1, seed=1234)if args.mode == "train":train_x, train_y, _ = dataset.read_train()print("Train set size: %d images, %d identities" % (len(train_x), len(np.unique(train_y))))network_factory = net.create_network_factory(is_training=True, num_classes=VeRi.MAX_LABEL + 1,add_logits=args.loss_mode == "cosine-softmax")train_kwargs = train_app.to_train_kwargs(args)train_app.train_loop(net.preprocess, network_factory, train_x, train_y,num_images_per_id=4, image_shape=VeRi.IMAGE_SHAPE,**train_kwargs)elif args.mode == "eval":valid_x, valid_y, camera_indices = dataset.read_validation()print("Validation set size: %d images, %d identities" % (len(valid_x), len(np.unique(valid_y))))network_factory = net.create_network_factory(is_training=False, num_classes=VeRi.MAX_LABEL + 1,add_logits=args.loss_mode == "cosine-softmax")eval_kwargs = train_app.to_eval_kwargs(args)train_app.eval_loop(net.preprocess, network_factory, valid_x, valid_y, camera_indices,image_shape=VeRi.IMAGE_SHAPE, **eval_kwargs)elif args.mode == "export":# Export one specific model.gallery_filenames, _, query_filenames, _, _ = dataset.read_test()network_factory = net.create_network_factory(is_training=False, num_classes=VeRi.MAX_LABEL + 1,add_logits=False, reuse=None)gallery_features = train_app.encode(net.preprocess, network_factory, args.restore_path,gallery_filenames, image_shape=VeRi.IMAGE_SHAPE)sio.savemat(os.path.join(args.sdk_dir, "feat_test.mat"),{"features": gallery_features})network_factory = net.create_network_factory(is_training=False, num_classes=VeRi.MAX_LABEL + 1,add_logits=False, reuse=True)query_features = train_app.encode(net.preprocess, network_factory, args.restore_path,query_filenames, image_shape=VeRi.IMAGE_SHAPE)sio.savemat(os.path.join(args.sdk_dir, "feat_query.mat"),{"features": query_features})elif args.mode == "finalize":network_factory = net.create_network_factory(is_training=False, num_classes=VeRi.MAX_LABEL + 1,add_logits=False, reuse=None)train_app.finalize(functools.partial(net.preprocess, input_is_bgr=True),network_factory, args.restore_path,image_shape=VeRi.IMAGE_SHAPE,output_filename="./VeRi.ckpt")elif args.mode == "freeze":network_factory = net.create_network_factory(is_training=False, num_classes=VeRi.MAX_LABEL + 1,add_logits=False, reuse=None)train_app.freeze(functools.partial(net.preprocess, input_is_bgr=True),network_factory, args.restore_path,image_shape=VeRi.IMAGE_SHAPE,output_filename="./VeRi.pb")else:raise ValueError("Invalid mode argument.")if __name__ == "__main__":main()

版权归属本作者所用,转载需引用

深度余弦度量学习(cosine-metric-learning)在VeRi数据集调试相关推荐

  1. 度量学习(Metric Learning)基础概念

    一.什么是度量学习? 度量学习 (Metric Learning) == 距离度量学习 (Distance Metric Learning,DML) == 相似度学习. 在数学中,一个度量(或距离函数 ...

  2. mark一波——度量学习(metric learning)

    度量学习(Metric Learning) 度量(Metric)的定义 在数学中,一个度量(或距离函数)是一个定义集合中元素之间距离的函数.一个具有度量的集合被称为度量空间. 1 为什么要用度量学习? ...

  3. 度量学习(metric learning)

    转自:http://blog.csdn.net/nehemiah_li/article/details/44230053 度量学习(Metric Learning) 度量(Metric)的定义 在数学 ...

  4. 度量学习(Metric Learning)【AMSoftmax、Arcface】

    一.概述 度量学习 (Metric Learning) == 距离度量学习 (Distance Metric Learning,DML) == 相似度学习. 在数学中,一个度量(或距离函数)是一个定义 ...

  5. 度量学习(Metric learning、损失函数、triplet、三元组损失、fastreid)

    定义 Metric learning 是学习一个度量相似度的距离函数:相似的目标离得近,不相似的离得远. 一般来说,DML包含三个部分,如下图.. 1)特征提取网络:map embedding 2)采 ...

  6. 度量学习(Metric learning)—— 基于分类损失函数(softmax、交叉熵、cosface、arcface)

    概述 首先,我们把loss归为两类:一类是本篇讲述的基于softmax的,一类是基于pair对的(如对比损失.三元损失等). 基于pair对的,参考我的另一篇博客: https://blog.csdn ...

  7. 度量学习 度量函数 metric learning deep metric learning 深度度量学习

    曼哈顿距离(CityBlockSimilarity) 同欧式距离相似,都是用于多维数据空间距离的测度. 欧式距离(Euclidean Distance) 用于衡量多维空间中各个点之间的绝对距离.欧式距 ...

  8. 深度度量学习 (metric learning deep metric learning )度量函数总结

    版权声明:本文为博主原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接和本声明. 本文链接:https://blog.csdn.net/qq_16234613/article/ ...

  9. 机器学习: Metric Learning (度量学习)

    Introduction 度量学习 (Metric Learning) == 距离度量学习 (Distance Metric Learning,DML) == 相似度学习 是人脸识别中常用传统机器学习 ...

最新文章

  1. pcb中几个层的解释
  2. 由浅入深CIL系列【目录索引】+ PostSharp AOP编程【目录索引】
  3. 表弟:这数学规律题该怎么解决呢?使用Python你会发现很简单
  4. 解封装(四):avformat_find_stream_info探测获取封装上下文并打印
  5. python语句分号_python 为什么不用分号作终止符?
  6. 渗透常用SQL注入语句合集
  7. 区块链:关键阻力的突破会带来持续的积极情绪
  8. 点击微信公众号菜单发送图片或文本
  9. 你相信吗?这是210万“像素”人工画成的肖像_-Chaz-_新浪博客
  10. 陈安之超级成功法则(1)
  11. 识别速度3.6ms/帧,人像抠图、工业质检、遥感识别,用这一个分割模型就够了
  12. YTU OJ 1329: 手机尾号评分
  13. 使用代码的当前SVN版本构建项目build版号
  14. 保护计算机系统与数据有什么方法,计算机系统开机和硬盘数据保护方法,与其数据保护模块...
  15. 论文检测的时间段是什么时候?
  16. Pytorch深度学习入门与实战(笔记)
  17. 虎书学习笔记6:图形学基础数学(重心坐标系、三维三角形)
  18. 6-2 求圆面积自定义异常类 (15分)
  19. ccna出来能做什么_CCNA的完整形式是什么?
  20. Java后端真实面试题大全(有详细答案)--高频/真题

热门文章

  1. i春秋网络内生安全试验场CTF夺旗赛(第二季)部分Web题WriteUp
  2. linux ubuntu 上安装rar 压缩软件
  3. 数据库大作业代码展示1
  4. 线上直播丨国际人工智能会议AAAI 2021论文北京预讲会,33场报告+31个Poster等你来...
  5. 苹果4s怎么越狱教程_苹果手机桌面图标怎么随意摆放 iPhone桌面图标随意摆放教程...
  6. git stash命令之暂存的操作
  7. Linux-查询登入用户信息
  8. 【愚公系列】2021年12月 攻防世界-简单题-MOBILE-002(app1)
  9. android 热更新之腾讯Bugly 及所遇问题的修改总结
  10. 芯片验证自学,IC验证自学入门教程:ASIC芯片设计流程讲解