卷积计算是深度学习模型的常见算子,在3D项目中,比如点云分割,由于点云数据是稀疏的,使用常规的卷积计算,将会加大卷积计算时间,不利于模型推理加速。由此SECOND网络提出了稀疏卷积的概念。

稀疏卷积的主要理念就是由正常的全部数据进行卷积运算,优化了为只计算有效的输入点的卷积结果。稀疏卷积的思路网上已经有很多简明扼要的文章,比如知乎的这一篇就很清晰,本文就是根据这一篇的思路实现的一个简单的稀疏卷积流程。建议先看一下先了解。

稀疏卷积的输入是有效输入点的索引坐标(哈希表)和对应的features值,大概流程是:

1,根据输入坐标得到输出点的索引坐标(哈希表)。每一个输入点,可以最多和kernel个点(比如3d卷积,kernel=3,则kernel点个数是3*3*3=27)相乘,得到kernel点个数的输出坐标。所以rulebook可以建立成kenel点个数的字典,每个kenel对应一个或多个输入点索引和输出点索引。

2,将输入点和对应kernel点进行矩阵乘法,得到卷积结果。

3,将同一个输出点坐标的卷积结果进行累加,根据输出点索引与真实坐标的关系,将结果还原到输出位置,即完成了稀疏卷积运算。

下面是实现的一个简单示例代码,其中稀疏卷积结果和普通卷积结果进行了对比,误差为0。

输入坐标和输出点坐标的映射关系,是遍历每个输出点的坐标,根据输出点坐标,kernel,stride可以得到相关的kernel点个输出点的坐标,如果在有点输出点列表里面,则表示这是一个有效输出点,更新输出点索引哈希表和rulebook字典。

这种方法的时间复杂度较大,需要遍历所有输出点,后面有优化方案,直接有公式计算输入点对应的输出点坐标。但是可以大概看一下整体流程。

# -*- coding: utf-8 -*-
import timeimport torch
import torch.nn as nn
import itertools
import numpy as npdef generate_sparse_data(shape,num_points,num_channels,integer=False,data_range=(-1, 1),with_dense=True,dtype=np.float32):dense_shape = shapendim = len(dense_shape)num_points = np.array(num_points)batch_size = len(num_points)batch_indices = []coors_total = np.stack(np.meshgrid(*[np.arange(0, s) for s in shape]),axis=-1)coors_total = coors_total.reshape(-1, ndim)for i in range(batch_size):np.random.shuffle(coors_total)inds_total = coors_total[:num_points[i]]inds_total = np.pad(inds_total, ((0, 0), (0, 1)),mode="constant",constant_values=i)batch_indices.append(inds_total)if integer:sparse_data = np.random.randint(data_range[0],data_range[1],size=[num_points.sum(),num_channels]).astype(dtype)else:sparse_data = np.random.uniform(data_range[0],data_range[1],size=[num_points.sum(),num_channels]).astype(dtype)res = {"features": sparse_data.astype(dtype),}if with_dense:dense_data = np.zeros([batch_size, num_channels, *dense_shape],dtype=sparse_data.dtype)start = 0for i, inds in enumerate(batch_indices):for j, ind in enumerate(inds):dense_slice = (i, slice(None), *ind[:-1])dense_data[dense_slice] = sparse_data[start + j]start += len(inds)res["features_dense"] = dense_data.astype(dtype)batch_indices = np.concatenate(batch_indices, axis=0)res["indices"] = batch_indices.astype(np.int32)return resdef get_Pin2Pout_Rulebook_3d(n,ho, wo,do,ks,stride, in_indice):'''根据有效的输入点位置,得到有效的输出点位置,并建立kernel, in_idx, out_indx字典关系。in_indice:有效点的坐标 [[hi,wi,ni],[hi1,wi1,ni1],...]return:offset, {k0:[[pin_idx, pout_idx],...], k2:[[pin_idx, pout_idx],...]}pout_indice, same to in_indice'''offset = {i: [] for i in range(ks**3)}pout_indice = []out_count = 0for b, i, j, d in itertools.product(range(n), range(ho), range(wo), range(do)):flag = Falsefor kh, kw, kd in itertools.product(range(ks),range(ks),range(ks)):if [stride*i + kh, stride*j + kw,stride*d+kd,b] in in_indice:flag = Trueoffset[kh*ks*ks+kw*ks+kd].append([in_indice.index([ stride*i + kh, stride*j + kw,stride*d+kd,b]), out_count])  # [in_index,out_index]if flag == True:pout_indice.append([b, i, j,d])out_count += 1return offset, pout_indicedef get_output_3d(rulebook,in_data,weight_data,out_indice,out_data):'''遍历每一个kernel, 通过查找pin_idx和对应的kernel, 矩阵乘得到pout的值,并放回位置。同一个pout结果累加'''for key in rulebook.keys():cur_book=rulebook[key]w_data=weight_data[key]for i in range(len(cur_book)):x=in_data[cur_book[i][0],:]n,ho,wo,do=out_indice[cur_book[i][1]]out_data[n,:,ho,wo,do]+=np.matmul(x,w_data)return out_datadef test_conv3d(sparse_dict,ci,co,kernel,stride):features=sparse_dict['features']features_dense=sparse_dict['features_dense']in_indices=sparse_dict['indices'] #conv3d=nn.Conv3d(ci,co, kernel,stride=stride, bias=False)weight = conv3d.weight.detach().numpy()  # co,ci,kh,kwweight = weight.reshape(co, ci, kernel ** 3).transpose(2, 1, 0)ref_out=conv3d(torch.tensor(features_dense))bs,co,ho,wo,do=ref_out.shapespconv_out=np.zeros([bs,co,ho,wo,do])rulebook,pout_indice=get_Pin2Pout_Rulebook_3d(bs,ho,wo,do,kernel,stride, in_indices.tolist())spconv_out=get_output_3d(rulebook,features,weight,pout_indice,spconv_out)dif=np.abs(ref_out.detach().numpy()-spconv_out)print('max diff is:',round(np.max(dif),4))print('sparse conv3d test over')return spconv_outif __name__ =="__main__":shapes=(9,19,18)  # conv3d:(h,w,d)bs=1  #batch_sizeks=3 #kernel_sizestride=2ci=7co=32num_points = [100] * bs  # 100个有效点个数sparse_dict=generate_sparse_data(shapes,num_points,ci)test_conv3d(sparse_dict, ci, co, ks,stride)  

备注

该示例代码默认无padding,可以任意定义输入shapes, 其中generate_sparse_data是spconv的github代码里面给产生的稀疏数据代码。

基于pytorch简单实现稀疏3d卷积(SECOND)相关推荐

  1. 基于pytorch的模型稀疏训练与模型剪枝示例

    基于pytorch的模型稀疏训练与模型剪枝示例 稀疏训练+模型剪枝代码下载地址:下载地址 CIFAR10-VGG16BN Baseline Trained with Sparsity (1e-4) P ...

  2. 基于Pytorch再次解读NiN现代卷积神经网络和批量归一化

    个人简介:CSDN百万访问量博主,普普通通男大学生,深度学习算法.医学图像处理专攻,偶尔也搞全栈开发,没事就写文章,you feel me? 博客地址:lixiang.blog.csdn.net 基于 ...

  3. 基于Pytorch再次解读DenseNet现代卷积神经网络

    个人简介:CSDN百万访问量博主,普普通通男大学生,深度学习算法.医学图像处理专攻,偶尔也搞全栈开发,没事就写文章,you feel me? 博客地址:lixiang.blog.csdn.net 基于 ...

  4. 基于Pytorch再次解读LeNet-5现代卷积神经网络

    个人简介:CSDN百万访问量博主,普普通通男大学生,深度学习算法.医学图像处理专攻,偶尔也搞全栈开发,没事就写文章,you feel me? 博客地址:lixiang.blog.csdn.net 基于 ...

  5. 基于Pytorch再次解读ResNet现代卷积神经网络

    个人简介:CSDN百万访问量博主,普普通通男大学生,深度学习算法.医学图像处理专攻,偶尔也搞全栈开发,没事就写文章,you feel me? 博客地址:lixiang.blog.csdn.net 基于 ...

  6. 基于Pytorch再次解析AlexNet现代卷积神经网络

    个人简介:CSDN百万访问量博主,普普通通男大学生,深度学习算法.医学图像处理专攻,偶尔也搞全栈开发,没事就写文章,you feel me? 博客地址:lixiang.blog.csdn.net 基于 ...

  7. 基于Pytorch再次解读GoogLeNet现代卷积神经网络

    个人简介:CSDN百万访问量博主,普普通通男大学生,深度学习算法.医学图像处理专攻,偶尔也搞全栈开发,没事就写文章,you feel me? 博客地址:lixiang.blog.csdn.net 基于 ...

  8. 图卷积神经网络笔记——第六章:(1)基于PyTorch的时序数据处理(交通流量数据)

    在前面说了PyG这个框架,但是这个框架处理数据其实没那么简单,并且有时候我们想要改变底层的图卷积框架时就无能为力了,所以这一章说一下用PyTorch怎么写出图卷积并且实现交通流量数据的预测.但在这之前 ...

  9. 基于pytorch使用实现CNN 如何使用pytorch构建CNN卷积神经网络

    基于pytorch使用实现CNN 如何使用pytorch构建CNN卷积神经网络 所用工具 文件结构: 数据: 代码: 结果: 改进思路 拓展 本文是一个基于pytorch使用CNN在生物信息学上进行位 ...

最新文章

  1. 10分钟搭建你的第一个图像识别模型 | 附完整代码
  2. PCL点云库用Poisson网格化实现点云的表面重建
  3. jmeter之ip欺骗
  4. servlet容器_Tomcat 容器与servlet的交互原理
  5. 20165301 预备作业二:学习基础和C语言基础调查
  6. 数字效率Evernote超效率数字笔记术
  7. Mysql5.7开启远程
  8. yum 查看java版本_如何查找YUM安装的JAVA_HOME环境变量详解
  9. 电商数据库设计及架构优化实战(一) - 制定数据库开发规范
  10. 卷积神经网络图像识别_[源码和文档分享]基于CUDA的卷积神经网络算法实现
  11. #436. 子串的最大差(单调栈)
  12. wiki.openwrt.org无法打开的解决办法
  13. 产品经理的第一堂课(四):质量还是质量
  14. 32 位和 64 位版本的 Office 2010 之间的兼容性,同样适用于AutoCAD的VBA兼容性--VBA 64 32 调用dll的区别
  15. 系统主题修改桌面嵌入html,更换主题桌面主题 Win7桌面动态主题怎么更换
  16. VMware安装centOS镜像
  17. php+laravel框架七牛云存储+图片审核+文字审核
  18. Mongrel无法启动解决方案
  19. 实现查找关键字高亮显示
  20. macromedia_Macromedia.com的想法…

热门文章

  1. opencv 与dlib 结合实现人脸融合
  2. 使用PyTorch+OpenCV进行人脸识别(附代码演练)
  3. GRIN透镜的构造和建模
  4. 博主-橄榄山软件创始人-其人其事
  5. 什么是销售管理软件及其重要性
  6. 支持5G和C-V2X的L3级量产车预计2021年上市,值得期待?...
  7. 初入NLP领域的一些小建议 1
  8. (附源码)计算机毕业设计ssm Sketch2Mod网站
  9. Proteus仿真运行流水灯程序
  10. 短信转发器 SmsForwarder,备用机必备神器,开源免费