Skip to content

[inductor] Fix index_reduce_ on view inputs raising AssertionError in assert_functional_graph#176606

Closed
aorenste wants to merge 3 commits intogh/aorenste/207/basefrom
gh/aorenste/207/head
Closed

[inductor] Fix index_reduce_ on view inputs raising AssertionError in assert_functional_graph#176606
aorenste wants to merge 3 commits intogh/aorenste/207/basefrom
gh/aorenste/207/head

Conversation

@aorenste
Copy link
Copy Markdown
Contributor

@aorenste aorenste commented Mar 5, 2026

Stack from ghstack (oldest at bottom):

The _index_fill decomposition used mutable empty_like + copy_ to
restore strides when index_copy returned a contiguous tensor, which
broke the functional graph invariant. Replace with the functional
prims.copy_strided prim that does the same thing as a single op.

Fixes #144846

Authored with Claude.

cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @kadeng @muchulee8 @amjames @chauhang @aakhundov @coconutruben @jataylo

…in `assert_functional_graph`

The `_index_fill` decomposition used mutable `empty_like + copy_` to
restore strides when `index_copy` returned a contiguous tensor, which
broke the functional graph invariant. Replace with the functional
`prims.copy_strided` prim that does the same thing as a single op.

Fixes #144846

Authored with Claude.

[ghstack-poisoned]
@pytorch-bot
Copy link
Copy Markdown

pytorch-bot bot commented Mar 5, 2026

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/176606

Note: Links to docs will display an error until the docs builds have been completed.

✅ You can merge normally! (7 Unrelated Failures)

As of commit 766cfa8 with merge base d87ebee (image):

FLAKY - The following jobs failed but were likely due to flakiness present on trunk:

BROKEN TRUNK - The following jobs failed but were present on the merge base:

👉 Rebase onto the `viable/strict` branch to avoid these failures

UNSTABLE - The following job is marked as unstable, possibly due to flakiness on trunk:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@pytorch-bot
Copy link
Copy Markdown

pytorch-bot bot commented Mar 5, 2026

This PR needs a release notes: label

If your changes are user facing and intended to be a part of release notes, please use a label starting with release notes:.

If not, please add the topic: not user facing label.

To add a label, you can comment to pytorchbot, for example
@pytorchbot label "topic: not user facing"

For more information, see
https://github.com/pytorch/pytorch/wiki/PyTorch-AutoLabel-Bot#why-categorize-for-release-notes-and-how-does-it-work.

aorenste added a commit that referenced this pull request Mar 5, 2026
…in `assert_functional_graph`

The `_index_fill` decomposition used mutable `empty_like + copy_` to
restore strides when `index_copy` returned a contiguous tensor, which
broke the functional graph invariant. Replace with the functional
`prims.copy_strided` prim that does the same thing as a single op.

Fixes #144846

Authored with Claude.

ghstack-source-id: 8468ef5
Pull Request resolved: #176606
…rtionError in `assert_functional_graph`"

The `_index_fill` decomposition used mutable `empty_like + copy_` to
restore strides when `index_copy` returned a contiguous tensor, which
broke the functional graph invariant. Replace with the functional
`prims.copy_strided` prim that does the same thing as a single op.

Fixes #144846

Authored with Claude.

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy kadeng muchulee8 amjames chauhang aakhundov coconutruben jataylo

[ghstack-poisoned]
aorenste added a commit that referenced this pull request Mar 5, 2026
…in `assert_functional_graph`

The `_index_fill` decomposition used mutable `empty_like + copy_` to
restore strides when `index_copy` returned a contiguous tensor, which
broke the functional graph invariant. Replace with the functional
`prims.copy_strided` prim that does the same thing as a single op.

Fixes #144846

Authored with Claude.

ghstack-source-id: 292cd14
Pull Request resolved: #176606
@aorenste aorenste marked this pull request as ready for review March 5, 2026 20:59
@aorenste aorenste requested review from Lucaskabela and angelayi March 5, 2026 20:59
out2 = run_session(100, 16, 64, self.device)
self.assertEqual(out2.device.type, self.device)

def test_index_reduce_on_view_input(self):
Copy link
Copy Markdown
Contributor

@Lucaskabela Lucaskabela Mar 5, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we know why this causes a failure for Pallas backend? Let's make sure those signals are good before landing

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

According to claude the pallas backend doesn't support index_reduce_ so the only solution was to expect it to fail.

@aorenste
Copy link
Copy Markdown
Contributor Author

aorenste commented Mar 6, 2026

@pytorchbot merge

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label Mar 6, 2026
@pytorchmergebot
Copy link
Copy Markdown
Collaborator

Merge failed

Reason: This PR needs a release notes: label
If your changes are user facing and intended to be a part of release notes, please use a label starting with release notes:.

If not, please add the topic: not user facing label.

To add a label, you can comment to pytorchbot, for example
@pytorchbot label "topic: not user facing"

For more information, see
https://github.com/pytorch/pytorch/wiki/PyTorch-AutoLabel-Bot#why-categorize-for-release-notes-and-how-does-it-work.

Details for Dev Infra team Raised by workflow job

@pytorch-bot
Copy link
Copy Markdown

pytorch-bot bot commented Mar 6, 2026

This PR needs a release notes: label

If your changes are user facing and intended to be a part of release notes, please use a label starting with release notes:.

If not, please add the topic: not user facing label.

To add a label, you can comment to pytorchbot, for example
@pytorchbot label "topic: not user facing"

For more information, see
https://github.com/pytorch/pytorch/wiki/PyTorch-AutoLabel-Bot#why-categorize-for-release-notes-and-how-does-it-work.

@aorenste aorenste added the topic: not user facing topic category label Mar 6, 2026
@aorenste
Copy link
Copy Markdown
Contributor Author

aorenste commented Mar 6, 2026

@pytorchbot merge

@pytorchmergebot
Copy link
Copy Markdown
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@jeffdaily
Copy link
Copy Markdown
Collaborator

This broke ROCm CI.

PYTORCH_OPINFO_SAMPLE_INPUT_INDEX=4 PYTORCH_TEST_WITH_ROCM=1 python test/test_meta.py TestMetaCUDA.test_dispatch_symbolic_meta_outplace_all_strides_index_fill_cuda_float32

test/test_meta.py::TestMetaCUDA::test_dispatch_symbolic_meta_outplace_all_strides_index_fill_cuda_float32 GH job link HUD commit link

@jeffdaily
Copy link
Copy Markdown
Collaborator

@pytorchbot revert -c weird -m "trunk was passing pre-merge, but failures appeared post-merge only. see test/test_meta.py::TestMetaCUDA::test_dispatch_symbolic_meta_outplace_all_strides_index_fill_cuda_float32 GH job link HUD commit link"

@pytorchmergebot
Copy link
Copy Markdown
Collaborator

@pytorchbot successfully started a revert job. Check the current status here.
Questions? Feedback? Please reach out to the PyTorch DevX Team

@malfet
Copy link
Copy Markdown
Contributor

malfet commented Mar 7, 2026

@jeffdaily already pinged the team, but feel free to find commit that causes conflicts and revert as well

@zou3519
Copy link
Copy Markdown
Contributor

zou3519 commented Mar 7, 2026

@pytorchbot revert -c weird -m "trunk was passing pre-merge, but failures appeared post-merge only. see test/test_meta.py::TestMetaCUDA::test_dispatch_symbolic_meta_outplace_all_strides_index_fill_cuda_float32 GH job link HUD commit link"

@pytorchmergebot
Copy link
Copy Markdown
Collaborator

@pytorchbot successfully started a revert job. Check the current status here.
Questions? Feedback? Please reach out to the PyTorch DevX Team

@pytorchmergebot
Copy link
Copy Markdown
Collaborator

Reverting PR 176606 failed

Reason: Command git -C /home/runner/work/pytorch/pytorch revert --no-edit 354b0ff88ca5350b68da5d1dee8f3e32dcede563 returned non-zero exit code 1

Auto-merging test/inductor/test_torchinductor.py
CONFLICT (content): Merge conflict in test/inductor/test_torchinductor.py
error: could not revert 354b0ff88ca... [inductor] Fix `index_reduce_` on view inputs raising AssertionError in `assert_functional_graph` (#176606)
hint: After resolving the conflicts, mark them with
hint: "git add/rm <pathspec>", then run
hint: "git revert --continue".
hint: You can instead skip this commit with "git revert --skip".
hint: To abort and get back to the state before "git revert",
hint: run "git revert --abort".
hint: Disable this message with "git config set advice.mergeConflict false"
Details for Dev Infra team Raised by workflow job

@zou3519
Copy link
Copy Markdown
Contributor

zou3519 commented Mar 7, 2026

rip I thought I got the merge conflict

pytorchmergebot added a commit that referenced this pull request Mar 7, 2026
…hader codegen (#176436)"

This reverts commit 4926192.

Reverted #176436 on behalf of https://github.com/zou3519 due to sorry I need to revert this in order to revert #176606 ([comment](#176436 (comment)))
pytorchmergebot referenced this pull request Mar 7, 2026
…degen (#176436)

Metal Shading Language rejects implicit float-to-bfloat conversions, so
bare float literals like `0.0` in generated shaders cause compilation
failures when the target variable is `bfloat` (or `half`). Three codegen
methods were affected:

- `constant()` ignored its `dtype` parameter and returned raw literals.
- `masked()` assigned a bare literal in the else-branch (`} else tmp = 0.0;`).
- `where()` passed a bare literal through the ternary without casting.

All three now emit `static_cast<bfloat>(...)` / `static_cast<half>(...)`
where needed. Tests added for half-precision constants, reductions, and
conditionals.

Pull Request resolved: #176436
Approved by: https://github.com/malfet
@zou3519
Copy link
Copy Markdown
Contributor

zou3519 commented Mar 7, 2026

@pytorchbot revert -c weird -m "trunk was passing pre-merge, but failures appeared post-merge only. see test/test_meta.py::TestMetaCUDA::test_dispatch_symbolic_meta_outplace_all_strides_index_fill_cuda_float32 GH job link HUD commit link"

@pytorchmergebot
Copy link
Copy Markdown
Collaborator

@pytorchbot successfully started a revert job. Check the current status here.
Questions? Feedback? Please reach out to the PyTorch DevX Team

pytorchmergebot added a commit that referenced this pull request Mar 7, 2026
…onError in `assert_functional_graph` (#176606)"

This reverts commit 354b0ff.

Reverted #176606 on behalf of https://github.com/zou3519 due to trunk was passing pre-merge, but failures appeared post-merge only. see test/test_meta.py::TestMetaCUDA::test_dispatch_symbolic_meta_outplace_all_strides_index_fill_cuda_float32 [GH job link](https://github.com/pytorch/pytorch/actions/runs/22771325935/job/66057190983) [HUD commit link](https://hud.pytorch.org/pytorch/pytorch/commit/354b0ff88ca5350b68da5d1dee8f3e32dcede563) ([comment](#176606 (comment)))
@pytorchmergebot
Copy link
Copy Markdown
Collaborator

@aorenste your PR has been successfully reverted.

…rtionError in `assert_functional_graph`"

The `_index_fill` decomposition used mutable `empty_like + copy_` to
restore strides when `index_copy` returned a contiguous tensor, which
broke the functional graph invariant. Replace with the functional
`prims.copy_strided` prim that does the same thing as a single op.

Fixes #144846

Authored with Claude.

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy kadeng muchulee8 amjames chauhang aakhundov coconutruben jataylo

[ghstack-poisoned]
aorenste added a commit that referenced this pull request Mar 7, 2026
…in `assert_functional_graph`

The `_index_fill` decomposition used mutable `empty_like + copy_` to
restore strides when `index_copy` returned a contiguous tensor, which
broke the functional graph invariant. Replace with the functional
`prims.copy_strided` prim that does the same thing as a single op.

Fixes #144846

Authored with Claude.

ghstack-source-id: ad3ad51
Pull Request resolved: #176606
@aorenste
Copy link
Copy Markdown
Contributor Author

aorenste commented Mar 9, 2026

@pytorchbot merge

@pytorchmergebot
Copy link
Copy Markdown
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

EmanueleCoradin pushed a commit to EmanueleCoradin/pytorch that referenced this pull request Mar 30, 2026
…in `assert_functional_graph` (pytorch#176606)

The `_index_fill` decomposition used mutable `empty_like + copy_` to
restore strides when `index_copy` returned a contiguous tensor, which
broke the functional graph invariant. Replace with the functional
`prims.copy_strided` prim that does the same thing as a single op.

Fixes pytorch#144846

Authored with Claude.

Pull Request resolved: pytorch#176606
Approved by: https://github.com/Lucaskabela
EmanueleCoradin pushed a commit to EmanueleCoradin/pytorch that referenced this pull request Mar 30, 2026
…hader codegen (pytorch#176436)"

This reverts commit 4926192.

Reverted pytorch#176436 on behalf of https://github.com/zou3519 due to sorry I need to revert this in order to revert pytorch#176606 ([comment](pytorch#176436 (comment)))
EmanueleCoradin pushed a commit to EmanueleCoradin/pytorch that referenced this pull request Mar 30, 2026
…onError in `assert_functional_graph` (pytorch#176606)"

This reverts commit 354b0ff.

Reverted pytorch#176606 on behalf of https://github.com/zou3519 due to trunk was passing pre-merge, but failures appeared post-merge only. see test/test_meta.py::TestMetaCUDA::test_dispatch_symbolic_meta_outplace_all_strides_index_fill_cuda_float32 [GH job link](https://github.com/pytorch/pytorch/actions/runs/22771325935/job/66057190983) [HUD commit link](https://hud.pytorch.org/pytorch/pytorch/commit/354b0ff88ca5350b68da5d1dee8f3e32dcede563) ([comment](pytorch#176606 (comment)))
EmanueleCoradin pushed a commit to EmanueleCoradin/pytorch that referenced this pull request Mar 30, 2026
…in `assert_functional_graph` (pytorch#176606)

The `_index_fill` decomposition used mutable `empty_like + copy_` to
restore strides when `index_copy` returned a contiguous tensor, which
broke the functional graph invariant. Replace with the functional
`prims.copy_strided` prim that does the same thing as a single op.

Fixes pytorch#144846

Authored with Claude.

Pull Request resolved: pytorch#176606
Approved by: https://github.com/Lucaskabela, https://github.com/mlazos
@github-actions github-actions bot deleted the gh/aorenste/207/head branch April 8, 2026 02:23
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

7 participants