【3D detection】CT3D部分代码的理解
【3D detection】CT3D部分代码的理解
- 获得box的八个角点坐标
- 在无限高的圆柱中随机采样
- Embedding和Encoding
paper: Improving 3D Object Detection with Channel-wise Transformer
code:https://github.com/hlsheng1/CT3D
获得box的八个角点坐标
在/pcdet/models/roi_heads/ct3d_head.py
中 130行左右
# corner#输入的rois大小为(batch_size,roi的个数,roi信息)(2,128,7)#其中7包括roi的x,y,z,l,w,h,方向角#最后得到的corner_points是(B, 128, 2*2*2, 3),每个角点的坐标corner_points, _ = self.get_global_grid_points_of_roi(rois) # (BxN, 2x2x2, 3)corner_points = corner_points.view(batch_size, num_rois, -1, corner_points.shape[-1]) # (2, 128, 2x2x2, 3)
下面看一下 self.get_global_grid_points_of_roi(rois)
这个函数,同样在class CT3DHead
这个类中
def get_global_grid_points_of_roi(self, rois):rois = rois.view(-1, rois.shape[-1]) # (256,7)batch_size_rcnn = rois.shape[0] # (256)# 得到了box八个角点到中心点的坐标距离# 中心点的三个轴坐标只要在每个轴加上相应的坐标距离# 就能得到每一个角点的坐标了。local_roi_grid_points = self.get_corner_points(rois, batch_size_rcnn) #(256,8,3)#注意!!!#上面得到的只考虑了box的大小,并没有考虑角度#如果直接拿来计算的话,得到的所有角点都不是真正的,而是平行于坐标轴的!#所以还要通过下面的函数将每个roi的角度考虑信息。#最后得到真正的角点全局坐标global_roi_grid_points = common_utils.rotate_points_along_z(local_roi_grid_points.clone(), rois[:, 6]).squeeze(dim=1) #(256, 8 ,3)#得到中心点的全局坐标global_center = rois[:, 0:3].clone() # (256, 3)# pdb.set_trace()#中心点加坐标距离最终得到每个box八个点的角点global_roi_grid_points += global_center.unsqueeze(dim=1) #(256,8,3 )#(256,8,3) (256,8,3)return global_roi_grid_points, local_roi_grid_points
看一下 self.get_corner_points(rois, batch_size_rcnn)
这个函数,同样在class CT3DHead
这个类中
@staticmethoddef get_corner_points(rois, batch_size_rcnn):#得到一个(2, 2, 2) 的tensor,用来在后面求box八个角点的索引 ,这里还是很巧妙的,可以学习一下。faked_features = rois.new_ones((2, 2, 2))# nonzero() 返回每个这个tensor中非零元素的索引# 例如这里建立了一个(2,2,2)的全是1的tensor,每一个元素都是非零的,#所以得到的非零元素的索引为 [0,0,0], [0,0,1], [0,1,0], [0,1,1] .....[1,1,1] 共8个dense_idx = faked_features.nonzero() # (8, 3) [x_idx, y_idx, z_idx]dense_idx = dense_idx.repeat(batch_size_rcnn, 1, 1).float() # (256, 2x2x2, 3)# 取每一个RoI的长宽高local_roi_size = rois.view(batch_size_rcnn, -1)[:, 3:6] # (256,3)#这一步是关键,求每个角点在每个轴上相对于roi中心点的距离#例如:[0, 0, 1] * [l, w, h] - [l/2,w/2, h/2]# = [-l/2, -w/2, h/2]#即这个角点到中心点的坐标距离#中心点坐标 在x轴减去一半长,在y轴减去一半宽,在z轴加上一般高。就是这个角点的坐标了。roi_grid_points = dense_idx * local_roi_size.unsqueeze(dim=1) \- (local_roi_size.unsqueeze(dim=1) / 2) # (B, 2x2x2, 3)return roi_grid_points #(256,8,3)
这个函数得到了box每一个角点到中心点的坐标距离,中心点的三个轴坐标只要在每个轴加上相应的坐标距离,就能得到每一个角点的坐标了。
在无限高的圆柱中随机采样
论文中在RoI中采样点时,将RoI转换为一个高度无限高的圆柱体,并从中随机采样256个点作为RoI的表征。采样方法如下
圆柱体底面的半径为:
代码如下
num_sample = self.num_points # 这里是256个点src = rois.new_zeros(batch_size, num_rois, num_sample, 4) #(2,128,256,4)for bs_idx in range(batch_size): # 每次循环一个batch_size# batch_dict[points]是所有batch_size中所有的点# batch_dict[points][(batch_dict['points'][:, 0] == bs_idx)]是第bs_idx个batch中的点# batch_dict[points][(batch_dict['points'][:, 0] == bs_idx)][:,1:5]是每个点的坐标和反射强度#cur_points是一个batch中所有点的坐标+反射强度(194165,4)cur_points = batch_dict['points'][(batch_dict['points'][:, 0] == bs_idx)][:,1:5] #(194165,4)# 每个batch的roi boxcur_batch_boxes = batch_dict['rois'][bs_idx] #(128,7)# 求半径公式如上图所示 (128)cur_radiis = torch.sqrt((cur_batch_boxes[:,3]/2) ** 2 + (cur_batch_boxes[:,4]/2) ** 2) * 1.2# 所有点到roi box中心的距离,# 共128个RoI,每一个RoI 都计算19165个点到其中心的距离dis = torch.norm((cur_points[:,:2].unsqueeze(0) - cur_batch_boxes[:,:2].unsqueeze(1).repeat(1,cur_points.shape[0],1)), dim = 2) # (128,19165)#过滤出半径内的点point_mask = (dis <= cur_radiis.unsqueeze(-1))# 遍历每一个roifor roi_box_idx in range(0, num_rois):# point_mask[roi_box_idx] 是第roi_box_idx个roi的mask#cur_roi_points 这里是(465,4),即这个圆柱roi中有465个点cur_roi_points = cur_points[point_mask[roi_box_idx]]# 如果roi内部的点大于要求采样的数量的话,就随机取256个(源码中num_sample=256)# 如果roi内部的点个数等于0,就用0填充256个# 如果roi内部点在0到256之间,采样所有点,剩余的用0填充if cur_roi_points.shape[0] >= num_sample:random.seed(0)index = np.random.randint(cur_roi_points.shape[0], size=num_sample)cur_roi_points_sample = cur_roi_points[index]elif cur_roi_points.shape[0] == 0:cur_roi_points_sample = cur_roi_points.new_zeros(num_sample, 4)else:empty_num = num_sample - cur_roi_points.shape[0]add_zeros = cur_roi_points.new_zeros(empty_num, 4)add_zeros = cur_roi_points[0].repeat(empty_num, 1)cur_roi_points_sample = torch.cat([cur_roi_points, add_zeros], dim = 0)#这个roi的采样结束,记录到src中src[bs_idx, roi_box_idx, :, :] = cur_roi_points_sample#采样结束,经过整理得到src,共b个batch,每个batch中128个roi,每个roi取256个点(x,y,z,r)src = src.view(batch_size * num_rois, -1, src.shape[-1]) # (b*128, 256, 4)
Embedding和Encoding
论文的主要结构图如下所示:
得到了圆柱体内部的采样点,和box角点,中心点的信息。可以进行embedding了
如下所示
以一个roi举例,roi内部256个采样点,每一个点都计算自己与roi八个角点以及中心点的距离,再包括自己的反射强度,通过线形层升维。具体公式如下所示:对于点采样点pi, fi是对该点的embedding
Enbedding的代码如下:
# src是采样得到的roi内部的点src = src.view(batch_size * num_rois, -1, src.shape[-1]) # (b*128, 256, 4)#corner_points是上一步得到的roi八个角点坐标(256,8,3) -->(256,24)corner_points = corner_points.view(batch_size * num_rois, -1)#将每个roi的中心点坐标concat到八个角点坐标上# corner_add_center_points (256,24) -->(256,27)corner_add_center_points = torch.cat([corner_points, rois.view(-1, rois.shape[-1])[:,:3]], dim = -1)# 计算每个点与角点中心点得坐标距离#pos_fea(b*roi, num_sample, 27)pos_fea = src[:,:,:3].repeat(1,1,9) - corner_add_center_points.unsqueeze(1).repeat(1,num_sample,1) # 27 维度#roi的长宽高 lwh (b*roi, num_sample, 3)lwh = rois.view(-1, rois.shape[-1])[:,3:6].unsqueeze(1).repeat(1,num_sample,1)# (l*l + w*w + h*h) ** 0.5diag_dist = (lwh[:,:,0]**2 + lwh[:,:,1]**2 + lwh[:,:,2]**2) ** 0.5# pos_fea(256,256,27) 每一个点到角点,中心点的 球形坐标距离pos_fea = self.spherical_coordinate(pos_fea, diag_dist = diag_dist.unsqueeze(-1))# src(256,256,28)src = torch.cat([pos_fea, src[:,:,-1].unsqueeze(-1)], dim = -1)#src(256,256,256)src = self.up_dimension(src)
下面看一下 pos_fea = self.spherical_coordinate(pos_fea, diag_dist = diag_dist.unsqueeze(-1))
这个函数:
def spherical_coordinate(self, src, diag_dist):# src(256,256,27)每个采样点到角点中心点的坐标距离# diag_dist(256,256,1)assert (src.shape[-1] == 27)device = src.device#分别取每个采样点到角点中心点坐标距离indices_x = torch.LongTensor([0,3,6,9,12,15,18,21,24]).to(device) #indices_y = torch.LongTensor([1,4,7,10,13,16,19,22,25]).to(device) #indices_z = torch.LongTensor([2,5,8,11,14,17,20,23,26]).to(device) src_x = torch.index_select(src, -1, indices_x) # (256, 256, 9)src_y = torch.index_select(src, -1, indices_y) # (256, 256, 9)src_z = torch.index_select(src, -1, indices_z) # (256, 256, 9)#把坐标距离转换成球形坐标距离??dis = (src_x ** 2 + src_y ** 2 + src_z ** 2) ** 0.5phi = torch.atan(src_y / (src_x + 1e-5))the = torch.acos(src_z / (dis + 1e-5))dis = dis / diag_dist#src(256,256,256)球形坐标距离??src = torch.cat([dis, phi, the], dim = -1)return src
这里为什么要把坐标系转换成球形坐标,作者在github上解释说,换不换没有什么差别,所以在论文中就没有提到…
Tranformer如下所示
def forward(self, src, query_embed, pos_embed):# src (256,256,256)# query_embed (1,256)# pos_embed (256,256,256) 位置编码在Emcoder中q,k相加bs, n, c = src.shapesrc = src.permute(1, 0, 2)pos_embed = pos_embed.permute(1, 0, 2)query_embed = query_embed.unsqueeze(1).repeat(1, bs, 1) # (1,256,256)tgt = torch.zeros_like(query_embed) #(1,256,256)# memory (256,256,256)memory = self.encoder(src, src_key_padding_mask=None, pos=pos_embed)#因为在整体代码中,TransformerDecoder的初始化参数return_intermediate设置为True#因此,Decoder的输出包含了每层的结果,共有一层,shape是[1,num_querie,batch_size,hidden_dim]hs = self.decoder(tgt, memory, memory_key_padding_mask=None,pos=pos_embed, query_pos=query_embed)#(1,256,1,256)return hs.transpose(1, 2), memory.permute(1, 2, 0).view(bs, c, n)
Encoder模块没有什么可说的,和DETR一样的,共有三层。encoder得到memory(256,256,256)
,256个roi,每个roi256个采样点,通道数为256。与输入相同
Decoder共一层,与DETR一样,输入的num_query = 1 ,代表每个roi里256个采样点生成一个box 。
k,v来自encoder,每一个roi得到一个box。
唯一不同的是,decoder在计算self-attention时候。公式由标准的
变成了论文中的:
再乘以V
代码如下
def attention(query, key, value):dim = query.shape[1]scores_1 = torch.einsum('bdhn,bdhm->bhnm', query, key) / dim**.5scores_2 = torch.einsum('abcd, aced->abcd', key, scores_1)prob = torch.nn.functional.softmax(scores_2, dim=-1)output = torch.einsum('bnhm,bdhm->bdhn', prob, value)return output, prob
【3D detection】CT3D部分代码的理解相关推荐
- 3D Detection 论文汇总
来源丨AI 修炼之路 这篇文章主要是梳理一下近期3D Detection的进展,分类列举出一些我认为的比较重要的.有代表性的工作. 一.论文分类汇总 1. 基于激光雷达点云的3D检测方法(LiDAR ...
- 复现KM3D:Monocular 3D Detection with Geometric Constraints Embedding and Semi-supervised Training
复现KM3D:Monocular 3D Detection with Geometric Constraints Embedding and Semi-supervised Training 时间:2 ...
- 【论文阅读】【3d目标检测】Sparse Fuse Dense: Towards High Quality 3D Detection with Depth Completion
论文题目:Sparse Fuse Dense: Towards High Quality 3D Detection with Depth Completion 飞步科技 cvpr2022 kitti ...
- CenterNet-TensorRT 3D Detection
点击上方"3D视觉工坊",选择"星标" 干货第一时间送达 作者丨Panzerfahrer@知乎 来源丨https://zhuanlan.zhihu.com/p/ ...
- SSD-6D: Making RGB-Based 3D Detection and 6D Pose Estimation Great Again—2017(笔记)
SSD-6D: Making RGB-Based 3D Detection and 6D Pose Estimation Great Again-2017(笔记) SSD-6D让RGB图像的3D检测和 ...
- 基于2.5/3D的自主主体室内场景理解研究
作者:Tom Hardy Date:2020-3-13 来源:基于2.5/3D的自主主体室内场景理解研究 参考链接:https://arxiv.org/abs/1803.03352 主要内容 摘要随着 ...
- python绘制三维曲面图-python中Matplotlib实现绘制3D图的示例代码
Matplotlib 也可以绘制 3D 图像,与二维图像不同的是,绘制三维图像主要通过 mplot3d 模块实现.但是,使用 Matplotlib 绘制三维图像实际上是在二维画布上展示,所以一般绘制三 ...
- python 协程 php,python3.x,协程_python协程练习部分代码的理解?,python3.x,协程,asyncio - phpStudy...
python协程练习部分代码的理解? import asyncio import threading async def wget(host): print('wget {}'.format(host ...
- python绘制四边螺旋线代_Python绘制3d螺旋曲线图实例代码
Line plots Axes3D.plot(xs, ys, *args, **kwargs) 绘制2D或3D数据 参数 描述 xs, ys X轴,Y轴坐标定点 zs Z值,每一个点的值都是1 zdi ...
最新文章
- SQL Server 2005下的分页SQL
- 香河php程序员_失控的香河最流行的四大职位
- 0.0 目录-深度学习第二课《改善神经网络》-Stanford吴恩达教授
- 基于注解的Spring AOP的配置和使用--转载
- Puppet 的部署与应用,看这一篇就够了
- 银行流水你真的会看吗?
- dubbo系列(一)
- 手把手教你如下在Linux下如何写一个C语言代码,编译并运行
- 手游建筑美术资源_建筑商和机械手
- linux下expdp定时备份_Linux下定时任务的配置
- linux分屏显示文件行数,linux常用命令集合1
- Improving Opencv9 Eroding and Dilating 和对opencv窗体上有控制按钮的理解
- ArduCopter——ArduPilot——Notch Filter(陷波滤波器)
- MATLAB 电子书
- 排列组合——排列公式的推理和组合
- 微信 支付 h5 开发 使用 best-pay-sdk
- 电脑连接上wifi,但是无法打开网页上网,小记
- 倪明选:追忆似水流年,祝愿更加辉煌
- activiti学习之回退实现
- Laravel Trait method broker has not been applied, because there are collisions with other trait meth