参考:https://github.com/layumi/Person_reID_baseline_pytorch/tree/master/tutorial

1、环境的安装

系统的基础环境:

  • ubantu16.04
  • CUDA9.0+Cudnn7.4.2
  • Python3.7.4
  • Anaconda 3

创建虚拟环境

conda create -n re_id python=3.7.4
source activate re_id

安装Pytorch

根据CUDA的版本来安装:
https://pytorch.org/get-started/previous-versions/

conda install pytorch==1.0.1 torchvision==0.2.2 cudatoolkit=9.0 -c pytorch

安装 yacs

git clone https://github.com/rbgirshick/yacs
cd yacs
python setup.py install

安装其他依赖库

pip install pretrainedmodels
conda install matplotlib
conda install future
pip install torchvision
pip install tensorboardX
pip install tensorflow -i https://pypi.mirrors.ustc.edu.cn/simple
conda install scipy
conda install Cython

2、开始

数据集和代码的准备

数据集:Market-1501
代码:Practical-Baseline

2.1训练

2.1.1:数据的准备(python prepare.py)

数据集分布如下:

├── Market/
│   ├── bounding_box_test/          /* Files for testing (candidate images pool)
│   ├── bounding_box_train/         /* Files for training
│   ├── gt_bbox/                    /* Files for multiple query testing
│   ├── gt_query/                   /* We do not use it
│   ├── query/                      /* Files for testing (query images)
│   ├── readme.txt
1. "bounding_box_test" – 19732张图片,测试集,也是所谓的gallery参考图像集;2. "bounding_box_train" – 12936张图片,训练集;3. "query" – 3368张query图片,即要查询的图片,在 "bounding_box_test"中执行搜索;4. "gt_bbox" – 25259张图片(人工标注),对应test和train数据集中1501个个体,用于区分"good"、“junk"和"distractors”;5. "gt_query" – 对于3368张query图片的每一个,都有"good"和"junk"相关的图像(包含相同个体),这个文件夹包含了"good"和"junk"图像的索引,用在性能评估中。

打开代码prepare.py。 将第五行的地址改为你本地的地址,比如 \home\zzd\Download\Market,然后在终端中运行代码。
记得所有操作都在刚刚创建的虚拟环境下

python prepare.py

运行后文件的改变如下:

├── Market/
│   ├── bounding_box_test/          /* Files for testing (candidate images pool)
│   ├── bounding_box_train/         /* Files for training
│   ├── gt_bbox/                    /* Files for multiple query testing
│   ├── gt_query/                   /* We do not use it
│   ├── query/                      /* Files for testing (query images)
│   ├── readme.txt
│   ├── pytorch/
│       ├── train/                   /* train
│           ├── 0002
|           ├── 0007
|           ...
│       ├── val/                     /* val
│       ├── train_all/               /* train+val
│       ├── query/                   /* query files
│       ├── gallery/                 /* gallery files

2.1.2:搭建神经网络模型(model.py)

我们可以使用预先训练好的网络结构,例如“ AlexNet”,“ VGG16”,“ ResNet”和“ DenseNet”。 通常,经过预训练的网络结构有助于保留更好的性能,因为它保留了ImageNet的优点[1].

在 pytorch中, 两行代码就可以导入模型:

from torchvision import models
model = models.resnet50(pretrained=True)

但是我们需要稍微调整一下网络结构。 Market-1501中有751个类别(不同的人),与ImageNet中的1,000个类别所不同。 因此,在这里我们修正模型以使用分类器。

import torch
import torch.nn as nn
from torchvision import models# Define the ResNet50-based Model
class ft_net(nn.Module):def __init__(self, class_num = 751):super(ft_net, self).__init__()#load the modelmodel_ft = models.resnet50(pretrained=True) # change avg pooling to global poolingmodel_ft.avgpool = nn.AdaptiveAvgPool2d((1,1))self.model = model_ftself.classifier = ClassBlock(2048, class_num) #define our classifier.def forward(self, x):x = self.model.conv1(x)x = self.model.bn1(x)x = self.model.relu(x)x = self.model.maxpool(x)x = self.model.layer1(x)x = self.model.layer2(x)x = self.model.layer3(x)x = self.model.layer4(x)x = self.model.avgpool(x)x = torch.squeeze(x)x = self.classifier(x) #use our classifier.return x

为什么我们使用AdaptiveAvgPool2d? AvgPool2d和AdaptiveAvgPool2d有什么区别? 该模型现在有参数吗? 如何在新的网络层中初始化参数?

仔细看看model.py吧。
这里我们不需要修改model.py,已经修改好了

2.1.3:开始训练(train.py)

  • 训练方法 【ResNet-50】
python train.py --gpu_ids 0 --name ft_ResNet50 --train_all --batchsize 32  --data_dir /home/huan/deep_learning/ReID/Person_reID_baseline_pytorch/Market/pytorch/
  • 训练方法 【ResNet-50(alltricks)
python train.py --warm_epoch 5 --stride 1 --erasing_p 0.5 --batchsize 8 --lr 0.02 --name warm5_s1_b8_lr2_p0.5 --gpu_ids 0  --data_dir /home/huan/deep_learning/ReID/Person_reID_baseline_pytorch/Market/pytorch/
--gpu_ids 运行的gpu型号--name 模型名字--data_dir 训练数据路径--train_all 所有用来训练的图像.--batchsize batch大小--erasing_p 随机删除参数.

2.1.4:开始测试(test.py)

  • 测试Market-1501数据集
python test.py --gpu_ids 0 --name ft_ResNet50 --test_dir /home/huan/deep_learning/ReID/Person_reID_baseline_pytorch/Market/pytorch/  --batchsize 32 --which_epoch 19
  • 测试自己的数据集,并生成json格式的文件。

行人重识别的代码复现相关推荐

  1. 本周新出开源计算机视觉代码汇总(含图像超分辨、视频目标分割、行人重识别、点云识别等)...

    点击我爱计算机视觉标星,更快获取CVML新技术 今天汇总了本周新出的计算机视觉开源代码.(有部分已经有git地址但还没上传代码) 共有12份来自前沿计算机视觉研究的代码,CV君数了数,竟然发现其中10 ...

  2. 行人重识别 代码阅读(来自郑哲东 简单行人重识别代码到88%准确率)

    来自郑哲东 简单行人重识别代码到88%准确率 阅读代码 prepare.py 数据结构 部分代码 一些函数 model.py ClassBlock ResNet50 train.py 一些参数 使用f ...

  3. 入门行人重识别 尝试跑(郑哲东 简单行人重识别代码到88%准确率)过程

    来自郑哲东 简单行人重识别代码到88%准确率 运行代码和参考步骤 试运行-第一部分 prepare.py model.py train.py 试运行-第二部分 test.py 运行代码和参考步骤 代码 ...

  4. ReID行人重识别(训练+检测,附代码),可做图像检索,陌生人检索等项目

    利用ReID和目标检测对视频进行检测,可以对视频中的人进行重识别,支持更换数据集可以做车辆重识别等.可应用于图像.视频检索,行人跟踪等 在以前学习ReID的时候,是跟着下面视频学习的,该论文和代码也可 ...

  5. CVPR2021 行人重识别/Person Re-identification 论文+开源代码汇总

    点击上方"AI算法与图像处理",选择加"星标"或"置顶"重磅干货,第一时间送达 行人重识别(Person re-identification ...

  6. 点云编码是计算机视觉吗,本周新出开源计算机视觉代码汇总(含图像超分辨、视频目标分割、行人重识别、点云识别等)...

    今天汇总了本周新出的计算机视觉开源代码.(有部分已经有git地址但还没上传代码) 共有12份来自前沿计算机视觉研究的代码,CV君数了数,竟然发现其中10份代码所属论文的第一作者是华人! 可见,华人学者 ...

  7. 利用行人重识别代码训练车辆重识别

    一.参考 行人重识别代码:Person_reID_baseline_pytorch-master 代码链接:https://github.com/layumi/Person_reID_baseline ...

  8. 代码开源!!行人检测与行人重识别结合 person search

    0 前言 最近在做自己课题相关的小项目,行人检测与行人重识别进行结合进行场景图片进行特定行人的检索由于也比较着急,先利用现有的模型搭建了简单的demo,简单效果展示如下,速度大概在20FPS: 先是给 ...

  9. 行人重识别0-08:DG-Net(ReID)-代码无死角解读(4)-网络Es编码解码过程

    以下链接是个人关于DG-Net(行人重识别ReID)所有见解,如有错误欢迎大家指出,我会第一时间纠正.有兴趣的朋友可以加微信:17575010159 相互讨论技术.若是帮助到了你什么,一定要记得点赞! ...

最新文章

  1. Velocity知识点总结
  2. 领域驱动设计(2)怎么使用沟通
  3. bzoj2875: [Noi2012]随机数生成器
  4. vs code gitee使用_实用为王!来看看Gitee上五款新鲜出炉的WebUI组件
  5. 诗与远方:无题(八十四)- 自己醉了
  6. 农夫山泉终于“玩砸了”
  7. linux vmware 服务,学习笔记:在Linux虚拟机上搭建node服务
  8. Package Control 使用
  9. IDL实现遥感影像融合(批量)TASK(三)
  10. ascll编码表图片_ASCII码一览表,ASCII码对照表
  11. java使用elasticsearch进行模糊查询-已在项目中实际应用
  12. mq消息队列到底是什么
  13. vue 图片写入文字,图片注入文字,图片添加文字
  14. 君望永远--纠缠在爱情的起点上 (转载)
  15. M3U8文件简介及在线播放器
  16. Linux 安装 Intel 网卡驱动
  17. 使用navicat导入SQL语句的教程
  18. JavaScript——ES8新特性
  19. vs2015已停止工作,事件名称APPCRASH 故障模块KERNELBASE.dll
  20. POST、GET请求及对应的参数获取

热门文章

  1. 李建忠设计模式之”数据结构“模式
  2. Web前端大作业—里约热内卢奥运会(html+css+javascript)
  3. ol-ext transform 对象,旋转、拉伸、放大(等比例缩放),事件监听
  4. 【边做项目边学Android】手机安全卫士07-手机防盗之进入限制
  5. 3DMAX快速入门 界面介绍【上】
  6. 手工命令行打包java工程为war包
  7. SSM实现养老院管理系统
  8. 电脑 耳机播放声音,左右耳朵不平衡解决方法
  9. 夯实第一超市地位 京东超市成超10大品类超50家品牌线上最大渠道
  10. 网址怎样收藏到我计算机桌面,电脑应该如何收藏网址