《RAFT:Recurrent All-Pairs Field Transforms for Optical Flow》论文笔记
参考代码:RAFT
作者主页:Zachary Teed
1. 概述
导读:这篇文章提出了一种新的光流估计pipline,与之前介绍的PWC-Net类似其也包含特征抽取/correlation volume构建操作。在这篇文章中为了优化光流估计,首先在correlation volume的像素上进行邻域采样得到lookups特征(增强特征相关性,也可以理解为感受野),之后直接使用以CNN-GRU为基础的迭代优化网络,在完整尺寸上对光流估计迭代优化。这样尽管采用了迭代优化的形式,文章的迭代优化机制也比像IRR/FlowNet这类方法轻量化,运行速度也更快,其可以在1080 TI GPU上达到10FPS(输入为1088∗4361088*4361088∗436)。文章的算法在诸如特征处理与融合/上采样策略上设计得细致合理,并且使用迭代优化的策略,从而使得文章算法具有较好的泛化性能。
将文章的方法与之前的一些方法作对比,可以将其中对比得到的改进点归纳如下:
- 1)抛弃了类似PWC-Net中的coarse-to-fine的光流迭代优化策略,直接生成全尺寸的光流估计,从而避免了这种优化策略带来的弊端:coarse层次的预测结果会天然增加丢失小而快速运动的目标的风险,并且训练需要的迭代次数也更多;
- 2)为了提升光流估计的准确性,一种可行的方式就是进行module的叠加优化,如FlowNet和IRR等,但是这样的操作一个是带来更多的参数量,增加运算的时间。还会使得整个网络的训练过程变得繁琐冗长;
- 3)光流的更新模块,文章使用以CNN-GRU为基础,在4D的correlation volume上对其采样得到的correlation lookups进行运算,从而得到光流信息。这样的更新模块引入了GRU网络,很好利用了迭代优化的时序特性;
将文章的方法与其它的一些光流估计方法进行比较:
2. 方法设计
2.1 整体pipline
文章的整体pipeline如下:
按照上图所示可以将整体pipeline划分为3个部分(阶段):
- 1)feature encoder进行输入图像的抽取,以及context encoder进行图像特征的抽取;
- 2)使用矩阵相乘的方式构建correlation volume,之后使用池化操作得到correlation volume pyramid;
- 3)对correlation volume在像素邻域上进行采样,之后使用以CNN-GRU为基础构建的光流迭代更新网络进行全尺寸光流估计;
文章按照网络容量的不同设计了一大一小的两个网络,后面的内容都是以大网络为基准,其网络结构为:
文章的整体流程简洁,直接在一个forawrd中完成了所有操作,其具体的步骤可以归纳为:
# core/raft.py#86
# step1:图像1/2的feature encoder特征抽取
fmap1, fmap2 = self.fnet([image1, image2]) # [N, 256, H//8, W//8]# step2:correlation volume pyramid构建
if self.args.alternate_corr:corr_fn = AlternateCorrBlock(fmap1, fmap2, radius=self.args.corr_radius)
else:corr_fn = CorrBlock(fmap1, fmap2, radius=self.args.corr_radius) # 输入两幅图像特征用于构造金字塔相似矩阵# step3:图像1的context encoder特征抽取
cnet = self.cnet(image1)
net, inp = torch.split(cnet, [hdim, cdim], dim=1) # 对输出的特征进行划分
# 一部分用于递归优化的输入,一部分用于GRU递归优化的传递变量
net = torch.tanh(net) # [N, 256, H//8, W//8]
inp = torch.relu(inp) # [N, 256, H//8, W//8]# step4:以图像1经过编码之后的尺度构建两个一致的坐标网格
coords0, coords1 = self.initialize_flow(image1) # 一个用于更新(使用每次迭代预测出来的光流),一个用于作为基准if flow_init is not None: # 若初始光流不为空,则用其更新初始光流coords1 = coords1 + flow_init# step5:进行光流更新迭代
flow_predictions = []
for itr in range(iters):coords1 = coords1.detach()# 在坐标网格的基础上对correlation volume pyramid进行半径r=4的邻域采样corr = corr_fn(coords1) # index correlation volume [N, (2*r+1)*(2*r+1)*num_levels, H//8, W//8]# 使用CNN-GRU计算光流偏移量与上采样系数等flow = coords1 - coords0with autocast(enabled=self.args.mixed_precision):net, up_mask, delta_flow = self.update_block(net, inp, corr, flow) # 迭代之后的特征/采样权重/预测光流偏移# F(t+1) = F(t) + \Delta(t)coords1 = coords1 + delta_flow # 更新光流# upsample predictionsif up_mask is None:flow_up = upflow8(coords1 - coords0) # 普通的上采样方式else:flow_up = self.upsample_flow(coords1 - coords0, up_mask) # 使用卷积构造的上采样方式flow_predictions.append(flow_up) # 保存当前迭代次数的光流优化结果
2.2 correlation volume
这里主要讲述correlation volume的构建过程,之后在其基础上进行邻域采样构建correlation lookups(用于提升光流信息的特征相关性),以及提出一种更加高效的correlation volume构建方式(减少计算复杂度)。这里的编码器特征抽取部分省略。。。(其输出的维度为:[N,256,H//8,H//8][N,256,H//8,H//8][N,256,H//8,H//8])
构建过程:
correlation volume的构建过程其实是一个矩阵相乘形式:
# core/corr.py#53
def corr(fmap1, fmap2):batch, dim, ht, wd = fmap1.shapefmap1 = fmap1.view(batch, dim, ht*wd)fmap2 = fmap2.view(batch, dim, ht*wd) corr = torch.matmul(fmap1.transpose(1,2), fmap2) # 图像1/2的特征矩阵乘 [batch, ht*wd, ht*wd]corr = corr.view(batch, ht, wd, 1, ht, wd) # [batch, ht, wd, 1, ht, wd]return corr / torch.sqrt(torch.tensor(dim).float())
在此基础上使用池化操作得到correlation volume,这里使用到的层级为4(池化操作的kernel size为{1,2,4,8}\{1,2,4,8\}{1,2,4,8})。也就如下图所示:
correlation lookups的构建:
这里为了增加correlation volume中每个像素对周围像素的感知能力,使用半径r=4r=4r=4的邻域对correlation volume中每个像素进行采样,之后再组合起来。其实现可以参考:
# core/corr.py#29
def __call__(self, coords):r = self.radiuscoords = coords.permute(0, 2, 3, 1) # flow的idx坐标信息permutebatch, h1, w1, _ = coords.shape # [batch, h1, w1, 2]out_pyramid = []for i in range(self.num_levels):corr = self.corr_pyramid[i]dx = torch.linspace(-r, r, 2*r+1) # 构造邻域采样空间dy = torch.linspace(-r, r, 2*r+1)delta = torch.stack(torch.meshgrid(dy, dx), axis=-1).to(coords.device)centroid_lvl = coords.reshape(batch*h1*w1, 1, 1, 2) / 2**i # 将光流缩放到对应的金字塔尺度上去delta_lvl = delta.view(1, 2*r+1, 2*r+1, 2) # 邻域区域,邻域半径r=4coords_lvl = centroid_lvl + delta_lvl # 在光流基础上加上邻域区域的偏置,[batch*h1*w1, 2*r+1, 2*r+1, 2]# 在邻域上对correlation volume在坐标coords_lvl引导下进行双线性采样corr = bilinear_sampler(corr, coords_lvl) # [batch*h1*w1, 1, 2*r+1, 2*r+1]corr = corr.view(batch, h1, w1, -1) # [batch, h1, w1, (2*r+1) * (2*r+1)]out_pyramid.append(corr)out = torch.cat(out_pyramid, dim=-1) # [batch, h1, w1, (2*r+1) * (2*r+1) * num_levels]return out.permute(0, 3, 1, 2).contiguous().float() # [batch, (2*r+1) * (2*r+1) * num_levels, h1, w1]
更加高效的correlation构建:
在之前的correlation volume构建过程中是直接在编码器输出的特征图上运算,其计算的复杂度为O(N2)O(N^2)O(N2),其中NNN是特征图上像素的个数(W//8∗H//8W//8*H//8W//8∗H//8,channel=1)。之后这个保持不变,使用不同kernel size的池化操作迭代计算MMM次(也就是金字塔的层级)。那么对此文章对于层级为mmm处correlation volume的计算其实是可以描述为下面的形式的:
Cijklm=122m∑p2m∑q2m⟨gi,j(1),g2mk+p,2ml+q(2)⟩=⟨gi,j(1),122m(∑p2m∑q2mg2mk+p,2ml+q(2))⟩C_{ijkl}^m=\frac{1}{2^{2m}}\sum_p^{2^m}\sum_q^{2^m}\langle g_{i,j}^{(1)},g_{2^mk+p,2^ml+q}^{(2)}\rangle=\langle g_{i,j}^{(1)},\frac{1}{2^{2m}}(\sum_p^{2^m}\sum_q^{2^m}g_{2^mk+p,2^ml+q}^{(2)})\rangleCijklm=22m1p∑2mq∑2m⟨gi,j(1),g2mk+p,2ml+q(2)⟩=⟨gi,j(1),22m1(p∑2mq∑2mg2mk+p,2ml+q(2))⟩
也就是图像1的特征与图像2图像块avg-pooling之后的特征进行计算,进而可以减少计算复杂度,变为$O(NM)。其实现可以参考类:
#core/corr.py#63
class AlternateCorrBlock:...
以及目录alt_cuda_corr
下的实现。
2.3 迭代更新机制
文章的光流估计是采用迭代更新的机制实现的,也就是在一个迭代序列中会生成光流序列{f1,…,fN}\{f_1,\dots,f_N\}{f1,…,fN},初始情况下f0=0f_0=0f0=0,每次迭代之后的更新量描述为Δf\Delta fΔf,那么其更新过程描述为:
fk+1=fk+Δff_{k+1}=f_k+\Delta ffk+1=fk+Δf
迭代更新的初始值:
在缺省情况下文章的方法是使用0作迭代的初始光流。当然也是可以接受用一个先验光流作为输入,并在该基础上进行更新迭代,也就是文章提到的warm-start。
迭代更新的输入:
在上文中的网络结构图中可以看到输入CNN-GRU网络模块中的信息是包含3个:context encoder的输出特征/上一次光流的迭代结果/correlation volume在邻域内的采样结果。它们在网络中通过concat的形式进行特征融合,融合之后的特征记为xtx_txt。
特征更新过程:
光流在进行估计之前会经过CNN-GRU模块,这里采用的是Separate的形式,也就是分离的大卷积核(减少参数量的同时,增大感受野)。这里循环递归的隐变量为hth_tht,它初始的时候使用context encoder产生。其在GRU模块中更新的过程可以描述为:
zt=σ(Conv3∗3([ht−1,xt],Wz))z_t=\sigma(Conv_{3*3}([h_{t-1},x_t],W_z))zt=σ(Conv3∗3([ht−1,xt],Wz))
rt=σ(Conv3∗3([ht−1,xt],Wr))r_t=\sigma(Conv_{3*3}([h_{t-1},x_t],W_r))rt=σ(Conv3∗3([ht−1,xt],Wr))
hˉt=tanh(Conv3∗3[rt⊙ht−1,xt],Wh)\bar{h}_t=tanh(Conv_{3*3}[r_t\odot h_{t-1},x_t],W_h)hˉt=tanh(Conv3∗3[rt⊙ht−1,xt],Wh)
ht=(1−zt)⊙ht−1+zt⊙hˉth_t=(1-z_t)\odot h_{t-1}+z_t\odot \bar{h}_tht=(1−zt)⊙ht−1+zt⊙hˉt
光流更新量估计:
这里估计的光流其实是一个偏移量Δf\Delta fΔf,相对坐标矩阵来讲的。最后将变换后的坐标矩阵减去之前的基准就得到了最后的光流估计。
全尺寸上采样:
这里将传统的双线性采样操作替换为了基于卷积的上采样操作,也就是每个全尺寸中的每个像素是通过在stride=8的光流预测结果对应位置处采样3∗33*33∗3的块,之后通过预测出来的权值进行加权组合。其采样加权的示意图如下:
其实现代码参考:
# core/raft.py#72
def upsample_flow(self, flow, mask):""" Upsample flow field [H/8, W/8, 2] -> [H, W, 2] using convex combination """N, _, H, W = flow.shapemask = mask.view(N, 1, 9, 8, 8, H, W) # [N, 8*8*9, H, W]->[N, 1, 9, 8, 8, H, W]mask = torch.softmax(mask, dim=2) # 当前像素与邻域像素的权值up_flow = F.unfold(8 * flow, [3, 3], padding=1) # 对光流进行窗口滑块提取[N, 3*3*2, H, W](提前乘上采样系数8)up_flow = up_flow.view(N, 2, 9, 1, 1, H, W) # [N, 3*3*2, H, W]->[N, 2, 9, 1, 1, H, W]up_flow = torch.sum(mask * up_flow, dim=2) # 当前像素与邻域进行加权计算, [N, 2, 9, 1, 1, H, W]->[N, 2, 8, 8, H, W]up_flow = up_flow.permute(0, 1, 4, 2, 5, 3) # [N, 2, 8, 8, H, W]->[N, 2, H, 8, W, 8]return up_flow.reshape(N, 2, 8*H, 8*W) # 得到输入图像分辨率的光流估计
将直接双线性上采样的结果与这里提出的上采样结果进行对比如下:
2.4 损失函数
文章的方法是迭代更新的,因而会生成许多个序列的光流估计结果{f1,…,fN}\{f_1,\dots,f_N\}{f1,…,fN},对此其损失函数描述为:
L=∑i=1NγN−i∣fgt−fi∣1L=\sum_{i=1}^N\gamma^{N-i}|f_{gt}-f_i|_1L=i=1∑NγN−i∣fgt−fi∣1
其中,γ=0.8\gamma=0.8γ=0.8。
3. 实验结果
KITTI数据集上的性能比较:
消融实验:
训练策略:
《RAFT:Recurrent All-Pairs Field Transforms for Optical Flow》论文笔记相关推荐
- 论文笔记之Understanding and Diagnosing Visual Tracking Systems
Understanding and Diagnosing Visual Tracking Systems 论文链接:http://dwz.cn/6qPeIb 本文的主要思想是为了剖析出一个跟踪算法中到 ...
- 《Understanding and Diagnosing Visual Tracking Systems》论文笔记
本人为目标追踪初入小白,在博客下第一次记录一下自己的论文笔记,如有差错,恳请批评指正!! 论文相关信息:<Understanding and Diagnosing Visual Tracking ...
- 论文笔记Understanding and Diagnosing Visual Tracking Systems
最近在看目标跟踪方面的论文,看到王乃岩博士发的一篇分析跟踪系统的文章,将目标跟踪系统拆分为多个独立的部分进行分析,比较各个部分的效果.本文主要对该论文的重点的一个大致翻译,刚入门,水平有限,如有理解错 ...
- 目标跟踪笔记Understanding and Diagnosing Visual Tracking Systems
Understanding and Diagnosing Visual Tracking Systems 原文链接:https://blog.csdn.net/u010515206/article/d ...
- 追踪系统分模块解析(Understanding and Diagnosing Visual Tracking Systems)
追踪系统分模块解析(Understanding and Diagnosing Visual Tracking Systems) PROJECT http://winsty.net/tracker_di ...
- ICCV 2015 《Understanding and Diagnosing Visual Tracking Systems》论文笔记
目录 写在前面 文章大意 一些benchmark 实验 实验设置 基本模型 数据集 实验1 Featrue Extractor 实验2 Observation Model 实验3 Motion Mod ...
- Understanding and Diagnosing Visual Tracking Systems
文章把一个跟踪器分为几个模块,分别为motion model, feature extractor, observation model, model updater, and ensemble po ...
- CVPR 2017 SANet:《SANet: Structure-Aware Network for Visual Tracking》论文笔记
理解出错之处望不吝指正. 本文模型叫做SANet.作者在论文中提到,CNN模型主要适用于类间判别,对于相似物体的判别能力不强.作者提出使用RNN对目标物体的self-structure进行建模,用于提 ...
- ICCV 2017 UCT:《UCT: Learning Unified Convolutional Networks forReal-time Visual Tracking》论文笔记
理解出错之处望不吝指正. 本文模型叫做UCT.就像论文题目一样,作者提出了一个基于卷积神经网络的end2end的tracking模型.模型的整体结构如下图所示(图中实线代表online trackin ...
- CVPR 2018 STRCF:《Learning Spatial-Temporal Regularized Correlation Filters for Visual Tracking》论文笔记
理解出错之处望不吝指正. 本文提出的模型叫做STRCF. 在DCF中存在边界效应,SRDCF在DCF的基础上中通过加入spatial惩罚项解决了边界效应,但是SRDCF在tracking的过程中要使用 ...
最新文章
- AlphaImageLoader用法
- Kattis之旅——Prime Reduction
- eclipse创建了java web项目后怎么连接mysql
- 使用率激增250%,这份报告再次将 Serverless 推向幕前
- gRPC in ASP.NET Core 3.x - gRPC 简介(2)
- 成功跳槽百度工资从15K涨到28K,已整理成文档
- matlab话pca的双标图biplot,r – 用ggplot2绘制pca biplot
- 18春东大计算机在线作业3,东大18春学期《计算机辅助数控编程》在线作业3.docx...
- composer安装Workerman报错:Installation failed, reverting ./composer.json to its original content....
- jedis连接池的maxIdle和maxtotal参数
- 知网被引第一、第二的论文,都出自这位双一流大学教授
- 【NLP】之 结巴分词
- Day02 郝斌C语言自学视频之C语言编程预备知识
- [深度学习概念]·主流声学模型对比
- 检测到 Mac 文件格式: 请将源文件转换为 DOS 格式或 UNIX 格式
- 时之歌 服务器维护,时之歌抽卡卡住了怎么办 时之歌手游招募吞卡解决方法
- 抓取网易云音乐网页歌单(url)js
- Git分布式版本管理工具
- THREE.JS 与其他库的对比
- 小白学python.1
热门文章
- 翻译: 网页排名PageRank算法的来龙去脉 以及 Python实现
- 什么是pageRank
- 想要不被裁,看一看 13 年华为老兵的宝贵经验
- pdf.js使用方法
- 显示器的 VGA、HDMI、DVI 和 DisplayPort
- ASP.NET Core 托管和部署(一)【Kestrel】
- WinSCP拒绝访问问题
- 知识图谱-生物信息学-医学顶刊论文(Bioinformatics-2021)-SumGNN:通过有效的KG聚集进行多类型DDI预测
- 刘德华五条抖音粉丝五千万,流量平台用什么“留量”?
- Nervos 双周报第 4 期:经济白皮书的发件小哥正在路上