PyTorch Python API:FX || Intro
参考:https://pytorch.org/docs/stable/fx.html
PyTorch:FX 概要
- Intro
- FX 定义的类对象
- Example for Transformation
- Direct Graph Manipulation
- 简单替换 Node (利用 Pattern)
- 复杂替换 Node(利用 Proxy)
- GraphModule Modification
- 符号追踪的局限性(注意事项)
- 控制流(Control Flow)
- 动态控制流
- 静态控制流
- 非 torch 函数
- 查看 Graph 内容
- 通过 print() 函数
- 通过 print_tabular() 函数
- 总结
Intro
FX 是针对 torch.nn.module
而开发的工具,其能动态地获取 model 前向传播的执行过程,以便动态地增加、删除、改动、检查运算操作。其由三个主要组件组成:符号追踪器(Symbolic Tracer)、中间表示(Intermediate Representation, IR)和 Python 代码生成。这三个组件常常同时出现,如下面的例子:
import torch
# 一个简单的模型
class MyModule(torch.nn.Module):def __init__(self):super().__init__()self.param = torch.nn.Parameter(torch.rand(3, 4))self.linear = torch.nn.Linear(4, 5)def forward(self, x):return self.linear(x + self.param).clamp(min=0.0, max=1.0)module = MyModule()from torch.fx import symbolic_trace
# 符号追踪。捕获模型的forward的内容
symbolic_traced : torch.fx.GraphModule = symbolic_trace(module)# 查看该模型的 IR图
print(symbolic_traced.graph)
"""
graph():%x : [#users=1] = placeholder[target=x]%param : [#users=1] = get_attr[target=param]%add : [#users=1] = call_function[target=operator.add](args = (%x, %param), kwargs = {})%linear : [#users=1] = call_module[target=linear](args = (%add,), kwargs = {})%clamp : [#users=1] = call_method[target=clamp](args = (%linear,), kwargs = {min: 0.0, max: 1.0})return clamp
"""# 查看由 IR 图生成的 Python 代码
print(symbolic_traced.code)
"""
def forward(self, x):param = self.paramadd = x + param; x = param = Nonelinear = self.linear(add); add = Noneclamp = linear.clamp(min = 0.0, max = 1.0); linear = Nonereturn clamp
"""
- 符号追踪器会对代码执行“符号”。其喂入(自己生成的)假数据(Proxy),来执行代码。由 Proxy 经过、执行的代码会被记录下来。
- 中间表示是在 Trace 期间记录各种操作的“容器”。其由一个节点(Node)列表组成,这些节点表示了函数的输入、名字和返回值。
- Python 代码生成是一种代码生成工具,可以根据当前 IR 图的内容生成正确、可执行的 Python 代码。这代码是可以复制出来黏贴使用的,可以用于进一步配置模型的(
forward
)定义。
总的来说,FX 的使用流程为:符号跟踪->中间表示->转换->Python代码生成。这是一种 Python-to-Python 的方法。FX 的精髓在于“Dynamic Transformation”,即当你需要对模型进行额外改动设计(如插入量化节点、算子 Fusion)时,不需要繁琐地针对模型的每一个部分来修改代码,只需要按照 FX 的流程来高效自动化地实现。
FX 定义的类对象
- GraphModule:是由
fx.Graph
生成而来的nn.Module
,其有对应的graph
、code
成员变量。当graph
成员变量被重新赋值过,code
变量和forward()
函数回自动重新生成。如果你编辑过graph
的内容却没有重新赋值过,那你必须调用recompile()
函数来更新信息。torch.fx.symbolic_trace()
函数作用完后return
的就是GraphModule
。 - Graph:是 FX 的 IR 图的主要数据结构,由一系列有序的
Node
组成。这一一系列的Node
就构成了执行逻辑。torch.fx.Tracer.trace()
函数作用完后 return 的就是Graph
。 - Node:是
graph
中操作的单位数据结构。大多数情况下,Node
代表了各种实体的调用方式,如输入(Input)、输出(Output)、算子(Operator)、已执行的成员函数(Method)和子模型(Module)。每个Node
都有一个op
属性,具体分类如下:placeholder
:表示整个模型的输入。get_attr
:表示从模型层次结构中检索参数。call_function
:表示将自由函数应用于某些值。call_module
:表示将模型层次结构的forward()
成员函数中的子模块应用于给定参数。call_method
:表示对某值调用成员函数。output
:这与打印graph
输出中的return
语句内容相对应。
- Proxy:在符号追踪期间会用到。其本质上是一个
Node Wrapper
,用于流经程序的执行过程并记录下所有的操作(被调用的 torch function、method 和 operator)。若没有主动设置的话,Pytorch 会生成默认的Proxy
用于符号追踪 。
Example for Transformation
对模型的图进行额外改动的方法有很多,如直接获取图并修改图(Direct Graph Manipulation),或通过在 GraphModule
模型上间接获取图来修改图(GraphModule Modification)。
Direct Graph Manipulation
简单替换 Node (利用 Pattern)
- 遍历
GraphModule
的Graph
中的所有Node
。 - 判断当前
Node
是否满足替换要求(可以用target
属性作为判断条件)。 - 创建一个新的
Node
并插入到Graph
中。 - 使用 FX 内置的 replace_all_uses_with 函数来将要被替换
Node
的输入输出流(flow)重新定向到新Node
身上。 - 从
Graph
中删除旧 Node。 - 调用
recompile()
函数来更新GraphModule
。
下面一个例子展示 FX 如何将任何加法操作替换成二进制与(AND)运算:
import torch
from torch.fx import symbolic_trace
import operator# 定义一个简单的模型
class M(torch.nn.Module):def forward(self, x, y):return x + y, torch.add(x, y), x.add(y)# 进行符号追踪
traced = symbolic_trace(M())# 加法操作有三种:
# 1. x + y,其成为 Node 时的 target 为 operator.add。
# 2. torch.add(x, y),其成为 Node 时的 target 为 torch.add.
# 3. x.add(y),其成为 Node 时的 target 为字符串 "add".
patterns = set([operator.add, torch.add, "add"])# 遍历 Graph 中所有 node
for n in traced.graph.nodes:# 如果满足 pattren 之一if any(n.target == pattern for pattern in patterns):# 在指定位置插入新 node (还没建立连接关系)with traced.graph.inserting_after(n):new_node = traced.graph.call_function(torch.bitwise_and, n.args, n.kwargs)# 建立连接,将旧 node 的连接关系重定向到新 node 上。n.replace_all_uses_with(new_node)# 从 Graph 中删除旧 nodetraced.graph.erase_node(n)# 必须 recompile!
traced.recompile()
复杂替换 Node(利用 Proxy)
另一个修改 Graph
的方式是利用 Proxy
,再在一次主动 Trace
的过程中复制 Node
、构建新 Node
来组成新的 Graph
。
import torch
import torch.fx as fx
import torch.nn.functional as F
# 定义一个简单的模型
class M(torch.nn.Module):def forward(self, x, y):o = F.relu(x) + F.relu(y)return o# 数学定义
def relu_decomposition(x):return (x > 0) * xdecomposition_rules = {}
decomposition_rules[F.relu] = relu_decompositiondef decompose_relu(model: torch.nn.Module,tracer_class : type = fx.Tracer) -> torch.nn.Module:graph : fx.Graph = tracer_class().trace(model)new_graph = fx.Graph()# 这相当于一个探针,将旧graph里需要用到的node的名字映射到新graph里对应的node。mapping_table = {} # {old node name : new node object}# 遍历 nodefor node in graph.nodes:# 判断是否是 relu函数。if node.op == 'call_function' and node.target in decomposition_rules:# 用于记录 proxy当前绑定proxy_args = []# node的arg即输入/上一个node。这一步其实就为该 node 生成对应输入的 proxy。for x in node.args:if isinstance(x, fx.Node):proxy_args.append(fx.Proxy(mapping_table[x.name])) else:proxy_args.append(x)# 这一步就是在“穿线”。穿线完毕后,与Proxy绑定的Graph也自动完成了:# 在原末尾插新加入的node并建立连接。proxy会自动绑定到下一个(输出)node上,# 依次类推,最后就变成 output_proxyoutput_proxy = decomposition_rules[node.target](*proxy_args)# 获取 当前proxy绑定的node。new_node = output_proxy.nodemapping_table[node.name] = new_nodeelse: # 当 node 不需要被拆解时,只需要复制到新graph里就好。# 该函数就是实现旧node与新node的映射关系。def node_mapping(x):return mapping_table[x.name]# node_copy 确实吧 node 拷贝过来了,同时还建立了连接。# 其会访问 node 的 原来所有输入的node,然后再利用opera# 重定向来给新生成的node建立在目标Graph上的连接。new_node = new_graph.node_copy(node, node_mapping)mapping_table[node.name] = new_node# 最后返回的模型绑定的是新 graphreturn fx.GraphModule(model, new_graph)decompose_relu(M())
Proxy
可以想象为一个“穿线器”:绑定 Node
后,在经过新的 Node
时能自动“串”好连接关系并加入到原 Graph
中。能记录此时的“线头”,即记录访问到的 Node
。
GraphModule Modification
下面一个例子展示 FX 是如何通过 GraphModule 间接替换 torch.add() 为 torch.mul() 的:
import torch
import torch.fx as fx# 定义一个简单模型
class M(torch.nn.Module):def forward(self, x, y):return torch.add(x, y)
# 下面尝试用替换 target 的方式来改动 graph (不提倡,因为对应的 node 的 name 没有改动!)
def transform(m: torch.nn.Module) -> torch.nn.Module:gm : fx.GraphModule = fx.symbolic_trace(m)# FX 的 IR 图是顺序储存节点,所以可以遍历for node in gm.graph.nodes:# 检查该节点是否是函数操作 (i.e: torch.add)if node.op == 'call_function':# 确认是该节点是函数操作时if node.target == torch.add:node.target = torch.mulgm.recompile() # 重新编译 GraphModule,更新 code 属性gm.graph.lint() # 最后需要检查修改过的 IR 图是否符合FX语法return gmtransform(M())
符号追踪的局限性(注意事项)
控制流(Control Flow)
PyTorch 官方将 if 语句、循环语句等具有选择/判断性质的语句称为控制流。在 FX语境中,控制流又可以分为动态控制流(Dynamic Control Flow)和静态控制流(Static Control Flow)。
FX 无法 trace
动态控制流,但可以 trace
判断条件明确的静态控制流。
动态控制流
若控制流的判断条件含有运算变量(Input Tensor)参与,那么该控制流就称为动态控制流,如:
def func_to_trace(x):if x.sum() > 0:# 可以看到x变量既参与计算,又参与判断return torch.relu(x)else:return torch.neg(x)
此时对该函数使用 trace
功能就会报错:
"""
raise TraceError('symbolically traced variables cannot be used as inputs to control flow')
torch.fx.proxy.TraceError: symbolically traced variables cannot be used as inputs to control flow
"""
静态控制流
类推可知,若控制流的判断条件无运算变量参与,也即判断条件的变量不参与流(Flow)计算,那么该控制流就称为静态控制流,如:
import torch
import torch.fxclass MyModule(torch.nn.Module):def __init__(self, do_activation : bool = False):super().__init__()self.do_activation = do_activationself.linear = torch.nn.Linear(512, 512)def forward(self, x):x = self.linear(x)# 该if语句就是静态控制流if self.do_activation:x = torch.relu(x)return x
若想 trace
静态控制流,就需要明确判断条件,即给判断变量显式赋值:
without_activation = MyModule(do_activation=False)
# 然后就可以 trace
traced_without_activation = torch.fx.symbolic_trace(without_activation)
非 torch 函数
有些函数没有__torch_function__
属性,例如 Python 自带的函数或 math
库中的函数,无法被 trace
追踪。例如,当你的模型里调用了 len() 函数,那么进行 trace
时会报错:
"""
raise RuntimeError("'len' is not supported in symbolic tracing by default. If you want ")
RuntimeError: 'len' is not supported in symbolic tracing by default. If you want this call to be recorded, please call torch.fx.wrap('len') at module scope
"""
那么需要使用 wrap()
API 来将普通函数包装成 torch 性质的函数:
torch.fx.wrap('len')
# 然后就可以正常 trace 了
traced = torch.fx.symbolic_trace(normalize)
查看 Graph 内容
通过 print() 函数
如:
# 模型定义过程就不展示了
print(traced_model.graph)
"""
graph():%x : [#users=1] = placeholder[target=x]%param : [#users=1] = get_attr[target=param]%add : [#users=1] = call_function[target=operator.add](args = (%x, %param), kwargs = {})%linear : [#users=1] = call_module[target=linear](args = (%add,), kwargs = {})%clamp : [#users=1] = call_method[target=clamp](args = (%linear,), kwargs = {min: 0.0, max: 1.0})return clamp
"""
通过 print_tabular() 函数
通过调用 print_tabular()
函数就可以以 tabular 的格式输出 IR 图:
# 模型定义过程就不展示了
traced_model.graph.print_tabular()
"""
opcode name target args kwargs
------------- -------- ----------------------- ----------- ------------------------
placeholder x x () {}
get_attr param param () {}
call_function add_1 <built-in function add> (x, param) {}
call_module linear_1 linear (add_1,) {}
call_method clamp_1 clamp (linear_1,) {'min': 0.0, 'max': 1.0}
output output output (clamp_1,) {}
"""
总结
目前,FX 没有提供任何方式来保证/验证运算符在语法上是有效的。也就是说,任何新(定义)加入的运算符都必须由用户自己来保证其正确性。
最后官网建议的一点是,你在对 Graph
做变换时,应该让整个程序的输入 torch.nn.Module
,然后获取对应的 Graph
,做出修改,最后再返回一个 torch.nn.Module
。这样更方便后续工作,比如又传入下一段 FX 代码中。
以上总结如有谬误,还请包涵、指正。
PyTorch Python API:FX || Intro相关推荐
- PyTorch Python API详解大全(持续更新ing...)
诸神缄默不语-个人CSDN博文目录 具体内容以官方文档为准. 最早更新时间:2021.4.23 最近更新时间:2023.1.9 文章目录 0. 常用入参及函数统一解释 1. torch 1.1 Ten ...
- 使用Maple的Python API :OpenMaple(Windows下的解决方案)
在Maple 2023(按照软件文档,Maple 2018及以上版本均适用:我目前测试的版本为2023)的安装目录下,有软件附带的解释器,如D:\Program Files\Maple 2023\Py ...
- thinkcmf5调用指定分类的二级_Tengine快速上手系列教程amp;视频:基于Python API的图片分类应用入门丨附彩蛋...
前言:近期,Tengine团队加班加点,好消息接踵而来,OpenCV 4.3.0发布,OPEN AI LAB AIoT智能开发平台Tengine与OpenCV合作共同加速边缘智能,Tengine再获业 ...
- Sublime Text 4 首个稳定版终于来了:支持 GPU 渲染、兼容旧版本、Python API 升级
技术编辑:小魔丨发自 思否编辑部 公众号:SegmentFault Sublime Text 是一个轻量.简洁.高效.跨平台的编辑器,支持 Linux.Windows 和 Mac OS X 操作系统, ...
- python枪_大疆机甲大师教育机器人Python API中文化之一:枪亮枪暗
之前开始整理机甲的Python API,但纸上得来终觉浅,而且发现有些API与即使官方qq群的教程文档也有少许出入,于是打算逐个测试.这一系列将附上真机运行视频,以便以后直观看到最终演示效果. 先从灯 ...
- 强化学习系列文章(二十三):AirSim Python API图像与图像处理
强化学习系列文章(二十三):AirSim Python API图像与图像处理 参考网址:https://microsoft.github.io/AirSim/image_apis/#segmentat ...
- openstack二次开发:Python API
2019独角兽企业重金招聘Python工程师标准>>> 作 为 OpenStack 用户或管理员,您常常需要编写脚本来自动化常见任务.除了 REST 和命令行接口之外,OpenSta ...
- OpenPose Python API调用:ImportError: cannot import name 'pyopenpose' from 'openpose'
OpenPose Python API 调用方法 cmake-gui选项中勾选BUILD_PYTHON选项 cd build sudo make -j'' sudo make install cd ~ ...
- Python 人工智能:16~20
原文:Artificial Intelligence with Python 协议:CC BY-NC-SA 4.0 译者:飞龙 本文来自[ApacheCN 深度学习 译文集],采用译后编辑(MTPE) ...
- keras cnn注意力机制_TensorFlow、PyTorch、Keras:NLP框架哪家强
全文共3412字,预计学习时长7分钟 在对TensorFlow.PyTorch和Keras做功能对比之前,先来了解一些它们各自的非竞争性柔性特点吧. 非竞争性特点 下文介绍了TensorFlow.Py ...
最新文章
- javafx官方文档学习之二Scene体系学习一
- mysql文本数据_mysql操作文本数据
- tableau必知必会之用蝴蝶图(旋风图)实现数据之间对比
- 20175213 2018-2019-2 《Java程序设计》第6周学习总结
- 2013年第四届蓝桥杯C/C++ A组国赛 —— 第二题:骰子迷题
- C语言重新定位文件,C语言代码重定位 (原创)
- cba比赛比分预测_【CBA直播】深圳vs广东前瞻:深圳战广东再掀反攻?
- lisp 揭 院长_HISLISPACSRIS EMR系统简介
- R语言学习 - 热图美化 (数值标准化和调整坐标轴顺序)
- lLinux系统安全sudo+pam
- java实例 内存_一个分析和解决Java应用程序内存浪费的实战例子,值得收藏!
- break 退出循环
- sam机架和kx连线图_创新声卡KX 3552驱动连线搭载SAM机架
- 创业维艰,且行且珍惜
- 百度地图绘制3D棱柱
- 学数学,要“直觉”还是要“严谨”?
- 微信小程序之小程序审核
- 生鲜电商O2O 可以怎么做?
- Java中 构造方法 和 成员方法 的区别(图文介绍)
- 2022年11月20日-2022年11月26日学习周报