参考代码:RAFT
作者主页:Zachary Teed

1. 概述

导读:这篇文章提出了一种新的光流估计pipline,与之前介绍的PWC-Net类似其也包含特征抽取/correlation volume构建操作。在这篇文章中为了优化光流估计,首先在correlation volume的像素上进行邻域采样得到lookups特征(增强特征相关性,也可以理解为感受野),之后直接使用以CNN-GRU为基础的迭代优化网络,在完整尺寸上对光流估计迭代优化。这样尽管采用了迭代优化的形式,文章的迭代优化机制也比像IRR/FlowNet这类方法轻量化,运行速度也更快,其可以在1080 TI GPU上达到10FPS(输入为1088∗4361088*4361088∗436)。文章的算法在诸如特征处理与融合/上采样策略上设计得细致合理,并且使用迭代优化的策略,从而使得文章算法具有较好的泛化性能

将文章的方法与之前的一些方法作对比,可以将其中对比得到的改进点归纳如下:

  • 1)抛弃了类似PWC-Net中的coarse-to-fine的光流迭代优化策略,直接生成全尺寸的光流估计,从而避免了这种优化策略带来的弊端:coarse层次的预测结果会天然增加丢失小而快速运动的目标的风险,并且训练需要的迭代次数也更多;
  • 2)为了提升光流估计的准确性,一种可行的方式就是进行module的叠加优化,如FlowNet和IRR等,但是这样的操作一个是带来更多的参数量,增加运算的时间。还会使得整个网络的训练过程变得繁琐冗长;
  • 3)光流的更新模块,文章使用以CNN-GRU为基础,在4D的correlation volume上对其采样得到的correlation lookups进行运算,从而得到光流信息。这样的更新模块引入了GRU网络,很好利用了迭代优化的时序特性;

将文章的方法与其它的一些光流估计方法进行比较:

2. 方法设计

2.1 整体pipline

文章的整体pipeline如下:

按照上图所示可以将整体pipeline划分为3个部分(阶段):

  • 1)feature encoder进行输入图像的抽取,以及context encoder进行图像特征的抽取;
  • 2)使用矩阵相乘的方式构建correlation volume,之后使用池化操作得到correlation volume pyramid;
  • 3)对correlation volume在像素邻域上进行采样,之后使用以CNN-GRU为基础构建的光流迭代更新网络进行全尺寸光流估计;

文章按照网络容量的不同设计了一大一小的两个网络,后面的内容都是以大网络为基准,其网络结构为:

文章的整体流程简洁,直接在一个forawrd中完成了所有操作,其具体的步骤可以归纳为:

# core/raft.py#86
# step1:图像1/2的feature encoder特征抽取
fmap1, fmap2 = self.fnet([image1, image2])  # [N, 256, H//8, W//8]# step2:correlation volume pyramid构建
if self.args.alternate_corr:corr_fn = AlternateCorrBlock(fmap1, fmap2, radius=self.args.corr_radius)
else:corr_fn = CorrBlock(fmap1, fmap2, radius=self.args.corr_radius)  # 输入两幅图像特征用于构造金字塔相似矩阵# step3:图像1的context encoder特征抽取
cnet = self.cnet(image1)
net, inp = torch.split(cnet, [hdim, cdim], dim=1)  # 对输出的特征进行划分
# 一部分用于递归优化的输入,一部分用于GRU递归优化的传递变量
net = torch.tanh(net)  # [N, 256, H//8, W//8]
inp = torch.relu(inp)  # [N, 256, H//8, W//8]# step4:以图像1经过编码之后的尺度构建两个一致的坐标网格
coords0, coords1 = self.initialize_flow(image1)  # 一个用于更新(使用每次迭代预测出来的光流),一个用于作为基准if flow_init is not None:  # 若初始光流不为空,则用其更新初始光流coords1 = coords1 + flow_init# step5:进行光流更新迭代
flow_predictions = []
for itr in range(iters):coords1 = coords1.detach()# 在坐标网格的基础上对correlation volume pyramid进行半径r=4的邻域采样corr = corr_fn(coords1)  # index correlation volume [N, (2*r+1)*(2*r+1)*num_levels, H//8, W//8]# 使用CNN-GRU计算光流偏移量与上采样系数等flow = coords1 - coords0with autocast(enabled=self.args.mixed_precision):net, up_mask, delta_flow = self.update_block(net, inp, corr, flow)  # 迭代之后的特征/采样权重/预测光流偏移# F(t+1) = F(t) + \Delta(t)coords1 = coords1 + delta_flow  # 更新光流# upsample predictionsif up_mask is None:flow_up = upflow8(coords1 - coords0)  # 普通的上采样方式else:flow_up = self.upsample_flow(coords1 - coords0, up_mask)  # 使用卷积构造的上采样方式flow_predictions.append(flow_up)  # 保存当前迭代次数的光流优化结果

2.2 correlation volume

这里主要讲述correlation volume的构建过程,之后在其基础上进行邻域采样构建correlation lookups(用于提升光流信息的特征相关性),以及提出一种更加高效的correlation volume构建方式(减少计算复杂度)。这里的编码器特征抽取部分省略。。。(其输出的维度为:[N,256,H//8,H//8][N,256,H//8,H//8][N,256,H//8,H//8])

构建过程:
correlation volume的构建过程其实是一个矩阵相乘形式:

# core/corr.py#53
def corr(fmap1, fmap2):batch, dim, ht, wd = fmap1.shapefmap1 = fmap1.view(batch, dim, ht*wd)fmap2 = fmap2.view(batch, dim, ht*wd) corr = torch.matmul(fmap1.transpose(1,2), fmap2)  # 图像1/2的特征矩阵乘 [batch, ht*wd, ht*wd]corr = corr.view(batch, ht, wd, 1, ht, wd)  # [batch, ht, wd, 1, ht, wd]return corr  / torch.sqrt(torch.tensor(dim).float())

在此基础上使用池化操作得到correlation volume,这里使用到的层级为4(池化操作的kernel size为{1,2,4,8}\{1,2,4,8\}{1,2,4,8})。也就如下图所示:

correlation lookups的构建:
这里为了增加correlation volume中每个像素对周围像素的感知能力,使用半径r=4r=4r=4的邻域对correlation volume中每个像素进行采样,之后再组合起来。其实现可以参考:

# core/corr.py#29
def __call__(self, coords):r = self.radiuscoords = coords.permute(0, 2, 3, 1)  # flow的idx坐标信息permutebatch, h1, w1, _ = coords.shape  # [batch, h1, w1, 2]out_pyramid = []for i in range(self.num_levels):corr = self.corr_pyramid[i]dx = torch.linspace(-r, r, 2*r+1)  # 构造邻域采样空间dy = torch.linspace(-r, r, 2*r+1)delta = torch.stack(torch.meshgrid(dy, dx), axis=-1).to(coords.device)centroid_lvl = coords.reshape(batch*h1*w1, 1, 1, 2) / 2**i  # 将光流缩放到对应的金字塔尺度上去delta_lvl = delta.view(1, 2*r+1, 2*r+1, 2)  # 邻域区域,邻域半径r=4coords_lvl = centroid_lvl + delta_lvl  # 在光流基础上加上邻域区域的偏置,[batch*h1*w1, 2*r+1, 2*r+1, 2]# 在邻域上对correlation volume在坐标coords_lvl引导下进行双线性采样corr = bilinear_sampler(corr, coords_lvl)  # [batch*h1*w1, 1, 2*r+1, 2*r+1]corr = corr.view(batch, h1, w1, -1)  # [batch, h1, w1, (2*r+1) * (2*r+1)]out_pyramid.append(corr)out = torch.cat(out_pyramid, dim=-1)  # [batch, h1, w1, (2*r+1) * (2*r+1) * num_levels]return out.permute(0, 3, 1, 2).contiguous().float()    # [batch, (2*r+1) * (2*r+1) * num_levels, h1, w1]

更加高效的correlation构建:
在之前的correlation volume构建过程中是直接在编码器输出的特征图上运算,其计算的复杂度为O(N2)O(N^2)O(N2),其中NNN是特征图上像素的个数(W//8∗H//8W//8*H//8W//8∗H//8,channel=1)。之后这个保持不变,使用不同kernel size的池化操作迭代计算MMM次(也就是金字塔的层级)。那么对此文章对于层级为mmm处correlation volume的计算其实是可以描述为下面的形式的:
Cijklm=122m∑p2m∑q2m⟨gi,j(1),g2mk+p,2ml+q(2)⟩=⟨gi,j(1),122m(∑p2m∑q2mg2mk+p,2ml+q(2))⟩C_{ijkl}^m=\frac{1}{2^{2m}}\sum_p^{2^m}\sum_q^{2^m}\langle g_{i,j}^{(1)},g_{2^mk+p,2^ml+q}^{(2)}\rangle=\langle g_{i,j}^{(1)},\frac{1}{2^{2m}}(\sum_p^{2^m}\sum_q^{2^m}g_{2^mk+p,2^ml+q}^{(2)})\rangleCijklm​=22m1​p∑2m​q∑2m​⟨gi,j(1)​,g2mk+p,2ml+q(2)​⟩=⟨gi,j(1)​,22m1​(p∑2m​q∑2m​g2mk+p,2ml+q(2)​)⟩
也就是图像1的特征与图像2图像块avg-pooling之后的特征进行计算,进而可以减少计算复杂度,变为$O(NM)。其实现可以参考类:

#core/corr.py#63
class AlternateCorrBlock:...

以及目录alt_cuda_corr下的实现。

2.3 迭代更新机制

文章的光流估计是采用迭代更新的机制实现的,也就是在一个迭代序列中会生成光流序列{f1,…,fN}\{f_1,\dots,f_N\}{f1​,…,fN​},初始情况下f0=0f_0=0f0​=0,每次迭代之后的更新量描述为Δf\Delta fΔf,那么其更新过程描述为:
fk+1=fk+Δff_{k+1}=f_k+\Delta ffk+1​=fk​+Δf
迭代更新的初始值:
在缺省情况下文章的方法是使用0作迭代的初始光流。当然也是可以接受用一个先验光流作为输入,并在该基础上进行更新迭代,也就是文章提到的warm-start。

迭代更新的输入:
在上文中的网络结构图中可以看到输入CNN-GRU网络模块中的信息是包含3个:context encoder的输出特征/上一次光流的迭代结果/correlation volume在邻域内的采样结果。它们在网络中通过concat的形式进行特征融合,融合之后的特征记为xtx_txt​。

特征更新过程:
光流在进行估计之前会经过CNN-GRU模块,这里采用的是Separate的形式,也就是分离的大卷积核(减少参数量的同时,增大感受野)。这里循环递归的隐变量为hth_tht​,它初始的时候使用context encoder产生。其在GRU模块中更新的过程可以描述为:
zt=σ(Conv3∗3([ht−1,xt],Wz))z_t=\sigma(Conv_{3*3}([h_{t-1},x_t],W_z))zt​=σ(Conv3∗3​([ht−1​,xt​],Wz​))
rt=σ(Conv3∗3([ht−1,xt],Wr))r_t=\sigma(Conv_{3*3}([h_{t-1},x_t],W_r))rt​=σ(Conv3∗3​([ht−1​,xt​],Wr​))
hˉt=tanh(Conv3∗3[rt⊙ht−1,xt],Wh)\bar{h}_t=tanh(Conv_{3*3}[r_t\odot h_{t-1},x_t],W_h)hˉt​=tanh(Conv3∗3​[rt​⊙ht−1​,xt​],Wh​)
ht=(1−zt)⊙ht−1+zt⊙hˉth_t=(1-z_t)\odot h_{t-1}+z_t\odot \bar{h}_tht​=(1−zt​)⊙ht−1​+zt​⊙hˉt​

光流更新量估计:
这里估计的光流其实是一个偏移量Δf\Delta fΔf,相对坐标矩阵来讲的。最后将变换后的坐标矩阵减去之前的基准就得到了最后的光流估计。

全尺寸上采样:
这里将传统的双线性采样操作替换为了基于卷积的上采样操作,也就是每个全尺寸中的每个像素是通过在stride=8的光流预测结果对应位置处采样3∗33*33∗3的块,之后通过预测出来的权值进行加权组合。其采样加权的示意图如下:

其实现代码参考:

# core/raft.py#72
def upsample_flow(self, flow, mask):""" Upsample flow field [H/8, W/8, 2] -> [H, W, 2] using convex combination """N, _, H, W = flow.shapemask = mask.view(N, 1, 9, 8, 8, H, W)  # [N, 8*8*9, H, W]->[N, 1, 9, 8, 8, H, W]mask = torch.softmax(mask, dim=2)  # 当前像素与邻域像素的权值up_flow = F.unfold(8 * flow, [3, 3], padding=1)  # 对光流进行窗口滑块提取[N, 3*3*2, H, W](提前乘上采样系数8)up_flow = up_flow.view(N, 2, 9, 1, 1, H, W)  # [N, 3*3*2, H, W]->[N, 2, 9, 1, 1, H, W]up_flow = torch.sum(mask * up_flow, dim=2)  # 当前像素与邻域进行加权计算, [N, 2, 9, 1, 1, H, W]->[N, 2, 8, 8, H, W]up_flow = up_flow.permute(0, 1, 4, 2, 5, 3)  # [N, 2, 8, 8, H, W]->[N, 2, H, 8, W, 8]return up_flow.reshape(N, 2, 8*H, 8*W)  # 得到输入图像分辨率的光流估计

将直接双线性上采样的结果与这里提出的上采样结果进行对比如下:

2.4 损失函数

文章的方法是迭代更新的,因而会生成许多个序列的光流估计结果{f1,…,fN}\{f_1,\dots,f_N\}{f1​,…,fN​},对此其损失函数描述为:
L=∑i=1NγN−i∣fgt−fi∣1L=\sum_{i=1}^N\gamma^{N-i}|f_{gt}-f_i|_1L=i=1∑N​γN−i∣fgt​−fi​∣1​
其中,γ=0.8\gamma=0.8γ=0.8。

3. 实验结果

KITTI数据集上的性能比较:

消融实验:

训练策略:

《RAFT:Recurrent All-Pairs Field Transforms for Optical Flow》论文笔记相关推荐

  1. 论文笔记之Understanding and Diagnosing Visual Tracking Systems

    Understanding and Diagnosing Visual Tracking Systems 论文链接:http://dwz.cn/6qPeIb 本文的主要思想是为了剖析出一个跟踪算法中到 ...

  2. 《Understanding and Diagnosing Visual Tracking Systems》论文笔记

    本人为目标追踪初入小白,在博客下第一次记录一下自己的论文笔记,如有差错,恳请批评指正!! 论文相关信息:<Understanding and Diagnosing Visual Tracking ...

  3. 论文笔记Understanding and Diagnosing Visual Tracking Systems

    最近在看目标跟踪方面的论文,看到王乃岩博士发的一篇分析跟踪系统的文章,将目标跟踪系统拆分为多个独立的部分进行分析,比较各个部分的效果.本文主要对该论文的重点的一个大致翻译,刚入门,水平有限,如有理解错 ...

  4. 目标跟踪笔记Understanding and Diagnosing Visual Tracking Systems

    Understanding and Diagnosing Visual Tracking Systems 原文链接:https://blog.csdn.net/u010515206/article/d ...

  5. 追踪系统分模块解析(Understanding and Diagnosing Visual Tracking Systems)

    追踪系统分模块解析(Understanding and Diagnosing Visual Tracking Systems) PROJECT http://winsty.net/tracker_di ...

  6. ICCV 2015 《Understanding and Diagnosing Visual Tracking Systems》论文笔记

    目录 写在前面 文章大意 一些benchmark 实验 实验设置 基本模型 数据集 实验1 Featrue Extractor 实验2 Observation Model 实验3 Motion Mod ...

  7. Understanding and Diagnosing Visual Tracking Systems

    文章把一个跟踪器分为几个模块,分别为motion model, feature extractor, observation model, model updater, and ensemble po ...

  8. CVPR 2017 SANet:《SANet: Structure-Aware Network for Visual Tracking》论文笔记

    理解出错之处望不吝指正. 本文模型叫做SANet.作者在论文中提到,CNN模型主要适用于类间判别,对于相似物体的判别能力不强.作者提出使用RNN对目标物体的self-structure进行建模,用于提 ...

  9. ICCV 2017 UCT:《UCT: Learning Unified Convolutional Networks forReal-time Visual Tracking》论文笔记

    理解出错之处望不吝指正. 本文模型叫做UCT.就像论文题目一样,作者提出了一个基于卷积神经网络的end2end的tracking模型.模型的整体结构如下图所示(图中实线代表online trackin ...

  10. CVPR 2018 STRCF:《Learning Spatial-Temporal Regularized Correlation Filters for Visual Tracking》论文笔记

    理解出错之处望不吝指正. 本文提出的模型叫做STRCF. 在DCF中存在边界效应,SRDCF在DCF的基础上中通过加入spatial惩罚项解决了边界效应,但是SRDCF在tracking的过程中要使用 ...

最新文章

  1. AlphaImageLoader用法
  2. Kattis之旅——Prime Reduction
  3. eclipse创建了java web项目后怎么连接mysql
  4. 使用率激增250%,这份报告再次将 Serverless 推向幕前
  5. gRPC in ASP.NET Core 3.x - gRPC 简介(2)
  6. 成功跳槽百度工资从15K涨到28K,已整理成文档
  7. matlab话pca的双标图biplot,r – 用ggplot2绘制pca biplot
  8. 18春东大计算机在线作业3,东大18春学期《计算机辅助数控编程》在线作业3.docx...
  9. composer安装Workerman报错:Installation failed, reverting ./composer.json to its original content....
  10. jedis连接池的maxIdle和maxtotal参数
  11. 知网被引第一、第二的论文,都出自这位双一流大学教授
  12. 【NLP】之 结巴分词
  13. Day02 郝斌C语言自学视频之C语言编程预备知识
  14. [深度学习概念]·主流声学模型对比
  15. 检测到 Mac 文件格式: 请将源文件转换为 DOS 格式或 UNIX 格式
  16. 时之歌 服务器维护,时之歌抽卡卡住了怎么办 时之歌手游招募吞卡解决方法
  17. 抓取网易云音乐网页歌单(url)js
  18. Git分布式版本管理工具
  19. THREE.JS 与其他库的对比
  20. 小白学python.1

热门文章

  1. 翻译: 网页排名PageRank算法的来龙去脉 以及 Python实现
  2. 什么是pageRank
  3. 想要不被裁,看一看 13 年华为老兵的宝贵经验
  4. pdf.js使用方法
  5. 显示器的 VGA、HDMI、DVI 和 DisplayPort
  6. ASP.NET Core 托管和部署(一)【Kestrel】
  7. WinSCP拒绝访问问题
  8. 知识图谱-生物信息学-医学顶刊论文(Bioinformatics-2021)-SumGNN:通过有效的KG聚集进行多类型DDI预测
  9. 刘德华五条抖音粉丝五千万,流量平台用什么“留量”?
  10. Nervos 双周报第 4 期:经济白皮书的发件小哥正在路上