点上方蓝字计算机视觉联盟获取更多干货

在右上方 ··· 设为星标 ★,与你不见不散

仅作学术分享,不代表本公众号立场,侵权联系删除

转载于:作者丨奔腾的黑猫@知乎

来源丨https://zhuanlan.zhihu.com/p/158643792

AI博士笔记系列推荐

周志华《机器学习》手推笔记正式开源!可打印版本附pdf下载链接

关于PyTorch构建扩展的一些基础操作,官方往往已经出具了完整的教程。本文对这些官方教程的链接进行了整理,以供读者查阅。

在做毕设的时候,需要实现一个PyTorch原生代码中没有的并行算子,所以用到了这部分的知识,再不总结就要忘光了= =

本文内容主要是PyTorch的官方教程的各种传送门,这些官方教程写的都很好,以后就可以不用再浪费时间在百度上了。由于图神经网络计算框架PyG的代码实现也是采用了扩展的方法,因此也可以当成下面总结PyG源码文章的前导知识吧 。

第一种情况:使用PyThon扩展PyTorch

使用PyThon扩展PyTorch准确的来说是在PyTorch的Python前端实现自定义算子或者模型,不涉及底层C++的实现。这种扩展方式是所有扩展方式中最简单的,也是官方首先推荐的,这是因为PyTorch在NVIDIA cuDNN,Intel MKL或NNPACK之类的库的支持下已经对可能出现的CPU和GPU操作进行了高度优化,因此用Python扩展的代码通常足够快。

比如要扩展一个新的PyThon算子(torch.nn)只需要继承torch.nn.Module并实现其forward方法即可。详细的过程请参考官方教程传送门:

https://pytorch.org/docs/master/notes/extending.html

第二种情况:使用pybind11构建共享库形式的C++和CUDA扩展

但是如果我们想对代码进行进一步优化,比如对自己的算子添加并行的CUDA实现或者连接个OpenCV的库什么的,那么仅仅使用Python进行扩展就不能满足需求;其次如果我们想序列化模型,在一个没有Python环境的生产环境下部署,也需要我们使用C++重写算法;最后考虑到考虑到多线程执行和性能原因,一般Python代码也并不适合做部署。因此在对性能有要求或者需要序列化模型的场景下我们还是会用到C++扩展。

下面我先把官方教程传送门放在这里:

https://pytorch.org/tutorials/advanced/cpp_extension.html

对于一种典型的扩展情况,比如我们要设计一个全新的C++底层算子,其过程其实就三步:

第一步:使用C++编写算子的forward函数和backward函数

第二步:将该算子的forward函数和backward函数使用**pybind11**绑定到python上

第三步:使用setuptools/JIT/CMake编译打包C++工程为so文件

注意到在第一步中,我们不仅仅要实现forward函数也要实现backward函数,这是因为在C++端PyTorch目前不支持自动根据forward函数推导出backward函数,所以我们必须要对自己算子的反向传播过程完全清楚。一个需要注意的地方是,你可以选择直接在C++中继承torch::autograd类进行扩展;也可以像官方教程中那样在C++代码中实现forward和backward的核心过程,而在python端继承PyTorch的torch.autograd.Function类。

在C++端扩展forward函数和backward函数的需要注意以下规则:

(1)首先无论是forward函数还是backward函数都需要声明为静态函数

(2)forward函数可以接受任意多的参数并且应该返回一个 variable list或者variable;forward函数需要将[torch::autograd::AutogradContext](https://link.zhihu.com/?target=https%3A//pytorch.org/cppdocs/api/structtorch_1_1autograd_1_1_autograd_context.html%23structtorch_1_1autograd_1_1_autograd_context) 作为自己的第一个参数。Variables可以被使用ctx->save_for_backward保存,而其他数据类型可以使用ctx->saved_data以<std::string,at::IValue>pairs的形式保存在一个map中。

(3)backward函数第一个参数同样需要为torch::autograd::AutogradContext,其余的参数是一个variable_list,包含的变量数量与forward输出的变量数量相等。它应该返回和forward输入一样多的变量。保存在forward中的Variable变量可以通过ctx->get_saved_variables而其他的数据类型可以通过ctx->saved_data获取。

请注意,backward的输入参数是自动微分系统反传回来的参数梯度值,其需要和forward函数的返回值位置一一对应的;而backward的返回值是对各参数根据自动微分规则求导后的梯度值,其需要和forward函数的输入参数位置一一对应,对于不需要求导的参数也需要使用空Variable占位。

// PyG的C++扩展就选择的是直接继承PyTorch的C++端的torch::autograd类进行扩展// 下面是PyG的一个ScatterSum算子的扩展示例// 不用纠结这个算子的具体内容,对扩展的算子的结构有一个大致了解即可class ScatterSum : public torch::autograd::Function<ScatterSum> {public:  // AutogradContext *ctx指针可以操作  static variable_list forward(AutogradContext *ctx, Variable src,                               Variable index, int64_t dim,                               torch::optional<Variable> optional_out,                               torch::optional<int64_t> dim_size) {    dim = dim < 0 ? src.dim() + dim : dim;    ctx->saved_data["dim"] = dim;    ctx->saved_data["src_shape"] = src.sizes();    index = broadcast(index, src, dim);    auto result = scatter_fw(src, index, dim, optional_out, dim_size, "sum");    auto out = std::get<0>(result);    ctx->save_for_backward({index});    // 如果在扩展的C++代码中使用非Aten内建操作修改了tensor的值,需要对其进行脏标记    if (optional_out.has_value())      ctx->mark_dirty({optional_out.value()});      return {out};  } // grad_outs是out参数反传回来的梯度值  static variable_list backward(AutogradContext *ctx, variable_list grad_outs) {    auto grad_out = grad_outs[0];    auto saved = ctx->get_saved_variables();    auto index = saved[0];    auto dim = ctx->saved_data["dim"].toInt();    auto src_shape = list2vec(ctx->saved_data["src_shape"].toIntList());    auto grad_in = torch::gather(grad_out, dim, index, false);    // 不需要求导的参数需要空Variable占位    return {grad_in, Variable(), Variable(), Variable(), Variable()};  }};

由于涉及到在C++环境下操作张量和反向传播等操作,因此我们需要对PyTorch的C++后端的库有所了解,主要就是Torch和Aten这两个库,下面我简要介绍一下这两兄弟。

其中Torch是PyTorch的C++底层实现(PS:其实是先有的Torch后有的PyTorch,从名字也能看出来),FB在编码PyTorch的时候就有意将PyTorch的接口和Torch的接口设计的十分类似,因此如果你对PyTorch很熟悉的话那么你也会很快的对Torch上手。

Torch官方文档传送门:

https://pytorch.org/cppdocs/frontend.html

安装PyTorch的C++前端的官方教程:

https://pytorch.org/cppdocs/installing.html

而Aten是ATen从根本上讲是一个张量库,在PyTorch中几乎所有其他Python和C ++接口都在其上构建。它提供了一个核心Tensor类,在其上定义了数百种操作。这些操作大多数都具有CPU和GPU实现,Tensor该类将根据其类型向其动态调度。和Torch相比Aten更接近底层和核心逻辑。

Aten源代码传送门:

https://github.com/zdevito/ATen/tree/master/aten/srcgithub.com

使用Aten声明和操作张量的教程:

https://pytorch.org/cppdocs/notes/tensor_basics.html

由于Pyorch的C++后端文档比较少,因此要多参考官方的例子,尝试去模仿官方教程的代码,同时可以通过Python前端的接口猜测后端接口的功能,如果没有文档了就读一读源码,还是有不少注释的,还能理解实现的逻辑。

第三种情况:为TORCHSCRIPT添加C++和CUDA扩展

首先简单解释一下TorchScript是什么,如果用官方的定义来说:“TorchScript是一种从PyTorch代码创建可序列化和可优化模型的方法。任何TorchScript程序都可以从一个Python进程中保存并可以在一个没有Python环境的进程中被加载。”通俗来说TorchScript就是一个序列化模型(即Inference)的工具,它可以让你的PyTorch代码方便的在生产环境中部署,同时在将PyTorch代码转化TorchScript代码时还会对你的模型进行一些性能上的优化。使用TorchScript完成模型的部署要比我们之前提到的使用C++重写要简单的多,因为是自动生成的。

TorchScript包含两种序列化模型的方法:tracingscript,两种方法各有其适用场景,由于和本文关系不大就不详细展开了,具体的官方教程传送门在此:

https://pytorch.org/tutorials/beginner/Intro_to_TorchScript_tutorial.html

但是,TorchScript只能自动化的构造PyTorch的原生代码,如果我们需要序列化自定义的C++扩展算子,则需要我们显式的将这些自定义算子注册到TorchScript中,所幸的是,这一过程其实非常简单,整个过程和第二小节中使用pybind11构建共享库的形式的C++和CUDA扩展十分类似。官方教程传送门如下:

https://pytorch.org/tutorials/advanced/torch_script_custom_ops.html

而对于自定义的C++类,如果要注册到TorchScript要稍微复杂一些,官方教程传送门如下:

https://pytorch.org/tutorials/advanced/torch_script_custom_classes.html?highlight=registeroperators

另外需要注意的是,如果想要编写能够被TorchScript编译器理解的代码,需要注意在C++自定义扩展算子参数中的数据类型,目前被TorchScript支持的参数数据类型有torch::Tensortorch::Scalar(标量类型),doubleint64_tstd::vector,而像float,int,short这些是不能作为自定义扩展算子的参数数据类型的。

目前就先总结这么多吧~

end

这是我的私人微信,还有少量坑位,可与相关学者研究人员交流学习 

目前开设有人工智能、机器学习、计算机视觉、自动驾驶(含SLAM)、Python、求职面经、综合交流群扫描添加CV联盟微信拉你进群,备注:CV联盟

王博的公众号,欢迎关注,干货多多

王博的系列手推笔记(附高清PDF下载):

博士笔记 | 周志华《机器学习》手推笔记第一章思维导图

博士笔记 | 周志华《机器学习》手推笔记第二章“模型评估与选择”

博士笔记 | 周志华《机器学习》手推笔记第三章“线性模型”

博士笔记 | 周志华《机器学习》手推笔记第四章“决策树”

博士笔记 | 周志华《机器学习》手推笔记第五章“神经网络”

博士笔记 | 周志华《机器学习》手推笔记第六章支持向量机(上)

博士笔记 | 周志华《机器学习》手推笔记第六章支持向量机(下)

博士笔记 | 周志华《机器学习》手推笔记第七章贝叶斯分类(上)

博士笔记 | 周志华《机器学习》手推笔记第七章贝叶斯分类(下)

博士笔记 | 周志华《机器学习》手推笔记第八章(上)

博士笔记 | 周志华《机器学习》手推笔记第八章(下)

博士笔记 | 周志华《机器学习》手推笔记第九章

点个在看支持一下吧

PyTorch扩展自定义PyThon/C++(CUDA)算子的若干方法总结相关推荐

  1. python interpreter 中没有torch_PyTorch扩展自定义PyThon/C++(CUDA)算子的若干方法总结

    在做毕设的时候需要实现一个PyTorch原生代码中没有的并行算子,所以用到了这部分的知识,再不总结就要忘光了= =,本文内容主要是PyTorch的官方教程的各种传送门,这些官方教程写的都很好,以后就可 ...

  2. Pytorch 使用不同版本的 cuda,跟使用不同版本的cuda进行编译扩展库,其实TensorFlow也是一样

    在使用 Pytorch 时,由于 Pytorch 和 cuda 版本的更新,可能出现程序运行时需要特定版本的 cuda 进行运行环境支持的情况,如使用特定版本的 cuda 编译 CUDAExtensi ...

  3. PyTorch基础-自定义数据集和数据加载器(2)

    处理数据样本的代码可能会变得混乱且难以维护: 理想情况下,我们想要数据集代码与模型训练代码解耦,以获得更好的可读性和模块化.PyTorch 域库提供了许多预加载的数据(例如 FashionMNIST) ...

  4. 详解PyTorch编译并调用自定义CUDA算子的三种方式

    点击上方"3D视觉工坊",选择"星标" 干货第一时间送达 在上一篇教程中,我们实现了一个自定义的CUDA算子add2,用来实现两个Tensor的相加.然后用Py ...

  5. 自定义 C++ 和 CUDA 扩展

    来源 官方文档 前言 PyTorch 提供了大量与神经网络.随机张量代数(arbitrary tensor algebra).数据整合(data wrangling)以及其他目的相关的操作.但是,您仍 ...

  6. PyTorch 源码解读之 cpp_extension:讲解 C++/CUDA 算子实现和调用全流程

    "Python 用户友好却运行效率低","C++ 运行效率较高,但实现一个功能代码量会远大于 Python".平常学习工作中你是否常听到类似的说法?在 Pyth ...

  7. pytorch扩展——如何自定义前向和后向传播

    python 端扩展 pytorch 版权声明:本文为博主原创文章,遵循 CC 4.0 by-sa 版权协议,转载请附上原文出处链接和本声明. 本文链接: https://blog.csdn.net/ ...

  8. 实例:手写 CUDA 算子,让 Pytorch 提速 20 倍

    作者丨PENG Bo@知乎(已授权) 来源丨https://zhuanlan.zhihu.com/p/476297195 编辑丨极市平台 本文的代码,在 win10 和 linux 均可直接编译运行: ...

  9. Pytorch的自定义拓展:torch.nn.Module和torch.autograd.Function

    参考链接:pytorch的自定义拓展之(一)--torch.nn.Module和torch.autograd.Function_LoveMIss-Y的博客-CSDN博客_pytorch自定义backw ...

最新文章

  1. python数据库应用开发实例_纯Python开发的nosql数据库CodernityDB介绍和使用实例
  2. C语言 小游戏之贪吃蛇
  3. python 写入第二列_python读写Excel表格的实例代码(简单实用)
  4. SectionIndexer中的getSectionForPosition()与getPositionForSection()解惑
  5. ichat在线客服jQuery插件(可能是历史上最灵活的)
  6. 实战 SQL Server 2008 数据库误删除数据的恢复 (转)
  7. PyTorch 1.0 中文官方教程:Torchvision 模型微调
  8. 笔记:网络管理与检测命令
  9. docker中如何制作自己的基础镜像
  10. HTTP、HTTP2、HTTPS、SPDY等的理解及在spring-boot中的使用
  11. 转速恒压频比交流变频调速系统Simulink仿真,可观察到电压频率的变比情况以及电动机的转速波形。
  12. 为解放程序员而生,网易重磅推“场景化云服务”,强势进军云计算市场
  13. “CHK文件恢复”和“文件恢复”有什么区别?
  14. Java实现 蓝桥杯VIP 算法提高 盾神与砝码称重
  15. 量子计算机的成熟度模型,全球首家:紫光展锐通过 TMMi 软件测试成熟度模型集成 5 级认证...
  16. wifi有网可以连接,但打不开网页了,找不到 服务器 dns 地址
  17. ABAQUS|多重约束的解决办法!(过约束/螺栓预紧力)
  18. 网络工程师高薪就业行业有哪些
  19. matlab白光干涉,matlab白光干涉
  20. 耳机不分主从是什么意思_不疯魔不成活!红魔TWS蓝牙耳机告诉你什么是“低延怪兽”...

热门文章

  1. 方法~作用于对象~失败_消息三:ActiveMQ Topic 消息失败重发
  2. php 下拉菜单 不提交 选中的值,在html中怎样可以做到下拉菜单提交后保留选中值不返回默认值...
  3. sqlserver导入execl数据ACE.OLEDB.12.0错误
  4. java batch size_java – @BatchSize但在@ManyToOne案例中有很多往返
  5. gitbash登录码云报错_手把手教你入门git仓库和关联码云
  6. c++ jna 数据类型_JNA实战笔记汇总一 简单认识JNA|成功调用JNA
  7. 开启mongodb数据库命令行_【赵强老师】使用MongoDB的命令行工具:mongoshell
  8. 速读《文献管理与信息分析》笔记
  9. 急急急 大神帮忙给个思路和步骤吧 万分感谢
  10. layui流加载及传参