文章目录

  • 前言
  • 1、SiamRPN相比于SiamFC的创新
  • 2、architecture
    • 2.1 特征提取网络(Siamese Network)和RPN
    • 2.2 训练与测试网络推导
      • 2.2.1 训练推导
      • 2.2.1 测试推导
  • 下一篇链接

前言

  SiamRPN代码分析从代码角度出发,分析时结合本人对论文的理解,总共由三部分组成:architecture、training、test。本文是SiamRPN代码分析的第一部分。
  SiamRPN论文
  代码


1、SiamRPN相比于SiamFC的创新

  SiamFC的一个最明显的缺陷在于,当目标发生较大形变时无法跟踪,原因在于它使用的固定多尺度tracking方式,在前一帧的目标位置生成固定的几个尺度框,选择得分最高的框作为当前帧的跟踪结果。这种tracking策略当目标发生大形变时将无法跟踪,然而SiamRPN很好地解决了该问题,它第一次在目标跟踪领域引入了anchor的使用,有了anchor,就不再需要多尺度策略,有人会问:anchor也是多种尺度啊?关键在于数据关联方式,SiamFC是利用全卷积结构把目标模板当作卷积核在搜索图像上卷积,有点类似滑动窗口,而SiamRPN是以一种穷举的方式,在搜索图像上举出所有可能存在目标的anchor,简而言之,就是SiamFC相邻帧的尺度变化是较小的,所以无法跟踪大形变物体,而由于穷举anchor的缘故,SiamRPN相邻帧的尺度变化较大,可以由长条变为横条。anchor可以理解为在搜索图像上生成大大小小的各种框,足以覆盖图像内的每个物体(包含待跟踪目标),示意图如下所示,左侧是经过数据预处理的原始图片,中间图片列出了所有anchor的尺度,最右侧图片是在图片规定好的正方形范围以每个像素为中心生成所有尺度的anchor。
  对比两篇论文的实验结果,发现SiamRPN的fps大于SiamFC,虽然设备不一样,但也不会差别这么大(86fps<160fps),我在相同环境下实验时,发现SiamRPN只比SiamFC快些许。单看网络结构,会发现SiamRPN的更复杂,而且SiamFC网络有的SiamRPN也都有,速度的差异是看测试过程,SiamFC需要多尺度测试,意味着要进行多次全卷积,而SiamRPN只要一次,这是两者跟踪速度差距的主要因素。

2、architecture

  下面是从论文截取的SiamRPN框架图
左边是用于特征提取的孪生子网络。区域建议子网络位于中间,有两个分支,上分支用于分类,下分支用于回归。采用互相关的方法得到两个分支的输出。这两个输出特性映射的详细信息在右边。在分类分支中,输出特征图有2k个通道,对应k个anchor的前景和背景。回归分支中,输出特征图有4k通道,对应4个坐标,用于对anchor位置的微调。回归的作用是得到预测的目标框,回归的任务是将得到的目标框微调,使得位置更加精确。

2.1 特征提取网络(Siamese Network)和RPN

  代码如下,这部分代码是SiamRPNNet的网络结构定义__init__和网络参数初始化_init_weights。

class SiamRPNNet(nn.Module):def __init__(self, init_weight=False):super(SiamRPNNet, self).__init__()self.featureExtract = nn.Sequential(nn.Conv2d(3, 96, 11, stride=2),  #stride=2  [batch,3,127,127]->[batch,96,59,59]nn.BatchNorm2d(96),nn.MaxPool2d(3, stride=2),       #stride=2  [batch,96,58,58]->[batch,96,29,29]nn.ReLU(inplace=True), nn.Conv2d(96, 256, 5),           #[batch,256,29,29]->[batch,256,25,25]nn.BatchNorm2d(256),nn.MaxPool2d(3, stride=2),       #stride=2  [batch,256,25,25]->[batch,256,12,12]nn.ReLU(inplace=True),nn.Conv2d(256, 384, 3),          #[batch,256,12,12]->[batch,384,10,10]nn.BatchNorm2d(384),nn.ReLU(inplace=True),nn.Conv2d(384, 384, 3),          #[batch,384,10,10]->[batch,384,8,8]nn.BatchNorm2d(384),nn.ReLU(inplace=True),nn.Conv2d(384, 256, 3),          #[batch,384,8,8]->[batch,256,6,6]nn.BatchNorm2d(256),)self.anchor_num = config.anchor_num    #每一个位置有5个anchor""" 模板的分类和回归"""self.examplar_cla = nn.Conv2d(256, 256 * 2 * self.anchor_num, kernel_size=3, stride=1, padding=0)self.examplar_reg = nn.Conv2d(256, 256 * 4 * self.anchor_num, kernel_size=3, stride=1, padding=0)""" 搜索图像的分类和回归"""self.instance_cla = nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=0)self.instance_reg = nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=0)#这一步是SiamRPN框架中没有的,1x1的卷积,感觉可有可无,仅仅用于回归分支#简单理解就是增加了网络的学习能力self.regress_adjust = nn.Conv2d(4 * self.anchor_num, 4 * self.anchor_num, 1)if init_weight:self._init_weights()def _init_weights(self):for m in self.modules():if isinstance(m, nn.Conv2d):nn.init.xavier_uniform_(m.weight, 1) #xavier是参数初始化,它的初始化思想是保持输入和输出方差一致,这样就避免了所有输出值都趋向于0if m.bias is not None:nn.init.constant_(m.bias, 0)     #偏置初始化为0elif isinstance(m, nn.BatchNorm2d):      #在激活函数之前,希望输出值由较好的分布,以便于计算梯度和更新参数,这时用到BatchNorm2d函数nn.init.constant_(m.weight, 1)nn.init.constant_(m.bias, 0)elif isinstance(m, nn.Linear):nn.init.xavier_uniform_(m.weight, 1)if m.bias is not None:nn.init.constant_(m.bias, 0)

self.featureExtract里面定义了特征提取网络,使用的是alexnet,这部分具体讲解课参考这篇博客,rpn结构定义了4个不同的nn.Conv2d,分别对应architrcture4个橙黄色的Conv,主要的推导过程在2.2小节详解。

2.2 训练与测试网络推导

2.2.1 训练推导

  先给出代码,再看着分析,训练推导代码如下:

"""——————————前向传播用于训练——————————————————"""def forward(self, template, detection):N = template.size(0)    # batch=32template_feature = self.featureExtract(template)    #[32,256,6,6]detection_feature = self.featureExtract(detection)  #[32,256,24,24]"""对应模板分支,求分类核以及回归核"""# [32,256*2*5,4,4]->[32,2*5,256,4,4]kernel_score = self.examplar_cla(template_feature)# 类似可得[32,4*5,256,4,4]kernel_regression = self.examplar_reg(template_feature)"""对应搜素图像的分支,得到搜索图像的特征图"""conv_score = self.instance_cla(detection_feature)      #[32,256,22,22]conv_regression = self.instance_reg(detection_feature)   #[32,256,22,22]"""对应模板和搜索图像的分类"""# [32,256,22,22]->[1,32*256=8192,22,22]conv_scores = conv_score.reshape(1, -1, 22, 22)score_filters = kernel_score.reshape(-1, 256, 4, 4)    #[32*2*5,256,4,4]#inout=[1,8192,22,22],filter=[320,256,4,4],得到output=[1,32*2*5,19,19]->[32,10,19,19] 32始终是batch不变,勿忘!pred_score = F.conv2d(conv_scores, score_filters, groups=N).reshape(N, 10, 19,19)"""对应模板和搜索图像的回归----------"""#[32,256,22,22]->[1,32*256=8192,22,22]conv_reg = conv_regression.reshape(1, -1, 22, 22)reg_filters = kernel_regression.reshape(-1, 256, 4, 4)   #[32*4*5,256,4,4]#input=[1,8192,22,22],filter=[340,256,4,4],得到output=[1,32*4*5,19,19]->[32,4*5=20,19,19]——>微调pred_regression = self.regress_adjust(F.conv2d(conv_reg, reg_filters, groups=N).reshape(N, 20, 19, 19))#score.shape=[32,10,19,19],regression.shape=[32,20,19,19]return pred_score, pred_regression

代码注释中,假定训练时的batch_size=32,首先127x127x3大小的目标模板图像和271x271x3大小检测图像(搜索图像)共享特征提取网络的网络参数,得到相应大小的feature map,分别为6x6x256、24x24x256,论文中的检测图像大小是255,所以输出的feature map大小为22x22x256;然后,将6x6x256大小的检测图像通过两个卷积层self.instance_cla和self.instance_reg分别用于回归和分类,可以看到,这两个卷积层的结构是完全相同的,最终大小都为22x22x256,但这两个feature map的用途不一样,分别用于分类和回归,再将上分支通过特征提取网络得到的6x6x256大小的feature map分别通过self.examplar_cla和self.examplar_reg,得到的feature map大小为[32,256x2x5,4,4]和[32,4x5x256,4,4],其中2和4分别表示需要学习的实际参数,2代表前景和背景,前景分数越大,代表是目标的概率越高,4表示dx、dy、dw、dh四个微调参数。
  接下来就是将得到的4张feature map两两成对互相关,如果有对互相关不理解的小伙伴在博客中有讲解。最终得到用于分类和回归的feature map大小分别是[32,10,19,19]、[32,20,19,19]就可用于训练。

2.2.1 测试推导

  先上代码

"""—————————————初始化————————————————————"""def track_init(self, template):N = template.size(0) #1template_feature = self.featureExtract(template)# [1,256, 6, 6]# kernel_score=[1,2*5*256,4,4]   kernel_regression=[1,4*5*256,4,4]kernel_score = self.examplar_cla(template_feature)kernel_regression = self.examplar_reg(template_feature)self.score_filters = kernel_score.reshape(-1, 256, 4, 4)    #[2*5,256,4,4]self.reg_filters = kernel_regression.reshape(-1, 256, 4, 4) #[4*5,256,4,4]"""—————————————————跟踪—————————————————————"""def track_update(self, detection):N = detection.size(0)# [1,256,24,24]detection_feature = self.featureExtract(detection)"""----得到搜索图像的feature map-----"""conv_score = self.instance_cla(detection_feature)     #[1,256,22,22]conv_regression = self.instance_reg(detection_feature)  #[1,256,22,22]"""---------与模板互相关"""#input=[1,256,22,22] filter=[2*5,256,4,4] gropu=1 得output=[1,2*5,19,19]pred_score = F.conv2d(conv_score, self.score_filters, groups=N)# input=[1,256,22,22] filter=[4*5,256,4,4] gropu=1 得output=[1,4*5,19,19]pred_regression = self.regress_adjust(F.conv2d(conv_regression, self.reg_filters, groups=N))#score.shape=[1,10,19,19],regression.shape=[1,20,19,19]return pred_score, pred_regression

仔细分析,这部分代码与forward类似,只是结构有些变化,训练时是上下分支同时输入输出,而在测试时,明确分为了初始帧和后续待跟踪帧通过的网络结构,后续帧与固定不变的初始帧互相关运算。论文中也提到过,如下,论文说这是SiamRPN快速的原因,但要知道这并不是SiamRPN相比于SiamFC更快的原因。

最后,验证下搭建网络的正确性,验证代码和结果如下,可以看到网络搭建是正确的。

if __name__ == '__main__':model = SiamRPNNet()z_train = torch.randn([32,3,127,127])  #batch=8x_train = torch.randn([32,3,271,271])# 返回shape为[32,20,19,19] [32,10,19,19]  20=5*4 10=5*2pred_score_train, pred_regression_train = model(z_train,x_train)z_test = torch.randn([1,3,127,127])x_test = torch.randn([1,3,271,271])model.track_init(z_test)# 返回shape为[1,20,19,19] [1,10,19,19]pred_score_test, pred_regression_test = model.track_update(x_test)

Over,第一部分architecture代码分析就结束了!!

下一篇链接

  SiamRPN代码分析:training

SiamRPN代码分析:architecture相关推荐

  1. Device Tree(三):代码分析

    2019独角兽企业重金招聘Python工程师标准>>> 一.前言 Device Tree总共有三篇,分别是: 1.为何要引入Device Tree,这个机制是用来解决什么问题的?(请 ...

  2. 模块加载过程代码分析1

    一.概述 模块是作为ELF对象文件存放在文件系统中的,并通过执行insmod程序链接到内核中.对于每个模块,系统都要分配一个包含以下数据结构的内存区. 一个module对象,表示模块名的一个以null ...

  3. kernel 3.10代码分析--KVM相关--虚拟机创建\VCPU创建\虚拟机运行

    分三部分:一是KVM虚拟机创建.二是VCPU创建.三是KVM虚拟机运行 第一部分: 1.基本原理 如之前分析,kvm虚拟机通过对/dev/kvm字符设备的ioctl的System指令KVM_CREAT ...

  4. crt0.S(_main)代码分析

    crt0,S(_main)代码分析 --- 1. 设置sp寄存器地址 //设置SP栈指针 #if defined(CONFIG_SPL_BUILD) && defined(CONFIG ...

  5. Device Tree(三):代码分析【转】

    转自:http://www.wowotech.net/linux_kenrel/dt-code-analysis.html Device Tree(三):代码分析 作者:linuxer 发布于:201 ...

  6. Pixhawk代码分析-启动代码及入口函数

    启动代码及入口函数 基础知识 关于坐标系 1)GeographicCoordinate System Represents position on earth with alongitude and ...

  7. Linux GIC代码分析

    一.前言 GIC(Generic Interrupt Controller)是ARM公司提供的一个通用的中断控制器,其architecture specification目前有四个版本,V1-V4(V ...

  8. ARM GICv3 ITS介绍及代码分析

    前言: 在ARM gicv3中断控制器,有提到过ITS的作用,本篇就ITS进行更详细的介绍以及分析linux 内核中ITS代码的实现. 本文基于linux 4.19,介绍DT方式初始化的ITS代码. ...

  9. arm linux kernel 从入口到start_kernel 的代码分析

    Linux系统启动过程分析(主要是加载内核前的动作) 经过对Linux系统有了一定了解和熟悉后,想对其更深层次的东西做进一步探究.这当中就包括系统的启动流程.文件系统的组成结构.基于动态库和静态库的程 ...

  10. kernel 3.10代码分析--KVM相关--虚拟机运行

    1.基本原理 KVM虚拟机通过字符设备/dev/kvm的ioctl接口创建和运行,相关原理见之前的文章说明. 虚拟机的运行通过/dev/kvm设备ioctl VCPU接口的KVM_RUN指令实现,在V ...

最新文章

  1. BCH优于BCE+LN的5个理由
  2. VMware或者KVM克隆的虚拟机网卡无法启动
  3. 互联网1分钟 | 0117 IBM入驻上海张江人工智能岛;IoT业务将成为小米新支撑点
  4. 2013\Province_C_C++_A\4.颠倒的价牌
  5. 鸿蒙系统中的 JS 开发框架
  6. LeetCode(617)——合并二叉树(JavaScript)
  7. 天天都在用的 Nginx,可你知道如何用一个反向代理实现多个不同类型的后端网站访问吗?...
  8. Flex结合java实现一个登录功能
  9. 网页转化成pdf,网页转换图片,wkhtmltopdf,wkhtmltoimage使用小结
  10. c++实现解释器模式完整源代码
  11. sci二区计算机类有哪些期刊,二区材料类sci期刊有哪些
  12. android wifi已停用,为什么手机连接wifi时总显示已停用
  13. 矩阵的分解——LU分解
  14. Java UI设计 计算三角形周长
  15. OptaPlanner快速开始
  16. JavaSE基础(21) 打印数组
  17. [安卓相机1]简单小Demo
  18. 《管理学》第九章 沟通
  19. Java设计模式之(九)——门面模式
  20. 2022年高压电工上岗证题库及答案

热门文章

  1. 怎样查询网站关键字的排名
  2. 打开FOXMAIL常见错误提示“Message format error”
  3. 爬取分析雪球网实盘用户数据
  4. synchronized的底层实现
  5. win10蓝屏无法进入系统_WIN10系统进“吃鸡”蓝屏
  6. SpringBoot+H5微信登陆(网页)
  7. 全角半角英文字母及符号
  8. canvas多重阴影发光效果
  9. Android tips(十)--允许模拟位置在Android M下的坑
  10. 如何应用计算机键盘截图,计算机屏幕截图的键盘快捷键是哪个键?在计算机上截图的方法...