pointnet.pytorch的代码详细解释

  • 1. PointNet的Pytorch版本代码解析链接
  • 2. 代码解释
    • 2.1 代码结构思维导图
    • 2.2 代码注释
      • 2.2.1 build.sh
      • 2.2.2 render_balls_so.cpp
      • 2.2.3 download.sh
      • 2.2.4 train_classification.py
      • 2.2.5 dataset.py
      • 2.2.6 model
  • 参考文献

1. PointNet的Pytorch版本代码解析链接

pointnet.pytorch

2. 代码解释

2.1 代码结构思维导图

2.2 代码注释

2.2.1 build.sh

按照代码运行的顺序,先从pointnet.pytorch/scripts/build.sh开始解释:

#获取build.sh所在文件夹的绝对路径
SCRIPT=`realpath $0`
SCRIPTPATH=`dirname $SCRIPT`
echo $SCRIPTPATH #对../utils/render_balls_so.cpp进行编译,render_balls_so.cpp文件是用于可视化的C++代码
#-o参数用来指定生成程序的名字
#-shared参数表示编译动态库
#-O2用于优化编译文件
#-D_GLIBCXX_USE_CXX11_ABI用于区分有旧版(c++03规范)的libstdc++.so,和新版(c++11规范)的libstdc++.so两个库,-D_GLIBCXX_USE_CXX11_ABI=0 链接旧版库,-D_GLIBCXX_USE_CXX11_ABI=1 链接新版库
g++ -std=c++11 $SCRIPTPATH/../utils/render_balls_so.cpp -o $SCRIPTPATH/../utils/render_balls_so.so -shared -fPIC -O2 -D_GLIBCXX_USE_CXX11_ABI=0

-O1, -O2, -O3编译参数的详细解释
-D_GLIBCXX_USE_CXX11_ABI参数的详细解释

2.2.2 render_balls_so.cpp

接下来再看pointnet.pytorch/utils/render_balls_so.cpp是如何进行可视化的:

#include <cstdio>
#include <vector>
#include <algorithm>
#include <math.h>
using namespace std;struct PointInfo{int x,y,z;float r,g,b;
};extern "C"{void render_ball(int h,int w,unsigned char * show,int n,int * xyzs,float * c0,float * c1,float * c2,int r){r=max(r,1);//定义了容量为h*w,初始值为-2100000000的vectorvector<int> depth(h*w,-2100000000); vector<PointInfo> pattern;//将以r为半径球中所有整数点放入容器pattern中for (int dx=-r;dx<=r;dx++)for (int dy=-r;dy<=r;dy++)if (dx*dx+dy*dy<r*r){double dz=sqrt(double(r*r-dx*dx-dy*dy));PointInfo pinfo;pinfo.x=dx;pinfo.y=dy;pinfo.z=dz;pinfo.r=dz/r;pinfo.g=dz/r;pinfo.b=dz/r;pattern.push_back(pinfo);}//找到xyzs中z的最小值和最大值double zmin=0,zmax=0;for (int i=0;i<n;i++){if (i==0){zmin=xyzs[i*3+2]-r;zmax=xyzs[i*3+2]+r;}else{zmin=min(zmin,double(xyzs[i*3+2]-r));zmax=max(zmax,double(xyzs[i*3+2]+r));}}//for (int i=0;i<n;i++){int x=xyzs[i*3+0],y=xyzs[i*3+1],z=xyzs[i*3+2];for (int j=0;j<int(pattern.size());j++){int x2=x+pattern[j].x;int y2=y+pattern[j].y;int z2=z+pattern[j].z;if (!(x2<0 || x2>=h || y2<0 || y2>=w) && depth[x2*w+y2]<z2){depth[x2*w+y2]=z2;double intensity=min(1.0,(z2-zmin)/(zmax-zmin)*0.7+0.3);show[(x2*w+y2)*3+0]=pattern[j].b*c2[i]*intensity;show[(x2*w+y2)*3+1]=pattern[j].g*c0[i]*intensity;show[(x2*w+y2)*3+2]=pattern[j].r*c1[i]*intensity;}}}
}}//extern "C"

2.2.3 download.sh

下载数据集的脚本pointnet.pytorch/scripts/download.sh

#获取download.sh所在文件夹的绝对路径
SCRIPT=`realpath $0`
SCRIPTPATH=`dirname $SCRIPT`#进入download.sh所在文件夹的上一层
cd $SCRIPTPATH/..
#下载数据集压缩包、解压压缩包、删除压缩包
wget https://shapenet.cs.stanford.edu/ericyi/shapenetcore_partanno_segmentation_benchmark_v0.zip --no-check-certificate
unzip shapenetcore_partanno_segmentation_benchmark_v0.zip
rm shapenetcore_partanno_segmentation_benchmark_v0.zip#重新进入当前文件夹
cd -

2.2.4 train_classification.py

开始进行Pointnet的分类训练pointnet.pytorch/utils/train_classification.py:

#Python提供了__future__模块,把下一个新版本的特性导入到当前版本,于是我们就可以在当前版本中测试一些新版本的特性,见链接(1)
from __future__ import print_function
#argparse 是 Python 内置的一个用于命令项选项与参数解析的模块,可实现命令行中输入参数的传递,见链接(2)
import argparse
#提供了一些方便使用操作系统相关功能的函数
import os
import random
import torch
import torch.nn.parallel
#优化器模块
import torch.optim as optim
#处理数据集的模块
import torch.utils.data
#从pointnet.pytorch/pointnet/dataset.py和pointnet.pytorch/pointnet/model.py中导入库
#数据进行预处理的库
from pointnet.dataset import ShapeNetDataset, ModelNetDataset
#pointnet的模型结构库
from pointnet.model import PointNetCls, feature_transform_regularizer
#封装好的类
import torch.nn.functional as F
#展示进度条的模块,见链接(3)
from tqdm import tqdm#使用argparse 的第一步是创建一个 ArgumentParser 对象
parser = argparse.ArgumentParser()
#添加程序参数信息
#终端键入batchsize大小
parser.add_argument('--batchSize', type=int, default=32, help='input batch size')
#默认的数据集每个点云是2500个点
parser.add_argument('--num_points', type=int, default=2500, help='input batch size')
#加载数据的进程数目
parser.add_argument('--workers', type=int, help='number of data loading workers', default=4)
#epoch,训练多少轮
parser.add_argument('--nepoch', type=int, default=250, help='number of epochs to train for')
#输出文件夹名称
parser.add_argument('--outf', type=str, default='cls', help='output folder')
#预训练模型路径
parser.add_argument('--model', type=str, default='', help='model path')
#这里,数据集的路径必须手动设置
parser.add_argument('--dataset', type=str, required=True, help="dataset path")
#数据集类型
parser.add_argument('--dataset_type', type=str, default='shapenet', help="dataset type shapenet|modelnet40")
#是否进行特征变换
parser.add_argument('--feature_transform', action='store_true', help="use feature transform")
#解析参数
opt = parser.parse_args()
print(opt)blue = lambda x: '\033[94m' + x + '\033[0m'#返回1~10000间的一个整数,作为随机种子 opt的类型为:<class 'argparse.Namespace'>
opt.manualSeed = random.randint(1, 10000)
print("Random Seed: ", opt.manualSeed)#保证在有种子的情况下生成的随机数都是一样的,见链接(4)
random.seed(opt.manualSeed)
#设置一个用于生成随机数的种子,返回的是一个torch.Generator对象
torch.manual_seed(opt.manualSeed)#调用pointnet.pytorch/pointnet/dataset.py中的ShapeNetDataset类,创建针对shapenet数据集的类对象
if opt.dataset_type == 'shapenet':dataset = ShapeNetDataset(#训练集root=opt.dataset,classification=True,#打开分类的选项npoints=opt.num_points)test_dataset = ShapeNetDataset(#测试集root=opt.dataset,classification=True,split='test',#标记为测试npoints=opt.num_points,data_augmentation=False)
#调用pointnet.pytorch/pointnet/dataset.py中的ModelNetDataset类,创建针对modelnet40数据集的类对象
elif opt.dataset_type == 'modelnet40':dataset = ModelNetDataset(root=opt.dataset,npoints=opt.num_points,split='trainval')test_dataset = ModelNetDataset(root=opt.dataset,split='test',npoints=opt.num_points,data_augmentation=False)
else:exit('wrong dataset type')#用来把训练数据分成多个小组,此函数每次抛出一组数据。直至把所有的数据都抛出。就是做一个数据的初始化
dataloader = torch.utils.data.DataLoader(dataset,batch_size=opt.batchSize,shuffle=True,#将数据集的顺序打乱num_workers=int(opt.workers))testdataloader = torch.utils.data.DataLoader(test_dataset,batch_size=opt.batchSize,shuffle=True,num_workers=int(opt.workers))print(len(dataset), len(test_dataset))# 12137 2874
num_classes = len(dataset.classes)
print('classes', num_classes)#classes 16#创建文件夹,若无法创建,进行异常检测
try:os.makedirs(opt.outf)
except OSError:pass#调用model.py的PointNetCls定义分类函数
classifier = PointNetCls(k=num_classes, feature_transform=opt.feature_transform)#如果有预训练模型,将预训练模型加载
if opt.model != '':classifier.load_state_dict(torch.load(opt.model))# 优化器:adam-Adaptive Moment Estimation(自适应矩估计),利用梯度的一阶矩和二阶矩动态调整每个参数的学习率
# betas:用于计算梯度一阶矩和二阶矩的系数
optimizer = optim.Adam(classifier.parameters(), lr=0.001, betas=(0.9, 0.999))
# 学习率调整:每个step_size次epoch后,学习率x0.5
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=20, gamma=0.5)
# 将所有的模型参数移到GPU中
classifier.cuda()# 计算batch的数量
num_batch = len(dataset) / opt.batchSize#开始一趟一趟的训练
for epoch in range(opt.nepoch):scheduler.step() #调整学习率# 将一个可遍历对象组合为一个索引序列,同时列出数据和数据下标,(0, seq[0])...# __init__(self, iterable, start=0),参数为可遍历对象及起始位置for i, data in enumerate(dataloader, 0):points, target = data  #读取待训练对象点云与标签target = target[:, 0] # 取所有行的第0列points = points.transpose(2, 1) #改变点云的维度points, target = points.cuda(), target.cuda() # tensor转到cuda上optimizer.zero_grad() # 梯度清除,避免backward时梯度累加classifier = classifier.train() # 训练模式,使能BN和dropoutpred, trans, trans_feat = classifier(points)  # 网络结果预测输出# 损失函数:负log似然损失,在分类网络中使用了log_softmax,二者结合其实就是交叉熵损失函数loss = F.nll_loss(pred, target) #对feature_transform中64X64的变换矩阵做正则化,满足AA^T=Iif opt.feature_transform:loss += feature_transform_regularizer(trans_feat) * 0.001loss.backward() # loss反向传播optimizer.step() # 梯度下降,参数优化pred_choice = pred.data.max(1)[1] # max(1)返回每一行中的最大值及索引,[1]取出索引(代表着类别)correct = pred_choice.eq(target.data).cpu().sum() # 判断和target是否匹配,并计算匹配的数量print('[%d: %d/%d] train loss: %f accuracy: %f' % (epoch, i, num_batch, loss.item(), correct.item() / float(opt.batchSize)))# 每10次batch之后,进行一次测试if i % 10 == 0: j, data = next(enumerate(testdataloader, 0))points, target = datatarget = target[:, 0]points = points.transpose(2, 1)points, target = points.cuda(), target.cuda()classifier = classifier.eval() # 测试模式,固定住BN和dropoutpred, _, _ = classifier(points)loss = F.nll_loss(pred, target)pred_choice = pred.data.max(1)[1]correct = pred_choice.eq(target.data).cpu().sum()print('[%d: %d/%d] %s loss: %f accuracy: %f' % (epoch, i, num_batch, blue('test'), loss.item(), correct.item()/float(opt.batchSize)))#保存权重文件在cls/cls_model_1.pthtorch.save(classifier.state_dict(), '%s/cls_model_%d.pth' % (opt.outf, epoch))#在测试集上验证模型的精度
total_correct = 0
total_testset = 0
for i,data in tqdm(enumerate(testdataloader, 0)):points, target = datatarget = target[:, 0]points = points.transpose(2, 1)points, target = points.cuda(), target.cuda()classifier = classifier.eval()pred, _, _ = classifier(points)pred_choice = pred.data.max(1)[1]correct = pred_choice.eq(target.data).cpu().sum()total_correct += correct.item()total_testset += points.size()[0]print("final accuracy {}".format(total_correct / float(total_testset)))

(1)from future import print_function 用法
(2)argparse用法
(3)详细介绍Python进度条tqdm的使用
(4)random模块中seed的用法
(5)try的用法

2.2.5 dataset.py

看一下如何处理数据集pointnet.pytorch/pointnet/dataset.py:

from __future__ import print_function
import torch.utils.data as data
import os
#os.path 模块主要用于获取文件的属性
import os.path
import torch
import numpy as np
#针对与Python解释器相关的变量和方法
import sys
from tqdm import tqdm
#用于存储和转换数据格式的语法
import json
#处理点云的文件,自行安装
from plyfile import PlyData, PlyElementdef get_segmentation_classes(root):catfile = os.path.join(root, 'synsetoffset2category.txt')cat = {}meta = {}with open(catfile, 'r') as f:for line in f:ls = line.strip().split()cat[ls[0]] = ls[1]for item in cat:dir_seg = os.path.join(root, cat[item], 'points_label')dir_point = os.path.join(root, cat[item], 'points')fns = sorted(os.listdir(dir_point))meta[item] = []for fn in fns:token = (os.path.splitext(os.path.basename(fn))[0])meta[item].append((os.path.join(dir_point, token + '.pts'), os.path.join(dir_seg, token + '.seg')))with open(os.path.join(os.path.dirname(os.path.realpath(__file__)), '../misc/num_seg_classes.txt'), 'w') as f:for item in cat:datapath = []num_seg_classes = 0for fn in meta[item]:datapath.append((item, fn[0], fn[1]))for i in tqdm(range(len(datapath))):l = len(np.unique(np.loadtxt(datapath[i][-1]).astype(np.uint8)))if l > num_seg_classes:num_seg_classes = lprint("category {} num segmentation classes {}".format(item, num_seg_classes))f.write("{}\t{}\n".format(item, num_seg_classes))def gen_modelnet_id(root):classes = []with open(os.path.join(root, 'train.txt'), 'r') as f:for line in f:classes.append(line.strip().split('/')[0])classes = np.unique(classes)with open(os.path.join(os.path.dirname(os.path.realpath(__file__)), '../misc/modelnet_id.txt'), 'w') as f:for i in range(len(classes)):f.write('{}\t{}\n'.format(classes[i], i))class ShapeNetDataset(data.Dataset):def __init__(self,root,npoints=2500,classification=False,class_choice=None,split='train',data_augmentation=True):self.npoints = npointsself.root = rootself.catfile = os.path.join(self.root, 'synsetoffset2category.txt') #路径拼接 这个参数是在root路径中synsetoffset2category.txt的路径self.cat = {}self.data_augmentation = data_augmentation # 数据扩充self.classification = classificationself.seg_classes = {}# 读synsetoffset2category.txt中的数据,并以字典的形式存储到self.cat中with open(self.catfile, 'r') as f:# 打开目录txt文件,'r':open for readingfor line in f:# strip():移除字符串头尾指定的字符(默认为空格或换行符)# split():指定分隔符对字符串进行切片,返回分割后的字符串列表(默认为所有的空字符,包括空格、换行\n、制表符\t等)ls = line.strip().split() #ls的类型为list# cat为字典,通过[键]索引。键:类别;值:文件夹名称self.cat[ls[0]] = ls[1]#print(self.cat)# 类别选择,对那些种类物体进行分类if not class_choice is None:self.cat = {k: v for k, v in self.cat.items() if k in class_choice}self.id2cat = {v: k for k, v in self.cat.items()}# key和value互换self.meta = {}# json文件类似xml文件,可存储键值对和数组等# split=train# format():字符串格式化函数,使用{}代替之前的%splitfile = os.path.join(self.root, 'train_test_split', 'shuffled_{}_file_list.json'.format(split))#from IPython import embed; embed()filelist = json.load(open(splitfile, 'r'))# for item in self.cat:item为键# for item in self.cat.values():item为值# for item in self.cat.items():item为键值对(元组的形式)# for k, v in self.cat.items():更为规范的键值对读取方式# meta为字典,键为类别,键值为空for item in self.cat:self.meta[item] = []# 读取shuffled_train_file_list.jsonfor file in filelist:_, category, uuid = file.split('/')# category为某一类别所在文件夹,uuid为某一类别的某一个#分类:把每一类物体的路径分到每一类物体的后面,格式为{'Airplane':[('*.pts','*.seg'), ...]}if category in self.cat.values():self.meta[self.id2cat[category]].append((os.path.join(self.root, category, 'points', uuid+'.pts'),os.path.join(self.root, category, 'points_label', uuid+'.seg')))self.datapath = []# cat存储类别及其所在文件夹,item访问键,即类别for item in self.cat:# meta为字典,fn访问值,即路径for fn in self.meta[item]:# item为类别,fn[0]为点云路径,fn[1]为用于分割的标签路径self.datapath.append((item, fn[0], fn[1]))# sorted():对所有可迭代兑现进行排序,默认为升序;sorted(self.cat)对字典cat中的键(种类)进行排序,排序结果的类型为list# zip():  函数用于将可迭代的对象作为参数,将对象中对应的元素打包成一个个元组# dict(): 创建字典。dict(zip(['one', 'two'], [1, 2])) -> {'two': 2, 'one': 1}# 下列操作实现了对类别进行数字编码表示self.classes = dict(zip(sorted(self.cat), range(len(self.cat))))print(self.classes)#读misc/num_seg_classes.txt中的数据with open(os.path.join(os.path.dirname(os.path.realpath(__file__)), '../misc/num_seg_classes.txt'), 'r') as f:for line in f:ls = line.strip().split()self.seg_classes[ls[0]] = int(ls[1])#'Airplane'应该分成几类。num_seg_classes为对应的的类应该分成几类self.num_seg_classes = self.seg_classes[list(self.cat.keys())[0]]print(self.seg_classes, self.num_seg_classes)# 该方法的实例对象可通过索引取值,自动调用该方法def __getitem__(self, index):fn = self.datapath[index]  # 获取类别、点云路径、分割标签路径元组cls = self.classes[self.datapath[index][0]] # 获取数字编码的类别标签point_set = np.loadtxt(fn[1]).astype(np.float32) # 读取pts点云seg = np.loadtxt(fn[2]).astype(np.int64)  # 读取分割标签#print(point_set.shape, seg.shape)# 重新采样到self.npoints个点choice = np.random.choice(len(seg), self.npoints, replace=True)#resamplepoint_set = point_set[choice, :]# 去中心化point_set = point_set - np.expand_dims(np.mean(point_set, axis = 0), 0) # center#计算到原点的最远距离dist = np.max(np.sqrt(np.sum(point_set ** 2, axis = 1)),0)# 归一化point_set = point_set / dist #scale#默认False  开启旋转任意角度并加上一个bias,增强数据的抗干扰能力if self.data_augmentation:theta = np.random.uniform(0,np.pi*2)rotation_matrix = np.array([[np.cos(theta), -np.sin(theta)],[np.sin(theta), np.cos(theta)]])point_set[:,[0,2]] = point_set[:,[0,2]].dot(rotation_matrix) # random rotationpoint_set += np.random.normal(0, 0.02, size=point_set.shape) # random jitterseg = seg[choice]point_set = torch.from_numpy(point_set)#转换数据格式seg = torch.from_numpy(seg)cls = torch.from_numpy(np.array([cls]).astype(np.int64)) #cls为对应的代号,比如Airplane对应0if self.classification:return point_set, clselse:return point_set, segdef __len__(self):return len(self.datapath)class ModelNetDataset(data.Dataset):def __init__(self,root,npoints=2500,split='train',data_augmentation=True):self.npoints = npointsself.root = rootself.split = splitself.data_augmentation = data_augmentationself.fns = []with open(os.path.join(root, '{}.txt'.format(self.split)), 'r') as f:for line in f:self.fns.append(line.strip())self.cat = {}with open(os.path.join(os.path.dirname(os.path.realpath(__file__)), '../misc/modelnet_id.txt'), 'r') as f:for line in f:ls = line.strip().split()self.cat[ls[0]] = int(ls[1])print(self.cat)self.classes = list(self.cat.keys())def __getitem__(self, index):fn = self.fns[index]cls = self.cat[fn.split('/')[0]]with open(os.path.join(self.root, fn), 'rb') as f:plydata = PlyData.read(f)pts = np.vstack([plydata['vertex']['x'], plydata['vertex']['y'], plydata['vertex']['z']]).Tchoice = np.random.choice(len(pts), self.npoints, replace=True)point_set = pts[choice, :]point_set = point_set - np.expand_dims(np.mean(point_set, axis=0), 0)  # centerdist = np.max(np.sqrt(np.sum(point_set ** 2, axis=1)), 0)point_set = point_set / dist  # scaleif self.data_augmentation:theta = np.random.uniform(0, np.pi * 2)rotation_matrix = np.array([[np.cos(theta), -np.sin(theta)], [np.sin(theta), np.cos(theta)]])point_set[:, [0, 2]] = point_set[:, [0, 2]].dot(rotation_matrix)  # random rotationpoint_set += np.random.normal(0, 0.02, size=point_set.shape)  # random jitterpoint_set = torch.from_numpy(point_set.astype(np.float32))cls = torch.from_numpy(np.array([cls]).astype(np.int64))return point_set, clsdef __len__(self):return len(self.fns)if __name__ == '__main__':dataset = sys.argv[1]datapath = sys.argv[2]if dataset == 'shapenet':d = ShapeNetDataset(root = datapath, class_choice = ['Chair'])print(len(d))ps, seg = d[0]print(ps.size(), ps.type(), seg.size(),seg.type())d = ShapeNetDataset(root = datapath, classification = True)print(len(d))ps, cls = d[0]print(ps.size(), ps.type(), cls.size(),cls.type())# get_segmentation_classes(datapath)if dataset == 'modelnet':gen_modelnet_id(datapath)d = ModelNetDataset(root=datapath)print(len(d))print(d[0])

(1)加载os和os.path之间的关联和区别
(2)Python常用标准库之sys
(3)问题解决:NameError: name ‘file’ is not defined
(4)np.random.choice()的用法详解
(5)np.expand_dims()的用法详解

2.2.6 model

pointnet.pytorch/pointnet/model.py中看看如何定义分类器,这一部分如果有网络架构图就很容易理解了,建议参考大佬的PointNet网络架构图:

from __future__ import print_function
import torch
#nn全称为neural network,意思是神经网络,是torch中构建神经网络的模块
import torch.nn as nn
import torch.nn.parallel
import torch.utils.data
from torch.autograd import Variable
import numpy as np
import torch.nn.functional as F# T-Net: is a pointnet itself.获取3x3的变换矩阵,校正点云姿态;效果一般,后续的改进并没有再加入这部分
# 经过全连接层映射到9个数据,最后调整为3x3矩阵
class STN3d(nn.Module):def __init__(self):super(STN3d, self).__init__()#mlpself.conv1 = torch.nn.Conv1d(3, 64, 1)self.conv2 = torch.nn.Conv1d(64, 128, 1)self.conv3 = torch.nn.Conv1d(128, 1024, 1)#fcself.fc1 = nn.Linear(1024, 512)self.fc2 = nn.Linear(512, 256)self.fc3 = nn.Linear(256, 9)#激活函数self.relu = nn.ReLU()#bnself.bn1 = nn.BatchNorm1d(64)self.bn2 = nn.BatchNorm1d(128)self.bn3 = nn.BatchNorm1d(1024)self.bn4 = nn.BatchNorm1d(512)self.bn5 = nn.BatchNorm1d(256)def forward(self, x):batchsize = x.size()[0]x = F.relu(self.bn1(self.conv1(x)))x = F.relu(self.bn2(self.conv2(x)))x = F.relu(self.bn3(self.conv3(x)))x = torch.max(x, 2, keepdim=True)[0]x = x.view(-1, 1024)x = F.relu(self.bn4(self.fc1(x)))x = F.relu(self.bn5(self.fc2(x)))x = self.fc3(x)# Variable已被弃用,之前的版本中,pytorch的tensor只能在CPU计算,Variable将tensor转换成variable,具有三个属性(data\grad\grad_fn)# 现在二者已经融合,Variable返回tensor# iden生成单位变换矩阵# repeat(batchsize, 1),重复batchsize次,生成batchsize x 9的tensoriden = Variable(torch.from_numpy(np.array([1,0,0,0,1,0,0,0,1]).astype(np.float32))).view(1,9).repeat(batchsize,1)#将单位矩阵送入GPUif x.is_cuda:iden = iden.cuda()x = x + iden# view()相当于numpy中的resize(),重构tensor维度,-1表示缺省参数由系统自动计算(为batchsize大小)# 返回结果为 batchsize x 3 x 3x = x.view(-1, 3, 3)return x# 数据为k维,用于mlp之后的高维特征,同上
class STNkd(nn.Module):def __init__(self, k=64):super(STNkd, self).__init__()self.conv1 = torch.nn.Conv1d(k, 64, 1)self.conv2 = torch.nn.Conv1d(64, 128, 1)self.conv3 = torch.nn.Conv1d(128, 1024, 1)self.fc1 = nn.Linear(1024, 512)self.fc2 = nn.Linear(512, 256)self.fc3 = nn.Linear(256, k*k)self.relu = nn.ReLU()self.bn1 = nn.BatchNorm1d(64)self.bn2 = nn.BatchNorm1d(128)self.bn3 = nn.BatchNorm1d(1024)self.bn4 = nn.BatchNorm1d(512)self.bn5 = nn.BatchNorm1d(256)self.k = kdef forward(self, x):batchsize = x.size()[0]x = F.relu(self.bn1(self.conv1(x)))x = F.relu(self.bn2(self.conv2(x)))x = F.relu(self.bn3(self.conv3(x)))x = torch.max(x, 2, keepdim=True)[0]x = x.view(-1, 1024)x = F.relu(self.bn4(self.fc1(x)))x = F.relu(self.bn5(self.fc2(x)))x = self.fc3(x)iden = Variable(torch.from_numpy(np.eye(self.k).flatten().astype(np.float32))).view(1,self.k*self.k).repeat(batchsize,1)if x.is_cuda:iden = iden.cuda()x = x + idenx = x.view(-1, self.k, self.k)return x
#包含变换矩阵的中间网络
class PointNetfeat(nn.Module):def __init__(self, global_feat = True, feature_transform = False):super(PointNetfeat, self).__init__()self.stn = STN3d()self.conv1 = torch.nn.Conv1d(3, 64, 1)self.conv2 = torch.nn.Conv1d(64, 128, 1)self.conv3 = torch.nn.Conv1d(128, 1024, 1)self.bn1 = nn.BatchNorm1d(64)self.bn2 = nn.BatchNorm1d(128)self.bn3 = nn.BatchNorm1d(1024)self.global_feat = global_featself.feature_transform = feature_transformif self.feature_transform:self.fstn = STNkd(k=64)def forward(self, x):n_pts = x.size()[2]# size()返回张量各个维度的尺度trans = self.stn(x) #得到3x3的坐标变换矩阵x = x.transpose(2, 1) #调整点的维度,将点云数据转换为nx3形式,便于和旋转矩阵计算x = torch.bmm(x, trans) #点的坐标和3x3的变换矩阵相乘x = x.transpose(2, 1) #再把点的坐标调整回来3xnx = F.relu(self.bn1(self.conv1(x))) #作者本来在这里用了两次mlpif self.feature_transform: trans_feat = self.fstn(x) #得到64x64的特征变换矩阵x = x.transpose(2,1) x = torch.bmm(x, trans_feat)x = x.transpose(2,1)else:trans_feat = Nonepointfeat = x # 保留经过第一次mlp的特征,便于后续分割进行特征拼接融合x = F.relu(self.bn2(self.conv2(x)))# 第二次mlp的第一层,64->128x = self.bn3(self.conv3(x))# 第二次mlp的第二层,128->1024x = torch.max(x, 2, keepdim=True)[0] # pointnet的核心操作,最大池化操作保证了点云的置换不变性(最大池化操作为对称函数)x = x.view(-1, 1024)# resize池化结果的形状,获得全局1024维特征if self.global_feat:return x, trans, trans_feat #返回特征、坐标变换矩阵、特征变换矩阵else:x = x.view(-1, 1024, 1).repeat(1, 1, n_pts)return torch.cat([x, pointfeat], 1), trans, trans_feat #分割时候会用到的global特征、坐标变换矩阵、特征变换矩阵
#主干网络
class PointNetCls(nn.Module):def __init__(self, k=2, feature_transform=False): #k表示最后分为k类super(PointNetCls, self).__init__()self.feature_transform = feature_transformself.feat = PointNetfeat(global_feat=True, feature_transform=feature_transform) #global_feat=True 表示只用于分类self.fc1 = nn.Linear(1024, 512)self.fc2 = nn.Linear(512, 256)self.fc3 = nn.Linear(256, k)self.dropout = nn.Dropout(p=0.3)self.bn1 = nn.BatchNorm1d(512)self.bn2 = nn.BatchNorm1d(256)self.relu = nn.ReLU()def forward(self, x):x, trans, trans_feat = self.feat(x)# 调用带变换的网络x = F.relu(self.bn1(self.fc1(x)))# 第三次mlp的第一层:1024->512x = F.relu(self.bn2(self.dropout(self.fc2(x)))) # 第三次mlp的第二层:512->256x = self.fc3(x)# 全连接得到k维return F.log_softmax(x, dim=1), trans, trans_feat# log_softmax分类,解决softmax在计算e的次方时容易造成的上溢出和下溢出问题#分割
class PointNetDenseCls(nn.Module):def __init__(self, k = 2, feature_transform=False):super(PointNetDenseCls, self).__init__()self.k = kself.feature_transform=feature_transformself.feat = PointNetfeat(global_feat=False, feature_transform=feature_transform)self.conv1 = torch.nn.Conv1d(1088, 512, 1)self.conv2 = torch.nn.Conv1d(512, 256, 1)self.conv3 = torch.nn.Conv1d(256, 128, 1)self.conv4 = torch.nn.Conv1d(128, self.k, 1)self.bn1 = nn.BatchNorm1d(512)self.bn2 = nn.BatchNorm1d(256)self.bn3 = nn.BatchNorm1d(128)def forward(self, x):batchsize = x.size()[0]n_pts = x.size()[2]x, trans, trans_feat = self.feat(x)x = F.relu(self.bn1(self.conv1(x)))x = F.relu(self.bn2(self.conv2(x)))x = F.relu(self.bn3(self.conv3(x)))x = self.conv4(x)x = x.transpose(2,1).contiguous()x = F.log_softmax(x.view(-1,self.k), dim=-1)x = x.view(batchsize, n_pts, self.k)return x, trans, trans_feat#特征变换矩阵的正则化
def feature_transform_regularizer(trans):d = trans.size()[1]batchsize = trans.size()[0]I = torch.eye(d)[None, :, :]if trans.is_cuda:I = I.cuda()loss = torch.mean(torch.norm(torch.bmm(trans, trans.transpose(2,1)) - I, dim=(1,2)))return loss#测试用的函数
if __name__ == '__main__':sim_data = Variable(torch.rand(32,3,2500))trans = STN3d()out = trans(sim_data)print('stn', out.size())print('loss', feature_transform_regularizer(out))sim_data_64d = Variable(torch.rand(32, 64, 2500))trans = STNkd(k=64)out = trans(sim_data_64d)print('stn64d', out.size())print('loss', feature_transform_regularizer(out))pointfeat = PointNetfeat(global_feat=True)out, _, _ = pointfeat(sim_data)print('global feat', out.size())pointfeat = PointNetfeat(global_feat=False)out, _, _ = pointfeat(sim_data)print('point feat', out.size())cls = PointNetCls(k = 5)out, _, _ = cls(sim_data)print('class', out.size())seg = PointNetDenseCls(k = 3)out, _, _ = seg(sim_data)print('seg', out.size())

(1)if name == ‘main’ 如何正确理解?

参考文献

1.PointNet.pytorch程序注释点云分类
2.PointNet网络结构详细解析
3.PointNet学习记录
4.PointNet代码学习(pytorch版本)
5.Dir-b/PointNet_Github
6.jiangdi1998/PointNet.pytorch_Github

PointNet代码详细解释(Pytorch版本)相关推荐

  1. 吴恩达机器学习 神经网络 作业1(用已经求好的权重进行手写数字分类) Python实现 代码详细解释

    整个项目的github:https://github.com/RobinLuoNanjing/MachineLearning_Ng_Python 里面可以下载进行代码实现的数据集 题目介绍: In t ...

  2. 吴恩达机器学习 逻辑回归 作业3(手写数字分类) Python实现 代码详细解释

    整个项目的github:https://github.com/RobinLuoNanjing/MachineLearning_Ng_Python 里面可以下载进行代码实现的数据集 题目介绍: In t ...

  3. 吴恩达机器学习 逻辑回归 作业2(芯片预测) Python实现 代码详细解释

    整个项目的github:https://github.com/RobinLuoNanjing/MachineLearning_Ng_Python 里面可以下载进行代码实现的数据集 题目介绍: In t ...

  4. 共享单车神经网络预测(pytorch )每行代码详细解释

    1.数据来源:本例将会使用一个国外的共享单车公开数据集(Capital Bikeshare)来完成我们的任务,数据集下载链接:www.capitalbikeshare.com/ system-data ...

  5. Gauss-Newton算法代码详细解释(转载+自己注释)

    这篇博客是对[1]中不详细的地方进行细节上的阐述, 并且每句代码都加了注释,使得更加容易理解 下面的论述(包括伪代码和算法)特指被最小化的目标函数是MSE的时候 需要注意,如果不是MSE为目标函数,那 ...

  6. ID3的REP(Reduced Error Pruning)剪枝代码详细解释+周志华《机器学习》决策树图4.5、图4.6、图4.7绘制

    处理数据对象:离散型数据 信息计算方式:熵 数据集:西瓜数据集2.0共17条数据 训练集(用来建立决策树):西瓜数据集2.0中的第1,2,3,6,7,10,14,15,16,17,4 请注意,书上说是 ...

  7. NLTK找出最频繁的名词标记的程序(代码详细解释)

    代码来自<Python自然语言处理>,我做了详细的代码解释. # -*- coding:utf-8 -*- import nltk def findtags(tag_prefix,tagg ...

  8. 超详细的pytorch版本yolov3安装教程--亲测有效!!!

    前言 最近在进行一个工程项目,需要使用yolo算法来实现.首先就选择了yolov3来进行demo实现,因为yolov3在YOLO系列中也是非常经典的一个版本.网上有很多环境配置教程,但是很多教程的讲述 ...

  9. iapp调用java点击换行,iapp部分基础代码详细解释

    该楼层疑似违规已被系统折叠 隐藏此楼查看此楼 本文绝大部分内容是在iapp1.5版本左右编写的,可能与当前iapp版本代码稍有差别.如果发现,望指出,谢谢. 1. [用syso打印及tw调试app] ...

最新文章

  1. 卧槽!新基建背景下,这些姿势架构师必须懂!
  2. Hacker(六)----黑客藏匿之地--系统进程
  3. Ghost 系统的过程
  4. ifstat,iftop
  5. Activity-数据状态的保存
  6. 球体动画Android,Android自定义View实现简单炫酷的球体进度球实例代码
  7. CXF发布RestFul WebService和SOAP WebService
  8. Android批量图片加载经典系列——使用LruCache、AsyncTask缓存并异步加载图片
  9. 【百度网盘】老罗android开发视频教程[压缩后3.63G]
  10. 最好用的HDR图像处理器——Photomatix Pro新功能介绍及使用教程
  11. 通信原理(一) 通信原理概述
  12. 中国联通 光猫 吉比特 G-140W-UG 管理员 账号密码
  13. html5 dicom opensource,基于HTML5标准的Dicom图像显示.pdf
  14. 夸克服务器过载或暂停维修,服务器过载或CGI脚本出错
  15. asp.net动态网页制作视频教程
  16. 3dsmax2020安装报1603错误的解决方法
  17. 把kali linux 装进 U盘并实现数据可存储
  18. android 小米加载大图,Android手机拍照或从本地相册选取图片设置头像。适配小米、华为、7.0...
  19. 微信支付失败中关于“签名错误”的解决方案
  20. Excel技能培训-INDIRECT实现拼接动态引用单元格,trl+pageDown速切换工作簿,多工作表求和,多个工作簿合并和拆分

热门文章

  1. kubeedge设备添加以及mapper管理
  2. cdn 中移集采_中兴通讯中标中国移动融合CDN四期集采新建项目最大份额
  3. axure做微信小程序 - axure教程
  4. 全民java竞争有多激烈,全民Kotlin:你没有玩过的全新玩法
  5. LeetCode-994-腐烂的橘子
  6. 《R in action》《R语言实战》源代码_2
  7. 短视频剪辑都用什么软件 短视频剪辑软件推荐
  8. 微信好友太少?因为你还没掌握这些微信爆粉方法!(上篇)
  9. 七月算法机器学习笔记10 人工神经网络
  10. HTML5常用四种盒模型标签介绍与区分