参考:https://pytorch.org/docs/stable/fx.html

PyTorch:FX 概要

  • Intro
  • FX 定义的类对象
  • Example for Transformation
    • Direct Graph Manipulation
      • 简单替换 Node (利用 Pattern)
      • 复杂替换 Node(利用 Proxy)
    • GraphModule Modification
  • 符号追踪的局限性(注意事项)
    • 控制流(Control Flow)
      • 动态控制流
      • 静态控制流
    • 非 torch 函数
  • 查看 Graph 内容
    • 通过 print() 函数
    • 通过 print_tabular() 函数
  • 总结

Intro

  FX 是针对 torch.nn.module 而开发的工具,其能动态地获取 model 前向传播的执行过程,以便动态地增加、删除、改动、检查运算操作。其由三个主要组件组成:符号追踪器(Symbolic Tracer)、中间表示(Intermediate Representation, IR)和 Python 代码生成。这三个组件常常同时出现,如下面的例子:

import torch
# 一个简单的模型
class MyModule(torch.nn.Module):def __init__(self):super().__init__()self.param = torch.nn.Parameter(torch.rand(3, 4))self.linear = torch.nn.Linear(4, 5)def forward(self, x):return self.linear(x + self.param).clamp(min=0.0, max=1.0)module = MyModule()from torch.fx import symbolic_trace
# 符号追踪。捕获模型的forward的内容
symbolic_traced : torch.fx.GraphModule = symbolic_trace(module)# 查看该模型的 IR图
print(symbolic_traced.graph)
"""
graph():%x : [#users=1] = placeholder[target=x]%param : [#users=1] = get_attr[target=param]%add : [#users=1] = call_function[target=operator.add](args = (%x, %param), kwargs = {})%linear : [#users=1] = call_module[target=linear](args = (%add,), kwargs = {})%clamp : [#users=1] = call_method[target=clamp](args = (%linear,), kwargs = {min: 0.0, max: 1.0})return clamp
"""# 查看由 IR 图生成的 Python 代码
print(symbolic_traced.code)
"""
def forward(self, x):param = self.paramadd = x + param;  x = param = Nonelinear = self.linear(add);  add = Noneclamp = linear.clamp(min = 0.0, max = 1.0);  linear = Nonereturn clamp
"""
  • 符号追踪器会对代码执行“符号”。其喂入(自己生成的)假数据(Proxy),来执行代码。由 Proxy 经过、执行的代码会被记录下来。
  • 中间表示是在 Trace 期间记录各种操作的“容器”。其由一个节点(Node)列表组成,这些节点表示了函数的输入、名字和返回值。
  • Python 代码生成是一种代码生成工具,可以根据当前 IR 图的内容生成正确、可执行的 Python 代码。这代码是可以复制出来黏贴使用的,可以用于进一步配置模型的(forward)定义。

  总的来说,FX 的使用流程为:符号跟踪->中间表示->转换->Python代码生成。这是一种 Python-to-Python 的方法。FX 的精髓在于“Dynamic Transformation”,即当你需要对模型进行额外改动设计(如插入量化节点、算子 Fusion)时,不需要繁琐地针对模型的每一个部分来修改代码,只需要按照 FX 的流程来高效自动化地实现。

FX 定义的类对象

  • GraphModule:是由 fx.Graph 生成而来的 nn.Module,其有对应的 graphcode 成员变量。当 graph 成员变量被重新赋值过,code 变量和 forward() 函数回自动重新生成。如果你编辑过 graph 的内容却没有重新赋值过,那你必须调用 recompile() 函数来更新信息。torch.fx.symbolic_trace() 函数作用完后 return 的就是 GraphModule
  • Graph:是 FX 的 IR 图的主要数据结构,由一系列有序的 Node 组成。这一一系列的 Node 就构成了执行逻辑。torch.fx.Tracer.trace() 函数作用完后 return 的就是 Graph
  • Node:是 graph 中操作的单位数据结构。大多数情况下,Node 代表了各种实体的调用方式,如输入(Input)、输出(Output)、算子(Operator)、已执行的成员函数(Method)和子模型(Module)。每个 Node 都有一个 op 属性,具体分类如下:
    • placeholder:表示整个模型的输入。
    • get_attr:表示从模型层次结构中检索参数。
    • call_function:表示将自由函数应用于某些值。
    • call_module:表示将模型层次结构的 forward() 成员函数中的子模块应用于给定参数。
    • call_method:表示对某值调用成员函数。
    • output:这与打印 graph 输出中的 return 语句内容相对应。
  • Proxy:在符号追踪期间会用到。其本质上是一个 Node Wrapper,用于流经程序的执行过程并记录下所有的操作(被调用的 torch function、method 和 operator)。若没有主动设置的话,Pytorch 会生成默认的 Proxy 用于符号追踪 。

Example for Transformation

  对模型的图进行额外改动的方法有很多,如直接获取图并修改图(Direct Graph Manipulation),或通过在 GraphModule 模型上间接获取图来修改图(GraphModule Modification)。

Direct Graph Manipulation

简单替换 Node (利用 Pattern)

  1. 遍历 GraphModuleGraph 中的所有 Node
  2. 判断当前 Node 是否满足替换要求(可以用 target 属性作为判断条件)。
  3. 创建一个新的 Node 并插入到 Graph 中。
  4. 使用 FX 内置的 replace_all_uses_with 函数来将要被替换 Node 的输入输出流(flow)重新定向到新 Node 身上。
  5. Graph 中删除旧 Node。
  6. 调用 recompile() 函数来更新 GraphModule

  下面一个例子展示 FX 如何将任何加法操作替换成二进制与(AND)运算:

import torch
from torch.fx import symbolic_trace
import operator# 定义一个简单的模型
class M(torch.nn.Module):def forward(self, x, y):return x + y, torch.add(x, y), x.add(y)# 进行符号追踪
traced = symbolic_trace(M())# 加法操作有三种:
#     1. x + y,其成为 Node 时的 target 为 operator.add。
#     2. torch.add(x, y),其成为 Node 时的 target 为 torch.add.
#     3. x.add(y),其成为 Node 时的 target 为字符串 "add".
patterns = set([operator.add, torch.add, "add"])# 遍历 Graph 中所有 node
for n in traced.graph.nodes:# 如果满足 pattren 之一if any(n.target == pattern for pattern in patterns):# 在指定位置插入新 node (还没建立连接关系)with traced.graph.inserting_after(n):new_node = traced.graph.call_function(torch.bitwise_and, n.args, n.kwargs)# 建立连接,将旧 node 的连接关系重定向到新 node 上。n.replace_all_uses_with(new_node)# 从 Graph 中删除旧 nodetraced.graph.erase_node(n)# 必须 recompile!
traced.recompile()

复杂替换 Node(利用 Proxy)

  另一个修改 Graph 的方式是利用 Proxy,再在一次主动 Trace 的过程中复制 Node、构建新 Node 来组成新的 Graph

import torch
import torch.fx as fx
import torch.nn.functional as F
# 定义一个简单的模型
class M(torch.nn.Module):def forward(self, x, y):o = F.relu(x) + F.relu(y)return o# 数学定义
def relu_decomposition(x):return (x > 0) * xdecomposition_rules = {}
decomposition_rules[F.relu] = relu_decompositiondef decompose_relu(model: torch.nn.Module,tracer_class : type = fx.Tracer) -> torch.nn.Module:graph : fx.Graph = tracer_class().trace(model)new_graph = fx.Graph()# 这相当于一个探针,将旧graph里需要用到的node的名字映射到新graph里对应的node。mapping_table = {}  # {old node name : new node object}# 遍历 nodefor node in graph.nodes:# 判断是否是 relu函数。if node.op == 'call_function' and node.target in decomposition_rules:# 用于记录 proxy当前绑定proxy_args = []# node的arg即输入/上一个node。这一步其实就为该 node 生成对应输入的 proxy。for x in node.args:if isinstance(x, fx.Node):proxy_args.append(fx.Proxy(mapping_table[x.name])) else:proxy_args.append(x)# 这一步就是在“穿线”。穿线完毕后,与Proxy绑定的Graph也自动完成了:# 在原末尾插新加入的node并建立连接。proxy会自动绑定到下一个(输出)node上,# 依次类推,最后就变成 output_proxyoutput_proxy = decomposition_rules[node.target](*proxy_args)# 获取 当前proxy绑定的node。new_node = output_proxy.nodemapping_table[node.name] = new_nodeelse: # 当 node 不需要被拆解时,只需要复制到新graph里就好。# 该函数就是实现旧node与新node的映射关系。def node_mapping(x):return mapping_table[x.name]# node_copy 确实吧 node 拷贝过来了,同时还建立了连接。# 其会访问 node 的 原来所有输入的node,然后再利用opera# 重定向来给新生成的node建立在目标Graph上的连接。new_node = new_graph.node_copy(node, node_mapping)mapping_table[node.name] = new_node# 最后返回的模型绑定的是新 graphreturn fx.GraphModule(model, new_graph)decompose_relu(M())

  Proxy 可以想象为一个“穿线器”:绑定 Node 后,在经过新的 Node 时能自动“串”好连接关系并加入到原 Graph 中。能记录此时的“线头”,即记录访问到的 Node

GraphModule Modification

下面一个例子展示 FX 是如何通过 GraphModule 间接替换 torch.add() 为 torch.mul() 的:

import torch
import torch.fx as fx# 定义一个简单模型
class M(torch.nn.Module):def forward(self, x, y):return torch.add(x, y)
# 下面尝试用替换 target 的方式来改动 graph (不提倡,因为对应的 node 的 name 没有改动!)
def transform(m: torch.nn.Module) -> torch.nn.Module:gm : fx.GraphModule = fx.symbolic_trace(m)# FX 的 IR 图是顺序储存节点,所以可以遍历for node in gm.graph.nodes:# 检查该节点是否是函数操作 (i.e: torch.add)if node.op == 'call_function':# 确认是该节点是函数操作时if node.target == torch.add:node.target = torch.mulgm.recompile()  # 重新编译 GraphModule,更新 code 属性gm.graph.lint() # 最后需要检查修改过的 IR 图是否符合FX语法return gmtransform(M())

符号追踪的局限性(注意事项)

控制流(Control Flow)

  PyTorch 官方将 if 语句循环语句等具有选择/判断性质的语句称为控制流。在 FX语境中,控制流又可以分为动态控制流(Dynamic Control Flow)和静态控制流(Static Control Flow)。

  FX 无法 trace 动态控制流,但可以 trace 判断条件明确的静态控制流。

动态控制流

  若控制流的判断条件含有运算变量(Input Tensor)参与,那么该控制流就称为动态控制流,如:

def func_to_trace(x):if x.sum() > 0:# 可以看到x变量既参与计算,又参与判断return torch.relu(x)else:return torch.neg(x)

  此时对该函数使用 trace 功能就会报错:

"""
raise TraceError('symbolically traced variables cannot be used as inputs to control flow')
torch.fx.proxy.TraceError: symbolically traced variables cannot be used as inputs to control flow
"""

静态控制流

  类推可知,若控制流的判断条件无运算变量参与,也即判断条件的变量不参与流(Flow)计算,那么该控制流就称为静态控制流,如:

import torch
import torch.fxclass MyModule(torch.nn.Module):def __init__(self, do_activation : bool = False):super().__init__()self.do_activation = do_activationself.linear = torch.nn.Linear(512, 512)def forward(self, x):x = self.linear(x)# 该if语句就是静态控制流if self.do_activation:x = torch.relu(x)return x

  若想 trace 静态控制流,就需要明确判断条件,即给判断变量显式赋值:

without_activation = MyModule(do_activation=False)
# 然后就可以 trace
traced_without_activation = torch.fx.symbolic_trace(without_activation)

非 torch 函数

  有些函数没有__torch_function__属性,例如 Python 自带的函数或 math 库中的函数,无法被 trace 追踪。例如,当你的模型里调用了 len() 函数,那么进行 trace 时会报错:

"""
raise RuntimeError("'len' is not supported in symbolic tracing by default. If you want ")
RuntimeError: 'len' is not supported in symbolic tracing by default. If you want this call to be recorded, please call torch.fx.wrap('len') at module scope
"""

  那么需要使用 wrap() API 来将普通函数包装成 torch 性质的函数:

torch.fx.wrap('len')
# 然后就可以正常 trace 了
traced = torch.fx.symbolic_trace(normalize)

查看 Graph 内容

通过 print() 函数

如:

# 模型定义过程就不展示了
print(traced_model.graph)
"""
graph():%x : [#users=1] = placeholder[target=x]%param : [#users=1] = get_attr[target=param]%add : [#users=1] = call_function[target=operator.add](args = (%x, %param), kwargs = {})%linear : [#users=1] = call_module[target=linear](args = (%add,), kwargs = {})%clamp : [#users=1] = call_method[target=clamp](args = (%linear,), kwargs = {min: 0.0, max: 1.0})return clamp
"""

通过 print_tabular() 函数

  通过调用 print_tabular() 函数就可以以 tabular 的格式输出 IR 图:

# 模型定义过程就不展示了
traced_model.graph.print_tabular()
"""
opcode         name      target                   args         kwargs
-------------  --------  -----------------------  -----------  ------------------------
placeholder    x         x                        ()           {}
get_attr       param     param                    ()           {}
call_function  add_1     <built-in function add>  (x, param)   {}
call_module    linear_1  linear                   (add_1,)     {}
call_method    clamp_1   clamp                    (linear_1,)  {'min': 0.0, 'max': 1.0}
output         output    output                   (clamp_1,)   {}
"""

总结

  目前,FX 没有提供任何方式来保证/验证运算符在语法上是有效的。也就是说,任何新(定义)加入的运算符都必须由用户自己来保证其正确性。

  最后官网建议的一点是,你在对 Graph 做变换时,应该让整个程序的输入 torch.nn.Module,然后获取对应的 Graph,做出修改,最后再返回一个 torch.nn.Module。这样更方便后续工作,比如又传入下一段 FX 代码中。

  以上总结如有谬误,还请包涵、指正。

PyTorch Python API:FX || Intro相关推荐

  1. PyTorch Python API详解大全(持续更新ing...)

    诸神缄默不语-个人CSDN博文目录 具体内容以官方文档为准. 最早更新时间:2021.4.23 最近更新时间:2023.1.9 文章目录 0. 常用入参及函数统一解释 1. torch 1.1 Ten ...

  2. 使用Maple的Python API :OpenMaple(Windows下的解决方案)

    在Maple 2023(按照软件文档,Maple 2018及以上版本均适用:我目前测试的版本为2023)的安装目录下,有软件附带的解释器,如D:\Program Files\Maple 2023\Py ...

  3. thinkcmf5调用指定分类的二级_Tengine快速上手系列教程amp;视频:基于Python API的图片分类应用入门丨附彩蛋...

    前言:近期,Tengine团队加班加点,好消息接踵而来,OpenCV 4.3.0发布,OPEN AI LAB AIoT智能开发平台Tengine与OpenCV合作共同加速边缘智能,Tengine再获业 ...

  4. Sublime Text 4 首个稳定版终于来了:支持 GPU 渲染、兼容旧版本、Python API 升级

    技术编辑:小魔丨发自 思否编辑部 公众号:SegmentFault Sublime Text 是一个轻量.简洁.高效.跨平台的编辑器,支持 Linux.Windows 和 Mac OS X 操作系统, ...

  5. python枪_大疆机甲大师教育机器人Python API中文化之一:枪亮枪暗

    之前开始整理机甲的Python API,但纸上得来终觉浅,而且发现有些API与即使官方qq群的教程文档也有少许出入,于是打算逐个测试.这一系列将附上真机运行视频,以便以后直观看到最终演示效果. 先从灯 ...

  6. 强化学习系列文章(二十三):AirSim Python API图像与图像处理

    强化学习系列文章(二十三):AirSim Python API图像与图像处理 参考网址:https://microsoft.github.io/AirSim/image_apis/#segmentat ...

  7. openstack二次开发:Python API

    2019独角兽企业重金招聘Python工程师标准>>> 作 为 OpenStack 用户或管理员,您常常需要编写脚本来自动化常见任务.除了 REST 和命令行接口之外,OpenSta ...

  8. OpenPose Python API调用:ImportError: cannot import name 'pyopenpose' from 'openpose'

    OpenPose Python API 调用方法 cmake-gui选项中勾选BUILD_PYTHON选项 cd build sudo make -j'' sudo make install cd ~ ...

  9. Python 人工智能:16~20

    原文:Artificial Intelligence with Python 协议:CC BY-NC-SA 4.0 译者:飞龙 本文来自[ApacheCN 深度学习 译文集],采用译后编辑(MTPE) ...

  10. keras cnn注意力机制_TensorFlow、PyTorch、Keras:NLP框架哪家强

    全文共3412字,预计学习时长7分钟 在对TensorFlow.PyTorch和Keras做功能对比之前,先来了解一些它们各自的非竞争性柔性特点吧. 非竞争性特点 下文介绍了TensorFlow.Py ...

最新文章

  1. javafx官方文档学习之二Scene体系学习一
  2. mysql文本数据_mysql操作文本数据
  3. tableau必知必会之用蝴蝶图(旋风图)实现数据之间对比
  4. 20175213 2018-2019-2 《Java程序设计》第6周学习总结
  5. 2013年第四届蓝桥杯C/C++ A组国赛 —— 第二题:骰子迷题
  6. C语言重新定位文件,C语言代码重定位 (原创)
  7. cba比赛比分预测_【CBA直播】深圳vs广东前瞻:深圳战广东再掀反攻?
  8. lisp 揭 院长_HISLISPACSRIS EMR系统简介
  9. R语言学习 - 热图美化 (数值标准化和调整坐标轴顺序)
  10. lLinux系统安全sudo+pam
  11. java实例 内存_一个分析和解决Java应用程序内存浪费的实战例子,值得收藏!
  12. break 退出循环
  13. sam机架和kx连线图_创新声卡KX 3552驱动连线搭载SAM机架
  14. 创业维艰,且行且珍惜
  15. 百度地图绘制3D棱柱
  16. 学数学,要“直觉”还是要“严谨”?
  17. 微信小程序之小程序审核
  18. 生鲜电商O2O 可以怎么做?
  19. Java中 构造方法 和 成员方法 的区别(图文介绍)
  20. 2022年11月20日-2022年11月26日学习周报

热门文章

  1. UE4 制作玻璃材质总结
  2. 第4章 网络信息资源检索
  3. 记录一次pre环境OOM异常解决过程
  4. c语言递归阶乘汉诺塔文曲星游戏词典制作文件调用整合
  5. c++ QT 反走样
  6. 放大电路静态工作点的稳定
  7. Redis java如何清除缓存 redisTemplate
  8. 最小生成树的第三种求法-Borůvka (Sollin) 算法
  9. PTA翁恺7-6 厘米换算英尺英寸 (15 分)
  10. 匠心独运: python打造GUI图形小窗口