Skip to content

Pallas: generalize permutation detection for N-D tensor transposes#176952

Closed
v0i0 wants to merge 1 commit intomainfrom
v0i0/pallas-permutation-detection
Closed

Pallas: generalize permutation detection for N-D tensor transposes#176952
v0i0 wants to merge 1 commit intomainfrom
v0i0/pallas-permutation-detection

Conversation

@v0i0
Copy link
Copy Markdown
Contributor

@v0i0 v0i0 commented Mar 9, 2026

Stack

This is part of a PR stack. Merge order:

  1. Pallas: add strided access support via reshape + static indexing #176622 — Pallas: strided access support (ready to merge)
  2. Pallas: generalize permutation detection for N-D tensor transposes #176952 — Pallas: permutation detection (this PR)
  3. Pallas: add scalar prefetch and indirect access support #177212 — Pallas: scalar prefetch (depends on this PR)

Summary

  • Generalize permutation detection for N-D tensor transposes (not just 2D swaps)
  • Add collapsed-dimension detection for cases where iteration dimensions map to grouped tensor dimensions
  • Skip permutation detection on GPU (inputs are flattened to 1D by pallas_gpu_pad_inputs)
  • Add comprehensive permutation tests for 2D and 3D tensors

Test plan

  • python -m pytest test/inductor/test_pallas.py -v -k "permute" — all permutation tests pass on TPU
  • Full test suite: 878 passed, 154 skipped, 67 xfailed, 102 subtests passed, 0 unexpected failures

cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @kadeng @muchulee8 @amjames @chauhang @aakhundov @coconutruben @jataylo

@pytorch-bot
Copy link
Copy Markdown

pytorch-bot bot commented Mar 9, 2026

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/176952

Note: Links to docs will display an error until the docs builds have been completed.

⏳ 1 Pending, 2 Unrelated Failures

As of commit 4f940a2 with merge base e07e28d (image):

BROKEN TRUNK - The following job failed but were present on the merge base:

👉 Rebase onto the `viable/strict` branch to avoid these failures

UNSTABLE - The following job is marked as unstable, possibly due to flakiness on trunk:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@pytorch-bot
Copy link
Copy Markdown

pytorch-bot bot commented Mar 9, 2026

This PR needs a release notes: label

If your changes are user facing and intended to be a part of release notes, please use a label starting with release notes:.

If not, please add the topic: not user facing label.

To add a label, you can comment to pytorchbot, for example
@pytorchbot label "topic: not user facing"

For more information, see
https://github.com/pytorch/pytorch/wiki/PyTorch-AutoLabel-Bot#why-categorize-for-release-notes-and-how-does-it-work.

@v0i0 v0i0 force-pushed the v0i0/pallas-permutation-detection branch from d020c14 to 38d0c4f Compare March 11, 2026 16:32
@linux-foundation-easycla
Copy link
Copy Markdown

linux-foundation-easycla bot commented Mar 11, 2026

CLA Signed

The committers listed above are authorized under a signed CLA.

  • ✅ login: v0i0 / name: Markus Hoehnerbach (4f940a2)

@v0i0 v0i0 force-pushed the v0i0/pallas-permutation-detection branch from 38d0c4f to f773243 Compare March 11, 2026 17:34
v0i0 added a commit to v0i0/pytorch that referenced this pull request Mar 11, 2026
Includes permutation detection generalization (pytorch#176952) as prerequisite.

Co-authored-by: Claude <noreply@anthropic.com>
@v0i0 v0i0 force-pushed the v0i0/pallas-permutation-detection branch from f773243 to baaf867 Compare March 12, 2026 14:30
@v0i0
Copy link
Copy Markdown
Contributor Author

v0i0 commented Mar 12, 2026

@pytorchbot label "topic: not user facing"

@pytorch-bot pytorch-bot bot added the topic: not user facing topic category label Mar 12, 2026
pytorchmergebot pushed a commit that referenced this pull request Mar 12, 2026
…6622)

## Stack

This is part of a PR stack. Merge order:
1. **#176622 — Pallas: strided access support (this PR)**
2. #176952 — Pallas: permutation detection
3. #177212 — Pallas: scalar prefetch (depends on #176952)

## Summary

- Add strided access support via reshape + static indexing for non-contiguous tensor patterns
- Fix torch_tpu keyword-only API changes in `register_custom_kernel` and `call_custom_kernel`
- Set `jax_default_device` to CPU when running in interpret mode on TPU machines

## Test plan

- Full test suite on TPU: 869 passed, 153 skipped, 69 xfailed, 0 unexpected failures
- `python -m pytest test/inductor/test_pallas.py -v -k "strided"` — all strided access tests pass
Pull Request resolved: #176622
Approved by: https://github.com/oulgen, https://github.com/norx1991
@v0i0 v0i0 force-pushed the v0i0/pallas-permutation-detection branch 10 times, most recently from 898f092 to d291bf8 Compare March 12, 2026 23:45
@v0i0 v0i0 requested review from norx1991 and oulgen March 13, 2026 00:36
@v0i0 v0i0 force-pushed the v0i0/pallas-permutation-detection branch from d291bf8 to 0e54c55 Compare March 13, 2026 03:36
@v0i0 v0i0 force-pushed the v0i0/pallas-permutation-detection branch from 0e54c55 to 6819ccd Compare March 13, 2026 04:48
Generalize the permutation detection logic in the Pallas backend to
handle arbitrary N-D tensor transposes, not just 2D swaps. Adds
collapsed-dimension detection for cases where iteration dimensions
don't match tensor dimensions directly. Skips permutation detection
on GPU where inputs are flattened to 1D.

Also adds cat_unbacked tests to expected failures (pre-existing
dlpack issue on TPU, unrelated to this change).

Authored with assistance from Claude
@v0i0 v0i0 force-pushed the v0i0/pallas-permutation-detection branch from 6819ccd to 4f940a2 Compare March 13, 2026 05:39
@v0i0
Copy link
Copy Markdown
Contributor Author

v0i0 commented Mar 13, 2026

@pytorchbot merge

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label Mar 13, 2026
@pytorchmergebot
Copy link
Copy Markdown
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@skip_if_tpu
def test_permute_contiguous_4d(self):
"""Test all 23 non-identity 4D permutations with multi-tile grids."""
all_perms = [
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use itertools.permutations?

Comment on lines 1222 to 1223
@skip_if_tpu # stack+where fusion doesn't broadcast correctly on TPU yet
def test_rope_interleaved(self):
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will this work now if the collapsed dimension detection is supported?

@pytorchmergebot
Copy link
Copy Markdown
Collaborator

The merge job was canceled or timed out. This most often happen if two merge requests were issued for the same PR, or if merge job was waiting for more than 6 hours for tests to finish. In later case, please do not hesitate to reissue the merge command
For more information see pytorch-bot wiki.

@v0i0
Copy link
Copy Markdown
Contributor Author

v0i0 commented Mar 16, 2026

@pytorchbot merge

@pytorchmergebot
Copy link
Copy Markdown
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@pytorchmergebot
Copy link
Copy Markdown
Collaborator

The merge job was canceled or timed out. This most often happen if two merge requests were issued for the same PR, or if merge job was waiting for more than 6 hours for tests to finish. In later case, please do not hesitate to reissue the merge command
For more information see pytorch-bot wiki.

@v0i0
Copy link
Copy Markdown
Contributor Author

v0i0 commented Mar 16, 2026

@pytorchbot merge -f "tpu tests are passing"

@pytorchmergebot
Copy link
Copy Markdown
Collaborator

Merge started

Your change will be merged immediately since you used the force (-f) flag, bypassing any CI checks (ETA: 1-5 minutes). Please use -f as last resort and instead consider -i/--ignore-current to continue the merge ignoring current failures. This will allow currently pending tests to finish and report signal before the merge.

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

pytorchmergebot pushed a commit that referenced this pull request Mar 25, 2026
## Stack

This is part of a PR stack. Merge order:
1. #176622 — Pallas: strided access support (ready to merge)
2. #176952 — Pallas: permutation detection
3. **#177212 — Pallas: scalar prefetch (this PR, depends on #176952)**

## Summary

- Add scalar prefetch support for TPU Pallas kernels
- Enable indirect access patterns (gather-like operations) via scalar prefetch pipeline
- Index tensors are loaded in a separate prefetch stage for better TPU performance

## Test plan

- `python -m pytest test/inductor/test_pallas.py -v` — full test suite passes on TPU
- Scalar prefetch / indirect access tests pass

Pull Request resolved: #177212
Approved by: https://github.com/oulgen
Copilot AI pushed a commit that referenced this pull request Mar 27, 2026
## Stack

This is part of a PR stack. Merge order:
1. #176622 — Pallas: strided access support (ready to merge)
2. #176952 — Pallas: permutation detection
3. **#177212 — Pallas: scalar prefetch (this PR, depends on #176952)**

## Summary

- Add scalar prefetch support for TPU Pallas kernels
- Enable indirect access patterns (gather-like operations) via scalar prefetch pipeline
- Index tensors are loaded in a separate prefetch stage for better TPU performance

## Test plan

- `python -m pytest test/inductor/test_pallas.py -v` — full test suite passes on TPU
- Scalar prefetch / indirect access tests pass

Pull Request resolved: #177212
Approved by: https://github.com/oulgen

Co-authored-by: Xia-Weiwen <12522207+Xia-Weiwen@users.noreply.github.com>
EmanueleCoradin pushed a commit to EmanueleCoradin/pytorch that referenced this pull request Mar 30, 2026
…orch#176622)

## Stack

This is part of a PR stack. Merge order:
1. **pytorch#176622 — Pallas: strided access support (this PR)**
2. pytorch#176952 — Pallas: permutation detection
3. pytorch#177212 — Pallas: scalar prefetch (depends on pytorch#176952)

## Summary

- Add strided access support via reshape + static indexing for non-contiguous tensor patterns
- Fix torch_tpu keyword-only API changes in `register_custom_kernel` and `call_custom_kernel`
- Set `jax_default_device` to CPU when running in interpret mode on TPU machines

## Test plan

- Full test suite on TPU: 869 passed, 153 skipped, 69 xfailed, 0 unexpected failures
- `python -m pytest test/inductor/test_pallas.py -v -k "strided"` — all strided access tests pass
Pull Request resolved: pytorch#176622
Approved by: https://github.com/oulgen, https://github.com/norx1991
EmanueleCoradin pushed a commit to EmanueleCoradin/pytorch that referenced this pull request Mar 30, 2026
…ytorch#176952)

## Stack

This is part of a PR stack. Merge order:
1. pytorch#176622 — Pallas: strided access support (ready to merge)
2. **pytorch#176952 — Pallas: permutation detection (this PR)**
3. pytorch#177212 — Pallas: scalar prefetch (depends on this PR)

## Summary

- Generalize permutation detection for N-D tensor transposes (not just 2D swaps)
- Add collapsed-dimension detection for cases where iteration dimensions map to grouped tensor dimensions
- Skip permutation detection on GPU (inputs are flattened to 1D by `pallas_gpu_pad_inputs`)
- Add comprehensive permutation tests for 2D and 3D tensors

## Test plan

- `python -m pytest test/inductor/test_pallas.py -v -k "permute"` — all permutation tests pass on TPU
- Full test suite: 878 passed, 154 skipped, 67 xfailed, 102 subtests passed, 0 unexpected failures

Pull Request resolved: pytorch#176952
Approved by: https://github.com/oulgen
AaronWang04 pushed a commit to AaronWang04/pytorch that referenced this pull request Mar 31, 2026
…ytorch#176952)

## Stack

This is part of a PR stack. Merge order:
1. pytorch#176622 — Pallas: strided access support (ready to merge)
2. **pytorch#176952 — Pallas: permutation detection (this PR)**
3. pytorch#177212 — Pallas: scalar prefetch (depends on this PR)

## Summary

- Generalize permutation detection for N-D tensor transposes (not just 2D swaps)
- Add collapsed-dimension detection for cases where iteration dimensions map to grouped tensor dimensions
- Skip permutation detection on GPU (inputs are flattened to 1D by `pallas_gpu_pad_inputs`)
- Add comprehensive permutation tests for 2D and 3D tensors

## Test plan

- `python -m pytest test/inductor/test_pallas.py -v -k "permute"` — all permutation tests pass on TPU
- Full test suite: 878 passed, 154 skipped, 67 xfailed, 102 subtests passed, 0 unexpected failures

Pull Request resolved: pytorch#176952
Approved by: https://github.com/oulgen
AaronWang04 pushed a commit to AaronWang04/pytorch that referenced this pull request Mar 31, 2026
## Stack

This is part of a PR stack. Merge order:
1. pytorch#176622 — Pallas: strided access support (ready to merge)
2. pytorch#176952 — Pallas: permutation detection
3. **pytorch#177212 — Pallas: scalar prefetch (this PR, depends on pytorch#176952)**

## Summary

- Add scalar prefetch support for TPU Pallas kernels
- Enable indirect access patterns (gather-like operations) via scalar prefetch pipeline
- Index tensors are loaded in a separate prefetch stage for better TPU performance

## Test plan

- `python -m pytest test/inductor/test_pallas.py -v` — full test suite passes on TPU
- Scalar prefetch / indirect access tests pass

Pull Request resolved: pytorch#177212
Approved by: https://github.com/oulgen
pytorch-bot bot pushed a commit that referenced this pull request Apr 2, 2026
## Stack

This is part of a PR stack. Merge order:
1. #176622 — Pallas: strided access support (ready to merge)
2. #176952 — Pallas: permutation detection
3. **#177212 — Pallas: scalar prefetch (this PR, depends on #176952)**

## Summary

- Add scalar prefetch support for TPU Pallas kernels
- Enable indirect access patterns (gather-like operations) via scalar prefetch pipeline
- Index tensors are loaded in a separate prefetch stage for better TPU performance

## Test plan

- `python -m pytest test/inductor/test_pallas.py -v` — full test suite passes on TPU
- Scalar prefetch / indirect access tests pass

Pull Request resolved: #177212
Approved by: https://github.com/oulgen
nklshy-aws pushed a commit to nklshy-aws/pytorch that referenced this pull request Apr 7, 2026
## Stack

This is part of a PR stack. Merge order:
1. pytorch#176622 — Pallas: strided access support (ready to merge)
2. pytorch#176952 — Pallas: permutation detection
3. **pytorch#177212 — Pallas: scalar prefetch (this PR, depends on pytorch#176952)**

## Summary

- Add scalar prefetch support for TPU Pallas kernels
- Enable indirect access patterns (gather-like operations) via scalar prefetch pipeline
- Index tensors are loaded in a separate prefetch stage for better TPU performance

## Test plan

- `python -m pytest test/inductor/test_pallas.py -v` — full test suite passes on TPU
- Scalar prefetch / indirect access tests pass

Pull Request resolved: pytorch#177212
Approved by: https://github.com/oulgen
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants