评价此页

torch.export 教程#

创建日期:2023年10月02日 | 最后更新:2026年03月24日 | 最后验证:2024年11月05日

作者: William Wen, Zhengxu Chen, Angela Yi, Pian Pawakapan

警告

torch.export 及其相关功能目前处于原型(prototype)阶段,后续可能会有不兼容的变更。本教程提供截至 PyTorch 2.5 版本时 torch.export 的用法概览。

torch.export() 是 PyTorch 2.X 中将 PyTorch 模型导出为标准化模型表示的方法,旨在不同环境(即无 Python 环境)中运行。官方文档可在此处 找到

在本教程中,你将学习如何使用 torch.export() 从 PyTorch 程序中提取 ExportedProgram(即单图表示)。我们还将详细说明为了使你的模型与 torch.export 兼容,可能需要进行的一些考虑事项/修改。

内容

基本用法#

torch.export 通过跟踪目标函数,并结合示例输入,从 PyTorch 程序中提取单图表示。torch.export.export()torch.export 的主要入口。

在本教程中,torch.exporttorch.export.export() 实际上是同义的,尽管 torch.export 通常指代 PyTorch 2.X 的导出流程,而 torch.export.export() 通常指具体的函数调用。

torch.export.export() 的签名如下:

export(
    mod: torch.nn.Module,
    args: Tuple[Any, ...],
    kwargs: Optional[Dict[str, Any]] = None,
    *,
    dynamic_shapes: Optional[Dict[str, Dict[int, Dim]]] = None
) -> ExportedProgram

torch.export.export() 通过调用 mod(*args, **kwargs) 来跟踪张量计算图,并将其封装在一个 ExportedProgram 中。该程序可以被序列化,或在后续使用不同输入进行执行。要执行 ExportedProgram,我们可以调用其 .module() 方法返回一个可调用的 torch.nn.Module,就像原始程序一样。我们将在教程后面详细介绍 dynamic_shapes 参数。

import torch
from torch.export import export

class MyModule(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.lin = torch.nn.Linear(100, 10)

    def forward(self, x, y):
        return torch.nn.functional.relu(self.lin(x + y), inplace=True)

mod = MyModule()
exported_mod = export(mod, (torch.randn(8, 100), torch.randn(8, 100)))
print(type(exported_mod))
print(exported_mod.module()(torch.randn(8, 100), torch.randn(8, 100)))
<class 'torch.export.exported_program.ExportedProgram'>
tensor([[0.6300, 0.0000, 0.0000, 0.0000, 1.3291, 0.0000, 0.5392, 0.3743, 0.0000,
         1.2836],
        [0.3362, 0.2798, 1.1324, 1.0829, 0.2591, 0.1500, 0.8622, 0.0000, 1.6608,
         0.0000],
        [1.2444, 0.4373, 0.0000, 0.0000, 0.0000, 0.0000, 1.3091, 0.0000, 0.0000,
         1.3469],
        [2.0964, 0.4620, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.2557,
         0.5486],
        [0.0000, 0.7975, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.4902,
         0.6657],
        [0.8525, 0.4503, 0.3919, 0.0000, 1.0072, 0.1359, 0.0000, 0.0000, 0.0000,
         0.0000],
        [0.0000, 0.0000, 0.0000, 0.0000, 0.4611, 0.0000, 0.6018, 1.0647, 0.0000,
         0.0000],
        [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.9631, 0.0000, 0.6032, 0.6220,
         0.1830]], grad_fn=<ReluBackward0>)

让我们回顾一下 ExportedProgram 中一些值得关注的属性。

graph 属性是从我们导出的函数中跟踪得到的 FX 图,即所有 PyTorch 操作的计算图。该 FX 图处于“ATen IR”级别,意味着它仅包含“ATen 级别”的操作。

graph_signature 属性提供了导出图中输入和输出节点的详细描述,说明哪些是参数、缓冲区、用户输入或用户输出。

range_constraints 属性将在后面介绍。

print(exported_mod)
ExportedProgram:
    class GraphModule(torch.nn.Module):
        def forward(self, p_lin_weight: "f32[10, 100]", p_lin_bias: "f32[10]", x: "f32[8, 100]", y: "f32[8, 100]"):
            # File: /var/lib/workspace/intermediate_source/torch_export_tutorial.py:71 in forward, code: return torch.nn.functional.relu(self.lin(x + y), inplace=True)
            add: "f32[8, 100]" = torch.ops.aten.add.Tensor(x, y);  x = y = None

            # File: /usr/local/lib/python3.10/dist-packages/torch/nn/modules/linear.py:134 in forward, code: return F.linear(input, self.weight, self.bias)
            linear: "f32[8, 10]" = torch.ops.aten.linear.default(add, p_lin_weight, p_lin_bias);  add = p_lin_weight = p_lin_bias = None

            # File: /var/lib/workspace/intermediate_source/torch_export_tutorial.py:71 in forward, code: return torch.nn.functional.relu(self.lin(x + y), inplace=True)
            relu_: "f32[8, 10]" = torch.ops.aten.relu_.default(linear);  linear = None
            return (relu_,)

Graph signature:
    # inputs
    p_lin_weight: PARAMETER target='lin.weight'
    p_lin_bias: PARAMETER target='lin.bias'
    x: USER_INPUT
    y: USER_INPUT

    # outputs
    relu_: USER_OUTPUT

Range constraints: {}

更多详细信息,请参阅 torch.export文档

图断点 (Graph Breaks)#

尽管 torch.exporttorch.compile 共享组件,但 torch.export 的一个关键局限在于(特别是在与 torch.compile 相比时)它不支持图断点。这是因为处理图断点涉及使用默认 Python 执行方式解释不支持的操作,这与导出用例不兼容。因此,为了使你的模型代码与 torch.export 兼容,你需要修改代码以移除图断点。

以下情况必须避免图断点:

  • 数据依赖的控制流

class Bad1(torch.nn.Module):
    def forward(self, x):
        if x.sum() > 0:
            return torch.sin(x)
        return torch.cos(x)

import traceback as tb
try:
    export(Bad1(), (torch.randn(3, 3),))
except Exception:
    tb.print_exc()
def forward(self, arg0_1: "f32[3, 3]"):
    # File: /var/lib/workspace/intermediate_source/torch_export_tutorial.py:116 in forward, code: if x.sum() > 0:
    sum_1: "f32[]" = torch.ops.aten.sum.default(arg0_1);  arg0_1 = None
    gt: "b8[]" = torch.ops.aten.gt.Scalar(sum_1, 0);  sum_1 = None
    ne: "b8[]" = torch.ops.aten.ne.Scalar(gt, 0);  gt = None
    item: "Sym(Eq(u0, 1))" = torch.ops.aten.item.default(ne);  ne = item = None




def forward(self, arg0_1: "f32[3, 3]"):
    # File: /var/lib/workspace/intermediate_source/torch_export_tutorial.py:116 in forward, code: if x.sum() > 0:
    sum_1: "f32[]" = torch.ops.aten.sum.default(arg0_1);  arg0_1 = None
    gt: "b8[]" = torch.ops.aten.gt.Scalar(sum_1, 0);  sum_1 = None
    ne: "b8[]" = torch.ops.aten.ne.Scalar(gt, 0);  gt = None
    item: "Sym(Eq(u0, 1))" = torch.ops.aten.item.default(ne);  ne = item = None

Traceback (most recent call last):
  File "/var/lib/workspace/intermediate_source/torch_export_tutorial.py", line 122, in <module>
    export(Bad1(), (torch.randn(3, 3),))
  File "/usr/local/lib/python3.10/dist-packages/torch/export/__init__.py", line 205, in export
    raise e
  File "/usr/local/lib/python3.10/dist-packages/torch/export/__init__.py", line 171, in export
    return _export(
  File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 1344, in wrapper
    raise e
  File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 1310, in wrapper
    ep = fn(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/export/exported_program.py", line 124, in wrapper
    return fn(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/_utils_internal.py", line 96, in wrapper_function
    return function(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 2512, in _export
    ep = _export_for_training(
  File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 1344, in wrapper
    raise e
  File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 1310, in wrapper
    ep = fn(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/export/exported_program.py", line 124, in wrapper
    return fn(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 2300, in _export_for_training
    export_artifact = export_func(
  File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 2229, in _non_strict_export
    aten_export_artifact = _to_aten_func(
  File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 2006, in _export_to_aten_ir_make_fx
    gm, graph_signature = transform(_make_fx_helper)(
  File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 2136, in _aot_export_non_strict
    gm, sig = aot_export(stack, wrapped_mod, args, kwargs=kwargs, **flags)
  File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 1914, in _make_fx_helper
    gm = make_fx(
  File "/usr/local/lib/python3.10/dist-packages/torch/fx/experimental/proxy_tensor.py", line 2965, in wrapped
    return make_fx_tracer.trace(f, *args)
  File "/usr/local/lib/python3.10/dist-packages/torch/fx/experimental/proxy_tensor.py", line 2867, in trace
    return self._trace_inner(f, *args)
  File "/usr/local/lib/python3.10/dist-packages/torch/fx/experimental/proxy_tensor.py", line 2828, in _trace_inner
    t = dispatch_trace(
  File "/usr/local/lib/python3.10/dist-packages/torch/_compile.py", line 54, in inner
    return disable_fn(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/eval_frame.py", line 1297, in _fn
    return fn(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/fx/experimental/proxy_tensor.py", line 1673, in dispatch_trace
    graph = tracer.trace(root, concrete_args)  # type: ignore[arg-type]
  File "/usr/local/lib/python3.10/dist-packages/torch/fx/experimental/proxy_tensor.py", line 2402, in trace
    res = super().trace(root, concrete_args)
  File "/usr/local/lib/python3.10/dist-packages/torch/fx/_symbolic_trace.py", line 912, in trace
    (self.create_arg(fn(*args)),),
  File "/usr/local/lib/python3.10/dist-packages/torch/fx/experimental/proxy_tensor.py", line 1743, in wrapped
    out = f(*tensors)  # type:ignore[call-arg]
  File "<string>", line 1, in <lambda>
  File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 1798, in wrapped_fn
    return tuple(flat_fn(*args))
  File "/usr/local/lib/python3.10/dist-packages/torch/_functorch/_aot_autograd/utils.py", line 192, in flat_fn
    tree_out = fn(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/_functorch/_aot_autograd/graph_capture_wrappers.py", line 1536, in functional_call
    out = mod(*args[params_len:], **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/fx/_symbolic_trace.py", line 886, in module_call_wrapper
    return self.call_module(mod, forward, args, kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/fx/experimental/proxy_tensor.py", line 2491, in call_module
    return Tracer.call_module(self, m, forward, args, kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/fx/_symbolic_trace.py", line 577, in call_module
    ret_val = forward(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/fx/_symbolic_trace.py", line 879, in forward
    return _orig_module_call(mod, *args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1778, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1789, in _call_impl
    return forward_call(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 2120, in forward
    tree_out = mod(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/fx/_symbolic_trace.py", line 886, in module_call_wrapper
    return self.call_module(mod, forward, args, kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/fx/experimental/proxy_tensor.py", line 2491, in call_module
    return Tracer.call_module(self, m, forward, args, kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/fx/_symbolic_trace.py", line 577, in call_module
    ret_val = forward(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/fx/_symbolic_trace.py", line 879, in forward
    return _orig_module_call(mod, *args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1778, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1789, in _call_impl
    return forward_call(*args, **kwargs)
  File "/var/lib/workspace/intermediate_source/torch_export_tutorial.py", line 116, in forward
    if x.sum() > 0:
  File "/usr/local/lib/python3.10/dist-packages/torch/fx/experimental/proxy_tensor.py", line 1798, in __torch_function__
    return func(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/fx/experimental/proxy_tensor.py", line 1869, in __torch_function__
    return func(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/_export/non_strict_utils.py", line 1167, in __torch_function__
    return func(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/fx/experimental/sym_node.py", line 576, in guard_bool
    r = self.evaluate()
  File "/usr/local/lib/python3.10/dist-packages/torch/fx/experimental/sym_node.py", line 550, in evaluate
    return self.shape_env.evaluate_sym_node(self, size_oblivious)
  File "/usr/local/lib/python3.10/dist-packages/torch/fx/experimental/symbolic_shapes.py", line 7860, in evaluate_sym_node
    return self.evaluate_expr(
  File "/usr/local/lib/python3.10/dist-packages/torch/fx/experimental/symbolic_shapes.py", line 7956, in evaluate_expr
    return self._inner_evaluate_expr(
  File "/usr/local/lib/python3.10/dist-packages/torch/fx/experimental/recording.py", line 297, in wrapper
    return retlog(fn(*args, **kwargs))
  File "/usr/local/lib/python3.10/dist-packages/torch/fx/experimental/symbolic_shapes.py", line 7979, in _inner_evaluate_expr
    return self._evaluate_expr(
  File "/usr/local/lib/python3.10/dist-packages/torch/fx/experimental/symbolic_shapes.py", line 8212, in _evaluate_expr
    raise self._make_data_dependent_error(
torch.fx.experimental.symbolic_shapes.GuardOnDataDependentSymNode: Could not guard on data-dependent expression Eq(u0, 1) (unhinted: Eq(u0, 1)).  (Size-like symbols: none)

consider using data-dependent friendly APIs such as guard_or_false, guard_or_true and statically_known_true.
Caused by: (_export/non_strict_utils.py:1167 in __torch_function__)
For more information, run with TORCH_LOGS="dynamic"
For extended logs when we create symbols, also add TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL="u0"
If you suspect the guard was triggered from C++, add TORCHDYNAMO_EXTENDED_DEBUG_CPP=1
For more debugging help, see https://docs.google.com/document/d/1HSuTTVvYH1pTew89Rtpeu84Ht3nQEFTYhAX3Ypa_xJs/edit?usp=sharing

For C++ stack trace, run with TORCHDYNAMO_EXTENDED_DEBUG_CPP=1

The following call raised this error:
  File "/var/lib/workspace/intermediate_source/torch_export_tutorial.py", line 116, in forward
    if x.sum() > 0:


The error above occurred when calling torch.export.export. If you would like to view some more information about this error, and get a list of all other errors that may occur in your export call, you can replace your `export()` call with `draft_export()`.
  • 使用 .data 访问张量数据

class Bad2(torch.nn.Module):
    def forward(self, x):
        x.data[0, 0] = 3
        return x

try:
    export(Bad2(), (torch.randn(3, 3),))
except Exception:
    tb.print_exc()
Traceback (most recent call last):
  File "/var/lib/workspace/intermediate_source/torch_export_tutorial.py", line 135, in <module>
    export(Bad2(), (torch.randn(3, 3),))
  File "/usr/local/lib/python3.10/dist-packages/torch/export/__init__.py", line 205, in export
    raise e
  File "/usr/local/lib/python3.10/dist-packages/torch/export/__init__.py", line 171, in export
    return _export(
  File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 1344, in wrapper
    raise e
  File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 1310, in wrapper
    ep = fn(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/export/exported_program.py", line 124, in wrapper
    return fn(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/_utils_internal.py", line 96, in wrapper_function
    return function(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 2512, in _export
    ep = _export_for_training(
  File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 1344, in wrapper
    raise e
  File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 1310, in wrapper
    ep = fn(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/export/exported_program.py", line 124, in wrapper
    return fn(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 2326, in _export_for_training
    raise RuntimeError(error_msg)
RuntimeError: We found a fake tensor in the exported program constant's list. This typically means our tracing system encountered an op that we can't trace through. For the potential source, you can refer to following model attribute: lifted_tensor_1. Please file an issue on github.
  • 调用不支持的函数(例如许多内置函数)

class Bad3(torch.nn.Module):
    def forward(self, x):
        x = x + 1
        return x + id(x)

try:
    export(Bad3(), (torch.randn(3, 3),))
except Exception:
    tb.print_exc()

非严格导出#

为了跟踪程序,torch.export 默认使用 TorchDynamo(一个字节码分析引擎)来符号化分析 Python 代码并据此构建图。这种分析使得 torch.export 能够提供更强的安全性保证,但并非所有 Python 代码都受支持,这会导致图断点。

为了解决这个问题,我们在 PyTorch 2.3 中引入了一种称为“非严格模式(non-strict mode)”的新导出模式。在此模式下,我们使用 Python 解释器跟踪程序,其执行方式与即时(eager)模式完全一致,从而允许我们跳过不支持的 Python 特性。这可以通过添加 strict=False 标志来实现。

回顾之前导致图断点的一些示例:

  • 调用不支持的函数(如许多内置函数)可以被跟踪,

但在这种情况下,id(x) 在图中会被特化为一个常量整数。这是因为 id(x) 不是张量操作,因此该操作不会被记录在图中。

class Bad3(torch.nn.Module):
    def forward(self, x):
        x = x + 1
        return x + id(x)

bad3_nonstrict = export(Bad3(), (torch.randn(3, 3),), strict=False)
print(bad3_nonstrict)
print(bad3_nonstrict.module()(torch.ones(3, 3)))
ExportedProgram:
    class GraphModule(torch.nn.Module):
        def forward(self, x: "f32[3, 3]"):
            # File: /var/lib/workspace/intermediate_source/torch_export_tutorial.py:179 in forward, code: x = x + 1
            add: "f32[3, 3]" = torch.ops.aten.add.Tensor(x, 1);  x = None

            # File: /var/lib/workspace/intermediate_source/torch_export_tutorial.py:180 in forward, code: return x + id(x)
            add_1: "f32[3, 3]" = torch.ops.aten.add.Tensor(add, 139820054845360);  add = None
            return (add_1,)

Graph signature:
    # inputs
    x: USER_INPUT

    # outputs
    add_1: USER_OUTPUT

Range constraints: {}

tensor([[1.3982e+14, 1.3982e+14, 1.3982e+14],
        [1.3982e+14, 1.3982e+14, 1.3982e+14],
        [1.3982e+14, 1.3982e+14, 1.3982e+14]])

然而,仍有一些功能需要重写原始模块。

控制流算子#

torch.export 实际上支持数据依赖的控制流,但这些必须使用控制流算子来表达。例如,我们可以使用 cond 算子修复上述控制流示例,如下所示:

class Bad1Fixed(torch.nn.Module):
    def forward(self, x):
        def true_fn(x):
            return torch.sin(x)
        def false_fn(x):
            return torch.cos(x)
        return torch.cond(x.sum() > 0, true_fn, false_fn, [x])

exported_bad1_fixed = export(Bad1Fixed(), (torch.randn(3, 3),))
print(exported_bad1_fixed)
print(exported_bad1_fixed.module()(torch.ones(3, 3)))
print(exported_bad1_fixed.module()(-torch.ones(3, 3)))
ExportedProgram:
    class GraphModule(torch.nn.Module):
        def forward(self, x: "f32[3, 3]"):
            # File: /var/lib/workspace/intermediate_source/torch_export_tutorial.py:205 in forward, code: return torch.cond(x.sum() > 0, true_fn, false_fn, [x])
            sum_1: "f32[]" = torch.ops.aten.sum.default(x)
            gt: "b8[]" = torch.ops.aten.gt.Scalar(sum_1, 0);  sum_1 = None

            # File: <eval_with_key>.36:9 in forward, code: cond = torch.ops.higher_order.cond(l_args_0_, cond_true_0, cond_false_0, (l_args_3_0_,));  l_args_0_ = cond_true_0 = cond_false_0 = l_args_3_0_ = None
            true_graph_0 = self.true_graph_0
            false_graph_0 = self.false_graph_0
            cond = torch.ops.higher_order.cond(gt, true_graph_0, false_graph_0, (x,));  gt = true_graph_0 = false_graph_0 = x = None
            getitem: "f32[3, 3]" = cond[0];  cond = None
            return (getitem,)

        class true_graph_0(torch.nn.Module):
            def forward(self, x: "f32[3, 3]"):
                # File: <eval_with_key>.37:6 in forward, code: sin = torch.sin(l_args_3_0__1);  l_args_3_0__1 = None
                sin: "f32[3, 3]" = torch.ops.aten.sin.default(x);  x = None
                return (sin,)

        class false_graph_0(torch.nn.Module):
            def forward(self, x: "f32[3, 3]"):
                # File: <eval_with_key>.38:6 in forward, code: cos = torch.cos(l_args_3_0__1);  l_args_3_0__1 = None
                cos: "f32[3, 3]" = torch.ops.aten.cos.default(x);  x = None
                return (cos,)

Graph signature:
    # inputs
    x: USER_INPUT

    # outputs
    getitem: USER_OUTPUT

Range constraints: {}

tensor([[0.8415, 0.8415, 0.8415],
        [0.8415, 0.8415, 0.8415],
        [0.8415, 0.8415, 0.8415]])
tensor([[0.5403, 0.5403, 0.5403],
        [0.5403, 0.5403, 0.5403],
        [0.5403, 0.5403, 0.5403]])

使用 cond 时需要注意以下限制:

  • 谓词(例如 x.sum() > 0)必须产生布尔值或单元素张量。

  • 操作数(例如 [x])必须是张量。

  • 分支函数(例如 true_fnfalse_fn)的签名必须与操作数匹配,且它们必须返回具有相同元数据(例如 dtype, shape 等)的单个张量。

  • 分支函数不能修改输入或全局变量。

  • 分支函数不能访问闭包变量,但如果函数是在方法范围内定义的,则可以访问 self

关于 cond 的更多详细信息,请查看 cond 文档

我们还可以使用 map,它将一个函数应用于第一个张量参数的第一个维度。

from torch._higher_order_ops.map import map as torch_map

class MapModule(torch.nn.Module):
    def forward(self, xs, y, z):
        def body(x, y, z):
            return x + y + z

        return torch_map(body, xs, y, z)

inps = (torch.ones(6, 4), torch.tensor(5), torch.tensor(4))
exported_map_example = export(MapModule(), inps)
print(exported_map_example)
print(exported_map_example.module()(*inps))
ExportedProgram:
    class GraphModule(torch.nn.Module):
        def forward(self, xs: "f32[6, 4]", y: "i64[]", z: "i64[]"):
            # File: <eval_with_key>.96:9 in forward, code: map_impl = torch.ops.higher_order.map_impl(map_body_0, [l_flat_xs_0_], [l_flat_args_0_, l_flat_args_1_]);  map_body_0 = l_flat_xs_0_ = l_flat_args_0_ = l_flat_args_1_ = None
            body_graph_0 = self.body_graph_0
            map_impl = torch.ops.higher_order.map_impl(body_graph_0, [xs], [y, z]);  body_graph_0 = xs = y = z = None
            getitem: "f32[6, 4]" = map_impl[0];  map_impl = None
            return (getitem,)

        class body_graph_0(torch.nn.Module):
            def forward(self, xs: "f32[4]", y: "i64[]", z: "i64[]"):
                # File: <eval_with_key>.97:5 in forward, code: add = child + l_flat_args_0_;  child = l_flat_args_0_ = None
                add: "f32[4]" = torch.ops.aten.add.Tensor(xs, y);  xs = y = None

                # File: <eval_with_key>.97:6 in forward, code: add_1 = add + l_flat_args_1_;  add = l_flat_args_1_ = None
                add_1: "f32[4]" = torch.ops.aten.add.Tensor(add, z);  add = z = None
                return (add_1,)

Graph signature:
    # inputs
    xs: USER_INPUT
    y: USER_INPUT
    z: USER_INPUT

    # outputs
    getitem: USER_OUTPUT

Range constraints: {}

tensor([[10., 10., 10., 10.],
        [10., 10., 10., 10.],
        [10., 10., 10., 10.],
        [10., 10., 10., 10.],
        [10., 10., 10., 10.],
        [10., 10., 10., 10.]])

其他控制流算子包括 while_loop, associative_scanscan。关于每个算子的更多文档,请参考 此页面

约束/动态形状#

本节涵盖导出程序的动态行为和表示。动态行为取决于正在导出的特定模型,因此在本教程的大部分内容中,我们将专注于这个特定的玩具模型(并标注了最终的张量形状)。

class DynamicModel(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.l = torch.nn.Linear(5, 3)

    def forward(
        self,
        w: torch.Tensor,  # [6, 5]
        x: torch.Tensor,  # [4]
        y: torch.Tensor,  # [8, 4]
        z: torch.Tensor,  # [32]
    ):
        x0 = x + y  # [8, 4]
        x1 = self.l(w)  # [6, 3]
        x2 = x0.flatten()  # [32]
        x3 = x2 + z  # [32]
        return x1, x3

默认情况下,torch.export 生成的是静态程序。一个后果是,在运行时,该程序无法处理形状不同的输入,即使它们在即时模式下是有效的。

w = torch.randn(6, 5)
x = torch.randn(4)
y = torch.randn(8, 4)
z = torch.randn(32)
model = DynamicModel()
ep = export(model, (w, x, y, z))
model(w, x, torch.randn(3, 4), torch.randn(12))
try:
    ep.module()(w, x, torch.randn(3, 4), torch.randn(12))
except Exception:
    tb.print_exc()
Traceback (most recent call last):
  File "/var/lib/workspace/intermediate_source/torch_export_tutorial.py", line 286, in <module>
    ep.module()(w, x, torch.randn(3, 4), torch.randn(12))
  File "/usr/local/lib/python3.10/dist-packages/torch/fx/graph_module.py", line 949, in call_wrapped
    return self._wrapped_call(self, *args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/fx/graph_module.py", line 461, in __call__
    raise e
  File "/usr/local/lib/python3.10/dist-packages/torch/fx/graph_module.py", line 447, in __call__
    return super(self.cls, obj).__call__(*args, **kwargs)  # type: ignore[misc]
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1778, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1884, in _call_impl
    return inner()
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1832, in inner
    result = forward_call(*args, **kwargs)
  File "<eval_with_key>.133", line 8, in forward
    _guards_fn = self._guards_fn(w, x, y, z);  _guards_fn = None
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1778, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1789, in _call_impl
    return forward_call(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/external_utils.py", line 239, in inner
    return func(*args, **kwargs)
  File "<string>", line 6, in _
  File "/usr/local/lib/python3.10/dist-packages/torch/__init__.py", line 2272, in _assert
    raise AssertionError(message)
AssertionError: Guard failed: y.size()[0] == 8

基本概念:符号与守卫#

为了启用动态性,export() 提供了一个 dynamic_shapes 参数。使用动态形状的最简单方法是使用 Dim.AUTO 并观察返回的程序。动态行为是在输入维度级别指定的;对于每个输入,我们可以指定一个值元组。

from torch.export.dynamic_shapes import Dim

dynamic_shapes = {
    "w": (Dim.AUTO, Dim.AUTO),
    "x": (Dim.AUTO,),
    "y": (Dim.AUTO, Dim.AUTO),
    "z": (Dim.AUTO,),
}
ep = export(model, (w, x, y, z), dynamic_shapes=dynamic_shapes)

在查看生成的程序之前,让我们了解一下指定 dynamic_shapes 的含义,以及它与导出的交互方式。对于指定了 Dim 对象的每个输入维度,都会 分配 一个符号,其取值范围为 [2, inf](为什么不是 [0, inf][1, inf]?我们稍后将在“0/1 特化”一节中解释)。

导出随后运行模型跟踪,查看模型执行的每个操作。每个单独的操作都可以发出所谓的“守卫(guards)”;本质上是程序为了有效必须为真的布尔条件。当守卫涉及为输入维度分配的符号时,程序就会包含对有效输入形状的限制;即程序的动态行为。符号化形状子系统负责接收所有发出的守卫,并生成符合所有这些守卫的最终程序表示。在我们在 ExportedProgram 中看到这个“最终表示”之前,让我们先看看我们正在跟踪的玩具模型发出的守卫。

在这里,每个前向输入张量都标注了在跟踪开始时分配的符号。

class DynamicModel(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.l = torch.nn.Linear(5, 3)

    def forward(
        self,
        w: torch.Tensor,  # [s0, s1]
        x: torch.Tensor,  # [s2]
        y: torch.Tensor,  # [s3, s4]
        z: torch.Tensor,  # [s5]
    ):
        x0 = x + y  # guard: s2 == s4
        x1 = self.l(w)  # guard: s1 == 5
        x2 = x0.flatten()  # no guard added here
        x3 = x2 + z  # guard: s3 * s4 == s5
        return x1, x3

让我们了解每个操作和发出的守卫:

  • x0 = x + y:这是一个带广播的逐元素加法,因为 x 是一个一维张量,而 y 是一个二维张量。x 沿着 y 的最后一个维度进行广播,发出守卫 s2 == s4

  • x1 = self.l(w):调用 nn.Linear() 会执行与模型参数的矩阵乘法。在导出中,参数、缓冲区和常量被视为程序状态,被视为静态的,因此这是动态输入(w: [s0, s1])与静态形状张量之间的矩阵乘法。这发出了守卫 s1 == 5

  • x2 = x0.flatten():此调用实际上不会发出任何守卫!(至少没有与输入形状相关的守卫)。

  • x3 = x2 + zx2 在展平后的形状为 [s3*s4],这个逐元素加法发出了 s3 * s4 == s5

写下并总结所有这些守卫就像一个数学证明,这正是符号化形状子系统试图做的事情!总之,我们可以得出结论,程序必须具有以下输入形状才有效:

  • w: [s0, 5]

  • x: [s2]

  • y: [s3, s2]

  • z: [s2*s3]

当我们最终打印出导出的程序以查看结果时,我们在相应的输入上看到了这些标注的形状。

print(ep)
ExportedProgram:
    class GraphModule(torch.nn.Module):
        def forward(self, p_l_weight: "f32[3, 5]", p_l_bias: "f32[3]", w: "f32[s15, 5]", x: "f32[s77]", y: "f32[s17, s77]", z: "f32[s17*s77]"):
            # File: /var/lib/workspace/intermediate_source/torch_export_tutorial.py:268 in forward, code: x0 = x + y  # [8, 4]
            add: "f32[s17, s77]" = torch.ops.aten.add.Tensor(x, y);  x = y = None

            # File: /usr/local/lib/python3.10/dist-packages/torch/nn/modules/linear.py:134 in forward, code: return F.linear(input, self.weight, self.bias)
            linear: "f32[s15, 3]" = torch.ops.aten.linear.default(w, p_l_weight, p_l_bias);  w = p_l_weight = p_l_bias = None

            # File: /var/lib/workspace/intermediate_source/torch_export_tutorial.py:270 in forward, code: x2 = x0.flatten()  # [32]
            flatten: "f32[s17*s77]" = torch.ops.aten.flatten.using_ints(add);  add = None

            # File: /var/lib/workspace/intermediate_source/torch_export_tutorial.py:271 in forward, code: x3 = x2 + z  # [32]
            add_1: "f32[s17*s77]" = torch.ops.aten.add.Tensor(flatten, z);  flatten = z = None
            return (linear, add_1)

Graph signature:
    # inputs
    p_l_weight: PARAMETER target='l.weight'
    p_l_bias: PARAMETER target='l.bias'
    w: USER_INPUT
    x: USER_INPUT
    y: USER_INPUT
    z: USER_INPUT

    # outputs
    linear: USER_OUTPUT
    add_1: USER_OUTPUT

Range constraints: {s15: VR[2, int_oo], s77: VR[2, int_oo], s17: VR[2, int_oo], s17*s77: VR[4, int_oo]}

另一个值得注意的功能是上面的 range_constraints 字段,它包含每个符号的有效范围。这目前还不是很有趣,因为此导出调用不会发出与符号边界相关的任何守卫,并且每个基础符号都有一个通用边界,但这稍后会出现。

到目前为止,因为我们一直在导出这个玩具模型,所以这种体验并不能代表调试动态形状守卫和问题的难度。在大多数情况下,并不明显发出了哪些守卫,以及用户代码的哪些操作和部分是责任方。对于这个玩具模型,我们确定了确切的行,守卫也相当直观。

在更复杂的情况下,一个有益的第一步总是启用详细日志记录。这可以通过环境变量 TORCH_LOGS="+dynamic" 或通过 torch._logging.set_logs(dynamic=10) 交互式完成。

torch._logging.set_logs(dynamic=10)
ep = export(model, (w, x, y, z), dynamic_shapes=dynamic_shapes)
I0603 01:00:15.773000 28681 torch/fx/experimental/symbolic_shapes.py:4002] create_env
I0603 01:00:15.774000 28681 torch/fx/experimental/symbolic_shapes.py:5523] create_symbol s15 = 6 for L['w'].size()[0] [2, int_oo] (_export/non_strict_utils.py:221 in fakify), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL="s15" or to suppress this message run with TORCHDYNAMO_EXTENDED_ADVICE="0"
I0603 01:00:15.775000 28681 torch/fx/experimental/symbolic_shapes.py:5523] create_symbol s21 = 5 for L['w'].size()[1] [2, int_oo] (_export/non_strict_utils.py:221 in fakify), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL="s21" or to suppress this message run with TORCHDYNAMO_EXTENDED_ADVICE="0"
I0603 01:00:15.777000 28681 torch/fx/experimental/symbolic_shapes.py:5523] create_symbol s77 = 4 for L['x'].size()[0] [2, int_oo] (_export/non_strict_utils.py:221 in fakify), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL="s77" or to suppress this message run with TORCHDYNAMO_EXTENDED_ADVICE="0"
I0603 01:00:15.778000 28681 torch/fx/experimental/symbolic_shapes.py:5523] create_symbol s17 = 8 for L['y'].size()[0] [2, int_oo] (_export/non_strict_utils.py:221 in fakify), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL="s17" or to suppress this message run with TORCHDYNAMO_EXTENDED_ADVICE="0"
I0603 01:00:15.778000 28681 torch/fx/experimental/symbolic_shapes.py:5523] create_symbol s94 = 4 for L['y'].size()[1] [2, int_oo] (_export/non_strict_utils.py:221 in fakify), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL="s94" or to suppress this message run with TORCHDYNAMO_EXTENDED_ADVICE="0"
I0603 01:00:15.780000 28681 torch/fx/experimental/symbolic_shapes.py:5523] create_symbol s68 = 32 for L['z'].size()[0] [2, int_oo] (_export/non_strict_utils.py:221 in fakify), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL="s68" or to suppress this message run with TORCHDYNAMO_EXTENDED_ADVICE="0"
V0603 01:00:15.786000 28681 torch/fx/experimental/symbolic_shapes.py:8149] eval False == False [statically known]
V0603 01:00:15.786000 28681 torch/fx/experimental/symbolic_shapes.py:8149] eval False == False [statically known]
V0603 01:00:15.787000 28681 torch/fx/experimental/symbolic_shapes.py:8149] eval False == False [statically known]
V0603 01:00:15.788000 28681 torch/fx/experimental/symbolic_shapes.py:8149] eval False == False [statically known]
V0603 01:00:15.789000 28681 torch/fx/experimental/symbolic_shapes.py:8149] eval False == False [statically known]
V0603 01:00:15.790000 28681 torch/fx/experimental/symbolic_shapes.py:8149] eval False == False [statically known]
V0603 01:00:15.790000 28681 torch/fx/experimental/symbolic_shapes.py:8149] eval False == False [statically known]
V0603 01:00:15.791000 28681 torch/fx/experimental/symbolic_shapes.py:8149] eval False == False [statically known]
V0603 01:00:15.792000 28681 torch/fx/experimental/symbolic_shapes.py:8149] eval False == False [statically known]
V0603 01:00:15.792000 28681 torch/fx/experimental/symbolic_shapes.py:8149] eval False == False [statically known]
I0603 01:00:15.794000 28681 torch/fx/experimental/symbolic_shapes.py:7834] eval Eq(s77, s94) [guard added] (_subclasses/fake_impls.py:1520 in infer_size), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_GUARD_ADDED="Eq(s77, s94)"
I0603 01:00:15.795000 28681 torch/fx/experimental/symbolic_shapes.py:7427] set_replacement s94 = s77 (solve) VR[2, int_oo]
V0603 01:00:15.797000 28681 torch/fx/experimental/symbolic_shapes.py:8149] eval False == False [statically known]
I0603 01:00:15.803000 28681 torch/fx/experimental/symbolic_shapes.py:7834] eval Eq(s21, 5) [guard added] (_meta_registrations.py:2456 in meta_mm), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_GUARD_ADDED="Eq(s21, 5)"
V0603 01:00:15.803000 28681 torch/fx/experimental/symbolic_shapes.py:7256] _update_var_to_range s21 = VR[5, 5] (update)
I0603 01:00:15.804000 28681 torch/fx/experimental/symbolic_shapes.py:7427] set_replacement s21 = 5 (range_refined_to_singleton) VR[5, 5]
V0603 01:00:15.816000 28681 torch/fx/experimental/symbolic_shapes.py:8149] eval False == False [statically known]
V0603 01:00:15.818000 28681 torch/fx/experimental/symbolic_shapes.py:8149] eval False == False [statically known]
I0603 01:00:15.820000 28681 torch/fx/experimental/symbolic_shapes.py:7834] eval Eq(s17*s77, s68) [guard added] (_subclasses/fake_impls.py:1520 in infer_size), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_GUARD_ADDED="Eq(s17*s77, s68)"
V0603 01:00:15.821000 28681 torch/fx/experimental/symbolic_shapes.py:7256] _update_var_to_range s68 = VR[4, int_oo] (update)
I0603 01:00:15.822000 28681 torch/fx/experimental/symbolic_shapes.py:7427] set_replacement s68 = s17*s77 (solve) VR[4, int_oo]
I0603 01:00:15.827000 28681 torch/fx/experimental/symbolic_shapes.py:5669] produce_guards
V0603 01:00:15.827000 28681 torch/fx/experimental/symbolic_shapes.py:5902] track_symint L['w'].size()[0] s15 None
V0603 01:00:15.827000 28681 torch/fx/experimental/symbolic_shapes.py:5902] track_symint L['w'].size()[1] 5 None
V0603 01:00:15.828000 28681 torch/fx/experimental/symbolic_shapes.py:5902] track_symint L['w'].stride()[0] 5 None
V0603 01:00:15.828000 28681 torch/fx/experimental/symbolic_shapes.py:5902] track_symint L['w'].stride()[1] 1 None
V0603 01:00:15.828000 28681 torch/fx/experimental/symbolic_shapes.py:5902] track_symint L['w'].storage_offset() 0 None
V0603 01:00:15.828000 28681 torch/fx/experimental/symbolic_shapes.py:5902] track_symint L['x'].size()[0] s77 None
V0603 01:00:15.828000 28681 torch/fx/experimental/symbolic_shapes.py:5902] track_symint L['x'].stride()[0] 1 None
V0603 01:00:15.828000 28681 torch/fx/experimental/symbolic_shapes.py:5902] track_symint L['x'].storage_offset() 0 None
V0603 01:00:15.829000 28681 torch/fx/experimental/symbolic_shapes.py:5902] track_symint L['y'].size()[0] s17 None
V0603 01:00:15.829000 28681 torch/fx/experimental/symbolic_shapes.py:5902] track_symint L['y'].size()[1] s77 None
V0603 01:00:15.829000 28681 torch/fx/experimental/symbolic_shapes.py:5902] track_symint L['y'].stride()[0] s77 None
V0603 01:00:15.829000 28681 torch/fx/experimental/symbolic_shapes.py:5902] track_symint L['y'].stride()[1] 1 None
V0603 01:00:15.829000 28681 torch/fx/experimental/symbolic_shapes.py:5902] track_symint L['y'].storage_offset() 0 None
V0603 01:00:15.829000 28681 torch/fx/experimental/symbolic_shapes.py:5902] track_symint L['z'].size()[0] s17*s77 None
V0603 01:00:15.830000 28681 torch/fx/experimental/symbolic_shapes.py:5902] track_symint L['z'].stride()[0] 1 None
V0603 01:00:15.830000 28681 torch/fx/experimental/symbolic_shapes.py:5902] track_symint L['z'].storage_offset() 0 None
V0603 01:00:15.842000 28681 torch/fx/experimental/symbolic_shapes.py:8105] eval 5 [trivial]

即使对于这个简单的玩具模型,它也会输出相当多的内容。这里的日志行在前面和后面都被缩短以忽略不必要的信息,但在日志中我们可以看到与我们上面描述的内容相关的行;例如符号的分配。

"""
create_symbol s0 = 6 for L['w'].size()[0] [2, int_oo] (_dynamo/variables/builder.py:2841 in <lambda>)
create_symbol s1 = 5 for L['w'].size()[1] [2, int_oo] (_dynamo/variables/builder.py:2841 in <lambda>)
runtime_assert True == True [statically known]
create_symbol s2 = 4 for L['x'].size()[0] [2, int_oo] (_dynamo/variables/builder.py:2841 in <lambda>)
create_symbol s3 = 8 for L['y'].size()[0] [2, int_oo] (_dynamo/variables/builder.py:2841 in <lambda>)
create_symbol s4 = 4 for L['y'].size()[1] [2, int_oo] (_dynamo/variables/builder.py:2841 in <lambda>)
create_symbol s5 = 32 for L['z'].size()[0] [2, int_oo] (_dynamo/variables/builder.py:2841 in <lambda>)
"""
"\ncreate_symbol s0 = 6 for L['w'].size()[0] [2, int_oo] (_dynamo/variables/builder.py:2841 in <lambda>)\ncreate_symbol s1 = 5 for L['w'].size()[1] [2, int_oo] (_dynamo/variables/builder.py:2841 in <lambda>)\nruntime_assert True == True [statically known]\ncreate_symbol s2 = 4 for L['x'].size()[0] [2, int_oo] (_dynamo/variables/builder.py:2841 in <lambda>)\ncreate_symbol s3 = 8 for L['y'].size()[0] [2, int_oo] (_dynamo/variables/builder.py:2841 in <lambda>)\ncreate_symbol s4 = 4 for L['y'].size()[1] [2, int_oo] (_dynamo/variables/builder.py:2841 in <lambda>)\ncreate_symbol s5 = 32 for L['z'].size()[0] [2, int_oo] (_dynamo/variables/builder.py:2841 in <lambda>)\n"

带有 create_symbol 的行显示了何时分配了新符号,日志还标识了分配给它们的张量变量名称和维度。在其他行中,我们还可以看到发出的守卫。

"""
runtime_assert Eq(s2, s4) [guard added] x0 = x + y  # output shape: [8, 4]  # dynamic_shapes_tutorial.py:16 in forward (_subclasses/fake_impls.py:845 in infer_size), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_GUARD_ADDED="Eq(s2, s4)"
runtime_assert Eq(s1, 5) [guard added] x1 = self.l(w)  # [6, 3]  # dynamic_shapes_tutorial.py:17 in forward (_meta_registrations.py:2127 in meta_mm), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_GUARD_ADDED="Eq(s1, 5)"
runtime_assert Eq(s2*s3, s5) [guard added] x3 = x2 + z  # [32]  # dynamic_shapes_tutorial.py:19 in forward (_subclasses/fake_impls.py:845 in infer_size), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_GUARD_ADDED="Eq(s2*s3, s5)"
"""
'\nruntime_assert Eq(s2, s4) [guard added] x0 = x + y  # output shape: [8, 4]  # dynamic_shapes_tutorial.py:16 in forward (_subclasses/fake_impls.py:845 in infer_size), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_GUARD_ADDED="Eq(s2, s4)"\nruntime_assert Eq(s1, 5) [guard added] x1 = self.l(w)  # [6, 3]  # dynamic_shapes_tutorial.py:17 in forward (_meta_registrations.py:2127 in meta_mm), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_GUARD_ADDED="Eq(s1, 5)"\nruntime_assert Eq(s2*s3, s5) [guard added] x3 = x2 + z  # [32]  # dynamic_shapes_tutorial.py:19 in forward (_subclasses/fake_impls.py:845 in infer_size), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_GUARD_ADDED="Eq(s2*s3, s5)"\n'

[guard added] 消息旁边,我们还看到了负责的用户代码行 - 幸运的是,这里模型足够简单。在许多现实世界的情况下,这并不那么直接:高级 torch 操作可能有复杂的伪内核实现或算子分解,这使得很难确定在哪里以及发出了什么守卫。在这种情况下,深入研究和调查的最佳方法是遵循日志的建议,并使用环境变量 TORCHDYNAMO_EXTENDED_DEBUG_GUARD_ADDED="..." 重新运行,以进一步归因感兴趣的守卫。

Dim.AUTO 只是与 dynamic_shapes 交互的可用选项之一;在撰写本文时,还有其他 2 个可用选项:Dim.DYNAMICDim.STATICDim.STATIC 只是将一个维度标记为静态,而 Dim.DYNAMIC 在所有方面都与 Dim.AUTO 类似,除了一点:当特化为常量时,它会引发错误;这是为了保持动态性而设计的。例如,看看当在动态标记的维度上发出静态守卫时会发生什么。

dynamic_shapes["w"] = (Dim.AUTO, Dim.DYNAMIC)
try:
    export(model, (w, x, y, z), dynamic_shapes=dynamic_shapes)
except Exception:
    tb.print_exc()
I0603 01:00:15.846000 28681 torch/fx/experimental/symbolic_shapes.py:4002] create_env
I0603 01:00:15.848000 28681 torch/fx/experimental/symbolic_shapes.py:5523] create_symbol s15 = 6 for L['w'].size()[0] [2, int_oo] (_export/non_strict_utils.py:221 in fakify), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL="s15" or to suppress this message run with TORCHDYNAMO_EXTENDED_ADVICE="0"
I0603 01:00:15.848000 28681 torch/fx/experimental/symbolic_shapes.py:5523] create_symbol s21 = 5 for L['w'].size()[1] [2, int_oo] (_export/non_strict_utils.py:221 in fakify), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL="s21" or to suppress this message run with TORCHDYNAMO_EXTENDED_ADVICE="0"
I0603 01:00:15.850000 28681 torch/fx/experimental/symbolic_shapes.py:5523] create_symbol s77 = 4 for L['x'].size()[0] [2, int_oo] (_export/non_strict_utils.py:221 in fakify), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL="s77" or to suppress this message run with TORCHDYNAMO_EXTENDED_ADVICE="0"
I0603 01:00:15.852000 28681 torch/fx/experimental/symbolic_shapes.py:5523] create_symbol s17 = 8 for L['y'].size()[0] [2, int_oo] (_export/non_strict_utils.py:221 in fakify), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL="s17" or to suppress this message run with TORCHDYNAMO_EXTENDED_ADVICE="0"
I0603 01:00:15.852000 28681 torch/fx/experimental/symbolic_shapes.py:5523] create_symbol s94 = 4 for L['y'].size()[1] [2, int_oo] (_export/non_strict_utils.py:221 in fakify), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL="s94" or to suppress this message run with TORCHDYNAMO_EXTENDED_ADVICE="0"
I0603 01:00:15.854000 28681 torch/fx/experimental/symbolic_shapes.py:5523] create_symbol s68 = 32 for L['z'].size()[0] [2, int_oo] (_export/non_strict_utils.py:221 in fakify), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL="s68" or to suppress this message run with TORCHDYNAMO_EXTENDED_ADVICE="0"
V0603 01:00:15.860000 28681 torch/fx/experimental/symbolic_shapes.py:8149] eval False == False [statically known]
V0603 01:00:15.860000 28681 torch/fx/experimental/symbolic_shapes.py:8149] eval False == False [statically known]
V0603 01:00:15.861000 28681 torch/fx/experimental/symbolic_shapes.py:8149] eval False == False [statically known]
V0603 01:00:15.862000 28681 torch/fx/experimental/symbolic_shapes.py:8149] eval False == False [statically known]
V0603 01:00:15.862000 28681 torch/fx/experimental/symbolic_shapes.py:8149] eval False == False [statically known]
V0603 01:00:15.863000 28681 torch/fx/experimental/symbolic_shapes.py:8149] eval False == False [statically known]
V0603 01:00:15.864000 28681 torch/fx/experimental/symbolic_shapes.py:8149] eval False == False [statically known]
V0603 01:00:15.864000 28681 torch/fx/experimental/symbolic_shapes.py:8149] eval False == False [statically known]
V0603 01:00:15.865000 28681 torch/fx/experimental/symbolic_shapes.py:8149] eval False == False [statically known]
V0603 01:00:15.866000 28681 torch/fx/experimental/symbolic_shapes.py:8149] eval False == False [statically known]
I0603 01:00:15.868000 28681 torch/fx/experimental/symbolic_shapes.py:7834] eval Eq(s77, s94) [guard added] (_subclasses/fake_impls.py:1520 in infer_size), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_GUARD_ADDED="Eq(s77, s94)"
I0603 01:00:15.869000 28681 torch/fx/experimental/symbolic_shapes.py:7427] set_replacement s94 = s77 (solve) VR[2, int_oo]
V0603 01:00:15.870000 28681 torch/fx/experimental/symbolic_shapes.py:8149] eval False == False [statically known]
I0603 01:00:15.877000 28681 torch/fx/experimental/symbolic_shapes.py:7834] eval Eq(s21, 5) [guard added] (_meta_registrations.py:2456 in meta_mm), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_GUARD_ADDED="Eq(s21, 5)"
V0603 01:00:15.877000 28681 torch/fx/experimental/symbolic_shapes.py:7256] _update_var_to_range s21 = VR[5, 5] (update)
I0603 01:00:15.878000 28681 torch/fx/experimental/symbolic_shapes.py:7427] set_replacement s21 = 5 (range_refined_to_singleton) VR[5, 5]
V0603 01:00:15.890000 28681 torch/fx/experimental/symbolic_shapes.py:8149] eval False == False [statically known]
V0603 01:00:15.893000 28681 torch/fx/experimental/symbolic_shapes.py:8149] eval False == False [statically known]
I0603 01:00:15.895000 28681 torch/fx/experimental/symbolic_shapes.py:7834] eval Eq(s17*s77, s68) [guard added] (_subclasses/fake_impls.py:1520 in infer_size), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_GUARD_ADDED="Eq(s17*s77, s68)"
V0603 01:00:15.896000 28681 torch/fx/experimental/symbolic_shapes.py:7256] _update_var_to_range s68 = VR[4, int_oo] (update)
I0603 01:00:15.897000 28681 torch/fx/experimental/symbolic_shapes.py:7427] set_replacement s68 = s17*s77 (solve) VR[4, int_oo]
I0603 01:00:15.902000 28681 torch/fx/experimental/symbolic_shapes.py:5669] produce_guards
V0603 01:00:15.902000 28681 torch/fx/experimental/symbolic_shapes.py:5902] track_symint L['w'].size()[0] s15 None
V0603 01:00:15.903000 28681 torch/fx/experimental/symbolic_shapes.py:5902] track_symint L['w'].size()[1] 5 RelaxedUnspecConstraint(warn_only=False)
V0603 01:00:15.903000 28681 torch/fx/experimental/symbolic_shapes.py:5902] track_symint L['w'].stride()[0] 5 None
V0603 01:00:15.903000 28681 torch/fx/experimental/symbolic_shapes.py:5902] track_symint L['w'].stride()[1] 1 None
V0603 01:00:15.903000 28681 torch/fx/experimental/symbolic_shapes.py:5902] track_symint L['w'].storage_offset() 0 None
V0603 01:00:15.903000 28681 torch/fx/experimental/symbolic_shapes.py:5902] track_symint L['x'].size()[0] s77 None
V0603 01:00:15.903000 28681 torch/fx/experimental/symbolic_shapes.py:5902] track_symint L['x'].stride()[0] 1 None
V0603 01:00:15.903000 28681 torch/fx/experimental/symbolic_shapes.py:5902] track_symint L['x'].storage_offset() 0 None
V0603 01:00:15.904000 28681 torch/fx/experimental/symbolic_shapes.py:5902] track_symint L['y'].size()[0] s17 None
V0603 01:00:15.904000 28681 torch/fx/experimental/symbolic_shapes.py:5902] track_symint L['y'].size()[1] s77 None
V0603 01:00:15.904000 28681 torch/fx/experimental/symbolic_shapes.py:5902] track_symint L['y'].stride()[0] s77 None
V0603 01:00:15.904000 28681 torch/fx/experimental/symbolic_shapes.py:5902] track_symint L['y'].stride()[1] 1 None
V0603 01:00:15.904000 28681 torch/fx/experimental/symbolic_shapes.py:5902] track_symint L['y'].storage_offset() 0 None
V0603 01:00:15.904000 28681 torch/fx/experimental/symbolic_shapes.py:5902] track_symint L['z'].size()[0] s17*s77 None
V0603 01:00:15.904000 28681 torch/fx/experimental/symbolic_shapes.py:5902] track_symint L['z'].stride()[0] 1 None
V0603 01:00:15.904000 28681 torch/fx/experimental/symbolic_shapes.py:5902] track_symint L['z'].storage_offset() 0 None
Traceback (most recent call last):
  File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 2035, in _export_to_aten_ir_make_fx
    produce_guards_callback(gm)
  File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 2193, in _produce_guards_callback
    return produce_guards_and_solve_constraints(
  File "/usr/local/lib/python3.10/dist-packages/torch/_export/non_strict_utils.py", line 623, in produce_guards_and_solve_constraints
    raise constraint_violation_error
  File "/usr/local/lib/python3.10/dist-packages/torch/_export/non_strict_utils.py", line 582, in produce_guards_and_solve_constraints
    shape_env.produce_guards(
  File "/usr/local/lib/python3.10/dist-packages/torch/fx/experimental/symbolic_shapes.py", line 5631, in produce_guards
    return self.produce_guards_verbose(*args, **kwargs, langs=("python",))[0].exprs
  File "/usr/local/lib/python3.10/dist-packages/torch/fx/experimental/symbolic_shapes.py", line 6484, in produce_guards_verbose
    raise ConstraintViolationError(
torch.fx.experimental.symbolic_shapes.ConstraintViolationError: Constraints violated (L['w'].size()[1])! For more information, run with TORCH_LOGS="+dynamic".
  - You marked L['w'].size()[1] as dynamic but your code specialized it to be a constant (5). If you're using mark_dynamic, either remove it or use maybe_mark_dynamic. If you're using Dim.DYNAMIC, replace it with either Dim.STATIC or Dim.AUTO.

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/var/lib/workspace/intermediate_source/torch_export_tutorial.py", line 418, in <module>
    export(model, (w, x, y, z), dynamic_shapes=dynamic_shapes)
  File "/usr/local/lib/python3.10/dist-packages/torch/export/__init__.py", line 205, in export
    raise e
  File "/usr/local/lib/python3.10/dist-packages/torch/export/__init__.py", line 171, in export
    return _export(
  File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 1344, in wrapper
    raise e
  File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 1310, in wrapper
    ep = fn(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/export/exported_program.py", line 124, in wrapper
    return fn(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/_utils_internal.py", line 96, in wrapper_function
    return function(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 2512, in _export
    ep = _export_for_training(
  File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 1344, in wrapper
    raise e
  File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 1310, in wrapper
    ep = fn(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/export/exported_program.py", line 124, in wrapper
    return fn(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 2300, in _export_for_training
    export_artifact = export_func(
  File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 2229, in _non_strict_export
    aten_export_artifact = _to_aten_func(
  File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 2037, in _export_to_aten_ir_make_fx
    raise UserError(UserErrorType.CONSTRAINT_VIOLATION, str(e))  # noqa: B904
torch._dynamo.exc.UserError: Constraints violated (L['w'].size()[1])! For more information, run with TORCH_LOGS="+dynamic".
  - You marked L['w'].size()[1] as dynamic but your code specialized it to be a constant (5). If you're using mark_dynamic, either remove it or use maybe_mark_dynamic. If you're using Dim.DYNAMIC, replace it with either Dim.STATIC or Dim.AUTO.

The error above occurred when calling torch.export.export. If you would like to view some more information about this error, and get a list of all other errors that may occur in your export call, you can replace your `export()` call with `draft_export()`.

静态守卫并不总是模型固有的;它们也可以来自用户规范。事实上,导致形状特化的一个常见陷阱是用户为等效维度指定了相互冲突的标记;一个是动态的,另一个是静态的。当 x.shape[0]y.shape[1] 出现这种情况时,会引发相同的错误类型。

dynamic_shapes["w"] = (Dim.AUTO, Dim.AUTO)
dynamic_shapes["x"] = (Dim.STATIC,)
dynamic_shapes["y"] = (Dim.AUTO, Dim.DYNAMIC)
try:
    export(model, (w, x, y, z), dynamic_shapes=dynamic_shapes)
except Exception:
    tb.print_exc()
I0603 01:00:15.915000 28681 torch/fx/experimental/symbolic_shapes.py:4002] create_env
I0603 01:00:15.917000 28681 torch/fx/experimental/symbolic_shapes.py:5523] create_symbol s15 = 6 for L['w'].size()[0] [2, int_oo] (_export/non_strict_utils.py:221 in fakify), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL="s15" or to suppress this message run with TORCHDYNAMO_EXTENDED_ADVICE="0"
I0603 01:00:15.917000 28681 torch/fx/experimental/symbolic_shapes.py:5523] create_symbol s21 = 5 for L['w'].size()[1] [2, int_oo] (_export/non_strict_utils.py:221 in fakify), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL="s21" or to suppress this message run with TORCHDYNAMO_EXTENDED_ADVICE="0"
I0603 01:00:15.919000 28681 torch/fx/experimental/symbolic_shapes.py:5523] create_symbol s17 = 8 for L['y'].size()[0] [2, int_oo] (_export/non_strict_utils.py:221 in fakify), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL="s17" or to suppress this message run with TORCHDYNAMO_EXTENDED_ADVICE="0"
I0603 01:00:15.920000 28681 torch/fx/experimental/symbolic_shapes.py:5523] create_symbol s94 = 4 for L['y'].size()[1] [2, int_oo] (_export/non_strict_utils.py:221 in fakify), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL="s94" or to suppress this message run with TORCHDYNAMO_EXTENDED_ADVICE="0"
I0603 01:00:15.922000 28681 torch/fx/experimental/symbolic_shapes.py:5523] create_symbol s68 = 32 for L['z'].size()[0] [2, int_oo] (_export/non_strict_utils.py:221 in fakify), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL="s68" or to suppress this message run with TORCHDYNAMO_EXTENDED_ADVICE="0"
V0603 01:00:15.928000 28681 torch/fx/experimental/symbolic_shapes.py:8149] eval False == False [statically known]
V0603 01:00:15.928000 28681 torch/fx/experimental/symbolic_shapes.py:8149] eval False == False [statically known]
V0603 01:00:15.929000 28681 torch/fx/experimental/symbolic_shapes.py:8149] eval False == False [statically known]
V0603 01:00:15.930000 28681 torch/fx/experimental/symbolic_shapes.py:8149] eval False == False [statically known]
V0603 01:00:15.931000 28681 torch/fx/experimental/symbolic_shapes.py:8149] eval False == False [statically known]
V0603 01:00:15.931000 28681 torch/fx/experimental/symbolic_shapes.py:8149] eval False == False [statically known]
V0603 01:00:15.932000 28681 torch/fx/experimental/symbolic_shapes.py:8149] eval False == False [statically known]
V0603 01:00:15.933000 28681 torch/fx/experimental/symbolic_shapes.py:8149] eval False == False [statically known]
I0603 01:00:15.938000 28681 torch/fx/experimental/symbolic_shapes.py:7834] eval Eq(s94, 4) [guard added] (_subclasses/fake_impls.py:1520 in infer_size), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_GUARD_ADDED="Eq(s94, 4)"
V0603 01:00:15.938000 28681 torch/fx/experimental/symbolic_shapes.py:7256] _update_var_to_range s94 = VR[4, 4] (update)
I0603 01:00:15.939000 28681 torch/fx/experimental/symbolic_shapes.py:7427] set_replacement s94 = 4 (range_refined_to_singleton) VR[4, 4]
I0603 01:00:15.946000 28681 torch/fx/experimental/symbolic_shapes.py:7834] eval Eq(s21, 5) [guard added] (_meta_registrations.py:2456 in meta_mm), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_GUARD_ADDED="Eq(s21, 5)"
V0603 01:00:15.946000 28681 torch/fx/experimental/symbolic_shapes.py:7256] _update_var_to_range s21 = VR[5, 5] (update)
I0603 01:00:15.947000 28681 torch/fx/experimental/symbolic_shapes.py:7427] set_replacement s21 = 5 (range_refined_to_singleton) VR[5, 5]
I0603 01:00:15.968000 28681 torch/fx/experimental/symbolic_shapes.py:7834] eval Eq(4*s17, s68) [guard added] (_subclasses/fake_impls.py:1520 in infer_size), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_GUARD_ADDED="Eq(4*s17, s68)"
V0603 01:00:15.970000 28681 torch/fx/experimental/symbolic_shapes.py:7256] _update_var_to_range s68 = VR[8, int_oo] (update)
I0603 01:00:15.971000 28681 torch/fx/experimental/symbolic_shapes.py:7427] set_replacement s68 = 4*s17 (solve) VR[8, int_oo]
I0603 01:00:15.976000 28681 torch/fx/experimental/symbolic_shapes.py:5669] produce_guards
V0603 01:00:15.976000 28681 torch/fx/experimental/symbolic_shapes.py:5902] track_symint L['w'].size()[0] s15 None
V0603 01:00:15.976000 28681 torch/fx/experimental/symbolic_shapes.py:5902] track_symint L['w'].size()[1] 5 None
V0603 01:00:15.977000 28681 torch/fx/experimental/symbolic_shapes.py:5902] track_symint L['w'].stride()[0] 5 None
V0603 01:00:15.977000 28681 torch/fx/experimental/symbolic_shapes.py:5902] track_symint L['w'].stride()[1] 1 None
V0603 01:00:15.977000 28681 torch/fx/experimental/symbolic_shapes.py:5902] track_symint L['w'].storage_offset() 0 None
V0603 01:00:15.977000 28681 torch/fx/experimental/symbolic_shapes.py:5902] track_symint L['x'].size()[0] 4 None
V0603 01:00:15.977000 28681 torch/fx/experimental/symbolic_shapes.py:5902] track_symint L['x'].stride()[0] 1 None
V0603 01:00:15.977000 28681 torch/fx/experimental/symbolic_shapes.py:5902] track_symint L['x'].storage_offset() 0 None
V0603 01:00:15.977000 28681 torch/fx/experimental/symbolic_shapes.py:5902] track_symint L['y'].size()[0] s17 None
V0603 01:00:15.977000 28681 torch/fx/experimental/symbolic_shapes.py:5902] track_symint L['y'].size()[1] 4 RelaxedUnspecConstraint(warn_only=False)
V0603 01:00:15.978000 28681 torch/fx/experimental/symbolic_shapes.py:5902] track_symint L['y'].stride()[0] 4 None
V0603 01:00:15.978000 28681 torch/fx/experimental/symbolic_shapes.py:5902] track_symint L['y'].stride()[1] 1 None
V0603 01:00:15.978000 28681 torch/fx/experimental/symbolic_shapes.py:5902] track_symint L['y'].storage_offset() 0 None
V0603 01:00:15.978000 28681 torch/fx/experimental/symbolic_shapes.py:5902] track_symint L['z'].size()[0] 4*s17 None
V0603 01:00:15.978000 28681 torch/fx/experimental/symbolic_shapes.py:5902] track_symint L['z'].stride()[0] 1 None
V0603 01:00:15.978000 28681 torch/fx/experimental/symbolic_shapes.py:5902] track_symint L['z'].storage_offset() 0 None
Traceback (most recent call last):
  File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 2035, in _export_to_aten_ir_make_fx
    produce_guards_callback(gm)
  File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 2193, in _produce_guards_callback
    return produce_guards_and_solve_constraints(
  File "/usr/local/lib/python3.10/dist-packages/torch/_export/non_strict_utils.py", line 623, in produce_guards_and_solve_constraints
    raise constraint_violation_error
  File "/usr/local/lib/python3.10/dist-packages/torch/_export/non_strict_utils.py", line 582, in produce_guards_and_solve_constraints
    shape_env.produce_guards(
  File "/usr/local/lib/python3.10/dist-packages/torch/fx/experimental/symbolic_shapes.py", line 5631, in produce_guards
    return self.produce_guards_verbose(*args, **kwargs, langs=("python",))[0].exprs
  File "/usr/local/lib/python3.10/dist-packages/torch/fx/experimental/symbolic_shapes.py", line 6484, in produce_guards_verbose
    raise ConstraintViolationError(
torch.fx.experimental.symbolic_shapes.ConstraintViolationError: Constraints violated (L['y'].size()[1])! For more information, run with TORCH_LOGS="+dynamic".
  - You marked L['y'].size()[1] as dynamic but your code specialized it to be a constant (4). If you're using mark_dynamic, either remove it or use maybe_mark_dynamic. If you're using Dim.DYNAMIC, replace it with either Dim.STATIC or Dim.AUTO.

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/var/lib/workspace/intermediate_source/torch_export_tutorial.py", line 431, in <module>
    export(model, (w, x, y, z), dynamic_shapes=dynamic_shapes)
  File "/usr/local/lib/python3.10/dist-packages/torch/export/__init__.py", line 205, in export
    raise e
  File "/usr/local/lib/python3.10/dist-packages/torch/export/__init__.py", line 171, in export
    return _export(
  File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 1344, in wrapper
    raise e
  File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 1310, in wrapper
    ep = fn(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/export/exported_program.py", line 124, in wrapper
    return fn(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/_utils_internal.py", line 96, in wrapper_function
    return function(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 2512, in _export
    ep = _export_for_training(
  File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 1344, in wrapper
    raise e
  File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 1310, in wrapper
    ep = fn(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/export/exported_program.py", line 124, in wrapper
    return fn(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 2300, in _export_for_training
    export_artifact = export_func(
  File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 2229, in _non_strict_export
    aten_export_artifact = _to_aten_func(
  File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 2037, in _export_to_aten_ir_make_fx
    raise UserError(UserErrorType.CONSTRAINT_VIOLATION, str(e))  # noqa: B904
torch._dynamo.exc.UserError: Constraints violated (L['y'].size()[1])! For more information, run with TORCH_LOGS="+dynamic".
  - You marked L['y'].size()[1] as dynamic but your code specialized it to be a constant (4). If you're using mark_dynamic, either remove it or use maybe_mark_dynamic. If you're using Dim.DYNAMIC, replace it with either Dim.STATIC or Dim.AUTO.

The error above occurred when calling torch.export.export. If you would like to view some more information about this error, and get a list of all other errors that may occur in your export call, you can replace your `export()` call with `draft_export()`.

在这里,你可能会问为什么导出要进行“特化”,即为什么我们通过走静态路线来解决这个静态/动态冲突。答案是因为上述关于符号和守卫的符号化形状系统。当 x.shape[0] 被标记为静态时,我们不分配符号,并在编译时将此形状视为具体整数 4。为 y.shape[1] 分配了一个符号,因此我们最终发出了守卫 s3 == 4,导致特化。

导出的一个功能是,在跟踪期间,诸如断言、torch._check()if/else 条件之类的语句也会发出守卫。看看当我们用这样的语句扩充现有模型时会发生什么。

class DynamicModel(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.l = torch.nn.Linear(5, 3)

    def forward(self, w, x, y, z):
        assert w.shape[0] <= 512
        torch._check(x.shape[0] >= 4)
        if w.shape[0] == x.shape[0] + 2:
            x0 = x + y
            x1 = self.l(w)
            x2 = x0.flatten()
            x3 = x2 + z
            return x1, x3
        else:
            return w

dynamic_shapes = {
    "w": (Dim.AUTO, Dim.AUTO),
    "x": (Dim.AUTO,),
    "y": (Dim.AUTO, Dim.AUTO),
    "z": (Dim.AUTO,),
}
try:
    ep = export(DynamicModel(), (w, x, y, z), dynamic_shapes=dynamic_shapes)
except Exception:
    tb.print_exc()
I0603 01:00:15.987000 28681 torch/fx/experimental/symbolic_shapes.py:4002] create_env
I0603 01:00:15.989000 28681 torch/fx/experimental/symbolic_shapes.py:5523] create_symbol s15 = 6 for L['w'].size()[0] [2, int_oo] (_export/non_strict_utils.py:221 in fakify), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL="s15" or to suppress this message run with TORCHDYNAMO_EXTENDED_ADVICE="0"
I0603 01:00:15.989000 28681 torch/fx/experimental/symbolic_shapes.py:5523] create_symbol s21 = 5 for L['w'].size()[1] [2, int_oo] (_export/non_strict_utils.py:221 in fakify), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL="s21" or to suppress this message run with TORCHDYNAMO_EXTENDED_ADVICE="0"
I0603 01:00:15.991000 28681 torch/fx/experimental/symbolic_shapes.py:5523] create_symbol s77 = 4 for L['x'].size()[0] [2, int_oo] (_export/non_strict_utils.py:221 in fakify), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL="s77" or to suppress this message run with TORCHDYNAMO_EXTENDED_ADVICE="0"
I0603 01:00:15.992000 28681 torch/fx/experimental/symbolic_shapes.py:5523] create_symbol s17 = 8 for L['y'].size()[0] [2, int_oo] (_export/non_strict_utils.py:221 in fakify), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL="s17" or to suppress this message run with TORCHDYNAMO_EXTENDED_ADVICE="0"
I0603 01:00:15.993000 28681 torch/fx/experimental/symbolic_shapes.py:5523] create_symbol s94 = 4 for L['y'].size()[1] [2, int_oo] (_export/non_strict_utils.py:221 in fakify), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL="s94" or to suppress this message run with TORCHDYNAMO_EXTENDED_ADVICE="0"
I0603 01:00:15.994000 28681 torch/fx/experimental/symbolic_shapes.py:5523] create_symbol s68 = 32 for L['z'].size()[0] [2, int_oo] (_export/non_strict_utils.py:221 in fakify), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL="s68" or to suppress this message run with TORCHDYNAMO_EXTENDED_ADVICE="0"
V0603 01:00:16.001000 28681 torch/fx/experimental/symbolic_shapes.py:8149] eval False == False [statically known]
V0603 01:00:16.001000 28681 torch/fx/experimental/symbolic_shapes.py:8149] eval False == False [statically known]
V0603 01:00:16.002000 28681 torch/fx/experimental/symbolic_shapes.py:8149] eval False == False [statically known]
V0603 01:00:16.003000 28681 torch/fx/experimental/symbolic_shapes.py:8149] eval False == False [statically known]
V0603 01:00:16.003000 28681 torch/fx/experimental/symbolic_shapes.py:8149] eval False == False [statically known]
V0603 01:00:16.005000 28681 torch/fx/experimental/symbolic_shapes.py:8149] eval False == False [statically known]
V0603 01:00:16.005000 28681 torch/fx/experimental/symbolic_shapes.py:8149] eval False == False [statically known]
V0603 01:00:16.006000 28681 torch/fx/experimental/symbolic_shapes.py:8149] eval False == False [statically known]
V0603 01:00:16.007000 28681 torch/fx/experimental/symbolic_shapes.py:8149] eval False == False [statically known]
V0603 01:00:16.007000 28681 torch/fx/experimental/symbolic_shapes.py:8149] eval False == False [statically known]
I0603 01:00:16.013000 28681 torch/fx/experimental/symbolic_shapes.py:7834] eval s15 <= 512 [guard added] (ar/lib/workspace/intermediate_source/torch_export_tutorial.py:450 in forward), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_GUARD_ADDED="s15 <= 512"
V0603 01:00:16.013000 28681 torch/fx/experimental/symbolic_shapes.py:7256] _update_var_to_range s15 = VR[2, 512] (update)
I0603 01:00:16.017000 28681 torch/fx/experimental/symbolic_shapes.py:7834] eval s77 >= 4 [guard added] (ar/lib/workspace/intermediate_source/torch_export_tutorial.py:451 in forward), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_GUARD_ADDED="s77 >= 4"
V0603 01:00:16.017000 28681 torch/fx/experimental/symbolic_shapes.py:7256] _update_var_to_range s77 = VR[4, int_oo] (update)
I0603 01:00:16.022000 28681 torch/fx/experimental/symbolic_shapes.py:7834] eval Eq(s15, s77 + 2) [guard added] (ar/lib/workspace/intermediate_source/torch_export_tutorial.py:452 in forward), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_GUARD_ADDED="Eq(s15, s77 + 2)"
V0603 01:00:16.024000 28681 torch/fx/experimental/symbolic_shapes.py:7256] _update_var_to_range s77 = VR[4, 510] (update)
V0603 01:00:16.025000 28681 torch/fx/experimental/symbolic_shapes.py:7256] _update_var_to_range s15 = VR[6, 512] (update)
I0603 01:00:16.026000 28681 torch/fx/experimental/symbolic_shapes.py:7427] set_replacement s15 = s77 + 2 (solve) VR[6, 512]
I0603 01:00:16.030000 28681 torch/fx/experimental/symbolic_shapes.py:7834] eval Eq(s77, s94) [guard added] (_subclasses/fake_impls.py:1520 in infer_size), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_GUARD_ADDED="Eq(s77, s94)"
V0603 01:00:16.031000 28681 torch/fx/experimental/symbolic_shapes.py:7256] _update_var_to_range s94 = VR[4, 510] (update)
I0603 01:00:16.031000 28681 torch/fx/experimental/symbolic_shapes.py:7427] set_replacement s94 = s77 (solve) VR[4, 510]
V0603 01:00:16.034000 28681 torch/fx/experimental/symbolic_shapes.py:8149] eval False == False [statically known]
I0603 01:00:16.041000 28681 torch/fx/experimental/symbolic_shapes.py:7834] eval Eq(s21, 5) [guard added] (_meta_registrations.py:2456 in meta_mm), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_GUARD_ADDED="Eq(s21, 5)"
V0603 01:00:16.042000 28681 torch/fx/experimental/symbolic_shapes.py:7256] _update_var_to_range s21 = VR[5, 5] (update)
I0603 01:00:16.043000 28681 torch/fx/experimental/symbolic_shapes.py:7427] set_replacement s21 = 5 (range_refined_to_singleton) VR[5, 5]
V0603 01:00:16.058000 28681 torch/fx/experimental/symbolic_shapes.py:8149] eval False == False [statically known]
V0603 01:00:16.060000 28681 torch/fx/experimental/symbolic_shapes.py:8149] eval False == False [statically known]
I0603 01:00:16.069000 28681 torch/fx/experimental/symbolic_shapes.py:7834] eval Eq(s17*s77, s68) [guard added] (_subclasses/fake_impls.py:1520 in infer_size), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_GUARD_ADDED="Eq(s17*s77, s68)"
V0603 01:00:16.070000 28681 torch/fx/experimental/symbolic_shapes.py:7256] _update_var_to_range s68 = VR[8, int_oo] (update)
I0603 01:00:16.071000 28681 torch/fx/experimental/symbolic_shapes.py:7427] set_replacement s68 = s17*s77 (solve) VR[8, int_oo]
I0603 01:00:16.077000 28681 torch/fx/experimental/symbolic_shapes.py:5669] produce_guards
V0603 01:00:16.077000 28681 torch/fx/experimental/symbolic_shapes.py:5902] track_symint L['w'].size()[0] s77 + 2 None
V0603 01:00:16.077000 28681 torch/fx/experimental/symbolic_shapes.py:5902] track_symint L['w'].size()[1] 5 None
V0603 01:00:16.077000 28681 torch/fx/experimental/symbolic_shapes.py:5902] track_symint L['w'].stride()[0] 5 None
V0603 01:00:16.078000 28681 torch/fx/experimental/symbolic_shapes.py:5902] track_symint L['w'].stride()[1] 1 None
V0603 01:00:16.078000 28681 torch/fx/experimental/symbolic_shapes.py:5902] track_symint L['w'].storage_offset() 0 None
V0603 01:00:16.078000 28681 torch/fx/experimental/symbolic_shapes.py:5902] track_symint L['x'].size()[0] s77 None
V0603 01:00:16.078000 28681 torch/fx/experimental/symbolic_shapes.py:5902] track_symint L['x'].stride()[0] 1 None
V0603 01:00:16.078000 28681 torch/fx/experimental/symbolic_shapes.py:5902] track_symint L['x'].storage_offset() 0 None
V0603 01:00:16.078000 28681 torch/fx/experimental/symbolic_shapes.py:5902] track_symint L['y'].size()[0] s17 None
V0603 01:00:16.078000 28681 torch/fx/experimental/symbolic_shapes.py:5902] track_symint L['y'].size()[1] s77 None
V0603 01:00:16.079000 28681 torch/fx/experimental/symbolic_shapes.py:5902] track_symint L['y'].stride()[0] s77 None
V0603 01:00:16.079000 28681 torch/fx/experimental/symbolic_shapes.py:5902] track_symint L['y'].stride()[1] 1 None
V0603 01:00:16.079000 28681 torch/fx/experimental/symbolic_shapes.py:5902] track_symint L['y'].storage_offset() 0 None
V0603 01:00:16.079000 28681 torch/fx/experimental/symbolic_shapes.py:5902] track_symint L['z'].size()[0] s17*s77 None
V0603 01:00:16.079000 28681 torch/fx/experimental/symbolic_shapes.py:5902] track_symint L['z'].stride()[0] 1 None
V0603 01:00:16.079000 28681 torch/fx/experimental/symbolic_shapes.py:5902] track_symint L['z'].storage_offset() 0 None
V0603 01:00:16.096000 28681 torch/fx/experimental/symbolic_shapes.py:8105] eval 5 [trivial]

这些语句中的每一个都会发出一个额外的守卫,导出的程序显示了变化;s0 被删除,转而支持 s2 + 2s2 现在包含下界和上界,反映在 range_constraints 中。

对于 if/else 条件,你可能会问为什么采用了 True 分支,以及为什么不是跟踪时发出的 w.shape[0] != x.shape[0] + 2 守卫。答案是导出由跟踪提供的样本输入指导,并根据所采用的分支进行特化。如果提供了未能通过 if 条件的不同样本输入形状,导出将跟踪并发出对应于 else 分支的守卫。此外,你可能会问为什么我们只跟踪了 if 分支,以及是否可以在程序中保持控制流并使两个分支都存活。为此,请参阅上面 控制流算子 部分,重写你的模型代码。

0/1 特化#

既然我们在谈论守卫和特化,是时候谈谈我们之前提到的 0/1 特化问题了。底线是,导出将对值为 0 或 1 的样本输入维度进行特化,因为这些形状具有跟踪时的特性,不能推广到其他形状。例如,大小为 1 的张量可以进行广播,而其他大小则会失败;以及大小为 0 的……。这只是意味着当你希望程序硬编码它们时,你应该指定 0/1 样本输入,而当需要动态行为时,你应该指定非 0/1 样本输入。看看当我们导出这个线性层时,运行时会发生什么。

ep = export(
    torch.nn.Linear(4, 3),
    (torch.randn(1, 4),),
    dynamic_shapes={
        "input": (Dim.AUTO, Dim.STATIC),
    },
)
try:
    ep.module()(torch.randn(2, 4))
except Exception:
    tb.print_exc()
I0603 01:00:16.101000 28681 torch/fx/experimental/symbolic_shapes.py:4002] create_env
I0603 01:00:16.114000 28681 torch/fx/experimental/symbolic_shapes.py:5669] produce_guards
V0603 01:00:16.114000 28681 torch/fx/experimental/symbolic_shapes.py:5902] track_symint L['input'].size()[0] 1 None
V0603 01:00:16.114000 28681 torch/fx/experimental/symbolic_shapes.py:5902] track_symint L['input'].size()[1] 4 None
V0603 01:00:16.114000 28681 torch/fx/experimental/symbolic_shapes.py:5902] track_symint L['input'].stride()[0] 4 None
V0603 01:00:16.114000 28681 torch/fx/experimental/symbolic_shapes.py:5902] track_symint L['input'].stride()[1] 1 None
V0603 01:00:16.115000 28681 torch/fx/experimental/symbolic_shapes.py:5902] track_symint L['input'].storage_offset() 0 None
W0603 01:00:16.117000 28681 torch/_export/non_strict_utils.py:654] dimension inputs['input'].shape[0] 0/1 specialized; Dim.AUTO was specified along with a sample input with hint = 1.
Traceback (most recent call last):
  File "/var/lib/workspace/intermediate_source/torch_export_tutorial.py", line 500, in <module>
    ep.module()(torch.randn(2, 4))
  File "/usr/local/lib/python3.10/dist-packages/torch/fx/graph_module.py", line 949, in call_wrapped
    return self._wrapped_call(self, *args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/fx/graph_module.py", line 461, in __call__
    raise e
  File "/usr/local/lib/python3.10/dist-packages/torch/fx/graph_module.py", line 447, in __call__
    return super(self.cls, obj).__call__(*args, **kwargs)  # type: ignore[misc]
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1778, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1884, in _call_impl
    return inner()
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1832, in inner
    result = forward_call(*args, **kwargs)
  File "<eval_with_key>.163", line 9, in forward
    _guards_fn = self._guards_fn(input_1);  _guards_fn = None
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1778, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1789, in _call_impl
    return forward_call(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/external_utils.py", line 239, in inner
    return func(*args, **kwargs)
  File "<string>", line 3, in _
  File "/usr/local/lib/python3.10/dist-packages/torch/__init__.py", line 2272, in _assert
    raise AssertionError(message)
AssertionError: Guard failed: input.size()[0] == 1

命名维度#

到目前为止,我们只谈到了指定动态形状的 3 种方式:Dim.AUTODim.DYNAMICDim.STATIC。它们的吸引力在于低摩擦的用户体验;所有在模型跟踪期间发出的守卫都会被遵守,并且像最小/最大范围、关系和静态/动态维度这样的动态行为会在导出底层自动计算出来。动态形状子系统本质上充当了“发现”过程,总结这些守卫并呈现导出认为的程序整体动态行为。这种设计的一个缺点一旦用户对这些模型的动态行为有更强的期望或信念就会显现出来——也许有强烈的动态性愿望,并且必须不惜一切代价避免特定维度的特化,或者我们只是想通过原始模型代码的更改来捕捉动态行为的变化,或者可能是潜在的分解或元内核。除非有检查生成的 ExportedProgram 表示的测试,否则这些变化将不会被检测到,并且 export() 调用很可能会成功。

对于这种情况,我们的立场是推荐指定动态形状的“传统”方式,导出程序的长期用户可能对此很熟悉:命名 Dims

dx = Dim("dx", min=4, max=256)
dh = Dim("dh", max=512)
dynamic_shapes = {
    "x": (dx, None),
    "y": (2 * dx, dh),
}

这种风格的动态形状允许用户指定为输入维度分配什么符号、这些符号的最小/最大边界,并对生成的 ExportedProgram 的动态行为施加限制;如果模型跟踪发出的守卫与给定的关系或静态/动态规范冲突,将引发 ConstraintViolation 错误。例如,在上述规范中,断言了以下内容:

  • x.shape[0] 的范围为 [4, 256],并且通过 y.shape[0] == 2 * x.shape[0]y.shape[0] 相关联。

  • x.shape[1] 是静态的。

  • y.shape[1] 的范围为 [2, 512],并且与任何其他维度无关。

在这种设计中,我们允许使用单变量线性表达式指定维度之间的关系:可以为任何维度指定 A * dim + B。这允许用户为动态维度指定更复杂的约束,如整数可分性。

dx = Dim("dx", min=4, max=512)
dynamic_shapes = {
    "x": (4 * dx, None)  # x.shape[0] has range [16, 2048], and is divisible by 4.
}

违反约束及其建议修复方案#

这种规范风格的一个常见问题(在引入 Dim.AUTO 之前)是规范经常与模型跟踪产生的结果不匹配。这将导致 ConstraintViolation 错误和导出的建议修复方案——例如这个模型和规范,模型固有地要求 xy 的维度 0 之间相等,并且要求维度 1 为静态。

class Foo(torch.nn.Module):
    def forward(self, x, y):
        w = x + y
        return w + torch.ones(4)

dx, dy, d1 = torch.export.dims("dx", "dy", "d1")
try:
    ep = export(
        Foo(),
        (torch.randn(6, 4), torch.randn(6, 4)),
        dynamic_shapes={
            "x": (dx, d1),
            "y": (dy, d1),
        },
    )
except Exception:
    tb.print_exc()
I0603 01:00:16.126000 28681 torch/fx/experimental/symbolic_shapes.py:4002] create_env
I0603 01:00:16.128000 28681 torch/fx/experimental/symbolic_shapes.py:5523] create_symbol s77 = 6 for L['x'].size()[0] [2, int_oo] (_export/non_strict_utils.py:221 in fakify), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL="s77" or to suppress this message run with TORCHDYNAMO_EXTENDED_ADVICE="0"
I0603 01:00:16.130000 28681 torch/fx/experimental/symbolic_shapes.py:5523] create_symbol s27 = 4 for L['x'].size()[1] [2, int_oo] (_export/non_strict_utils.py:221 in fakify), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL="s27" or to suppress this message run with TORCHDYNAMO_EXTENDED_ADVICE="0"
I0603 01:00:16.133000 28681 torch/fx/experimental/symbolic_shapes.py:5523] create_symbol s17 = 6 for L['y'].size()[0] [2, int_oo] (_export/non_strict_utils.py:221 in fakify), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL="s17" or to suppress this message run with TORCHDYNAMO_EXTENDED_ADVICE="0"
I0603 01:00:16.133000 28681 torch/fx/experimental/symbolic_shapes.py:5523] create_symbol s94 = 4 for L['y'].size()[1] [2, int_oo] (_export/non_strict_utils.py:221 in fakify), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL="s94" or to suppress this message run with TORCHDYNAMO_EXTENDED_ADVICE="0"
V0603 01:00:16.140000 28681 torch/fx/experimental/symbolic_shapes.py:8149] eval False == False [statically known]
V0603 01:00:16.141000 28681 torch/fx/experimental/symbolic_shapes.py:8149] eval False == False [statically known]
V0603 01:00:16.142000 28681 torch/fx/experimental/symbolic_shapes.py:8149] eval False == False [statically known]
V0603 01:00:16.143000 28681 torch/fx/experimental/symbolic_shapes.py:8149] eval False == False [statically known]
V0603 01:00:16.144000 28681 torch/fx/experimental/symbolic_shapes.py:8149] eval False == False [statically known]
V0603 01:00:16.144000 28681 torch/fx/experimental/symbolic_shapes.py:8149] eval False == False [statically known]
I0603 01:00:16.148000 28681 torch/fx/experimental/symbolic_shapes.py:7834] eval Eq(s27, s94) [guard added] (_subclasses/fake_impls.py:1520 in infer_size), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_GUARD_ADDED="Eq(s27, s94)"
I0603 01:00:16.149000 28681 torch/fx/experimental/symbolic_shapes.py:7427] set_replacement s94 = s27 (solve) VR[2, int_oo]
I0603 01:00:16.151000 28681 torch/fx/experimental/symbolic_shapes.py:7834] eval Eq(s77, s17) [guard added] (_subclasses/fake_impls.py:1520 in infer_size), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_GUARD_ADDED="Eq(s77, s17)"
I0603 01:00:16.152000 28681 torch/fx/experimental/symbolic_shapes.py:7427] set_replacement s77 = s17 (solve) VR[2, int_oo]
V0603 01:00:16.154000 28681 torch/fx/experimental/symbolic_shapes.py:8149] eval False == False [statically known]
I0603 01:00:16.164000 28681 torch/fx/experimental/symbolic_shapes.py:7834] eval Eq(s27, 4) [guard added] (_subclasses/fake_impls.py:1520 in infer_size), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_GUARD_ADDED="Eq(s27, 4)"
V0603 01:00:16.165000 28681 torch/fx/experimental/symbolic_shapes.py:7256] _update_var_to_range s27 = VR[4, 4] (update)
I0603 01:00:16.165000 28681 torch/fx/experimental/symbolic_shapes.py:7427] set_replacement s27 = 4 (range_refined_to_singleton) VR[4, 4]
I0603 01:00:16.172000 28681 torch/fx/experimental/symbolic_shapes.py:5669] produce_guards
V0603 01:00:16.172000 28681 torch/fx/experimental/symbolic_shapes.py:7256] _update_var_to_range s94 = VR[4, 4] (update)
I0603 01:00:16.173000 28681 torch/fx/experimental/symbolic_shapes.py:7427] set_replacement s94 = 4 (find) VR[4, 4]
V0603 01:00:16.173000 28681 torch/fx/experimental/symbolic_shapes.py:5902] track_symint L['x'].size()[0] s17 StrictMinMaxConstraint(warn_only=False, vr=VR[0, int_oo])
V0603 01:00:16.173000 28681 torch/fx/experimental/symbolic_shapes.py:5902] track_symint L['x'].size()[1] 4 StrictMinMaxConstraint(warn_only=False, vr=VR[0, int_oo])
V0603 01:00:16.173000 28681 torch/fx/experimental/symbolic_shapes.py:5902] track_symint L['x'].stride()[0] 4 None
V0603 01:00:16.173000 28681 torch/fx/experimental/symbolic_shapes.py:5902] track_symint L['x'].stride()[1] 1 None
V0603 01:00:16.174000 28681 torch/fx/experimental/symbolic_shapes.py:5902] track_symint L['x'].storage_offset() 0 None
V0603 01:00:16.174000 28681 torch/fx/experimental/symbolic_shapes.py:5902] track_symint L['y'].size()[0] s17 StrictMinMaxConstraint(warn_only=False, vr=VR[0, int_oo])
V0603 01:00:16.174000 28681 torch/fx/experimental/symbolic_shapes.py:5902] track_symint L['y'].size()[1] 4 StrictMinMaxConstraint(warn_only=False, vr=VR[0, int_oo])
V0603 01:00:16.174000 28681 torch/fx/experimental/symbolic_shapes.py:5902] track_symint L['y'].stride()[0] 4 None
V0603 01:00:16.174000 28681 torch/fx/experimental/symbolic_shapes.py:5902] track_symint L['y'].stride()[1] 1 None
V0603 01:00:16.174000 28681 torch/fx/experimental/symbolic_shapes.py:5902] track_symint L['y'].storage_offset() 0 None
Traceback (most recent call last):
  File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 2035, in _export_to_aten_ir_make_fx
    produce_guards_callback(gm)
  File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 2193, in _produce_guards_callback
    return produce_guards_and_solve_constraints(
  File "/usr/local/lib/python3.10/dist-packages/torch/_export/non_strict_utils.py", line 623, in produce_guards_and_solve_constraints
    raise constraint_violation_error
  File "/usr/local/lib/python3.10/dist-packages/torch/_export/non_strict_utils.py", line 582, in produce_guards_and_solve_constraints
    shape_env.produce_guards(
  File "/usr/local/lib/python3.10/dist-packages/torch/fx/experimental/symbolic_shapes.py", line 5631, in produce_guards
    return self.produce_guards_verbose(*args, **kwargs, langs=("python",))[0].exprs
  File "/usr/local/lib/python3.10/dist-packages/torch/fx/experimental/symbolic_shapes.py", line 6484, in produce_guards_verbose
    raise ConstraintViolationError(
torch.fx.experimental.symbolic_shapes.ConstraintViolationError: Constraints violated (d1, dy)! For more information, run with TORCH_LOGS="+dynamic".
  - You marked d1 as dynamic but your code specialized it to be a constant (4). If you're using mark_dynamic, either remove it or use maybe_mark_dynamic. If you're using Dim.DYNAMIC, replace it with either Dim.STATIC or Dim.AUTO.
  - You marked d1 as dynamic but your code specialized it to be a constant (4). If you're using mark_dynamic, either remove it or use maybe_mark_dynamic. If you're using Dim.DYNAMIC, replace it with either Dim.STATIC or Dim.AUTO.
  - The values of dy = L['y'].size()[0] and dx = L['x'].size()[0] must always be equal.
Suggested fixes:
  d1 = 4
  dy = dx

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/var/lib/workspace/intermediate_source/torch_export_tutorial.py", line 557, in <module>
    ep = export(
  File "/usr/local/lib/python3.10/dist-packages/torch/export/__init__.py", line 205, in export
    raise e
  File "/usr/local/lib/python3.10/dist-packages/torch/export/__init__.py", line 171, in export
    return _export(
  File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 1344, in wrapper
    raise e
  File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 1310, in wrapper
    ep = fn(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/export/exported_program.py", line 124, in wrapper
    return fn(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/_utils_internal.py", line 96, in wrapper_function
    return function(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 2512, in _export
    ep = _export_for_training(
  File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 1344, in wrapper
    raise e
  File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 1310, in wrapper
    ep = fn(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/export/exported_program.py", line 124, in wrapper
    return fn(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 2300, in _export_for_training
    export_artifact = export_func(
  File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 2229, in _non_strict_export
    aten_export_artifact = _to_aten_func(
  File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 2037, in _export_to_aten_ir_make_fx
    raise UserError(UserErrorType.CONSTRAINT_VIOLATION, str(e))  # noqa: B904
torch._dynamo.exc.UserError: Constraints violated (d1, dy)! For more information, run with TORCH_LOGS="+dynamic".
  - You marked d1 as dynamic but your code specialized it to be a constant (4). If you're using mark_dynamic, either remove it or use maybe_mark_dynamic. If you're using Dim.DYNAMIC, replace it with either Dim.STATIC or Dim.AUTO.
  - You marked d1 as dynamic but your code specialized it to be a constant (4). If you're using mark_dynamic, either remove it or use maybe_mark_dynamic. If you're using Dim.DYNAMIC, replace it with either Dim.STATIC or Dim.AUTO.
  - The values of dy = L['y'].size()[0] and dx = L['x'].size()[0] must always be equal.
Suggested fixes:
  d1 = 4
  dy = dx

The error above occurred when calling torch.export.export. If you would like to view some more information about this error, and get a list of all other errors that may occur in your export call, you can replace your `export()` call with `draft_export()`.

建议修复方案的预期是用户可以交互式地将更改复制并粘贴到他们的动态形状规范中,并在之后成功导出。

最后,还有几个关于规范选项的“最好知道”:

  • None 是静态行为的一个好选项:- dynamic_shapes=None(默认值)导出时整个模型都是静态的。- 在输入级别指定 None 会导出所有张量维度都为静态,对于非张量输入也是必需的。- 在维度级别指定 None 会特化该维度,尽管这已被弃用,建议使用 Dim.STATIC

  • 指定每个维度的整数值也会产生静态行为,并且还会额外检查提供的样本输入是否与规范匹配。

这些选项在下面的输入和动态形状规范中进行了组合。

inputs = (
    torch.randn(4, 4),
    torch.randn(3, 3),
    16,
    False,
)
dynamic_shapes = {
    "tensor_0": (Dim.AUTO, None),
    "tensor_1": None,
    "int_val": None,
    "bool_val": None,
}

数据依赖错误#

在尝试导出模型时,你可能遇到过类似“Could not guard on data-dependent expression”或“Could not extract specialized integer from data-dependent expression”的错误。这些错误存在是因为 torch.export() 使用 FakeTensors 编译程序,这些 FakeTensors 符号化地表示其真实的张量对应物。虽然它们具有等效的符号属性(例如大小、步长、数据类型),但它们的不同之处在于 FakeTensors 不包含任何数据值。虽然这避免了不必要的内存使用和昂贵的计算,但确实意味着导出可能无法“开箱即用”地编译那些依赖数据值的用户代码部分。简而言之,如果编译器需要一个具体的、数据依赖的值才能继续,它会报错,抱怨该值不可用。

数据依赖值出现在许多地方,常见的来源包括像 item()tolist()torch.unbind() 这样从张量中提取标量值的调用。这些值在导出的程序中是如何表示的?在 约束/动态形状 一节中,我们讨论了分配符号来表示动态输入维度。这里也是一样:我们为程序中出现的每个数据依赖值分配符号。重要的区别是,这些是“未支持(unbacked)”的符号,与为输入维度分配的“已支持(backed)”符号形成对比。“已支持/未支持” 的命名法指的是符号是否存在“提示(hint)”:一个支持符号的具体值,可以告知编译器如何继续。

在输入形状符号(已支持符号)的情况下,这些提示就是提供的样本输入形状,这解释了为什么控制流分支由样本输入属性决定。对于数据依赖值,符号是在跟踪期间从 FakeTensor 的“数据”中获取的,因此编译器不知道这些符号将采用的实际值(提示)。

让我们看看这些是如何在导出的程序中显示的。

class Foo(torch.nn.Module):
    def forward(self, x, y):
        a = x.item()
        b = y.tolist()
        return b + [a]

inps = (
    torch.tensor(1),
    torch.tensor([2, 3]),
)
ep = export(Foo(), inps)
print(ep)
I0603 01:00:16.183000 28681 torch/fx/experimental/symbolic_shapes.py:4002] create_env
I0603 01:00:16.189000 28681 torch/fx/experimental/symbolic_shapes.py:5124] create_unbacked_symint u0 [-int_oo, int_oo] (_subclasses/fake_impls.py:849 in local_scalar_dense)
I0603 01:00:16.189000 28681 torch/fx/experimental/symbolic_shapes.py:1439] compute_unbacked_bindings [u0]
I0603 01:00:16.193000 28681 torch/fx/experimental/symbolic_shapes.py:5124] create_unbacked_symint u1 [-int_oo, int_oo] (_subclasses/fake_impls.py:849 in local_scalar_dense)
I0603 01:00:16.194000 28681 torch/fx/experimental/symbolic_shapes.py:1439] compute_unbacked_bindings [u1]
I0603 01:00:16.195000 28681 torch/fx/experimental/symbolic_shapes.py:5124] create_unbacked_symint u2 [-int_oo, int_oo] (_subclasses/fake_impls.py:849 in local_scalar_dense)
I0603 01:00:16.195000 28681 torch/fx/experimental/symbolic_shapes.py:1439] compute_unbacked_bindings [u2]
I0603 01:00:16.197000 28681 torch/fx/experimental/symbolic_shapes.py:5669] produce_guards
V0603 01:00:16.197000 28681 torch/fx/experimental/symbolic_shapes.py:5902] track_symint L['x'].storage_offset() 0 None
V0603 01:00:16.197000 28681 torch/fx/experimental/symbolic_shapes.py:5902] track_symint L['y'].size()[0] 2 None
V0603 01:00:16.198000 28681 torch/fx/experimental/symbolic_shapes.py:5902] track_symint L['y'].stride()[0] 1 None
V0603 01:00:16.198000 28681 torch/fx/experimental/symbolic_shapes.py:5902] track_symint L['y'].storage_offset() 0 None
ExportedProgram:
    class GraphModule(torch.nn.Module):
        def forward(self, x: "i64[]", y: "i64[2]"):
            # File: /var/lib/workspace/intermediate_source/torch_export_tutorial.py:618 in forward, code: a = x.item()
            item: "Sym(u0)" = torch.ops.aten.item.default(x);  x = None

            # File: /var/lib/workspace/intermediate_source/torch_export_tutorial.py:619 in forward, code: b = y.tolist()
            unbind = torch.ops.aten.unbind.int(y);  y = None
            getitem: "i64[]" = unbind[0]
            getitem_1: "i64[]" = unbind[1];  unbind = None
            item_1: "Sym(u1)" = torch.ops.aten.item.default(getitem);  getitem = None
            item_2: "Sym(u2)" = torch.ops.aten.item.default(getitem_1);  getitem_1 = None
            return (item_1, item_2, item)

Graph signature:
    # inputs
    x: USER_INPUT
    y: USER_INPUT

    # outputs
    item_1: USER_OUTPUT
    item_2: USER_OUTPUT
    item: USER_OUTPUT

Range constraints: {u0: VR[-int_oo, int_oo], u1: VR[-int_oo, int_oo], u2: VR[-int_oo, int_oo]}

结果是分配并返回了 3 个未支持的符号(注意它们以“u”为前缀,而不是输入形状/已支持符号通常使用的“s”):1 个用于 item() 调用,每个 y 的元素使用 tolist() 调用各分配 1 个。请注意,从范围约束字段可以看出,这些符号的取值范围为 [-int_oo, int_oo],而不是分配给输入形状符号的默认 [0, int_oo] 范围,因为我们没有关于这些值是什么的信息——它们不代表大小,因此不一定有正值。

守卫,torch._check()#

但上面的情况很容易导出,因为这些符号的具体值不会用于任何编译器的决策;唯一相关的是返回值是未支持的符号。本节中强调的数据依赖错误是像以下这样的情况,即遇到了 数据依赖的守卫

class Foo(torch.nn.Module):
    def forward(self, x, y):
        a = x.item()
        if a // 2 >= 5:
            return y + 2
        else:
            return y * 5

在这里,我们实际上需要“提示”,即 a 的具体值,以便编译器决定是将 return y + 2 还是 return y * 5 跟踪为输出。因为我们使用 FakeTensors 进行跟踪,所以我们不知道 a // 2 >= 5 实际上计算为什么结果,导出时报错“Could not guard on data-dependent expression u0 // 2 >= 5 (unhinted)”。

那么我们如何导出这个玩具模型呢?与 torch.compile() 不同,导出需要完整的图编译,我们不能在这里仅仅进行图断点。这里有一些基本选项:

  1. 手动特化:我们可以通过选择要跟踪的分支来进行干预,可以通过删除控制流代码以仅包含特化分支,或者使用 torch.compiler.is_compiling() 来守护在编译时跟踪的内容。

  2. torch.cond():我们可以重写控制流代码以使用 torch.cond(),这样我们就不会在特定分支上进行特化。

虽然这些选项是有效的,但它们也有陷阱。选项 1 有时需要对模型代码进行剧烈的、侵入性的重写以进行特化,而 torch.cond() 并不是处理数据依赖错误的全面系统。正如我们将看到的,有些数据依赖错误并不涉及控制流。

通常推荐的方法是从 torch._check() 调用开始。虽然它们给人的印象仅仅是断言语句,但实际上它们是一个告知编译器符号属性的系统。虽然 torch._check() 调用在运行时确实充当断言,但在编译时跟踪时,被检查的表达式会被发送到符号化形状子系统进行推理,任何从表达式为真中得出的符号属性都会被存储为符号属性(前提是它足够智能以推断出这些属性)。因此,即使未支持的符号没有提示,如果我们能够通过 torch._check() 调用传达对于这些符号通常为真的属性,我们就有可能在不重写有问题的模型代码的情况下绕过数据依赖的守卫。

例如在上面的模型中,插入 torch._check(a >= 10) 会告诉编译器 y + 2 总是可以返回,而 torch._check(a == 4) 告诉它返回 y * 5。看看当我们重新导出此模型时会发生什么。

class Foo(torch.nn.Module):
    def forward(self, x, y):
        a = x.item()
        torch._check(a >= 10)
        torch._check(a <= 60)
        if a // 2 >= 5:
            return y + 2
        else:
            return y * 5

inps = (
    torch.tensor(32),
    torch.randn(4),
)
ep = export(Foo(), inps)
print(ep)
I0603 01:00:16.205000 28681 torch/fx/experimental/symbolic_shapes.py:4002] create_env
I0603 01:00:16.211000 28681 torch/fx/experimental/symbolic_shapes.py:5124] create_unbacked_symint u0 [-int_oo, int_oo] (_subclasses/fake_impls.py:849 in local_scalar_dense)
I0603 01:00:16.211000 28681 torch/fx/experimental/symbolic_shapes.py:1439] compute_unbacked_bindings [u0]
I0603 01:00:16.213000 28681 torch/fx/experimental/symbolic_shapes.py:7834] runtime_assert u0 >= 10 [guard added] (ar/lib/workspace/intermediate_source/torch_export_tutorial.py:673 in forward), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_GUARD_ADDED="u0 >= 10"
V0603 01:00:16.213000 28681 torch/fx/experimental/symbolic_shapes.py:7256] _update_var_to_range u0 = VR[10, int_oo] (update)
I0603 01:00:16.217000 28681 torch/fx/experimental/symbolic_shapes.py:7834] runtime_assert u0 <= 60 [guard added] (ar/lib/workspace/intermediate_source/torch_export_tutorial.py:674 in forward), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_GUARD_ADDED="u0 <= 60"
V0603 01:00:16.218000 28681 torch/fx/experimental/symbolic_shapes.py:7256] _update_var_to_range u0 = VR[10, 60] (update)
V0603 01:00:16.224000 28681 torch/fx/experimental/symbolic_shapes.py:8149] eval False == True [statically known]
I0603 01:00:16.227000 28681 torch/fx/experimental/symbolic_shapes.py:5669] produce_guards
V0603 01:00:16.228000 28681 torch/fx/experimental/symbolic_shapes.py:5902] track_symint L['x'].storage_offset() 0 None
V0603 01:00:16.228000 28681 torch/fx/experimental/symbolic_shapes.py:5902] track_symint L['y'].size()[0] 4 None
V0603 01:00:16.228000 28681 torch/fx/experimental/symbolic_shapes.py:5902] track_symint L['y'].stride()[0] 1 None
V0603 01:00:16.228000 28681 torch/fx/experimental/symbolic_shapes.py:5902] track_symint L['y'].storage_offset() 0 None
V0603 01:00:16.230000 28681 torch/fx/experimental/symbolic_shapes.py:8369] runtime_assert u0 >= 10 == True [statically known]
V0603 01:00:16.231000 28681 torch/fx/experimental/symbolic_shapes.py:8369] runtime_assert u0 <= 60 == True [statically known]
ExportedProgram:
    class GraphModule(torch.nn.Module):
        def forward(self, x: "i64[]", y: "f32[4]"):
            # File: /var/lib/workspace/intermediate_source/torch_export_tutorial.py:672 in forward, code: a = x.item()
            item: "Sym(u0)" = torch.ops.aten.item.default(x);  x = None
            ge_2: "Sym(u0 >= 10)" = item >= 10
            _assert_scalar_default = torch.ops.aten._assert_scalar.default(ge_2, "Runtime assertion failed for expression u0 >= 10 on node 'ge_2'");  ge_2 = _assert_scalar_default = None
            le_1: "Sym(u0 <= 60)" = item <= 60;  item = None
            _assert_scalar_default_1 = torch.ops.aten._assert_scalar.default(le_1, "Runtime assertion failed for expression u0 <= 60 on node 'le_1'");  le_1 = _assert_scalar_default_1 = None

            # File: /var/lib/workspace/intermediate_source/torch_export_tutorial.py:676 in forward, code: return y + 2
            add: "f32[4]" = torch.ops.aten.add.Tensor(y, 2);  y = None
            return (add,)

Graph signature:
    # inputs
    x: USER_INPUT
    y: USER_INPUT

    # outputs
    add: USER_OUTPUT

Range constraints: {u0: VR[10, 60]}

导出成功,请注意从范围约束字段中,u0 的取值范围为 [10, 60]

那么 torch._check() 调用实际上传达了什么信息呢?随着符号化形状子系统变得越来越聪明,这一点会有所不同,但在基本层面上,这些通常是真的:

  1. 与非数据依赖表达式的相等性:传达等式的 torch._check() 调用,如 u0 == s0 + 4u0 == 5

  2. 范围细化:为符号提供下界或上界的调用,如上所述。

  3. 围绕更复杂表达式的一些基本推理:插入 torch._check(a < 4) 通常会告诉编译器 a >= 4 为假。对复杂表达式的检查,如 torch._check(a ** 2 - 3 * a <= 10),通常会让你绕过相同的守卫。

如前所述,torch._check() 调用的适用范围超出了数据依赖的控制流。例如,这是一个模型,其中 torch._check() 插入起作用,而手动特化和 torch.cond() 却不行。

class Foo(torch.nn.Module):
    def forward(self, x, y):
        a = x.item()
        return y[a]

inps = (
    torch.tensor(32),
    torch.randn(60),
)
try:
    export(Foo(), inps)
except Exception:
    tb.print_exc()
I0603 01:00:16.237000 28681 torch/fx/experimental/symbolic_shapes.py:4002] create_env
I0603 01:00:16.243000 28681 torch/fx/experimental/symbolic_shapes.py:5124] create_unbacked_symint u0 [-int_oo, int_oo] (_subclasses/fake_impls.py:849 in local_scalar_dense)
I0603 01:00:16.243000 28681 torch/fx/experimental/symbolic_shapes.py:1439] compute_unbacked_bindings [u0]
I0603 01:00:16.245000 28681 torch/fx/experimental/symbolic_shapes.py:8002] could not evaluate u0 >= 0 due to data dependency, it was assumed to be False with no runtime assertions (_subclasses/fake_impls.py:498 in meta_select)
I0603 01:00:16.245000 28681 torch/fx/experimental/symbolic_shapes.py:8002] For C++ stack trace, run with TORCHDYNAMO_EXTENDED_DEBUG_CPP=1
I0603 01:00:16.246000 28681 torch/fx/experimental/symbolic_shapes.py:8002] could not evaluate u0 < 0 due to data dependency, it was assumed to be False with no runtime assertions (_subclasses/fake_impls.py:500 in meta_select)
I0603 01:00:16.246000 28681 torch/fx/experimental/symbolic_shapes.py:8002] For C++ stack trace, run with TORCHDYNAMO_EXTENDED_DEBUG_CPP=1
I0603 01:00:16.246000 28681 torch/fx/experimental/symbolic_shapes.py:5124] create_unbacked_symint u1 [-int_oo, int_oo] (_subclasses/fake_impls.py:512 in meta_select)
I0603 01:00:16.248000 28681 torch/fx/experimental/symbolic_shapes.py:8002] could not evaluate u1 >= 0 due to data dependency, it was assumed to be True with no runtime assertions (utils/_stats.py:29 in wrapper)
I0603 01:00:16.248000 28681 torch/fx/experimental/symbolic_shapes.py:8002] For C++ stack trace, run with TORCHDYNAMO_EXTENDED_DEBUG_CPP=1
I0603 01:00:16.248000 28681 torch/fx/experimental/symbolic_shapes.py:1439] compute_unbacked_bindings [u1]
I0603 01:00:16.250000 28681 torch/fx/experimental/symbolic_shapes.py:5669] produce_guards
V0603 01:00:16.251000 28681 torch/fx/experimental/symbolic_shapes.py:5902] track_symint L['x'].storage_offset() 0 None
V0603 01:00:16.251000 28681 torch/fx/experimental/symbolic_shapes.py:5902] track_symint L['y'].size()[0] 60 None
V0603 01:00:16.251000 28681 torch/fx/experimental/symbolic_shapes.py:5902] track_symint L['y'].stride()[0] 1 None
V0603 01:00:16.251000 28681 torch/fx/experimental/symbolic_shapes.py:5902] track_symint L['y'].storage_offset() 0 None

这是一个场景,其中插入 torch._check() 仅是为了防止操作失败。导出调用将因“Could not guard on data-dependent expression -u0 > 60”而失败,这意味着编译器不知道这是否是一个有效的索引操作——即 x 的值对于 y 是否越界。在这里,手动特化太受限制了,torch.cond() 也没有用武之地。相反,告知编译器 u0 的范围就足够了。

class Foo(torch.nn.Module):
    def forward(self, x, y):
        a = x.item()
        torch._check(a >= 0)
        torch._check(a < y.shape[0])
        return y[a]

inps = (
    torch.tensor(32),
    torch.randn(60),
)
ep = export(Foo(), inps)
print(ep)
I0603 01:00:16.255000 28681 torch/fx/experimental/symbolic_shapes.py:4002] create_env
I0603 01:00:16.261000 28681 torch/fx/experimental/symbolic_shapes.py:5124] create_unbacked_symint u0 [-int_oo, int_oo] (_subclasses/fake_impls.py:849 in local_scalar_dense)
I0603 01:00:16.261000 28681 torch/fx/experimental/symbolic_shapes.py:1439] compute_unbacked_bindings [u0]
I0603 01:00:16.262000 28681 torch/fx/experimental/symbolic_shapes.py:7834] runtime_assert u0 >= 0 [guard added] (ar/lib/workspace/intermediate_source/torch_export_tutorial.py:722 in forward), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_GUARD_ADDED="u0 >= 0"
V0603 01:00:16.262000 28681 torch/fx/experimental/symbolic_shapes.py:7256] _update_var_to_range u0 = VR[0, int_oo] (update)
I0603 01:00:16.265000 28681 torch/fx/experimental/symbolic_shapes.py:7834] runtime_assert u0 < 60 [guard added] (ar/lib/workspace/intermediate_source/torch_export_tutorial.py:723 in forward), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_GUARD_ADDED="u0 < 60"
V0603 01:00:16.266000 28681 torch/fx/experimental/symbolic_shapes.py:7256] _update_var_to_range u0 = VR[0, 59] (update)
I0603 01:00:16.271000 28681 torch/fx/experimental/symbolic_shapes.py:5669] produce_guards
V0603 01:00:16.272000 28681 torch/fx/experimental/symbolic_shapes.py:5902] track_symint L['x'].storage_offset() 0 None
V0603 01:00:16.272000 28681 torch/fx/experimental/symbolic_shapes.py:5902] track_symint L['y'].size()[0] 60 None
V0603 01:00:16.272000 28681 torch/fx/experimental/symbolic_shapes.py:5902] track_symint L['y'].stride()[0] 1 None
V0603 01:00:16.272000 28681 torch/fx/experimental/symbolic_shapes.py:5902] track_symint L['y'].storage_offset() 0 None
V0603 01:00:16.275000 28681 torch/fx/experimental/symbolic_shapes.py:8369] runtime_assert u0 <= 59 == True [statically known]
V0603 01:00:16.276000 28681 torch/fx/experimental/symbolic_shapes.py:8369] runtime_assert u0 < 60 == True [statically known]
ExportedProgram:
    class GraphModule(torch.nn.Module):
        def forward(self, x: "i64[]", y: "f32[60]"):
            # File: /var/lib/workspace/intermediate_source/torch_export_tutorial.py:721 in forward, code: a = x.item()
            item: "Sym(u0)" = torch.ops.aten.item.default(x);  x = None
            ge_1: "Sym(u0 >= 0)" = item >= 0
            _assert_scalar_default = torch.ops.aten._assert_scalar.default(ge_1, "Runtime assertion failed for expression u0 >= 0 on node 'ge_1'");  ge_1 = _assert_scalar_default = None
            le: "Sym(u0 <= 59)" = item <= 59
            _assert_scalar_default_1 = torch.ops.aten._assert_scalar.default(le, "Runtime assertion failed for expression u0 <= 59 on node 'le'");  le = _assert_scalar_default_1 = None
            lt_1: "Sym(u0 < 60)" = item < 60
            _assert_scalar_default_2 = torch.ops.aten._assert_scalar.default(lt_1, "Runtime assertion failed for expression u0 < 60 on node 'lt_1'");  lt_1 = _assert_scalar_default_2 = None

            # File: /var/lib/workspace/intermediate_source/torch_export_tutorial.py:724 in forward, code: return y[a]
            select: "f32[]" = torch.ops.aten.select.int(y, 0, item);  y = item = None
            return (select,)

Graph signature:
    # inputs
    x: USER_INPUT
    y: USER_INPUT

    # outputs
    select: USER_OUTPUT

Range constraints: {u0: VR[0, 59]}

特化值#

当程序尝试在跟踪时提取具体的、数据依赖的整数/浮点值时,会发生另一类数据依赖错误。这看起来像“Could not extract specialized integer from data-dependent expression”,类似于前一类错误——如果它们在尝试评估具体整数/浮点值时发生,评估具体布尔值时会产生数据依赖守卫错误。

当数据依赖表达式上有显式或隐式的 int() 类型转换时,通常会发生此错误。例如,此列表推导式有一个 range() 调用,该调用对列表的大小隐式进行了 int() 类型转换。

class Foo(torch.nn.Module):
    def forward(self, x, y):
        a = x.item()
        b = torch.cat([y for y in range(a)], dim=0)
        return b + int(a)

inps = (
    torch.tensor(32),
    torch.randn(60),
)
try:
    export(Foo(), inps, strict=False)
except Exception:
    tb.print_exc()
I0603 01:00:16.283000 28681 torch/fx/experimental/symbolic_shapes.py:4002] create_env
I0603 01:00:16.289000 28681 torch/fx/experimental/symbolic_shapes.py:5124] create_unbacked_symint u0 [-int_oo, int_oo] (_subclasses/fake_impls.py:849 in local_scalar_dense)
I0603 01:00:16.289000 28681 torch/fx/experimental/symbolic_shapes.py:1439] compute_unbacked_bindings [u0]
V0603 01:00:16.290000 28681 torch/fx/experimental/symbolic_shapes.py:7169] Data dependent variable 'u0' allocated at:
V0603 01:00:16.290000 28681 torch/fx/experimental/symbolic_shapes.py:7169]   File "/usr/local/bin/sphinx-build", line 6, in <module>
V0603 01:00:16.290000 28681 torch/fx/experimental/symbolic_shapes.py:7169]     sys.exit(main())
V0603 01:00:16.290000 28681 torch/fx/experimental/symbolic_shapes.py:7169]   File "/usr/local/lib/python3.10/dist-packages/sphinx/cmd/build.py", line 339, in main
V0603 01:00:16.290000 28681 torch/fx/experimental/symbolic_shapes.py:7169]     return make_main(argv)
V0603 01:00:16.290000 28681 torch/fx/experimental/symbolic_shapes.py:7169]   File "/usr/local/lib/python3.10/dist-packages/sphinx/cmd/build.py", line 213, in make_main
V0603 01:00:16.290000 28681 torch/fx/experimental/symbolic_shapes.py:7169]     return make_mode.run_make_mode(argv[1:])
V0603 01:00:16.290000 28681 torch/fx/experimental/symbolic_shapes.py:7169]   File "/usr/local/lib/python3.10/dist-packages/sphinx/cmd/make_mode.py", line 181, in run_make_mode
V0603 01:00:16.290000 28681 torch/fx/experimental/symbolic_shapes.py:7169]     return make.run_generic_build(args[0])
V0603 01:00:16.290000 28681 torch/fx/experimental/symbolic_shapes.py:7169]   File "/usr/local/lib/python3.10/dist-packages/sphinx/cmd/make_mode.py", line 169, in run_generic_build
V0603 01:00:16.290000 28681 torch/fx/experimental/symbolic_shapes.py:7169]     return build_main(args + opts)
V0603 01:00:16.290000 28681 torch/fx/experimental/symbolic_shapes.py:7169]   File "/usr/local/lib/python3.10/dist-packages/sphinx/cmd/build.py", line 293, in build_main
V0603 01:00:16.290000 28681 torch/fx/experimental/symbolic_shapes.py:7169]     app = Sphinx(args.sourcedir, args.confdir, args.outputdir,
V0603 01:00:16.290000 28681 torch/fx/experimental/symbolic_shapes.py:7169]   File "/usr/local/lib/python3.10/dist-packages/sphinx/application.py", line 272, in __init__
V0603 01:00:16.290000 28681 torch/fx/experimental/symbolic_shapes.py:7169]     self._init_builder()
V0603 01:00:16.290000 28681 torch/fx/experimental/symbolic_shapes.py:7169]   File "/usr/local/lib/python3.10/dist-packages/sphinx/application.py", line 343, in _init_builder
V0603 01:00:16.290000 28681 torch/fx/experimental/symbolic_shapes.py:7169]     self.events.emit('builder-inited')
V0603 01:00:16.290000 28681 torch/fx/experimental/symbolic_shapes.py:7169]   File "/usr/local/lib/python3.10/dist-packages/sphinx/events.py", line 97, in emit
V0603 01:00:16.290000 28681 torch/fx/experimental/symbolic_shapes.py:7169]     results.append(listener.handler(self.app, *args))
V0603 01:00:16.290000 28681 torch/fx/experimental/symbolic_shapes.py:7169]   File "/usr/local/lib/python3.10/dist-packages/sphinx_gallery/gen_gallery.py", line 757, in generate_gallery_rst
V0603 01:00:16.290000 28681 torch/fx/experimental/symbolic_shapes.py:7169]     ) = generate_dir_rst(
V0603 01:00:16.290000 28681 torch/fx/experimental/symbolic_shapes.py:7169]   File "/usr/local/lib/python3.10/dist-packages/sphinx_gallery/gen_rst.py", line 606, in generate_dir_rst
V0603 01:00:16.290000 28681 torch/fx/experimental/symbolic_shapes.py:7169]     results = parallel(
V0603 01:00:16.290000 28681 torch/fx/experimental/symbolic_shapes.py:7169]   File "/usr/local/lib/python3.10/dist-packages/sphinx_gallery/gen_rst.py", line 607, in <genexpr>
V0603 01:00:16.290000 28681 torch/fx/experimental/symbolic_shapes.py:7169]     p_fun(fname, target_dir, src_dir, gallery_conf) for fname in iterator
V0603 01:00:16.290000 28681 torch/fx/experimental/symbolic_shapes.py:7169]   File "/var/lib/workspace/conf.py", line 85, in wrapper
V0603 01:00:16.290000 28681 torch/fx/experimental/symbolic_shapes.py:7169]     p.start()
V0603 01:00:16.290000 28681 torch/fx/experimental/symbolic_shapes.py:7169]   File "/usr/lib/python3.10/multiprocessing/process.py", line 121, in start
V0603 01:00:16.290000 28681 torch/fx/experimental/symbolic_shapes.py:7169]     self._popen = self._Popen(self)
V0603 01:00:16.290000 28681 torch/fx/experimental/symbolic_shapes.py:7169]   File "/usr/lib/python3.10/multiprocessing/context.py", line 224, in _Popen
V0603 01:00:16.290000 28681 torch/fx/experimental/symbolic_shapes.py:7169]     return _default_context.get_context().Process._Popen(process_obj)
V0603 01:00:16.290000 28681 torch/fx/experimental/symbolic_shapes.py:7169]   File "/usr/lib/python3.10/multiprocessing/context.py", line 281, in _Popen
V0603 01:00:16.290000 28681 torch/fx/experimental/symbolic_shapes.py:7169]     return Popen(process_obj)
V0603 01:00:16.290000 28681 torch/fx/experimental/symbolic_shapes.py:7169]   File "/usr/lib/python3.10/multiprocessing/popen_fork.py", line 19, in __init__
V0603 01:00:16.290000 28681 torch/fx/experimental/symbolic_shapes.py:7169]     self._launch(process_obj)
V0603 01:00:16.290000 28681 torch/fx/experimental/symbolic_shapes.py:7169]   File "/usr/lib/python3.10/multiprocessing/popen_fork.py", line 71, in _launch
V0603 01:00:16.290000 28681 torch/fx/experimental/symbolic_shapes.py:7169]     code = process_obj._bootstrap(parent_sentinel=child_r)
V0603 01:00:16.290000 28681 torch/fx/experimental/symbolic_shapes.py:7169]   File "/usr/lib/python3.10/multiprocessing/process.py", line 314, in _bootstrap
V0603 01:00:16.290000 28681 torch/fx/experimental/symbolic_shapes.py:7169]     self.run()
V0603 01:00:16.290000 28681 torch/fx/experimental/symbolic_shapes.py:7169]   File "/usr/lib/python3.10/multiprocessing/process.py", line 108, in run
V0603 01:00:16.290000 28681 torch/fx/experimental/symbolic_shapes.py:7169]     self._target(*self._args, **self._kwargs)
V0603 01:00:16.290000 28681 torch/fx/experimental/symbolic_shapes.py:7169]   File "/var/lib/workspace/conf.py", line 73, in call_fn
V0603 01:00:16.290000 28681 torch/fx/experimental/symbolic_shapes.py:7169]     result = func(*args, **kwargs)
V0603 01:00:16.290000 28681 torch/fx/experimental/symbolic_shapes.py:7169]   File "/usr/local/lib/python3.10/dist-packages/sphinx_gallery/gen_rst.py", line 1374, in generate_file_rst
V0603 01:00:16.290000 28681 torch/fx/experimental/symbolic_shapes.py:7169]     output_blocks, time_elapsed = execute_script(
V0603 01:00:16.290000 28681 torch/fx/experimental/symbolic_shapes.py:7169]   File "/usr/local/lib/python3.10/dist-packages/sphinx_gallery/gen_rst.py", line 1192, in execute_script
V0603 01:00:16.290000 28681 torch/fx/experimental/symbolic_shapes.py:7169]     execute_code_block(
V0603 01:00:16.290000 28681 torch/fx/experimental/symbolic_shapes.py:7169]   File "/usr/local/lib/python3.10/dist-packages/sphinx_gallery/gen_rst.py", line 1048, in execute_code_block
V0603 01:00:16.290000 28681 torch/fx/experimental/symbolic_shapes.py:7169]     is_last_expr, mem_max = _exec_and_get_memory(
V0603 01:00:16.290000 28681 torch/fx/experimental/symbolic_shapes.py:7169]   File "/usr/local/lib/python3.10/dist-packages/sphinx_gallery/gen_rst.py", line 876, in _exec_and_get_memory
V0603 01:00:16.290000 28681 torch/fx/experimental/symbolic_shapes.py:7169]     mem_max, _ = call_memory(
V0603 01:00:16.290000 28681 torch/fx/experimental/symbolic_shapes.py:7169]   File "/usr/local/lib/python3.10/dist-packages/sphinx_gallery/gen_rst.py", line 1725, in _sg_call_memory_noop
V0603 01:00:16.290000 28681 torch/fx/experimental/symbolic_shapes.py:7169]     return 0.0, func()
V0603 01:00:16.290000 28681 torch/fx/experimental/symbolic_shapes.py:7169]   File "/usr/local/lib/python3.10/dist-packages/sphinx_gallery/gen_rst.py", line 794, in __call__
V0603 01:00:16.290000 28681 torch/fx/experimental/symbolic_shapes.py:7169]     exec(self.code, self.fake_main.__dict__)
V0603 01:00:16.290000 28681 torch/fx/experimental/symbolic_shapes.py:7169]   File "/var/lib/workspace/intermediate_source/torch_export_tutorial.py", line 756, in <module>
V0603 01:00:16.290000 28681 torch/fx/experimental/symbolic_shapes.py:7169]     export(Foo(), inps, strict=False)
V0603 01:00:16.290000 28681 torch/fx/experimental/symbolic_shapes.py:7169]   File "/usr/local/lib/python3.10/dist-packages/torch/export/__init__.py", line 171, in export
V0603 01:00:16.290000 28681 torch/fx/experimental/symbolic_shapes.py:7169]     return _export(
V0603 01:00:16.290000 28681 torch/fx/experimental/symbolic_shapes.py:7169]   File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 1310, in wrapper
V0603 01:00:16.290000 28681 torch/fx/experimental/symbolic_shapes.py:7169]     ep = fn(*args, **kwargs)
V0603 01:00:16.290000 28681 torch/fx/experimental/symbolic_shapes.py:7169]   File "/usr/local/lib/python3.10/dist-packages/torch/export/exported_program.py", line 124, in wrapper
V0603 01:00:16.290000 28681 torch/fx/experimental/symbolic_shapes.py:7169]     return fn(*args, **kwargs)
V0603 01:00:16.290000 28681 torch/fx/experimental/symbolic_shapes.py:7169]   File "/usr/local/lib/python3.10/dist-packages/torch/_utils_internal.py", line 96, in wrapper_function
V0603 01:00:16.290000 28681 torch/fx/experimental/symbolic_shapes.py:7169]     return function(*args, **kwargs)
V0603 01:00:16.290000 28681 torch/fx/experimental/symbolic_shapes.py:7169]   File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 2512, in _export
V0603 01:00:16.290000 28681 torch/fx/experimental/symbolic_shapes.py:7169]     ep = _export_for_training(
V0603 01:00:16.290000 28681 torch/fx/experimental/symbolic_shapes.py:7169]   File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 1310, in wrapper
V0603 01:00:16.290000 28681 torch/fx/experimental/symbolic_shapes.py:7169]     ep = fn(*args, **kwargs)
V0603 01:00:16.290000 28681 torch/fx/experimental/symbolic_shapes.py:7169]   File "/usr/local/lib/python3.10/dist-packages/torch/export/exported_program.py", line 124, in wrapper
V0603 01:00:16.290000 28681 torch/fx/experimental/symbolic_shapes.py:7169]     return fn(*args, **kwargs)
V0603 01:00:16.290000 28681 torch/fx/experimental/symbolic_shapes.py:7169]   File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 2300, in _export_for_training
V0603 01:00:16.290000 28681 torch/fx/experimental/symbolic_shapes.py:7169]     export_artifact = export_func(
V0603 01:00:16.290000 28681 torch/fx/experimental/symbolic_shapes.py:7169]   File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 2229, in _non_strict_export
V0603 01:00:16.290000 28681 torch/fx/experimental/symbolic_shapes.py:7169]     aten_export_artifact = _to_aten_func(
V0603 01:00:16.290000 28681 torch/fx/experimental/symbolic_shapes.py:7169]   File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 2006, in _export_to_aten_ir_make_fx
V0603 01:00:16.290000 28681 torch/fx/experimental/symbolic_shapes.py:7169]     gm, graph_signature = transform(_make_fx_helper)(
V0603 01:00:16.290000 28681 torch/fx/experimental/symbolic_shapes.py:7169]   File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 2136, in _aot_export_non_strict
V0603 01:00:16.290000 28681 torch/fx/experimental/symbolic_shapes.py:7169]     gm, sig = aot_export(stack, wrapped_mod, args, kwargs=kwargs, **flags)
V0603 01:00:16.290000 28681 torch/fx/experimental/symbolic_shapes.py:7169]   File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 1914, in _make_fx_helper
V0603 01:00:16.290000 28681 torch/fx/experimental/symbolic_shapes.py:7169]     gm = make_fx(
V0603 01:00:16.290000 28681 torch/fx/experimental/symbolic_shapes.py:7169]   File "/usr/local/lib/python3.10/dist-packages/torch/fx/experimental/proxy_tensor.py", line 2965, in wrapped
V0603 01:00:16.290000 28681 torch/fx/experimental/symbolic_shapes.py:7169]     return make_fx_tracer.trace(f, *args)
V0603 01:00:16.290000 28681 torch/fx/experimental/symbolic_shapes.py:7169]   File "/usr/local/lib/python3.10/dist-packages/torch/fx/experimental/proxy_tensor.py", line 2867, in trace
V0603 01:00:16.290000 28681 torch/fx/experimental/symbolic_shapes.py:7169]     return self._trace_inner(f, *args)
V0603 01:00:16.290000 28681 torch/fx/experimental/symbolic_shapes.py:7169]   File "/usr/local/lib/python3.10/dist-packages/torch/fx/experimental/proxy_tensor.py", line 2828, in _trace_inner
V0603 01:00:16.290000 28681 torch/fx/experimental/symbolic_shapes.py:7169]     t = dispatch_trace(
V0603 01:00:16.290000 28681 torch/fx/experimental/symbolic_shapes.py:7169]   File "/usr/local/lib/python3.10/dist-packages/torch/_compile.py", line 54, in inner
V0603 01:00:16.290000 28681 torch/fx/experimental/symbolic_shapes.py:7169]     return disable_fn(*args, **kwargs)
V0603 01:00:16.290000 28681 torch/fx/experimental/symbolic_shapes.py:7169]   File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/eval_frame.py", line 1297, in _fn
V0603 01:00:16.290000 28681 torch/fx/experimental/symbolic_shapes.py:7169]     return fn(*args, **kwargs)
V0603 01:00:16.290000 28681 torch/fx/experimental/symbolic_shapes.py:7169]   File "/usr/local/lib/python3.10/dist-packages/torch/fx/experimental/proxy_tensor.py", line 1673, in dispatch_trace
V0603 01:00:16.290000 28681 torch/fx/experimental/symbolic_shapes.py:7169]     graph = tracer.trace(root, concrete_args)  # type: ignore[arg-type]
V0603 01:00:16.290000 28681 torch/fx/experimental/symbolic_shapes.py:7169]   File "/usr/local/lib/python3.10/dist-packages/torch/fx/experimental/proxy_tensor.py", line 2402, in trace
V0603 01:00:16.290000 28681 torch/fx/experimental/symbolic_shapes.py:7169]     res = super().trace(root, concrete_args)
V0603 01:00:16.290000 28681 torch/fx/experimental/symbolic_shapes.py:7169]   File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/eval_frame.py", line 1297, in _fn
V0603 01:00:16.290000 28681 torch/fx/experimental/symbolic_shapes.py:7169]     return fn(*args, **kwargs)
V0603 01:00:16.290000 28681 torch/fx/experimental/symbolic_shapes.py:7169]   File "/usr/local/lib/python3.10/dist-packages/torch/fx/_symbolic_trace.py", line 912, in trace
V0603 01:00:16.290000 28681 torch/fx/experimental/symbolic_shapes.py:7169]     (self.create_arg(fn(*args)),),
V0603 01:00:16.290000 28681 torch/fx/experimental/symbolic_shapes.py:7169]   File "/usr/local/lib/python3.10/dist-packages/torch/fx/experimental/proxy_tensor.py", line 1743, in wrapped
V0603 01:00:16.290000 28681 torch/fx/experimental/symbolic_shapes.py:7169]     out = f(*tensors)  # type:ignore[call-arg]
V0603 01:00:16.290000 28681 torch/fx/experimental/symbolic_shapes.py:7169]   File "<string>", line 1, in <lambda>
V0603 01:00:16.290000 28681 torch/fx/experimental/symbolic_shapes.py:7169]   File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 1798, in wrapped_fn
V0603 01:00:16.290000 28681 torch/fx/experimental/symbolic_shapes.py:7169]     return tuple(flat_fn(*args))
V0603 01:00:16.290000 28681 torch/fx/experimental/symbolic_shapes.py:7169]   File "/usr/local/lib/python3.10/dist-packages/torch/_functorch/_aot_autograd/utils.py", line 192, in flat_fn
V0603 01:00:16.290000 28681 torch/fx/experimental/symbolic_shapes.py:7169]     tree_out = fn(*args, **kwargs)
V0603 01:00:16.290000 28681 torch/fx/experimental/symbolic_shapes.py:7169]   File "/usr/local/lib/python3.10/dist-packages/torch/_functorch/_aot_autograd/graph_capture_wrappers.py", line 1536, in functional_call
V0603 01:00:16.290000 28681 torch/fx/experimental/symbolic_shapes.py:7169]     out = mod(*args[params_len:], **kwargs)
V0603 01:00:16.290000 28681 torch/fx/experimental/symbolic_shapes.py:7169]   File "/usr/local/lib/python3.10/dist-packages/torch/fx/_symbolic_trace.py", line 886, in module_call_wrapper
V0603 01:00:16.290000 28681 torch/fx/experimental/symbolic_shapes.py:7169]     return self.call_module(mod, forward, args, kwargs)
V0603 01:00:16.290000 28681 torch/fx/experimental/symbolic_shapes.py:7169]   File "/usr/local/lib/python3.10/dist-packages/torch/fx/experimental/proxy_tensor.py", line 2491, in call_module
V0603 01:00:16.290000 28681 torch/fx/experimental/symbolic_shapes.py:7169]     return Tracer.call_module(self, m, forward, args, kwargs)
V0603 01:00:16.290000 28681 torch/fx/experimental/symbolic_shapes.py:7169]   File "/usr/local/lib/python3.10/dist-packages/torch/fx/_symbolic_trace.py", line 577, in call_module
V0603 01:00:16.290000 28681 torch/fx/experimental/symbolic_shapes.py:7169]     ret_val = forward(*args, **kwargs)
V0603 01:00:16.290000 28681 torch/fx/experimental/symbolic_shapes.py:7169]   File "/usr/local/lib/python3.10/dist-packages/torch/fx/_symbolic_trace.py", line 879, in forward
V0603 01:00:16.290000 28681 torch/fx/experimental/symbolic_shapes.py:7169]     return _orig_module_call(mod, *args, **kwargs)
V0603 01:00:16.290000 28681 torch/fx/experimental/symbolic_shapes.py:7169]   File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1778, in _wrapped_call_impl
V0603 01:00:16.290000 28681 torch/fx/experimental/symbolic_shapes.py:7169]     return self._call_impl(*args, **kwargs)
V0603 01:00:16.290000 28681 torch/fx/experimental/symbolic_shapes.py:7169]   File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1789, in _call_impl
V0603 01:00:16.290000 28681 torch/fx/experimental/symbolic_shapes.py:7169]     return forward_call(*args, **kwargs)
V0603 01:00:16.290000 28681 torch/fx/experimental/symbolic_shapes.py:7169]   File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 2120, in forward
V0603 01:00:16.290000 28681 torch/fx/experimental/symbolic_shapes.py:7169]     tree_out = mod(*args, **kwargs)
V0603 01:00:16.290000 28681 torch/fx/experimental/symbolic_shapes.py:7169]   File "/usr/local/lib/python3.10/dist-packages/torch/fx/_symbolic_trace.py", line 886, in module_call_wrapper
V0603 01:00:16.290000 28681 torch/fx/experimental/symbolic_shapes.py:7169]     return self.call_module(mod, forward, args, kwargs)
V0603 01:00:16.290000 28681 torch/fx/experimental/symbolic_shapes.py:7169]   File "/usr/local/lib/python3.10/dist-packages/torch/fx/experimental/proxy_tensor.py", line 2491, in call_module
V0603 01:00:16.290000 28681 torch/fx/experimental/symbolic_shapes.py:7169]     return Tracer.call_module(self, m, forward, args, kwargs)
V0603 01:00:16.290000 28681 torch/fx/experimental/symbolic_shapes.py:7169]   File "/usr/local/lib/python3.10/dist-packages/torch/fx/_symbolic_trace.py", line 577, in call_module
V0603 01:00:16.290000 28681 torch/fx/experimental/symbolic_shapes.py:7169]     ret_val = forward(*args, **kwargs)
V0603 01:00:16.290000 28681 torch/fx/experimental/symbolic_shapes.py:7169]   File "/usr/local/lib/python3.10/dist-packages/torch/fx/_symbolic_trace.py", line 879, in forward
V0603 01:00:16.290000 28681 torch/fx/experimental/symbolic_shapes.py:7169]     return _orig_module_call(mod, *args, **kwargs)
V0603 01:00:16.290000 28681 torch/fx/experimental/symbolic_shapes.py:7169]   File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1778, in _wrapped_call_impl
V0603 01:00:16.290000 28681 torch/fx/experimental/symbolic_shapes.py:7169]     return self._call_impl(*args, **kwargs)
V0603 01:00:16.290000 28681 torch/fx/experimental/symbolic_shapes.py:7169]   File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1789, in _call_impl
V0603 01:00:16.290000 28681 torch/fx/experimental/symbolic_shapes.py:7169]     return forward_call(*args, **kwargs)
V0603 01:00:16.290000 28681 torch/fx/experimental/symbolic_shapes.py:7169]   File "/var/lib/workspace/intermediate_source/torch_export_tutorial.py", line 747, in forward
V0603 01:00:16.290000 28681 torch/fx/experimental/symbolic_shapes.py:7169]     a = x.item()
V0603 01:00:16.290000 28681 torch/fx/experimental/symbolic_shapes.py:7169]   File "/usr/local/lib/python3.10/dist-packages/torch/fx/experimental/proxy_tensor.py", line 1798, in __torch_function__
V0603 01:00:16.290000 28681 torch/fx/experimental/symbolic_shapes.py:7169]     return func(*args, **kwargs)
V0603 01:00:16.290000 28681 torch/fx/experimental/symbolic_shapes.py:7169]   File "/usr/local/lib/python3.10/dist-packages/torch/fx/experimental/proxy_tensor.py", line 1869, in __torch_function__
V0603 01:00:16.290000 28681 torch/fx/experimental/symbolic_shapes.py:7169]     return func(*args, **kwargs)
V0603 01:00:16.290000 28681 torch/fx/experimental/symbolic_shapes.py:7169]   File "/usr/local/lib/python3.10/dist-packages/torch/_export/non_strict_utils.py", line 1167, in __torch_function__
V0603 01:00:16.290000 28681 torch/fx/experimental/symbolic_shapes.py:7169]     return func(*args, **kwargs)
V0603 01:00:16.290000 28681 torch/fx/experimental/symbolic_shapes.py:7169]   File "/usr/local/lib/python3.10/dist-packages/torch/_ops.py", line 994, in handler
V0603 01:00:16.290000 28681 torch/fx/experimental/symbolic_shapes.py:7169]     return torch._library.utils.handle_dispatch_mode(
V0603 01:00:16.290000 28681 torch/fx/experimental/symbolic_shapes.py:7169]   File "/usr/local/lib/python3.10/dist-packages/torch/_library/utils.py", line 325, in handle_dispatch_mode
V0603 01:00:16.290000 28681 torch/fx/experimental/symbolic_shapes.py:7169]     return curr_mode.__torch_dispatch__(op_overload, overload_types, args, kwargs)
V0603 01:00:16.290000 28681 torch/fx/experimental/symbolic_shapes.py:7169]   File "/usr/local/lib/python3.10/dist-packages/torch/_compile.py", line 54, in inner
V0603 01:00:16.290000 28681 torch/fx/experimental/symbolic_shapes.py:7169]     return disable_fn(*args, **kwargs)
V0603 01:00:16.290000 28681 torch/fx/experimental/symbolic_shapes.py:7169]   File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/eval_frame.py", line 1298, in _fn
V0603 01:00:16.290000 28681 torch/fx/experimental/symbolic_shapes.py:7169]     return fn(*args, **kwargs)
V0603 01:00:16.290000 28681 torch/fx/experimental/symbolic_shapes.py:7169]   File "/usr/local/lib/python3.10/dist-packages/torch/utils/_stats.py", line 29, in wrapper
V0603 01:00:16.290000 28681 torch/fx/experimental/symbolic_shapes.py:7169]     return fn(*args, **kwargs)
V0603 01:00:16.290000 28681 torch/fx/experimental/symbolic_shapes.py:7169]   File "/usr/local/lib/python3.10/dist-packages/torch/fx/experimental/proxy_tensor.py", line 1929, in __torch_dispatch__
V0603 01:00:16.290000 28681 torch/fx/experimental/symbolic_shapes.py:7169]     return proxy_call(self, func, self.pre_dispatch, args, kwargs)
V0603 01:00:16.290000 28681 torch/fx/experimental/symbolic_shapes.py:7169]   File "/usr/local/lib/python3.10/dist-packages/torch/fx/experimental/proxy_tensor.py", line 1268, in proxy_call
V0603 01:00:16.290000 28681 torch/fx/experimental/symbolic_shapes.py:7169]     out = func(*args, **kwargs)
V0603 01:00:16.290000 28681 torch/fx/experimental/symbolic_shapes.py:7169]   File "/usr/local/lib/python3.10/dist-packages/torch/_ops.py", line 871, in __call__
V0603 01:00:16.290000 28681 torch/fx/experimental/symbolic_shapes.py:7169]     return self._op(*args, **kwargs)
V0603 01:00:16.290000 28681 torch/fx/experimental/symbolic_shapes.py:7169]   File "/usr/local/lib/python3.10/dist-packages/torch/_compile.py", line 54, in inner
V0603 01:00:16.290000 28681 torch/fx/experimental/symbolic_shapes.py:7169]     return disable_fn(*args, **kwargs)
V0603 01:00:16.290000 28681 torch/fx/experimental/symbolic_shapes.py:7169]   File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/eval_frame.py", line 1298, in _fn
V0603 01:00:16.290000 28681 torch/fx/experimental/symbolic_shapes.py:7169]     return fn(*args, **kwargs)
V0603 01:00:16.290000 28681 torch/fx/experimental/symbolic_shapes.py:7169]   File "/usr/local/lib/python3.10/dist-packages/torch/utils/_stats.py", line 29, in wrapper
V0603 01:00:16.290000 28681 torch/fx/experimental/symbolic_shapes.py:7169]     return fn(*args, **kwargs)
V0603 01:00:16.290000 28681 torch/fx/experimental/symbolic_shapes.py:7169]   File "/usr/local/lib/python3.10/dist-packages/torch/_subclasses/fake_tensor.py", line 1465, in __torch_dispatch__
V0603 01:00:16.290000 28681 torch/fx/experimental/symbolic_shapes.py:7169]     return self.dispatch(func, types, args, kwargs)
V0603 01:00:16.290000 28681 torch/fx/experimental/symbolic_shapes.py:7169]   File "/usr/local/lib/python3.10/dist-packages/torch/_subclasses/fake_tensor.py", line 2242, in dispatch
V0603 01:00:16.290000 28681 torch/fx/experimental/symbolic_shapes.py:7169]     return self._cached_dispatch_impl(func, types, args, kwargs)
V0603 01:00:16.290000 28681 torch/fx/experimental/symbolic_shapes.py:7169]   File "/usr/local/lib/python3.10/dist-packages/torch/_subclasses/fake_tensor.py", line 1598, in _cached_dispatch_impl
V0603 01:00:16.290000 28681 torch/fx/experimental/symbolic_shapes.py:7169]     return self._dispatch_impl(func, types, args, kwargs)
V0603 01:00:16.290000 28681 torch/fx/experimental/symbolic_shapes.py:7169]   File "/usr/local/lib/python3.10/dist-packages/torch/_subclasses/fake_tensor.py", line 2908, in _dispatch_impl
V0603 01:00:16.290000 28681 torch/fx/experimental/symbolic_shapes.py:7169]     op_impl_out = op_impl(self, func, *args, **kwargs)
V0603 01:00:16.290000 28681 torch/fx/experimental/symbolic_shapes.py:7169]   File "/usr/local/lib/python3.10/dist-packages/torch/_subclasses/fake_impls.py", line 205, in dispatch_to_op_implementations_dict
V0603 01:00:16.290000 28681 torch/fx/experimental/symbolic_shapes.py:7169]     return op_implementations_dict[func](fake_mode, func, *args, **kwargs)
V0603 01:00:16.290000 28681 torch/fx/experimental/symbolic_shapes.py:7169]   File "/usr/local/lib/python3.10/dist-packages/torch/_subclasses/fake_impls.py", line 849, in local_scalar_dense
V0603 01:00:16.290000 28681 torch/fx/experimental/symbolic_shapes.py:7169]     r = fake_mode.shape_env.create_unbacked_symint()
V0603 01:00:16.290000 28681 torch/fx/experimental/symbolic_shapes.py:7169]   File "/usr/local/lib/python3.10/dist-packages/torch/fx/experimental/recording.py", line 297, in wrapper
V0603 01:00:16.290000 28681 torch/fx/experimental/symbolic_shapes.py:7169]     return retlog(fn(*args, **kwargs))
V0603 01:00:16.290000 28681 torch/fx/experimental/symbolic_shapes.py:7169]



def forward(self, arg0_1: "i64[]", arg1_1: "f32[60]"):
    # File: /var/lib/workspace/intermediate_source/torch_export_tutorial.py:747 in forward, code: a = x.item()
    item: "Sym(u0)" = torch.ops.aten.item.default(arg0_1);  arg0_1 = item = None




def forward(self, arg0_1: "i64[]", arg1_1: "f32[60]"):
    # File: /var/lib/workspace/intermediate_source/torch_export_tutorial.py:747 in forward, code: a = x.item()
    item: "Sym(u0)" = torch.ops.aten.item.default(arg0_1);  arg0_1 = item = None

Traceback (most recent call last):
  File "/var/lib/workspace/intermediate_source/torch_export_tutorial.py", line 756, in <module>
    export(Foo(), inps, strict=False)
  File "/usr/local/lib/python3.10/dist-packages/torch/export/__init__.py", line 205, in export
    raise e
  File "/usr/local/lib/python3.10/dist-packages/torch/export/__init__.py", line 171, in export
    return _export(
  File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 1344, in wrapper
    raise e
  File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 1310, in wrapper
    ep = fn(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/export/exported_program.py", line 124, in wrapper
    return fn(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/_utils_internal.py", line 96, in wrapper_function
    return function(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 2512, in _export
    ep = _export_for_training(
  File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 1344, in wrapper
    raise e
  File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 1310, in wrapper
    ep = fn(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/export/exported_program.py", line 124, in wrapper
    return fn(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 2300, in _export_for_training
    export_artifact = export_func(
  File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 2229, in _non_strict_export
    aten_export_artifact = _to_aten_func(
  File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 2006, in _export_to_aten_ir_make_fx
    gm, graph_signature = transform(_make_fx_helper)(
  File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 2136, in _aot_export_non_strict
    gm, sig = aot_export(stack, wrapped_mod, args, kwargs=kwargs, **flags)
  File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 1914, in _make_fx_helper
    gm = make_fx(
  File "/usr/local/lib/python3.10/dist-packages/torch/fx/experimental/proxy_tensor.py", line 2965, in wrapped
    return make_fx_tracer.trace(f, *args)
  File "/usr/local/lib/python3.10/dist-packages/torch/fx/experimental/proxy_tensor.py", line 2867, in trace
    return self._trace_inner(f, *args)
  File "/usr/local/lib/python3.10/dist-packages/torch/fx/experimental/proxy_tensor.py", line 2828, in _trace_inner
    t = dispatch_trace(
  File "/usr/local/lib/python3.10/dist-packages/torch/_compile.py", line 54, in inner
    return disable_fn(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/eval_frame.py", line 1297, in _fn
    return fn(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/fx/experimental/proxy_tensor.py", line 1673, in dispatch_trace
    graph = tracer.trace(root, concrete_args)  # type: ignore[arg-type]
  File "/usr/local/lib/python3.10/dist-packages/torch/fx/experimental/proxy_tensor.py", line 2402, in trace
    res = super().trace(root, concrete_args)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/eval_frame.py", line 1297, in _fn
    return fn(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/fx/_symbolic_trace.py", line 912, in trace
    (self.create_arg(fn(*args)),),
  File "/usr/local/lib/python3.10/dist-packages/torch/fx/experimental/proxy_tensor.py", line 1743, in wrapped
    out = f(*tensors)  # type:ignore[call-arg]
  File "<string>", line 1, in <lambda>
  File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 1798, in wrapped_fn
    return tuple(flat_fn(*args))
  File "/usr/local/lib/python3.10/dist-packages/torch/_functorch/_aot_autograd/utils.py", line 192, in flat_fn
    tree_out = fn(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/_functorch/_aot_autograd/graph_capture_wrappers.py", line 1536, in functional_call
    out = mod(*args[params_len:], **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/fx/_symbolic_trace.py", line 886, in module_call_wrapper
    return self.call_module(mod, forward, args, kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/fx/experimental/proxy_tensor.py", line 2491, in call_module
    return Tracer.call_module(self, m, forward, args, kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/fx/_symbolic_trace.py", line 577, in call_module
    ret_val = forward(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/fx/_symbolic_trace.py", line 879, in forward
    return _orig_module_call(mod, *args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1778, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1789, in _call_impl
    return forward_call(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 2120, in forward
    tree_out = mod(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/fx/_symbolic_trace.py", line 886, in module_call_wrapper
    return self.call_module(mod, forward, args, kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/fx/experimental/proxy_tensor.py", line 2491, in call_module
    return Tracer.call_module(self, m, forward, args, kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/fx/_symbolic_trace.py", line 577, in call_module
    ret_val = forward(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/fx/_symbolic_trace.py", line 879, in forward
    return _orig_module_call(mod, *args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1778, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1789, in _call_impl
    return forward_call(*args, **kwargs)
  File "/var/lib/workspace/intermediate_source/torch_export_tutorial.py", line 748, in forward
    b = torch.cat([y for y in range(a)], dim=0)
  File "/usr/local/lib/python3.10/dist-packages/torch/__init__.py", line 466, in __index__
    return self.node.int_()
  File "/usr/local/lib/python3.10/dist-packages/torch/fx/experimental/sym_node.py", line 502, in int_
    return self.guard_int("", 0)  # NB: uses Python backtrace
  File "/usr/local/lib/python3.10/dist-packages/torch/fx/experimental/sym_node.py", line 556, in guard_int
    r = self.evaluate()
  File "/usr/local/lib/python3.10/dist-packages/torch/fx/experimental/sym_node.py", line 550, in evaluate
    return self.shape_env.evaluate_sym_node(self, size_oblivious)
  File "/usr/local/lib/python3.10/dist-packages/torch/fx/experimental/symbolic_shapes.py", line 7860, in evaluate_sym_node
    return self.evaluate_expr(
  File "/usr/local/lib/python3.10/dist-packages/torch/fx/experimental/symbolic_shapes.py", line 7956, in evaluate_expr
    return self._inner_evaluate_expr(
  File "/usr/local/lib/python3.10/dist-packages/torch/fx/experimental/recording.py", line 297, in wrapper
    return retlog(fn(*args, **kwargs))
  File "/usr/local/lib/python3.10/dist-packages/torch/fx/experimental/symbolic_shapes.py", line 7979, in _inner_evaluate_expr
    return self._evaluate_expr(
  File "/usr/local/lib/python3.10/dist-packages/torch/fx/experimental/symbolic_shapes.py", line 8212, in _evaluate_expr
    raise self._make_data_dependent_error(
torch.fx.experimental.symbolic_shapes.GuardOnDataDependentSymNode: Could not extract specialized integer from data-dependent expression u0 (unhinted: u0).  (Size-like symbols: none)


Caused by: (ar/lib/workspace/intermediate_source/torch_export_tutorial.py:748 in forward)
For more information, run with TORCH_LOGS="dynamic"
For extended logs when we create symbols, also add TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL="u0"
If you suspect the guard was triggered from C++, add TORCHDYNAMO_EXTENDED_DEBUG_CPP=1
For more debugging help, see https://docs.google.com/document/d/1HSuTTVvYH1pTew89Rtpeu84Ht3nQEFTYhAX3Ypa_xJs/edit?usp=sharing

For C++ stack trace, run with TORCHDYNAMO_EXTENDED_DEBUG_CPP=1

The error above occurred when calling torch.export.export. If you would like to view some more information about this error, and get a list of all other errors that may occur in your export call, you can replace your `export()` call with `draft_export()`.

对于这些错误,你有一些基本选项:

  1. 避免不必要的 int() 类型转换调用,在本例中是 return 语句中的 int(a)

  2. 使用 torch._check() 调用;不幸的是,在这种情况下你能做的可能只是特化(使用 torch._check(a == 60))。

  3. 在更高级别上重写有问题的代码。例如,列表推导式在语义上是一个 repeat() 算子,它不涉及 int() 类型转换。以下重写避免了数据依赖错误。

class Foo(torch.nn.Module):
    def forward(self, x, y):
        a = x.item()
        b = y.unsqueeze(0).repeat(a, 1)
        return b + a

inps = (
    torch.tensor(32),
    torch.randn(60),
)
ep = export(Foo(), inps, strict=False)
print(ep)
I0603 01:00:16.310000 28681 torch/fx/experimental/symbolic_shapes.py:4002] create_env
I0603 01:00:16.317000 28681 torch/fx/experimental/symbolic_shapes.py:5124] create_unbacked_symint u0 [-int_oo, int_oo] (_subclasses/fake_impls.py:849 in local_scalar_dense)
I0603 01:00:16.317000 28681 torch/fx/experimental/symbolic_shapes.py:1439] compute_unbacked_bindings [u0]
I0603 01:00:16.320000 28681 torch/fx/experimental/symbolic_shapes.py:7834] runtime_assert u0 >= 0 [guard added] (_meta_registrations.py:4434 in meta_repeat), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_GUARD_ADDED="u0 >= 0"
V0603 01:00:16.320000 28681 torch/fx/experimental/symbolic_shapes.py:7256] _update_var_to_range u0 = VR[0, int_oo] (update)
I0603 01:00:16.324000 28681 torch/fx/experimental/symbolic_shapes.py:8002] could not evaluate Eq(u0, 0) due to data dependency, it was assumed to be False with no runtime assertions (utils/_stats.py:29 in wrapper)
I0603 01:00:16.324000 28681 torch/fx/experimental/symbolic_shapes.py:8002] For C++ stack trace, run with TORCHDYNAMO_EXTENDED_DEBUG_CPP=1
I0603 01:00:16.330000 28681 torch/fx/experimental/symbolic_shapes.py:8002] could not evaluate 60*u0 < 2 due to data dependency, it was assumed to be False with no runtime assertions (_prims_common/__init__.py:317 in is_contiguous)
I0603 01:00:16.330000 28681 torch/fx/experimental/symbolic_shapes.py:8002] For C++ stack trace, run with TORCHDYNAMO_EXTENDED_DEBUG_CPP=1
I0603 01:00:16.331000 28681 torch/fx/experimental/symbolic_shapes.py:8002] could not evaluate Eq(u0, 1) due to data dependency, it was assumed to be False with no runtime assertions (_prims_common/__init__.py:280 in check_contiguous_sizes_strides)
I0603 01:00:16.331000 28681 torch/fx/experimental/symbolic_shapes.py:8002] For C++ stack trace, run with TORCHDYNAMO_EXTENDED_DEBUG_CPP=1
V0603 01:00:16.333000 28681 torch/fx/experimental/symbolic_shapes.py:8369] runtime_assert True == True [statically known]
I0603 01:00:16.338000 28681 torch/fx/experimental/symbolic_shapes.py:5669] produce_guards
V0603 01:00:16.338000 28681 torch/fx/experimental/symbolic_shapes.py:5902] track_symint L['x'].storage_offset() 0 None
V0603 01:00:16.338000 28681 torch/fx/experimental/symbolic_shapes.py:5902] track_symint L['y'].size()[0] 60 None
V0603 01:00:16.338000 28681 torch/fx/experimental/symbolic_shapes.py:5902] track_symint L['y'].stride()[0] 1 None
V0603 01:00:16.338000 28681 torch/fx/experimental/symbolic_shapes.py:5902] track_symint L['y'].storage_offset() 0 None
ExportedProgram:
    class GraphModule(torch.nn.Module):
        def forward(self, x: "i64[]", y: "f32[60]"):
            # File: /var/lib/workspace/intermediate_source/torch_export_tutorial.py:769 in forward, code: a = x.item()
            item: "Sym(u0)" = torch.ops.aten.item.default(x);  x = None
            ge: "Sym(u0 >= 0)" = item >= 0
            _assert_scalar_default = torch.ops.aten._assert_scalar.default(ge, "Runtime assertion failed for expression u0 >= 0 on node 'ge'");  ge = _assert_scalar_default = None

            # File: /var/lib/workspace/intermediate_source/torch_export_tutorial.py:770 in forward, code: b = y.unsqueeze(0).repeat(a, 1)
            unsqueeze: "f32[1, 60]" = torch.ops.aten.unsqueeze.default(y, 0);  y = None
            repeat: "f32[u0, 60]" = torch.ops.aten.repeat.default(unsqueeze, [item, 1]);  unsqueeze = None

            # File: /var/lib/workspace/intermediate_source/torch_export_tutorial.py:771 in forward, code: return b + a
            add: "f32[u0, 60]" = torch.ops.aten.add.Tensor(repeat, item);  repeat = item = None
            return (add,)

Graph signature:
    # inputs
    x: USER_INPUT
    y: USER_INPUT

    # outputs
    add: USER_OUTPUT

Range constraints: {u0: VR[0, int_oo]}

数据依赖错误可能要复杂得多,你的工具包中还有更多处理它们的方法:torch._check_is_size()guard_size_oblivious() 或真实张量跟踪。有关更深入的指南,请参考 导出编程模型,或 处理 GuardOnDataDependentSymNode 错误

自定义算子#

torch.export 可以导出带有自定义算子的 PyTorch 程序。请参考 此页面,了解如何用 C++ 或 Python 编写自定义算子。

以下是在 Python 中注册一个可供 torch.export 使用的自定义算子的示例。需要注意的重要一点是,自定义算子必须具有 FakeTensor 内核

@torch.library.custom_op("my_custom_library::custom_op", mutates_args={})
def custom_op(x: torch.Tensor) -> torch.Tensor:
    print("custom_op called!")
    return torch.relu(x)

@custom_op.register_fake
def custom_op_meta(x):
    # Returns an empty tensor with the same shape as the expected output
    return torch.empty_like(x)

这是导出带有自定义算子的程序的示例。

class CustomOpExample(torch.nn.Module):
    def forward(self, x):
        x = torch.sin(x)
        x = torch.ops.my_custom_library.custom_op(x)
        x = torch.cos(x)
        return x

exported_custom_op_example = export(CustomOpExample(), (torch.randn(3, 3),))
print(exported_custom_op_example)
print(exported_custom_op_example.module()(torch.randn(3, 3)))
I0603 01:00:16.354000 28681 torch/fx/experimental/symbolic_shapes.py:4002] create_env
I0603 01:00:16.364000 28681 torch/fx/experimental/symbolic_shapes.py:5669] produce_guards
V0603 01:00:16.364000 28681 torch/fx/experimental/symbolic_shapes.py:5902] track_symint L['x'].size()[0] 3 None
V0603 01:00:16.364000 28681 torch/fx/experimental/symbolic_shapes.py:5902] track_symint L['x'].size()[1] 3 None
V0603 01:00:16.364000 28681 torch/fx/experimental/symbolic_shapes.py:5902] track_symint L['x'].stride()[0] 3 None
V0603 01:00:16.364000 28681 torch/fx/experimental/symbolic_shapes.py:5902] track_symint L['x'].stride()[1] 1 None
V0603 01:00:16.365000 28681 torch/fx/experimental/symbolic_shapes.py:5902] track_symint L['x'].storage_offset() 0 None
ExportedProgram:
    class GraphModule(torch.nn.Module):
        def forward(self, x: "f32[3, 3]"):
            # File: /var/lib/workspace/intermediate_source/torch_export_tutorial.py:812 in forward, code: x = torch.sin(x)
            sin: "f32[3, 3]" = torch.ops.aten.sin.default(x);  x = None

            # File: /var/lib/workspace/intermediate_source/torch_export_tutorial.py:813 in forward, code: x = torch.ops.my_custom_library.custom_op(x)
            custom_op: "f32[3, 3]" = torch.ops.my_custom_library.custom_op.default(sin);  sin = None

            # File: /var/lib/workspace/intermediate_source/torch_export_tutorial.py:814 in forward, code: x = torch.cos(x)
            cos: "f32[3, 3]" = torch.ops.aten.cos.default(custom_op);  custom_op = None
            return (cos,)

Graph signature:
    # inputs
    x: USER_INPUT

    # outputs
    cos: USER_OUTPUT

Range constraints: {}

custom_op called!
tensor([[1.0000, 1.0000, 1.0000],
        [0.7857, 1.0000, 1.0000],
        [0.7594, 0.7974, 1.0000]])

请注意,在 ExportedProgram 中,自定义算子被包含在图中。

IR/分解#

torch.export 产生的图仅包含 ATen 算子,这是 PyTorch 中的基本计算单位。由于有超过 3000 个 ATen 算子,导出提供了一种基于某些特征缩小图中使用的算子集的方法,从而创建不同的 IR。

默认情况下,导出产生最通用的 IR,其中包含所有 ATen 算子,包括函数式和非函数式算子。函数式算子是不包含任何输入变异或别名的算子。你可以在 此处 找到所有 ATen 算子的列表,并且可以通过检查 op._schema.is_mutable 来检查一个算子是否是函数式的,例如:

print(torch.ops.aten.add.Tensor._schema.is_mutable)
print(torch.ops.aten.add_.Tensor._schema.is_mutable)
False
True

这种通用 IR 可用于在即时 PyTorch Autograd 中进行训练。

class DecompExample(torch.nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.conv = torch.nn.Conv2d(1, 3, 1, 1)
        self.bn = torch.nn.BatchNorm2d(3)

    def forward(self, x):
        x = self.conv(x)
        x = self.bn(x)
        return (x,)

ep_for_training = torch.export.export(DecompExample(), (torch.randn(1, 1, 3, 3),))
print(ep_for_training.graph)
I0603 01:00:16.375000 28681 torch/fx/experimental/symbolic_shapes.py:4002] create_env
I0603 01:00:16.410000 28681 torch/fx/experimental/symbolic_shapes.py:5669] produce_guards
V0603 01:00:16.410000 28681 torch/fx/experimental/symbolic_shapes.py:5902] track_symint L['x'].size()[0] 1 None
V0603 01:00:16.411000 28681 torch/fx/experimental/symbolic_shapes.py:5902] track_symint L['x'].size()[1] 1 None
V0603 01:00:16.411000 28681 torch/fx/experimental/symbolic_shapes.py:5902] track_symint L['x'].size()[2] 3 None
V0603 01:00:16.411000 28681 torch/fx/experimental/symbolic_shapes.py:5902] track_symint L['x'].size()[3] 3 None
V0603 01:00:16.411000 28681 torch/fx/experimental/symbolic_shapes.py:5902] track_symint L['x'].stride()[0] 9 None
V0603 01:00:16.411000 28681 torch/fx/experimental/symbolic_shapes.py:5902] track_symint L['x'].stride()[1] 9 None
V0603 01:00:16.411000 28681 torch/fx/experimental/symbolic_shapes.py:5902] track_symint L['x'].stride()[2] 3 None
V0603 01:00:16.411000 28681 torch/fx/experimental/symbolic_shapes.py:5902] track_symint L['x'].stride()[3] 1 None
V0603 01:00:16.411000 28681 torch/fx/experimental/symbolic_shapes.py:5902] track_symint L['x'].storage_offset() 0 None
graph():
    %p_conv_weight : [num_users=1] = placeholder[target=p_conv_weight]
    %p_conv_bias : [num_users=1] = placeholder[target=p_conv_bias]
    %p_bn_weight : [num_users=1] = placeholder[target=p_bn_weight]
    %p_bn_bias : [num_users=1] = placeholder[target=p_bn_bias]
    %b_bn_running_mean : [num_users=1] = placeholder[target=b_bn_running_mean]
    %b_bn_running_var : [num_users=1] = placeholder[target=b_bn_running_var]
    %b_bn_num_batches_tracked : [num_users=1] = placeholder[target=b_bn_num_batches_tracked]
    %x : [num_users=1] = placeholder[target=x]
    %conv2d : [num_users=1] = call_function[target=torch.ops.aten.conv2d.default](args = (%x, %p_conv_weight, %p_conv_bias), kwargs = {})
    %add_ : [num_users=0] = call_function[target=torch.ops.aten.add_.Tensor](args = (%b_bn_num_batches_tracked, 1), kwargs = {})
    %batch_norm : [num_users=1] = call_function[target=torch.ops.aten.batch_norm.default](args = (%conv2d, %p_bn_weight, %p_bn_bias, %b_bn_running_mean, %b_bn_running_var, True, 0.1, 1e-05, False), kwargs = {})
    return (batch_norm,)

然后,我们可以通过 API run_decompositions 将此导出程序降级到仅包含函数式 ATen 算子的算子集,该 API 将 ATen 算子分解为分解表中指定的算子,并将图函数化。通过指定一个空集,我们仅执行函数化,而不进行任何额外的分解。这产生了一个包含约 2000 个算子(而不是上面的 3000 个算子)的 IR,非常适合推理用例。

/usr/lib/python3.10/copyreg.py:101: FutureWarning: `isinstance(treespec, LeafSpec)` is deprecated, use `isinstance(treespec, TreeSpec) and treespec.is_leaf()` instead.
  return cls.__new__(cls, *args)
graph():
    %p_conv_weight : [num_users=1] = placeholder[target=p_conv_weight]
    %p_conv_bias : [num_users=1] = placeholder[target=p_conv_bias]
    %p_bn_weight : [num_users=1] = placeholder[target=p_bn_weight]
    %p_bn_bias : [num_users=1] = placeholder[target=p_bn_bias]
    %b_bn_running_mean : [num_users=1] = placeholder[target=b_bn_running_mean]
    %b_bn_running_var : [num_users=1] = placeholder[target=b_bn_running_var]
    %b_bn_num_batches_tracked : [num_users=1] = placeholder[target=b_bn_num_batches_tracked]
    %x : [num_users=1] = placeholder[target=x]
    %conv2d : [num_users=1] = call_function[target=torch.ops.aten.conv2d.default](args = (%x, %p_conv_weight, %p_conv_bias), kwargs = {})
    %add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%b_bn_num_batches_tracked, 1), kwargs = {})
    %_native_batch_norm_legit_functional : [num_users=3] = call_function[target=torch.ops.aten._native_batch_norm_legit_functional.default](args = (%conv2d, %p_bn_weight, %p_bn_bias, %b_bn_running_mean, %b_bn_running_var, True, 0.1, 1e-05), kwargs = {})
    %getitem : [num_users=1] = call_function[target=operator.getitem](args = (%_native_batch_norm_legit_functional, 0), kwargs = {})
    %getitem_3 : [num_users=1] = call_function[target=operator.getitem](args = (%_native_batch_norm_legit_functional, 3), kwargs = {})
    %getitem_4 : [num_users=1] = call_function[target=operator.getitem](args = (%_native_batch_norm_legit_functional, 4), kwargs = {})
    return (getitem_3, getitem_4, add, getitem)

我们可以看到,之前的可变算子 torch.ops.aten.add_.default 现在已被替换为 torch.ops.aten.add.default,这是一个函数式算子。

我们还可以将此导出程序进一步降级到仅包含 核心 ATen 算子集 的算子集,该集合仅包含约 180 个算子。此 IR 对于不想重新实现所有 ATen 算子的后端而言是最佳的。

from torch.export import default_decompositions

core_aten_decomp_table = default_decompositions()
core_aten_ep = ep_for_training.run_decompositions(decomp_table=core_aten_decomp_table)
print(core_aten_ep.graph)
/usr/lib/python3.10/copyreg.py:101: FutureWarning: `isinstance(treespec, LeafSpec)` is deprecated, use `isinstance(treespec, TreeSpec) and treespec.is_leaf()` instead.
  return cls.__new__(cls, *args)
graph():
    %p_conv_weight : [num_users=1] = placeholder[target=p_conv_weight]
    %p_conv_bias : [num_users=1] = placeholder[target=p_conv_bias]
    %p_bn_weight : [num_users=1] = placeholder[target=p_bn_weight]
    %p_bn_bias : [num_users=1] = placeholder[target=p_bn_bias]
    %b_bn_running_mean : [num_users=1] = placeholder[target=b_bn_running_mean]
    %b_bn_running_var : [num_users=1] = placeholder[target=b_bn_running_var]
    %b_bn_num_batches_tracked : [num_users=1] = placeholder[target=b_bn_num_batches_tracked]
    %x : [num_users=1] = placeholder[target=x]
    %convolution : [num_users=1] = call_function[target=torch.ops.aten.convolution.default](args = (%x, %p_conv_weight, %p_conv_bias, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), kwargs = {})
    %add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%b_bn_num_batches_tracked, 1), kwargs = {})
    %_native_batch_norm_legit_functional : [num_users=3] = call_function[target=torch.ops.aten._native_batch_norm_legit_functional.default](args = (%convolution, %p_bn_weight, %p_bn_bias, %b_bn_running_mean, %b_bn_running_var, True, 0.1, 1e-05), kwargs = {})
    %getitem : [num_users=1] = call_function[target=operator.getitem](args = (%_native_batch_norm_legit_functional, 0), kwargs = {})
    %getitem_3 : [num_users=1] = call_function[target=operator.getitem](args = (%_native_batch_norm_legit_functional, 3), kwargs = {})
    %getitem_4 : [num_users=1] = call_function[target=operator.getitem](args = (%_native_batch_norm_legit_functional, 4), kwargs = {})
    return (getitem_3, getitem_4, add, getitem)

我们现在看到 torch.ops.aten.conv2d.default 已被分解为 torch.ops.aten.convolution.default。这是因为 convolution 是一个更“核心”的算子,因为像 conv1dconv2d 这样的操作可以使用同一个算子来实现。

我们还可以指定我们自己的分解行为。

my_decomp_table = torch.export.default_decompositions()

def my_awesome_custom_conv2d_function(x, weight, bias, stride=[1, 1], padding=[0, 0], dilation=[1, 1], groups=1):
    return 2 * torch.ops.aten.convolution(x, weight, bias, stride, padding, dilation, False, [0, 0], groups)

my_decomp_table[torch.ops.aten.conv2d.default] = my_awesome_custom_conv2d_function
my_ep = ep_for_training.run_decompositions(my_decomp_table)
print(my_ep.graph)
/usr/lib/python3.10/copyreg.py:101: FutureWarning: `isinstance(treespec, LeafSpec)` is deprecated, use `isinstance(treespec, TreeSpec) and treespec.is_leaf()` instead.
  return cls.__new__(cls, *args)
graph():
    %p_conv_weight : [num_users=1] = placeholder[target=p_conv_weight]
    %p_conv_bias : [num_users=1] = placeholder[target=p_conv_bias]
    %p_bn_weight : [num_users=1] = placeholder[target=p_bn_weight]
    %p_bn_bias : [num_users=1] = placeholder[target=p_bn_bias]
    %b_bn_running_mean : [num_users=1] = placeholder[target=b_bn_running_mean]
    %b_bn_running_var : [num_users=1] = placeholder[target=b_bn_running_var]
    %b_bn_num_batches_tracked : [num_users=1] = placeholder[target=b_bn_num_batches_tracked]
    %x : [num_users=1] = placeholder[target=x]
    %convolution : [num_users=1] = call_function[target=torch.ops.aten.convolution.default](args = (%x, %p_conv_weight, %p_conv_bias, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), kwargs = {})
    %mul : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convolution, 2), kwargs = {})
    %add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%b_bn_num_batches_tracked, 1), kwargs = {})
    %_native_batch_norm_legit_functional : [num_users=3] = call_function[target=torch.ops.aten._native_batch_norm_legit_functional.default](args = (%mul, %p_bn_weight, %p_bn_bias, %b_bn_running_mean, %b_bn_running_var, True, 0.1, 1e-05), kwargs = {})
    %getitem : [num_users=1] = call_function[target=operator.getitem](args = (%_native_batch_norm_legit_functional, 0), kwargs = {})
    %getitem_3 : [num_users=1] = call_function[target=operator.getitem](args = (%_native_batch_norm_legit_functional, 3), kwargs = {})
    %getitem_4 : [num_users=1] = call_function[target=operator.getitem](args = (%_native_batch_norm_legit_functional, 4), kwargs = {})
    return (getitem_3, getitem_4, add, getitem)

注意,torch.ops.aten.conv2d.default 没有被分解为 torch.ops.aten.convolution.default,而是分解为 torch.ops.aten.convolution.defaulttorch.ops.aten.mul.Tensor,这符合我们的自定义分解规则。

ExportDB#

torch.export 将始终仅从 PyTorch 程序导出单个计算图。由于此要求,将会有与 torch.export 不兼容的 Python 或 PyTorch 功能,这将需要用户重写其模型代码的部分内容。我们在教程前面已经看到了这方面的示例——例如,使用 cond 重写 if 语句。

ExportDB 是记录 torch.export 支持和不支持的 Python/PyTorch 功能的标准参考。它本质上是一个程序示例列表,每一个都代表了特定 Python/PyTorch 功能的使用及其与 torch.export 的交互。示例也按类别标记,以便更轻松地搜索。

例如,让我们使用 ExportDB 来更好地了解谓词在 cond 算子中是如何工作的。我们可以查看名为 cond_predicate 的示例,它有一个 torch.cond 标签。示例代码如下:

def cond_predicate(x):
    """
    The conditional statement (aka predicate) passed to ``cond()`` must be one of the following:
    - ``torch.Tensor`` with a single element
    - boolean expression
    NOTE: If the `pred` is test on a dim with batch size < 2, it will be specialized.
    """
    pred = x.dim() > 2 and x.shape[2] > 10
    return cond(pred, lambda x: x.cos(), lambda y: y.sin(), [x])

更一般地,当出现以下情况之一时,ExportDB 可用作参考:

  1. 在尝试 torch.export 之前,你预先知道你的模型使用了一些复杂的 Python/PyTorch 功能,并且想知道 torch.export 是否涵盖该功能。

  2. 当尝试 torch.export 时,出现了失败,并且不清楚如何解决它。

ExportDB 并不详尽,但旨在涵盖典型 PyTorch 代码中发现的所有用例。如果有一个应该添加到 ExportDB 或由 torch.export 支持的重要 Python/PyTorch 功能,请随时联系我们。

运行导出的程序#

由于 torch.export 仅是一种图捕获机制,即时调用 torch.export 产生的产品将等同于运行即时模块。为了优化已导出程序的执行,我们可以通过 torch.compileAOTInductorTensorRT 将此导出的产品传递给 Inductor 等后端。

class M(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.linear = torch.nn.Linear(3, 3)

    def forward(self, x):
        x = self.linear(x)
        return x

inp = torch.randn(2, 3, device="cuda")
m = M().to(device="cuda")
ep = torch.export.export(m, (inp,))

# Run it eagerly
res = ep.module()(inp)
print(res)

# Run it with torch.compile
res = torch.compile(ep.module(), backend="inductor")(inp)
print(res)
I0603 01:00:17.332000 28681 torch/fx/experimental/symbolic_shapes.py:4002] create_env
I0603 01:00:17.345000 28681 torch/fx/experimental/symbolic_shapes.py:5669] produce_guards
V0603 01:00:17.345000 28681 torch/fx/experimental/symbolic_shapes.py:5902] track_symint L['x'].size()[0] 2 None
V0603 01:00:17.346000 28681 torch/fx/experimental/symbolic_shapes.py:5902] track_symint L['x'].size()[1] 3 None
V0603 01:00:17.346000 28681 torch/fx/experimental/symbolic_shapes.py:5902] track_symint L['x'].stride()[0] 3 None
V0603 01:00:17.346000 28681 torch/fx/experimental/symbolic_shapes.py:5902] track_symint L['x'].stride()[1] 1 None
V0603 01:00:17.346000 28681 torch/fx/experimental/symbolic_shapes.py:5902] track_symint L['x'].storage_offset() 0 None
tensor([[0.3286, 0.7846, 1.5571],
        [0.1822, 0.5528, 1.0456]], device='cuda:0', grad_fn=<AddmmBackward0>)
I0603 01:00:17.989000 28681 torch/fx/experimental/symbolic_shapes.py:4002] [2/0] create_env
/usr/local/lib/python3.10/dist-packages/torch/_inductor/compile_fx.py:320: UserWarning: TensorFloat32 tensor cores for float32 matrix multiplication available but not enabled. Consider setting `torch.set_float32_matmul_precision('high')` for better performance.
  warnings.warn(
I0603 01:00:19.289000 28681 torch/fx/experimental/symbolic_shapes.py:5669] [2/0] produce_guards
I0603 01:00:19.300000 28681 torch/fx/experimental/symbolic_shapes.py:5669] [2/0] produce_guards
V0603 01:00:19.300000 28681 torch/fx/experimental/symbolic_shapes.py:5902] [2/0] track_symint L['x'].size()[0] 2 None
V0603 01:00:19.300000 28681 torch/fx/experimental/symbolic_shapes.py:5902] [2/0] track_symint L['x'].size()[1] 3 None
V0603 01:00:19.300000 28681 torch/fx/experimental/symbolic_shapes.py:5902] [2/0] track_symint L['x'].stride()[0] 3 None
V0603 01:00:19.301000 28681 torch/fx/experimental/symbolic_shapes.py:5902] [2/0] track_symint L['x'].stride()[1] 1 None
V0603 01:00:19.301000 28681 torch/fx/experimental/symbolic_shapes.py:5902] [2/0] track_symint L['x'].storage_offset() 0 None
V0603 01:00:19.301000 28681 torch/fx/experimental/symbolic_shapes.py:5902] [2/0] track_symint L['self']._modules['linear']._parameters['weight'].size()[0] 3 None
V0603 01:00:19.301000 28681 torch/fx/experimental/symbolic_shapes.py:5902] [2/0] track_symint L['self']._modules['linear']._parameters['weight'].size()[1] 3 None
V0603 01:00:19.301000 28681 torch/fx/experimental/symbolic_shapes.py:5902] [2/0] track_symint L['self']._modules['linear']._parameters['weight'].stride()[0] 3 None
V0603 01:00:19.301000 28681 torch/fx/experimental/symbolic_shapes.py:5902] [2/0] track_symint L['self']._modules['linear']._parameters['weight'].stride()[1] 1 None
V0603 01:00:19.301000 28681 torch/fx/experimental/symbolic_shapes.py:5902] [2/0] track_symint L['self']._modules['linear']._parameters['weight'].storage_offset() 0 None
V0603 01:00:19.301000 28681 torch/fx/experimental/symbolic_shapes.py:5902] [2/0] track_symint L['self']._modules['linear']._parameters['bias'].size()[0] 3 None
V0603 01:00:19.301000 28681 torch/fx/experimental/symbolic_shapes.py:5902] [2/0] track_symint L['self']._modules['linear']._parameters['bias'].stride()[0] 1 None
V0603 01:00:19.301000 28681 torch/fx/experimental/symbolic_shapes.py:5902] [2/0] track_symint L['self']._modules['linear']._parameters['bias'].storage_offset() 0 None
V0603 01:00:19.302000 28681 torch/fx/experimental/symbolic_shapes.py:6139] [2/0] Skipping guard L['x'].size()[0] == 2
V0603 01:00:19.302000 28681 torch/fx/experimental/symbolic_shapes.py:6139] [2/0] Skipping guard L['x'].size()[1] == 3
V0603 01:00:19.302000 28681 torch/fx/experimental/symbolic_shapes.py:6139] [2/0] Skipping guard L['x'].stride()[0] == 3
V0603 01:00:19.302000 28681 torch/fx/experimental/symbolic_shapes.py:6139] [2/0] Skipping guard L['x'].stride()[1] == 1
V0603 01:00:19.302000 28681 torch/fx/experimental/symbolic_shapes.py:6139] [2/0] Skipping guard L['x'].storage_offset() == 0
V0603 01:00:19.302000 28681 torch/fx/experimental/symbolic_shapes.py:6139] [2/0] Skipping guard L['self']._modules['linear']._parameters['weight'].size()[0] == 3
V0603 01:00:19.302000 28681 torch/fx/experimental/symbolic_shapes.py:6139] [2/0] Skipping guard L['self']._modules['linear']._parameters['weight'].size()[1] == 3
V0603 01:00:19.302000 28681 torch/fx/experimental/symbolic_shapes.py:6139] [2/0] Skipping guard L['self']._modules['linear']._parameters['weight'].stride()[0] == 3
V0603 01:00:19.302000 28681 torch/fx/experimental/symbolic_shapes.py:6139] [2/0] Skipping guard L['self']._modules['linear']._parameters['weight'].stride()[1] == 1
V0603 01:00:19.302000 28681 torch/fx/experimental/symbolic_shapes.py:6139] [2/0] Skipping guard L['self']._modules['linear']._parameters['weight'].storage_offset() == 0
V0603 01:00:19.303000 28681 torch/fx/experimental/symbolic_shapes.py:6139] [2/0] Skipping guard L['self']._modules['linear']._parameters['bias'].size()[0] == 3
V0603 01:00:19.303000 28681 torch/fx/experimental/symbolic_shapes.py:6139] [2/0] Skipping guard L['self']._modules['linear']._parameters['bias'].stride()[0] == 1
V0603 01:00:19.303000 28681 torch/fx/experimental/symbolic_shapes.py:6139] [2/0] Skipping guard L['self']._modules['linear']._parameters['bias'].storage_offset() == 0
tensor([[0.3286, 0.7846, 1.5571],
        [0.1822, 0.5528, 1.0456]], device='cuda:0',
       grad_fn=<CompiledFunctionBackward>)
import torch._inductor

# Note: these APIs are subject to change
# Compile the exported program to a PT2 archive using ``AOTInductor``
with torch.no_grad():
    pt2_path = torch._inductor.aoti_compile_and_package(ep)

# Load and run the .so file in Python.
# To load and run it in a C++ environment, see:
# https://docs.pytorch.ac.cn/docs/stable/torch.compiler_aot_inductor.html
aoti_compiled = torch._inductor.aoti_load_package(pt2_path)
res = aoti_compiled(inp)

结论#

我们介绍了 torch.export,这是从 PyTorch 程序导出单个计算图的 PyTorch 2.X 新方法。特别是,我们演示了为了导出图而需要进行的一些代码修改和考虑事项(控制流算子、约束等)。

脚本的总运行时间: (0 分 4.666 秒)