行人重识别 代码阅读(来自郑哲东 简单行人重识别代码到88%准确率)
来自郑哲东 简单行人重识别代码到88%准确率
- 阅读代码
- prepare.py
- 数据结构
- 部分代码
- 一些函数
- model.py
- ClassBlock
- ResNet50
- train.py
- 一些参数
- 使用fp16
- 预处理
- 数据集迭代器
- 训练模块
阅读代码
因为自己对代码不擅长,所以在参考这个博主的一系列博文,如Person_reID_baseline_pytorch 源码解析之 prepare.py等
prepare.py
数据结构
这部分代码主要用于重构数据集,方便之后运行代码。
在原来的数据集中加入了pytorch
这个文件夹。
原来的
├── Market/
│ ├── bounding_box_test/ /* Files for testing (candidate images pool)
│ ├── bounding_box_train/ /* Files for training
│ ├── gt_bbox/ /* We do not use it
│ ├── gt_query/ /* Files for multiple query testing
│ ├── query/ /* Files for testing (query images)
│ ├── readme.txt
运行代码后的
├── Market/
│ ├── bounding_box_test/ /* Files for testing (candidate images pool)
│ ├── bounding_box_train/ /* Files for training
│ ├── gt_bbox/ /* We do not use it
│ ├── gt_query/ /* Files for multiple query testing
│ ├── query/ /* Files for testing (query images)
│ ├── readme.txt
│ ├── pytorch/
│ ├── train/ /* train,包含train_all除val剩下的图片
│ ├── 0002
| ├── 0007
| ...
│ ├── val/ /* val,包含train_all每个子文件夹的第一张图片
│ ├── train_all/ /* train+val,包含bounding_box_train 的所有图片
│ ├── query/ /* query files,包含所有待测行人的图片
│ ├── gallery/ /* gallery files,包含bounding_box_test 的所有图片
部分代码
#query
query_path = download_path + '/query'
query_save_path = download_path + '/pytorch/query'
if not os.path.isdir(query_save_path):os.mkdir(query_save_path)for root, dirs, files in os.walk(query_path, topdown=True):for name in files:if not name[-3:]=='jpg':continueID = name.split('_')src_path = query_path + '/' + namedst_path = query_save_path + '/' + ID[0] if not os.path.isdir(dst_path):os.mkdir(dst_path)copyfile(src_path, dst_path + '/' + name)
主要就是把原来数据集中的Market/query
文件夹中的所有图像按照序号放在了新建的Market/pytorch/query/....
文件夹。
如以0001
开头的图像放在名称为Market/pytorch/query/0001
的文件夹中
其他部分也大同小异,这里不再赘述。
一些函数
os.path.isdir(path) #判断path是否为目录
os.mkdir(path) #创建目录
os.walk(top, topdown=True, onerror=None, followlinks=False)
# top 要遍历的目录。
# topdown 遍历方式:从上到下遍历或者从下到上遍历。
# onerror 用来设置出现错误时的处理函数(该函数接受一个OSError的实例作为参数),设置为空则不作处理。
# followlinks 是否要跟随目录下的链接去继续遍历。要注意的是,os.walk不会记录已经遍历的目录,所以跟随链接遍历的话有可能一直循环调用下去。
model.py
该脚本实现了多种行人重识别模型,在看ResNet50之前,先看一下定义的分类器
ClassBlock
# Defines the new fc layer and classification layer
# |--Linear--|--bn--|--relu--|--Linear--|
class ClassBlock(nn.Module):def __init__(self, input_dim, class_num, droprate, relu=False, bnorm=True, linear=512, return_f = False):super(ClassBlock, self).__init__()self.return_f = return_fadd_block = []if linear>0:add_block += [nn.Linear(input_dim, linear)]else:linear = input_dimif bnorm:add_block += [nn.BatchNorm1d(linear)]if relu:add_block += [nn.LeakyReLU(0.1)]if droprate>0:add_block += [nn.Dropout(p=droprate)]add_block = nn.Sequential(*add_block)add_block.apply(weights_init_kaiming)classifier = []classifier += [nn.Linear(linear, class_num)]classifier = nn.Sequential(*classifier)classifier.apply(weights_init_classifier)self.add_block = add_blockself.classifier = classifierdef forward(self, x):x = self.add_block(x)if self.return_f:f = xx = self.classifier(x)return [x,f]else:x = self.classifier(x)return x
ResNet50
这里使用了pytorch预训练好的模型
# Define the ResNet50-based Model
class ft_net(nn.Module):def __init__(self, class_num=751, droprate=0.5, stride=2, circle=False, ibn=False, linear_num=512):## 下面有介绍super(ft_net, self).__init__()super(ft_net, self).__init__()model_ft = models.resnet50(pretrained=True)if ibn==True:## 下面有介绍torch.hub.load()model_ft = torch.hub.load('XingangPan/IBN-Net', 'resnet50_ibn_a', pretrained=True)# avg pooling to global pooling## 也就是平均池化改为了自适应平均池化if stride == 1:model_ft.layer4[0].downsample[0].stride = (1,1)model_ft.layer4[0].conv2.stride = (1,1)## 是使得池化后的每个通道上的大小是一个1x1的,也就是每个通道上只有一个像素点model_ft.avgpool = nn.AdaptiveAvgPool2d((1,1))self.model = model_ftself.circle = circle## 定义了自己的分类器self.classifier = ClassBlock(2048, class_num, droprate, linear=linear_num, return_f = circle)def forward(self, x):## 卷积层1x = self.model.conv1(x)## 归一化(Batch Normalization)常用在激活层之前## 可以加快模型训练时的收敛速度,使得模型训练过程更加稳定## 避免梯度爆炸或者梯度消失## 起到一定的正则化作用,几乎代替了Dropout。x = self.model.bn1(x)## 激活层(Rectified linear unit,ReLU)x = self.model.relu(x)## 最大池化x = self.model.maxpool(x)x = self.model.layer1(x)x = self.model.layer2(x)x = self.model.layer3(x)x = self.model.layer4(x)## 平均池化x = self.model.avgpool(x)x = x.view(x.size(0), x.size(1))## 自己的分类器x = self.classifier(x)return x
1) nn.Module #nn.Module是PyTorch体系下所有神经网络模块的基类
来自这篇博文
我们在定义自已的网络的时候,需要继承nn.Module类,并重新实现构造函数__init__构造函数和forward这两个方法。
- 一般把网络中具有可学习参数的层(如全连接层、卷积层等)放在构造函数__init__()中,也可以把不具有参数的层也放在里面;
- 一般把不具有可学习参数的层(如ReLU、dropout、BatchNormanation层)可放在构造函数中,也可不放在构造函数中,如果不放在构造函数__init__里面,则在forward方法里面可以使用nn.functional来代替
- forward方法是必须要重写的,它是实现模型的功能,实现各个层之间的连接关系的核心。
2) super(ft_net, self).__ init __()
来自这篇博文
super(Net, self).__ init__() 是指首先找到Net的父类(比如是类NNet),然后把类Net的对象self转换为类NNet的对象,然后“被转换”的类NNet对象调用自己的init函数,其实简单理解就是子类把父类的__init__()放到自己的__init__()当中,这样子类就有了父类的__init__()的东西。
Net类继承nn.Module,super(Net, self).__ init__() 就是对继承自父类nn.Module的属性进行初始化,而且是用nn.Module的初始化方法来初始化继承的属性。
3) def __init__(self, class_num=751, droprate=0.5, stride=2, circle=False, ibn=False, linear_num=512):# 想知道参数都是什么意思 但是暂时不知道 先放一放
# stride 步长
4) model_ft = torch.hub.load('XingangPan/IBN-Net', 'resnet50_ibn_a', pretrained=True)
来自这篇博文
torch.hub.load(repo_or_dir, model, *args, source=‘github’, force_reload=False, verbose=True, skip_validation=False, **kwargs)
- repo_or_dir:记得设置本地./pytorch/vision 路径
- model 调用入口在:本地./pytorch/vision 路径下的hubconf.py文件里,因此模型名字要在hubconf.py中存在才能调用
- source=‘local’, 记得设置本地,默认:github
- pretrained=True #设定预训练模式,默认False,为了代码清晰,最好还是加上参数赋值.
为了加快学习进度,训练的初期直接加载pretrain模型中预先训练好的参数
pretrained=True 加载网络结构和预训练参数
pretrained=False 只加载网络结构,不加载预训练参数,即不需要用预训练模型的参数来初始化。
train.py
一些参数
1) erasing_p # 随机擦除(Random Erasing, RE)增强
论文:Random Erasing Data Augmentation
具体就是(来自这篇)
作者提出的目的主要是模拟遮挡,从而提高模型泛化能力。如果把物体遮挡一部分后依然能够分类正确,那么肯定会迫使网络利用局部未遮挡的数据进行识别,加大了训练难度,一定程度会提高泛化能力。可以视为add noise的一种,并且与随机裁剪、随机水平翻转具有一定的互补性,综合应用他们,可以取得更好的模型表现,尤其是对噪声和遮挡具有更好的鲁棒性。
具体操作就是:随机选择一个区域,然后采用随机值进行覆盖,模拟遮挡场景。
使用fp16
使用fp16可以让运算量大大降低,会快很多,不过要安装apex(目前还没有尝试过用fp16跑)
#fp16
try:from apex.fp16_utils import *from apex import ampfrom apex.optimizers import FusedSGD
except ImportError: # will be 3.x seriesprint('This is not an error. If you want to use low precision, i.e., fp16, please install the apex with cuda support (https://github.com/NVIDIA/apex) and update pytorch to 1.0')
预处理
transforms在计算机视觉工具包torchvision下,对图像进行预处理,这样可以使得之后运算的速度更快,这里加了一点注释
transform_train_list = [#transforms.RandomResizedCrop(size=128, scale=(0.75,1.0), ratio=(0.75,1.3333), interpolation=3), #Image.BICUBIC)transforms.Resize((h, w), interpolation=3), # 缩放transforms.Pad(10), # 填充transforms.RandomCrop((h, w)), # 在一个随机的位置进行裁剪transforms.RandomHorizontalFlip(), # 以0.5的概率水平翻转给定的PIL图像transforms.ToTensor(), # 图片转张量,同时归一化0-255 ---> 0-1transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) # 用给定的均值和标准差分别对每个通道的数据进行正则化]transform_val_list = [transforms.Resize(size=(h, w),interpolation=3), #Image.BICUBIC transforms.ToTensor(),transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]
数据集迭代器
来自Person_reID_baseline_pytorch 源码解析之 train.py
训练模型时,一般不会一次性把所有数据都加载到模型中。通常采用 mini_batch 的方法,按照 batchsize 的大小将一个 batch 的数据载入到模型中。pytorch 框架支持用 torch.utils.data.DataLoader 作为 dataloader 载入数据。
image_datasets = {}
image_datasets['train'] = datasets.ImageFolder(os.path.join(data_dir, 'train' + train_all),data_transforms['train'])
image_datasets['val'] = datasets.ImageFolder(os.path.join(data_dir, 'val'),data_transforms['val'])dataloaders = {x: torch.utils.data.DataLoader(image_datasets[x], batch_size=opt.batchsize,shuffle=True, num_workers=2, pin_memory=True) # 8 workers may work fasterfor x in ['train', 'val']}
将 image_datasets[‘train’] 和 image_datasets[‘val’] 输入 torch.utils.data.DataLoader 后,获得了两个迭代器 dataloaders[‘train’] and dataloaders[‘val’]
训练模块
# Iterate over data.for data in dataloaders[phase]:# get a batch of inputsinputs, labels = datanow_batch_size,c,h,w = inputs.shapeif now_batch_size<opt.batchsize: # skip the last batchcontinue# print(inputs.shape)# wrap them in Variable, if gpu is used, we transform the data to cuda.if use_gpu:inputs = Variable(inputs.cuda())labels = Variable(labels.cuda())else:inputs, labels = Variable(inputs), Variable(labels)# zero the parameter gradientsoptimizer.zero_grad() # 梯度初始化为零#-------- forward --------outputs = model(inputs) # 前向传播求出预测的值_, preds = torch.max(outputs.data, 1) # 见下面的解释loss = criterion(outputs, labels) # 求loss#-------- backward + optimize -------- # only if in training phaseif phase == 'train':loss.backward() # 反向传播求梯度optimizer.step() # 更新所有参数
_, preds = torch.max(outputs.data, 1)
相关解释来自PyTorch系列 | _, predicted = torch.max(outputs.data, 1)的理解
关于为什么要使用 optimizer.zero_grad() 来自为什么要使用 zero_grad()?
行人重识别 代码阅读(来自郑哲东 简单行人重识别代码到88%准确率)相关推荐
- 入门行人重识别 尝试跑(郑哲东 简单行人重识别代码到88%准确率)过程
来自郑哲东 简单行人重识别代码到88%准确率 运行代码和参考步骤 试运行-第一部分 prepare.py model.py train.py 试运行-第二部分 test.py 运行代码和参考步骤 代码 ...
- OPenCV4-颜色识别(一)调色板和简单的颜色识别
OPenCV4-颜色识别(一)调色板和简单的颜色识别 使用 OPenCV4 做颜色识别十分简单.本文章使用 python 语言来实现一个调色板和简单的颜色识别. 1.调色板 绘制一个调色板对颜色识别非 ...
- 嵌入式OCR+RFID识别电子护照阅读器模块MRZ码电子证件识别模组的应用与攻略
嵌入式OCR+RFID识别电子护照阅读器模块|MRZ码电⼦证件验证识别模组是一款USB通讯模式,体积小巧.便于安装,方便嵌入至各类终端设备中,适用于出境游旅行社.海关检查.口岸出入境检查.使领馆签证登 ...
- JavaSE基础知识(五)--面向对象代码实现初步(实现一个简单的类类型代码)
Java SE 是什么,包括哪些内容(五)? 本文内容参考自Java8标准 一.面向对象(代码实现): 首先,在这里我需要说明一个根本性的问题:实际上,面向对象编程包括了两部分,一个是你的编程思想,一 ...
- python文字识别并获取位置_python实现简单的文字识别
将图片翻译成文字一般被称为光学文字识别(Optical Character Recognition,OCR).今天我们用到的就是一个OCR 库--Tesseract. 首先要安装Tesseract,除 ...
- android 代码浏览,Webview实现android简单的浏览器实例代码
WebView是Android中一个非常实用的组件,它和Safai.Chrome一样都是基于Webkit网页渲染引擎,可以通过加载HTML数据的方式便捷地展现软件的界面,下面通过本文给大家介绍Webv ...
- python做数据可视化的代码_Python数据可视化正态分布简单分析及实现代码
Python说来简单也简单,但是也不简单,尤其是再跟高数结合起来的时候... 正态分布(Normaldistribution),也称"常态分布",又名高斯分布(Gaussiandi ...
- 二维正态分布图python代码_Python数据可视化正态分布简单分析及实现代码
Python说来简单也简单,但是也不简单,尤其是再跟高数结合起来的时候... 正态分布(Normaldistribution),也称"常态分布",又名高斯分布(Gaussiandi ...
- 代码阅读器 android,适用于Android的条形码/ Qr代码阅读器
我用zxing构建了我的应用程序.你需要一些编码.首先包括core.jar,它在core / core.jar,在你的构建路径中,然后转到他们的客户端,在 android /-./ com.googl ...
最新文章
- 蟑螂背上芯片板子,组队去救人类
- Spring MVC的框架组件
- 动态添加内容到百度搜索框里
- halcon create_ocr_class_svm 使用SVM分类器创建OCR分类器
- shell命令查阅端口信息_linux运维实用的42个常用命令总结
- Spring Data JPA 实例查询
- 别再学习框架了,看看这些让你起飞的计算机基础知识
- QT--QDockWidget 停靠窗口
- 烧钱两年,做事对得起工资,也要对得起公司这份决心
- c语言王者荣耀制作,易语言制作王者荣耀刷金币脚本的代码
- 点击图片放大功能 jquery
- 【GEE笔记】最大类间方差法(otsu、大津法)算法实现——计算阈值、图像二值化分割
- 企业微信和个人微信的区别是什么?
- CodeIgniter session过期时间问题
- oracle的gca的文件,GCA文件扩展名 - 什么是.gca以及如何打开? - ReviverSoft
- php 大转盘抽奖概率 角度,在线抽奖大转盘和概率计算
- 孩子数学成绩不好怎么办_孩子数学成绩差怎么才能快速提高
- EXCEL干货(1-1): 基本表格操作
- 简单的proxy之TinyHTTPProxy.py
- 2020全球创业者城市Top50