【3D detection】CT3D部分代码的理解

  • 获得box的八个角点坐标
  • 在无限高的圆柱中随机采样
  • Embedding和Encoding

paper: Improving 3D Object Detection with Channel-wise Transformer
code:https://github.com/hlsheng1/CT3D

获得box的八个角点坐标

/pcdet/models/roi_heads/ct3d_head.py130行左右

        # corner#输入的rois大小为(batch_size,roi的个数,roi信息)(2,128,7)#其中7包括roi的x,y,z,l,w,h,方向角#最后得到的corner_points是(B, 128, 2*2*2, 3),每个角点的坐标corner_points, _ = self.get_global_grid_points_of_roi(rois)  # (BxN, 2x2x2, 3)corner_points = corner_points.view(batch_size, num_rois, -1, corner_points.shape[-1])  # (2, 128, 2x2x2, 3)

下面看一下 self.get_global_grid_points_of_roi(rois)这个函数,同样在class CT3DHead这个类中

  def get_global_grid_points_of_roi(self, rois):rois = rois.view(-1, rois.shape[-1]) # (256,7)batch_size_rcnn = rois.shape[0]      # (256)# 得到了box八个角点到中心点的坐标距离# 中心点的三个轴坐标只要在每个轴加上相应的坐标距离# 就能得到每一个角点的坐标了。local_roi_grid_points = self.get_corner_points(rois, batch_size_rcnn) #(256,8,3)#注意!!!#上面得到的只考虑了box的大小,并没有考虑角度#如果直接拿来计算的话,得到的所有角点都不是真正的,而是平行于坐标轴的!#所以还要通过下面的函数将每个roi的角度考虑信息。#最后得到真正的角点全局坐标global_roi_grid_points = common_utils.rotate_points_along_z(local_roi_grid_points.clone(), rois[:, 6]).squeeze(dim=1) #(256, 8 ,3)#得到中心点的全局坐标global_center = rois[:, 0:3].clone() # (256, 3)# pdb.set_trace()#中心点加坐标距离最终得到每个box八个点的角点global_roi_grid_points += global_center.unsqueeze(dim=1) #(256,8,3  )#(256,8,3)  (256,8,3)return global_roi_grid_points, local_roi_grid_points

看一下 self.get_corner_points(rois, batch_size_rcnn)这个函数,同样在class CT3DHead这个类中

 @staticmethoddef get_corner_points(rois, batch_size_rcnn):#得到一个(2, 2, 2) 的tensor,用来在后面求box八个角点的索引 ,这里还是很巧妙的,可以学习一下。faked_features = rois.new_ones((2, 2, 2))# nonzero() 返回每个这个tensor中非零元素的索引# 例如这里建立了一个(2,2,2)的全是1的tensor,每一个元素都是非零的,#所以得到的非零元素的索引为 [0,0,0], [0,0,1], [0,1,0], [0,1,1] .....[1,1,1] 共8个dense_idx = faked_features.nonzero()  # (8, 3) [x_idx, y_idx, z_idx]dense_idx = dense_idx.repeat(batch_size_rcnn, 1, 1).float()  # (256, 2x2x2, 3)# 取每一个RoI的长宽高local_roi_size = rois.view(batch_size_rcnn, -1)[:, 3:6] # (256,3)#这一步是关键,求每个角点在每个轴上相对于roi中心点的距离#例如:[0, 0, 1] * [l, w, h] - [l/2,w/2, h/2]# = [-l/2, -w/2, h/2]#即这个角点到中心点的坐标距离#中心点坐标 在x轴减去一半长,在y轴减去一半宽,在z轴加上一般高。就是这个角点的坐标了。roi_grid_points = dense_idx * local_roi_size.unsqueeze(dim=1) \- (local_roi_size.unsqueeze(dim=1) / 2)  # (B, 2x2x2, 3)return roi_grid_points #(256,8,3)

这个函数得到了box每一个角点到中心点的坐标距离,中心点的三个轴坐标只要在每个轴加上相应的坐标距离,就能得到每一个角点的坐标了。

在无限高的圆柱中随机采样

论文中在RoI中采样点时,将RoI转换为一个高度无限高的圆柱体,并从中随机采样256个点作为RoI的表征。采样方法如下

圆柱体底面的半径为:

代码如下

        num_sample = self.num_points  # 这里是256个点src = rois.new_zeros(batch_size, num_rois, num_sample, 4) #(2,128,256,4)for bs_idx in range(batch_size): # 每次循环一个batch_size# batch_dict[points]是所有batch_size中所有的点# batch_dict[points][(batch_dict['points'][:, 0] == bs_idx)]是第bs_idx个batch中的点# batch_dict[points][(batch_dict['points'][:, 0] == bs_idx)][:,1:5]是每个点的坐标和反射强度#cur_points是一个batch中所有点的坐标+反射强度(194165,4)cur_points = batch_dict['points'][(batch_dict['points'][:, 0] == bs_idx)][:,1:5] #(194165,4)# 每个batch的roi boxcur_batch_boxes = batch_dict['rois'][bs_idx] #(128,7)# 求半径公式如上图所示 (128)cur_radiis = torch.sqrt((cur_batch_boxes[:,3]/2) ** 2 + (cur_batch_boxes[:,4]/2) ** 2) * 1.2# 所有点到roi box中心的距离,# 共128个RoI,每一个RoI 都计算19165个点到其中心的距离dis = torch.norm((cur_points[:,:2].unsqueeze(0) - cur_batch_boxes[:,:2].unsqueeze(1).repeat(1,cur_points.shape[0],1)), dim = 2) # (128,19165)#过滤出半径内的点point_mask = (dis <= cur_radiis.unsqueeze(-1))# 遍历每一个roifor roi_box_idx in range(0, num_rois):# point_mask[roi_box_idx] 是第roi_box_idx个roi的mask#cur_roi_points 这里是(465,4),即这个圆柱roi中有465个点cur_roi_points = cur_points[point_mask[roi_box_idx]]# 如果roi内部的点大于要求采样的数量的话,就随机取256个(源码中num_sample=256)# 如果roi内部的点个数等于0,就用0填充256个# 如果roi内部点在0到256之间,采样所有点,剩余的用0填充if cur_roi_points.shape[0] >= num_sample:random.seed(0)index = np.random.randint(cur_roi_points.shape[0], size=num_sample)cur_roi_points_sample = cur_roi_points[index]elif cur_roi_points.shape[0] == 0:cur_roi_points_sample = cur_roi_points.new_zeros(num_sample, 4)else:empty_num = num_sample - cur_roi_points.shape[0]add_zeros = cur_roi_points.new_zeros(empty_num, 4)add_zeros = cur_roi_points[0].repeat(empty_num, 1)cur_roi_points_sample = torch.cat([cur_roi_points, add_zeros], dim = 0)#这个roi的采样结束,记录到src中src[bs_idx, roi_box_idx, :, :] = cur_roi_points_sample#采样结束,经过整理得到src,共b个batch,每个batch中128个roi,每个roi取256个点(x,y,z,r)src = src.view(batch_size * num_rois, -1, src.shape[-1])  # (b*128, 256, 4)

Embedding和Encoding

论文的主要结构图如下所示:

得到了圆柱体内部的采样点,和box角点,中心点的信息。可以进行embedding了
如下所示

以一个roi举例,roi内部256个采样点,每一个点都计算自己与roi八个角点以及中心点的距离,再包括自己的反射强度,通过线形层升维。具体公式如下所示:对于点采样点pi, fi是对该点的embedding

Enbedding的代码如下:

 # src是采样得到的roi内部的点src = src.view(batch_size * num_rois, -1, src.shape[-1])  # (b*128, 256, 4)#corner_points是上一步得到的roi八个角点坐标(256,8,3) -->(256,24)corner_points = corner_points.view(batch_size * num_rois, -1)#将每个roi的中心点坐标concat到八个角点坐标上# corner_add_center_points (256,24) -->(256,27)corner_add_center_points = torch.cat([corner_points, rois.view(-1, rois.shape[-1])[:,:3]], dim = -1)# 计算每个点与角点中心点得坐标距离#pos_fea(b*roi, num_sample, 27)pos_fea = src[:,:,:3].repeat(1,1,9) - corner_add_center_points.unsqueeze(1).repeat(1,num_sample,1)  # 27 维度#roi的长宽高 lwh (b*roi, num_sample, 3)lwh = rois.view(-1, rois.shape[-1])[:,3:6].unsqueeze(1).repeat(1,num_sample,1)# (l*l + w*w + h*h) ** 0.5diag_dist = (lwh[:,:,0]**2 + lwh[:,:,1]**2 + lwh[:,:,2]**2) ** 0.5# pos_fea(256,256,27) 每一个点到角点,中心点的 球形坐标距离pos_fea = self.spherical_coordinate(pos_fea, diag_dist = diag_dist.unsqueeze(-1))# src(256,256,28)src = torch.cat([pos_fea, src[:,:,-1].unsqueeze(-1)], dim = -1)#src(256,256,256)src = self.up_dimension(src)

下面看一下 pos_fea = self.spherical_coordinate(pos_fea, diag_dist = diag_dist.unsqueeze(-1))这个函数:

  def spherical_coordinate(self, src, diag_dist):# src(256,256,27)每个采样点到角点中心点的坐标距离# diag_dist(256,256,1)assert (src.shape[-1] == 27)device = src.device#分别取每个采样点到角点中心点坐标距离indices_x = torch.LongTensor([0,3,6,9,12,15,18,21,24]).to(device)  #indices_y = torch.LongTensor([1,4,7,10,13,16,19,22,25]).to(device) #indices_z = torch.LongTensor([2,5,8,11,14,17,20,23,26]).to(device) src_x = torch.index_select(src, -1, indices_x)  # (256, 256, 9)src_y = torch.index_select(src, -1, indices_y)  # (256, 256, 9)src_z = torch.index_select(src, -1, indices_z)  # (256, 256, 9)#把坐标距离转换成球形坐标距离??dis = (src_x ** 2 + src_y ** 2 + src_z ** 2) ** 0.5phi = torch.atan(src_y / (src_x + 1e-5))the = torch.acos(src_z / (dis + 1e-5))dis = dis / diag_dist#src(256,256,256)球形坐标距离??src = torch.cat([dis, phi, the], dim = -1)return src

这里为什么要把坐标系转换成球形坐标,作者在github上解释说,换不换没有什么差别,所以在论文中就没有提到…

Tranformer如下所示

def forward(self, src, query_embed, pos_embed):# src (256,256,256)# query_embed (1,256)# pos_embed (256,256,256) 位置编码在Emcoder中q,k相加bs, n, c = src.shapesrc = src.permute(1, 0, 2)pos_embed = pos_embed.permute(1, 0, 2)query_embed = query_embed.unsqueeze(1).repeat(1, bs, 1) # (1,256,256)tgt = torch.zeros_like(query_embed) #(1,256,256)# memory (256,256,256)memory = self.encoder(src, src_key_padding_mask=None, pos=pos_embed)#因为在整体代码中,TransformerDecoder的初始化参数return_intermediate设置为True#因此,Decoder的输出包含了每层的结果,共有一层,shape是[1,num_querie,batch_size,hidden_dim]hs = self.decoder(tgt, memory, memory_key_padding_mask=None,pos=pos_embed, query_pos=query_embed)#(1,256,1,256)return hs.transpose(1, 2), memory.permute(1, 2, 0).view(bs, c, n)

Encoder模块没有什么可说的,和DETR一样的,共有三层。encoder得到memory(256,256,256),256个roi,每个roi256个采样点,通道数为256。与输入相同
Decoder共一层,与DETR一样,输入的num_query = 1 ,代表每个roi里256个采样点生成一个box 。
k,v来自encoder,每一个roi得到一个box。
唯一不同的是,decoder在计算self-attention时候。公式由标准的

变成了论文中的:

再乘以V

代码如下

def attention(query, key,  value):dim = query.shape[1]scores_1 = torch.einsum('bdhn,bdhm->bhnm', query, key) / dim**.5scores_2 = torch.einsum('abcd, aced->abcd', key, scores_1)prob = torch.nn.functional.softmax(scores_2, dim=-1)output = torch.einsum('bnhm,bdhm->bdhn', prob, value)return output, prob

【3D detection】CT3D部分代码的理解相关推荐

  1. 3D Detection 论文汇总

    来源丨AI 修炼之路 这篇文章主要是梳理一下近期3D Detection的进展,分类列举出一些我认为的比较重要的.有代表性的工作. 一.论文分类汇总 1. 基于激光雷达点云的3D检测方法(LiDAR ...

  2. 复现KM3D:Monocular 3D Detection with Geometric Constraints Embedding and Semi-supervised Training

    复现KM3D:Monocular 3D Detection with Geometric Constraints Embedding and Semi-supervised Training 时间:2 ...

  3. 【论文阅读】【3d目标检测】Sparse Fuse Dense: Towards High Quality 3D Detection with Depth Completion

    论文题目:Sparse Fuse Dense: Towards High Quality 3D Detection with Depth Completion 飞步科技 cvpr2022 kitti ...

  4. CenterNet-TensorRT 3D Detection

    点击上方"3D视觉工坊",选择"星标" 干货第一时间送达 作者丨Panzerfahrer@知乎 来源丨https://zhuanlan.zhihu.com/p/ ...

  5. SSD-6D: Making RGB-Based 3D Detection and 6D Pose Estimation Great Again—2017(笔记)

    SSD-6D: Making RGB-Based 3D Detection and 6D Pose Estimation Great Again-2017(笔记) SSD-6D让RGB图像的3D检测和 ...

  6. 基于2.5/3D的自主主体室内场景理解研究

    作者:Tom Hardy Date:2020-3-13 来源:基于2.5/3D的自主主体室内场景理解研究 参考链接:https://arxiv.org/abs/1803.03352 主要内容 摘要随着 ...

  7. python绘制三维曲面图-python中Matplotlib实现绘制3D图的示例代码

    Matplotlib 也可以绘制 3D 图像,与二维图像不同的是,绘制三维图像主要通过 mplot3d 模块实现.但是,使用 Matplotlib 绘制三维图像实际上是在二维画布上展示,所以一般绘制三 ...

  8. python 协程 php,python3.x,协程_python协程练习部分代码的理解?,python3.x,协程,asyncio - phpStudy...

    python协程练习部分代码的理解? import asyncio import threading async def wget(host): print('wget {}'.format(host ...

  9. python绘制四边螺旋线代_Python绘制3d螺旋曲线图实例代码

    Line plots Axes3D.plot(xs, ys, *args, **kwargs) 绘制2D或3D数据 参数 描述 xs, ys X轴,Y轴坐标定点 zs Z值,每一个点的值都是1 zdi ...

最新文章

  1. SQL Server 2005下的分页SQL
  2. 香河php程序员_失控的香河最流行的四大职位
  3. 0.0 目录-深度学习第二课《改善神经网络》-Stanford吴恩达教授
  4. 基于注解的Spring AOP的配置和使用--转载
  5. Puppet 的部署与应用,看这一篇就够了
  6. 银行流水你真的会看吗?
  7. dubbo系列(一)
  8. 手把手教你如下在Linux下如何写一个C语言代码,编译并运行
  9. 手游建筑美术资源_建筑商和机械手
  10. linux下expdp定时备份_Linux下定时任务的配置
  11. linux分屏显示文件行数,linux常用命令集合1
  12. Improving Opencv9 Eroding and Dilating 和对opencv窗体上有控制按钮的理解
  13. ArduCopter——ArduPilot——Notch Filter(陷波滤波器)
  14. MATLAB 电子书
  15. 排列组合——排列公式的推理和组合
  16. 微信 支付 h5 开发 使用 best-pay-sdk
  17. 电脑连接上wifi,但是无法打开网页上网,小记
  18. 倪明选:追忆似水流年,祝愿更加辉煌
  19. activiti学习之回退实现
  20. Laravel Trait method broker has not been applied, because there are collisions with other trait meth

热门文章

  1. ArcEngine根据属性分割要素类的实现方法
  2. 使用Python破解维吉尼亚密码
  3. idc运维怎么转linux运维,IDC运维怎么便捷配置机房交换机
  4. 微星B450M迫击炮MAX开启CPU虚拟化功能
  5. 虚数到底有什么意义?
  6. Android小项目———— 冰炭不投de小计算器
  7. 【数据结构】时间复杂度_空间复杂度
  8. 离线安装OneNote for Windows 10
  9. Eclipse使用SVN进行代码提交的步骤
  10. 高数_证明_高斯公式