torch.compiler
torch.compiler
是一个命名空间,通过它向用户开放了一些内部编译器方法。该命名空间中的主要功能和特性是 torch.compile
。
torch.compile
是 PyTorch 2.x 引入的一个函数,旨在解决 PyTorch 中精确图捕获的问题,最终帮助软件工程师加速运行他们的 PyTorch 程序。torch.compile
使用 Python 编写,标志着 PyTorch 从 C++ 向 Python 的过渡。
torch.compile
利用了以下底层技术:
- TorchDynamo (torch._dynamo) 是一个内部 API,它使用 CPython 的 Frame Evaluation API 功能来安全捕获 PyTorch 计算图。通过
torch.compiler
命名空间向 PyTorch 用户开放可用方法。 - TorchInductor 是
torch.compile
默认的深度学习编译器,为多种加速器和后端生成快速代码。需要通过后端编译器才能实现torch.compile
的加速效果。对于 NVIDIA、AMD 和 Intel GPU,它使用 OpenAI Triton 作为关键构建块。 - AOT Autograd 不仅能捕获用户级代码,还能捕获反向传播,实现"提前"捕获反向传递。这使得 TorchInductor 能够同时加速前向和反向传递。
注意:在本文档中,术语 torch.compile
、TorchDynamo 和 torch.compiler
有时会互换使用。
如上所述,要通过 TorchDynamo 运行更快的工作流,torch.compile
需要一个后端将捕获的计算图转换为快速机器码。不同的后端会带来不同的优化效果。默认后端是 TorchInductor(也称为 inductor)。TorchDynamo 还支持由合作伙伴开发的一系列后端,可以通过运行 torch.compiler.list_backends()
查看,每个后端都有其可选依赖项。
一些最常用的后端包括:
训练和推理后端
后端 | 描述 |
---|---|
torch.compile(m, backend="inductor") |
使用 TorchInductor 后端。了解更多 |
torch.compile(m, backend="cudagraphs") |
使用 AOT Autograd 的 CUDA 图。了解更多 |
torch.compile(m, backend="ipex") |
在 CPU 上使用 IPEX。了解更多 |
torch.compile(m, backend="onnxrt") |
使用 ONNX Runtime 在 CPU/GPU 上进行训练。了解更多 |
仅推理后端
后端 | 描述 |
---|---|
torch.compile(m, backend="tensorrt") |
使用 Torch-TensorRT 进行推理优化。需要在调用脚本中 import torch_tensorrt 来注册后端。了解更多 |
torch.compile(m, backend="ipex") |
在 CPU 上使用 IPEX 进行推理。了解更多 |
torch.compile(m, backend="tvm") |
使用 Apache TVM 进行推理优化。了解更多 |
torch.compile(m, backend="openvino") |
使用 OpenVINO 进行推理优化。了解更多 |
延伸阅读
PyTorch 用户入门指南
- 快速入门
- torch.compiler API 参考
- torch.compiler.config 配置
- TorchDynamo 细粒度追踪 API
- AOTInductor: Torch.Export 模型的预编译方案
- TorchInductor GPU 性能分析
- torch.compile 性能剖析指南
- 常见问题解答
- torch.compile 故障排查
- PyTorch 2.0 性能看板
PyTorch 开发者深度解析
PyTorch 后端供应商指南
torch.fft
离散傅里叶变换及相关函数。
快速傅里叶变换
fft |
计算input 的一维离散傅里叶变换 |
---|---|
ifft |
计算input 的一维离散傅里叶逆变换 |
fft2 |
计算input 的二维离散傅里叶变换 |
ifft2 |
计算input 的二维离散傅里叶逆变换 |
fftn |
计算input 的N维离散傅里叶变换 |
ifftn |
计算input 的N维离散傅里叶逆变换 |
rfft |
计算实数input 的一维傅里叶变换 |
irfft |
计算rfft() 的逆变换 |
rfft2 |
计算实数input 的二维离散傅里叶变换 |
irfft2 |
计算rfft2() 的逆变换 |
rfftn |
计算实数input 的N维离散傅里叶变换 |
irfftn |
计算rfftn() 的逆变换 |
hfft |
计算Hermitian对称input 信号的一维离散傅里叶变换 |
ihfft |
计算hfft() 的逆变换 |
hfft2 |
计算Hermitian对称input 信号的二维离散傅里叶变换 |
ihfft2 |
计算实数input 的二维离散傅里叶逆变换 |
hfftn |
计算Hermitian对称input 信号的N维离散傅里叶变换 |
ihfftn |
计算实数input 的N维离散傅里叶逆变换 |
辅助函数
fftfreq |
计算大小为 n 的信号的离散傅里叶变换采样频率。 |
---|---|
rfftfreq |
计算大小为 n 的信号在使用 rfft() 时的采样频率。 |
fftshift |
对由 fftn() 提供的 n 维 FFT 数据进行重新排序,使负频率项优先。 |
ifftshift |
fftshift() 的逆操作。 |
torch.func
torch.func(前身为"functorch")是为PyTorch提供的JAX风格可组合函数变换工具。
注意:该库目前处于测试阶段。
这意味着这些功能基本可用(除非另有说明),且我们(PyTorch团队)将持续推进该库的发展。但API可能会根据用户反馈进行调整,且尚未完全覆盖所有PyTorch操作。
如果您对API有改进建议,或希望支持特定使用场景,请提交GitHub issue或直接联系我们。我们非常期待了解您如何使用这个库。
什么是可组合的函数变换?
- 函数变换是一种高阶函数,它接受一个数值函数作为输入,并返回一个新函数来计算不同的量。
torch.func
提供了自动微分变换(例如grad(f)
返回计算f
梯度的函数)、向量化/批处理变换(例如vmap(f)
返回对输入批次执行f
的函数)等多种变换。- 这些函数变换可以任意组合使用。例如,组合
vmap(grad(f))
可以计算单样本梯度(per-sample-gradients),这是当前标准 PyTorch 无法高效计算的量。
为什么需要可组合的函数变换?
目前在 PyTorch 中实现以下用例较为棘手:
- 计算逐样本梯度(或其他逐样本量)
- 在单台机器上运行模型集成
- 在 MAML 内循环中高效批处理任务
- 高效计算雅可比矩阵和海森矩阵
- 高效计算批量雅可比矩阵和海森矩阵
通过组合使用 vmap()
、grad()
和 vjp()
变换,我们无需为每个用例单独设计子系统即可实现上述功能。这种可组合函数变换的理念源自 JAX 框架。
延伸阅读
torch.futures
该包提供了一种 Future
类型,用于封装异步执行过程,并提供一组实用函数来简化对 Future
对象的操作。目前,Future
类型主要被 分布式RPC框架 使用。
class torch.futures.Future(*, devices=None)
Wrapper around a torch._C.Future
which encapsulates an asynchronous
execution of a callable, e.g. rpc_async()
. It also exposes a set of APIs to add callback functions and set results.
Warning: GPU support is a beta feature, subject to changes.
add_done_callback(callback)
将给定的回调函数附加到此Future
上,该回调函数将在Future
完成时运行。可以向同一个Future
添加多个回调,但无法保证它们的执行顺序。回调函数必须接受一个参数,即对此Future
的引用。回调函数可以使用value()
方法获取值。请注意,如果此Future
已经完成,给定的回调将立即内联执行。
我们建议使用then()
方法,因为它提供了一种在回调完成后进行同步的方式。如果回调不返回任何内容,add_done_callback
可能更高效。但then()
和add_done_callback
在底层使用相同的回调注册API。
对于GPU张量,此方法的行为与then()
相同。
参数
callback (
Future)
– 一个可调用对象,接受一个参数,即对此Future
的引用。
注意:请注意,如果回调函数抛出异常,无论是由于原始future以异常完成并调用fut.wait()
,还是由于回调中的其他代码,都必须仔细处理错误。例如,如果此回调随后完成了其他future,这些future不会被标记为以错误完成,用户需要独立处理这些future的完成/等待。
示例:
>>> def callback(fut):
... print("This will run after the future has finished.")
... print(fut.wait())
>>> fut = torch.futures.Future()
>>> fut.add_done_callback(callback)
>>> fut.set_result(5)
This will run after the future has finished.
5
done()
如果该Future
已完成则返回True
。当Future
包含结果或异常时即视为完成。
如果值包含位于GPU上的张量,即使填充这些张量的异步内核尚未在设备上完成运行,Future.done()
仍会返回True
,因为在此阶段结果已可被使用(前提是执行适当的同步操作,参见wait()
)。
返回类型:bool
set_exception(result)
为这个 Future
设置一个异常,这将标记该 Future
以错误状态完成,并触发所有已附加的回调。请注意,当对此 Future
调用 wait()/value() 时,此处设置的异常
将被内联抛出。
参数
result ([BaseException](https://docs.python.org/3/library/exceptions.html#BaseException "(in Python v3.13)"))
– 该Future
的异常对象。
示例
>>> fut = torch.futures.Future()
>>> fut.set_exception(ValueError("foo"))
>>> fut.wait()
Traceback (most recent call last):
...
ValueError: foo
set_result(result)
为这个Future
设置结果,这将标记该Future
为已完成状态并触发所有关联的回调。需要注意的是,一个Future
不能被标记为已完成两次。
如果结果包含位于GPU上的张量,即使填充这些张量的异步内核尚未在设备上完成运行,只要调用此方法时这些内核所入队的流被设置为当前流,仍可调用此方法。简而言之,在启动这些内核后立即调用此方法是安全的,无需额外同步,前提是期间不切换流。此方法会在所有相关当前流上记录事件,并利用它们确保此Future
的所有消费者都能得到正确调度。
参数
result ( object )
- 该Future
的结果对象。
示例:
>>> import threading
>>> import time
>>> def slow_set_future(fut, value):
... time.sleep(0.5)
... fut.set_result(value)
>>> fut = torch.futures.Future()
>>> t = threading.Thread(
... target=slow_set_future,
... args=(fut, torch.ones(2) * 3)
... )
>>> t.start()
>>> print(fut.wait())
tensor([3., 3.])
>>> t.join()
then(callback)
将给定的回调函数附加到此Future
上,该回调函数将在Future
完成时运行。可以向同一个Future
添加多个回调,但无法保证它们的执行顺序(如需确保特定顺序,请考虑链式调用:fut.then(cb1).then(cb2)
)。回调函数必须接受一个参数,即对此Future
的引用。回调函数可通过value()
方法获取值。请注意,如果此Future
已完成,给定的回调将立即内联执行。
如果Future
的值包含位于GPU上的张量,回调可能在填充这些张量的异步内核尚未在设备上完成执行时就被调用。不过,回调将通过设置为当前的一些专用流(从全局池中获取)被调用,这些流将与那些内核同步。因此,回调对这些张量执行的任何操作都将在内核完成后调度到设备上。换句话说,只要回调不切换流,它就可以安全地操作结果而无需额外同步。这与wait()
的非阻塞行为类似。
类似地,如果回调返回的值包含位于GPU上的张量,即使生成这些张量的内核仍在设备上运行,回调也可以这样做,前提是回调在执行期间没有切换流。如果想要切换流,必须注意与原始流重新同步,即回调被调用时当前的流。
参数
callback (
Callable)
– 一个以该Future
为唯一参数的可调用对象。
返回
一个新的Future
对象,它持有callback
的返回值,并将在给定callback
完成时标记为已完成。
返回类型
Future[S]
注意:请注意,如果回调函数抛出异常,无论是通过原始future以异常完成并调用fut.wait()
,还是通过回调中的其他代码,then
返回的future将适当地标记为遇到错误。但是,如果此回调随后完成其他future,这些future不会标记为以错误完成,用户需负责独立处理这些future的完成/等待。
示例:
>>> def callback(fut):
... print(f"RPC return value is {fut.wait()}.")
>>> fut = torch.futures.Future()
>>> # The inserted callback will print the return value when
>>> # receiving the response from "worker1"
>>> cb_fut = fut.then(callback)
>>> chain_cb_fut = cb_fut.then(
... lambda x : print(f"Chained cb done. {x.wait()}")
... )
>>> fut.set_result(5)
RPC return value is 5、Chained cb done. None
value()
获取已完成的Future对象的值。
此方法仅应在调用wait()
完成后,或在传递给then()
的回调函数内部使用。其他情况下,该Future
可能尚未持有值,调用value()
可能会失败。
如果值包含位于GPU上的张量,此方法将不会执行任何额外的同步操作。此类同步应事先通过调用wait()
单独完成(回调函数内部除外,因为then()
已自动处理此情况)。
返回值
该Future
持有的值。如果创建该值的函数(回调或RPC)抛出错误,此value()
方法同样会抛出错误。
返回类型:T
wait()
等待直到该 Future
的值准备就绪。
如果值包含位于 GPU 上的张量,则会与设备上异步填充这些张量的内核执行额外的同步操作。此类同步是非阻塞的,这意味着 wait()
会在当前流中插入必要的指令,以确保后续在这些流上排队的操作能正确安排在异步内核之后执行。但一旦完成指令插入,即使这些内核仍在运行,wait()
也会立即返回。只要不切换流,在访问和使用这些值时无需进一步同步。
返回值:此 Future
持有的值。如果创建该值的函数(回调或 RPC)抛出错误,此 wait
方法同样会抛出错误。
返回类型:T
torch.futures.collect_all(futures)
将提供的 Future
对象收集到一个统一的组合 Future
中,该组合 Future 会在所有子 Future 完成时完成。
参数
futures (list)
– 一个包含Future
对象的列表。
返回
返回一个 Future
对象,该对象关联到传入的 Future 列表。
返回类型
Future[list [torch.jit.Future]]
示例
>>> fut0 = torch.futures.Future()
>>> fut1 = torch.futures.Future()
>>> fut = torch.futures.collect_all([fut0, fut1])
>>> fut0.set_result(0)
>>> fut1.set_result(1)
>>> fut_list = fut.wait()
>>> print(f"fut0 result = {fut_list[0].wait()}")
fut0 result = 0
>>> print(f"fut1 result = {fut_list[1].wait()}")
fut1 result = 1
torch.futures.wait_all(futures)
等待所有提供的 futures 完成,并返回已完成值的列表。如果任一 future 遇到错误,该方法将提前退出并报告错误,而不会等待其他 futures 完成。
参数
futures (list)
– 一个Future
对象列表。
返回值:已完成 Future
结果的列表。如果对任何 Future
调用 wait
时抛出错误,该方法也会抛出错误。
返回类型:list
torch.fx
概述
FX 是一个供开发者使用的工具包,用于转换 nn.Module
实例。FX 包含三个核心组件:符号追踪器、中间表示和 Python 代码生成。以下是这些组件的实际应用演示:
import torch
# Simple module for demonstration
class MyModule(torch.nn.Module):
def __init__(self) -None:
super().__init__()
self.param = torch.nn.Parameter(torch.rand(3, 4))
self.linear = torch.nn.Linear(4, 5)
def forward(self, x):
return self.linear(x + self.param).clamp(min=0.0, max=1.0)
module = MyModule()
from torch.fx import symbolic_trace
# Symbolic tracing frontend - captures the semantics of the module
symbolic_traced: torch.fx.GraphModule = symbolic_trace(module)
# High-level intermediate representation (IR) - Graph representation
print(symbolic_traced.graph)
"""
graph():
%x : [num_users=1] = placeholder[target=x]
%param : [num_users=1] = get_attr[target=param]
%add : [num_users=1] = call_function[target=operator.add](args = (%x, %param), kwargs = {})
%linear : [num_users=1] = call_module[target=linear](args = (%add,), kwargs = {})
%clamp : [num_users=1] = call_method[target=clamp](args = (%linear,), kwargs = {min: 0.0, max: 1.0})
return clamp
"""
# Code generation - valid Python code
print(symbolic_traced.code)
"""
def forward(self, x):
param = self.param
add = x + param; x = param = None
linear = self.linear(add); add = None
clamp = linear.clamp(min = 0.0, max = 1.0); linear = None
return clamp
"""
符号追踪器(symbolic tracer)对Python代码执行"符号执行"。它通过代码传递称为Proxy的虚拟值,并记录对这些Proxy的操作。有关符号追踪的更多信息,请参阅symbolic_trace()
和Tracer
文档。
中间表示(intermediate representation)是符号追踪过程中记录操作的容器。它由一组节点组成,这些节点表示函数输入、调用点(指向函数、方法或torch.nn.Module
实例)以及返回值。有关IR的更多信息,请参阅Graph
文档。IR是应用转换的基础格式。
Python代码生成功能使FX成为Python到Python(或Module到Module)的转换工具包。对于每个Graph IR,我们都可以生成符合Graph语义的有效Python代码。这个功能被封装在GraphModule
中,它是一个torch.nn.Module
实例,包含一个Graph
以及从Graph生成的forward
方法。
这些组件(符号追踪→中间表示→转换→Python代码生成)共同构成了FX的Python到Python转换流程。此外,这些组件也可以单独使用。例如,符号追踪可以单独用于捕获代码形式进行分析(而非转换)目的。代码生成可以用于通过编程方式生成模型,例如从配置文件生成。FX有许多用途!
在示例库中可以找到几个转换示例。
编写转换函数
什么是FX转换?本质上,它是一个形如下列的函数。
import torch
import torch.fx
def transform(m: nn.Module, tracer_class : type = torch.fx.Tracer) -torch.nn.Module:
# Step 1: Acquire a Graph representing the code in `m`
# NOTE: torch.fx.symbolic_trace is a wrapper around a call to # fx.Tracer.trace and constructing a GraphModule. We'll
# split that out in our transform to allow the caller to # customize tracing behavior.
graph : torch.fx.Graph = tracer_class().trace(m)
# Step 2: Modify this Graph or create a new one
graph = ...
# Step 3: Construct a Module to return
return torch.fx.GraphModule(m, graph)
您的转换器将接收一个 torch.nn.Module
,从中获取 Graph
,进行一些修改后返回一个新的 torch.nn.Module
。您应该将 FX 转换器返回的 torch.nn.Module
视为与常规 torch.nn.Module
完全相同——可以将其传递给另一个 FX 转换器、传递给 TorchScript 或直接运行它。确保 FX 转换器的输入和输出均为 torch.nn.Module
将有助于实现组合性。
注意:也可以直接修改现有的 GraphModule
而不创建新实例,例如:
import torch
import torch.fx
def transform(m : nn.Module) -nn.Module:
gm : torch.fx.GraphModule = torch.fx.symbolic_trace(m)
# Modify gm.graph
# <...>
# Recompile the forward() method of `gm` from its Graph
gm.recompile()
return gm
请注意,你必须调用 GraphModule.recompile()
方法,使生成的 forward()
方法与修改后的 Graph
保持同步。
假设你已经传入了一个经过追踪转换为 Graph
的 torch.nn.Module
,现在主要有两种方法来构建新的 Graph
。
图结构快速入门
关于图的语义完整说明可以参考 Graph
文档,这里我们主要介绍基础概念。Graph
是一种数据结构,用于表示 GraphModule
上的方法。其核心需要描述以下信息:
- 方法的输入参数是什么?
- 方法内部运行了哪些操作?
- 方法的输出(即返回值)是什么?
这三个概念都通过 Node
实例来表示。下面通过一个简单示例来说明:
import torch
import torch.fx
class MyModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.param = torch.nn.Parameter(torch.rand(3, 4))
self.linear = torch.nn.Linear(4, 5)
def forward(self, x):
return torch.topk(torch.sum(
self.linear(x + self.linear.weight).relu(), dim=-1), 3)
m = MyModule()
gm = torch.fx.symbolic_trace(m)
gm.graph.print_tabular()
这里我们定义一个演示用的模块 MyModule
,实例化后进行符号追踪,然后调用 Graph.print_tabular()
方法打印该 Graph
的节点表格:
操作码 | 名称 | 目标 | 参数 | 关键字参数 |
---|---|---|---|---|
placeholder | x | x | () | {} |
get_attr | linear_weight | linear.weight | () | {} |
call_function | add_1 | <built-in function add | (x, linear_weight) | {} |
call_module | linear_1 | linear | (add_1,) | {} |
call_method | relu_1 | relu | (linear_1,) | {} |
call_function | sum_1 | <built-in method sum … | (relu_1,) | {‘dim’: -1} |
call_function | topk_1 | <built-in method topk … | (sum_1, 3) | {} |
output | output | output | (topk_1,) | {} |
通过这些信息,我们可以回答之前提出的问题:
- 方法的输入是什么?
在FX中,方法输入通过特殊的placeholder
节点指定。本例中有一个目标为x
的placeholder
节点,表示存在一个名为x的(非self)参数。 - 方法内部有哪些操作?
get_attr
、call_function
、call_module
和call_method
节点表示方法中的操作。这些节点的完整语义说明可参考Node
文档。 - 方法的返回值是什么?
在Graph
中,返回值由特殊的output
节点指定。
现在我们已经了解FX中代码表示的基本原理,接下来可以探索如何编辑 Graph
。
图操作
直接操作计算图
构建新Graph
的一种方法是直接操作原有计算图。为此,我们可以简单地获取通过符号追踪得到的Graph
并进行修改。例如,假设我们需要将所有torch.add()
调用替换为torch.mul()
调用。
import torch
import torch.fx
# Sample module
class M(torch.nn.Module):
def forward(self, x, y):
return torch.add(x, y)
def transform(m: torch.nn.Module,
tracer_class : type = fx.Tracer) -torch.nn.Module:
graph : fx.Graph = tracer_class().trace(m)
# FX represents its Graph as an ordered list of # nodes, so we can iterate through them.
for node in graph.nodes:
# Checks if we're calling a function (i.e:
# torch.add)
if node.op == 'call_function':
# The target attribute is the function
# that call_function calls.
if node.target == torch.add:
node.target = torch.mul
graph.lint() # Does some checks to make sure the # Graph is well-formed.
return fx.GraphModule(m, graph)
我们还可以进行更复杂的 Graph
重写操作,例如删除或追加节点。为了辅助这些转换,FX 提供了一些用于操作计算图的实用函数,这些函数可以在 Graph
文档中找到。
下面展示了一个使用这些 API 追加 torch.relu()
调用的示例。
# Specifies the insertion point. Any nodes added to the # Graph within this scope will be inserted after `node` with traced.graph.inserting_after(node):
# Insert a new `call_function` node calling `torch.relu`
new_node = traced.graph.call_function(
torch.relu, args=(node,))
# We want all places that used the value of `node` to # now use that value after the `relu` call we've added.
# We use the `replace_all_uses_with` API to do this.
node.replace_all_uses_with(new_node)
对于仅包含替换操作的简单转换,您也可以使用子图重写器。
使用 replace_pattern() 进行子图重写
FX 在直接图操作的基础上提供了更高层次的自动化能力。replace_pattern()
API 本质上是一个用于编辑 Graph
的"查找/替换"工具。它允许你指定一个 pattern
(模式)和 replacement
(替换)函数,然后会追踪这些函数,在图中找到与 pattern
图匹配的操作组实例,并用 replacement
图的副本替换这些实例。这可以极大地自动化繁琐的图操作代码,随着转换逻辑变得复杂,手动操作会变得难以维护。
图操作示例
代理/回溯机制
另一种操作 Graph
的方式是复用符号追踪中使用的 Proxy
机制。例如,假设我们需要编写一个将 PyTorch 函数分解为更小操作的转换器:将每个 F.relu(x)
调用转换为 (x > 0) * x
。传统做法可能是通过图重写来插入比较和乘法操作,然后清理原始的 F.relu
。但借助 Proxy
对象,我们可以自动将操作记录到 Graph
中来实现这一过程。
具体实现时,只需将需要插入的操作写成常规 PyTorch 代码,并用 Proxy
对象作为参数调用该代码。这些 Proxy
对象会捕获对其执行的操作,并将其追加到 Graph
中。
# Note that this decomposition rule can be read as regular Python
def relu_decomposition(x):
return (x 0) * x
decomposition_rules = {}
decomposition_rules[F.relu] = relu_decomposition
def decompose(model: torch.nn.Module, tracer_class : type = fx.Tracer) -torch.nn.Module:
"""
Decompose `model` into smaller constituent operations.
Currently,this only supports decomposing ReLU into its
mathematical definition: (x 0) * x
"""
graph : fx.Graph = tracer_class().trace(model)
new_graph = fx.Graph()
env = {}
tracer = torch.fx.proxy.GraphAppendingTracer(new_graph)
for node in graph.nodes:
if node.op == 'call_function' and node.target in decomposition_rules:
# By wrapping the arguments with proxies, # we can dispatch to the appropriate
# decomposition rule and implicitly add it
# to the Graph by symbolically tracing it.
proxy_args = [
fx.Proxy(env[x.name], tracer) if isinstance(x, fx.Node) else x for x in node.args]
output_proxy = decomposition_rules[node.target](proxy_args)
# Operations on `Proxy` always yield new `Proxy`s, and the # return value of our decomposition rule is no exception.
# We need to extract the underlying `Node` from the `Proxy`
# to use it in subsequent iterations of this transform.
new_node = output_proxy.node
env[node.name] = new_node
else:
# Default case: we don't have a decomposition rule for this # node, so just copy the node over into the new graph.
new_node = new_graph.node_copy(node, lambda x: env[x.name])
env[node.name] = new_node
return fx.GraphModule(model, new_graph)
除了避免显式的图操作外,使用Proxy
还允许您将重写规则指定为原生Python代码。对于需要大量重写规则的转换(如vmap或grad),这通常可以提高规则的可读性和可维护性。
需要注意的是,在调用Proxy
时,我们还传递了一个指向底层变量图的追踪器。这样做是为了防止当图中的操作是n元操作时(例如add是二元运算符),调用Proxy
不会创建多个图追踪器实例,否则可能导致意外的运行时错误。特别是在底层操作不能安全地假设为一元操作时,我们推荐使用这种Proxy
方法。
一个使用Proxy
进行Graph
操作的实际示例可以在这里找到。
解释器模式
在FX中,一个实用的代码组织模式是遍历Graph
中的所有Node
并执行它们。这种模式可用于多种场景,包括:
- 运行时分析流经计算图的值
- 通过
Proxy
重新追踪来实现代码转换
例如,假设我们想运行一个GraphModule
,并在运行时记录节点上torch.Tensor
的形状和数据类型属性。实现代码可能如下:
import torch
import torch.fx
from torch.fx.node import Node
from typing import Dict
class ShapeProp:
"""
Shape propagation. This class takes a `GraphModule`.
Then, its `propagate` method executes the `GraphModule`
node-by-node with the given arguments. As each operation
executes, the ShapeProp class stores away the shape and element type for the output values of each operation on the `shape` and `dtype` attributes of the operation's
`Node`.
"""
def __init__(self, mod):
self.mod = mod
self.graph = mod.graph
self.modules = dict(self.mod.named_modules())
def propagate(self, args):
args_iter = iter(args)
env : Dict[str, Node] = {}
def load_arg(a):
return torch.fx.graph.map_arg(a, lambda n: env[n.name])
def fetch_attr(target : str):
target_atoms = target.split('.')
attr_itr = self.mod
for i, atom in enumerate(target_atoms):
if not hasattr(attr_itr, atom):
raise RuntimeError(f"Node referenced nonexistent target {'.'.join(target_atoms[:i])}")
attr_itr = getattr(attr_itr, atom)
return attr_itr
for node in self.graph.nodes:
if node.op == 'placeholder':
result = next(args_iter)
elif node.op == 'get_attr':
result = fetch_attr(node.target)
elif node.op == 'call_function':
result = node.target(load_arg(node.args), *load_arg(node.kwargs))
elif node.op == 'call_method':
self_obj, args = load_arg(node.args)
kwargs = load_arg(node.kwargs)
result = getattr(self_obj, node.target)(args, *kwargs)
elif node.op == 'call_module':
result = self.modules[node.target](load_arg(node.args), *load_arg(node.kwargs))
# This is the only code specific to shape propagation.
# you can delete this `if` branch and this becomes
# a generic GraphModule interpreter.
if isinstance(result, torch.Tensor):
node.shape = result.shape
node.dtype = result.dtype
env[node.name] = result
return load_arg(self.graph.result)
如你所见,为FX实现一个完整的解释器并不复杂,但却非常实用。为了简化这一模式的使用,我们提供了Interpreter
类,它封装了上述逻辑,允许通过方法重写来覆盖解释器执行的某些方面。
除了执行操作外,我们还可以通过向解释器传递Proxy
值来生成新的计算图。
类似地,我们提供了Transformer
类来封装这种模式。Transformer
的行为与Interpreter
类似,但不同于调用run
方法从模块获取具体输出值,你需要调用Transformer.transform()
方法来返回一个新的GraphModule
,该模块会应用你通过重写方法设置的任何转换规则。
解释器模式示例
调试
简介
在编写转换代码的过程中,我们的代码往往不会一开始就完全正确。这时就需要进行调试。关键在于采用逆向思维:首先检查调用生成模块的结果,验证其正确性;接着审查并调试生成的代码;最后追溯导致生成代码的转换过程并进行调试。
如果您不熟悉调试工具,请参阅辅助章节可用调试工具。
变换编写中的常见陷阱
set
迭代顺序的不确定性。在Python中,set
数据类型是无序的。例如,使用set
来存储Node
等对象集合可能导致意外的非确定性行为。比如当迭代一组Node
并将其插入Graph
时,由于set
数据类型是无序的,输出程序中操作的顺序将是非确定性的,且每次程序调用都可能变化。
推荐的替代方案是使用dict
数据类型。自Python 3.7起(以及cPython 3.6起),dict
保持了插入顺序。通过将需要去重的值存储在dict
的键中,可以等效地实现set
的功能。
检查模块的正确性
由于大多数深度学习模块的输出都是浮点型 torch.Tensor
实例,因此检查两个 torch.nn.Module
的结果是否相等并不像简单的相等性检查那样直接。为了说明这一点,我们来看一个示例:
import torch
import torch.fx
import torchvision.models as models
def transform(m : torch.nn.Module) -torch.nn.Module:
gm = torch.fx.symbolic_trace(m)
# Imagine we're doing some transforms here
# <...>
gm.recompile()
return gm
resnet18 = models.resnet18()
transformed_resnet18 = transform(resnet18)
input_image = torch.randn(5, 3, 224, 224)
assert resnet18(input_image) == transformed_resnet18(input_image)
"""
RuntimeError: Boolean value of Tensor with more than one value is ambiguous
"""
在这里,我们尝试使用==
相等运算符来检查两个深度学习模型的值是否相等。然而,这种做法存在两个问题:首先,该运算符返回的是张量而非布尔值;其次,浮点数值的比较应考虑误差范围(或epsilon),以解决浮点运算不可交换性的问题(详见此处)。
我们可以改用torch.allclose()
函数,它会基于相对和绝对容差阈值进行近似比较:
assert torch.allclose(resnet18(input_image), transformed_resnet18(input_image))
这是我们工具箱中的第一个工具,用于检查转换后的模块与参考实现相比是否按预期运行。
调试生成的代码
由于 FX 在 GraphModule
上生成 forward()
函数,使用传统的调试技术(如 print
语句或 pdb
)会不太直观。幸运的是,我们有多种方法可以用来调试生成的代码。
使用 pdb
通过调用 pdb
可以进入正在运行的程序进行调试。虽然表示 Graph
的代码不在任何源文件中,但当执行前向传播时,我们仍然可以手动使用 pdb
进入该代码进行调试。
import torch
import torch.fx
import torchvision.models as models
def my_pass(inp: torch.nn.Module, tracer_class : type = fx.Tracer) -torch.nn.Module:
graph = tracer_class().trace(inp)
# Transformation logic here
# <...>
# Return new Module
return fx.GraphModule(inp, graph)
my_module = models.resnet18()
my_module_transformed = my_pass(my_module)
input_value = torch.randn(5, 3, 224, 224)
# When this line is executed at runtime, we will be dropped into an # interactive `pdb` prompt. We can use the `step` or `s` command to # step into the execution of the next line
import pdb; pdb.set_trace()
my_module_transformed(input_value)
打印生成的代码
如果需要多次运行相同的代码,使用pdb
逐步调试到目标代码可能会有些繁琐。这种情况下,一个简单的方法是将生成的forward
传递代码直接复制粘贴到你的代码中,然后在那里进行检查。
# Assume that `traced` is a GraphModule that has undergone some
# number of transforms
# Copy this code for later
print(traced)
# Print the code generated from symbolic tracing. This outputs:
"""
def forward(self, y):
x = self.x
add_1 = x + y; x = y = None
return add_1
"""
# Subclass the original Module
class SubclassM(M):
def __init__(self):
super().__init__()
# Paste the generated `forward` function (the one we printed and # copied above) here
def forward(self, y):
x = self.x
add_1 = x + y; x = y = None
return add_1
# Create an instance of the original, untraced Module. Then, create an # instance of the Module with the copied `forward` function. We can # now compare the output of both the original and the traced version.
pre_trace = M()
post_trace = SubclassM()
使用 GraphModule
中的 to_folder
函数
GraphModule.to_folder()
是 GraphModule
中的一个方法,它允许你将生成的 FX 代码导出到一个文件夹。虽然像打印生成的代码中那样直接复制前向传播代码通常已经足够,但使用 to_folder
可以更方便地检查模块和参数。
m = symbolic_trace(M())
m.to_folder("foo", "Bar")
from foo import Bar
y = Bar()
运行上述示例后,我们可以查看foo/module.py
中的代码,并根据需要进行修改(例如添加print
语句或使用pdb
)来调试生成的代码。
调试转换过程
既然我们已经确认是转换过程生成了错误代码,现在就该调试转换本身了。首先,我们会查阅文档中的符号追踪限制部分。在确认追踪功能按预期工作后,我们的目标就转变为找出GraphModule
转换过程中出现的问题。编写转换部分可能有快速解决方案,如果没有的话,我们还可以通过多种方式来检查追踪模块:
# Sample Module
class M(torch.nn.Module):
def forward(self, x, y):
return x + y
# Create an instance of `M`
m = M()
# Symbolically trace an instance of `M` (returns a GraphModule). In
# this example, we'll only be discussing how to inspect a # GraphModule, so we aren't showing any sample transforms for the # sake of brevity.
traced = symbolic_trace(m)
# Print the code produced by tracing the module.
print(traced)
# The generated `forward` function is:
"""
def forward(self, x, y):
add = x + y; x = y = None
return add
"""
# Print the internal Graph.
print(traced.graph)
# This print-out returns:
"""
graph():
%x : [num_users=1] = placeholder[target=x]
%y : [num_users=1] = placeholder[target=y]
%add : [num_users=1] = call_function[target=operator.add](args = (%x, %y), kwargs = {})
return add
"""
# Print a tabular representation of the internal Graph.
traced.graph.print_tabular()
# This gives us:
"""
opcode name target args kwargs
------------- ------ ----------------------- ------ --------
placeholder x x () {}
placeholder y y () {}
call_function add <built-in function add (x, y) {}
output output output (add,) {}
"""
通过使用上述工具函数,我们可以对比应用转换前后的追踪模块。有时,简单的视觉对比就足以定位错误。如果问题仍不明确,下一步可以尝试使用 pdb
这类调试器。
以上述示例为基础,请看以下代码:
# Sample user-defined function
def transform_graph(module: torch.nn.Module, tracer_class : type = fx.Tracer) -torch.nn.Module:
# Get the Graph from our traced Module
g = tracer_class().trace(module)
"""
Transformations on `g` go here
"""
return fx.GraphModule(module, g)
# Transform the Graph
transformed = transform_graph(traced)
# Print the new code after our transforms. Check to see if it was
# what we expected
print(transformed)
以上述示例为例,假设调用print(traced)
时发现转换过程中存在错误。我们需要通过调试器定位问题根源。启动pdb
调试会话后,可以在transform_graph(traced)
处设置断点,然后按s
键"步入"该函数调用,实时观察转换过程。
另一个有效方法是修改print_tabular
方法,使其输出图中节点的不同属性(例如查看节点的input_nodes
和users
关系)。
可用的调试器
最常用的Python调试器是pdb。你可以通过在命令行输入python -m pdb FILENAME.py
来以"调试模式"启动程序,其中FILENAME
是你要调试的文件名。之后,你可以使用pdb
的调试器命令逐步执行正在运行的程序。通常的做法是在启动pdb
时设置一个断点(b LINE-NUMBER
),然后调用c
让程序运行到该断点处。这样可以避免你不得不使用s
或n
逐行执行代码才能到达想要检查的部分。或者,你也可以在想中断的代码行前写入import pdb; pdb.set_trace()
。如果添加了pdb.set_trace()
,当你运行程序时它会自动进入调试模式(换句话说,你只需在命令行输入python FILENAME.py
而不用输入python -m pdb FILENAME.py
)。一旦以调试模式运行文件,你就可以使用特定命令逐步执行代码并检查程序的内部状态。网上有很多关于pdb
的优秀教程,包括RealPython的《Python Debugging With Pdb》。
像PyCharm或VSCode这样的IDE通常内置了调试器。在你的IDE中,你可以选择:a)通过调出IDE中的终端窗口(例如在VSCode中选择View → Terminal)使用pdb
,或者b)使用内置的调试器(通常是pdb
的图形化封装)。
符号追踪的局限性
FX 采用符号追踪(又称符号执行)系统,以可转换/可分析的形式捕获程序语义。该系统具有以下特点:
- 追踪性:通过实际执行程序(实际是
torch.nn.Module
或函数)来记录操作 - 符号性:执行过程中流经程序的数据并非真实数据,而是符号(FX术语中称为
Proxy
)
虽然符号追踪适用于大多数神经网络代码,但它仍存在一些局限性。
动态控制流
符号追踪的主要局限在于目前不支持动态控制流。也就是说,当循环或if
语句的条件可能依赖于程序输入值时,就无法处理。
例如,我们来看以下程序:
def func_to_trace(x):
if x.sum() 0:
return torch.relu(x)
else:
return torch.neg(x)
traced = torch.fx.symbolic_trace(func_to_trace)
"""
<...>
File "dyn.py", line 6, in func_to_trace
if x.sum() 0:
File "pytorch/torch/fx/proxy.py", line 155, in __bool__
return self.tracer.to_bool(self)
File "pytorch/torch/fx/proxy.py", line 85, in to_bool
raise TraceError('symbolically traced variables cannot be used as inputs to control flow')
torch.fx.proxy.TraceError: symbolically traced variables cannot be used as inputs to control flow
"""
if
语句的条件依赖于x.sum()
的值,而该值又依赖于函数输入x
。由于x
可能发生变化(例如向追踪函数传入新的输入张量时),这就形成了动态控制流。回溯信息会沿着代码向上追溯,展示这种情况发生的位置。
静态控制流
另一方面,系统支持所谓的静态控制流。静态控制流指的是那些在多次调用中值不会改变的循环或if
语句。通常在PyTorch程序中,这种控制流出现在根据超参数决定模型架构的代码中。举个具体例子:
import torch
import torch.fx
class MyModule(torch.nn.Module):
def __init__(self, do_activation : bool = False):
super().__init__()
self.do_activation = do_activation
self.linear = torch.nn.Linear(512, 512)
def forward(self, x):
x = self.linear(x)
# This if-statement is so-called static control flow.
# Its condition does not depend on any input values
if self.do_activation:
x = torch.relu(x)
return x
without_activation = MyModule(do_activation=False)
with_activation = MyModule(do_activation=True)
traced_without_activation = torch.fx.symbolic_trace(without_activation)
print(traced_without_activation.code)
"""
def forward(self, x):
linear_1 = self.linear(x); x = None
return linear_1
"""
traced_with_activation = torch.fx.symbolic_trace(with_activation)
print(traced_with_activation.code)
"""
import torch
def forward(self, x):
linear_1 = self.linear(x); x = None
relu_1 = torch.relu(linear_1); linear_1 = None
return relu_1
"""
if self.do_activation
这个条件语句不依赖于任何函数输入,因此它是静态的。do_activation
可以被视为一个超参数,当 MyModule
的不同实例使用不同参数值时,生成的代码轨迹也会不同。这是一种有效模式,符号追踪功能支持这种模式。
许多动态控制流的实例在语义上其实是静态控制流。通过消除对输入值的数据依赖,这些实例可以支持符号追踪。具体方法包括:
- 将值移至
Module
属性中 - 在符号追踪期间将具体值绑定到参数上
def f(x, flag):
if flag: return x
else: return x*2
fx.symbolic_trace(f) # Fails!
fx.symbolic_trace(f, concrete_args={'flag': True})
在真正动态控制流的情况下,包含此类代码的程序部分可以被追踪为对方法的调用(参见使用Tracer类自定义追踪)或函数调用(参见wrap()
),而不是直接追踪这些代码本身。
非torch
函数
FX采用__torch_function__
作为拦截调用的机制(更多技术细节请参阅技术概览)。某些函数(如Python内置函数或math
模块中的函数)不受__torch_function__
覆盖,但我们仍希望在符号追踪中捕获它们。例如:
import torch
import torch.fx
from math import sqrt
def normalize(x):
"""
Normalize `x` by the size of the batch dimension
"""
return x / sqrt(len(x))
# It's valid Python code
normalize(torch.rand(3, 4))
traced = torch.fx.symbolic_trace(normalize)
"""
<...>
File "sqrt.py", line 9, in normalize
return x / sqrt(len(x))
File "pytorch/torch/fx/proxy.py", line 161, in __len__
raise RuntimeError("'len' is not supported in symbolic tracing by default. If you want "
RuntimeError: 'len' is not supported in symbolic tracing by default. If you want this call to be recorded, please call torch.fx.wrap('len') at module scope
"""
错误提示表明内置函数 len
不被支持。
我们可以通过 wrap()
API 将此类函数记录为跟踪中的直接调用:
torch.fx.wrap('len')
torch.fx.wrap('sqrt')
traced = torch.fx.symbolic_trace(normalize)
print(traced.code)
"""
import math
def forward(self, x):
len_1 = len(x)
sqrt_1 = math.sqrt(len_1); len_1 = None
truediv = x / sqrt_1; x = sqrt_1 = None
return truediv
"""
使用 Tracer
类自定义追踪功能
Tracer
类是 symbolic_trace
功能的基础实现类。通过继承 Tracer 类,可以自定义追踪行为,例如:
class MyCustomTracer(torch.fx.Tracer):
# Inside here you can override various methods
# to customize tracing. See the `Tracer` API
# reference
pass
# Let's use this custom tracer to trace through this module
class MyModule(torch.nn.Module):
def forward(self, x):
return torch.relu(x) + torch.ones(3, 4)
mod = MyModule()
traced_graph = MyCustomTracer().trace(mod)
# trace() returns a Graph. Let's wrap it up in a # GraphModule to make it runnable
traced = torch.fx.GraphModule(mod, traced_graph)
叶子模块
叶子模块是指在符号追踪过程中作为调用出现,而不会被继续追踪的模块。默认的叶子模块集合由标准torch.nn
模块实例组成。例如:
class MySpecialSubmodule(torch.nn.Module):
def forward(self, x):
return torch.neg(x)
class MyModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.linear = torch.nn.Linear(3, 4)
self.submod = MySpecialSubmodule()
def forward(self, x):
return self.submod(self.linear(x))
traced = torch.fx.symbolic_trace(MyModule())
print(traced.code)
# `linear` is preserved as a call, yet `submod` is traced though.
# This is because the default set of "Leaf Modules" includes all
# standard `torch.nn` modules.
"""
import torch
def forward(self, x):
linear_1 = self.linear(x); x = None
neg_1 = torch.neg(linear_1); linear_1 = None
return neg_1
"""
可以通过重写 Tracer.is_leaf_module()
来自定义叶子模块集合。
杂项说明
- 当前无法追踪张量构造函数(如
torch.zeros
、torch.ones
、torch.rand
、torch.randn
、torch.sparse_coo_tensor
):- 确定性构造函数(
zeros
、ones
)仍可使用,其生成的值会作为常量嵌入追踪记录。仅当这些构造函数的参数涉及动态输入大小时才会出现问题,此时可改用ones_like
或zeros_like
作为替代方案。 - 非确定性构造函数(
rand
、randn
)会将单个随机值嵌入追踪记录,这通常不符合预期行为。变通方法是将torch.randn
包装在torch.fx.wrap
函数中并调用该包装函数。
- 确定性构造函数(
(注:保留所有代码块及技术术语原貌,被动语态转为主动表述,长句拆分后保持技术严谨性)
@torch.fx.wrap
def torch_randn(x, shape):
return torch.randn(shape)
def f(x):
return x + torch_randn(x, 5)
fx.symbolic_trace(f)
此行为可能在未来的版本中修复。
- 类型注解
支持 Python 3 风格的类型注解(例如
func(x : torch.Tensor, y : int) -torch.Tensor
),
并且会通过符号追踪保留这些注解。目前不支持 Python 2 风格的注释类型注解
# type: (torch.Tensor, int) -torch.Tensor
。目前不支持函数内部局部变量的类型注解。
- 关于
training
标志和子模块的注意事项
- 当使用像
torch.nn.functional.dropout
这样的函数时,通常会传入self.training
作为训练参数。在 FX 追踪过程中,这个值很可能会被固定为一个常量。
import torch
import torch.fx
class DropoutRepro(torch.nn.Module):
def forward(self, x):
return torch.nn.functional.dropout(x, training=self.training)
traced = torch.fx.symbolic_trace(DropoutRepro())
print(traced.code)
"""
def forward(self, x):
dropout = torch.nn.functional.dropout(x, p = 0.5, training = True, inplace = False); x = None
return dropout
"""
traced.eval()
x = torch.randn(5, 3)
torch.testing.assert_close(traced(x), x)
"""
AssertionError: Tensor-likes are not close!
Mismatched elements: 15 / 15 (100.0%)
Greatest absolute difference: 1.6207983493804932 at index (0, 2) (up to 1e-05 allowed)
Greatest relative difference: 1.0 at index (0, 0) (up to 0.0001 allowed)
"""
然而,当使用标准的 nn.Dropout()
子模块时,训练标志会被封装起来,并且由于保留了 nn.Module
对象模型,可以对其进行修改。
class DropoutRepro2(torch.nn.Module):
def __init__(self):
super().__init__()
self.drop = torch.nn.Dropout()
def forward(self, x):
return self.drop(x)
traced = torch.fx.symbolic_trace(DropoutRepro2())
print(traced.code)
"""
def forward(self, x):
drop = self.drop(x); x = None
return drop
"""
traced.eval()
x = torch.randn(5, 3)
torch.testing.assert_close(traced(x), x)
由于这一差异,建议将与动态training
标志交互的模块标记为叶模块。
API 参考
torch.fx.symbolic_trace(root, concrete_args=None)
符号追踪 API
给定一个 nn.Module
或函数实例 root
,该 API 会返回一个 GraphModule
,这是通过记录追踪 root
时观察到的操作构建而成的。
concrete_args
参数允许你对函数进行部分特化,无论是为了移除控制流还是数据结构。
例如:
def f(a, b):
if b == True:
return a else:
return a * 2
由于控制流的存在,FX通常无法追踪此过程。不过,我们可以使用concrete_args
来针对变量b的值进行特化处理,从而实现追踪:
f = fx.symbolic_trace(f, concrete_args={"b": False})
assert f(3, False) == 6
请注意,虽然您仍可以传入不同的b值,但这些值将被忽略。
我们还可以使用concrete_args
来消除函数中对数据结构的处理。这将利用pytrees来展平您的输入。为了避免过度特化,对于不应特化的值,请传入fx.PH
。例如:
def f(x):
out = 0
for v in x.values():
out += v
return out
f = fx.symbolic_trace(f, concrete_args={"x": {"a": fx.PH, "b": fx.PH, "c": fx.PH}})
assert f({"a": 1, "b": 2, "c": 4}) == 7
参数
root (Union[torch.nn.Module, Callable])
- 待追踪并转换为图表示形式的模块或函数concrete_args (Optional[Dict[str, any]])
- 需要部分特化的输入参数
返回从root
记录的操作所创建的模块。
返回类型:GraphModule
注意:此API保证向后兼容性。
torch.fx.wrap(fn_or_name)
该函数可在模块级作用域调用,将fn_or_name
注册为"叶子函数"。
"叶子函数"在FX跟踪中会保留为CallFunction节点,而不会被进一步跟踪。
# foo/bar/baz.py
def my_custom_function(x, y):
return x * x + y * y
torch.fx.wrap("my_custom_function")
def fn_to_be_traced(x, y):
# When symbolic tracing, the below call to my_custom_function will be inserted into
# the graph rather than tracing it.
return my_custom_function(x, y)
该函数也可以等效地用作装饰器:
# foo/bar/baz.py
@torch.fx.wrap
def my_custom_function(x, y):
return x * x + y * y
包装函数可以被视为"叶子函数",类似于"叶子模块"的概念,也就是说,这些函数在FX跟踪中会保留为调用点,而不会被进一步追踪。
参数
fn_or_name (Union[str, Callable])
- 当被调用时,要插入到图中的函数或全局函数名称
注意:此API保证向后兼容性。
class torch.fx.GraphModule(*args, **kwargs)
GraphModule 是由 fx.Graph 生成的 nn.Module。GraphModule 具有一个 graph
属性,以及从该 graph
生成的 code
和 forward
属性。
警告:当重新分配 graph
时,code
和 forward
将自动重新生成。但如果你编辑了 graph
的内容而没有重新分配 graph
属性本身,则必须调用 recompile()
来更新生成的代码。
注意:此 API 保证向后兼容性。
__init__(root, graph, class_name='GraphModule')
构建一个 GraphModule。
参数
root (Union[torch.nn.Module , Dict[str, Any])
–root
可以是 nn.Module 实例,也可以是将字符串映射到任意属性类型的字典。
当 root
是 Module 时,Graph 的 Nodes 中 target
字段对基于 Module 的对象(通过限定名称引用)的任何引用,都会从 root
的 Module 层次结构中的相应位置复制到 GraphModule 的模块层次结构中。
当 root
是字典时,Node 的 target
中找到的限定名称将直接在字典的键中查找。字典映射的对象将被复制到 GraphModule 模块层次结构中的适当位置。
graph (Graph)
–graph
包含此 GraphModule 用于代码生成的节点class_name (str)
–name
表示此 GraphModule 的名称,用于调试目的。如果未设置,所有错误消息将报告为源自GraphModule
。将其设置为root
的原始名称或在转换上下文中合理的名称可能会有所帮助。
注意:此 API 保证向后兼容性。
add_submodule(target, m)
将给定的子模块添加到self
中。
如果target
是子路径且对应位置尚未存在模块,此方法会安装空的模块。
参数
target (str)
- 新子模块的完整限定字符串名称
(参见nn.Module.get_submodule
中的示例了解如何指定完整限定字符串)m (Module)
- 子模块本身;即我们想要安装到当前模块中的实际对象
返回
子模块是否能够被插入。要使该方法返回True,target
表示的链中每个对象必须满足以下条件之一:
a) 尚不存在,或
b) 引用的是nn.Module
(而非参数或其他属性)
返回类型:bool
注意:此API保证向后兼容性。
property code: str
返回从该 GraphModule
底层 Graph
生成的 Python 代码。
delete_all_unused_submodules()
***
Deletes all unused submodules from `self`.
A Module is considered “used” if any one of the following is true:
1、It has children that are used
2、Its forward is called directly via a `call_module` node
3、It has a non-Module attribute that is used from a `get_attr` node
This method can be called to clean up an `nn.Module` without
manually calling `delete_submodule` on each unused submodule.
***
Note: Backwards-compatibility for this API is guaranteed.
delete_submodule(target)
从self
中删除指定的子模块。
如果target
不是有效的目标,则不会删除该模块。
参数
target (str)
- 新子模块的完全限定字符串名称
(有关如何指定完全限定字符串的示例,请参阅nn.Module.get_submodule
)
返回值
表示目标字符串是否引用了我们要删除的子模块。返回值为False
意味着target
不是有效的子模块引用。
返回类型 : bool
注意:此API保证向后兼容性。
property graph: [Graph](https://pytorch.org/docs/stable/data.html#torch.fx.Graph "torch.fx.graph.Graph")
返回该 GraphModule
底层对应的 Graph
print_readable(print_output=True, include_stride=False, include_device=False, colored=False)
返回为当前 GraphModule 及其子 GraphModule 生成的 Python 代码
警告:此 API 为实验性质,且不保证向后兼容性。
recompile()
根据其 graph
属性重新编译该 GraphModule。在编辑包含的 graph
后应调用此方法,否则该 GraphModule
生成的代码将过期。
注意:此 API 保证向后兼容性。
返回类型:PythonCode
to_folder(folder, module_name='FxModule')
将模块以 module_name
名称转储到 folder
目录下,以便可以通过 from <folder> import <module_name>
方式导入。
参数:
folder (Union [str, os.PathLike])
: 用于输出代码的目标文件夹路径
module_name (str): 在输出代码时使用的顶层模块名称
警告:此 API 为实验性质,不保证向后兼容性。
class torch.fx.Graph(owning_module=None, tracer_cls=None, tracer_extras=None)
Graph
是 FX 中间表示中使用的主要数据结构。
它由一系列 Node
组成,每个节点代表调用点(或其他语法结构)。这些 Node
的集合共同构成了一个有效的 Python 函数。
例如,以下代码
import torch
import torch.fx
class MyModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.param = torch.nn.Parameter(torch.rand(3, 4))
self.linear = torch.nn.Linear(4, 5)
def forward(self, x):
return torch.topk(
torch.sum(self.linear(x + self.linear.weight).relu(), dim=-1), 3
)
m = MyModule()
gm = torch.fx.symbolic_trace(m)
将生成以下图表:
print(gm.graph)
graph(x):
%linear_weight : [num_users=1] = self.linear.weight
%add_1 : [num_users=1] = call_function[target=operator.add](args = (%x, %linear_weight), kwargs = {})
%linear_1 : [num_users=1] = call_module[target=linear](args = (%add_1,), kwargs = {})
%relu_1 : [num_users=1] = call_method[target=relu](args = (%linear_1,), kwargs = {})
%sum_1 : [num_users=1] = call_function[target=torch.sum](args = (%relu_1,), kwargs = {dim: -1})
%topk_1 : [num_users=1] = call_function[target=torch.topk](args = (%sum_1, 3), kwargs = {})
return topk_1
关于Graph
中操作的具体语义,请参阅Node
文档。
注意:本API保证向后兼容性。
__init__(owning_module=None, tracer_cls=None, tracer_extras=None)
构建一个空图。
注意:此 API 保证向后兼容性。
call_function(the_function, args=None, kwargs=None, type_expr=None)
在Graph
中插入一个call_function
类型的Node
。call_function
节点表示对Python可调用对象的调用,由the_function
指定。
参数
the_function (Callable[...*, Any])
– 要调用的函数。可以是任何PyTorch运算符、Python函数,或属于builtins
或operator
命名空间的成员。args (Optional[Tuple[Argument*, ...]])
– 传递给被调用函数的位置参数。kwargs (Optional[Dict[str, Argument]])
– 传递给被调用函数的关键字参数。type_expr (Optional[Any])
– 可选的类型注解,表示该节点输出值的Python类型。
返回
新创建并插入的call_function
节点。
返回类型
Node
注意:此方法的插入点和类型表达式规则与Graph.create_node()
相同。
注意:此API保证向后兼容性。
call_method(method_name, args=None, kwargs=None, type_expr=None)
向Graph
中插入一个call_method
节点。call_method
节点表示对args
第0个元素调用指定方法。
参数
method_name (str)
- 要应用于self参数的方法名称。例如,如果args[0]是一个表示Tensor
的Node
,那么要对该Tensor
调用relu()
方法时,需将relu
作为method_name
传入。args (Optional[Tuple[Argument*, ...]])
- 要传递给被调用方法的位置参数。注意这应该包含一个self参数。kwargs (Optional[Dict[str, Argument]])
- 要传递给被调用方法的关键字参数type_expr (Optional[Any])
- 可选的类型注解,表示该节点输出结果的Python类型。
返回
新创建并插入的call_method
节点。
返回类型
Node
注意:本方法的插入点和类型表达式规则与Graph.create_node()
相同。
注意:此API保证向后兼容性。
call_module(module_name, args=None, kwargs=None, type_expr=None)
向Graph
中插入一个call_module
类型的Node
节点。call_module
节点表示对Module
层级结构中某个Module
的forward()函数的调用。
参数
module_name (str)
- 要调用的Module
在层级结构中的限定名称。例如,若被追踪的Module
有一个名为foo
的子模块,而该子模块又包含名为bar
的子模块,则应以foo.bar
作为module_name
来调用该模块。args (Optional[Tuple[Argument*, ...]])
- 传递给被调用方法的位置参数。注意:此处不应包含self
参数。kwargs (Optional[Dict[str, Argument]])
- 传递给被调用方法的关键字参数type_expr (Optional[Any])
- 可选类型注解,表示该节点输出值的Python类型。
返回
新创建并插入的call_module
节点。
返回类型:Node
注意:本方法的插入点与类型表达式规则与Graph.create_node()
相同。
注意:本API保证向后兼容性。
create_node(op, target, args=None, kwargs=None, name=None, type_expr=None)
创建一个 Node
并将其添加到当前插入点的 Graph
中。
注意:当前插入点可以通过 Graph.inserting_before()
和 Graph.inserting_after()
进行设置。
参数
op (str)
- 该节点的操作码。可选值包括 ‘call_function’、‘call_method’、‘get_attr’、‘call_module’、‘placeholder’ 或 ‘output’。这些操作码的语义在Graph
的文档字符串中有详细说明。args (Optional[Tuple[Argument*, ...]])
- 该节点的参数元组。kwargs (Optional[Dict[str, Argument]])
- 该节点的关键字参数。name (Optional[str])
- 为Node
指定的可选字符串名称。这将影响生成的 Python 代码中赋值给该节点的变量名。type_expr (Optional[Any])
- 可选类型注解,表示该节点输出值的 Python 类型。
返回
新创建并插入的节点。
返回类型:Node
注意:此 API 保证向后兼容。
eliminate_dead_code(is_impure_node=None)
根据图中各节点的用户数量及是否具有副作用,移除所有死代码。调用前必须确保图已完成拓扑排序。
参数
is_impure_node (Optional[Callable[[Node],* [bool]]])
—— 用于判断节点是否为非纯函数的回调函数。若未提供该参数,则默认使用Node.is_impure
方法。
返回值:返回布尔值,表示该过程是否导致图结构发生变更。
返回类型:bool
示例
在消除死代码前,下方表达式 a = x + 1
中的变量 a
无用户引用,因此可从图中安全移除而不影响结果。
def forward(self, x):
a = x + 1
return x + self.attr_1
消除死代码后,a = x + 1
已被移除,前向传播部分的其他代码保留不变。
def forward(self, x):
return x + self.attr_1
警告:死代码消除机制虽然采用了一些启发式方法来避免删除具有副作用的节点(参见 Node.is_impure
),但总体覆盖率非常不理想。因此,除非你明确知道当前 FX 计算图完全由无副作用的操作构成,或者自行提供了检测副作用节点的自定义函数,否则不应假设调用此方法是安全可靠的。
注意:本 API 保证向后兼容性。
erase_node(to_erase)
从Graph
中删除一个Node
。如果该节点在Graph
中仍被使用,将抛出异常。
参数
to_erase (Node)
– 要从Graph
中删除的Node
。
注意:此API保证向后兼容性。
find_nodes(*, op, target=None, sort=True)
支持快速查询节点
参数
op (str)
– 操作名称target (Optional[Target])
– 节点目标。对于call_function操作,target为必填项;其他操作中target为可选参数。sort ([bool])
– 是否按节点在图中出现的顺序返回结果。
返回值:返回符合指定op和target条件的节点迭代器。
警告:此API为实验性质,且不保证向后兼容。
get_attr(qualified_name, type_expr=None)
向图中插入一个 get_attr
节点。get_attr
类型的 Node
表示从 Module
层次结构中获取某个属性。
参数
qualified_name (str)
- 要获取属性的全限定名称。例如,若被追踪的 Module 包含名为foo
的子模块,该子模块又包含名为bar
的子模块,而bar
拥有名为baz
的属性,则应将全限定名称foo.bar.baz
作为qualified_name
传入。type_expr (Optional[Any])
- 可选的类型注解,用于表示该节点输出值的 Python 类型。
返回
新创建并插入的 get_attr
节点。
返回类型:Node
注意:本方法的插入点与类型表达式规则与 Graph.create_node
方法保持一致。
注意:此 API 保证向后兼容性。
graph_copy(g, val_map, return_output_node=False)
将给定图中的所有节点复制到 self
中。
参数
g (Graph)
– 作为节点复制来源的原始图。val_map (Dict[Node,* Node])
– 用于存储节点映射关系的字典,键为g
中的节点,值为self
中的对应节点。注意:val_map
可预先包含值以实现特定值的复制覆盖。
返回值:如果 g
包含输出节点,则返回 self
中与 g
输出值等效的值;否则返回 None
。
返回类型:Optional[Union [tuple [Argument, …], Sequence [Argument], Mapping [str , Argument], slice , range , Node, str , int , float, bool , complex , [dtype](tensor_attributes.html#torch.dtype "torch.dtype"), Tensor , device , memory_format , layout , OpOverload, [SymInt](torch.html#torch.SymInt "torch.SymInt"), SymBool , SymFloat ]]
注意:本API保证向后兼容性。
inserting_after(n=None)
设置 create_node
及相关方法在图中插入节点的位置。当在 with
语句中使用时,这会临时设置插入点,并在 with
语句退出时恢复原位置。
with g.inserting_after(n):
... # inserting after node n
... # insert point restored to what it was previously
g.inserting_after(n) # set the insert point permanently
参数:
n (可选[Node]): 要在其之前插入的节点。如果为None,则会在整个图的起始位置之后插入。
返回:
一个资源管理器,它会在__exit__
时恢复插入点。
注意:此API保证向后兼容性。
inserting_before(n=None)
设置 create_node
及相关方法在图中插入节点的基准位置。当在 with
语句中使用时,这将临时设置插入点,并在 with
语句退出时恢复原位置。
with g.inserting_before(n):
... # inserting before node n
... # insert point restored to what it was previously
g.inserting_before(n) # set the insert point permanently
参数:
n (Optional[Node]): 要插入位置的前一个节点。如果为None,则会在整个图的起始位置前插入。
返回:
一个资源管理器,该管理器会在__exit__
时恢复插入点。
注意:此API保证向后兼容性。
lint()
对该图执行多项检查以确保其结构正确。具体包括:
- 检查节点是否具有正确的所有权(由本图所有)
- 检查节点是否按拓扑顺序排列
- 若该图拥有所属的GraphModule,则检查目标是否存在该GraphModule中
注:本API保证向后兼容性。
node_copy(node, arg_transform=<function Graph.<lambda>>)
将节点从一个图复制到另一个图中。arg_transform
需要将节点所在图的参数转换为目标图(self)的参数。示例:
# Copying all the nodes in `g` into `new_graph`
g: torch.fx.Graph = ...
new_graph = torch.fx.graph()
value_remap = {}for node in g.nodes:
value_remap[node] = new_graph.node_copy(node, lambda n: value_remap[n])
参数
node (Node)
– 要复制到self
中的节点。arg_transform (Callable[[Node], Argument])
– 一个函数,用于将节点args
和kwargs
中的Node
参数转换为self
中的等效参数。最简单的情况下,该函数应从原始图中节点到self
的映射表中检索值。
返回类型:Node
注意:此 API 保证向后兼容性。
property nodes: _node_list
获取构成该图的所有节点列表。
请注意,这个Node
列表是以双向链表的形式表示的。在迭代过程中进行修改(例如删除节点、添加节点)是安全的。
返回值:一个双向链表结构的节点列表。注意可以对该列表调用reversed
方法来切换迭代顺序。
on_generate_code(make_transformer)
在生成 Python 代码时注册转换器函数
参数:
make_transformer (Callable[[Optional[TransformCodeFunc]], TransformCodeFunc]):返回待注册代码转换器的函数。
该函数由 on_generate_code 调用以获取代码转换器。
此函数的输入参数为当前已注册的代码转换器(若未注册则为 None),以便在不需要覆盖时使用。该机制可用于串联多个代码转换器。
返回值:一个上下文管理器,当在 with 语句中使用时,会自动恢复先前注册的代码转换器。
示例:
gm: fx.GraphModule = ...
# This is a code transformer we want to register. This code
# transformer prepends a pdb import and trace statement at the very
# beginning of the generated torch.fx code to allow for manual
# debugging with the PDB library.
def insert_pdb(body):
return ["import pdb; pdb.set_trace()\n", body]
# Registers `insert_pdb`, and overwrites the current registered
# code transformer (given by `_` to the lambda):
gm.graph.on_generate_code(lambda _: insert_pdb)
# Or alternatively, registers a code transformer which first
# runs `body` through existing registered transformer, then
# through `insert_pdb`:
gm.graph.on_generate_code(
lambda current_trans: (
lambda body: insert_pdb(current_trans(body) if current_trans else body)
)
)
gm.recompile()
gm(inputs) # drops into pdb
该函数也可作为上下文管理器使用,其优势在于能自动恢复之前注册的代码转换器。
# ... continue from previous example
with gm.graph.on_generate_code(lambda _: insert_pdb):
# do more stuff with `gm`...
gm.recompile()
gm(inputs) # drops into pdb
# now previous code transformer is restored (but `gm`'s code with pdb
# remains - that means you can run `gm` with pdb here too, until you # run next `recompile()`).
警告:此 API 为实验性质,且不向后兼容。
output(result, type_expr=None)
将 output
Node
插入到 Graph
中。output
节点代表 Python 代码中的 return
语句。result
是应当返回的值。
参数
result (Argument)
– 要返回的值。type_expr (Optional[Any])
– 可选的类型注解,表示此节点输出将具有的 Python 类型。
注意:此方法的插入点和类型表达式规则与 Graph.create_node
相同。
注意:此 API 保证向后兼容性。
output_node()
警告:此 API 为实验性质,且不向后兼容。
返回值类型:Node
placeholder(name, type_expr=None, default_value)
在图中插入一个placeholder
节点。placeholder
表示函数的输入参数。
参数
name (str)
- 输入值的名称。这对应于该Graph
所表示函数的位置参数名称。type_expr (Optional[Any])
- 可选的类型注解,表示该节点输出值的Python类型。在某些情况下(例如当函数后续用于TorchScript编译时),这是生成正确代码所必需的。default_value (Any)
- 该函数参数的默认值。注意:为了允许None作为默认值,当参数没有默认值时,应传递inspect.Signature.empty来指定。
返回类型:Node
注意:此方法的插入点和类型表达式规则与Graph.create_node
相同。
注意:此API保证向后兼容性。
print_tabular()
以表格形式打印图的中间表示。注意:此API需要安装tabulate
模块。
注:该API保证向后兼容性。
process_inputs(*args)
处理参数以便它们可以传递到 FX 计算图中。
警告:此 API 为实验性质,且不向后兼容。
process_outputs(out)
警告:此 API 为实验性质,且不向后兼容。
python_code(root_module, *, verbose=False, include_stride=False, include_device=False, colored=False)
将这段Graph
转换为有效的Python代码。
参数
root_module (str)
– 用于查找限定名称目标的根模块名称。通常为’self’。
返回值:src: 表示该对象的Python源代码
globals: 包含src中全局名称及其引用对象的字典
返回类型:一个包含两个字段的PythonCode对象
注意:此API保证向后兼容性。
set_codegen(codegen)
警告:此 API 为实验性功能,且不向后兼容。
class torch.fx.Node(graph, name, op, target, args, kwargs, return_type=None)
Node
是表示 Graph
中单个操作的数据结构。在大多数情况下,Node 表示对各种实体的调用点,例如运算符、方法和模块(某些例外包括指定函数输入和输出的节点)。每个 Node
都有一个由其 op
属性指定的函数。不同 op
值的 Node
语义如下:
placeholder
表示函数输入。name
属性指定该值的名称。target
同样是参数的名称。args
包含:1) 空值,或 2) 表示函数输入默认参数的单个参数。kwargs
无关紧要。占位符对应于图形输出中的函数参数(例如x
)。get_attr
从模块层次结构中检索参数。name
同样是获取结果后赋值的名称。target
是参数在模块层次结构中的完全限定名称。args
和kwargs
无关紧要。call_function
将自由函数应用于某些值。name
同样是赋值目标的名称。target
是要应用的函数。args
和kwargs
表示函数的参数,遵循 Python 调用约定。call_module
将模块层次结构中的forward()
方法应用于给定参数。name
同前。target
是要调用的模块在模块层次结构中的完全限定名称。args
和kwargs
表示调用模块时的参数(不包括 self 参数*)。call_method
调用值的方法。name
类似。target
是要应用于self
参数的方法名称字符串。args
和kwargs
表示调用模块时的参数(包括 self 参数*)。output
在其args[0]
属性中包含跟踪函数的输出。这对应于图形输出中的 “return” 语句。
注意:此 API 保证向后兼容。
property all_input_nodes: list ['Node']
Return all Nodes that are inputs to this Node. This is equivalent to iterating over `args` and `kwargs` and only collecting the values that are Nodes.
Returns
List of `Nodes` that appear in the `args` and `kwargs` of this `Node`, in that order.
append(x)
在图的节点列表中,将 x
插入到当前节点之后。
等价于调用 self.next.prepend(x)
参数
x (Node)
– 要插入到当前节点后的节点。必须属于同一个图。
注意:此 API 保证向后兼容。
property args: tuple [Union [tuple ['Argument',
...], collections.abc.Sequence ['Argument'], collections.abc.Mapping[str , 'Argument'], slice , range , torch.fx.node.Node, str , int , float, bool , complex , torch.dtype , torch.Tensor, torch.device , torch.memory_format, torch.layout, torch._ops.OpOverload, torch.SymInt, torch.SymBool, torch.SymFloat, NoneType],
...]
该Node
的参数元组。参数的具体含义取决于节点的操作码(opcode)。更多信息请参阅Node
文档字符串。
允许对此属性进行赋值操作。所有关于使用情况和用户的记录都会在赋值时自动更新。
format_node(placeholder_names=None, maybe_return_typename=None)
返回一个描述性的字符串表示形式self
。
该方法可不带参数使用,作为调试工具。
此函数也用于Graph
的__str__
方法内部。placeholder_names
和maybe_return_typename
中的字符串共同构成了该Graph所属GraphModule中自动生成的forward
函数的签名。placeholder_names
和maybe_return_typename
不应在其他情况下使用。
参数
placeholder_names (Optional[list[str]])
- 一个列表,用于存储表示生成的forward
函数中占位符的格式化字符串。仅供内部使用。maybe_return_typename (Optional[list[str]])
- 一个单元素列表,用于存储表示生成的forward
函数输出的格式化字符串。仅供内部使用。
返回
如果1)我们在Graph
的__str__
方法中将format_node
用作内部辅助工具,且2)self
是一个占位符Node,则返回None
。否则,返回当前Node的描述性字符串表示形式。
返回类型:str
注意:此API保证向后兼容。
insert_arg(idx, arg)
在参数列表的指定索引位置插入一个位置参数。
参数
idx ( int )
– 要插入到self.args
中元素之前的索引位置。arg (Argument)
– 要插入到args
中的新参数值
注意:本API保证向后兼容性。
is_impure()
返回该操作是否为不纯操作,即判断其操作是否为占位符或输出,或者是否为不纯的call_function
或call_module
。
返回值:指示该操作是否不纯。
返回类型:bool
警告:此API为实验性质,且不向后兼容。
property kwargs: dict[str , Union [tuple ['Argument',
...], collections.abc.Sequence['Argument'], collections.abc.Mapping, [str , 'Argument'], slice , range , torch.fx.node.Node, str , int , float, bool , complex , torch.dtype , torch.Tensor, torch.device , torch.memory_format, torch.layout, torch._ops.OpOverload, torch.SymInt, torch.SymBool, torch.SymFloat, NoneType]]
该Node
的关键字参数字典。参数的解析取决于节点的操作码。更多信息请参阅Node
文档字符串。
允许对此属性进行赋值。所有关于使用情况和用户的统计都会在赋值时自动更新。
property next: Node
返回链表中下一个Node
节点。
返回值:链表中下一个Node
节点。
normalized_arguments(root, arg_types=None, kwarg_types=None, normalize_to_only_use_kwargs=False)
返回经过标准化的Python目标参数。这意味着当normalize_to_only_use_kwargs
为真时,args/kwargs将与模块/函数的签名匹配,并按位置顺序仅返回kwargs。
同时会填充默认值。不支持仅限位置参数或可变参数。
支持模块调用。
可能需要arg_types
和kwarg_types
来消除重载歧义。
参数
root (torch.nn.Module)
– 用于解析模块目标的基模块arg_types (Optional[Tuple[Any]])
– 参数的元组类型kwarg_types (Optional[Dict[str, Any]])
– 关键字参数的字典类型normalize_to_only_use_kwargs ([bool])
– 是否标准化为仅使用kwargs
返回
返回命名元组ArgsKwargsPair
,若失败则返回None
返回类型
Optional[ArgsKwargsPair]
警告:该API为实验性质,不保证向后兼容。
prepend(x)
在图的节点列表中,在此节点前插入x。示例:
Before: p -self
bx -x -ax
After: p -x -self
bx -ax
参数
x (Node)
– 要放置在该节点之前的节点。必须是同一图的成员。
注意:此 API 保证向后兼容。
property prev: Node
返回链表中当前节点的前一个Node
。
返回值:链表中当前节点的前一个Node
。
replace_all_uses_with(replace_with, delete_user_cb=<function Node.<lambda>>, *, propagate_meta=False)
将图中所有使用 self
的地方替换为节点 replace_with
。
参数
replace_with (Node)
– 用于替换所有self
的节点。delete_user_cb (Callable)
– 回调函数,用于判断是否应移除某个使用原self
节点的用户节点。propagate_meta ([bool])
– 是否将原节点.meta
字段的所有属性复制到替换节点上。出于安全考虑,仅当替换节点本身没有.meta
字段时才允许此操作。
返回值
返回受此变更影响的节点列表。
返回类型:list [Node]
注意:此 API 保证向后兼容。
replace_input_with(old_input, new_input)
遍历 self
的输入节点,将所有 old_input
实例替换为 new_input
。
参数
old_input (Node)
– 需要被替换的旧输入节点。new_input (Node)
– 用于替换old_input
的新输入节点。
注意:此 API 保证向后兼容性。
property stack_trace: Optional[str ]
返回在追踪过程中记录的 Python 堆栈跟踪信息(如果有)。
当使用 fx.Tracer
进行追踪时,该属性通常由 Tracer.create_proxy
填充。若需在追踪过程中记录堆栈跟踪以用于调试,请在 Tracer 实例上设置 record_stack_traces = True
。
当使用 dynamo 进行追踪时,该属性默认会由 OutputGraph.create_proxy
填充。
stack_trace
的字符串末尾将包含最内层的调用帧。
update_arg(idx, arg)
更新现有位置参数以包含新值
调用后,self.args[idx] == arg
将成立。
参数
idx ( int )
- 要更新元素在self.args
中的索引位置arg (Argument)
- 要写入args
的新参数值
注意:此 API 保证向后兼容性。
update_kwarg(key, arg)
更新现有关键字参数以包含新值
arg
。调用后,self.kwargs[key] == arg
。
参数
key (str)
- 要更新的元素在self.kwargs
中的键名arg (Argument)
- 要写入kwargs
的新参数值
注意:此API保证向后兼容性。
class torch.fx.Tracer(autowrap_modules=(math,), autowrap_functions=())
Tracer
是实现 torch.fx.symbolic_trace
符号追踪功能的类。调用 symbolic_trace(m)
等价于执行 Tracer().trace(m)
。
可以通过继承 Tracer 类来覆盖追踪过程中的各种行为。具体可覆盖的行为详见该类方法的文档字符串。
注意:此 API 保证向后兼容。
call_module(m, forward, args, kwargs)
该方法定义了当Tracer
遇到对nn.Module
实例调用时的行为。
默认行为是通过is_leaf_module
检查被调用的模块是否为叶子模块。如果是,则在Graph
中生成指向m
的call_module
节点;否则正常调用该Module
,并跟踪其forward
函数中的操作。
可通过重写此方法实现自定义行为,例如:
- 创建嵌套的追踪GraphModules
- 实现跨
Module
边界追踪时的特殊处理
参数说明:
m (Module)
- 当前被调用的模块实例forward (Callable)
- 待调用模块的forward()方法args (Tuple)
- 模块调用点的参数元组kwargs (Dict)
- 模块调用点的关键字参数字典
返回值:
- 若生成
call_module
节点,则返回Proxy
代理值 - 否则返回模块调用的原始结果
返回类型:任意类型
注意:本API保证向后兼容性。
create_arg(a)
一种方法,用于指定在准备值作为Graph
中节点的参数时追踪的行为。
默认行为包括:
1、遍历集合类型(如元组、列表、字典)并递归地对元素调用create_args
。
2、给定一个Proxy对象,返回底层IR Node
的引用。
3、给定一个非Proxy的Tensor对象,为以下情况生成IR:
对于Parameter,生成一个引用该Parameter的
get_attr
节点。对于非Parameter的Tensor,将该Tensor存储在一个特殊属性中,并引用该属性。
可以重写此方法以支持更多类型。
参数
a (Any)
– 将被作为Argument
在Graph
中使用的值。
返回值:将值a
转换为适当的Argument
。
返回类型:Argument
注意:此API保证向后兼容。
create_args_for_root(root_fn, is_module, concrete_args=None)
为root
模块的签名创建对应的placeholder
节点。该方法会检查root模块的签名并据此生成这些节点,同时支持*args
和**kwargs
参数。
警告:此API为实验性质,且不向后兼容。
create_node(kind, target, args, kwargs, name=None, type_expr=None)
根据给定的目标、参数、关键字参数和名称插入一个图节点。
该方法可以被重写,用于在节点创建过程中对使用的值进行额外检查、验证或修改。例如,可能希望禁止记录原地操作。
注意:此API保证向后兼容性。
返回类型:Node
create_proxy(kind, target, args, kwargs, name=None, type_expr=None, proxy_factory_fn=None)
根据给定的参数创建一个节点,然后返回包裹在 Proxy 对象中的节点。
如果 kind = ‘placeholder’,则表示我们正在创建一个代表函数参数的节点。若需要编码默认参数,则使用 args
元组。对于 placeholder
类型的节点,args
在其他情况下为空。
注意:此 API 保证向后兼容性。
get_fresh_qualname(prefix)
获取一个基于前缀的新名称并返回。该函数确保生成的名称不会与图中现有属性发生冲突。
注意:此API保证向后兼容。
返回类型:str
getattr(attr, attr_val, parameter_proxy_cache)
该方法定义了当对nn.Module
实例调用getattr时,该Tracer
的行为表现。
默认情况下,其行为是返回该属性的代理值。同时会将代理值存入parameter_proxy_cache
中,以便后续调用能复用该代理而非新建。
可通过重写此方法来实现不同行为——例如在查询参数时不返回代理。
参数说明:
attr (str)
- 被查询的属性名称attr_val (Any)
- 该属性的值parameter_proxy_cache (Dict[str, Any])
- 属性名到代理值的映射缓存
返回值:
getattr调用的返回结果。
警告:此API属于实验性质,且不保证向后兼容。
is_leaf_module(m, module_qualified_name)
一种用于判断给定nn.Module
是否为"叶子"模块的方法。
叶子模块是指出现在IR(中间表示)中的原子单元,通过call_module
调用进行引用。默认情况下,PyTorch标准库命名空间(torch.nn)中的模块都属于叶子模块。除非通过本参数特别指定,否则其他所有模块都会被追踪并记录其组成操作。
参数说明:
m (Module)
- 被查询的模块module_qualified_name (str)
- 该模块到根模块的路径。例如,若模块层级结构中子模块foo
包含子模块bar
,而bar
又包含子模块baz
,则该模块的限定名将显示为foo.bar.baz
返回类型:bool
注意:本API保证向后兼容性。
iter(obj)
当代理对象被迭代时调用,例如在控制流中使用时。通常我们不知道如何处理,因为我们不知道代理的值,但自定义跟踪器可以通过 create_node
向图节点附加更多信息,并可以选择返回一个迭代器。
注意:此 API 保证向后兼容性。
返回类型:迭代器
keys(obj)
当代理对象的 keys()
方法被调用时触发。这是在代理对象上调用 **
时发生的情况。该方法应返回一个迭代器,如果 **
需要在自定义追踪器中生效。
注意:此 API 保证向后兼容。
返回类型:任意
path_of_module(mod)
这是一个辅助方法,用于在root
模块的层级结构中查找mod
的限定名称。例如,如果root
有一个名为foo
的子模块,而foo
又有一个名为bar
的子模块,那么将bar
传入此函数将返回字符串"foo.bar"。
参数
mod (str)
– 需要获取限定名称的Module
。
返回类型:str
注意:此API保证向后兼容性。
proxy(node)
注意:此 API 保证向后兼容性。
返回类型:Proxy
to_bool(obj)
当代理对象需要转换为布尔值时调用,例如在控制流中使用时。通常我们无法确定如何处理,因为不知道代理的具体值,但自定义追踪器可以通过create_node
向图节点附加更多信息,并选择返回一个值。
注意:此API保证向后兼容。
返回类型:bool
trace(root, concrete_args=None)
追踪 root
并返回对应的 FX Graph
表示形式。root
可以是 nn.Module
实例或 Python 可调用对象。
请注意,在此调用后,self.root
可能与传入的 root
不同。例如,当向 trace()
传递自由函数时,我们会创建一个 nn.Module
实例作为根节点,并添加嵌入的常量。
参数
root (Union[Module, Callable])
– 需要追踪的Module
或函数。该参数保证向后兼容性。concrete_args (Optional[Dict[str, any]])
– 不应被视为代理的具体参数。此参数为实验性功能,其向后兼容性不作保证。
返回值:表示传入 root
语义的 Graph
对象。
返回类型:Graph
注意:此 API 保证向后兼容性。
class torch.fx.Proxy(node, tracer=None)
Proxy
对象是Node
包装器,在符号追踪过程中流经程序,并记录它们接触到的所有操作(包括torch
函数调用、方法调用和运算符)到不断增长的FX Graph中。
如果需要进行图变换,您可以在原始Node
上封装自己的Proxy
方法,这样就可以使用重载运算符向Graph
添加额外内容。
Proxy
对象不可迭代。换句话说,如果在循环中或作为*args
/**kwargs
函数参数使用Proxy
,符号追踪器会抛出错误。
有两种主要解决方法:
1、将不可追踪的逻辑提取到顶层函数中,并使用fx.wrap
进行处理。
2、如果控制流是静态的(即循环次数基于某些超参数),可以保持代码在原位,并重构为类似形式:
for i in range(self.some_hyperparameter):
indexed_item = proxied_value[i]
如需更深入了解 Proxy 的内部实现细节,请查阅 torch/fx/README.md 文件中的 “Proxy” 章节。
注意:本 API 保证向后兼容性。
class torch.fx.Interpreter(module, garbage_collect_values=True, graph=None)
解释器(Interpreter)会逐节点(Node-by-Node)执行FX图。这种模式在许多场景下非常有用,包括编写代码转换器以及分析过程。
通过重写Interpreter类中的方法,可以自定义执行行为。以下是按调用层次结构划分的可重写方法映射:
run()
+-- run_node
+-- placeholder()
+-- get_attr()
+-- call_function()
+-- call_method()
+-- call_module()
+-- output()
示例
假设我们需要将所有 torch.neg
实例与 torch.sigmoid
互换(包括它们对应的 Tensor
方法等价形式)。我们可以通过如下方式继承 Interpreter 类:
class NegSigmSwapInterpreter(Interpreter):
def call_function(self, target: Target, args: Tuple, kwargs: Dict) -Any:
if target == torch.sigmoid:
return torch.neg(args, *kwargs)
return super().call_function(target, args, kwargs)
def call_method(self, target: Target, args: Tuple, kwargs: Dict) -Any:
if target == "neg":
call_self, args_tail = args
return call_self.sigmoid(args_tail, *kwargs)
return super().call_method(target, args, kwargs)
def fn(x):
return torch.sigmoid(x).neg()
gm = torch.fx.symbolic_trace(fn)
input = torch.randn(3, 4)
result = NegSigmSwapInterpreter(gm).run(input)
torch.testing.assert_close(result, torch.neg(input).sigmoid())
参数
module ( torch.nn.Module )
– 待执行的模块garbage_collect_values ([bool])
– 是否在模块执行过程中最后一次使用后删除值。这能确保执行期间内存使用最优。可以禁用此功能,例如通过查看Interpreter.env
属性来检查执行中的所有中间值。graph (Optional[Graph])
– 如果传入该参数,解释器将执行此图而非module.graph,并使用提供的模块参数来满足任何状态请求。
注意:此API保证向后兼容性。
boxed_run(args_list)
通过解释方式运行模块并返回结果。该过程采用"boxed"调用约定,即传递一个参数列表(这些参数会被解释器自动清除),从而确保输入张量能够及时释放。
注意:本API保证向后兼容性。
call_function(target, args, kwargs)
执行一个call_function
节点并返回结果。
参数
target (Target)
– 该节点的调用目标。关于语义的详细信息请参阅Nodeargs (Tuple)
– 本次调用的位置参数元组kwargs (Dict)
– 本次调用的关键字参数字典
返回类型:任意类型
返回值: 函数调用返回的值
注意:此API保证向后兼容性。
call_method(target, args, kwargs)
执行一个 call_method
节点并返回结果。
参数
target (Target)
– 该节点的调用目标。有关语义的详细信息,请参阅 Nodeargs (Tuple)
– 该调用的位置参数元组kwargs (Dict)
– 该调用的关键字参数字典
返回类型:任意
返回值:方法调用返回的值
注意:此 API 保证向后兼容性。
call_module(target, args, kwargs)
执行一个call_module
节点并返回结果。
参数
target (Target)
– 该节点的调用目标。关于语义的详细信息请参阅
Nodeargs (Tuple)
– 本次调用的位置参数元组kwargs (Dict)
– 本次调用的关键字参数字典
返回类型:Any
返回值:模块调用返回的值
注意:此API保证向后兼容性。
fetch_args_kwargs_from_env(n)
从当前执行环境中获取节点n
的args
和kwargs
具体值
参数
n (Node)
– 需要获取args
和kwargs
的目标节点
返回值
节点n
对应的具体args
和kwargs
值
返回类型:Tuple[Tuple, Dict]
注意:本API保证向后兼容性
fetch_attr(target)
从 self.module
的 Module
层级结构中获取一个属性。
参数
target (str)
- 要获取属性的全限定名称
返回
该属性的值。
返回类型
任意类型
注意:此 API 保证向后兼容。
get_attr(target, args, kwargs)
执行一个 get_attr
节点。该操作会从 self.module
的 Module
层级结构中获取属性值。
参数
target (Target)
– 该节点的调用目标。关于语义的详细信息请参阅 Nodeargs (Tuple)
– 本次调用的位置参数元组kwargs (Dict)
– 本次调用的关键字参数字典
返回值
获取到的属性值
返回类型
任意类型
注意:此 API 保证向后兼容性。
map_nodes_to_values(args, n)
递归遍历 args
并在当前执行环境中查找每个 Node
的具体值。
参数
args (Argument)
– 需要查找具体值的数据结构n (Node)
–args
所属的节点。仅用于错误报告。
返回类型:Optional[Union [tuple [Argument’, …], Sequence [Argument], Mapping [str , Argument], slice , range , Node, str , int , float, bool , complex , dtype, Tensor , device , memory_format , layout , OpOverload, SymInt, SymBool , SymFloat ]]
注意:此 API 保证向后兼容性。
output(target, args, kwargs)
执行一个output
节点。该操作实际上只是获取output
节点引用的值并返回它。
参数
target (Target)
– 该节点的调用目标。有关语义详情请参阅
Nodeargs (Tuple)
– 本次调用的位置参数元组kwargs (Dict)
– 本次调用的关键字参数字典
返回值:输出节点引用的返回值
返回类型:任意类型
注意:此API保证向后兼容。
placeholder(target, args, kwargs)
执行一个placeholder
节点。请注意这是有状态的:
Interpreter
内部维护了一个针对run
方法传入参数的迭代器,本方法会返回该迭代器的next()结果。
参数
target (Target)
– 该节点的调用目标。关于语义的详细信息请参阅Nodeargs (Tuple)
– 本次调用的位置参数元组kwargs (Dict)
– 本次调用的关键字参数字典
返回值:获取到的参数值。
返回类型:任意类型
注意:此API保证向后兼容。
run(*args, initial_env=None, enable_io_processing=True)
通过解释执行模块并返回结果。
参数
*args
– 按位置顺序传递给模块的运行参数initial_env (Optional[Dict[Node, Any]])
– 可选的执行初始环境。这是一个将节点映射到任意值的字典。例如,可用于预先填充某些节点的结果,从而在解释器中仅进行部分求值。enable_io_processing ([bool])
– 如果为true,我们会在使用输入和输出之前,先用图的process_inputs和process_outputs函数对它们进行处理。
返回值:执行模块后返回的值
返回类型:任意
注意:此API保证向后兼容。
run_node(n)
运行特定节点 n
并返回结果。
根据 node.op
的类型,调用对应的占位符、get_attr、call_function、call_method、call_module 或 output 方法。
参数
n (Node)
– 需要执行的节点
返回值:执行节点 n
的结果
返回类型:任意类型
注意:此 API 保证向后兼容性。
class torch.fx.Transformer(module)
Transformer
是一种特殊类型的解释器,用于生成新的 Module
。它提供了一个 transform()
方法,返回转换后的 Module
。与 Interpreter
不同,Transformer
不需要参数即可运行,完全基于符号化方式工作。
示例
假设我们需要将所有 torch.neg
实例与 torch.sigmoid
互换(包括它们的 Tensor
方法等效形式)。可以通过如下方式子类化 Transformer
:
class NegSigmSwapXformer(Transformer):
def call_function(
self, target: "Target", args: Tuple[Argument, ...], kwargs: Dict[str, Any]
) -> Any:
if target == torch.sigmoid:
return torch.neg(*args, **kwargs)
return super().call_function(target, args, kwargs)
def call_method(
self, target: "Target", args: Tuple[Argument, ...], kwargs: Dict[str, Any]
) -> Any:
if target == "neg":
call_self, *args_tail = args
return call_self.sigmoid(*args_tail, **kwargs)
return super().call_method(target, args, kwargs)
def fn(x):
return torch.sigmoid(x).neg()
gm = torch.fx.symbolic_trace(fn)
transformed: torch.nn.Module = NegSigmSwapXformer(gm).transform()
input = torch.randn(3, 4)
torch.testing.assert_close(transformed(input), torch.neg(input).sigmoid())
参数
module ([GraphModule](https://pytorch.org/docs/stable/data.html#torch.fx.GraphModule "torch.fx.GraphModule"))
– 待转换的Module
对象。
注意:此API保证向后兼容性。
call_function(target, args, kwargs)
注意:该 API 保证向后兼容。
返回类型
Any
call_module(target, args, kwargs)
注意:此 API 保证向后兼容。
返回类型
Any
get_attr(target, args, kwargs)
执行一个 get_attr
节点。在 Transformer
中,该方法被重写以便向输出图中插入新的 get_attr
节点。
参数
target (Target)
– 该节点的调用目标。关于语义的详细信息请参阅
Nodeargs (Tuple)
– 该调用的位置参数元组kwargs (Dict)
– 该调用的关键字参数字典
返回类型
Proxy
注意:此 API 保证向后兼容。
placeholder(target, args, kwargs)
执行一个 placeholder
节点。在 Transformer
中,该方法被重写以便向输出图中插入新的 placeholder
。
参数
target (Target)
– 该节点的调用目标。关于语义的详细信息请参阅 Nodeargs (Tuple)
– 该调用的位置参数元组kwargs (Dict)
– 该调用的关键字参数字典
返回类型:Proxy
注意:此 API 保证向后兼容。
transform()
转换 self.module
并返回转换后的 GraphModule
。
注意:此 API 保证向后兼容性。
返回类型 : GraphModule
torch.fx.replace_pattern(gm, pattern, replacement)
在GraphModule的图结构(gm
)中,匹配所有可能的非重叠运算符集及其数据依赖关系(pattern
),然后将每个匹配到的子图替换为另一个子图(replacement
)。
参数
gm (GraphModule)
- 封装待操作图的GraphModulepattern (Union[Callable, GraphModule])
- 需要在gm
中匹配并替换的子图replacement (Union[Callable, GraphModule])
- 用于替换pattern
的子图
返回值:返回一个Match
对象列表,表示原始图中与pattern
匹配的位置。如果没有匹配项则返回空列表。Match
定义如下:
class Match(NamedTuple):
# Node from which the match was found
anchor: Node
# Maps nodes in the pattern subgraph to nodes in the larger graph
nodes_map: Dict[Node, Node]
返回类型:List[Match]
示例:
import torch
from torch.fx import symbolic_trace, subgraph_rewriter
class M(torch.nn.Module):
def __init__(self) -None:
super().__init__()
def forward(self, x, w1, w2):
m1 = torch.cat([w1, w2]).sum()
m2 = torch.cat([w1, w2]).sum()
return x + torch.max(m1) + torch.max(m2)
def pattern(w1, w2):
return torch.cat([w1, w2])
def replacement(w1, w2):
return torch.stack([w1, w2])
traced_module = symbolic_trace(M())
subgraph_rewriter.replace_pattern(traced_module, pattern, replacement)
上述代码会先在 traced_module
的 forward
方法中匹配 pattern
。模式匹配基于使用-定义关系而非节点名称进行。例如,若 pattern
中包含 p = torch.cat([a, b])
,则可以在原始 forward
函数中匹配到 m = torch.cat([a, b])
,即使变量名不同(p
与 m
)也不影响。
pattern
中的 return
语句仅根据其值进行匹配,它可能与更大图中的 return
语句匹配,也可能不匹配。换句话说,模式不必延伸至更大图的末尾。
当模式匹配成功时,它将从更大的函数中被移除,并由 replacement
替换。如果更大函数中存在多个 pattern
匹配项,每个非重叠的匹配项都会被替换。若出现匹配重叠的情况,则替换重叠匹配集中最先找到的匹配项(此处的"最先"定义为节点使用-定义关系拓扑排序中的第一个节点。大多数情况下,第一个节点是紧接 self
后出现的参数,而最后一个节点是函数返回的内容)。
需要特别注意:pattern
可调用对象的参数必须在该可调用对象内部使用,且 replacement
可调用对象的参数必须与模式匹配。第一条规则解释了为何上述代码块中 forward
函数有参数 x, w1, w2
,而 pattern
函数只有参数 w1, w2
——因为 pattern
未使用 x
,故不应将 x
指定为参数。
关于第二条规则的示例,考虑替换…
def pattern(x, y):
return torch.neg(x) + torch.relu(y)
with
def replacement(x, y):
return torch.relu(x)
在这种情况下,replacement
需要与pattern
相同数量的参数(包括x
和y
),即使参数y
在replacement
中并未使用。
调用subgraph_rewriter.replace_pattern
后,生成的Python代码如下所示:
def forward(self, x, w1, w2):
stack_1 = torch.stack([w1, w2])
sum_1 = stack_1.sum()
stack_2 = torch.stack([w1, w2])
sum_2 = stack_2.sum()
max_1 = torch.max(sum_1)
add_1 = x + max_1
max_2 = torch.max(sum_2)
add_2 = add_1 + max_2
return add_2
注意:该 API 保证向后兼容。
torch.fx.experimental
警告:这些API属于实验性质,可能会随时变更而不另行通知。
torch.fx.experimental.symbolic_shapes
ShapeEnv |
|
---|---|
DimDynamic |
控制如何为维度分配符号。 |
StrictMinMaxConstraint |
对客户端:该维度的大小必须在’vr’范围内(指定包含性上下界),且必须为非负数且不应为0或1(但参见下方注意事项)。 |
RelaxedUnspecConstraint |
对客户端:无显式约束;约束由追踪过程中的守卫隐式推断得出。 |
EqualityConstraint |
表示并判定输入源之间的各类相等性约束。 |
SymbolicContext |
数据结构,指定在create_symbolic_sizes_strides_storage_offset 中如何创建符号;例如,应为静态还是动态。 |
StatelessSymbolicContext |
通过DimDynamic 和DimConstraint 给定的symbolic_context判定,在create_symbolic_sizes_strides_storage_offset 中创建符号。 |
StatefulSymbolicContext |
通过Source:Symbol缓存给定的symbolic_context判定,在create_symbolic_sizes_strides_storage_offset 中创建符号。 |
SubclassSymbolicContext |
可追踪张量子类的内部张量的正确符号上下文可能与外部符号上下文不同。 |
DimConstraints |
针对符号维度约束系统的自定义求解器。 |
ShapeEnvSettings |
封装所有可能影响FakeTensor调度的形状环境设置。 |
ConvertIntKey |
|
CallMethodKey |
|
PropagateUnbackedSymInts |
|
DivideByKey |
|
InnerTensorKey |
|
hint_int |
获取整数的提示值(基于运行时观察到的底层实际值)。 |
is_concrete_int |
检查SymInt底层对象是否为具体值的实用工具。 |
is_concrete_bool |
检查SymBool底层对象是否为具体值的实用工具。 |
is_concrete_float |
检查SymInt底层对象是否为具体值的实用工具。 |
has_free_symbols |
bool(free_symbols(val))的快速版本 |
has_free_unbacked_symbols |
bool(free_unbacked_symbols(val))的快速版本 |
definitely_true |
仅当能确定a为True时返回True,过程中可能引入守卫。 |
definitely_false |
仅当能确定a为False时返回True,过程中可能引入守卫。 |
guard_size_oblivious |
以无视大小的方式对符号布尔表达式执行守卫。 |
sym_eq |
类似==,但在列表/元组上运行时,会递归测试相等性并使用sym_and连接结果,不引入守卫。 |
constrain_range |
应用约束使传入的SymInt必须在min-max范围内(包含边界),且不引入SymInt的守卫(意味着可用于未绑定的SymInt)。 |
constrain_unify |
给定两个SymInt,约束它们必须相等。 |
canonicalize_bool_expr |
通过将布尔表达式转换为lt/le不等式并将所有非常量项移至右侧,实现规范化。 |
statically_known_true |
如果x可简化为常量且为真,则返回True。 |
lru_cache |
|
check_consistent |
测试两个"meta"值(通常为Tensor或SymInt)是否具有相同的值,例如在重追踪后。 |
compute_unbacked_bindings |
在运行fake tensor传播并生成example_value结果后,遍历example_value查找新绑定的未支持符号并记录其路径供后续使用。 |
rebind_unbacked |
假设我们正在重追踪一个已有FX图,该图先前进行过fake tensor传播(因此存在未支持的SymInt)。 |
resolve_unbacked_bindings |
|
is_accessor_node |
torch.fx.experimental.proxy_tensor
make_fx |
给定函数f,返回一个新函数。当使用有效参数执行该函数时,会返回一个FX GraphModule,表示执行过程中所执行的操作集合。 |
---|---|
handle_sym_dispatch |
调用当前活动的代理跟踪模式,对操作SymInt/SymFloat/SymBool参数的函数进行符号调度跟踪。 |
get_proxy_mode |
获取当前活动的代理跟踪模式,如果当前未处于跟踪状态则返回None。 |
maybe_enable_thunkify |
在此上下文管理器内,如果正在进行make_fx跟踪,将对所有SymNode计算进行thunkify处理,并避免将其跟踪到图中,除非确实需要。 |
maybe_disable_thunkify |
在某个上下文中禁用thunkification功能。 |
2025-05-10(六)