Skip to content

Commit 30d8b30

Browse files
bobrenjc93pytorchmergebot
authored andcommitted
refactor tensorify restart logic to use sources (#141517)
Differential Revision: [D67066706](https://our.internmc.facebook.com/intern/diff/D67066706) Pull Request resolved: #141517 Approved by: https://github.com/ezyang
1 parent bdbdbee commit 30d8b30

4 files changed

Lines changed: 35 additions & 77 deletions

File tree

torch/_dynamo/output_graph.py

Lines changed: 1 addition & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -40,16 +40,10 @@
4040
Source,
4141
TracingContext,
4242
)
43-
from torch._subclasses.fake_tensor import FakeTensor
4443
from torch._utils_internal import signpost_event
4544
from torch.fx._lazy_graph_module import _make_graph_module # type: ignore[attr-defined]
4645
from torch.fx.experimental._backward_state import BackwardState
47-
from torch.fx.experimental.symbolic_shapes import (
48-
free_symbols,
49-
guard_scalar,
50-
is_symbolic,
51-
ShapeEnv,
52-
)
46+
from torch.fx.experimental.symbolic_shapes import free_symbols, is_symbolic, ShapeEnv
5347
from torch.fx.passes.runtime_assert import insert_deferred_runtime_asserts
5448
from torch.utils._python_dispatch import is_traceable_wrapper_subclass
5549

@@ -1343,8 +1337,6 @@ def compile_and_call_fx_graph(self, tx, rv, root, replaced_outputs):
13431337
ncalls = count_calls(self.graph)
13441338
counters["stats"]["calls_captured"] += ncalls
13451339

1346-
self.remove_tensorify_specialized_graphargs()
1347-
13481340
# free a bit of memory
13491341
self.real_value_cache.clear()
13501342

@@ -1681,40 +1673,6 @@ def update_used_symbols(used_symbols, fake: Union[torch.SymInt, torch.Tensor]):
16811673
# Make sure we delete later occurrences of the same symbol
16821674
used_symbols.remove(symbol)
16831675

1684-
def remove_tensorify_specialized_graphargs(self) -> None:
1685-
# This is a pretty interesting function. Basically we have this problem
1686-
# where our compiler tends to choke when we have unused inputs. The way
1687-
# we support dynamic float arguments is by doing a joint fx pass and
1688-
# tensorifying away as many symfloats as we can. For the remaining symfloats
1689-
# we have no choice but to specialize... HOWEVER at that point in time
1690-
# we can no longer remove graph inputs. So our sledgehammer solution is to
1691-
# save the state of what inputs we should have specialized in dynamo and
1692-
# restart analysis. This function incorporates this "view from the future"
1693-
# state and specializes inputs that we know we won't be able to tensorify
1694-
# away in the joint pass. In principle we shouldn't choke on unused inputs
1695-
# and so this shouldn't be necessary. In practice CUDA graphs choke on
1696-
# unused inputs so we need this for now.
1697-
1698-
# Import here to prevent circular import
1699-
from torch._dynamo.symbolic_convert import TensorifyState
1700-
1701-
for node in self.graph.nodes:
1702-
example_value = node.meta.get("example_value")
1703-
if (
1704-
isinstance(example_value, FakeTensor)
1705-
and example_value.item_memo is not None
1706-
and hasattr(example_value.item_memo.node._expr, "name")
1707-
and all(u.target == "item" for u in node.users)
1708-
and TensorifyState.should_specialize(
1709-
# We use _expr instead of expr b/c we want the symbol not the replacement
1710-
example_value.item_memo.node._expr.name
1711-
)
1712-
):
1713-
for u in list(node.users):
1714-
u.replace_all_uses_with(guard_scalar(example_value.item_memo))
1715-
self.remove_node(u)
1716-
self.remove_node(node)
1717-
17181676
def add_output_instructions(self, prefix: List[Instruction]) -> None:
17191677
"""
17201678
We call this on the creation of a new compiled subgraph that is inserted

torch/_dynamo/symbolic_convert.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -247,17 +247,16 @@ class DistributedState:
247247

248248

249249
class TensorifyState:
250-
# These are the set of string symfloats names (eg. "zf0") that we collect
251-
# from the tensorify_python_scalars.py joint fx pass to inform us about
252-
# which float inputs we should specialize when we restart analysis.
253-
force_specializations: Set[str] = set()
250+
# These are the set of source that we collect from the tensorify_python_scalars.py joint
251+
# fx pass to inform us about which float inputs we should specialize when we restart analysis.
252+
force_specializations: Set[Source] = set()
254253

255254
@classmethod
256-
def specialize(cls, index: str) -> None:
255+
def specialize(cls, index: Source) -> None:
257256
cls.force_specializations.add(index)
258257

259258
@classmethod
260-
def should_specialize(cls, index: str) -> bool:
259+
def should_specialize(cls, index: Source) -> bool:
261260
return index in cls.force_specializations
262261

263262
@classmethod

torch/_dynamo/variables/builder.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1901,6 +1901,9 @@ def wrap_symint(self, value):
19011901
return unspec_var
19021902

19031903
def wrap_symfloat(self, value):
1904+
# To prevent circular import
1905+
from ..symbolic_convert import TensorifyState
1906+
19041907
# SymFloat wrapping is special. We first wrap it in the same way we
19051908
# do an unspecialized primitive, and then we item() it into a
19061909
# SymFloat. Removal of the item() call is left to a later FX pass,
@@ -1932,6 +1935,7 @@ def wrap_symfloat(self, value):
19321935
or torch._inductor.config.triton.cudagraphs
19331936
or justknobs_check("pytorch/compiler:unspecialize_float_killswitch", False)
19341937
or frame_state_entry.scalar is not auto_dynamic
1938+
or TensorifyState.should_specialize(self.source)
19351939
):
19361940
self.install_guards(GuardBuilder.CONSTANT_MATCH)
19371941
return ConstantVariable.create(value=value, source=self.source)

torch/fx/passes/_tensorify_python_scalars.py

Lines changed: 25 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -223,19 +223,25 @@ def _sympy_interp(expr: sympy.Expr) -> MetaProxy:
223223
val = node.meta.get("val")
224224
if isinstance(val, FakeTensor):
225225
for dim in val.shape:
226-
if isinstance(dim, torch.SymInt):
227-
for s in dim.node.expr.free_symbols:
228-
name = str(s)
229-
if symbol_is_type(
230-
s, SymT.FLOAT
231-
) and not TensorifyState.should_specialize(name):
232-
# In principle, we could support float input that
233-
# is used to do size compute. The problem is that
234-
# we don't actually want to tensorify the compute
235-
# in this case, which means we need codegen support for
236-
# all symfloats.
237-
TensorifyState.specialize(name)
238-
should_restart = True
226+
if not isinstance(dim, torch.SymInt):
227+
continue
228+
229+
for symbol in dim.node.expr.free_symbols:
230+
if not symbol_is_type(symbol, SymT.FLOAT):
231+
continue
232+
233+
sources = shape_env.var_to_sources.get(symbol)
234+
for source in sources:
235+
if TensorifyState.should_specialize(source):
236+
continue
237+
238+
# In principle, we could support float input that
239+
# is used to do size compute. The problem is that
240+
# we don't actually want to tensorify the compute
241+
# in this case, which means we need codegen support
242+
# for all symfloats.
243+
TensorifyState.specialize(source)
244+
should_restart = True
239245

240246
# Look for functions to convert
241247
if node.op == "call_function" and (
@@ -322,21 +328,12 @@ def _sympy_interp(expr: sympy.Expr) -> MetaProxy:
322328
node.replace_all_uses_with(guard_scalar(val))
323329
graph.erase_node(node)
324330

325-
# Sometimes by the time we get to tensorify, there have already been
326-
# specializations, eg. in python_arg_parser.h. In these cases,
327-
# placeholder nodes no longer have a reference to their original
328-
# symfloat and thus we need to deduce specializations have happend
329-
# via shape_env.replacements. NB: there's an important invariant here
330-
# that symfloats keep consistent names across restarts.
331-
for k, v in shape_env.var_to_val.items():
332-
if symbol_is_type(k, SymT.FLOAT) and isinstance(v, sympy.core.numbers.Float):
333-
name = str(k)
334-
if (
335-
not TensorifyState.should_specialize(name)
336-
and k not in tensorified_symbols
337-
):
338-
TensorifyState.specialize(name)
339-
should_restart = True
331+
for symbol, sources in shape_env.var_to_sources.items():
332+
if symbol_is_type(symbol, SymT.FLOAT) and symbol not in tensorified_symbols:
333+
for source in sources:
334+
if not TensorifyState.should_specialize(source):
335+
TensorifyState.specialize(source)
336+
should_restart = True
340337

341338
if should_restart:
342339
# Sledgehammer time. Restart dynamo analysis, keeping track of which input sources

0 commit comments

Comments
 (0)