• 本次源码解读的地址为:https://github.com/yanx27/Pointnet_Pointnet2_pytorch,这一版本的源码易读性高,主要是封装程度较低,注释较全,安装额外的库也比较少。

Pipeline

前向过程:数据加载、数据增强

 for epoch in range(start_epoch, args.epoch):log_string('Epoch %d (%d/%s):' % (global_epoch + 1, epoch + 1, args.epoch))mean_correct = []classifier = classifier.train()scheduler.step()for batch_id, (points, target) in tqdm(enumerate(trainDataLoader, 0), total=len(trainDataLoader),smoothing=0.9):optimizer.zero_grad()# b x n x cpoints = points.data.numpy()# b x n x c:并不缩减point的数量,而是设定某个阈值,将小于某个阈值的点覆盖为第一个点的信息points = provider.random_point_dropout(points)# 针对不同的batch进行随机坐标缩放,将某batch所有的pc乘某个缩放系数points[:, :, 0:3] = provider.random_scale_point_cloud(points[:, :, 0:3])# 针对不同的batch进行随机坐标缩放,将某batch所有的pc同步进行坐标移动points[:, :, 0:3] = provider.shift_point_cloud(points[:, :, 0:3])points = torch.Tensor(points)points = points.transpose(2, 1) # b,c,n:conv1d是channel first因此要转换if not args.use_cpu:points, target = points.cuda(), target.cuda()# 喂入模型数据pred, trans_feat = classifier(points)

model

  • Classification Model和Segmentation Model,区别在于PointNetEncoder的参数global_feature=true | false
  • 这里以Classification Model为例:
class get_model(nn.Module):def __init__(self, k=40, normal_channel=True):super(get_model, self).__init__()if normal_channel:channel = 6else:channel = 3self.feat = PointNetEncoder(global_feat=True, feature_transform=True, channel=channel)    # classification modelself.fc1 = nn.Linear(1024, 512)self.fc2 = nn.Linear(512, 256)self.fc3 = nn.Linear(256, k)self.dropout = nn.Dropout(p=0.4)self.bn1 = nn.BatchNorm1d(512)self.bn2 = nn.BatchNorm1d(256)self.relu = nn.ReLU()def forward(self, x):x, trans, trans_feat = self.feat(x)   # 这里计算出x = F.relu(self.bn1(self.fc1(x)))x = F.relu(self.bn2(self.dropout(self.fc2(x))))x = self.fc3(x)# 每个点云计算对应的类别损失x = F.log_softmax(x, dim=1)   # bs, kreturn x, trans_feat
  • 如果是segmentation model,只需对如上代码做如下修改
    def forward(self, x):batchsize = x.size()[0]n_pts = x.size()[2]x, trans, trans_feat = self.feat(x)x = F.relu(self.bn1(self.conv1(x)))x = F.relu(self.bn2(self.conv2(x)))x = F.relu(self.bn3(self.conv3(x)))x = self.conv4(x)x = x.transpose(2,1).contiguous()x = F.log_softmax(x.view(-1,self.k), dim=-1)x = x.view(batchsize, n_pts, self.k)return x, trans_feat

PointNetEncoder

class PointNetEncoder(nn.Module):def __init__(self, global_feat=True, feature_transform=False, channel=3):super(PointNetEncoder, self).__init__()self.stn = STN3d(channel)   # 计算第一个3x3的transformation矩阵self.conv1 = torch.nn.Conv1d(channel, 64, 1)self.conv2 = torch.nn.Conv1d(64, 128, 1)self.conv3 = torch.nn.Conv1d(128, 1024, 1)self.bn1 = nn.BatchNorm1d(64)self.bn2 = nn.BatchNorm1d(128)self.bn3 = nn.BatchNorm1d(1024)self.global_feat = global_featself.feature_transform = feature_transformif self.feature_transform:self.fstn = STNkd(k=64)  # 计算第二个transformation矩阵def forward(self, x):B, D, N = x.size()trans = self.stn(x) # bs, 3, 3x = x.transpose(2, 1)   # bs, n, 3if D > 3:feature = x[:, :, 3:]x = x[:, :, :3]x = torch.bmm(x, trans) #if D > 3:x = torch.cat([x, feature], dim=2)x = x.transpose(2, 1)   # bs, c, nx = F.relu(self.bn1(self.conv1(x))) # [bs, 3, n]->[bs, 64, n]if self.feature_transform:trans_feat = self.fstn(x)   # bs, 64, 64x = x.transpose(2, 1)x = torch.bmm(x, trans_feat)    # bs, n, 64x = x.transpose(2, 1)else:trans_feat = Nonepointfeat = xx = F.relu(self.bn2(self.conv2(x)))x = self.bn3(self.conv3(x))x = torch.max(x, 2, keepdim=True)[0]x = x.view(-1, 1024)    # bs, 1024:这里是global feature# 如果global_feat=true, 返回的x是global featureif self.global_feat:return x, trans, trans_feat # 分别是global feature, 3x3的转换矩阵, transform feature,用于instance-wise的classfication# 返回的x为每个点加入global feature之后的特征,用于segmentationelse:x = x.view(-1, 1024, 1).repeat(1, 1, N)# x:[bs, c, n] + pointfeat:[bs, c, n] -> [bs, c+c, n]return torch.cat([x, pointfeat], 1), trans, trans_feat  # 这里是point-wise的classfication

PointNet源码解读相关推荐

  1. Bert系列(二)——源码解读之模型主体

    本篇文章主要是解读模型主体代码modeling.py.在阅读这篇文章之前希望读者们对bert的相关理论有一定的了解,尤其是transformer的结构原理,网上的资料很多,本文内容对原理部分就不做过多 ...

  2. Bert系列(三)——源码解读之Pre-train

    https://www.jianshu.com/p/22e462f01d8c pre-train是迁移学习的基础,虽然Google已经发布了各种预训练好的模型,而且因为资源消耗巨大,自己再预训练也不现 ...

  3. linux下free源码,linux命令free源码解读:Procps free.c

    linux命令free源码解读 linux命令free源码解读:Procps free.c 作者:isayme 发布时间:September 26, 2011 分类:Linux 我们讨论的是linux ...

  4. nodeJS之eventproxy源码解读

    1.源码缩影 !(function (name, definition) { var hasDefine = typeof define === 'function', //检查上下文环境是否为AMD ...

  5. PyTorch 源码解读之即时编译篇

    点击上方"AI遇见机器学习",选择"星标"公众号 重磅干货,第一时间送达 作者丨OpenMMLab 来源丨https://zhuanlan.zhihu.com/ ...

  6. Alamofire源码解读系列(九)之响应封装(Response)

    本篇主要带来Alamofire中Response的解读 前言 在每篇文章的前言部分,我都会把我认为的本篇最重要的内容提前讲一下.我更想同大家分享这些顶级框架在设计和编码层次究竟有哪些过人的地方?当然, ...

  7. Feflow 源码解读

    Feflow 源码解读 Feflow(Front-end flow)是腾讯IVWEB团队的前端工程化解决方案,致力于改善多类型项目的开发流程中的规范和非业务相关的问题,可以让开发者将绝大部分精力集中在 ...

  8. spring-session源码解读 sesion

    2019独角兽企业重金招聘Python工程师标准>>> spring-session源码解读 sesion 博客分类: java spring 摘要: session通用策略 Ses ...

  9. 前端日报-20160527-underscore 源码解读

    underscore 源码解读 API文档浏览器 JavaScript 中加号操作符细节 抛弃 jQuery,拥抱原生 JS 从 0 开始学习 GitHub 系列之「加入 GitHub」 js实现克隆 ...

最新文章

  1. ButterKnife基本使用
  2. 阿里移动|《蚂蚁金服移动端高可用技术实践》
  3. spring基于注解程序开发
  4. linux常用运维命令【转】
  5. python编程(python调用dll程序)
  6. python 二维矩阵及转byte知识点
  7. 轮播图实现html,html、css、js实现轮播图
  8. [前端随笔][Javascript][物理引擎] 给元素添加简单的物理属性
  9. 转载--Github优秀java项目集合(中文版) - 涉及java所有的知识体系
  10. java的中文源代码
  11. oracle 下载 pb12.5,PowerBuilder
  12. Windows下修改hosts文件
  13. 机械制图国家标准的绘图模板_JS制图:映射
  14. w8系统服务器垃圾清理,win8系统盘太大怎么办?来给TA瘦身吧! | SDT技术网
  15. Python PIL库对阻挡文件blk进行解析,生成红绿色位图
  16. XJOI_3571_求十位数
  17. 如何在Excel中对工作表进行分组
  18. java中的Dao类是什么意思
  19. 传统行业程序员的深度焦虑?——快来互联网行业吧!
  20. bga bond焊盘 wire_封装模式: FC-BGA VS. WireBond ,谁是封装工艺中的真英雄?(图)

热门文章

  1. 一曲相思用计算机怎么按,抖音这人间袅袅炊烟是什么歌 抖音一曲相思完整版...
  2. 浪涌抑制专题-半导体放电管tss介绍
  3. 【HTML基础】CSS样式表
  4. 最喜欢的一首中文歌曲
  5. 电影分区发行新模式创造“中国电影市场的新增量”
  6. 民航飞行学院计算机分院云,ICC-数字课程云平台-中国民用航空飞行学院
  7. Java中调用ImageJ,与直接使用ImageJ软件处理所得图片黑白颠倒的问题
  8. 【北亚数据恢复】infortrend服务器raid6硬盘离线后进行上线操作导致服务器崩溃的数据恢复
  9. 圣路易斯华盛顿大学计算机科学排名,圣路易斯华盛顿大学CS的排名?真是应该稳重去看...
  10. 四种常见商务书信写作的排版格式