来自郑哲东 简单行人重识别代码到88%准确率

  • 阅读代码
    • prepare.py
      • 数据结构
      • 部分代码
      • 一些函数
    • model.py
      • ClassBlock
      • ResNet50
    • train.py
      • 一些参数
      • 使用fp16
    • 预处理
    • 数据集迭代器
    • 训练模块

阅读代码

因为自己对代码不擅长,所以在参考这个博主的一系列博文,如Person_reID_baseline_pytorch 源码解析之 prepare.py等

prepare.py

数据结构

这部分代码主要用于重构数据集,方便之后运行代码。
在原来的数据集中加入了pytorch这个文件夹。
原来的

├── Market/
│   ├── bounding_box_test/          /* Files for testing (candidate images pool)
│   ├── bounding_box_train/         /* Files for training
│   ├── gt_bbox/                    /* We do not use it
│   ├── gt_query/                   /* Files for multiple query testing
│   ├── query/                      /* Files for testing (query images)
│   ├── readme.txt

运行代码后的

├── Market/
│   ├── bounding_box_test/          /* Files for testing (candidate images pool)
│   ├── bounding_box_train/         /* Files for training
│   ├── gt_bbox/                    /* We do not use it
│   ├── gt_query/                   /* Files for multiple query testing
│   ├── query/                      /* Files for testing (query images)
│   ├── readme.txt
│   ├── pytorch/
│       ├── train/                   /* train,包含train_all除val剩下的图片
│           ├── 0002
|           ├── 0007
|           ...
│       ├── val/                     /* val,包含train_all每个子文件夹的第一张图片
│       ├── train_all/               /* train+val,包含bounding_box_train 的所有图片
│       ├── query/                   /* query files,包含所有待测行人的图片
│       ├── gallery/                 /* gallery files,包含bounding_box_test 的所有图片

部分代码


#query
query_path = download_path + '/query'
query_save_path = download_path + '/pytorch/query'
if not os.path.isdir(query_save_path):os.mkdir(query_save_path)for root, dirs, files in os.walk(query_path, topdown=True):for name in files:if not name[-3:]=='jpg':continueID  = name.split('_')src_path = query_path + '/' + namedst_path = query_save_path + '/' + ID[0] if not os.path.isdir(dst_path):os.mkdir(dst_path)copyfile(src_path, dst_path + '/' + name)

主要就是把原来数据集中的Market/query文件夹中的所有图像按照序号放在了新建的Market/pytorch/query/....文件夹。

如以0001开头的图像放在名称为Market/pytorch/query/0001的文件夹中

其他部分也大同小异,这里不再赘述。

一些函数

os.path.isdir(path)  #判断path是否为目录
os.mkdir(path)       #创建目录
os.walk(top, topdown=True, onerror=None, followlinks=False)
# top 要遍历的目录。
# topdown 遍历方式:从上到下遍历或者从下到上遍历。
# onerror 用来设置出现错误时的处理函数(该函数接受一个OSError的实例作为参数),设置为空则不作处理。
# followlinks 是否要跟随目录下的链接去继续遍历。要注意的是,os.walk不会记录已经遍历的目录,所以跟随链接遍历的话有可能一直循环调用下去。

model.py

该脚本实现了多种行人重识别模型,在看ResNet50之前,先看一下定义的分类器

ClassBlock

# Defines the new fc layer and classification layer
# |--Linear--|--bn--|--relu--|--Linear--|
class ClassBlock(nn.Module):def __init__(self, input_dim, class_num, droprate, relu=False, bnorm=True, linear=512, return_f = False):super(ClassBlock, self).__init__()self.return_f = return_fadd_block = []if linear>0:add_block += [nn.Linear(input_dim, linear)]else:linear = input_dimif bnorm:add_block += [nn.BatchNorm1d(linear)]if relu:add_block += [nn.LeakyReLU(0.1)]if droprate>0:add_block += [nn.Dropout(p=droprate)]add_block = nn.Sequential(*add_block)add_block.apply(weights_init_kaiming)classifier = []classifier += [nn.Linear(linear, class_num)]classifier = nn.Sequential(*classifier)classifier.apply(weights_init_classifier)self.add_block = add_blockself.classifier = classifierdef forward(self, x):x = self.add_block(x)if self.return_f:f = xx = self.classifier(x)return [x,f]else:x = self.classifier(x)return x

ResNet50

这里使用了pytorch预训练好的模型


# Define the ResNet50-based Model
class ft_net(nn.Module):def __init__(self, class_num=751, droprate=0.5, stride=2, circle=False, ibn=False, linear_num=512):## 下面有介绍super(ft_net, self).__init__()super(ft_net, self).__init__()model_ft = models.resnet50(pretrained=True)if ibn==True:## 下面有介绍torch.hub.load()model_ft = torch.hub.load('XingangPan/IBN-Net', 'resnet50_ibn_a', pretrained=True)# avg pooling to global pooling## 也就是平均池化改为了自适应平均池化if stride == 1:model_ft.layer4[0].downsample[0].stride = (1,1)model_ft.layer4[0].conv2.stride = (1,1)## 是使得池化后的每个通道上的大小是一个1x1的,也就是每个通道上只有一个像素点model_ft.avgpool = nn.AdaptiveAvgPool2d((1,1))self.model = model_ftself.circle = circle## 定义了自己的分类器self.classifier = ClassBlock(2048, class_num, droprate, linear=linear_num, return_f = circle)def forward(self, x):## 卷积层1x = self.model.conv1(x)## 归一化(Batch Normalization)常用在激活层之前## 可以加快模型训练时的收敛速度,使得模型训练过程更加稳定## 避免梯度爆炸或者梯度消失## 起到一定的正则化作用,几乎代替了Dropout。x = self.model.bn1(x)## 激活层(Rectified linear unit,ReLU)x = self.model.relu(x)## 最大池化x = self.model.maxpool(x)x = self.model.layer1(x)x = self.model.layer2(x)x = self.model.layer3(x)x = self.model.layer4(x)## 平均池化x = self.model.avgpool(x)x = x.view(x.size(0), x.size(1))## 自己的分类器x = self.classifier(x)return x
1) nn.Module    #nn.Module是PyTorch体系下所有神经网络模块的基类

来自这篇博文

我们在定义自已的网络的时候,需要继承nn.Module类,并重新实现构造函数__init__构造函数和forward这两个方法。

  1. 一般把网络中具有可学习参数的层(如全连接层、卷积层等)放在构造函数__init__()中,也可以把不具有参数的层也放在里面;
  2. 一般把不具有可学习参数的层(如ReLU、dropout、BatchNormanation层)可放在构造函数中,也可不放在构造函数中,如果不放在构造函数__init__里面,则在forward方法里面可以使用nn.functional来代替
  3. forward方法是必须要重写的,它是实现模型的功能,实现各个层之间的连接关系的核心。
2) super(ft_net, self).__ init __()

来自这篇博文

super(Net, self).__ init__() 是指首先找到Net的父类(比如是类NNet),然后把类Net的对象self转换为类NNet的对象,然后“被转换”的类NNet对象调用自己的init函数,其实简单理解就是子类把父类的__init__()放到自己的__init__()当中,这样子类就有了父类的__init__()的东西。
Net类继承nn.Module,super(Net, self).__ init__() 就是对继承自父类nn.Module的属性进行初始化,而且是用nn.Module的初始化方法来初始化继承的属性。

3) def __init__(self, class_num=751, droprate=0.5, stride=2, circle=False, ibn=False, linear_num=512):# 想知道参数都是什么意思 但是暂时不知道 先放一放
# stride 步长
4) model_ft = torch.hub.load('XingangPan/IBN-Net', 'resnet50_ibn_a', pretrained=True)

来自这篇博文

torch.hub.load(repo_or_dir, model, *args, source=‘github’, force_reload=False, verbose=True, skip_validation=False, **kwargs)

  • repo_or_dir:记得设置本地./pytorch/vision 路径
  • model 调用入口在:本地./pytorch/vision 路径下的hubconf.py文件里,因此模型名字要在hubconf.py中存在才能调用
  • source=‘local’, 记得设置本地,默认:github
  • pretrained=True #设定预训练模式,默认False,为了代码清晰,最好还是加上参数赋值.
    为了加快学习进度,训练的初期直接加载pretrain模型中预先训练好的参数
    pretrained=True 加载网络结构和预训练参数
    pretrained=False 只加载网络结构,不加载预训练参数,即不需要用预训练模型的参数来初始化。

train.py

一些参数

1) erasing_p     # 随机擦除(Random Erasing, RE)增强

论文:Random Erasing Data Augmentation
具体就是(来自这篇)
作者提出的目的主要是模拟遮挡,从而提高模型泛化能力。如果把物体遮挡一部分后依然能够分类正确,那么肯定会迫使网络利用局部未遮挡的数据进行识别,加大了训练难度,一定程度会提高泛化能力。可以视为add noise的一种,并且与随机裁剪、随机水平翻转具有一定的互补性,综合应用他们,可以取得更好的模型表现,尤其是对噪声和遮挡具有更好的鲁棒性。
具体操作就是:随机选择一个区域,然后采用随机值进行覆盖,模拟遮挡场景。

使用fp16

使用fp16可以让运算量大大降低,会快很多,不过要安装apex(目前还没有尝试过用fp16跑)


#fp16
try:from apex.fp16_utils import *from apex import ampfrom apex.optimizers import FusedSGD
except ImportError: # will be 3.x seriesprint('This is not an error. If you want to use low precision, i.e., fp16, please install the apex with cuda support (https://github.com/NVIDIA/apex) and update pytorch to 1.0')

预处理

transforms在计算机视觉工具包torchvision下,对图像进行预处理,这样可以使得之后运算的速度更快,这里加了一点注释


transform_train_list = [#transforms.RandomResizedCrop(size=128, scale=(0.75,1.0), ratio=(0.75,1.3333), interpolation=3), #Image.BICUBIC)transforms.Resize((h, w), interpolation=3), # 缩放transforms.Pad(10), # 填充transforms.RandomCrop((h, w)), # 在一个随机的位置进行裁剪transforms.RandomHorizontalFlip(), # 以0.5的概率水平翻转给定的PIL图像transforms.ToTensor(), # 图片转张量,同时归一化0-255 ---> 0-1transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) # 用给定的均值和标准差分别对每个通道的数据进行正则化]transform_val_list = [transforms.Resize(size=(h, w),interpolation=3), #Image.BICUBIC transforms.ToTensor(),transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]

数据集迭代器

来自Person_reID_baseline_pytorch 源码解析之 train.py
训练模型时,一般不会一次性把所有数据都加载到模型中。通常采用 mini_batch 的方法,按照 batchsize 的大小将一个 batch 的数据载入到模型中。pytorch 框架支持用 torch.utils.data.DataLoader 作为 dataloader 载入数据。


image_datasets = {}
image_datasets['train'] = datasets.ImageFolder(os.path.join(data_dir, 'train' + train_all),data_transforms['train'])
image_datasets['val'] = datasets.ImageFolder(os.path.join(data_dir, 'val'),data_transforms['val'])dataloaders = {x: torch.utils.data.DataLoader(image_datasets[x], batch_size=opt.batchsize,shuffle=True, num_workers=2, pin_memory=True) # 8 workers may work fasterfor x in ['train', 'val']}

将 image_datasets[‘train’] 和 image_datasets[‘val’] 输入 torch.utils.data.DataLoader 后,获得了两个迭代器 dataloaders[‘train’] and dataloaders[‘val’]

训练模块

# Iterate over data.for data in dataloaders[phase]:# get a batch of inputsinputs, labels = datanow_batch_size,c,h,w = inputs.shapeif now_batch_size<opt.batchsize: # skip the last batchcontinue# print(inputs.shape)# wrap them in Variable, if gpu is used, we transform the data to cuda.if use_gpu:inputs = Variable(inputs.cuda())labels = Variable(labels.cuda())else:inputs, labels = Variable(inputs), Variable(labels)# zero the parameter gradientsoptimizer.zero_grad() # 梯度初始化为零#-------- forward --------outputs = model(inputs) # 前向传播求出预测的值_, preds = torch.max(outputs.data, 1) # 见下面的解释loss = criterion(outputs, labels) # 求loss#-------- backward + optimize -------- # only if in training phaseif phase == 'train':loss.backward() # 反向传播求梯度optimizer.step() # 更新所有参数

_, preds = torch.max(outputs.data, 1)相关解释来自PyTorch系列 | _, predicted = torch.max(outputs.data, 1)的理解

关于为什么要使用 optimizer.zero_grad() 来自为什么要使用 zero_grad()?

行人重识别 代码阅读(来自郑哲东 简单行人重识别代码到88%准确率)相关推荐

  1. 入门行人重识别 尝试跑(郑哲东 简单行人重识别代码到88%准确率)过程

    来自郑哲东 简单行人重识别代码到88%准确率 运行代码和参考步骤 试运行-第一部分 prepare.py model.py train.py 试运行-第二部分 test.py 运行代码和参考步骤 代码 ...

  2. OPenCV4-颜色识别(一)调色板和简单的颜色识别

    OPenCV4-颜色识别(一)调色板和简单的颜色识别 使用 OPenCV4 做颜色识别十分简单.本文章使用 python 语言来实现一个调色板和简单的颜色识别. 1.调色板 绘制一个调色板对颜色识别非 ...

  3. 嵌入式OCR+RFID识别电子护照阅读器模块MRZ码电子证件识别模组的应用与攻略

    嵌入式OCR+RFID识别电子护照阅读器模块|MRZ码电⼦证件验证识别模组是一款USB通讯模式,体积小巧.便于安装,方便嵌入至各类终端设备中,适用于出境游旅行社.海关检查.口岸出入境检查.使领馆签证登 ...

  4. JavaSE基础知识(五)--面向对象代码实现初步(实现一个简单的类类型代码)

    Java SE 是什么,包括哪些内容(五)? 本文内容参考自Java8标准 一.面向对象(代码实现): 首先,在这里我需要说明一个根本性的问题:实际上,面向对象编程包括了两部分,一个是你的编程思想,一 ...

  5. python文字识别并获取位置_python实现简单的文字识别

    将图片翻译成文字一般被称为光学文字识别(Optical Character Recognition,OCR).今天我们用到的就是一个OCR 库--Tesseract. 首先要安装Tesseract,除 ...

  6. android 代码浏览,Webview实现android简单的浏览器实例代码

    WebView是Android中一个非常实用的组件,它和Safai.Chrome一样都是基于Webkit网页渲染引擎,可以通过加载HTML数据的方式便捷地展现软件的界面,下面通过本文给大家介绍Webv ...

  7. python做数据可视化的代码_Python数据可视化正态分布简单分析及实现代码

    Python说来简单也简单,但是也不简单,尤其是再跟高数结合起来的时候... 正态分布(Normaldistribution),也称"常态分布",又名高斯分布(Gaussiandi ...

  8. 二维正态分布图python代码_Python数据可视化正态分布简单分析及实现代码

    Python说来简单也简单,但是也不简单,尤其是再跟高数结合起来的时候... 正态分布(Normaldistribution),也称"常态分布",又名高斯分布(Gaussiandi ...

  9. 代码阅读器 android,适用于Android的条形码/ Qr代码阅读器

    我用zxing构建了我的应用程序.你需要一些编码.首先包括core.jar,它在core / core.jar,在你的构建路径中,然后转到他们的客户端,在 android /-./ com.googl ...

最新文章

  1. 蟑螂背上芯片板子,组队去救人类
  2. Spring MVC的框架组件
  3. 动态添加内容到百度搜索框里
  4. halcon create_ocr_class_svm 使用SVM分类器创建OCR分类器
  5. shell命令查阅端口信息_linux运维实用的42个常用命令总结
  6. Spring Data JPA 实例查询
  7. 别再学习框架了,看看这些让你起飞的计算机基础知识
  8. QT--QDockWidget 停靠窗口
  9. 烧钱两年,做事对得起工资,也要对得起公司这份决心
  10. c语言王者荣耀制作,易语言制作王者荣耀刷金币脚本的代码
  11. 点击图片放大功能 jquery
  12. 【GEE笔记】最大类间方差法(otsu、大津法)算法实现——计算阈值、图像二值化分割
  13. 企业微信和个人微信的区别是什么?
  14. CodeIgniter session过期时间问题
  15. oracle的gca的文件,GCA文件扩展名 - 什么是.gca以及如何打开? - ReviverSoft
  16. php 大转盘抽奖概率 角度,在线抽奖大转盘和概率计算
  17. 孩子数学成绩不好怎么办_孩子数学成绩差怎么才能快速提高
  18. EXCEL干货(1-1): 基本表格操作
  19. 简单的proxy之TinyHTTPProxy.py
  20. 2020全球创业者城市Top50

热门文章

  1. uniapp 点击动画_uni-app 点击元素左右抖动效果
  2. 利用java实现提现金额到支付宝账户的功能
  3. 十年技术支持工作的几点感悟
  4. 团队作业—beta冲刺
  5. ESP-Hosted 入门介绍 使用指南
  6. 好的博客学习的地址【持续更新中】
  7. java超级计算器,jdk自带类
  8. 会解方程会画图的超级计算器
  9. 【GStreamer 】2-ubuntu v4l2-ctl 查看USB 相机基本参数
  10. 直面大数据撞击这个时代——畅享网成功举办大数据应用沙龙