python pytorch fft_看PyTorch源代码的心路历程
1. 起因
曾经碰到过别人的模型prelu在内部的推理引擎算出的结果与其在原始框架PyTorch中不一致的情况,虽然理论上大家实现的都是一个算法,但是从参数上看,因为经过了模型转换,中间做了一些调整。为了确定究竟是初始参数传递就出了问题还是在后续传递过程中继续做了更改、亦或者是最终算法实现方面有着细微差别导致最终输出不同,就想着去看一看PyTorch一路下来是怎么做的。
但是代码跟着跟着就跟丢了,才会发现,PyTorch真的是一个很复杂的项目,但就像舌尖里面说的,环境越是恶劣,回报越是丰厚。为了以后再想跟踪的时候方便,因此决定以PReLU为例静态梳理一下PyTorch的代码结构。捣鼓的这些天,对如何构建一个带有C/C++代码的Python又有了新的了解,这也算是意外的收获吧。
2. 历程
首先,我们从PReLU的导入路径torch.nn.PReLU中知道,他应在径进torch\nn\之下,进入该路径虽然没看到,但是我们在该路径下的__init__.py中知道,其实它就在torch\nn\modules\activation.py中。类PReLU最终调用了从torch\nn\functional.py导入的prelu方法。顺腾摸瓜,找到prelu,它长下面这样:
def prelu(input, weight):
# type: (Tensor, Tensor) -> Tensor
if not torch.jit.is_scripting():
if type(input) is not Tensor and has_torch_function((input,)):
return handle_torch_function(prelu, (input,), input, weight)
return torch.prelu(input, weight)
经过人脑对代码的一番执行你会发现,第一个if条件满足,而第二个if不满足。因此,最终想看算法,得去看torch.prelu()。好吧,接着干……
一番搜寻之后你会发现,Python代码中在torch这个包下面你是找不到prelu的定义的。但是绝望之际我们在torch包的__init__.py之中看到看下面几行代码:
# pytorch\torch\__init__.py
# 为了简洁,省去不必要代码,详细代码参见pytorch\torch\__init__.py
try:
# _initExtension is chosen (arbitrarily) as a sentinel.
from torch._C import _initExtension
__all__ += [name for name in dir(_C)
if name[0] != '_' and
not name.endswith('Base')]
if TYPE_CHECKING:
# Some type signatures pulled in from _VariableFunctions here clash with
# signatures already imported. For now these clashes are ignored; see
# PR #43339 for details.
from torch._C._VariableFunctions import * # type: ignore
for name in dir(_C._VariableFunctions):
if name.startswith('__'):
continue
globals()[name] = getattr(_C._VariableFunctions, name)
__all__.append(name)
这是全村最后的希望了。我们知道__all__中的名字其实就是该模块有意暴露出去的API。
什么意思呢?也就是说虽然我们明文上已经看不到了prelu的定义,但是这几行代码表明有一大堆身份不明的API被暗搓搓的导入了,这其中就很有可能存在我们朝思暮想的prelu。
那么我们怎么凭借这么一点微弱的线索确定我们的猜测到底对不对呢?这里我们就用到了Python的一个关键知识:C/C++扩展。(戳这里《使用C语言编写Python模块-引子》《Python调用C++之PYBIND11简介》了解更多)
我们知道Python C/C++扩展有着固定的格式,只要我们找到模块初始化入口,就能顺藤摸瓜找到该模块暴露的给Python解释器所有函数。Python 3中的初始化函数样子为PyInit_,其中就是模块的名字。例如在前面提到的from torch._C import *中,模块torch下面必要有一个名字为_C的子模块。因此它的初始化函数应该为PyInit__C,我们搜索该名字就能找到模块入口。当然另外还有一种方法,就是查看setup.py文件中关于扩展的描述信息:
// pytorch\setup.py
main_sources = ["torch/csrc/stub.c"]
C = Extension("torch._C",
libraries=main_libraries,
sources=main_sources,
language='c',
extra_compile_args=main_compile_args + extra_compile_args,
include_dirs=[],
library_dirs=library_dirs,
extra_link_args=extra_link_args + main_link_args + make_relative_rpath_args('lib'))
extensions.append(C)
不管是通过搜索还是查看setup.py,我们最终都成功定位到了位于pytorch\torch\csrc\stub.c下的模块初始化函数PyInit__C(void),并进一步跟踪其调用的函数initModule(),便可以知道具体都暴露了哪些API给Python解释器。
// pytorch\torch\csrc\stub.c
PyMODINIT_FUNC PyInit__C(void)
{
return initModule();
}
// pytorch\torch\csrc\Module.cpp
initModule()
进入initModule()寻找一番,你会发现,模块_C中依然没有prelu的Python接口。怎么办?莫慌,通过前面对torch.__init__.py的分析,我们知道我们还有希望——_C模块下的子模块_VariableFunctions,这真的是最后的希望了!没了别的路可以走了,只能是硬着头皮找。经过一番惊天地泣鬼神、艰苦卓绝的寻找,我们在initModule()的调用链initModule()->THPVariable_initModule(module)->torch::autograd::initTorchFunctions(module)中发现了_VariableFunctions的踪影。Aha,simple!
void initTorchFunctions(PyObject* module) {
if (PyType_Ready(&THPVariableFunctions) < 0) {
throw python_error();
}
Py_INCREF(&THPVariableFunctions);
// Steals
Py_INCREF(&THPVariableFunctions);
if (PyModule_AddObject(module, "_VariableFunctionsClass", reinterpret_cast(&THPVariableFunctions)) < 0) {
throw python_error();
}
// PyType_GenericNew returns a new reference
THPVariableFunctionsModule = PyType_GenericNew(&THPVariableFunctions, Py_None, Py_None);
// PyModule_AddObject steals a reference
if (PyModule_AddObject(module, "_VariableFunctions", THPVariableFunctionsModule) < 0) {
throw python_error();
}
}
但是!!别高兴太早!查看模块_VariableFunctions中暴露的接口你会发现,根本就没有我们想要的!如下面的代码所示:
static PyMethodDef torch_functions[] = {
{"arange", castPyCFunctionWithKeywords(THPVariable_arange),
METH_VARARGS | METH_KEYWORDS | METH_STATIC, NULL},
{"as_tensor", castPyCFunctionWithKeywords(THPVariable_as_tensor),
METH_VARARGS | METH_KEYWORDS | METH_STATIC, NULL},
{"dsmm", castPyCFunctionWithKeywords(THPVariable_mm), METH_VARARGS | METH_KEYWORDS | METH_STATIC, NULL},
{"from_numpy", THPVariable_from_numpy, METH_STATIC | METH_O, NULL},
{"full", castPyCFunctionWithKeywords(THPVariable_full), METH_VARARGS | METH_KEYWORDS | METH_STATIC, NULL},
{"hsmm", castPyCFunctionWithKeywords(THPVariable_hspmm), METH_VARARGS | METH_KEYWORDS | METH_STATIC, NULL},
{"nonzero", castPyCFunctionWithKeywords(THPVariable_nonzero), METH_VARARGS | METH_KEYWORDS | METH_STATIC, NULL},
{"randint", castPyCFunctionWithKeywords(THPVariable_randint), METH_VARARGS | METH_KEYWORDS | METH_STATIC, NULL},
{"range", castPyCFunctionWithKeywords(THPVariable_range), METH_VARARGS | METH_KEYWORDS | METH_STATIC, NULL},
{"saddmm", castPyCFunctionWithKeywords(THPVariable_sspaddmm), METH_VARARGS | METH_KEYWORDS | METH_STATIC, NULL},
{"sparse_coo_tensor", castPyCFunctionWithKeywords(THPVariable_sparse_coo_tensor), METH_VARARGS | METH_KEYWORDS | METH_STATIC, NULL},
{"_sparse_coo_tensor_unsafe", castPyCFunctionWithKeywords(THPVariable__sparse_coo_tensor_unsafe), METH_VARARGS | METH_KEYWORDS | METH_STATIC, NULL},
{"_validate_sparse_coo_tensor_args", castPyCFunctionWithKeywords(THPVariable__validate_sparse_coo_tensor_args), METH_VARARGS | METH_KEYWORDS | METH_STATIC, NULL},
{"spmm", castPyCFunctionWithKeywords(THPVariable_mm), METH_VARARGS | METH_KEYWORDS | METH_STATIC, NULL},
{"tensor", castPyCFunctionWithKeywords(THPVariable_tensor), METH_VARARGS | METH_KEYWORDS | METH_STATIC, NULL},
{"get_device", castPyCFunctionWithKeywords(THPVariable_get_device), METH_VARARGS | METH_KEYWORDS | METH_STATIC, NULL},
{"numel", castPyCFunctionWithKeywords(THPVariable_numel), METH_VARARGS | METH_KEYWORDS | METH_STATIC, NULL},
${py_method_defs}
{NULL}
};
上面的代码中我们找不到prelu的任何身影。会不会prelu可以绕开C/C++扩展的方式直接被Python使用呢?所以不会出现在这里?答案是不会,自古华山一条路,程序是不会跟你讲潜规则的。那么既然最终代码已经跟丢了,作者一定是使用了黑魔法,作为麻瓜的我无计可施,本文也该结束了……
等等,上面的C代码中好像混入了奇怪的东西——${py_method_defs}。这种语法好像C/C++语法里面是没有的,反而是Shell这类脚本里面才会有,难道是新特性?费劲查找了一圈,并没有发现C/C++中有这种语法,既然不是正经语法,那么混入C/C++中肯定会导致编译失败,但是它确实就在那里。那么真相只有一个:它就是个占位符,后面肯定会有真正的代码替换它!
接下来怎么办?搜索!使用py_method_defs作为关键字全局搜索,最终我们会发现,确实是有一个Python脚本对这个占位符进行了替换,而替换的结果就是我们一直寻找的prelu终于出现在了模块_VariableFunctions之中。好,破案了。
但是就像警察破案,即便有单个证据,也要找到其他证据形成完整证据链才能使得证据具有说服力。虽然我们通过搜索得知了prelu会出现在模块_VariableFunctions中,但是它究竟怎么来的目前还是很模糊:占位符在什么时候被谁调用的脚本进行了替换?
实际上,这一切都是有迹可循的。踪迹依旧在setup.py中。进入setup.py的主函数,在调用setup函数之前会看到一个名为build_deps()的函数调用,此函数最终会调用指定平台的CMake去按照根目录下CMakeLists.txt中的脚本进行构建。根目录下的CMakeLists.txt最终又会调用到caffe2目录下的CMakeLists.txt(add_subdirectory(caffe2)),而caffe2/CMakeLists.txt中就会调用到进行代码生成的Python脚本,如下所示:
代码生成脚本起调过程示意图
// pytorch\caffe2\CMakeLists.txt
add_custom_command( OUTPUT
${TORCH_GENERATED_CODE}
COMMAND
"${PYTHON_EXECUTABLE}" tools/setup_helpers/generate_code.py
--declarations-path "${CMAKE_BINARY_DIR}/aten/src/ATen/Declarations.yaml"
--native-functions-path "aten/src/ATen/native/native_functions.yaml"
--nn-path "aten/src"
$:--disable-autograd>
$:--selected-op-list-path="${SELECTED_OP_LIST}">
--force_schema_registration
进行代码生成的主要流程如下面代码块所示,其大概流程是main()先解析传递给脚本的参数,之后将参数传递给generate_code()。结合caffe2/CMakeLists.txt中脚本调用时传递的参数可知,generate_code()中的是三个gen_*()函数都得到了调用,而在gen_autograd_python()会调用到一个名为create_python_bindings()的函数,这个函数就是真正执行代码生成的地方。
代码生成器调用流程示意图
// tools/setup_helpers/generate_code.py
def generate_code(ninja_global=None,
declarations_path=None,
nn_path=None,
native_functions_path=None,
install_dir=None,
subset=None,
disable_autograd=False,
force_schema_registration=False,
operator_selector=None):
if subset == "pybindings" or not subset:
gen_autograd_python(
declarations_path or DECLARATIONS_PATH,
native_functions_path or NATIVE_FUNCTIONS_PATH,
autograd_gen_dir,
autograd_dir)
if operator_selector is None:
operator_selector = SelectiveBuilder.get_nop_selector()
if subset == "libtorch" or not subset:
gen_autograd(
declarations_path or DECLARATIONS_PATH,
native_functions_path or NATIVE_FUNCTIONS_PATH,
autograd_gen_dir,
autograd_dir,
disable_autograd=disable_autograd,
operator_selector=operator_selector,
)
if subset == "python" or not subset:
gen_annotated(
native_functions_path or NATIVE_FUNCTIONS_PATH,
python_install_dir,
autograd_dir)
def main():
parser = argparse.ArgumentParser(description='Autogenerate code')
parser.add_argument('--declarations-path')
parser.add_argument('--native-functions-path')
parser.add_argument('--nn-path')
parser.add_argument('--ninja-global')
parser.add_argument('--install_dir')
parser.add_argument(
'--subset',
help='Subset of source files to generate. Can be "libtorch" or "pybindings". Generates both when omitted.'
)
parser.add_argument(
'--disable-autograd',
default=False,
action='store_true',
help='It can skip generating autograd related code when the flag is set',
)
parser.add_argument(
'--selected-op-list-path',
help='Path to the YAML file that contains the list of operators to include for custom build.',
)
parser.add_argument(
'--operators_yaml_path',
help='Path to the model YAML file that contains the list of operators to include for custom build.',
)
parser.add_argument(
'--force_schema_registration',
action='store_true',
help='force it to generate schema-only registrations for ops that are not'
'listed on --selected-op-list'
)
options = parser.parse_args()
generate_code(
options.ninja_global,
options.declarations_path,
options.nn_path,
options.native_functions_path,
options.install_dir,
options.subset,
options.disable_autograd,
options.force_schema_registration,
# options.selected_op_list
operator_selector=get_selector(options.selected_op_list_path, options.operators_yaml_path),
)
if __name__ == "__main__":
main()
// pytorch\tools\autograd\gen_autograd.py
def gen_autograd_python(aten_path, native_functions_path, out, autograd_dir):
from .load_derivatives import load_derivatives
differentiability_infos = load_derivatives(
os.path.join(autograd_dir, 'derivatives.yaml'), native_functions_path)
template_path = os.path.join(autograd_dir, 'templates')
# Generate Functions.h/cpp
from .gen_autograd_functions import gen_autograd_functions_python
gen_autograd_functions_python(
out, differentiability_infos, template_path)
# Generate Python bindings
from . import gen_python_functions
deprecated_path = os.path.join(autograd_dir, 'deprecated.yaml')
gen_python_functions.gen(
out, native_functions_path, deprecated_path, template_path)
// pytorch\tools\autograd\gen_python_functions.py
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
#
# Main Function
#
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
def gen(out: str, native_yaml_path: str, deprecated_yaml_path: str, template_path: str) -> None:
fm = FileManager(install_dir=out, template_dir=template_path, dry_run=False)
methods = load_signatures(native_yaml_path, deprecated_yaml_path, method=True)
create_python_bindings(
fm, methods, is_py_variable_method, None, 'python_variable_methods.cpp', method=True)
functions = load_signatures(native_yaml_path, deprecated_yaml_path, method=False)
create_python_bindings(
fm, functions, is_py_torch_function, 'torch', 'python_torch_functions.cpp', method=False)
create_python_bindings(
fm, functions, is_py_nn_function, 'torch.nn', 'python_nn_functions.cpp', method=False)
create_python_bindings(
fm, functions, is_py_fft_function, 'torch.fft', 'python_fft_functions.cpp', method=False)
create_python_bindings(
fm, functions, is_py_linalg_function, 'torch.linalg', 'python_linalg_functions.cpp', method=False)
def create_python_bindings(
fm: FileManager,
pairs: Sequence[PythonSignatureNativeFunctionPair],
pred: Callable[[NativeFunction], bool],
module: Optional[str],
filename: str,
*,
method: bool,
) -> None:
"""Generates Python bindings to ATen functions"""
py_methods: List[str] = []
py_method_defs: List[str] = []
py_forwards: List[str] = []
grouped: Dict[BaseOperatorName, List[PythonSignatureNativeFunctionPair]] = defaultdict(list)
for pair in pairs:
if pred(pair.function):
grouped[pair.function.func.name.name].append(pair)
for name in sorted(grouped.keys(), key=lambda x: str(x)):
overloads = grouped[name]
py_methods.append(method_impl(name, module, overloads, method=method))
py_method_defs.append(method_def(name, module, overloads, method=method))
py_forwards.extend(forward_decls(name, overloads, method=method))
fm.write_with_template(filename, filename, lambda: {
'generated_comment': '@' + f'generated from {fm.template_dir}/{filename}',
'py_forwards': py_forwards,
'py_methods': py_methods,
'py_method_defs': py_method_defs,
})
最终通过查看native_functions.yaml的内容以及深入跟踪加载native_functions.yaml的代码发现,native_functions.yaml中的prelu最终会被写到以python_torch_functions.cpp为模板的文件中,也就是调用
create_python_bindings(
fm, functions, is_py_torch_function, 'torch', 'python_torch_functions.cpp', method=False)
的时候被生成。整个生成的过程其实是很繁琐的,一层层跟踪后可以发现,最终生成的代码可以实现将一个名为at::的函数暴露给Python。例如我们的prelu,暴露给Python的API最终会调用一个名为at::prelu()的函数来做真正的计算。那么这个at::(例如at::prelu())的定义又在哪里呢?
还是一样,故技重施!仍然使用Python脚本根据native_functions.yaml文件中的内容去以pytorch\aten\src\ATen\templates目录下的各种模板去生成对应的实际C++源文件。最终结果是得到at::,在这个函数中,它调用了Dispatcher这个类寻找到目标函数的句柄。通常情况下能够使用的函数句柄都通过一个叫Library的类来管理。Python脚本以RegisterSchema.cpp为模板,生成了注册这些目标函数的注册代码,并通过一个名为TORCH_LIBRARY的宏调用Library类来注册管理。
#define TORCH_LIBRARY(ns, m) \
static void TORCH_LIBRARY_init_ ## ns (torch::Library&); \
static const torch::detail::TorchLibraryInit TORCH_LIBRARY_static_init_ ## ns ( \
torch::Library::DEF, \
&TORCH_LIBRARY_init_ ## ns, \
#ns, c10::nullopt, __FILE__, __LINE__ \
); \
void TORCH_LIBRARY_init_ ## ns (torch::Library& m)
class TorchLibraryInit final {
private:
using InitFn = void(Library&);
Library lib_;
public:
TorchLibraryInit(Library::Kind kind, InitFn* fn, const char* ns, c10::optional<:dispatchkey> k, const char* file, uint32_t line)
: lib_(kind, ns, k, file, line) {
fn(lib_);
}
};
PyTorch组成示意图
3. 总结
PyTorch虽然在使用上是非常的Pythonic,但实际上Python只不过是为了方便使用裹在C++代码上的一层糖衣。用起来虽然好用,但是看起来实在是非常费劲,特别是如果静态的梳理代码,很多用于连接Python C/C++接口与实际逻辑代码之间的C++代码都是通过Python脚本生成的。至此,整个大的线索已经摸清了,剩下的就是去查看具体细节的实现。
说实话,人脑执行Python代码之后再去理解C++代码实在是费劲,也费头发。因此我决定的让电脑去生成C++代码再接着看更具体的细节,比如究竟每一个算子是怎么注册到Library之中的。
4. Bonus
我真心怀疑我们生活在一个虚拟机里,为什么呢?因为到处可见运用于计算机里面的空间和时间局部性原理的实例。就在我写完这个博客的时候,意外的发现了一篇PyTorch工程师讲解PyTorch内部原理的博文,这对后续读代码应该会有很大帮助。等不及就戳它吧 http://blog.ezyang.com/2019/05/pytorch-internals/
python pytorch fft_看PyTorch源代码的心路历程相关推荐
- python pytorch fft_用Pytorch实现FFT
我尝试使用Pytorch中提供的conv1d函数来实现FFT.在 产生人工信号import numpy as np import torch from torch.autograd import Va ...
- Python:机器学习模块PyTorch【上】
点击访问:PyTorch中文API应用具体代码地址 自动求导机制 本说明将概述Autograd如何工作并记录操作.了解这些并不是绝对必要的,但我们建议您熟悉它,因为它将帮助您编写更高效,更简洁的程序, ...
- 【深度学习】基于Torch的Python开源机器学习库PyTorch卷积神经网络
[深度学习]基于Torch的Python开源机器学习库PyTorch卷积神经网络 文章目录 1 CNN概述 2 PyTorch实现步骤2.1 加载数据2.2 CNN模型2.3 训练2.4 可视化训练 ...
- 【深度学习】基于Torch的Python开源机器学习库PyTorch回归
[深度学习]基于Torch的Python开源机器学习库PyTorch回归 文章目录1 torch.autograd 2 torch.nn.functional 3 详细的回归DEMO3.1 DATAS ...
- 【深度学习】基于Torch的Python开源机器学习库PyTorch概述
[深度学习]基于Torch的Python开源机器学习库PyTorch概述 文章目录 1 PyTorch简介 2 环境搭建 3 Hello world!3.1 Tensors (张量)3.2 操作 4 ...
- 图像迁移风格保存模型_图像风格迁移也有框架了:使用Python编写,与PyTorch完美兼容,外行也能用...
原标题:图像风格迁移也有框架了:使用Python编写,与PyTorch完美兼容,外行也能用 选自Medium 作者:Philip Meier 机器之心编译 编辑:陈萍 易于使用的神经风格迁移框架 py ...
- 利用python安装opencv_科学网—Anaconda Python PyCharm PyQT5 OpenCV PyTorch TF2.0 安装指南 - 张重生的博文...
Anaconda Python PyCharm PyQT5 OpenCV PyTorch TF2.0 安装指南与资料汇总 (用Anaconda配置Python集成开发环境,含Python3, PyQT ...
- python pytorch语音识别_PyTorch通过ASR实现语音到文本端的模型以及pytorch语音识别(speech) - pytorch中文网...
ASR,英文的全称是Automated Speech Recognition,即自动语音识别技术,它是一种将人的语音转换为文本的技术.今天我们主要了解pytorch实现语音到文本的端到端模型. spe ...
- python绘制图形沙漏_pytorch-pose一个用于二维人体姿势估计的PyTorch工具包。 - pytorch中文网...
pytorch-pose PyTorch-Pose是2D单人姿态估计的一般流水线的PyTorch实现.其目的是为最流行的人体姿态数据库(如MPII人体姿态,LSP和FLIC)提供训练/推理/评估的接口 ...
最新文章
- 研究人员开发出最节能的 Wi-Fi 技术
- 教你控制Python多线程中线程数量
- Java实现有向图的拓扑排序
- struct和typedef struct的区别(转)
- 如何设置顶部和底部固定,中间填满
- 如何避免ajax重复请求?
- Java实现获取汉字的拼音(首拼)
- 在IGBT的开启过程中,IGBT的电压降低,电流上升,在IGBT的关断过程中IGBT的电压上升,电流下降,在一段时间内,电压和电流均不为0,由于功率等于电压乘以电流,即P=U×I,因此将产生损耗,开
- 户外航模试飞地踩点--杭州
- Spark视频王家林大神 第7课: Spark机器学习内幕剖析
- 用python画一朵鲜艳欲滴的红玫瑰
- 【数据科学家学习小组】之统计学(第二期)第一周(20191028-20191103)-momi
- 小米路由作二级路由,挂在上级路由之下,samba能被上级访问
- SPI协议(Standard/Dual/Qual)
- AI+医疗:使用神经网络进行医学影像识别分析 ⛵
- 从代理模式再出发!Proxy.newProxyInstance的秘密
- 微信视频应用、视频直播、流媒体服务、视频教学、在线教育类原创文章汇总
- Struts + hibernate +spring课堂笔记
- [招聘信息]QA Engineer@EMC
- vc编译,丢失mspdb100.DLL解决方法
热门文章
- CSS三大特性:层叠性、继承性、优先级
- 关于编译错误 fatal error C1083: Cannot open precompiled header file
- 《Adobe Illustrator CS5中文版经典教程》—第0课0.5节使用绘图模式
- 发送json给服务器
- (42) Aeroo 模板实战
- linux 内核 linux kernel travel
- 深入JavaScript与.NET Framework中的日期时间(1):基本概念与概述
- sql 函数 汉字转拼音
- dojo Quick Start/dojo入门手册--xmlhttp dojo.xhrGet
- 连接MongoDB 3.x 报 Authorization failed 解决办法(自己只用到了创建mongodb账号和密码的部分亲测)