Skip to content

Commit 8f27fde

Browse files
zhxchen17pytorchmergebot
authored andcommitted
[export] Log private api uses. (#119848)
Summary: as title. The following APIs are logged: - capture_preautograd_graph - torch._export.aot_compile - external usage of _export_to_torch_ir (AOTInductor, Pippy) - constraints API - public use of torch._dynamo.export Test Plan: CI Differential Revision: D53735599 Pull Request resolved: #119848 Approved by: https://github.com/suo
1 parent 340b6fa commit 8f27fde

3 files changed

Lines changed: 25 additions & 10 deletions

File tree

torch/_dynamo/eval_frame.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@
4343
import torch.utils.checkpoint
4444
from torch import _guards
4545
from torch._subclasses import fake_tensor
46+
from torch._utils_internal import log_export_usage
4647
from torch.export import Constraint
4748
from torch.export.dynamic_shapes import _process_dynamic_shapes
4849
from torch.fx.experimental.proxy_tensor import make_fx, maybe_disable_fake_tensor_mode
@@ -1115,6 +1116,7 @@ def export(
11151116
assume_static_by_default: bool = False,
11161117
same_signature: bool = True,
11171118
disable_constraint_solver: bool = False,
1119+
_log_export_usage: bool = True,
11181120
**extra_kwargs,
11191121
) -> Callable[..., ExportResult]:
11201122
"""
@@ -1179,6 +1181,9 @@ def export(
11791181
11801182
Note - this headerdoc was authored by ChatGPT, with slight modifications by the author.
11811183
"""
1184+
if _log_export_usage:
1185+
log_export_usage(event="export.private_api", flags={"_dynamo"})
1186+
11821187
# Deal with "local variable referenced before assignment"
11831188
_f = f
11841189
_assume_static_by_default = assume_static_by_default

torch/_export/__init__.py

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -31,15 +31,25 @@
3131
from torch._functorch.aot_autograd import aot_export_module, GraphSignature
3232
from torch._functorch.eager_transforms import functionalize
3333
from torch._guards import detect_fake_mode
34+
from torch._inductor import config
3435
from torch._ops import OpOverload
3536
from torch._subclasses.fake_tensor import FakeTensor, FakeTensorMode
3637
from torch._subclasses.functional_tensor import FunctionalTensor
38+
from torch._utils_internal import log_export_usage
3739
from torch.export._tree_utils import reorder_kwargs
40+
from torch.export._unlift import _create_stateful_graph_module
41+
from torch.export.dynamic_shapes import (
42+
_process_constraints,
43+
_process_dynamic_shapes,
44+
Constraint,
45+
dims,
46+
dynamic_dim,
47+
)
3848
from torch.export.exported_program import (
49+
_disable_prexisiting_fake_mode,
3950
ExportedProgram,
4051
ModuleCallEntry,
4152
ModuleCallSignature,
42-
_disable_prexisiting_fake_mode,
4353
)
4454
from torch.export.graph_signature import (
4555
_sig_to_specs,
@@ -53,14 +63,6 @@
5363
SymIntArgument,
5464
TensorArgument,
5565
)
56-
from torch.export.dynamic_shapes import (
57-
Constraint,
58-
dims,
59-
dynamic_dim,
60-
_process_constraints,
61-
_process_dynamic_shapes,
62-
)
63-
from torch.export._unlift import _create_stateful_graph_module
6466
from torch.fx import traceback as fx_traceback
6567
from torch.fx._compatibility import compatibility
6668
from torch.fx.experimental.proxy_tensor import make_fx, maybe_disable_fake_tensor_mode
@@ -77,7 +79,6 @@
7779
_AddRuntimeAssertionsForInlineConstraintsPass,
7880
)
7981
from .wrappers import _wrap_submodules
80-
from torch._inductor import config
8182

8283

8384
@dataclasses.dataclass
@@ -139,6 +140,8 @@ def capture_pre_autograd_graph(
139140
from torch.export._trace import _convert_input_to_fake, DEFAULT_EXPORT_DYNAMO_CONFIG
140141
from torch.export.dynamic_shapes import _process_dynamic_shapes
141142

143+
log_export_usage(event="export.private_api", flags={"capture_pre_autograd_graph"})
144+
142145
if kwargs is None:
143146
kwargs = {}
144147

torch/export/_trace.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -286,12 +286,16 @@ def _export_to_torch_ir(
286286
preserve_module_call_signature: Tuple[str, ...] = (),
287287
disable_constraint_solver: bool = False,
288288
restore_fqn: bool = True,
289+
_log_export_usage: bool = True,
289290
) -> torch.fx.GraphModule:
290291
"""
291292
Traces either an nn.Module's forward function or just a callable with PyTorch
292293
operations inside and produce a torch.fx.GraphModule in torch IR.
293294
"""
294295

296+
if _log_export_usage:
297+
log_export_usage(event="export.private_api", flags={"_export_to_torch_ir"})
298+
295299
constraints = constraints or []
296300
kwargs = kwargs or {}
297301

@@ -318,6 +322,7 @@ def _export_to_torch_ir(
318322
assume_static_by_default=True,
319323
tracing_mode="symbolic",
320324
disable_constraint_solver=disable_constraint_solver,
325+
_log_export_usage=_log_export_usage,
321326
)(
322327
*args,
323328
**kwargs,
@@ -580,6 +585,7 @@ def _export(
580585
log_export_usage(event="export.enter", flags=flags)
581586

582587
if constraints is not None:
588+
log_export_usage(event="export.private_api", flags={"constraints"})
583589
warnings.warn(
584590
"Using `constraints` to specify dynamic shapes for export is DEPRECATED "
585591
"and will not be supported in the future. "
@@ -744,6 +750,7 @@ def forward(self, *args, **kwargs):
744750
constraints,
745751
preserve_module_call_signature=preserve_module_call_signature,
746752
restore_fqn=False, # don't need to restore because we will do it later
753+
_log_export_usage=False,
747754
)
748755

749756
# We detect the fake_mode by looking at gm_torch_level's placeholders, this is the fake_mode created in dynamo.

0 commit comments

Comments
 (0)