在做毕设的时候需要实现一个PyTorch原生代码中没有的并行算子,所以用到了这部分的知识,再不总结就要忘光了= =,本文内容主要是PyTorch的官方教程的各种传送门,这些官方教程写的都很好,以后就可以不用再浪费时间在百度上了。由于图神经网络计算框架PyG的代码实现也是采用了扩展的方法,因此也可以当成下面总结PyG源码文章的前导知识吧 。

第一种情况:使用PyThon扩展PyTorch

使用PyThon扩展PyTorch准确的来说是在PyTorch的Python前端实现自定义算子或者模型,不涉及底层C++的实现。这种扩展方式是所有扩展方式中最简单的,也是官方首先推荐的,这是因为PyTorch在NVIDIA cuDNN,Intel MKL或NNPACK之类的库的支持下已经对可能出现的CPU和GPU操作进行了高度优化,因此用Python扩展的代码通常足够快。

比如要扩展一个新的PyThon算子(torch.nn)只需要继承torch.nn.Module并实现其forward方法即可。详细的过程请参考官方教程传送门:

Extending PyTorch​pytorch.org

第二种情况:使用pybind11构建共享库形式的C++和CUDA扩展

但是如果我们想对代码进行进一步优化,比如对自己的算子添加并行的CUDA实现或者连接个OpenCV的库什么的,那么仅仅使用Python进行扩展就不能满足需求;其次如果我们想序列化模型,在一个没有Python环境的生产环境下部署,也需要我们使用C++重写算法;最后考虑到考虑到多线程执行和性能原因,一般Python代码也并不适合做部署。因此在对性能有要求或者需要序列化模型的场景下我们还是会用到C++扩展。

下面我先把官方教程传送门放在这里:

CUSTOM C++ AND CUDA EXTENSIONS​pytorch.org

对于一种典型的扩展情况,比如我们要设计一个全新的C++底层算子,其过程其实就三步:

第一步:使用C++编写算子的forward函数和backward函数

第二步:将该算子的forward函数和backward函数使用pybind11绑定到python上

第三步:使用setuptools/JIT/CMake编译打包C++工程为so文件

注意到在第一步中,我们不仅仅要实现forward函数也要实现backward函数,这是因为在C++端PyTorch目前不支持自动根据forward函数推导出backward函数,所以我们必须要对自己算子的反向传播过程完全清楚。一个需要注意的地方是,你可以选择直接在C++中继承torch::autograd类进行扩展;也可以像官方教程中那样在C++代码中实现forward和backward的核心过程,而在python端继承PyTorch的torch.autograd.Function类。

在C++端扩展forward函数和backward函数的需要注意以下规则:

(1)首先无论是forward函数还是backward函数都需要声明为静态函数

(2)forward函数可以接受任意多的参数并且应该返回一个 variable list或者variable;forward函数需要将torch::autograd::AutogradContext 作为自己的第一个参数。Variables可以被使用ctx->save_for_backward保存,而其他数据类型可以使用ctx->saved_data以<std::string,at::IValue>pairs的形式保存在一个map中。

(3)backward函数第一个参数同样需要为torch::autograd::AutogradContext,其余的参数是一个variable_list,包含的变量数量与forward输出的变量数量相等。它应该返回和forward输入一样多的变量。保存在forward中的Variable变量可以通过ctx->get_saved_variables而其他的数据类型可以通过ctx->saved_data获取。

请注意,backward的输入参数是自动微分系统反传回来的参数梯度值,其需要和forward函数的返回值位置一一对应的;而backward的返回值是对各参数根据自动微分规则求导后的梯度值,其需要和forward函数的输入参数位置一一对应,对于不需要求导的参数也需要使用空Variable占位。

// PyG的C++扩展就选择的是直接继承PyTorch的C++端的torch::autograd类进行扩展
// 下面是PyG的一个ScatterSum算子的扩展示例
// 不用纠结这个算子的具体内容,对扩展的算子的结构有一个大致了解即可
class ScatterSum : public torch::autograd::Function<ScatterSum> {public:// AutogradContext *ctx指针可以操作static variable_list forward(AutogradContext *ctx, Variable src,Variable index, int64_t dim,torch::optional<Variable> optional_out,torch::optional<int64_t> dim_size) {dim = dim < 0 ? src.dim() + dim : dim;ctx->saved_data["dim"] = dim;ctx->saved_data["src_shape"] = src.sizes();index = broadcast(index, src, dim);auto result = scatter_fw(src, index, dim, optional_out, dim_size, "sum");auto out = std::get<0>(result);ctx->save_for_backward({index});// 如果在扩展的C++代码中使用非Aten内建操作修改了tensor的值,需要对其进行脏标记if (optional_out.has_value())ctx->mark_dirty({optional_out.value()});  return {out};}// grad_outs是out参数反传回来的梯度值static variable_list backward(AutogradContext *ctx, variable_list grad_outs) {auto grad_out = grad_outs[0];auto saved = ctx->get_saved_variables();auto index = saved[0];auto dim = ctx->saved_data["dim"].toInt();auto src_shape = list2vec(ctx->saved_data["src_shape"].toIntList());auto grad_in = torch::gather(grad_out, dim, index, false);// 不需要求导的参数需要空Variable占位return {grad_in, Variable(), Variable(), Variable(), Variable()};}
};

由于涉及到在C++环境下操作张量和反向传播等操作,因此我们需要对PyTorch的C++后端的库有所了解,主要就是Torch和Aten这两个库,下面我简要介绍一下这两兄弟。

其中Torch是PyTorch的C++底层实现(PS:其实是先有的Torch后有的PyTorch,从名字也能看出来),FB在编码PyTorch的时候就有意将PyTorch的接口和Torch的接口设计的十分类似,因此如果你对PyTorch很熟悉的话那么你也会很快的对Torch上手。

Torch官方文档传送门:

The C++ Frontend​pytorch.org

安装PyTorch的C++前端的官方教程:

INSTALLING C++ DISTRIBUTIONS OF PYTORCH​pytorch.org

而Aten是ATen从根本上讲是一个张量库,在PyTorch中几乎所有其他Python和C ++接口都在其上构建。它提供了一个核心Tensor类,在其上定义了数百种操作。这些操作大多数都具有CPU和GPU实现,Tensor该类将根据其类型向其动态调度。和Torch相比Aten更接近底层和核心逻辑。

Aten源代码传送门:

https://github.com/zdevito/ATen/tree/master/aten/src​github.com

使用Aten声明和操作张量的教程:

TENSOR BASICS​pytorch.org

由于Pyorch的C++后端文档比较少,因此要多参考官方的例子,尝试去模仿官方教程的代码,同时可以通过Python前端的接口猜测后端接口的功能,如果没有文档了就读一读源码,还是有不少注释的,还能理解实现的逻辑。

第三种情况:为TORCHSCRIPT添加C++和CUDA扩展

首先简单解释一下TorchScript是什么,如果用官方的定义来说:“TorchScript是一种从PyTorch代码创建可序列化和可优化模型的方法。任何TorchScript程序都可以从一个Python进程中保存并可以在一个没有Python环境的进程中被加载​​。”通俗来说TorchScript就是一个序列化模型(即Inference)的工具,它可以让你的PyTorch代码方便的在生产环境中部署,同时在将PyTorch代码转化TorchScript代码时还会对你的模型进行一些性能上的优化。使用TorchScript完成模型的部署要比我们之前提到的使用C++重写要简单的多,因为是自动生成的。

TorchScript包含两种序列化模型的方法:tracingscript,两种方法各有其适用场景,由于和本文关系不大就不详细展开了,具体的官方教程传送门在此:

INTRODUCTION TO TORCHSCRIPT​pytorch.org

但是,TorchScript只能自动化的构造PyTorch的原生代码,如果我们需要序列化自定义的C++扩展算子,则需要我们显式的将这些自定义算子注册到TorchScript中,所幸的是,这一过程其实非常简单,整个过程和第二小节中使用pybind11构建共享库的形式的C++和CUDA扩展十分类似。官方教程传送门如下:

EXTENDING TORCHSCRIPT WITH CUSTOM C++ OPERATORS​pytorch.org

而对于自定义的C++类,如果要注册到TorchScript要稍微复杂一些,官方教程传送门如下:

EXTENDING TORCHSCRIPT WITH CUSTOM C++ CLASSES​pytorch.org

另外需要注意的是,如果想要编写能够被TorchScript编译器理解的代码,需要注意在C++自定义扩展算子参数中的数据类型,目前被TorchScript支持的参数数据类型有torch::Tensortorch::Scalar(标量类型),doubleint64_tstd::vector,而像float,int,short这些是不能作为自定义扩展算子的参数数据类型的。

目前就先总结这么多吧,这点东西居然写了一天,好累啊(*  ̄︿ ̄)。

python interpreter 中没有torch_PyTorch扩展自定义PyThon/C++(CUDA)算子的若干方法总结相关推荐

  1. PyTorch扩展自定义PyThon/C++(CUDA)算子的若干方法总结

    点上方蓝字计算机视觉联盟获取更多干货 在右上方 ··· 设为星标 ★,与你不见不散 仅作学术分享,不代表本公众号立场,侵权联系删除 转载于:作者丨奔腾的黑猫@知乎 来源丨https://zhuanla ...

  2. 如何在Python Interpreter中重新导入更新的包? [重复]

    本文翻译自:How to re import an updated package while in Python Interpreter? [duplicate] This question alr ...

  3. Python:python语言中与时间有关的库函数简介、安装、使用方法之详细攻略

    Python:python语言中与时间有关的库函数简介.安装.使用方法之详细攻略 目录 与时间有关的库函数 案例应用 1.打印程序块前后运行时间 #T1.采用time库

  4. python中mod运算符_自定义 Python 类中的运算符和函数重载(上)

    Python部落(python.freelycode.com)组织翻译,禁止转载,欢迎转发. 如果你对 Python 中的str对象使用过 + 或 * 运算符,你一定注意到了它的操作与 int 或 f ...

  5. python打包为可执行文件的扩展名,Python脚本文件(.py)打包为可执行文件(.exe)即避免命令行中包含Python解释器...

    在最近的软件工程作业中用到了将Python脚本转化为exe文件这一过程,网上各种博客介绍了很多,有些东西都不完全,我也是综合了很多种方法最后才实现的,我就把这些整理出来,希望可以帮到大家~ 一.环境和 ...

  6. 在python程序中嵌入浏览器_用Python中的wxPython实现最基本的浏览器功能

    通常,大多数应用程序通过保持 HTML 简单来解决大多数浏览器问题 ― 或者说,根据最低共同特性来编写.然而,即便如此,也仍然存在字体和布局的问题,发行新浏览器和升级现有浏览器时,也免不了测试应用程序 ...

  7. python语言中浮点数_举例说明python如何生成一系列浮点数

    Python部落(python.freelycode.com)组织翻译,禁止转载,欢迎转发. 在这篇文章中,我将向您解释如何用python生成一系列浮点数.我已经用python写了几个示例,演示了如何 ...

  8. python语言中一切皆对象_2 python 中一切皆对象

    python 中一皆对象 在 python 编程语言中所有的一切都是对象 , python 语言相对于 C++ 或 java 这样的语言 (静态语言), 面向对象这个思想更加的彻底. 2.1 函数和类 ...

  9. python interpreter 中没有torch_python自动化办公之 Python 解析 PDF

    上次给大家介绍了 Python 如何操作 Word 和 Excel ,而今天想为大家再介绍下,用 Python 如何解析 PDF ,PDF 格式不像前面两个那么规范,从它的表现来看,它更像是一张图片, ...

最新文章

  1. 云计算重构渠道商的价值基础,推动渠道商向服务商转型
  2. 从opensuse 12.3 升级到 opensuse13.1体验
  3. Python之web开发(五):WEB开发html语句经典应用
  4. 跳出所有循环的语句_从零开始的Java之旅2.0 流程控制语句
  5. 阿里云Kubernetes服务上使用Tekton完成应用发布初体验
  6. MySQL回闪_MySQL进行BINLOG回闪
  7. 相量除法能用计算机吗,电路相量的加减乘除运算
  8. 【kafka】kafka DefaultRecordBatch. The older message format classes only support conversion from class
  9. python入门经典100例-【python】编程语言入门经典100例--37
  10. Delphi XE 10.2.3如何添加PDF阅读器组件
  11. H3CSE培训阶段1
  12. 淘宝特价版事业部java面试,含爱奇艺,小米,腾讯,阿里
  13. DS18B20 单总线多器件的ROM 搜索, ALARM 检测, CRC 校验 源码实现, 基于 STM32F103
  14. Python带_的变量或函数命名,带下划线的方法
  15. centos7 gitlab14搭建完成后,无法访问的问题处理(“error“:“badgateway: failed to receive response: dial unix /var/opt)
  16. 时间改变一切—兄弟连IT教育
  17. java代码走读,Kafka代码走读-LogManager
  18. Python 三维绘图
  19. 晶振03——晶振烧坏的原因
  20. linux下rs422串口通信,RS232/RS422/RS485通信接口區別

热门文章

  1. 如果公司的网络屏蔽了游戏【英雄联盟】的链接请求,使用这种方法玩游戏。
  2. 【Python】IDLE中文本进度条的单行动态刷新无法实现分析
  3. 使用外部同步的 Boost.Test 调用在 MT 环境中测试单元测试框架的可用性
  4. boost::spirit模块实现任意元组的解析器的测试程序
  5. boost::multiprecision模块hash相关的测试程序
  6. boost::mpl模块BOOST_MPL_ASSERT_MSG相关的测试程序
  7. boost::hana::take_while用法的测试程序
  8. boost::endian模块实现data的测试程序
  9. boost::describe模块实现overloaded的测试程序
  10. ITK:使用写访问权迭代图像中的区域