背景

在使用PyTorch深度学习框架的时候,不管是训练还是测试,代码中引入PyTorch的第一句总是:

import torch

在Gemfield前述专栏文章里,我们已经得知,torch/csrc/stub.cpp链接libshm.so、libtorch_python.so、libcaffe2_gpu.so生成了_C.cpython-37m-x86_64-linux-gnu.so库,而像前述方式import torch的时候,按照python规范,会找到torch package目录下的__init__.py,在这个文件中进一步会调用:

from torch._C import *

其中torch._C就是_C.cpython-37m-x86_64-linux-gnu.so。因为(以Python3为例)按照Python规范,由于默认的引擎都是CPython,而CPython的C/C++扩展是一个共享库,并且这个共享库安装在PYTHONPATH目录下,并且文件名(不包含后缀)要和module的名字一样,并且这个共享库中要实现PyInit_modulename符号来作为import时候的逻辑入口。

对于PyTorch来说这个modulename 是_C,因此我们可以揣测,在torch/csrc/stub.cpp中一定实现了PyInit_C这个函数。是的,PyTorch就是这么做的,torch/csrc/stub.cpp中的代码就是下面这样:

#include

extern PyObject* initModule();

PyMODINIT_FUNC PyInit__C()

{

return initModule();

}

本文将从initModule函数展开,全面阐述PyTorch框架的初始化工作。initModule就是PyTorch初始化时候的第一层调用栈了,因为所有的初始化工作都是在这个函数内完成的,内容比较多,gemfield将其划分为7部分:

1,torch._C的诞生:

这一步就是产生torch._C类,并在这个python类上面注册众多函数:

PyObject* initModule() {

//openmp的设置

THInferNumThreads();

THPUtils_addPyMethodDefs(methods, TorchMethods);

THPUtils_addPyMethodDefs(methods, DataLoaderMethods);

THPUtils_addPyMethodDefs(methods, torch::autograd::python_functions());

THPUtils_addPyMethodDefs(methods, torch::multiprocessing::python_functions());

THPUtils_addPyMethodDefs(methods, THCPModule_methods());

THPUtils_addPyMethodDefs(methods, THCUDNN_methods());

THPUtils_addPyMethodDefs(methods, THDPModule_methods());

THPUtils_addPyMethodDefs(methods, torch::distributed::c10d::python_functions());

module = Py_InitModule("torch._C", methods.data());

......

}

其中TorchMethods注册了29个方法,都是THPModule_前缀的函数;DataLoaderMethods注册了4个方法,都是THPModule_前缀的函数;torch::autograd::python_functions注册了4个方法;torch::multiprocessing::python_functions注册了1个方法;THCPModule_methods注册了37个CUDA相关的函数,前缀都是THCPModule_;THCUDNN_methods注册了1个方法;THDPModule_methods注册了28个方法;torch::distributed::c10d::python_functions注册了1个方法。

总而言之,在这一小步,我们达到了这样一个里程碑,torch._C符号诞生,并且向torch._C注册了一百余个函数,涉及torch、dataloader、autograd、multiprocess、cuda、cudnn、distribute、c10d方面。

2,一些关键类型

以下代码先后初始化了torch._C._PtrWrapper、torch._C.Generator(含5个方法)、FatalError、torch.Size、torch.dtype、torch.iinfo、torch.layout、torch.device:

PyObject* initModule() {

......

THPWrapper_init(module);

THPGenerator_init(module);

THPException_init(module);

THPSize_init(module);

THPDtype_init(module);

THPDTypeInfo_init(module);

THPLayout_init(module);

THPDevice_init(module);

THPVariable_initModule(module);

THPFunction_initModule(module);

THPEngine_initModule(module);

......

}

3,torch._C._TensorBase的诞生

Gemfield将以下三个初始化函数归为这一小节:

PyObject* initModule() {

......

THPVariable_initModule(module);

THPFunction_initModule(module);

THPEngine_initModule(module);

......

}

为什么呢?因为地位太显赫了。

THPVariable_initModule(module) 创建了torch._C._TensorBase,这是一切Tensor的基类,在Gemfield的其它专栏文章里将单独解释;

THPFunction_initModule(module)创建了torch._C._FunctionBase,在torch/autograd/function.py中,以下两个类以torch._C._FunctionBase为基类:

class Function(with_metaclass(FunctionMeta, _C._FunctionBase, _ContextMethodMixin, _HookMixin))

class BackwardCFunction(_C._FunctionBase, _ContextMethodMixin, _HookMixin)

这个Function继承体系就构成了DAG的基础。

THPEngine_initModule(module)创建了torch._C._EngineBase,_EngineBase这个类负责动态图执行之前的preprocess,_EngineBase会将torch.autograd的backward之类的请求预处理后送给真正的Engine去执行。

4,pybind11绑定

这一小节的初始化内容都是和pybind11相关的:

PyObject* initModule() {

......

// NOTE: We need to be able to access OperatorExportTypes from ONNX for use in

// the export side of JIT, so this ONNX init needs to appear before the JIT

// init.

torch::onnx::initONNXBindings(module);

torch::jit::initJITBindings(module);

torch::autograd::initNNFunctions(module);

torch::autograd::init_legacy_variable(module);

torch::python::init_bindings(module);

torch::cuda::initModule(module);

......

}

initONNXBindings是ONNX的python binding:torch._C._onnx.TensorProtoDataType和torch._C._onnx.OperatorExportTypes:

>>> dir(torch._C._onnx.TensorProtoDataType)

['BOOL', 'COMPLEX128', 'COMPLEX64', 'DOUBLE', 'FLOAT', 'FLOAT16', 'INT16', 'INT32', 'INT64', 'INT8', 'STRING', 'UINT16', 'UINT32', 'UINT64', 'UINT8', 'UNDEFINED', '__class__', '__delattr__', '__dir__', '__doc__', '__eq__', '__format__', '__ge__', '__getattribute__', '__getstate__', '__gt__', '__hash__', '__init__', '__int__', '__le__', '__lt__', '__members__', '__module__', '__ne__', '__new__', '__reduce__', '__reduce_ex__', '__repr__', '__setattr__', '__setstate__', '__sizeof__', '__str__', '__subclasshook__', 'name']

>>> dir(torch._C._onnx.OperatorExportTypes)

['ONNX', 'ONNX_ATEN', 'ONNX_ATEN_FALLBACK', 'RAW', '__class__', '__delattr__', '__dir__', '__doc__', '__eq__', '__format__', '__ge__', '__getattribute__', '__getstate__', '__gt__', '__hash__', '__init__', '__int__', '__le__', '__lt__', '__members__', '__module__', '__ne__', '__new__', '__reduce__', '__reduce_ex__', '__repr__', '__setattr__', '__setstate__', '__sizeof__', '__str__', '__subclasshook__', 'name']

initJITBindings则是通过pybind11往torch._C上注册了一堆和JIT相关的C++函数/对象;

initNNFunctions初始化了一个torch._C._nn 对象,并注册了一些nn相关的函数:

>>> dir(torch._C._nn)

['__doc__', '__loader__', '__name__', '__package__', '__spec__', '_parse_to', 'adaptive_avg_pool2d', 'adaptive_avg_pool3d', 'adaptive_max_pool2d', 'adaptive_max_pool3d', 'avg_pool2d', 'avg_pool3d', 'binary_cross_entropy', 'elu', 'elu_', \

'fractional_max_pool2d', 'glu', 'hardtanh', 'hardtanh_', 'l1_loss', 'leaky_relu', 'leaky_relu_', 'log_sigmoid', 'max_pool2d_with_indices', 'max_pool3d_with_indices', 'max_unpool2d', 'max_unpool3d', 'mse_loss', 'multi_margin_loss', \

'multilabel_margin_loss', 'nll_loss', 'nll_loss2d', 'reflection_pad1d', 'reflection_pad2d', 'replication_pad1d', 'replication_pad2d', 'replication_pad3d', 'rrelu_with_noise', 'rrelu_with_noise_', 'smooth_l1_loss', 'soft_margin_loss', \

'softplus', 'softshrink', 'thnn_conv2d', 'thnn_conv3d', 'thnn_conv_depthwise2d', 'thnn_conv_dilated2d', 'thnn_conv_dilated3d', 'thnn_conv_transpose2d', 'thnn_conv_transpose3d', 'upsample_bilinear2d', 'upsample_linear1d', 'upsample_nearest1d', \

'upsample_nearest2d', 'upsample_nearest3d', 'upsample_trilinear3d']

init_legacy_variable注册了torch._C._LegacyVariableBase:

>>> dir(torch._C._LegacyVariableBase)

['__class__', '__delattr__', '__dir__', '__doc__', '__eq__', '__format__', \

'__ge__', '__getattribute__', '__gt__', '__hash__', '__init__', '__le__', \

'__lt__', '__ne__', '__new__', '__reduce__', '__reduce_ex__', '__repr__', \

'__setattr__', '__sizeof__', '__str__', '__subclasshook__']

_LegacyVariableBase类会派生出Variable类(该类的_execution_engine会初始化为torch._C._EngineBase):

class Variable(with_metaclass(VariableMeta, torch._C._LegacyVariableBase))

init_bindings是通过pybind11往torch._C上注册一些函数,torch::cuda::initModule类似,也是通过pybind11往torch._C上注册一些函数,只不过内容是和cuda相关的。

5,在torch._C上注册StorageBase类

PyObject* initModule() {

......

THPDoubleStorage_init(module);

THPFloatStorage_init(module);

THPHalfStorage_init(module);

THPLongStorage_init(module);

THPIntStorage_init(module);

THPShortStorage_init(module);

THPCharStorage_init(module);

THPByteStorage_init(module);

THCPDoubleStorage_init(module);

THCPFloatStorage_init(module);

THCPHalfStorage_init(module);

THCPLongStorage_init(module);

THCPIntStorage_init(module);

THCPShortStorage_init(module);

THCPCharStorage_init(module);

THCPByteStorage_init(module);

THCPStream_init(module);

......

}

这些初始化工作主要就是往torch._C上注册了以下类:

CudaByteStorageBase

CudaCharStorageBase

CudaDoubleStorageBase

CudaFloatStorageBase

CudaHalfStorageBase

CudaIntStorageBase

CudaLongStorageBase

CudaShortStorageBase

ByteStorageBase

CharStorageBase

DoubleStorageBase

FloatStorageBase

HalfStorageBase

IntStorageBase

LongStorageBase

ShortStorageBase

比如以FloatStorageBase为例的话,我们可以这样查看它注册的方法:

>>> dir(torch._C.FloatStorageBase)

['__class__', '__delattr__', '__delitem__', '__dir__', '__doc__', '__eq__', '__format__', '__ge__', '__getattribute__', '__getitem__', '__gt__', '__hash__', '__init__', '__le__', '__len__', '__lt__', \

'__ne__', '__new__', '__reduce__', '__reduce_ex__', '__repr__', '__setattr__', '__setitem__', '__sizeof__', '__str__', '__subclasshook__', '_cdata', '_expired', '_free_weak_ref', \

'_get_shared_fd', '_new_shared_fd', '_new_shared_filename', '_new_using_fd', '_new_using_filename', '_new_with_file', '_new_with_weak_ptr', '_set_cdata', '_set_from_file', '_share_fd_', \

'_share_filename_', '_shared_decref', '_shared_incref', '_weak_ref', '_write_file', 'copy_', 'data_ptr', 'element_size', 'fill_', 'from_buffer', 'from_file', 'is_pinned', 'is_shared', 'new', \

'resize_', 'size']

这些类会在python体系中被继承:

class FloatStorage(_C.FloatStorageBase, _StorageBase)

另外注意下这块代码使用了一些宏来复用不同storage的代码,如下所示:

aten/src/TH/THGenerateLongType.h:10:#define Real Long

aten/src/TH/THGenerateHalfType.h:10:#define Real Half

aten/src/TH/THGenerateIntType.h:10:#define Real Int

aten/src/TH/THGenerateFloatType.h:9:#define Real Float

aten/src/TH/THGenerateShortType.h:10:#define Real Short

aten/src/TH/THGenerateCharType.h:8:#define Real Char

aten/src/TH/THGenerateByteType.h:8:#define Real Byte

aten/src/TH/THGenerateDoubleType.h:9:#define Real Double

aten/src/THC/THCGenerateIntType.h:7:#define Real Int

aten/src/THC/THCGenerateLongType.h:7:#define Real Long

aten/src/THC/THCGenerateCharType.h:7:#define Real Char

aten/src/THC/THCGenerateFloatType.h:9:#define Real Float

aten/src/THC/THCGenerateDoubleType.h:7:#define Real Double

aten/src/THC/THCGenerateHalfType.h:9:#define Real Half

aten/src/THC/THCGenerateShortType.h:7:#define Real Short

aten/src/THC/THCGenerateByteType.h:7:#define Real Byte

6,ATen的初始化

本小节会进行ATen的global context的初始化,然后使用at::globalContext().defaultGenerator(at::kCPU)进行generator的初始化。

另外,PyTorch会根据编译环境和用户配置,然后向torch._C上注册一些flag。这些flag有has_cudnn、has_mkl、has_lapack、_GLIBCXX_USE_CXX11_ABI:

PyObject* initModule() {

......

PyObject *has_cudnn = Py_True;

set_module_attr("has_cudnn", has_cudnn);

at::init();

py::reinterpret_borrow<:module>(module).def("_demangle", &c10::demangle);

::c10::Warning::set_warning_handler(&warning_handler);

set_module_attr("has_mkl", at::hasMKL() ? Py_True : Py_False);

set_module_attr("has_lapack", at::hasLAPACK() ? Py_True : Py_False);

set_module_attr("_GLIBCXX_USE_CXX11_ABI", _GLIBCXX_USE_CXX11_ABI ? Py_True : Py_False);

auto& defaultGenerator = at::globalContext().defaultGenerator(at::kCPU);

THPDefaultGenerator = (THPGenerator*)THPGenerator_NewWithGenerator(defaultGenerator);

set_module_attr("default_generator", (PyObject*)THPDefaultGenerator, /* incref= */ false);

7,torch._C._THNN和torch._C._THCUNN的初始化

PyTorch在这一小节里注册了torch._C._THNN和torch._C._THCUNN类:

PyObject* initModule() {

......

torch::nn::init__THNN(module);

torch::nn::init__THCUNN(module);

......

}

这两个类都拥有数量巨大的op函数,一个是CPU版的,一个是CUDA版的。

initModule之后

在initModule()函数初始化完毕之后,import torch的初始化工作还没有结束。因为在这之后,python的初始化脚本还要调用以下2个API才算真正完成全部的初始化:

_C._initExtension(manager_path())

_C._init_names(list(torch._storage_classes))

其中主要的工作都是在_C._initExtension中,这个初始化做了以下的工作:

torch::utils::initializeLayouts();

torch::utils::initializeDtypes();

torch::tensors::initialize_python_bindings();

THPDoubleStorage_postInit(module);

THPFloatStorage_postInit(module);

THPHalfStorage_postInit(module);

THPLongStorage_postInit(module);

THPIntStorage_postInit(module);

THPShortStorage_postInit(module);

THPCharStorage_postInit(module);

THPByteStorage_postInit(module);

THPBoolStorage_postInit(module);

//定义在THPStorage_(postInit)函数中,因为THPStorage_会被宏替换THPDoubleStorage_ \

//THPFloatStorage_、THPHalfStorage_、THPLongStorage_......

THPAutograd_initFunctions();

最后的THPAutograd_initFunctions()则是初始化了torch的自动微分系统,这是PyTorch动态图框架的基础。

总结

在PyTorch的初始化阶段,(python)torch模块先后初始化产生torch._C、torch._C._TensorBase、pybind11绑定、torch._C.*StorageBase、torch._C._THNN、torch._C._THCUNN,并进行了ATen context的初始化。在initModule()结束之后,初始化工作还进行了_C._initExtension()的初始化。

pytorch默认初始化_PyTorch的初始化相关推荐

  1. pytorch默认初始化_小白学PyTorch | 9 tensor数据结构与存储结构

    [机器学习炼丹术]的学习笔记分享<> 小白学PyTorch | 8 实战之MNIST小试牛刀 小白学PyTorch | 7 最新版本torchvision.transforms常用API翻 ...

  2. pytorch默认初始化_“最全PyTorch分布式教程”来了!

    前言 本文对使用pytorch进行分布式训练(单机多卡)的过程进行了详细的介绍,附加实际代码,希望可以给正在看的你提供帮助.本文分三个部分展开,分别是: 先验知识 使用过程框架 代码解析 若想学习分布 ...

  3. C++经典问题:如果对象A中有对象成员B,对象B没有默认构造函数,那么对象A必须在初始化列表中初始化对象B?

    对象成员特点总结: (1)实例化对象A时,如果对象A有对象成员B,那么先执行对象B的构造函数,再执行A的构造函数. (2)如果对象A中有对象成员B,那么销毁对象A时,先执行对象A的析构函数,再执行B的 ...

  4. C语言 数组的初始化 数组不初始化会怎样 数组的默认初始值

    本程序用于测试:数组的初始化. (1)定义数组后必须要初始化,不要认为不初始化,系统就会自动初始化为O;如果不初始化,局部变量在栈上,各数组元素的值将是随机数; (2)数组初始化:程序员至少必须把数组 ...

  5. 初始化、赋值、默认初始化、列表初始化、类内初始值、直接初始化与拷贝初始化

    文章目录 初始化和赋值的区别 什么是默认初始化? 列表初始化 列表初始化的使用场景 不适合使用列表初始化的场景 类内初始值 混用string对象和C风格字符串 数组与vector对象 关于vector ...

  6. Spark源码剖析 - SparkContext的初始化(八)_初始化管理器BlockManager

    8.初始化管理器BlockManager 无论是Spark的初始化阶段还是任务提交.执行阶段,始终离不开存储体系.Spark为了避免Hadoop读写磁盘的I/O操作成为性能瓶颈,优先将配置信息.计算结 ...

  7. 28.构造函数中,成员变量一定要通过初始化列表来初始化的?

    首先要明确:如果对象成员是const或者引用的话,必须将其初始化! 构造函数中,成员变量一定要通过初始化列表来初始化的的几种情况! 1)对象成员是const或者引用 #include <iost ...

  8. 【Java4】实例初始化,类初始化,/接口,多态,final/static,权限修饰符/native

    文章目录 1.实例初始化过程:有几个构造器,就会有几个实例初始化方法 2.实例初始化和类初始化结合:先类(静态)后实 3.接口:只有abstract可省 3.1 鸟类案例:Flyable相当于父类的一 ...

  9. C++直接初始化与复制初始化的区别深入解析

    首先:这是原文地址,这个哥们的文章解决了我的问题.谢谢这个哥们了.下面把原文地址放在这里: https://www.jb51.net/article/54773.htm C++中直接初始化与复制初始化 ...

最新文章

  1. oracle schedule stop,Oracle调度Schedule特性(第八部分)-Windows和Window Groups
  2. java gc的工作原理、如何优化GC的性能、如何和GC进行有效的交互
  3. 设计模式的六大原则(个人笔记)
  4. NFS为lamp提供共享存储实践
  5. android开发rn插件,在Android原生应用中嵌入React Native
  6. ofstream、ifstream、fstream
  7. should,would,could,must,might,may,can有什么区别
  8. #!(sha-bang)--脚本的开始
  9. 关于collectionview布局的坑
  10. 13凯越门锁继电器在哪里_凯越中控门锁不工作.更换中央门锁装置故障依旧.
  11. 洛谷P3386 【模板】二分图匹配
  12. linux 编译java web_linux:搭建java web环境
  13. 如何优雅的注入Java Agent内存马
  14. 【论文笔记】基于交易的以太坊智能合约分类检测方法
  15. JavaScript介绍及其特点
  16. 关于Jetson TX2刷机各种问题(刷机后键盘等等奇葩错误)
  17. 关于数据导出成excel表
  18. Python 实现哥德巴赫猜想
  19. 前端入门教程(四)head内常用标签与body内常用标签
  20. 华为哪款手机是鸿蒙系统_华为鸿蒙系统不会用于手机?

热门文章

  1. c语言注释参与程序设计的编译,提高C语言程序设计教学的有益探索
  2. python中scale_Python中的Log-scale mathplotlib?
  3. 树莓派装系统,配置,换源,远程操控
  4. PyTorch框架学习二——基本数据结构(张量)
  5. QT事件过滤器eventFilter函数
  6. 传统手工特征--opencv
  7. iBatis 事务控制 与 两表操作将SQL语句写入单表
  8. master分支删除文件_Git分支基础简介;创建分支;合并分支;删除分支;
  9. html网页定位,HTML_定位网页元素(示例代码)
  10. shell linux教程,Shell入门基础知识