Pytorch+cpp_cuda extension 课程二
配套视频
1. GPU 并行概论
可参考:gpu与cuda概论
grid -> block -> thread
2. 算法并行分析
feats:(N,8,F)
N表示有N个正方体,8表示8个特征点,F表示每个特征点的表示方式,F应该等于3吧。
points:(N,3)
N表示有N个正方体,3表示坐标
每个点的内插都是独立的,每个点对于每个特征点来说,也是独立的,因此可在这两个方向上并行。
3. 算法
1. 编写 .cu
文件
既然要用到 cuda 那么,就要 .cu(cuda) 文件,这里并没有并行。而是做一个调用的demo。
#include <torch/extension.h>#define CHECK_CUDA(x) TORCH_CHECK(x.is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)torch::Tensor trilinear_fw_cu(torch::Tensor feats,torch::Tensor points
){return feats;
}
2. C++ 文件
新建头文件 utils.h
,位于./include
,定义 CUDA
有哪些函数
#include <torch/extension.h>#define CHECK_CUDA(x) TORCH_CHECK(x.is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)torch::Tensor trilinear_fw_cu(torch::Tensor feats,torch::Tensor points
);
新建 C++桥梁
#include "utils.h"torch::Tensor trilinear_interpolation(torch::Tensor feats,torch::Tensor points
){CHECK_INPUT(feats);CHECK_INPUT(points);// GPU函数return trilinear_fw_cu(feats, points);
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m){m.def("trilinear", &trilinear_interpolation,"test");
}
3. setup.py
和之前的c++程序相比,有以下变化
- 源码的输入方式,因为源码有点多,所以用了
glop
工具 - 引入了头文件,所以加了
include_dirs
- 因为用了CUDA ,所以将
CppExtension
改为CUDAExtension
如果安装失败,就去 site-packages
里面手动删除
import glob
import os.path as osp
from setuptools import setup
from torch.utils.cpp_extension import CUDAExtension, BuildExtensionROOT_DIR = osp.dirname(osp.abspath(__file__))
include_dirs = [osp.join(ROOT_DIR, "include")]sources = glob.glob('*.cpp')+glob.glob('*.cu')setup(name='cppcuda_tutorial',version='1.0',author='kwea123',author_email='kwea123@gmail.com',description='cppcuda_tutorial',long_description='cppcuda_tutorial',ext_modules=[CUDAExtension(name='cppcuda_tutorial',sources=sources,include_dirs=include_dirs,extra_compile_args={'cxx': ['-O2'],'nvcc': ['-O2']})],cmdclass={'build_ext': BuildExtension}
)
4. 测试文件
当运行时:
import torch
#from torch.utils.cpp_extension import load
#cppcuda = load(name="test", sources=['interpolation.cpp'], verbose=False,extra_cflags=["-O2"])
import cppcuda_tutorial
feats = torch.ones(2)
point = torch.ones(2)out = cppcuda_tutorial.trilinear(feats, point)print(out)
报错:
dell/pytorch_c++_cuda/example_2/test.py
Traceback (most recent call last):File "/home/dell/pytorch_c++_cuda/example_2/test.py", line 8, in <module>out = cppcuda_tutorial.trilinear(feats, point)
RuntimeError: feats must be a CUDA tensor
修改之后,改为:
import torch
#from torch.utils.cpp_extension import load
#cppcuda = load(name="test", sources=['interpolation.cpp'], verbose=False,extra_cflags=["-O2"])
import cppcuda_tutorial
feats = torch.ones(2,device="cuda")
point = torch.ones(2,device="cuda")
out = cppcuda_tutorial.trilinear(feats, point)
print(out)
运行成功!
Pytorch+cpp_cuda extension 课程二相关推荐
- Pytorch+cpp_cuda extension 课程一
以下学习来源于 youtube AI 葵老师的系列课程 为了方便后续学习我将它上传到了我的 BliBli 上,国内的同学可以点击访问. github code 如果github打不开,可以用我们国内的 ...
- 小白学习pytorch源码(二):setup.py最详细解读
小白学习pytorch源码(二) pytorch setup.py最全解析 setup.py与setuptools setup.py最详细解读 setup.py 环境检查 setup.py setup ...
- pytorch学习笔记(二):gradien
pytorch学习笔记(二):gradient 2017年01月21日 11:15:45 阅读数:17030
- PyTorch框架学习十二——损失函数
PyTorch框架学习十二--损失函数 一.损失函数的作用 二.18种常见损失函数简述 1.L1Loss(MAE) 2.MSELoss 3.SmoothL1Loss 4.交叉熵CrossEntropy ...
- PyTorch学习笔记(二):PyTorch简介与基础知识
往期学习资料推荐: 1.Pytorch实战笔记_GoAI的博客-CSDN博客 2.Pytorch入门教程_GoAI的博客-CSDN博客 本系列目录: PyTorch学习笔记(一):PyTorch环境安 ...
- 切图案例实操课程二-姜威-专题视频课程
切图案例实操课程二-199人已学习 课程介绍 本课程以主要目的是引导初入前端的小白,了解前端是如何工作的,通过正确建立构建环境,解构任务, 课程收益 讲师介绍 姜威 ...
- 一步步读懂Pytorch Chatbot Tutorial代码(二) - 数据处理
文章目录 自述 代码出处 目录 代码 Create formatted data file (为了方便理解,把代码的顺序略微改一下, 此章节略长.) 1. `loadLines` 将文件的每一行拆分为 ...
- PointNet.pytorch程序注释(二)点云分割
PointNet.pytorch程序注释(二)点云分割 论文及程序地址 运行环境 训练train 测试test 论文及程序地址 论文原文 PointNet: Deep Learning on Poin ...
- PyTorch框架学习(二) — 一文精通张量操作
PyTorch框架学习(二) - 张量操作与线性回归 1 张量的操作 1.1 拼接 1.2 切分 1.3 索引 1.4 变换 2 张量的数学运算 2.1 加法运算 2.2 减法运算 2.3 哈达玛积运 ...
最新文章
- 交换机的端口工作模式
- MongoDB的下载与安装
- Django REST framework快速入门
- man thread_join
- 国家电网和南方电网还傻傻分不清?
- Parse a document from a String
- calltreetest中文_calltree查看工程代码中的函数调用关系
- ORB特征提取算法解析
- response对象设置返回状态_爬虫代理之设置
- Axure实例:即刻 app 产品需求文档
- 计算机部分应用显示模糊,win10系统打开部分软件字体总显示模糊的解决方法
- python如何抓取网页里面的文字_如何利用python抓取网页文字、图片内容?
- java pdfbox 转jpg_java实现PDF转图片的方法
- sp经营许可证适用范围是什么?
- 设置VSS2005使支持通过Internet访问
- P3386 【模板】二分图最大匹配(匈牙利算法,网络流)
- C++多线程详细讲解
- 怎么卸载apowerrec_录屏王ApowerREC Mac版卸载后,如何彻底删除Apowersoft Audio Device声音设备?...
- centos单机部署greenplum
- 二代身份证阅读器 C#、JAVA调用教程
热门文章
- 教你用Python画一棵圣诞树
- 山东畜牧兽医职业学院计算机考试,山东畜牧兽医职业学院计算机自编word15套试题11Word模拟试题(1-15).doc...
- 01.JS基础_前端的语法(4)
- 第六十六章 Caché 函数大全 $TRANSLATE 函数
- win7cmd闪退_win7系统运行bat批处理文件出现闪退的解决方法
- 通过CSS样式缩放图片导致图片模糊的解决方案
- html中出现弹窗偏右,打印机打印某些网页时,右边总是打印不全,怎么办
- 精选(26)面试官:讲讲你对ThreadLocal的理解
- 中软外包创维面试,尬聊半小时
- Threaded Binary Tree