Skip to content

Commit 9a7ae22

Browse files
ColinPepplerpytorchmergebot
authored andcommitted
Support negative index slicing with backed symints (#177308)
Pull Request resolved: #177308 Approved by: https://github.com/laithsakka ghstack dependencies: #175819
1 parent 79184f4 commit 9a7ae22

6 files changed

Lines changed: 63 additions & 2 deletions

File tree

test/inductor/test_aot_inductor.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1939,6 +1939,41 @@ def forward(self, x, y):
19391939
shifts = torch.arange(0, 64, 8, device=x.device, dtype=torch.int64)
19401940
return (expanded >> shifts) & 255
19411941

1942+
torch.cuda.caching_allocator_enable(False)
1943+
model = Repro()
1944+
example_inputs = (
1945+
torch.randint(
1946+
0, 256, (200, INNER_DIM), device=self.device, dtype=torch.int64
1947+
),
1948+
torch.randn(50, 8, device=self.device),
1949+
)
1950+
spec = {
1951+
"x": (Dim.DYNAMIC, Dim.STATIC),
1952+
"y": (Dim.DYNAMIC, Dim.STATIC),
1953+
}
1954+
self.check_model(model, example_inputs, dynamic_shapes=spec)
1955+
torch.cuda.caching_allocator_enable(True)
1956+
1957+
@skipIfMPS
1958+
@config.patch({"triton.autotune_at_compile_time": None})
1959+
@torch.fx.experimental._config.patch("backed_size_oblivious", True)
1960+
def test_slice_negative_index_backed_symints_no_unbacked(self):
1961+
# x[-s1:] where x.size(0) = s0-1 should produce Max(s0-1 - s1, 0),
1962+
# not an unbacked symint with a bad fallback value.
1963+
if self.device != GPU_TYPE:
1964+
raise unittest.SkipTest("requires triton")
1965+
1966+
INNER_DIM = 4224
1967+
1968+
class Repro(torch.nn.Module):
1969+
def forward(self, x, y):
1970+
x_trimmed = x[:-1]
1971+
sliced = x_trimmed[-y.size(0) :]
1972+
reshaped = sliced.reshape(-1, 128, 33)
1973+
expanded = reshaped.unsqueeze(3).expand(-1, 128, 33, 8)
1974+
shifts = torch.arange(0, 64, 8, device=x.device, dtype=torch.int64)
1975+
return (expanded >> shifts) & 255
1976+
19421977
torch.cuda.caching_allocator_enable(False)
19431978
try:
19441979
model = Repro()

torch/_inductor/ir.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3643,6 +3643,11 @@ def loader(idx: Sequence[Expr]) -> OpsValue:
36433643

36443644

36453645
class SliceView(View):
3646+
"""View that represents a slice along a single dimension.
3647+
3648+
Corresponds to tensor[..., start:end:step, ...].
3649+
"""
3650+
36463651
@classmethod
36473652
def normalize_start_end(
36483653
cls, x: IRNode, dim: int, start: int, end: int
@@ -3657,6 +3662,14 @@ def normalize_start_end(
36573662
if any(free_unbacked_symbols(x) for x in (start, end, dim_size)):
36583663
min_func = sympy.Min
36593664
max_func = sympy.Max
3665+
elif any(
3666+
# Only needed when backed_size_oblivious is on.
3667+
x.has(sympy.Min, sympy.Max)
3668+
for x in (start, end, dim_size)
3669+
if isinstance(x, Expr)
3670+
):
3671+
min_func = sympy.Min
3672+
max_func = sympy.Max
36603673
else:
36613674
min_func = sizevars.evaluate_min
36623675
max_func = sizevars.evaluate_max

torch/_inductor/lowering.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1432,6 +1432,9 @@ def compute_slice_index(index, size, default=None):
14321432
elif fn(sympy.Ge(index, 0)):
14331433
# If index >= 0, the resolved index is at most min(index, size).
14341434
return sympy.Min(index, size)
1435+
elif fn(sympy.Lt(index, 0)):
1436+
# If index < 0, wrap and clamp: the resolved index is at least 0.
1437+
return sympy.Max(index + size, 0)
14351438
return None
14361439

14371440
start_index, end_index = None, None

torch/_inductor/sizevars.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -550,7 +550,7 @@ def evaluate_min(self, left: Expr, right: Expr) -> Expr:
550550
return right
551551

552552
# Min/Max fallback: we can prove Min(a, b) <= c when any arg <= c, but
553-
# sympy doesn't simplify this yet. So, evaluate it here.
553+
# sympy doesn't simplify this yet. So, evaluate it here. Same for Max.
554554
for lhs, rhs in [(left, right), (right, left)]:
555555

556556
def le_rhs(a: Expr) -> bool:
@@ -559,6 +559,9 @@ def le_rhs(a: Expr) -> bool:
559559
# Min(Min(a, b), c) ==> Min(a, b) if (a <= c) or (b <= c).
560560
if isinstance(lhs, sympy.Min) and any(le_rhs(a) for a in lhs.args):
561561
return lhs
562+
# Min(Max(a, b), c) ==> Max(a, b) if (a <= c) and (b <= c).
563+
if isinstance(lhs, sympy.Max) and all(le_rhs(a) for a in lhs.args):
564+
return lhs
562565

563566
raise TypeError(
564567
f"evaluate_min({left}, {right}) with unbacked symints"

torch/_subclasses/fake_impls.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -963,6 +963,8 @@ def _compute_slice_index(size: IntLikeType, index: IntLikeType) -> IntLikeType |
963963
return size
964964
elif guard_or_false(index >= 0):
965965
return torch.sym_min(index, size)
966+
elif guard_or_false(index < 0):
967+
return torch.sym_max(index + size, 0)
966968

967969
return None
968970

@@ -1008,6 +1010,12 @@ def slice_forward(
10081010
new_size = (end_index - start_index + step - 1) // step
10091011
elif guard_or_false(start_index >= end_index):
10101012
new_size = 0
1013+
else:
1014+
# Both indices are resolved but we can't statically determine their
1015+
# ordering (e.g., when they involve Min/Max). Compute the size via
1016+
# max(end - start, 0) to avoid creating an unbacked symint.
1017+
diff = torch.sym_max(end_index - start_index, 0)
1018+
new_size = (diff + step - 1) // step
10111019

10121020
# create unbacked if case unknown
10131021
if new_size is None:

torch/testing/_internal/common_ops_unbacked.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -153,7 +153,6 @@ def skip(op_name, variant_name="", *, device_type=None, dtypes=None):
153153
xfail("nn.functional.fractional_max_pool2d"),
154154
xfail("nn.functional.fractional_max_pool3d"),
155155
xfail("nn.functional.gaussian_nll_loss"),
156-
xfail("nn.functional.glu"),
157156
xfail("nn.functional.grid_sample"),
158157
xfail("nn.functional.group_norm"),
159158
xfail("nn.functional.huber_loss"),

0 commit comments

Comments
 (0)