Skip to content

Commit 8a25bc6

Browse files
committed
Update on "Add support for nonzero, some improvements to reduce guards"
This takes the strategy described in https://docs.google.com/document/d/1lFRYAJo5nrfxRhwIzGnfi2pbLpU6T4ytSRSuLJ5qebI/edit# It is essentially #95222 but squashed and with changes that are unnecessary given that we assume nonzero returns > 1. What's in the PR: * nonzero now supports meta propagation. When `capture_dynamic_output_shape_ops`, it will return a tensor with an unbacked SymInt representing the size in question. * The unbacked SymInt is UNSOUNDLY assumed to be not equal to 0/1. We will still error if you guard otherwise. * PrimTorch pointwise operators are updated to use empty_permuted, to avoid guarding on unbacked SymInt from empty_strided (tested in `test_dynamic_pointwise_scalar`) * Convolution is updated to skip backend selection if batch is unbacked, to avoid guarding on unbacked SymInt (tested in `test_unbacked_batch_resnet`) * I kept the helper utilities like `definitely_true` for working with possibly unbacked SymInts. They're not used right now but maybe someone will find them useful. * Added `constrain_unify` to let you specify two unbacked SymInts must have the same value Signed-off-by: Edward Z. Yang <ezyangmeta.com> cc soumith voznesenskym yanboliang penguinwu anijain2305 EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx desertfire [ghstack-poisoned]
1 parent c619edc commit 8a25bc6

4 files changed

Lines changed: 19 additions & 13 deletions

File tree

torch/_decomp/decompositions.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -19,11 +19,7 @@
1919
_safe_copy_out,
2020
out_wrapper,
2121
)
22-
from torch.fx.experimental.symbolic_shapes import (
23-
definitely_true,
24-
guard_int,
25-
tensor_has_hints,
26-
)
22+
from torch.fx.experimental.symbolic_shapes import guard_int
2723
from torch.utils._pytree import tree_flatten, tree_map
2824

2925
DispatchKey = torch._C.DispatchKey # type: ignore[attr-defined]

torch/_prims_common/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,7 @@ def getnvFuserDtype(dtype: Union[torch.dtype, NumberTypeType]):
7777
torch.Tensor.device.__get__, # type: ignore[attr-defined]
7878
torch.Tensor.requires_grad.__get__, # type: ignore[attr-defined]
7979
torch.Tensor.layout.__get__, # type: ignore[attr-defined]
80+
torch.Tensor.is_contiguous,
8081
# For TorchRefsMode only
8182
torch.Tensor.__format__,
8283
torch.Tensor.__repr__,

torch/_subclasses/fake_tensor.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -426,11 +426,22 @@ def nonzero(fake_mode, func, arg):
426426
raise DynamicOutputShapeException(func)
427427
nnz = fake_mode.shape_env.create_unbacked_symint()
428428

429-
from torch.fx.experimental.symbolic_shapes import constrain_range
429+
from torch.fx.experimental.symbolic_shapes import (
430+
constrain_range,
431+
definitely_true,
432+
guard_int,
433+
)
434+
430435
# This is unsound, but it works well in practice
431436
# See https://docs.google.com/document/d/1lFRYAJo5nrfxRhwIzGnfi2pbLpU6T4ytSRSuLJ5qebI/edit#
432437
# TODO: Add a config knob to turn off this unsound behavior
433-
constrain_range(nnz, min=min(2, arg.numel()))
438+
lower = 2
439+
upper = None
440+
# But don't give totally unsatisfiable bounds if we know it's too small!
441+
if definitely_true(arg.numel() < 2):
442+
lower = 0
443+
upper = guard_int(arg.numel())
444+
constrain_range(nnz, min=lower, max=upper)
434445

435446
return arg.new_empty((nnz, arg.dim()), dtype=torch.int64)
436447

torch/fx/experimental/symbolic_shapes.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1355,12 +1355,10 @@ def create_symbol(self, val: int, source: Source, dyn=False) -> "sympy.Expr":
13551355
if not dyn:
13561356
# Non explicitly marked dynamic dims register to val_to_var to get duck shaped
13571357
self.val_to_var[val] = sympy_expr
1358-
# We also infer that they must not be 0/1
1359-
lower = 2 if self.specialize_zero_one else 0
1360-
self.var_to_range[sympy_expr] = ValueRanges(lower, sympy.oo)
1361-
else:
1362-
# Avoid up front 0/1 specializing dynamic dims
1363-
self.var_to_range[sympy_expr] = ValueRanges(0, sympy.oo)
1358+
1359+
# We also infer that it must be not 0/1
1360+
lower = 2 if self.specialize_zero_one else 0
1361+
self.var_to_range[sympy_expr] = ValueRanges(lower, sympy.oo)
13641362

13651363
if not dyn and self.duck_shape:
13661364
# This implements duck-shaping: input sizes that match are assigned

0 commit comments

Comments
 (0)