|
40 | 40 | Source, |
41 | 41 | TracingContext, |
42 | 42 | ) |
43 | | -from torch._subclasses.fake_tensor import FakeTensor |
44 | 43 | from torch._utils_internal import signpost_event |
45 | 44 | from torch.fx._lazy_graph_module import _make_graph_module # type: ignore[attr-defined] |
46 | 45 | from torch.fx.experimental._backward_state import BackwardState |
47 | | -from torch.fx.experimental.symbolic_shapes import ( |
48 | | - free_symbols, |
49 | | - guard_scalar, |
50 | | - is_symbolic, |
51 | | - ShapeEnv, |
52 | | -) |
| 46 | +from torch.fx.experimental.symbolic_shapes import free_symbols, is_symbolic, ShapeEnv |
53 | 47 | from torch.fx.passes.runtime_assert import insert_deferred_runtime_asserts |
54 | 48 | from torch.utils._python_dispatch import is_traceable_wrapper_subclass |
55 | 49 |
|
@@ -1343,8 +1337,6 @@ def compile_and_call_fx_graph(self, tx, rv, root, replaced_outputs): |
1343 | 1337 | ncalls = count_calls(self.graph) |
1344 | 1338 | counters["stats"]["calls_captured"] += ncalls |
1345 | 1339 |
|
1346 | | - self.remove_tensorify_specialized_graphargs() |
1347 | | - |
1348 | 1340 | # free a bit of memory |
1349 | 1341 | self.real_value_cache.clear() |
1350 | 1342 |
|
@@ -1681,40 +1673,6 @@ def update_used_symbols(used_symbols, fake: Union[torch.SymInt, torch.Tensor]): |
1681 | 1673 | # Make sure we delete later occurrences of the same symbol |
1682 | 1674 | used_symbols.remove(symbol) |
1683 | 1675 |
|
1684 | | - def remove_tensorify_specialized_graphargs(self) -> None: |
1685 | | - # This is a pretty interesting function. Basically we have this problem |
1686 | | - # where our compiler tends to choke when we have unused inputs. The way |
1687 | | - # we support dynamic float arguments is by doing a joint fx pass and |
1688 | | - # tensorifying away as many symfloats as we can. For the remaining symfloats |
1689 | | - # we have no choice but to specialize... HOWEVER at that point in time |
1690 | | - # we can no longer remove graph inputs. So our sledgehammer solution is to |
1691 | | - # save the state of what inputs we should have specialized in dynamo and |
1692 | | - # restart analysis. This function incorporates this "view from the future" |
1693 | | - # state and specializes inputs that we know we won't be able to tensorify |
1694 | | - # away in the joint pass. In principle we shouldn't choke on unused inputs |
1695 | | - # and so this shouldn't be necessary. In practice CUDA graphs choke on |
1696 | | - # unused inputs so we need this for now. |
1697 | | - |
1698 | | - # Import here to prevent circular import |
1699 | | - from torch._dynamo.symbolic_convert import TensorifyState |
1700 | | - |
1701 | | - for node in self.graph.nodes: |
1702 | | - example_value = node.meta.get("example_value") |
1703 | | - if ( |
1704 | | - isinstance(example_value, FakeTensor) |
1705 | | - and example_value.item_memo is not None |
1706 | | - and hasattr(example_value.item_memo.node._expr, "name") |
1707 | | - and all(u.target == "item" for u in node.users) |
1708 | | - and TensorifyState.should_specialize( |
1709 | | - # We use _expr instead of expr b/c we want the symbol not the replacement |
1710 | | - example_value.item_memo.node._expr.name |
1711 | | - ) |
1712 | | - ): |
1713 | | - for u in list(node.users): |
1714 | | - u.replace_all_uses_with(guard_scalar(example_value.item_memo)) |
1715 | | - self.remove_node(u) |
1716 | | - self.remove_node(node) |
1717 | | - |
1718 | 1676 | def add_output_instructions(self, prefix: List[Instruction]) -> None: |
1719 | 1677 | """ |
1720 | 1678 | We call this on the creation of a new compiled subgraph that is inserted |
|
0 commit comments