pointnet.py pointnet模型各个模块的实现

import torch
import torch.nn as nn
import torch.nn.parallel
import torch.utils.data
from torch.autograd import Variable
import numpy as np
import torch.nn.functional as F
​
# STN3d: T-Net 3*3 transform
# 类似一个mini-PointNet
class STN3d(nn.Module):def __init__(self, channel):super(STN3d, self).__init__()# torch.nn.Conv1d(in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True)self.conv1 = torch.nn.Conv1d(channel, 64, 1) self.conv2 = torch.nn.Conv1d(64, 128, 1)#channel为输入他通道数, 为3指的是输入点云特征的三个通道(X,Y,Z),channel=6表示6个通道(X.Y.Z.和X,Y,Z方向的法向量)self.conv3 = torch.nn.Conv1d(128, 1024, 1)self.fc1 = nn.Linear(1024, 512)self.fc2 = nn.Linear(512, 256)self.fc3 = nn.Linear(256, 9) # 9=3*3self.relu = nn.ReLU()
​self.bn1 = nn.BatchNorm1d(64)#输入特征的尺度会影响梯度下降算法的迭代步数以及梯度更新的难度,从而影响训练的收敛性。因此,我们需要对特征进行归一化,即使得各个特征有相似的尺度。self.bn2 = nn.BatchNorm1d(128)self.bn3 = nn.BatchNorm1d(1024)self.bn4 = nn.BatchNorm1d(512)self.bn5 = nn.BatchNorm1d(256)
​def forward(self, x):#前向传播函数batchsize = x.size()[0]x = F.relu(self.bn1(self.conv1(x)))x = F.relu(self.bn2(self.conv2(x)))x = F.relu(self.bn3(self.conv3(x)))# Symmetric function: max poolingx = torch.max(x, 2, keepdim=True)[0]#最大池化 这里得到了一个全局的特征# x参数展平(拉直)x = x.view(-1, 1024) #将全局特征展成一个1024列的特征,其中-1代表相应的行
​x = F.relu(self.bn4(self.fc1(x)))#1024降维到512后对512维的特征做相应的bn4(512)归一化之后进行relu的激活函数进行非线性。x = F.relu(self.bn5(self.fc2(x)))x = self.fc3(x)#9个元素
​# 展平的对角矩阵:np.array([1, 0, 0, 0, 1, 0, 0, 0, 1])iden = Variable(torch.from_numpy(np.array([1, 0, 0, 0, 1, 0, 0, 0, 1]).astype(np.float32))).view(1, 9).repeat(batchsize, 1)if x.is_cuda:iden = iden.cuda()x = x + iden # affine transformation,仿射变换,简单来说,“仿射变换”就是:“线性变换”+“平移”。# 用view,转换成batchsize*3*3的数组x = x.view(-1, 3, 3)return x
​
​
# STNkd: T-Net 64*64 transform,k默认是64
class STNkd(nn.Module):def __init__(self, k=64):super(STNkd, self).__init__()self.conv1 = torch.nn.Conv1d(k, 64, 1)self.conv2 = torch.nn.Conv1d(64, 128, 1)self.conv3 = torch.nn.Conv1d(128, 1024, 1)self.fc1 = nn.Linear(1024, 512)self.fc2 = nn.Linear(512, 256)self.fc3 = nn.Linear(256, k * k)self.relu = nn.ReLU()
​self.bn1 = nn.BatchNorm1d(64)self.bn2 = nn.BatchNorm1d(128)self.bn3 = nn.BatchNorm1d(1024)self.bn4 = nn.BatchNorm1d(512)self.bn5 = nn.BatchNorm1d(256)
​self.k = k
​def forward(self, x):batchsize = x.size()[0]x = F.relu(self.bn1(self.conv1(x)))x = F.relu(self.bn2(self.conv2(x)))x = F.relu(self.bn3(self.conv3(x)))# Symmetric function: max poolingx = torch.max(x, 2, keepdim=True)[0]# 参数拉直(展平)x = x.view(-1, 1024)
​x = F.relu(self.bn4(self.fc1(x)))x = F.relu(self.bn5(self.fc2(x)))x = self.fc3(x)
​# 展平的对角矩阵 iden = Variable(torch.from_numpy(np.eye(self.k).flatten().astype(np.float32))).view(1, self.k * self.k).repeat(batchsize, 1)if x.is_cuda:iden = iden.cuda()x = x + iden # affine transformationx = x.view(-1, self.k, self.k)return x
​
# PointNet编码器
class PointNetEncoder(nn.Module):def __init__(self, global_feat=True, feature_transform=False, channel=3):super(PointNetEncoder, self).__init__()
​self.stn = STN3d(channel) # STN3d: T-Net 3*3 transformself.conv1 = torch.nn.Conv1d(channel, 64, 1)self.conv2 = torch.nn.Conv1d(64, 128, 1)self.conv3 = torch.nn.Conv1d(128, 1024, 1)self.bn1 = nn.BatchNorm1d(64)self.bn2 = nn.BatchNorm1d(128)self.bn3 = nn.BatchNorm1d(1024)self.global_feat = global_feat  #是否需要全局的特征self.feature_transform = feature_transform#是否需要特征变换if self.feature_transform:self.fstn = STNkd(k=64) # STNkd: T-Net 64*64 transform
​def forward(self, x):B, D, N = x.size() # batchsize,3(xyz坐标)或6(xyz坐标+法向量),1024(一个物体所取的点的数目)trans = self.stn(x) # STN3d T-Netx = x.transpose(2, 1) # 交换一个tensor的两个维度if D >3 :x, feature = x.split(3,dim=2)# 对输入的点云进行输入转换(input transform)    # input transform: 计算两个tensor的矩阵乘法# bmm是两个三维张量相乘, 两个输入tensor维度是(b×n×m)和(b×m×p), # 第一维b代表batch size,输出为(b×n×p)x = torch.bmm(x, trans)if D > 3:x = torch.cat([x,feature],dim=2) #矩阵的拼接,0代表竖着拼接,1代表横着拼接x = x.transpose(2, 1)x = F.relu(self.bn1(self.conv1(x))) # MLP
​if self.feature_transform:trans_feat = self.fstn(x) # STNkd T-Netx = x.transpose(2, 1)# 对输入的点云进行特征转换(feature transform)# feature transform: 计算两个tensor的矩阵乘法x = torch.bmm(x, trans_feat)x = x.transpose(2, 1)else:trans_feat = None
​pointfeat = x # 局部特征x = F.relu(self.bn2(self.conv2(x))) # MLPx = self.bn3(self.conv3(x)) # MLPx = torch.max(x, 2, keepdim=True)[0] # 最大池化得到全局特征x = x.view(-1, 1024) # 展平if self.global_feat: # 需要返回的是否是全局特征?return x, trans, trans_feat # 返回全局特征else:x = x.view(-1, 1024, 1).repeat(1, 1, N)# 返回局部特征与全局特征的拼接return torch.cat([x, pointfeat], 1), trans, trans_feat
​
# 对特征转换矩阵做正则化:
# constrain the feature transformation matrix to be close to orthogonal matrix
def feature_transform_reguliarzer(trans): #让feature transform接近于一个正交的矩阵d = trans.size()[1]I = torch.eye(d)[None, :, :] # torch.eye(n, m=None, out=None) 返回一个2维张量,对角线位置全1,其它位置全0。单位阵if trans.is_cuda:I = I.cuda()# 正则化损失函数loss = torch.mean(torch.norm(torch.bmm(trans, trans.transpose(2, 1) - I), dim=(1, 2))) #a*a的转置-单位阵后求范数后再平均=正则化的lossreturn loss#用来将特征转换矩阵做正则化的约束
无序的。输入是欧几里德空间中点的子集。它有三个主要的属性:与图像中的像素阵列或体素网格中的体素阵列不同,点云是一组没有特定顺序的点。

点之间的交互作用。这些点来自一个有距离度量的空间。它意味着点不是孤立的,相邻点形成一个有意义的子集。因此,该模型需要能够从附近的点捕获局部结构,以及局部结构之间的组合相互作用。

变换不变性。作为一个几何对象,学习到的点集表示应该对某些变换是不变的。例如,同时旋转和平移点不应该改变全局点云的类别,也不应该改变点的分割。

它采用了两次STN(Spatial Transformer Networks),第一次input transform是对空间中点云进行调整,直观上理解是旋转出一个更有利于分类或分割的角度,比如把物体转到正面;第二次feature transform是对提取出的64维特征进行对齐,即在特征层面对点云进行变换。

网络分成了分类网络和分割网络两个部分,大体思路类似,都是设计表征的过程分类网络设计global feature,分割网络设计point-wise feature两者都是为了让表征尽可能discriminative,也就是同类的能分到一类,不同类的距离能拉开。

conclusion

第一篇直接用于点云数据处理的三为点云分割网络,结构简单,逻辑清晰,秒杀之前多视图,体素等方法。不用过多的去进行数据预处理。大大降低了点云模型的复杂度。影响力颇深。

2023.3.27 记录James.King缺席13场比赛后替补复出。出战29’31‘’拿下19分8篮板3助攻惜败公牛。

[PointNet代码详解]PointNet各模块代码实现超详细注释相关推荐

  1. [新手必备]Unity推箱子小游戏C#代码详解(第一篇-代码部分)

    完整项目请参考博客:https://blog.csdn.net/qq_41676090/article/details/96300302 本文为推箱子小游戏C#代码详解第一篇的代码部分,主要讲解 Sy ...

  2. 0-1背包问题详解(一步一步超详细)

    1.什么叫01背包问题? 背包问题通俗的说,就是假如你面前有5块宝石分别为a, b, c, d, e,每块宝石的重量不同,并且每块宝石所带来的价值也不同(注意:这里宝石的重量的价值没有特定关系),目前 ...

  3. 详解 K8S 高可用部署,超详细

    一.前言 二.基础环境部署 1)前期准备(所有节点) 2)安装容器 docker(所有节点) 3)配置 k8s yum 源(所有节点) 4)将 sandbox_image 镜像源设置为阿里云 goog ...

  4. fasterrcnn tensorflow代码详解_pytorch目标检测代码的一些bug调试

    这几天一直在做调包侠,是时候来总结总结了.记录一些我所遇到的不常见的问题. faster rcnn: 参考代码: jwyang/faster-rcnn.pytorch​github.com pytor ...

  5. java security 详解_Spring Security入门教程 通俗易懂 超详细 【内含案例】

    Spring Security的简单使用 简介 SSM 整合 Security 是比较麻烦的,虽然Security的功能比 Shiro 强大,相反却没有Shiro的使用量多 SpringBoot出现后 ...

  6. PointNet代码详解

    PointNet代码详解 最近在做点云深度学习的机器人抓取,这篇博客主要是把近期学习PointNet的一些总结的知识点汇总一下. PointNet概述详见以下网址和博客,这里也就不再赘述了. 三维深度 ...

  7. 基于U-Net的的图像分割代码详解及应用实现

    摘要 U-Net是基于卷积神经网络(CNN)体系结构设计而成的,由Olaf Ronneberger,Phillip Fischer和Thomas Brox于2015年首次提出应用于计算机视觉领域完成语 ...

  8. Transformer代码详解: attention-is-all-you-need-pytorch

    Transformer代码详解: attention-is-all-you-need-pytorch 前言 Transformer代码详解-pytorch版 Transformer模型结构 各模块结构 ...

  9. 超级超级详细的实体关系抽取数据预处理代码详解

    超级超级详细的实体关系抽取数据预处理代码详解 由于本人是代码小白,在学习代码过程中会出现很多的问题,所以需要一直记录自己出现的问题以及解决办法. 废话不多说,直接上代码!!! 一.data_proce ...

最新文章

  1. 微服务集成——《微服务设计》读书笔记
  2. html表单自动编号,自动编号插件
  3. arcgis中dem坐标定义_GIS基础-DEM Grid规则格网结构
  4. 《Google C++ 编码规范》小结
  5. HTML5七夕情人节表白网页制作【唯美满天星3D相册】HTML+CSS+JavaScript
  6. 能源在线监测管理系统
  7. 地图数字化步骤及问题总结
  8. 安装VMware16教程
  9. 2019上半年个人成长复盘
  10. Comma Separated Values Format
  11. 【thm】windows内网提权之Windows PrivEsc
  12. 【Kaggle 教程】Data Visualization 数据可视化-画图-各种图
  13. KK凯文.凯利:第一届中国社群领袖峰会演讲实录(全部版)
  14. FOJ 1968 Twinkling lights III
  15. 小萝莉说Crash(二): Unrecognized selector xxx 之 ForwardInvocation
  16. 190403内置模块
  17. 亚马逊买家秀视频怎么上传?上传买家秀视频的作用是什么
  18. idea2018版本 整合git不显示代码编辑记录和信息
  19. 苹果手机计算机隐藏照片app,‎App Store 上的“加密计算器 - 隐藏私人相册视频”...
  20. 手机火狐浏览html文件在哪里,火狐手机浏览器书签在哪?

热门文章

  1. React Native 启动速度优化——Native 篇(内含源码分析)
  2. Microsoft VBScript 运行时错误 错误 '800a01a8' 缺少对象: ''
  3. java计算机毕业设计共享单车使用满意度评价系统源码+mysql数据库+系统+lw文档+部署
  4. css中图片背景以及URL的介绍以及什么是css精灵
  5. 每天进步一点点---------kibana/Grafana场景2小学排名折线
  6. 四维图新亮相上汽大众合作伙伴技术展示日智能驾驶专题活动
  7. Mac 活动监视器 闪退 发热十分厉害 ssl3.plist
  8. input 实时监听输入框,判断最小值只能为1或其他数
  9. [励志文章]nbsp;一个计算机高手的成长
  10. 基于 Pytorch 疾病图片诊断识别 ResNet