Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
139 changes: 139 additions & 0 deletions test/dynamo/test_aot_autograd_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -2785,6 +2785,145 @@ def wrap_run_pre_grad_passes(

self.assertEqual(result1, result2)

@inductor_config.patch("fx_graph_cache", True)
@functorch_config.patch("enable_autograd_cache", True)
def test_pre_grad_passes_default_timing_with_uuid(self):
"""
With default timing and a custom pass that has a UUID, passes run late
(only on cache miss).
"""
from torch._inductor.compile_fx import run_pre_grad_passes

def fn(x, y):
return 1 * x + y

pre_grad_call_count = 0

def wrap_run_pre_grad_passes(
model: GraphModule, example_inputs: Sequence[InputType]
) -> GraphModule:
nonlocal pre_grad_call_count
pre_grad_call_count += 1
run_pre_grad_passes(model, example_inputs)
return model

x = torch.randn(10)
y = torch.randn(10)

with (
unittest.mock.patch(
"torch._inductor.compile_fx.run_pre_grad_passes",
wrap_run_pre_grad_passes,
),
inductor_config.patch(
"pre_grad_custom_pass", custom_pre_grad_pass_remove_ident_muls
),
):
self._clear_all_caches()

compiled_fn = torch.compile(fn)
result1 = compiled_fn(x, y)

self.assertEqual(counters["aot_autograd"]["autograd_cache_miss"], 1)
self.assertEqual(counters["aot_autograd"]["autograd_cache_hit"], 0)
self.assertEqual(pre_grad_call_count, 1)

torch._dynamo.reset()

# Cache hit — passes should NOT run (late timing)
compiled_fn2 = torch.compile(fn)
result2 = compiled_fn2(x, y)

self.assertEqual(counters["aot_autograd"]["autograd_cache_miss"], 1)
self.assertEqual(counters["aot_autograd"]["autograd_cache_hit"], 1)
self.assertEqual(pre_grad_call_count, 1)

self.assertEqual(result1, result2)

@inductor_config.patch("fx_graph_cache", True)
@functorch_config.patch("enable_autograd_cache", True)
def test_pre_grad_passes_default_timing_without_uuid(self):
"""
With default timing and a custom pass without a UUID, passes run early
(on every compile, even cache hits).
"""
from torch._inductor.compile_fx import run_pre_grad_passes

class NoUuidPass(CustomGraphPass):
def __call__(self, g: torch.fx.Graph) -> None:
pass

def uuid(self):
return None

def fn(x, y):
return x + y

pre_grad_call_count = 0

def wrap_run_pre_grad_passes(
model: GraphModule, example_inputs: Sequence[InputType]
) -> GraphModule:
nonlocal pre_grad_call_count
pre_grad_call_count += 1
run_pre_grad_passes(model, example_inputs)
return model

x = torch.randn(10)
y = torch.randn(10)

with (
unittest.mock.patch(
"torch._inductor.compile_fx.run_pre_grad_passes",
wrap_run_pre_grad_passes,
),
inductor_config.patch("pre_grad_custom_pass", NoUuidPass()),
):
self._clear_all_caches()

compiled_fn = torch.compile(fn)
result1 = compiled_fn(x, y)
self.assertEqual(pre_grad_call_count, 1)

torch._dynamo.reset()

# Cache hit — passes should STILL run (early timing)
compiled_fn2 = torch.compile(fn)
result2 = compiled_fn2(x, y)
self.assertEqual(pre_grad_call_count, 2)

self.assertEqual(result1, result2)

@inductor_config.patch("fx_graph_cache", True)
@functorch_config.patch("enable_autograd_cache", True)
@inductor_config.patch("pre_grad_pass_timing", "late")
def test_pre_grad_pass_late_timing_without_uuid_raises(self):
"""
Explicitly setting late timing with a pass that has no UUID should
raise a RuntimeError.
"""

class NoUuidPass(CustomGraphPass):
def __call__(self, g: torch.fx.Graph) -> None:
pass

def uuid(self):
return None

def fn(x, y):
return x + y

x = torch.randn(10)
y = torch.randn(10)

with inductor_config.patch("pre_grad_custom_pass", NoUuidPass()):
self._clear_all_caches()
compiled_fn = torch.compile(fn)
with self.assertRaisesRegex(
RuntimeError, "pre_grad_custom_pass must implement uuid"
):
compiled_fn(x, y)

def test_cache_hit_across_processes(self):
"""
Verify that a second subprocess gets a cache hit from the first subprocess's
Expand Down
24 changes: 15 additions & 9 deletions torch/_functorch/aot_autograd.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import time
from contextlib import nullcontext
from functools import wraps
from typing import Any, TYPE_CHECKING
from typing import Any, Literal, TYPE_CHECKING
from typing_extensions import ParamSpec, TypeVar
from unittest.mock import patch

Expand Down Expand Up @@ -1067,6 +1067,12 @@ def prepare_aot_module_simplified(
)


def _resolve_default_pre_grad_pass_timing() -> Literal["early", "late"]:
from torch._inductor.codecache import resolve_pre_grad_pass_timing

return resolve_pre_grad_pass_timing()


def aot_module_simplified(
mod: torch.fx.GraphModule | torch._dynamo.utils.GmWrapper,
args: Iterable[Any],
Expand Down Expand Up @@ -1130,11 +1136,13 @@ def aot_module_simplified(

compiled_fn = None

# "early" timing: run pre-grad passes before cache lookup so the
# cache key is computed from the already-transformed graph.
pre_grad_pass_timing: Literal["early", "late"] = (
_resolve_default_pre_grad_pass_timing()
)

if (
torch._inductor.config.pre_grad_pass_timing == "early"
and pre_grad_passes is not None
pre_grad_pass_timing == "early"
and pre_grad_passes
and isinstance(mod, torch.fx.GraphModule)
):
mod = pre_grad_passes(mod, fake_flat_args)
Expand All @@ -1158,11 +1166,9 @@ def aot_module_simplified(
)

if compiled_fn is None:
# "late" timing (default): run pre-grad passes after cache lookup,
# only on cache miss, to cache pre-grad transforms.
if (
torch._inductor.config.pre_grad_pass_timing == "late"
and pre_grad_passes is not None
pre_grad_pass_timing == "late"
and pre_grad_passes
and isinstance(mod, torch.fx.GraphModule)
):
mod = pre_grad_passes(mod, fake_flat_args)
Expand Down
42 changes: 38 additions & 4 deletions torch/_inductor/codecache.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
from tempfile import _TemporaryFileWrapper
from time import time, time_ns
from types import ModuleType
from typing import Any, cast, Generic, NoReturn, TYPE_CHECKING, TypeVar
from typing import Any, cast, Generic, Literal, NoReturn, TYPE_CHECKING, TypeVar
from typing_extensions import override, Self

import torch
Expand Down Expand Up @@ -781,6 +781,38 @@ class BypassFxGraphCache(Exception):
"""


def resolve_pre_grad_pass_timing() -> Literal["early", "late"]:
"""Resolve the effective pre-grad pass timing from the config.

"default" is resolved based on whether the custom pass provides a UUID:
passes with a UUID (or no custom pass) run "late" (after cache lookup),
passes without a UUID run "early" (before cache lookup).

Raises RuntimeError if a custom pass without a UUID is explicitly set to
run "late", since the cache key cannot account for it.
"""
timing: Literal["early", "late", "default"] = config.pre_grad_pass_timing
custom_pass = config.pre_grad_custom_pass
has_uuid = (
custom_pass
and isinstance(custom_pass, CustomGraphPass)
and custom_pass.uuid() is not None
)

if timing == "default":
supports_late = custom_pass is None or has_uuid
timing = "late" if supports_late else "early"

if timing == "late" and custom_pass and not has_uuid:
raise RuntimeError(
"pre_grad_custom_pass must implement uuid() to run late "
"(after cache lookup). Either implement uuid() or set "
"pre_grad_pass_timing to 'early'."
)

return timing


class FxGraphHashDetails:
"""
Object to capture all the details for a compiled FX graph relevant to computing
Expand Down Expand Up @@ -908,8 +940,8 @@ def __init__(
self.torch_version = torch_key()
self.system_info = CacheBase.get_system()
self.inductor_config = config.save_config_portable(ignore_private_configs=False)
# Custom passes should provide an ID to hash.
if config.pre_grad_pass_timing == "late":
# Custom passes should provide an ID to hash when they run late (after cache lookup).
if resolve_pre_grad_pass_timing() != "early":
self.pre_grad_custom_pass = self._get_custom_pass_detail(
config.pre_grad_custom_pass
)
Expand Down Expand Up @@ -1586,7 +1618,9 @@ def _check_can_cache(gm: torch.fx.GraphModule) -> None:
"""
# Custom passes must implement the CustomGraphPass or we don't
# know how to include them in the cache key calculation.
if config.pre_grad_pass_timing == "late":
# When timing is EARLY, pre-grad passes already ran before the cache
# lookup so there's nothing to validate here.
if resolve_pre_grad_pass_timing() != "early":
if config.pre_grad_custom_pass and (
not isinstance(config.pre_grad_custom_pass, CustomGraphPass)
or not config.pre_grad_custom_pass.uuid()
Expand Down
4 changes: 3 additions & 1 deletion torch/_inductor/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -661,7 +661,9 @@ def _nvgemm_max_profiling_configs_default() -> int | None:
# Use fx graph passes
use_pre_grad_passes: bool = True

pre_grad_pass_timing: Literal["early", "late"] = "early"

pre_grad_pass_timing: Literal["early", "late", "default"] = "default"


use_joint_graph_passes: bool = True
use_post_grad_passes: bool = True
Expand Down
Loading