Skip to content

Commit 354b0ff

Browse files
aorenstepytorchmergebot
authored andcommitted
[inductor] Fix index_reduce_ on view inputs raising AssertionError in assert_functional_graph (#176606)
The `_index_fill` decomposition used mutable `empty_like + copy_` to restore strides when `index_copy` returned a contiguous tensor, which broke the functional graph invariant. Replace with the functional `prims.copy_strided` prim that does the same thing as a single op. Fixes #144846 Authored with Claude. Pull Request resolved: #176606 Approved by: https://github.com/Lucaskabela
1 parent 45d619c commit 354b0ff

3 files changed

Lines changed: 14 additions & 3 deletions

File tree

test/inductor/pallas_expected_failures/CpuTests.test_index_reduce_on_view_input_cpu

Whitespace-only changes.

test/inductor/test_torchinductor.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15912,6 +15912,19 @@ def run_session(x_param, y_param, size, device):
1591215912
out2 = run_session(100, 16, 64, self.device)
1591315913
self.assertEqual(out2.device.type, self.device)
1591415914

15915+
def test_index_reduce_on_view_input(self):
15916+
# Regression test for https://github.com/pytorch/pytorch/issues/144846
15917+
def fn(x, index, source):
15918+
return x.index_reduce_(2, index, source, "mean", include_self=False)
15919+
15920+
x_base = torch.randn(4, 34, 64, device=self.device)
15921+
index = torch.randint(0, 34, (64,), device=self.device)
15922+
source = torch.randn(4, 32, 64, device=self.device)
15923+
15924+
expected = fn(x_base.clone()[:, 2:, :], index, source)
15925+
result = torch.compile(fn)(x_base.clone()[:, 2:, :], index, source)
15926+
self.assertEqual(result, expected)
15927+
1591515928
# end of class CommonTemplate - add new tests here
1591615929

1591715930

torch/_refs/__init__.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4319,9 +4319,7 @@ def _index_fill(
43194319
out = out.squeeze(0).clone()
43204320
# index_fill preserves the strides. index_copy always returns contiguous tensors
43214321
if out.stride() != x.stride():
4322-
new_out = torch.empty_like(x)
4323-
new_out.copy_(out)
4324-
out = new_out
4322+
out = prims.copy_strided(out, x.stride())
43254323
return out
43264324

43274325

0 commit comments

Comments
 (0)