|
2 | 2 | import builtins |
3 | 3 | import importlib |
4 | 4 | import inspect |
| 5 | +import logging |
5 | 6 | import pickle |
6 | 7 | import types |
7 | 8 | from dataclasses import dataclass |
8 | 9 | from typing import Any, Callable, Optional |
9 | 10 |
|
10 | 11 | import torch |
11 | 12 | import torch.fx |
| 13 | +from torch._dynamo.precompile_context import PrecompileContext |
12 | 14 |
|
13 | 15 | from . import convert_frame |
14 | 16 | from .hooks import Hooks |
15 | 17 |
|
16 | 18 |
|
| 19 | +log = logging.getLogger(__name__) |
| 20 | + |
| 21 | + |
17 | 22 | class SerializableCallable(abc.ABC): |
18 | 23 | @classmethod |
19 | 24 | @abc.abstractmethod |
@@ -119,6 +124,65 @@ def deserialize(cls, data: bytes) -> "AOTCompiledFunction": |
119 | 124 | return cls(artifacts) |
120 | 125 |
|
121 | 126 |
|
| 127 | +class BundledAOTAutogradSerializableCallable(SerializableCallable): |
| 128 | + """ |
| 129 | + Represents a serializable callable generated by compile_fx. |
| 130 | + This class wraps around the compiled function generated by AOTAutograd. |
| 131 | +
|
| 132 | + TODO: Instead of using PrecompileContext to grab it from AOTAutograd, |
| 133 | + this object should be what's *returned* by aot_module_simplified. |
| 134 | + We'll do that refactor in a later PR. |
| 135 | + """ |
| 136 | + |
| 137 | + def __init__(self, artifact: Any) -> None: |
| 138 | + """ |
| 139 | + Takes in a BundledAOTAutogradCacheArtifact, which is the serialized form |
| 140 | + of a compiled function generated by AOTAutograd. |
| 141 | + """ |
| 142 | + |
| 143 | + self.compiled_fn = artifact.after_deserialization() |
| 144 | + self.data = artifact.content |
| 145 | + |
| 146 | + def __getattr__(self, attr: Any) -> Any: |
| 147 | + if hasattr(self, attr): |
| 148 | + return getattr(super(), attr) |
| 149 | + else: |
| 150 | + return getattr(self.compiled_fn, attr) |
| 151 | + |
| 152 | + @classmethod |
| 153 | + def from_backend_id( |
| 154 | + cls, backend_id: str |
| 155 | + ) -> "BundledAOTAutogradSerializableCallable": |
| 156 | + """ |
| 157 | + Takes in a backend_id, and returns a BundledAOTAutogradSerializableCallable |
| 158 | + that wraps around the compiled function generated by AOTAutograd. |
| 159 | + """ |
| 160 | + artifact = PrecompileContext.serialize_artifact_by_key(backend_id) |
| 161 | + if artifact is None: |
| 162 | + raise RuntimeError("No artifact found for backend_id: " + backend_id) |
| 163 | + return cls(artifact) |
| 164 | + |
| 165 | + @classmethod |
| 166 | + def serialize_compile_artifacts( |
| 167 | + cls, fn: "BundledAOTAutogradSerializableCallable" |
| 168 | + ) -> bytes: |
| 169 | + return fn.data |
| 170 | + |
| 171 | + @classmethod |
| 172 | + def deserialize_compile_artifacts(cls, data: bytes) -> Any: |
| 173 | + from torch._functorch._aot_autograd.autograd_cache import ( |
| 174 | + BundledAOTAutogradCacheArtifact, |
| 175 | + ) |
| 176 | + |
| 177 | + # The key in the artifact is not important here since we're not populating a cache, |
| 178 | + # we just want to grab the callable back out of the serialized entry |
| 179 | + artifact = BundledAOTAutogradCacheArtifact("", data) |
| 180 | + return cls(artifact) |
| 181 | + |
| 182 | + def __call__(self, *args: Any, **kwargs: Any) -> Any: |
| 183 | + return self.compiled_fn(*args, **kwargs) |
| 184 | + |
| 185 | + |
122 | 186 | def aot_compile_fullgraph( |
123 | 187 | model: Any, |
124 | 188 | example_inputs: tuple[tuple[Any, ...], dict[str, Any]], |
@@ -191,12 +255,23 @@ def new_guard_filter_fn( |
191 | 255 | assert check_fn.guards_state is not None |
192 | 256 |
|
193 | 257 | backend_input = capture_output.backend_input |
| 258 | + backend_input.graph_module._backend_id = backend_input.backend_id # type: ignore[assignment] |
194 | 259 | output_graph = dynamo_output.tracer_output.output_graph |
195 | 260 | assert output_graph is not None |
196 | 261 | import_sources = output_graph.import_sources |
197 | | - with torch._guards.tracing(TracingContext(backend_input.fake_mode)): |
| 262 | + with ( |
| 263 | + torch._guards.tracing(TracingContext(backend_input.fake_mode)), |
| 264 | + torch._functorch.config.patch("bundled_autograd_cache", True), |
| 265 | + ): |
198 | 266 | compiled_fn = backend(backend_input.graph_module, backend_input.example_inputs) |
199 | 267 |
|
| 268 | + # If Inductor backend is used, grab the compiled_fn from PrecompileContext |
| 269 | + # TODO: this should be replaced once we make the backend return the SerializableCallable directly. |
| 270 | + if isinstance(backend, torch._TorchCompileInductorWrapper): |
| 271 | + compiled_fn = BundledAOTAutogradSerializableCallable.from_backend_id( |
| 272 | + backend_input.backend_id |
| 273 | + ) |
| 274 | + |
200 | 275 | if not isinstance(compiled_fn, SerializableCallable): |
201 | 276 | if hasattr(backend, "compiler_fn"): |
202 | 277 | compiler_fn = backend.compiler_fn |
|
0 commit comments