3D MinkowskiEngine稀疏模式重建
本文看一个简单的演示示例,该示例训练一个3D卷积神经网络,该网络用一个热点向量one-hot vector重构3D稀疏模式。这类似于Octree生成网络ICCV’17。输入的one-hot vector一热向量,来自ModelNet40数据集的3D计算机辅助设计(CAD)椅子索引。
使用MinkowskiEngine.MinkowskiConvolutionTranspose和 MinkowskiEngine.MinkowskiPruning,依次将体素上采样2倍,然后删除一些上采样的体素,以生成目标形状。常规的网络体系结构看起来类似于下图,但是细节可能有所不同。

在继续之前,请先阅读训练和数据加载。
创建稀疏模式重建网络
要从矢量创建3D网格世界中定义的稀疏张量,需要从 1×1×1分辨率体素。本文使用一个由块MinkowskiEngine.MinkowskiConvolutionTranspose,MinkowskiEngine.MinkowskiConvolution和MinkowskiEngine.MinkowskiPruning。
在前进过程forward pass中,为1)主要特征和2)稀疏体素分类创建两条路径,以删除不必要的体素。
out = upsample_block(z)
out_cls = classification(out).F
out = pruning(out, out_cls > 0)
在输入的稀疏张量达到目标分辨率之前,网络会重复执行一系列的上采样和修剪操作,以去除不必要的体素。在下图上可视化结果。注意,最终的重建非常精确地捕获了目标几何体。还可视化了上采样和修剪的分层重建过程。

运行示例
要训练网络,请转到Minkowski Engine根目录,然后键入:
python -m examples.reconstruction --train
要可视化网络预测或尝试预先训练的模型,请输入:
python -m examples.reconstruction

该程序将可视化两个3D形状。左边的一个是目标3D形状,右边的一个是重构的网络预测。
完整的代码可以在example / reconstruction.py找到。
import os
import sys
import subprocess
import argparse
import logging
import glob
import numpy as np
from time import time
import urllib
# Must be imported before large libs
try:
import open3d as o3d
except ImportError:
raise ImportError(‘Please install open3d and scipy with pip install open3d scipy.’)

import torch
import torch.nn as nn
import torch.utils.data
import torch.optim as optimimport MinkowskiEngine as MEfrom examples.modelnet40 import InfSampler, resample_meshM = np.array([[0.80656762, -0.5868724, -0.07091862],
[0.3770505, 0.418344, 0.82632997],
[-0.45528188, -0.6932309, 0.55870326]])assert int(
o3d.__version__.split('.')[1]
) >= 8, f'Requires open3d version >= 0.8, the current version is {o3d.__version__}'if not os.path.exists('ModelNet40'):
logging.info('Downloading the fixed ModelNet40 dataset...')
subprocess.run(["sh", "./examples/download_modelnet40.sh"])###############################################################################
# Utility functions
###############################################################################
def PointCloud(points, colors=None):pcd = o3d.geometry.PointCloud()
pcd.points = o3d.utility.Vector3dVector(points)
if colors is not None:
pcd.colors = o3d.utility.Vector3dVector(colors)
return pcddef collate_pointcloud_fn(list_data):
coords, feats, labels = list(zip(*list_data))# Concatenate all lists
return {
'coords': coords,
'xyzs': [torch.from_numpy(feat).float() for feat in feats],
'labels': torch.LongTensor(labels),
}class ModelNet40Dataset(torch.utils.data.Dataset):def __init__(self, phase, transform=None, config=None):
self.phase = phase
self.files = []
self.cache = {}
self.data_objects = []
self.transform = transform
self.resolution = config.resolution
self.last_cache_percent = 0self.root = './ModelNet40'
fnames = glob.glob(os.path.join(self.root, 'chair/train/*.off'))
fnames = sorted([os.path.relpath(fname, self.root) for fname in fnames])
self.files = fnames
assert len(self.files) > 0, "No file loaded"
logging.info(
f"Loading the subset {phase} from {self.root} with {len(self.files)} files"
)
self.density = 30000# Ignore warnings in obj loader
o3d.utility.set_verbosity_level(o3d.utility.VerbosityLevel.Error)def __len__(self):
return len(self.files)def __getitem__(self, idx):
mesh_file = os.path.join(self.root, self.files[idx])
if idx in self.cache:
xyz = self.cache[idx]
else:
# Load a mesh, over sample, copy, rotate, voxelization
assert os.path.exists(mesh_file)
pcd = o3d.io.read_triangle_mesh(mesh_file)
# Normalize to fit the mesh inside a unit cube while preserving aspect ratio
vertices = np.asarray(pcd.vertices)
vmax = vertices.max(0, keepdims=True)
vmin = vertices.min(0, keepdims=True)
pcd.vertices = o3d.utility.Vector3dVector(
(vertices - vmin) / (vmax - vmin).max())# Oversample points and copy
xyz = resample_mesh(pcd, density=self.density)
self.cache[idx] = xyz
cache_percent = int((len(self.cache) / len(self)) * 100)
if cache_percent > 0 and cache_percent % 10 == 0 and cache_percent != self.last_cache_percent:
logging.info(
f"Cached {self.phase}: {len(self.cache)} / {len(self)}: {cache_percent}%"
)
self.last_cache_percent = cache_percent# Use color or other features if available
feats = np.ones((len(xyz), 1))if len(xyz) < 1000:
logging.info(
f"Skipping {mesh_file}: does not have sufficient CAD sampling density after resampling: {len(xyz)}."
)
return Noneif self.transform:
xyz, feats = self.transform(xyz, feats)# Get coords
xyz = xyz * self.resolution
coords = np.floor(xyz)
inds = ME.utils.sparse_quantize(coords, return_index=True)return (coords[inds], xyz[inds], idx)def make_data_loader(phase, augment_data, batch_size, shuffle, num_workers,
repeat, config):
dset = ModelNet40Dataset(phase, config=config)args = {
'batch_size': batch_size,
'num_workers': num_workers,
'collate_fn': collate_pointcloud_fn,
'pin_memory': False,
'drop_last': False
}if repeat:
args['sampler'] = InfSampler(dset, shuffle)
else:
args['shuffle'] = shuffleloader = torch.utils.data.DataLoader(dset, **args)return loaderch = logging.StreamHandler(sys.stdout)
logging.getLogger().setLevel(logging.INFO)
logging.basicConfig(
format=os.uname()[1].split('.')[0] + ' %(asctime)s %(message)s',
datefmt='%m/%d %H:%M:%S',
handlers=[ch])parser = argparse.ArgumentParser()
parser.add_argument('--resolution', type=int, default=128)
parser.add_argument('--max_iter', type=int, default=30000)
parser.add_argument('--val_freq', type=int, default=1000)
parser.add_argument('--batch_size', default=16, type=int)
parser.add_argument('--lr', default=1e-2, type=float)
parser.add_argument('--momentum', type=float, default=0.9)
parser.add_argument('--weight_decay', type=float, default=1e-4)
parser.add_argument('--num_workers', type=int, default=1)
parser.add_argument('--stat_freq', type=int, default=50)
parser.add_argument(
'--weights', type=str, default='modelnet_reconstruction.pth')
parser.add_argument('--load_optimizer', type=str, default='true')
parser.add_argument('--train', action='store_true')
parser.add_argument('--max_visualization', type=int, default=4)###############################################################################
# End of utility functions
###############################################################################class GenerativeNet(nn.Module):CHANNELS = [1024, 512, 256, 128, 64, 32, 16]def __init__(self, resolution, in_nchannel=512):
nn.Module.__init__(self)self.resolution = resolution# Input sparse tensor must have tensor stride 128.
ch = self.CHANNELS# Block 1
self.block1 = nn.Sequential(
ME.MinkowskiConvolutionTranspose(
in_nchannel,
ch[0],
kernel_size=2,
stride=2,
generate_new_coords=True,
dimension=3),
ME.MinkowskiBatchNorm(ch[0]),
ME.MinkowskiELU(),
ME.MinkowskiConvolution(ch[0], ch[0], kernel_size=3, dimension=3),
ME.MinkowskiBatchNorm(ch[0]),
ME.MinkowskiELU(),
ME.MinkowskiConvolutionTranspose(
ch[0],
ch[1],
kernel_size=2,
stride=2,
generate_new_coords=True,
dimension=3),
ME.MinkowskiBatchNorm(ch[1]),
ME.MinkowskiELU(),
ME.MinkowskiConvolution(ch[1], ch[1], kernel_size=3, dimension=3),
ME.MinkowskiBatchNorm(ch[1]),
ME.MinkowskiELU(),
)self.block1_cls = ME.MinkowskiConvolution(
ch[1], 1, kernel_size=1, has_bias=True, dimension=3)# Block 2
self.block2 = nn.Sequential(
ME.MinkowskiConvolutionTranspose(
ch[1],
ch[2],
kernel_size=2,
stride=2,
generate_new_coords=True,
dimension=3),
ME.MinkowskiBatchNorm(ch[2]),
ME.MinkowskiELU(),
ME.MinkowskiConvolution(ch[2], ch[2], kernel_size=3, dimension=3),
ME.MinkowskiBatchNorm(ch[2]),
ME.MinkowskiELU(),
)self.block2_cls = ME.MinkowskiConvolution(
ch[2], 1, kernel_size=1, has_bias=True, dimension=3)# Block 3
self.block3 = nn.Sequential(
ME.MinkowskiConvolutionTranspose(
ch[2],
ch[3],
kernel_size=2,
stride=2,
generate_new_coords=True,
dimension=3),
ME.MinkowskiBatchNorm(ch[3]),
ME.MinkowskiELU(),
ME.MinkowskiConvolution(ch[3], ch[3], kernel_size=3, dimension=3),
ME.MinkowskiBatchNorm(ch[3]),
ME.MinkowskiELU(),
)self.block3_cls = ME.MinkowskiConvolution(
ch[3], 1, kernel_size=1, has_bias=True, dimension=3)# Block 4
self.block4 = nn.Sequential(
ME.MinkowskiConvolutionTranspose(
ch[3],
ch[4],
kernel_size=2,
stride=2,
generate_new_coords=True,
dimension=3),
ME.MinkowskiBatchNorm(ch[4]),
ME.MinkowskiELU(),
ME.MinkowskiConvolution(ch[4], ch[4], kernel_size=3, dimension=3),
ME.MinkowskiBatchNorm(ch[4]),
ME.MinkowskiELU(),
)self.block4_cls = ME.MinkowskiConvolution(
ch[4], 1, kernel_size=1, has_bias=True, dimension=3)# Block 5
self.block5 = nn.Sequential(
ME.MinkowskiConvolutionTranspose(
ch[4],
ch[5],
kernel_size=2,
stride=2,
generate_new_coords=True,
dimension=3),
ME.MinkowskiBatchNorm(ch[5]),
ME.MinkowskiELU(),
ME.MinkowskiConvolution(ch[5], ch[5], kernel_size=3, dimension=3),
ME.MinkowskiBatchNorm(ch[5]),
ME.MinkowskiELU(),
)self.block5_cls = ME.MinkowskiConvolution(
ch[5], 1, kernel_size=1, has_bias=True, dimension=3)# Block 6
self.block6 = nn.Sequential(
ME.MinkowskiConvolutionTranspose(
ch[5],
ch[6],
kernel_size=2,
stride=2,
generate_new_coords=True,
dimension=3),
ME.MinkowskiBatchNorm(ch[6]),
ME.MinkowskiELU(),
ME.MinkowskiConvolution(ch[6], ch[6], kernel_size=3, dimension=3),
ME.MinkowskiBatchNorm(ch[6]),
ME.MinkowskiELU(),
)self.block6_cls = ME.MinkowskiConvolution(
ch[6], 1, kernel_size=1, has_bias=True, dimension=3)# pruning
self.pruning = ME.MinkowskiPruning()def get_batch_indices(self, out):
return out.coords_man.get_row_indices_per_batch(out.coords_key)def get_target(self, out, target_key, kernel_size=1):
with torch.no_grad():
target = torch.zeros(len(out), dtype=torch.bool)
cm = out.coords_man
strided_target_key = cm.stride(
target_key, out.tensor_stride[0], force_creation=True)
ins, outs = cm.get_kernel_map(
out.coords_key,
strided_target_key,
kernel_size=kernel_size,
region_type=1)
for curr_in in ins:
target[curr_in] = 1
return targetdef valid_batch_map(self, batch_map):
for b in batch_map:
if len(b) == 0:
return False
return Truedef forward(self, z, target_key):
out_cls, targets = [], []# Block1
out1 = self.block1(z)
out1_cls = self.block1_cls(out1)
target = self.get_target(out1, target_key)
targets.append(target)
out_cls.append(out1_cls)
keep1 = (out1_cls.F > 0).cpu().squeeze()# If training, force target shape generation, use net.eval() to disable
if self.training:
keep1 += target# Remove voxels 32
out1 = self.pruning(out1, keep1.cpu())# Block 2
out2 = self.block2(out1)
out2_cls = self.block2_cls(out2)
target = self.get_target(out2, target_key)
targets.append(target)
out_cls.append(out2_cls)
keep2 = (out2_cls.F > 0).cpu().squeeze()if self.training:
keep2 += target# Remove voxels 16
out2 = self.pruning(out2, keep2.cpu())# Block 3
out3 = self.block3(out2)
out3_cls = self.block3_cls(out3)
target = self.get_target(out3, target_key)
targets.append(target)
out_cls.append(out3_cls)
keep3 = (out3_cls.F > 0).cpu().squeeze()if self.training:
keep3 += target# Remove voxels 8
out3 = self.pruning(out3, keep3.cpu())# Block 4
out4 = self.block4(out3)
out4_cls = self.block4_cls(out4)
target = self.get_target(out4, target_key)
targets.append(target)
out_cls.append(out4_cls)
keep4 = (out4_cls.F > 0).cpu().squeeze()if self.training:
keep4 += target# Remove voxels 4
out4 = self.pruning(out4, keep4.cpu())# Block 5
out5 = self.block5(out4)
out5_cls = self.block5_cls(out5)
target = self.get_target(out5, target_key)
targets.append(target)
out_cls.append(out5_cls)
keep5 = (out5_cls.F > 0).cpu().squeeze()if self.training:
keep5 += target# Remove voxels 2
out5 = self.pruning(out5, keep5.cpu())# Block 5
out6 = self.block6(out5)
out6_cls = self.block6_cls(out6)
target = self.get_target(out6, target_key)
targets.append(target)
out_cls.append(out6_cls)
keep6 = (out6_cls.F > 0).cpu().squeeze()# Last layer does not require keep
# if self.training:
# keep6 += target# Remove voxels 1
out6 = self.pruning(out6, keep6.cpu())return out_cls, targets, out6def train(net, dataloader, device, config):
in_nchannel = len(dataloader.dataset)optimizer = optim.SGD(
net.parameters(),
lr=config.lr,
momentum=config.momentum,
weight_decay=config.weight_decay)
scheduler = optim.lr_scheduler.ExponentialLR(optimizer, 0.95)crit = nn.BCEWithLogitsLoss()net.train()
train_iter = iter(dataloader)
# val_iter = iter(val_dataloader)
logging.info(f'LR: {scheduler.get_lr()}')
for i in range(config.max_iter):s = time()
data_dict = train_iter.next()
d = time() - soptimizer.zero_grad()
init_coords = torch.zeros((config.batch_size, 4), dtype=torch.int)
init_coords[:, 0] = torch.arange(config.batch_size)in_feat = torch.zeros((config.batch_size, in_nchannel))
in_feat[torch.arange(config.batch_size), data_dict['labels']] = 1sin = ME.SparseTensor(
feats=in_feat,
coords=init_coords,
allow_duplicate_coords=True, # for classification, it doesn't matter
tensor_stride=config.resolution,
).to(device)# Generate target sparse tensor
cm = sin.coords_man
target_key = cm.create_coords_key(
ME.utils.batched_coordinates(data_dict['xyzs']),
force_creation=True,
allow_duplicate_coords=True)# Generate from a dense tensor
out_cls, targets, sout = net(sin, target_key)
num_layers, loss = len(out_cls), 0
losses = []
for out_cl, target in zip(out_cls, targets):
curr_loss = crit(out_cl.F.squeeze(),
target.type(out_cl.F.dtype).to(device))
losses.append(curr_loss.item())
loss += curr_loss / num_layersloss.backward()
optimizer.step()
t = time() - sif i % config.stat_freq == 0:
logging.info(
f'Iter: {i}, Loss: {loss.item():.3e}, Depths: {len(out_cls)} Data Loading Time: {d:.3e}, Tot Time: {t:.3e}'
)if i % config.val_freq == 0 and i > 0:
torch.save(
{
'state_dict': net.state_dict(),
'optimizer': optimizer.state_dict(),
'scheduler': scheduler.state_dict(),
'curr_iter': i,
}, config.weights)scheduler.step()
logging.info(f'LR: {scheduler.get_lr()}')net.train()def visualize(net, dataloader, device, config):
in_nchannel = len(dataloader.dataset)
net.eval()
crit = nn.BCEWithLogitsLoss()
n_vis = 0for data_dict in dataloader:init_coords = torch.zeros((config.batch_size, 4), dtype=torch.int)
init_coords[:, 0] = torch.arange(config.batch_size)in_feat = torch.zeros((config.batch_size, in_nchannel))
in_feat[torch.arange(config.batch_size), data_dict['labels']] = 1sin = ME.SparseTensor(
feats=in_feat,
coords=init_coords,
allow_duplicate_coords=True, # for classification, it doesn't matter
tensor_stride=config.resolution,
).to(device)# Generate target sparse tensor
cm = sin.coords_man
target_key = cm.create_coords_key(
ME.utils.batched_coordinates(data_dict['xyzs']),
force_creation=True,
allow_duplicate_coords=True)# Generate from a dense tensor
out_cls, targets, sout = net(sin, target_key)
num_layers, loss = len(out_cls), 0
for out_cl, target in zip(out_cls, targets):
loss += crit(out_cl.F.squeeze(),
target.type(out_cl.F.dtype).to(device)) / num_layersbatch_coords, batch_feats = sout.decomposed_coordinates_and_features
for b, (coords, feats) in enumerate(zip(batch_coords, batch_feats)):
pcd = PointCloud(coords)
pcd.estimate_normals()
pcd.translate([0.6 * config.resolution, 0, 0])
pcd.rotate(M)
opcd = PointCloud(data_dict['xyzs'][b])
opcd.translate([-0.6 * config.resolution, 0, 0])
opcd.estimate_normals()
opcd.rotate(M)
o3d.visualization.draw_geometries([pcd, opcd])n_vis += 1
if n_vis > config.max_visualization:
returnif __name__ == '__main__':
config = parser.parse_args()
logging.info(config)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')dataloader = make_data_loader(
'val',
augment_data=True,
batch_size=config.batch_size,
shuffle=True,
num_workers=config.num_workers,
repeat=True,
config=config)
in_nchannel = len(dataloader.dataset)net = GenerativeNet(config.resolution, in_nchannel=in_nchannel)
net.to(device)logging.info(net)if config.train:
train(net, dataloader, device, config)
else:
if not os.path.exists(config.weights):
logging.info(
f'Downloaing pretrained weights. This might take a while...')
urllib.request.urlretrieve(
"https://bit.ly/36d9m1n", filename=config.weights)logging.info(f'Loading weights from {config.weights}')
checkpoint = torch.load(config.weights)
net.load_state_dict(checkpoint['state_dict'])visualize(net, dataloader, device, config)

3D MinkowskiEngine稀疏模式重建相关推荐

  1. 3D测量| 主动模式投影提高AOI三维测量精度

    3D相移模式投影系统与高分辨率.远心镜头和大画幅相机相结合,实现高精度3D测量. 非接触式3D测量可以通过各种技术实现,最常用的方法包括:(1)激光轮廓测量法:用高功率激光器和线阵或面阵传感器实现:( ...

  2. 电影院3d是什么模式的_3D的完整形式是什么?

    电影院3d是什么模式的 3D:三维 (3D: Three Dimensional) 3D is an abbreviation of "Three Dimensional". 3D ...

  3. 计算机视觉基础——3D空间坐标点的重建(三角测量)

    本文承接计算机视觉基础--相机运动位姿估计 补充说明:当相机发生纯旋转的时候,我们的t为0,无法通过公式求解出本质矩阵E. python之SVD函数介绍 三角测量 Triangulation求解3D坐 ...

  4. 苹果获杀手级3D成像专利 可重建3D图像

    Apple 发明了可用于拍照和录像的杀手级3D成像照相机.开发中的新照相机将利用深度传感器,如LIDAR,RADAR和激光,来合成立体色差图产生3D影像.此外,这种相机将使用先进的色差.亮度传感器获取 ...

  5. css3 3d旋转兼容模式下,CSS3 3D 转换

    3D Transforms CSS3 允许您使用 3D 转换来对元素进行格式化. 在本章中,您将学到其中的一些 3D 转换方法:rotateX() rotateY() 点击下面的元素,来查看 2D 转 ...

  6. html 二叉树模式,重建二叉树.html

    Document function TreeNode(x) { this.val = x; this.left = null; this.right = null; } var pre = [1,2, ...

  7. css3 3d旋转兼容模式下,前端CSS3: 3D旋转的问题 (请水神和毒舌放过)

    [HTML] 纯文本查看 复制代码 Document body { perspective: 500px; } .box { position: relative; width: 300px; hei ...

  8. MedNeRF:用于从单个X射线重建3D感知CT投影的医学神经辐射场

    摘要  计算机断层扫描(CT)是一种有效的医学成像方式,广泛应用于临床医学领域,用于各种病理的诊断.多探测器CT成像技术的进步实现了额外的功能,包括生成薄层多平面横截面身体成像和3D重建.然而,这涉及 ...

  9. 腾讯优图|人脸3D重建与渲染技术研究与应用

    编辑丨腾讯优图AI开放平台 6月5日-6日,2021全球人工智能技术大会(GAITC 2021)在杭州成功举办.本次大会,旨在汇聚中国科创智慧与活力的同时,与世界建立互通共享的沟通桥梁,在交流中探索共 ...

最新文章

  1. 好的架构是有价值观的
  2. 列表删除前面两个元素_第015篇:List列表 - 课程二
  3. php html补全,PHP实现HTML标签自动补全代码
  4. 回发或回调参数无效。下拉菜单中使用ajax,联动菜单引起的问题解决方案
  5. 双网卡上网冲突解决_【技术文章】局域网IP地址冲突罪魁祸首是什么?这几点要注意!(附高手处理方法)...
  6. iLogtail使用入门-iLogtail本地配置模式部署(For Kafka Flusher)
  7. 【Java】HashMap构建登录程序
  8. 精讲了33道二叉树经典题目之后,我总结了这些,帮你一举搞定二叉树
  9. 查看计算机温度指令,怎么查看电脑温度|查看电脑温度的三种方式
  10. 全球与中国医院电子病历系统市场深度研究分析报告
  11. 我没见过凌晨四点的洛杉矶,但想带你聆听每个都市夜归人的故事
  12. php 错误503的原因,CentOS + Apache2.4 + PHP5.6 FPM报503错误
  13. 莎士比亚名言录(中英对照整理版,加出处by 澈)
  14. 2023年比较经典的软件测试工程师面试题(自我总结)
  15. 去掉Excel 单元格里的字符后面的空格
  16. 计算机应屏后打印不全怎样处理,打印机打印不完整是怎么回事【解决办法】
  17. 强制关闭无法关闭的进程的方法
  18. H5,Audio音乐播放器(移动版)
  19. iOS开发笔记之八十一——2020 iOS面试总结《一》之干货篇
  20. 基于ssm+vue的师生防疫登记管理系统 elementui

热门文章

  1. Gradle错误提示:Java home supplied via ‘xxx.xxx.xxx‘ is invalid
  2. 【JavaScript总结】JavaScript语法基础:JS编码
  3. Exception in thread main java.lang.Error: 无法解析的编译问题: 方法 main 不能声明为 static;只能在静态类型或顶级类型中才能声明静态方法
  4. Laravel7中Redis队列的使用
  5. MySQL数据库+命令大全+常用操作
  6. Android XML: unbound prefix
  7. Android 使用 ellipsize 实现文字横向移动效果(跑马灯效果)
  8. Ubuntu系统查看文件夹目录
  9. java.lang.UnsupportedOperationException: Required method destroyItem was not overridden
  10. 使用属性position:fixed的时候如何才能让div居中