配套视频

1. GPU 并行概论

可参考:gpu与cuda概论
grid -> block -> thread

2. 算法并行分析

feats:(N,8,F)
N表示有N个正方体,8表示8个特征点,F表示每个特征点的表示方式,F应该等于3吧。

points:(N,3)
N表示有N个正方体,3表示坐标

每个点的内插都是独立的,每个点对于每个特征点来说,也是独立的,因此可在这两个方向上并行。

3. 算法

1. 编写 .cu 文件

既然要用到 cuda 那么,就要 .cu(cuda) 文件,这里并没有并行。而是做一个调用的demo。

#include <torch/extension.h>#define CHECK_CUDA(x) TORCH_CHECK(x.is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)torch::Tensor trilinear_fw_cu(torch::Tensor feats,torch::Tensor points
){return feats;
}

2. C++ 文件

新建头文件 utils.h,位于./include,定义 CUDA 有哪些函数

#include <torch/extension.h>#define CHECK_CUDA(x) TORCH_CHECK(x.is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)torch::Tensor trilinear_fw_cu(torch::Tensor feats,torch::Tensor points
);

新建 C++桥梁

#include "utils.h"torch::Tensor trilinear_interpolation(torch::Tensor feats,torch::Tensor points
){CHECK_INPUT(feats);CHECK_INPUT(points);// GPU函数return trilinear_fw_cu(feats, points);
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m){m.def("trilinear", &trilinear_interpolation,"test");
}

3. setup.py

和之前的c++程序相比,有以下变化

  • 源码的输入方式,因为源码有点多,所以用了 glop 工具
  • 引入了头文件,所以加了 include_dirs
  • 因为用了CUDA ,所以将 CppExtension 改为 CUDAExtension

如果安装失败,就去 site-packages 里面手动删除

import glob
import os.path as osp
from setuptools import setup
from torch.utils.cpp_extension import CUDAExtension, BuildExtensionROOT_DIR = osp.dirname(osp.abspath(__file__))
include_dirs = [osp.join(ROOT_DIR, "include")]sources = glob.glob('*.cpp')+glob.glob('*.cu')setup(name='cppcuda_tutorial',version='1.0',author='kwea123',author_email='kwea123@gmail.com',description='cppcuda_tutorial',long_description='cppcuda_tutorial',ext_modules=[CUDAExtension(name='cppcuda_tutorial',sources=sources,include_dirs=include_dirs,extra_compile_args={'cxx': ['-O2'],'nvcc': ['-O2']})],cmdclass={'build_ext': BuildExtension}
)

4. 测试文件

当运行时:

import torch
#from torch.utils.cpp_extension import load
#cppcuda = load(name="test", sources=['interpolation.cpp'], verbose=False,extra_cflags=["-O2"])
import cppcuda_tutorial
feats = torch.ones(2)
point = torch.ones(2)out = cppcuda_tutorial.trilinear(feats, point)print(out)

报错:

dell/pytorch_c++_cuda/example_2/test.py
Traceback (most recent call last):File "/home/dell/pytorch_c++_cuda/example_2/test.py", line 8, in <module>out = cppcuda_tutorial.trilinear(feats, point)
RuntimeError: feats must be a CUDA tensor

修改之后,改为:

import torch
#from torch.utils.cpp_extension import load
#cppcuda = load(name="test", sources=['interpolation.cpp'], verbose=False,extra_cflags=["-O2"])
import cppcuda_tutorial
feats = torch.ones(2,device="cuda")
point = torch.ones(2,device="cuda")
out = cppcuda_tutorial.trilinear(feats, point)
print(out)

运行成功!

Pytorch+cpp_cuda extension 课程二相关推荐

  1. Pytorch+cpp_cuda extension 课程一

    以下学习来源于 youtube AI 葵老师的系列课程 为了方便后续学习我将它上传到了我的 BliBli 上,国内的同学可以点击访问. github code 如果github打不开,可以用我们国内的 ...

  2. 小白学习pytorch源码(二):setup.py最详细解读

    小白学习pytorch源码(二) pytorch setup.py最全解析 setup.py与setuptools setup.py最详细解读 setup.py 环境检查 setup.py setup ...

  3. pytorch学习笔记(二):gradien

    pytorch学习笔记(二):gradient 2017年01月21日 11:15:45 阅读数:17030

  4. PyTorch框架学习十二——损失函数

    PyTorch框架学习十二--损失函数 一.损失函数的作用 二.18种常见损失函数简述 1.L1Loss(MAE) 2.MSELoss 3.SmoothL1Loss 4.交叉熵CrossEntropy ...

  5. PyTorch学习笔记(二):PyTorch简介与基础知识

    往期学习资料推荐: 1.Pytorch实战笔记_GoAI的博客-CSDN博客 2.Pytorch入门教程_GoAI的博客-CSDN博客 本系列目录: PyTorch学习笔记(一):PyTorch环境安 ...

  6. 切图案例实操课程二-姜威-专题视频课程

    切图案例实操课程二-199人已学习 课程介绍         本课程以主要目的是引导初入前端的小白,了解前端是如何工作的,通过正确建立构建环境,解构任务, 课程收益      讲师介绍     姜威 ...

  7. 一步步读懂Pytorch Chatbot Tutorial代码(二) - 数据处理

    文章目录 自述 代码出处 目录 代码 Create formatted data file (为了方便理解,把代码的顺序略微改一下, 此章节略长.) 1. `loadLines` 将文件的每一行拆分为 ...

  8. PointNet.pytorch程序注释(二)点云分割

    PointNet.pytorch程序注释(二)点云分割 论文及程序地址 运行环境 训练train 测试test 论文及程序地址 论文原文 PointNet: Deep Learning on Poin ...

  9. PyTorch框架学习(二) — 一文精通张量操作

    PyTorch框架学习(二) - 张量操作与线性回归 1 张量的操作 1.1 拼接 1.2 切分 1.3 索引 1.4 变换 2 张量的数学运算 2.1 加法运算 2.2 减法运算 2.3 哈达玛积运 ...

最新文章

  1. 交换机的端口工作模式
  2. MongoDB的下载与安装
  3. Django REST framework快速入门
  4. man thread_join
  5. 国家电网和南方电网还傻傻分不清?
  6. Parse a document from a String
  7. calltreetest中文_calltree查看工程代码中的函数调用关系
  8. ORB特征提取算法解析
  9. response对象设置返回状态_爬虫代理之设置
  10. Axure实例:即刻 app 产品需求文档
  11. 计算机部分应用显示模糊,win10系统打开部分软件字体总显示模糊的解决方法
  12. python如何抓取网页里面的文字_如何利用python抓取网页文字、图片内容?
  13. java pdfbox 转jpg_java实现PDF转图片的方法
  14. sp经营许可证适用范围是什么?
  15. 设置VSS2005使支持通过Internet访问
  16. P3386 【模板】二分图最大匹配(匈牙利算法,网络流)
  17. C++多线程详细讲解
  18. 怎么卸载apowerrec_录屏王ApowerREC Mac版卸载后,如何彻底删除Apowersoft Audio Device声音设备?...
  19. centos单机部署greenplum
  20. 二代身份证阅读器 C#、JAVA调用教程

热门文章

  1. 教你用Python画一棵圣诞树
  2. 山东畜牧兽医职业学院计算机考试,山东畜牧兽医职业学院计算机自编word15套试题11Word模拟试题(1-15).doc...
  3. 01.JS基础_前端的语法(4)
  4. 第六十六章 Caché 函数大全 $TRANSLATE 函数
  5. win7cmd闪退_win7系统运行bat批处理文件出现闪退的解决方法
  6. 通过CSS样式缩放图片导致图片模糊的解决方案
  7. html中出现弹窗偏右,打印机打印某些网页时,右边总是打印不全,怎么办
  8. 精选(26)面试官:讲讲你对ThreadLocal的理解
  9. 中软外包创维面试,尬聊半小时
  10. Threaded Binary Tree