点击上方“小白学视觉”,选择加"星标"或“置顶

重磅干货,第一时间送达

本文转自:深度学习这件小事

背景

PyTorch的动态图框架主要是由torch/csrc/autograd下的代码实现的。这个目录下定义了3个主要的基类:Variable、Function、Engine,这三个基类及其继承体系共同构成了PyTorch动态图的根基。

为什么叫作动态图呢?图容易理解,Function是nodes/vertices,(Function, input_nr)是edges。那么动态体现在什么地方呢?每一次前向时构建graph,反向时销毁。本文就以torch/csrc/autograd/下的代码为基础,深入讲解PyTorch的动态图系统——这也可能是互联网上关于PyTorch动态图最详尽的文章了。

在专栏文章《PyTorch的初始化》(https://zhuanlan.zhihu.com/p/57571317)中,gemfield描述了PyTorch的初始化流程,在文末提到了THPAutograd_initFunctions()调用:“最后的THPAutograd_initFunctions()则是初始化了torch的自动微分系统,这是PyTorch动态图框架的基础”。而本文将以THPAutograd_initFunctions开始,带你走入到PyTorch的动态图世界中。首先为上篇,主要介绍Function、Variable、Engine的类的继承体系。

autograd初始化

THPAutograd_initFunctions这个函数实现如下:


void THPAutograd_initFunctions(){  THPObjectPtr module(PyModule_New("torch._C._functions"));  ......  generated::initialize_autogenerated_functions();  auto c_module = THPObjectPtr(PyImport_ImportModule("torch._C"));}

用来初始化cpp_function_types表,这个表维护了从cpp类型的函数到python类型的映射:

static std::unordered_map<std::type_index, THPObjectPtr> cpp_function_types

这个表里存放的都是和autograd相关的函数的映射关系,起什么作用呢?比如我在python中print一个Variable的grad_fn:

>>> gemfield = torch.empty([2,2],requires_grad=True)>>> syszux = gemfield * gemfield>>> syszux.grad_fn<ThMulBackward object at 0x7f111621c350>

grad_fn是一个Function的实例,我们在C++中定义了那么多反向函数(参考下文),但是怎么在python中访问呢?就靠上面这个表的映射。实际上,cpp_function_types这个映射表就是为了在python中打印grad_fn服务的。

Variable

参考:https://zhuanlan.zhihu.com/p/64135058

以下面的代码片段作为例子:

gemfield = torch.ones(2, 2, requires_grad=True)syszux = gemfield + 2civilnet = syszux * syszux * 3gemfieldout = civilnet.mean()gemfieldout.backward()

需要指出的是,动态图是在前向的时候建立起来的。gemfieldout作为前向的最终输出,在反向传播的时候,却是计算的最初输入—在动态图中,我们称之为root。在下文介绍Engine的时候,你就会看到,我们会使用gemfieldout这个root来构建GraphRoot实例,以此作为Graph的输入。

Function

在开始介绍Function之前,还是以上面的代码为例,在一次前向的过程中,我们会创建出如下的Variable和Function实例:

#Variable实例gemfield --> grad_fn_ (Function实例)= None         --> grad_accumulator_ (Function实例)= AccumulateGrad实例0x55ca7f304500         --> output_nr_ = 0
#Function实例, 0x55ca7f872e90AddBackward0实例 --> sequence_nr_ (uint64_t) = 0            --> next_edges_ (edge_list) --> std::vector<Edge> = [(AccumulateGrad实例, 0),(0, 0)]            --> input_metadata_ --> [(type, shape, device)...] = [(CPUFloatType, [2, 2],cpu])]            --> alpha (Scalar) = 1            --> apply() --> 使用 AddBackward0 的apply
#Variable实例syszux --> grad_fn_ (Function实例)= AddBackward0实例0x55ca7f872e90       --> output_nr_ = 0
#Function实例, 0x55ca7ebba2a0MulBackward0 --> sequence_nr_ (uint64_t) = 1            --> next_edges_ (edge_list) = [(AddBackward0实例0x55ca7f872e90,0),(AddBackward0实例0x55ca7f872e90,0)]            --> input_metadata_ --> [(type, shape, device)...] = [(CPUFloatType, [2, 2],cpu])]            --> alpha (Scalar) = 1            --> apply() --> 使用 MulBackward0 的apply
# #Variable实例,syszux * syszux得到的tmptmp --> grad_fn_ (Function实例)= MulBackward0实例0x55ca7ebba2a0    --> output_nr_ = 0
#Function实例,0x55ca7fada2f0MulBackward0 --> sequence_nr_ (uint64_t) = 2 (每个线程内自增)            --> next_edges_ (edge_list) = [(MulBackward0实例0x55ca7ebba2a0,0),(0,0)]            --> input_metadata_ --> [(type, shape, device)...] = [(CPUFloatType, [2, 2],cpu])]            --> self_ (SavedVariable) = tmp的浅拷贝            --> other_ (SavedVariable) = 3的浅拷贝            --> apply() --> 使用 MulBackward0 的apply
#Variable实例civilnet --> grad_fn_ (Function实例)= MulBackward0实例0x55ca7fada2f0                                          -
#Function实例,0x55ca7eb358b0MeanBackward0 --> sequence_nr_ (uint64_t) = 3 (每个线程内自增)              --> next_edges_ (edge_list) = [(MulBackward0实例0x55ca7fada2f0,0)]              --> input_metadata_ --> [(type, shape, device)...] = [(CPUFloatType|[]|cpu])]              --> self_sizes (std::vector<int64_t>) = (2, 2)              --> self_numel = 4              --> apply() --> 使用 MulBackward0 的apply#Variable实例gemfieldout --> grad_fn_ (Function实例)= MeanBackward0实例0x55ca7eb358b0            --> output_nr_ = 0

这些用于反向计算的Function实例之间通过next_edges_连接在一起,因为这些Function的实际运行都是在反向期间,因此,输出输出关系正好和前向期间是反过来的。它们通过next_edges_连接在一起。用一个图来概括,就是下面这样:

这就引入一个新的话题——Function类是如何抽象出来的。

#Function基类定义

Function的数据成员如下所示:

using edge_list = std::vector<Edge>;using variable_list = std::vector<Variable>;
struct TORCH_API Function {...  virtual variable_list apply(variable_list&& inputs) = 0;...  const uint64_t sequence_nr_;  edge_list next_edges_;  PyObject* pyobj_ = nullptr; // weak reference  std::unique_ptr<AnomalyMetadata> anomaly_metadata_ = nullptr;  std::vector<std::unique_ptr<FunctionPreHook>> pre_hooks_;  std::vector<std::unique_ptr<FunctionPostHook>> post_hooks_;  at::SmallVector<InputMetadata, 2> input_metadata_;};

#Function call

Function类是抽象出来的基类,代表一个op(operation),每个op接收的参数是0个、1个或多个Variable实例(使用std::vector封装),并与此同时输出0个、1个或多个Variable实例。PyTorch中所有用于反向传播计算的函数都继承自Function类,并重写了Function类中的apply纯虚函数。因为Function类中实现了call函数:

variable_list operator()(variable_list&& inputs) {return apply(std::move(inputs));
}

所以依靠C++的多态,对op的call将转化为自身(子类)的apply调用。Function类中最重要的方法是call函数,call会调用apply,call函数接收vector封装的多个Variable实例,并输出vector封装的多个Variable实例。输入参数的vector长度可以由num_inputs()调用获得,对应的,输出的vector长度则由num_outputs()获得。

#Function的输入

Function成员input_metadata_代表input data的meta信息,界定了一个Function的输入:

struct InputMetadata {...  const at::Type* type_ = nullptr;  at::DimVector shape_;  at::Device device_ = at::kCPU;};

#Autograd graph的edge和vertices

如果将PyTorch的autograd系统看作是一个图(graph)的话,那么每个Function实例就是graph中的节点(nodes/vertices),各个Function实例之间则是通过Edge连接的。Edge是个结构体,通过 (Function, input_nr) 的配对来代表graph中的edge:

struct Edge {...  std::shared_ptr<Function> function;  uint32_t input_nr;};

Function的成员next_edges_正是一组这样的Edge实例,代表此function实例的返回值要输出到的(另外)function,也即next_edges_是function和function之间的纽带。

Function的输入输出都是Variable实例,因此,当一个graph被执行的时候,Variable实例就在这些edges之间来传输流动。当两个或者多个Edge指向同一个Function的时候(这个节点的入度大于1),这些edges的输出将会隐含的相加起来再送给指向的目标Function。

Function和Function之间通过next_edge接口连接在一起,你可以使用add_next_edge()来向Function添加一个edge, 通过next_edge(index)获取对应的edge,通过next_edges()方法获得迭代edge的迭代器。每一个Function都有一个sequence number,随着Function实例的不断构建而单调增长。你可以通过sequence_nr()方法来或者一个Function的sequence number。

Function继承体系

基类Function直接派生出TraceableFunction和以下这些Function:

CopySlices : public Function DelayedError : public Function Error : public Function Gather : public Function GraphRoot : public Function Scatter : public FunctionAccumulateGrad : public Function AliasBackward : public Function AsStridedBackward : public Function CopyBackwards : public Function DiagonalBackward : public Function ExpandBackward : public Function IndicesBackward0 : public Function IndicesBackward1 : public Function PermuteBackward : public Function SelectBackward : public Function SliceBackward : public Function SqueezeBackward0 : public Function SqueezeBackward1 : public Function TBackward : public Function TransposeBackward0 : public Function UnbindBackward : public Function UnfoldBackward : public Function UnsqueezeBackward0 : public Function ValuesBackward0 : public Function ValuesBackward1 : public Function ViewBackward : public Function
PyFunction : public Function

这其中,从基类Function派生出来的AccumulateGrad、TraceableFunction、GraphRoot是比较关键的类。

#派生类AccumulateGrad

先说说AccumulateGrad,AccumulateGrad正是Variable的grad_accumulator_成员的类型:


struct AccumulateGrad : public Function {  explicit AccumulateGrad(Variable variable_);  variable_list apply(variable_list&& grads) override;  Variable variable;};

可见一个AccumulateGrad实例必须用一个Variable构建,apply调用接收一个list的Variable的实例——这都是和Variable的grad_accumulator_相关的。

#派生类GraphRoot

对于GraphRoot,前向时候的最终输出——在反向的时候作为最初输入——是由GraphRoot封装的:


struct GraphRoot : public Function {  GraphRoot(edge_list functions, variable_list inputs)      : Function(std::move(functions)),        outputs(std::move(inputs)) {}  variable_list apply(variable_list&& inputs) override {    return outputs;  }  variable_list outputs;};

GraphRoot——正如Function的灵魂在apply一样——其apply函数仅仅返回它的输入!

#派生类TraceableFunction

再说说TraceableFunction:

struct TraceableFunction : public Function {using Function::Function;bool is_traceable() final {return true;}
};

TraceableFunction会进一步派生出372个子类(2019年4月),这些子类的名字都含有一个共同的部分:Backward。这说明什么呢?这些函数将只会用在反向传播中:

AbsBackward : public TraceableFunction AcosBackward : public TraceableFunction AdaptiveAvgPool2DBackwardBackward : public TraceableFunction AdaptiveAvgPool2DBackward : public TraceableFunction AdaptiveAvgPool3DBackwardBackward : public TraceableFunction AdaptiveAvgPool3DBackward : public TraceableFunction AdaptiveMaxPool2DBackwardBackward : public TraceableFunction AdaptiveMaxPool2DBackward : public TraceableFunction AdaptiveMaxPool3DBackwardBackward : public TraceableFunction AdaptiveMaxPool3DBackward : public TraceableFunction AddBackward0 : public TraceableFunction AddBackward1 : public TraceableFunction AddbmmBackward : public TraceableFunction AddcdivBackward : public TraceableFunction AddcmulBackward : public TraceableFunction AddmmBackward : public TraceableFunction AddmvBackward : public TraceableFunction AddrBackward : public TraceableFunction ......SoftmaxBackwardDataBackward : public TraceableFunction SoftmaxBackward : public TraceableFunction ......UpsampleBicubic2DBackwardBackward : public TraceableFunction UpsampleBicubic2DBackward : public TraceableFunction UpsampleBilinear2DBackwardBackward : public TraceableFunction UpsampleBilinear2DBackward : public TraceableFunction UpsampleLinear1DBackwardBackward : public TraceableFunction UpsampleLinear1DBackward : public TraceableFunction UpsampleNearest1DBackwardBackward : public TraceableFunction UpsampleNearest1DBackward : public TraceableFunction UpsampleNearest2DBackwardBackward : public TraceableFunction UpsampleNearest2DBackward : public TraceableFunction UpsampleNearest3DBackwardBackward : public TraceableFunction UpsampleNearest3DBackward : public TraceableFunction UpsampleTrilinear3DBackwardBackward : public TraceableFunction UpsampleTrilinear3DBackward : public TraceableFunction ......

这300多个Backward function都重写了apply函数,来实现自己的反向求导算法,比如加法的反向求导函数AddBackward0:

struct AddBackward0 : public TraceableFunction {  using TraceableFunction::TraceableFunction;  variable_list apply(variable_list&& grads) override;  Scalar alpha;};

这些apply函数是Function的灵魂,是反向传播计算时候的核心执行逻辑。

Engine

Engine类实现了从输出的variable(以及它的gradients)到root variables(用户创建的并且requires_grad=True)之间的反向传播。

gemfield = torch.ones(2, 2, requires_grad=True)syszux = gemfield + 2civilnet = syszux * syszux * 3gemfieldout = civilnet.mean()gemfieldout.backward()

还是以上面这个代码片段为例,Engine实现了从gemfieldout到gemfield的反向传播:

1,如何根据gemfieldout构建GraphRoot;

2,如何根据这些Function实例及它们上的metadata构建graph;

3,如何实现Queue来多线程完成反向计算的工作。

#Engine类定义

Engine类的定义如下:

struct Engine {  using ready_queue_type = std::deque<std::pair<std::shared_ptr<Function>, InputBuffer>>;  using dependencies_type = std::unordered_map<Function*, int>;  virtual variable_list execute(const edge_list& roots,const variable_list& inputs,...const edge_list& outputs = {});  void queue_callback(std::function<void()> callback);protected:  void compute_dependencies(Function* root, GraphTask& task);  void evaluate_function(FunctionTask& task);  void start_threads();  virtual void thread_init(int device);  virtual void thread_main(GraphTask *graph_task);  std::vector<std::shared_ptr<ReadyQueue>> ready_queues;};

核心就是execute函数,它接收一组Edge——(Function, input number) pairs ——来作为函数的输入,然后通过next_edge不断的找到指向的下一个Edge,最终完成整个Graph的计算。

#派生类PythonEngine

然而我们实际使用的是Engine类的派生类:PythonEngine。PythonEngine子类重写了父类的execute,只不过仅仅提供了把C++异常翻译为Python异常的功能,核心工作还是由Engine基类来完成:

struct PythonEngine : public Engine

整个PyTorch程序全局只维护一个Engine实例,也就是PythonEngine实例。

BP调用栈

既然Engine是用来计算网络反向传播的,我们不妨看下这个调用栈是怎么到达Engine类的。如果我们对gemfieldout进行backward计算,则调用栈如下所示:

#torch/tensor.py,self is gemfieldoutdef backward(self, gradient=None, retain_graph=None, create_graph=False)|V#torch.autograd.backward(self, gradient, retain_graph, create_graph)#torch/autograd/__init__.pydef backward(tensors, grad_tensors=None, retain_graph=None, create_graph=False, grad_variables=None)|VVariable._execution_engine.run_backward(tensors, grad_tensors, retain_graph, create_graph,allow_unreachable=True)#转化为Variable._execution_engine.run_backward((gemfieldout,), (tensor(1.),), False, False,True)|V#torch/csrc/autograd/python_engine.cppPyObject *THPEngine_run_backward(THPEngine *self, PyObject *args, PyObject *kwargs)|V#torch/csrc/autograd/python_engine.cppvariable_list PythonEngine::execute(const edge_list& roots, const variable_list& inputs, bool keep_graph, bool create_graph, const edge_list& outputs)|V#torch/csrc/autograd/engine.cpp

总结

在下段文章中,Gemfield将主要介绍Engine这个类是如何在gemfieldout.backward()中运行PyTorch动态图的。

下载1:OpenCV-Contrib扩展模块中文版教程

在「小白学视觉」公众号后台回复:扩展模块中文教程即可下载全网第一份OpenCV扩展模块教程中文版,涵盖扩展模块安装、SFM算法、立体视觉、目标跟踪、生物视觉、超分辨率处理等二十多章内容。

下载2:Python视觉实战项目52讲

在「小白学视觉」公众号后台回复:Python视觉实战项目即可下载包括图像分割、口罩检测、车道线检测、车辆计数、添加眼线、车牌识别、字符识别、情绪检测、文本内容提取、面部识别等31个视觉实战项目,助力快速学校计算机视觉。

下载3:OpenCV实战项目20讲

在「小白学视觉」公众号后台回复:OpenCV实战项目20讲即可下载含有20个基于OpenCV实现20个实战项目,实现OpenCV学习进阶。

交流群

欢迎加入公众号读者群一起和同行交流,目前有SLAM、三维视觉、传感器、自动驾驶、计算摄影、检测、分割、识别、医学影像、GAN、算法竞赛等微信群(以后会逐渐细分),请扫描下面微信号加群,备注:”昵称+学校/公司+研究方向“,例如:”张三 + 上海交大 + 视觉SLAM“。请按照格式备注,否则不予通过。添加成功后会根据研究方向邀请进入相关微信群。请勿在群内发送广告,否则会请出群,谢谢理解~

详尽 | PyTorch动态图解析相关推荐

  1. 可能是最详尽的PyTorch动态图解析

    ↑ 点击蓝字 关注视学算法 作者丨Gemfield@@知乎 来源丨https://zhuanlan.zhihu.com/p/61765561.https://zhuanlan.zhihu.com/p/ ...

  2. 收藏 | 可能是最详尽的PyTorch动态图解析

    点上方蓝字计算机视觉联盟获取更多干货 在右上方 ··· 设为星标 ★,与你不见不散 仅作学术分享,不代表本公众号立场,侵权联系删除 转载于:作者丨Gemfield@@知乎 来源丨https://zhu ...

  3. pytorch入门学习(四)-----计算图与动态图

    计算图: 用来描述运算的有向无环图有两个主要元素,结点note 边edge结点表示数据,如向量,矩阵,张量边表示运算,如加减乘除使用计算图主要是为了求导方便, 只需要沿着计算图的方向找到需要求导对象的 ...

  4. 【深度学习】村通网之——谈谈Tensorflow Eager Execution机制之静态图和动态图的区别(一)

    文章目录 前言 介绍 搭建静态图 搭建动态图 前言 随着TensorFlow 1.4 Eager Execution的出现,TensorFlow的使用出现了革命性的变化. 介绍 我很早就听说过这样一句 ...

  5. 一文详解pytorch的“动态图”与“自动微分”技术

    前言 众所周知,Pytorch是一个非常流行且深受好评的深度学习训练框架.这与它的两大特性"动态图"."自动微分"有非常大的关系."动态图" ...

  6. python合成gif动图_Python图像处理之gif动态图的解析与合成操作详解

    本文实例讲述了Python图像处理之gif动态图的解析与合成操作.分享给大家供大家参考,具体如下: gif动态图是在现在已经司空见惯,朋友圈里也经常是一言不合就斗图.这里,就介绍下如何使用python ...

  7. 清华「计图」现在支持国产芯片了!动态图推理比PyTorch快了270倍

    明敏 发自 凹非寺 量子位 报道 | 公众号 QbitAI 清华自研的深度学习框架计图(Jittor)在动态图推理速度上又一次完胜PyTorch. 最近,计图团队完成了在寒武纪芯片MLU270上的移植 ...

  8. python动态图-Python图像处理之gif动态图的解析与合成操作详解

    本文实例讲述了Python图像处理之gif动态图的解析与合成操作.分享给大家供大家参考,具体如下: gif动态图是在现在已经司空见惯,朋友圈里也经常是一言不合就斗图.这里,就介绍下如何使用python ...

  9. python绘制动态图-Python图像处理之gif动态图的解析与合成操作详解

    本文实例讲述了Python图像处理之gif动态图的解析与合成操作.分享给大家供大家参考,具体如下: gif动态图是在现在已经司空见惯,朋友圈里也经常是一言不合就斗图.这里,就介绍下如何使用python ...

最新文章

  1. eclipse导入myeclipse的web项目在eclipse中不能识别成web项目
  2. 菜鸟的 Sass 学习笔记
  3. 主板h110能装linux吗_H110主板好用吗 H110主板配CPU技巧介绍(DIY装机必看)
  4. NumPy简明教程(二、数组2)
  5. virtual function的一些心得
  6. 第六十三期:放下你手里的代码,小心被抓!
  7. 一台服务器上部署多个Terracotta的方法
  8. 上学帮:阿里云助力教育资讯平台防爬虫
  9. 了解 | 你必须了解的Mysql 三大日志
  10. .net 这些年发展 参考资料
  11. 使用动态代理爬取某房产平台信息并写入Excel(python)
  12. 【修正补发】Scratch2exe-ch将sb2文件转换为exe文件
  13. 局域网服务器共享不稳定怎么办,局域网计算机文件共享异常解决方案
  14. comsol兼容服务器系统,comsol 云服务器
  15. 无人机的电调及其工作原理是什么?
  16. 浅谈怎样入侵服务器,仅供学习用
  17. 语音文件格式转换:.amr 转 .MP3, .wav格式
  18. R语言逻辑运算符(Logical Operators,大于、小于、等于、不等于、与或非、是否为真)、R语言逻辑运算符(Logical Operators)实战示例
  19. std::this_thread::sleep_for 使用
  20. 第七章 - 类的详细设计

热门文章

  1. 阿里AI摘图像识别竞赛WebVision桂冠,万物识别准确率创世界纪录
  2. 英伟达发布RTX 2000系列显卡,“实时光线追踪”究竟能为游戏带来什么?
  3. IDEA + Vim = 得劲
  4. 设计模式在工作中的实践
  5. 分布式定时任务xxl-job的常用姿势都集齐了,So Easy!
  6. Spring Boot集成Sharding-jdbc + Mybatis-Plus实现分库分表
  7. 微服务架构之「 容错隔离 」
  8. 拨开云雾见天日:剖析单机事务原理
  9. 西瓜书公式推导讲解来了!
  10. LightGBM模型_相关资料整理