Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
109 changes: 101 additions & 8 deletions test/dynamo/test_aot_autograd_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
from torch._functorch._aot_autograd.schemas import AOTConfig
from torch._guards import TracingContext
from torch._inductor import config as inductor_config
from torch._inductor.custom_graph_pass import CustomRuntimeEstimator
from torch._inductor.custom_graph_pass import CustomGraphPass, CustomRuntimeEstimator
from torch._inductor.runtime.runtime_utils import cache_dir
from torch._inductor.runtime.triton_compat import tl, triton
from torch._inductor.test_case import TestCase as InductorTestCase
Expand All @@ -53,16 +53,24 @@
)


def custom_pre_grad_pass_remove_ident_muls(g: torch.fx.Graph) -> None:
class CustomPreGradPassRemoveIdentMuls(CustomGraphPass):
"""
Pre-grad pass that removes redundant identity multiplications (1 * x).
"""
for n in g.nodes:
if n.op == "call_function" and n.target is operator.mul:
lhs, rhs = n.args
if lhs == 1:
n.replace_all_uses_with(rhs)
g.erase_node(n)

def __call__(self, g: torch.fx.Graph) -> None:
for n in g.nodes:
if n.op == "call_function" and n.target is operator.mul:
lhs, rhs = n.args
if lhs == 1:
n.replace_all_uses_with(rhs)
g.erase_node(n)

def uuid(self):
return "custom_pre_grad_pass_remove_ident_muls_v1"


custom_pre_grad_pass_remove_ident_muls = CustomPreGradPassRemoveIdentMuls()


def aot_eager_regional_inductor():
Expand Down Expand Up @@ -2784,6 +2792,91 @@ def fn(x, y):
self.assertEqual(counters2.get("autograd_cache_miss", 0), 0)
self.assertEqual(counters2.get("autograd_cache_hit", 0), 1)

def test_cache_hit_across_processes_pre_grad_custom_pass(self):
"""
Verify that different pre-grad custom passes produce different cache keys
across processes, and that the same pass produces a cache hit.
"""
import subprocess
import sys
import tempfile
import textwrap

with tempfile.TemporaryDirectory() as cache_dir:
# Script template that accepts a custom pass UUID
script_template = textwrap.dedent(
"""
import json
import operator
import torch
import torch._dynamo
from torch._dynamo.utils import counters
from torch._inductor import config as inductor_config
from torch._inductor.custom_graph_pass import CustomGraphPass

inductor_config.fx_graph_cache = True
inductor_config.fx_graph_remote_cache = False
torch._dynamo.reset()

class TestPreGradPass(CustomGraphPass):
def __init__(self, pass_uuid):
self._uuid = pass_uuid

def __call__(self, g):
for n in g.nodes:
if n.op == "call_function" and n.target is operator.mul:
lhs, rhs = n.args
if lhs == 1:
n.replace_all_uses_with(rhs)
g.erase_node(n)

def uuid(self):
return self._uuid

inductor_config.pre_grad_custom_pass = TestPreGradPass("{pass_uuid}")

def fn(x, y):
return 1 * x + y

compiled_fn = torch.compile(fn)
x = torch.randn(10)
y = torch.randn(10)
compiled_fn(x, y)

print(json.dumps(dict(counters["aot_autograd"])))
"""
)

env = {**os.environ, "TORCHINDUCTOR_CACHE_DIR": cache_dir}

def run_script(pass_uuid):
script = script_template.format(pass_uuid=pass_uuid)
result = subprocess.run(
[sys.executable, "-c", script],
env=env,
capture_output=True,
text=True,
)
self.assertEqual(result.returncode, 0, result.stderr)
import json

return json.loads(result.stdout.splitlines()[-1])

# First run with pass_uuid_A - expect cache miss
c1 = run_script("pass_uuid_A")
self.assertEqual(c1.get("autograd_cache_miss", 0), 1)
self.assertEqual(c1.get("autograd_cache_hit", 0), 0)

# Second run with same pass_uuid_A - expect cache hit
c2 = run_script("pass_uuid_A")
self.assertEqual(c2.get("autograd_cache_miss", 0), 0)
self.assertEqual(c2.get("autograd_cache_hit", 0), 1)

# Third run with different pass_uuid_B - expect cache miss
c3 = run_script("pass_uuid_B")
self.assertEqual(c3.get("autograd_cache_miss", 0), 1)
self.assertEqual(c3.get("autograd_cache_hit", 0), 0)


@functorch_config.patch({"bundled_autograd_cache": True})
class AOTAutogradCacheBundledTests(AOTAutogradCacheTests):
Expand Down
12 changes: 10 additions & 2 deletions torch/_inductor/codecache.py
Original file line number Diff line number Diff line change
Expand Up @@ -908,7 +908,10 @@ def __init__(
self.torch_version = torch_key()
self.system_info = CacheBase.get_system()
self.inductor_config = config.save_config_portable(ignore_private_configs=False)
# Custom post grad passes should provide an ID to hash.
# Custom passes should provide an ID to hash.
self.pre_grad_custom_pass = self._get_custom_pass_detail(
config.pre_grad_custom_pass
)
self.post_grad_custom_pre_pass = self._get_custom_pass_detail(
config.post_grad_custom_pre_pass
)
Expand Down Expand Up @@ -1580,8 +1583,13 @@ def _check_can_cache(gm: torch.fx.GraphModule) -> None:
Check some conditions that would preclude caching and raise BypassFxGraphCache
to bypass in case caching is not possible.
"""
# Post grad custom passes must implement the CustomGraphPass or we don't
# Custom passes must implement the CustomGraphPass or we don't
# know how to include them in the cache key calculation.
if config.pre_grad_custom_pass and (
not isinstance(config.pre_grad_custom_pass, CustomGraphPass)
or not config.pre_grad_custom_pass.uuid()
):
raise BypassFxGraphCache("Unsupported pre grad custom pass")
for p in (config.post_grad_custom_pre_pass, config.post_grad_custom_post_pass):
if p and (not isinstance(p, CustomGraphPass) or not p.uuid()):
raise BypassFxGraphCache("Unsupported post grad custom pass")
Expand Down
3 changes: 2 additions & 1 deletion torch/_inductor/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -300,7 +300,7 @@ def prologue_fusion_enabled() -> bool:
# Registers a custom pregrad pass. Note that the pre-grad IR is 1.
# non-functional, 2. non-normalized, and 3. prone to change. Ideally we should
# use post-grad passes.
pre_grad_custom_pass: Callable[[torch.fx.graph.Graph], None] | None = None
pre_grad_custom_pass: torch._inductor.custom_graph_pass.CustomGraphPassType = None

# Registers a custom pass to be run right before fusion in Inductor scheduler.
# WARNING: Inductor scheduler IR is at prototype stage and subject to change,
Expand Down Expand Up @@ -2484,6 +2484,7 @@ class trace:
"post_grad_custom_pre_pass",
"joint_custom_pre_pass",
"joint_custom_post_pass",
"pre_grad_custom_pass",
"_fuse_ddp_communication_passes",
"_pre_fusion_custom_pass",
# tests assume that changes here don't invalidate cache
Expand Down
Loading