Pointnet++代码详解(一):farthest_point_sample函数
初入Pointnet++,看相关源码感觉很费力,想着把自己学到的记下来,避免后面忘记要用到又得重新思考,本系列主要讲解Pointnet++代码,其理论部分大家可以在网上自行搜索相关资料。本系列分析的源码来自:https://github.com/yanx27/Pointnet_Pointnet2_pytorch
farthest_point_sample函数是来自于Pointnet++的FPS(Farthest Point Sampling) 最远点采样法,该方法比随机采样的优势在于它可以尽可能的覆盖空间中的所有点。
最远点采样是Set Abstraction模块中较为核心的步骤,其目的是从一个输入点云中按照所需要的点的个数npoint采样出足够多的点,并且点与点之间的距离要足够远。最后的返回结果是npoint个采样点在原始点云中的索引。
FPS的逻辑如下:
假设一共有n个点,整个点集为N = {f1, f2,…,fn}, 目标是选取n1个起始点做为下一步的中心点:
- 随机选取一个点fi为起始点,并写入起始点集 B = {fi};
- 选取剩余n-1个点计算和fi点的距离,选择最远点fj写入起始点集B={fi,fj};
- 选取剩余n-2个点计算和点集B中每个点的距离, 将最短的那个距离作为该点到点集的距离, 这样得到n-2个到点集的距离,选取最远的那个点写入起始点B = {fi, fj ,fk},同时剩下n-3个点, 如果n1=3 则到此选择完毕;
- 如果n1 > 3则重复上面步骤直到选取n1个起始点为止.
具体实现步骤如下:
- 先随机初始化一个centroids矩阵,后面用于存储npoint个采样点的索引位置,大小为B×npoint,其中B为BatchSize的个数,即B个样本;
- 利用distance矩阵记录某个样本中所有点到某一个点的距离,初始化为B×N矩阵,初值给个比较大的值,后面会迭代更新;
- 利用farthest表示当前最远的点,也是随机初始化,范围为0~N,初始化B个,对应到每个样本都随机有一个初始最远点;
- batch_indices初始化为0~(B-1)的数组;
- 直到采样点达到npoint,否则进行如下迭代:
- (1)设当前的采样点centroids为当前的最远点farthest;
- (2)取出这个中心点centroid的坐标;
- (3)求出所有点到这个farthest点的欧式距离,存在dist矩阵中;
- (4) 建立一个mask,如果dist中的元素小于distance矩阵中保存的距离值,则更新distance中的对应值,随着迭代的继续distance矩阵中的值会慢慢变小,其相当于记录着某个样本中每个点距离所有已出现的采样点的最小距离;
- (5)最后从distance矩阵取出最远的点为farthest,继续下一轮迭代.
def farthest_point_sample(xyz, npoint):"""Input:xyz: pointcloud data, [B, N, 3]npoint: number of samplesReturn:centroids: sampled pointcloud index, [B, npoint]"""device = xyz.devicebatchsize, ndataset, dimension = xyz.shape#to方法Tensors和Modules可用于容易地将对象移动到不同的设备(代替以前的cpu()或cuda()方法)# 如果他们已经在目标设备上则不会执行复制操作centroids = torch.zeros(batchsize, npoint, dtype=torch.long).to(device)distance = torch.ones(batchsize, ndataset).to(device) * 1e10#randint(low, high, size, dtype)# torch.randint(3, 5, (3,))->tensor([4, 3, 4])farthest = torch.randint(0, ndataset, (batchsize,), dtype=torch.long).to(device)#batch_indices=[0,1,...,batchsize-1]batch_indices = torch.arange(batchsize, dtype=torch.long).to(device)for i in range(npoint):# 更新第i个最远点centroids[:,i] = farthest# 取出这个最远点的xyz坐标centroid = xyz[batch_indices, farthest, :].view(batchsize, 1, 3)# 计算点集中的所有点到这个最远点的欧式距离#等价于torch.sum((xyz - centroid) ** 2, 2)dist = torch.sum((xyz - centroid) ** 2, -1)# 更新distances,记录样本中每个点距离所有已出现的采样点的最小距离mask = dist < distancedistance[mask] = dist[mask]# 从更新后的distances矩阵中找出距离最远的点,作为最远点用于下一轮迭代#取出每一行的最大值构成列向量,等价于torch.max(x,2)farthest = torch.max(distance, -1)[1]return centroids
1、xyz是点云的坐标数据,其维度为[B,N,3], B代表Batchsize,即有多少样本, N代表每个样本的总点数,3代表点云的x,y,z坐标;
npoint代表采样点数,centroids代表采样点的索引,其维度为[B, N]。
2、关于device
device = xyz.device
因此,这句代码说的就是将xyz的device属性赋给device,这是为了后续操作所采用的。
3、shape
可以看出shape与size()是一样的,而dim()返回的是Tensor的维度(秩)
4、to(device)
centroids = torch.zeros(batchsize, npoint, dtype=torch.long).to(device)distance = torch.ones(batchsize, ndataset).to(device) * 1e10
to方法Tensors和Modules可用于容易地将对象移动到不同的设备(代替以前的cpu()或cuda()方法)
注意:如果数据已经在目标设备上则不会执行复制操作
5、torch.randint和torch.arange
torch.
randint
(low=0, high, size):size是元组,产生从low到high之间的随机整数,大小为size。
torch.arange(start, end, step) # 不包括end, step是两个点间距,start默认为0,step默认为1
#randint(low, high, size, dtype)# torch.randint(3, 5, (3,))->tensor([4, 3, 4])farthest = torch.randint(0, ndataset, (batchsize,), dtype=torch.long).to(device)#batch_indices=[0,1,...,batchsize-1]batch_indices = torch.arange(batchsize, dtype=torch.long).to(device)
常用函数:https://www.jianshu.com/p/46a8ad87d238
6、
for i in range(npoint):# 更新第i个最远点,centroids:[B,npoint],farthest是最远点的索引centroids[:,i] = farthest# 取出batchsize的每个样本这个最远点的xyz坐标,xyz:[B,N,3]centroid = xyz[batch_indices, farthest, :].view(batchsize, 1, 3)# 计算点集中的所有点到这个最远点的欧式距离#等价于torch.sum((xyz - centroid) ** 2, 2)dist = torch.sum((xyz - centroid) ** 2, -1)# 更新distances,记录样本中每个点距离所有已出现的采样点的最小距离mask = dist < distancedistance[mask] = dist[mask]# 从更新后的distances矩阵中找出距离最远的点,作为最远点用于下一轮迭代#torch.max(distance, -1)取出每一行的最大值构成列向量,等价于torch.max(x,2)#torch.max(distance, -1)[1]是取列向量的索引,若torch.max(distance, -1)[0]则是取列向量farthest = torch.max(distance, -1)[1]
torch.sum(input, dim, out=None) → Tensor
- input (Tensor) – 输入张量
- dim (int) – 缩减的维度
- out (Tensor, optional) – 结果张量
import torch
x = torch.randn(4, 5)print(x)print(x.sum(0)) #按列求和
print(x.sum(1)) #按行求和
print(torch.sum(x)) #按列求和
print(torch.sum(x, 0))#按列求和
print(torch.sum(x, 1))#按行求和#结果:
tensor([[ 0.2210, 1.8035, 0.7671, -0.1836, -0.2794],[-0.7922, -1.0881, -2.0180, 1.0981, 0.2320],[-0.4681, 0.1820, 0.0502, 0.0067, 1.3218],[ 0.4785, 1.0799, 1.6197, 0.6642, 0.6915]])
tensor([-0.5608, 1.9773, 0.4190, 1.5854, 1.9660])
tensor([ 2.3287, -2.5682, 1.0926, 4.5338])
tensor(5.3868)
tensor([-0.5608, 1.9773, 0.4190, 1.5854, 1.9660])
tensor([ 2.3287, -2.5682, 1.0926, 4.5338])
对于三维而言,
import torch
xyz = torch.tensor([[[3,7,9],[10,5,2]],[[5,4,2],[1,6,9]]])dist0 = torch.sum(xyz, -1)
dist1 = torch.sum(xyz, 2)
dist2 = torch.sum(xyz, 1)
dist3 = torch.sum(xyz)
print("xyz:",xyz)
print("sum-1:",dist0)
print("sum2:", dist1)
print("sum1:",dist2)
print("sum:", dist3)结果:
xyz: tensor([[[ 3, 7, 9],[10, 5, 2]],[[ 5, 4, 2],[ 1, 6, 9]]])
sum-1: tensor([[19, 17],[11, 16]])
sum2: tensor([[19, 17],[11, 16]])
sum1: tensor([[13, 12, 11],[ 6, 10, 11]])
sum: tensor(63)
更多sum用法详见:https://blog.csdn.net/qq_39463274/article/details/105145029
torch.max:
对于tensorA和tensorB:
- torch.max(tensorA):返回tensor中的最大值。
- torch.max(tensorA,dim):dim表示指定的维度,返回指定维度的最大数和对应下标
- torch.max(tensorA,tensorB):比较tensorA和tensorB相对较大的元素。
若为三阶张量,则结果如下:
import torch
x= torch.tensor([[[3,7,9],[10,5,2]],[[5,4,2],[1,6,9]]])
k0=torch.max(x,0)
k1=torch.max(x,1)
k2=torch.max(x,2)
k3=torch.max(x,-1)
print("x:",x)
print("k0:",k0)
print("k1:",k1)
print("k2:",k2)
print("k-1:",k3)结果:
x: tensor([[[ 3, 7, 9],[10, 5, 2]],[[ 5, 4, 2],[ 1, 6, 9]]])
k0: (tensor([[ 5, 7, 9],[10, 6, 9]]), tensor([[1, 0, 0],[0, 1, 1]]))
k1: (tensor([[10, 7, 9],[ 5, 6, 9]]), tensor([[1, 0, 0],[0, 1, 1]]))
k2: (tensor([[ 9, 10],[ 5, 9]]), tensor([[2, 0],[0, 2]]))
k-1: (tensor([[ 9, 10],[ 5, 9]]), tensor([[2, 0],[0, 2]]))
详细请见:https://blog.csdn.net/Linux_bin/article/details/95599849
Pointnet++代码详解(一):farthest_point_sample函数相关推荐
- Pointnet++代码详解:farthest_point_sample函数
FPS farthest_point_sample函数是来自于Pointnet++的FPS(Farthest Point Sampling) 最远点采样法,该方法比随机采样的优势在于它可以尽可能的覆盖 ...
- PointNet代码详解
PointNet代码详解 最近在做点云深度学习的机器人抓取,这篇博客主要是把近期学习PointNet的一些总结的知识点汇总一下. PointNet概述详见以下网址和博客,这里也就不再赘述了. 三维深度 ...
- Windows扫雷游戏代码详解【memset函数】
题目描述: 扫雷是Windows自带的游戏.游戏的目标是尽快找到雷区中的所有地雷,而不许踩到地雷.如果方块上的是地雷,将输掉游戏. 如果方块上出现数字,则表示在其周围的八个方块中共有多少颗地雷. 你的 ...
- Pointnet++代码详解(三):query_ball_point函数
query_ball_point函数对应于Grouping layer, 这一层使用Ball query方法生成N'个局部区域,根据论文中的意思,这里有两个变量 ,一个是每个区域中点的数量K,另一个是 ...
- [PointNet代码详解]PointNet各模块代码实现超详细注释
pointnet.py pointnet模型各个模块的实现 import torch import torch.nn as nn import torch.nn.parallel import tor ...
- 三维深度学习之pointnet系列详解(一)
目前二维深度学习取得了很大的进步并且应用范围越来越广,随着三维设备的发展,三维深度学习得到了很大的关注. 最近接触了三维深度学习方面的研究,从pointnet入手,对此有了一点点了解希望记录下来并分享 ...
- 从PointNet到PointNet++理论及代码详解
从PointNet到PointNet++理论及代码详解 1. 点云是什么 1.1 三维数据的表现形式 1.2 为什么使用点云 1.3 点云上以往的相关工作 2. PointNet 2.1 基于点云的置 ...
- PointNet模型的Pytorch代码详解
前言 关于PointNet模型的构成.原理.效果等等论文部分内容,我在之前一篇论文中写到过,可以参考这个链接:PointNet论文笔记 下边我就直接放一张网络组成图,并对代码进行解释,我以一种比 ...
- python代码大全表解释-python操作列表的函数使用代码详解
python的列表很重要,学习到后面你会发现使用的地方真的太多了.最近在写一些小项目时经常用到列表,有时其中的方法还会忘哎! 所以为了复习写下了这篇博客,大家也可以来学习一下,应该比较全面和详细了 列 ...
最新文章
- 使用Intellij IDEA 解决Java8的数据流问题
- 《新一代城市大脑建设与发展》专家研讨会在京举办
- JQuery学习笔记 [Ajax] (6-2)
- github可视化_Github上 10 个超好看可视化面板
- 学习笔记(32):Python网络编程并发编程-线程queue
- Net Core下使用RabbitMQ比较完备两种方案(虽然代码有点惨淡,不过我会完善)
- corda_Corda服务的异步流调用
- mysql8.0默认端口_mysql 8.0.19 安装 及 端口修改
- “程序员年薪50万到底有多累、多辛苦?”,句句扎心
- 可执行MIPS指令的单周期CPU
- 2021-03-16PyCharm3.0默认快捷键(翻译的)PyCharm Default Keymap
- Oracle EBS 统计数据收集模式(Gather Schema Statistics)报错处理
- APP游戏开发十诫!第一个雏型就要搞定的事
- win10显卡相关配置
- Torchlight(火炬之光)特效载入
- 商家自建流量池:10种微信引流的方法,值得学习社群营销的商家收藏 !
- 大型企业选择私有云的原因
- 计算机在校学校目标和措施,学校信息化工作方案
- 【历史上的今天】8 月 20 日:传奇程序员诞生日;谷歌发布 Pixel 4a
- 计算机网络系统承接查验,智能化系统承接查验标准及方法基本.docx