Skip to content

Commit 5babb4d

Browse files
jamesjwupytorchmergebot
authored andcommitted
Add BundledAOTAutogradSerializableCallable (#162170)
This PR hooks up the python wrapper inductor backend to aot_compile. This is *not* the best way for us to grab the output of AOTAutograd; that involves a refactor to make AOTAutograd itself return a serializable callable. I'll do that refactor soon, but I want a basic interface to test with for now. In the medium term, we'll want aot_compile to call AOTAutograd directly, instead of using the TorchInductorWrapper's callback through compile_fx. Pull Request resolved: #162170 Approved by: https://github.com/zhxchen17 ghstack dependencies: #162169
1 parent eb9073a commit 5babb4d

File tree

2 files changed

+95
-1
lines changed

2 files changed

+95
-1
lines changed

test/dynamo/test_aot_compile.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -207,6 +207,25 @@ def backend(gm, example_inputs):
207207
},
208208
).aot_compile((example_inputs, {}))
209209

210+
def test_aot_compile_basic_fn_inductor(self):
211+
def fn(x, y):
212+
return x + y
213+
214+
compiled_fn = torch.compile(fn, fullgraph=True, backend="inductor").aot_compile(
215+
((torch.randn(3, 4), torch.randn(3, 4)), {})
216+
)
217+
inputs = (torch.randn(3, 4), torch.randn(3, 4))
218+
expected = fn(*inputs)
219+
actual = compiled_fn(*inputs)
220+
self.assertEqual(expected, actual)
221+
compiled_fn.save_compiled_function(self.path())
222+
torch._dynamo.reset()
223+
with torch.compiler.set_stance("fail_on_recompile"):
224+
with open(self.path(), "rb") as f:
225+
compiled_fn = torch.compiler.load_compiled_function(f)
226+
actual = compiled_fn(*inputs)
227+
self.assertEqual(expected, actual)
228+
210229

211230
if __name__ == "__main__":
212231
from torch._dynamo.test_case import run_tests

torch/_dynamo/aot_compile.py

Lines changed: 76 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,18 +2,23 @@
22
import builtins
33
import importlib
44
import inspect
5+
import logging
56
import pickle
67
import types
78
from dataclasses import dataclass
89
from typing import Any, Callable, Optional
910

1011
import torch
1112
import torch.fx
13+
from torch._dynamo.precompile_context import PrecompileContext
1214

1315
from . import convert_frame
1416
from .hooks import Hooks
1517

1618

19+
log = logging.getLogger(__name__)
20+
21+
1722
class SerializableCallable(abc.ABC):
1823
@classmethod
1924
@abc.abstractmethod
@@ -119,6 +124,65 @@ def deserialize(cls, data: bytes) -> "AOTCompiledFunction":
119124
return cls(artifacts)
120125

121126

127+
class BundledAOTAutogradSerializableCallable(SerializableCallable):
128+
"""
129+
Represents a serializable callable generated by compile_fx.
130+
This class wraps around the compiled function generated by AOTAutograd.
131+
132+
TODO: Instead of using PrecompileContext to grab it from AOTAutograd,
133+
this object should be what's *returned* by aot_module_simplified.
134+
We'll do that refactor in a later PR.
135+
"""
136+
137+
def __init__(self, artifact: Any) -> None:
138+
"""
139+
Takes in a BundledAOTAutogradCacheArtifact, which is the serialized form
140+
of a compiled function generated by AOTAutograd.
141+
"""
142+
143+
self.compiled_fn = artifact.after_deserialization()
144+
self.data = artifact.content
145+
146+
def __getattr__(self, attr: Any) -> Any:
147+
if hasattr(self, attr):
148+
return getattr(super(), attr)
149+
else:
150+
return getattr(self.compiled_fn, attr)
151+
152+
@classmethod
153+
def from_backend_id(
154+
cls, backend_id: str
155+
) -> "BundledAOTAutogradSerializableCallable":
156+
"""
157+
Takes in a backend_id, and returns a BundledAOTAutogradSerializableCallable
158+
that wraps around the compiled function generated by AOTAutograd.
159+
"""
160+
artifact = PrecompileContext.serialize_artifact_by_key(backend_id)
161+
if artifact is None:
162+
raise RuntimeError("No artifact found for backend_id: " + backend_id)
163+
return cls(artifact)
164+
165+
@classmethod
166+
def serialize_compile_artifacts(
167+
cls, fn: "BundledAOTAutogradSerializableCallable"
168+
) -> bytes:
169+
return fn.data
170+
171+
@classmethod
172+
def deserialize_compile_artifacts(cls, data: bytes) -> Any:
173+
from torch._functorch._aot_autograd.autograd_cache import (
174+
BundledAOTAutogradCacheArtifact,
175+
)
176+
177+
# The key in the artifact is not important here since we're not populating a cache,
178+
# we just want to grab the callable back out of the serialized entry
179+
artifact = BundledAOTAutogradCacheArtifact("", data)
180+
return cls(artifact)
181+
182+
def __call__(self, *args: Any, **kwargs: Any) -> Any:
183+
return self.compiled_fn(*args, **kwargs)
184+
185+
122186
def aot_compile_fullgraph(
123187
model: Any,
124188
example_inputs: tuple[tuple[Any, ...], dict[str, Any]],
@@ -191,12 +255,23 @@ def new_guard_filter_fn(
191255
assert check_fn.guards_state is not None
192256

193257
backend_input = capture_output.backend_input
258+
backend_input.graph_module._backend_id = backend_input.backend_id # type: ignore[assignment]
194259
output_graph = dynamo_output.tracer_output.output_graph
195260
assert output_graph is not None
196261
import_sources = output_graph.import_sources
197-
with torch._guards.tracing(TracingContext(backend_input.fake_mode)):
262+
with (
263+
torch._guards.tracing(TracingContext(backend_input.fake_mode)),
264+
torch._functorch.config.patch("bundled_autograd_cache", True),
265+
):
198266
compiled_fn = backend(backend_input.graph_module, backend_input.example_inputs)
199267

268+
# If Inductor backend is used, grab the compiled_fn from PrecompileContext
269+
# TODO: this should be replaced once we make the backend return the SerializableCallable directly.
270+
if isinstance(backend, torch._TorchCompileInductorWrapper):
271+
compiled_fn = BundledAOTAutogradSerializableCallable.from_backend_id(
272+
backend_input.backend_id
273+
)
274+
200275
if not isinstance(compiled_fn, SerializableCallable):
201276
if hasattr(backend, "compiler_fn"):
202277
compiler_fn = backend.compiler_fn

0 commit comments

Comments
 (0)