Skip to content

Commit 6e16a7e

Browse files
ColinPepplerfacebook-github-bot
authored andcommitted
[reland] Slicing with backed should produce backed output when possible (#178899)
Summary: Original PR: #175819 - it got reverted internally (D98767572 ) - i must reland as a new diff internal -> then export again (hence this diff) ### Summary * `x[0:s1]` where x.size(0) = `s0-1` should produce `Min(s1, s0-1)` * Before this PR, it would produce `u0`. imported-using-ghimport Test Plan: Imported from OSS Reviewed By: sevenEng Differential Revision: D98937973 Pulled By: ColinPeppler
1 parent 8cc5b51 commit 6e16a7e

6 files changed

Lines changed: 56 additions & 13 deletions

File tree

test/distributed/tensor/test_dtensor_export.py

Lines changed: 1 addition & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -541,18 +541,7 @@ def forward(self, x):
541541
%item : [num_users=2] = call_method[target=item](args = (%clamp,), kwargs = {})
542542
%ge_1 : [num_users=1] = call_function[target=operator.ge](args = (%item, 1), kwargs = {})
543543
%_assert_scalar_default : [num_users=0] = call_function[target=torch.ops.aten._assert_scalar.default](args = (%ge_1, Runtime assertion failed for expression u0 >= 1 on node 'ge_1'), kwargs = {})
544-
%getitem : [num_users=3] = call_function[target=operator.getitem](args = (%l_x_, slice(None, item, None)), kwargs = {})
545-
%getattr_1 : [num_users=1] = call_function[target=builtins.getattr](args = (%getitem, _local_tensor), kwargs = {})
546-
%sym_size_int : [num_users=2] = call_function[target=torch.ops.aten.sym_size.int](args = (%getattr_1, 0), kwargs = {})
547-
%sym_size_int_1 : [num_users=2] = call_function[target=torch.ops.aten.sym_size.int](args = (%getitem, 0), kwargs = {})
548-
%ge_2 : [num_users=1] = call_function[target=operator.ge](args = (%sym_size_int, 0), kwargs = {})
549-
%_assert_scalar_default_1 : [num_users=0] = call_function[target=torch.ops.aten._assert_scalar.default](args = (%ge_2, Runtime assertion failed for expression u2 >= 0 on node 'ge_2'), kwargs = {})
550-
%le : [num_users=1] = call_function[target=operator.le](args = (%sym_size_int, 4), kwargs = {})
551-
%_assert_scalar_default_2 : [num_users=0] = call_function[target=torch.ops.aten._assert_scalar.default](args = (%le, Runtime assertion failed for expression u2 <= 4 on node 'le'), kwargs = {})
552-
%ge_3 : [num_users=1] = call_function[target=operator.ge](args = (%sym_size_int_1, 0), kwargs = {})
553-
%_assert_scalar_default_3 : [num_users=0] = call_function[target=torch.ops.aten._assert_scalar.default](args = (%ge_3, Runtime assertion failed for expression u1 >= 0 on node 'ge_3'), kwargs = {})
554-
%le_1 : [num_users=1] = call_function[target=operator.le](args = (%sym_size_int_1, 4), kwargs = {})
555-
%_assert_scalar_default_4 : [num_users=0] = call_function[target=torch.ops.aten._assert_scalar.default](args = (%le_1, Runtime assertion failed for expression u1 <= 4 on node 'le_1'), kwargs = {})
544+
%getitem : [num_users=1] = call_function[target=operator.getitem](args = (%l_x_, slice(None, item, None)), kwargs = {})
556545
return (getitem,)""", # noqa: B950
557546
)
558547

test/export/test_draft_export.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -427,7 +427,7 @@ def forward(self, x, y):
427427
for node in _ep.graph.nodes:
428428
if bindings := node.meta.get("unbacked_bindings"):
429429
unbacked_binding_symbols.update(bindings.keys())
430-
self.assertEqual(len(unbacked_binding_symbols), 2)
430+
self.assertEqual(len(unbacked_binding_symbols), 1)
431431

432432
def test_offsets(self):
433433
class M(torch.nn.Module):

test/inductor/test_aot_inductor.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1920,6 +1920,43 @@ def forward(self, x, y, lengths):
19201920
self.check_model(model, example_inputs, dynamic_shapes=spec)
19211921
torch.cuda.caching_allocator_enable(True)
19221922

1923+
@skipIfMPS
1924+
@config.patch({"triton.autotune_at_compile_time": None})
1925+
@torch.fx.experimental._config.patch("backed_size_oblivious", True)
1926+
def test_slice_independent_backed_symints_no_unbacked(self):
1927+
# x[0:s1] where x.size(0) = s0-1 should produce Min(s1, s0-1),
1928+
# not an unbacked symint with a bad fallback value.
1929+
if self.device != GPU_TYPE:
1930+
raise unittest.SkipTest("requires triton")
1931+
1932+
INNER_DIM = 4224
1933+
1934+
class Repro(torch.nn.Module):
1935+
def forward(self, x, y):
1936+
x_trimmed = x[:-1]
1937+
sliced = x_trimmed[: y.size(0)]
1938+
reshaped = sliced.reshape(-1, 128, 33)
1939+
expanded = reshaped.unsqueeze(3).expand(-1, 128, 33, 8)
1940+
shifts = torch.arange(0, 64, 8, device=x.device, dtype=torch.int64)
1941+
return (expanded >> shifts) & 255
1942+
1943+
torch.cuda.caching_allocator_enable(False)
1944+
try:
1945+
model = Repro()
1946+
example_inputs = (
1947+
torch.randint(
1948+
0, 256, (200, INNER_DIM), device=self.device, dtype=torch.int64
1949+
),
1950+
torch.randn(50, 8, device=self.device),
1951+
)
1952+
spec = {
1953+
"x": (Dim.DYNAMIC, Dim.STATIC),
1954+
"y": (Dim.DYNAMIC, Dim.STATIC),
1955+
}
1956+
self.check_model(model, example_inputs, dynamic_shapes=spec)
1957+
finally:
1958+
torch.cuda.caching_allocator_enable(True)
1959+
19231960
@config.patch({"triton.autotune_at_compile_time": None})
19241961
def test_stride_with_unbacked_expr(self):
19251962
class Repro(torch.nn.Module):

torch/_inductor/lowering.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1432,6 +1432,9 @@ def compute_slice_index(index, size, default=None):
14321432
return size
14331433
elif fn(sympy.Lt(index, -size)):
14341434
return 0
1435+
elif fn(sympy.Ge(index, 0)):
1436+
# If index >= 0, the resolved index is at most min(index, size).
1437+
return sympy.Min(index, size)
14351438
return None
14361439

14371440
start_index, end_index = None, None

torch/_inductor/sizevars.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -601,6 +601,17 @@ def evaluate_min(self, left: Expr, right: Expr) -> Expr:
601601
if right == gcd:
602602
return right
603603

604+
# Min/Max fallback: we can prove Min(a, b) <= c when any arg <= c, but
605+
# sympy doesn't simplify this yet. So, evaluate it here.
606+
for lhs, rhs in [(left, right), (right, left)]:
607+
608+
def le_rhs(a: Expr) -> bool:
609+
return self.guard_or_false(sympy.Le(a, rhs))
610+
611+
# Min(Min(a, b), c) ==> Min(a, b) if (a <= c) or (b <= c).
612+
if isinstance(lhs, sympy.Min) and any(le_rhs(a) for a in lhs.args):
613+
return lhs
614+
604615
raise TypeError(
605616
f"evaluate_min({left}, {right}) with unbacked symints"
606617
) from None

torch/_subclasses/fake_impls.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -961,6 +961,9 @@ def _compute_slice_index(size: IntLikeType, index: IntLikeType) -> IntLikeType |
961961
return 0
962962
elif guard_or_false(index > size):
963963
return size
964+
elif guard_or_false(index >= 0):
965+
return torch.sym_min(index, size)
966+
964967
return None
965968

966969

0 commit comments

Comments
 (0)