Pallas: generalize permutation detection for N-D tensor transposes#176952
Pallas: generalize permutation detection for N-D tensor transposes#176952
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/176952
Note: Links to docs will display an error until the docs builds have been completed. ⏳ 1 Pending, 2 Unrelated FailuresAs of commit 4f940a2 with merge base e07e28d ( BROKEN TRUNK - The following job failed but were present on the merge base:👉 Rebase onto the `viable/strict` branch to avoid these failures
UNSTABLE - The following job is marked as unstable, possibly due to flakiness on trunk:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
This PR needs a
|
d020c14 to
38d0c4f
Compare
|
|
38d0c4f to
f773243
Compare
Includes permutation detection generalization (pytorch#176952) as prerequisite. Co-authored-by: Claude <noreply@anthropic.com>
f773243 to
baaf867
Compare
|
@pytorchbot label "topic: not user facing" |
…6622) ## Stack This is part of a PR stack. Merge order: 1. **#176622 — Pallas: strided access support (this PR)** 2. #176952 — Pallas: permutation detection 3. #177212 — Pallas: scalar prefetch (depends on #176952) ## Summary - Add strided access support via reshape + static indexing for non-contiguous tensor patterns - Fix torch_tpu keyword-only API changes in `register_custom_kernel` and `call_custom_kernel` - Set `jax_default_device` to CPU when running in interpret mode on TPU machines ## Test plan - Full test suite on TPU: 869 passed, 153 skipped, 69 xfailed, 0 unexpected failures - `python -m pytest test/inductor/test_pallas.py -v -k "strided"` — all strided access tests pass Pull Request resolved: #176622 Approved by: https://github.com/oulgen, https://github.com/norx1991
898f092 to
d291bf8
Compare
d291bf8 to
0e54c55
Compare
0e54c55 to
6819ccd
Compare
Generalize the permutation detection logic in the Pallas backend to handle arbitrary N-D tensor transposes, not just 2D swaps. Adds collapsed-dimension detection for cases where iteration dimensions don't match tensor dimensions directly. Skips permutation detection on GPU where inputs are flattened to 1D. Also adds cat_unbacked tests to expected failures (pre-existing dlpack issue on TPU, unrelated to this change). Authored with assistance from Claude
6819ccd to
4f940a2
Compare
|
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
| @skip_if_tpu | ||
| def test_permute_contiguous_4d(self): | ||
| """Test all 23 non-identity 4D permutations with multi-tile grids.""" | ||
| all_perms = [ |
There was a problem hiding this comment.
Use itertools.permutations?
| @skip_if_tpu # stack+where fusion doesn't broadcast correctly on TPU yet | ||
| def test_rope_interleaved(self): |
There was a problem hiding this comment.
Will this work now if the collapsed dimension detection is supported?
|
The merge job was canceled or timed out. This most often happen if two merge requests were issued for the same PR, or if merge job was waiting for more than 6 hours for tests to finish. In later case, please do not hesitate to reissue the merge command |
|
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
|
The merge job was canceled or timed out. This most often happen if two merge requests were issued for the same PR, or if merge job was waiting for more than 6 hours for tests to finish. In later case, please do not hesitate to reissue the merge command |
|
@pytorchbot merge -f "tpu tests are passing" |
Merge startedYour change will be merged immediately since you used the force (-f) flag, bypassing any CI checks (ETA: 1-5 minutes). Please use Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
## Stack This is part of a PR stack. Merge order: 1. #176622 — Pallas: strided access support (ready to merge) 2. #176952 — Pallas: permutation detection 3. **#177212 — Pallas: scalar prefetch (this PR, depends on #176952)** ## Summary - Add scalar prefetch support for TPU Pallas kernels - Enable indirect access patterns (gather-like operations) via scalar prefetch pipeline - Index tensors are loaded in a separate prefetch stage for better TPU performance ## Test plan - `python -m pytest test/inductor/test_pallas.py -v` — full test suite passes on TPU - Scalar prefetch / indirect access tests pass Pull Request resolved: #177212 Approved by: https://github.com/oulgen
## Stack This is part of a PR stack. Merge order: 1. #176622 — Pallas: strided access support (ready to merge) 2. #176952 — Pallas: permutation detection 3. **#177212 — Pallas: scalar prefetch (this PR, depends on #176952)** ## Summary - Add scalar prefetch support for TPU Pallas kernels - Enable indirect access patterns (gather-like operations) via scalar prefetch pipeline - Index tensors are loaded in a separate prefetch stage for better TPU performance ## Test plan - `python -m pytest test/inductor/test_pallas.py -v` — full test suite passes on TPU - Scalar prefetch / indirect access tests pass Pull Request resolved: #177212 Approved by: https://github.com/oulgen Co-authored-by: Xia-Weiwen <12522207+Xia-Weiwen@users.noreply.github.com>
…orch#176622) ## Stack This is part of a PR stack. Merge order: 1. **pytorch#176622 — Pallas: strided access support (this PR)** 2. pytorch#176952 — Pallas: permutation detection 3. pytorch#177212 — Pallas: scalar prefetch (depends on pytorch#176952) ## Summary - Add strided access support via reshape + static indexing for non-contiguous tensor patterns - Fix torch_tpu keyword-only API changes in `register_custom_kernel` and `call_custom_kernel` - Set `jax_default_device` to CPU when running in interpret mode on TPU machines ## Test plan - Full test suite on TPU: 869 passed, 153 skipped, 69 xfailed, 0 unexpected failures - `python -m pytest test/inductor/test_pallas.py -v -k "strided"` — all strided access tests pass Pull Request resolved: pytorch#176622 Approved by: https://github.com/oulgen, https://github.com/norx1991
…ytorch#176952) ## Stack This is part of a PR stack. Merge order: 1. pytorch#176622 — Pallas: strided access support (ready to merge) 2. **pytorch#176952 — Pallas: permutation detection (this PR)** 3. pytorch#177212 — Pallas: scalar prefetch (depends on this PR) ## Summary - Generalize permutation detection for N-D tensor transposes (not just 2D swaps) - Add collapsed-dimension detection for cases where iteration dimensions map to grouped tensor dimensions - Skip permutation detection on GPU (inputs are flattened to 1D by `pallas_gpu_pad_inputs`) - Add comprehensive permutation tests for 2D and 3D tensors ## Test plan - `python -m pytest test/inductor/test_pallas.py -v -k "permute"` — all permutation tests pass on TPU - Full test suite: 878 passed, 154 skipped, 67 xfailed, 102 subtests passed, 0 unexpected failures Pull Request resolved: pytorch#176952 Approved by: https://github.com/oulgen
…ytorch#176952) ## Stack This is part of a PR stack. Merge order: 1. pytorch#176622 — Pallas: strided access support (ready to merge) 2. **pytorch#176952 — Pallas: permutation detection (this PR)** 3. pytorch#177212 — Pallas: scalar prefetch (depends on this PR) ## Summary - Generalize permutation detection for N-D tensor transposes (not just 2D swaps) - Add collapsed-dimension detection for cases where iteration dimensions map to grouped tensor dimensions - Skip permutation detection on GPU (inputs are flattened to 1D by `pallas_gpu_pad_inputs`) - Add comprehensive permutation tests for 2D and 3D tensors ## Test plan - `python -m pytest test/inductor/test_pallas.py -v -k "permute"` — all permutation tests pass on TPU - Full test suite: 878 passed, 154 skipped, 67 xfailed, 102 subtests passed, 0 unexpected failures Pull Request resolved: pytorch#176952 Approved by: https://github.com/oulgen
## Stack This is part of a PR stack. Merge order: 1. pytorch#176622 — Pallas: strided access support (ready to merge) 2. pytorch#176952 — Pallas: permutation detection 3. **pytorch#177212 — Pallas: scalar prefetch (this PR, depends on pytorch#176952)** ## Summary - Add scalar prefetch support for TPU Pallas kernels - Enable indirect access patterns (gather-like operations) via scalar prefetch pipeline - Index tensors are loaded in a separate prefetch stage for better TPU performance ## Test plan - `python -m pytest test/inductor/test_pallas.py -v` — full test suite passes on TPU - Scalar prefetch / indirect access tests pass Pull Request resolved: pytorch#177212 Approved by: https://github.com/oulgen
## Stack This is part of a PR stack. Merge order: 1. #176622 — Pallas: strided access support (ready to merge) 2. #176952 — Pallas: permutation detection 3. **#177212 — Pallas: scalar prefetch (this PR, depends on #176952)** ## Summary - Add scalar prefetch support for TPU Pallas kernels - Enable indirect access patterns (gather-like operations) via scalar prefetch pipeline - Index tensors are loaded in a separate prefetch stage for better TPU performance ## Test plan - `python -m pytest test/inductor/test_pallas.py -v` — full test suite passes on TPU - Scalar prefetch / indirect access tests pass Pull Request resolved: #177212 Approved by: https://github.com/oulgen
## Stack This is part of a PR stack. Merge order: 1. pytorch#176622 — Pallas: strided access support (ready to merge) 2. pytorch#176952 — Pallas: permutation detection 3. **pytorch#177212 — Pallas: scalar prefetch (this PR, depends on pytorch#176952)** ## Summary - Add scalar prefetch support for TPU Pallas kernels - Enable indirect access patterns (gather-like operations) via scalar prefetch pipeline - Index tensors are loaded in a separate prefetch stage for better TPU performance ## Test plan - `python -m pytest test/inductor/test_pallas.py -v` — full test suite passes on TPU - Scalar prefetch / indirect access tests pass Pull Request resolved: pytorch#177212 Approved by: https://github.com/oulgen
Stack
This is part of a PR stack. Merge order:
Summary
pallas_gpu_pad_inputs)Test plan
python -m pytest test/inductor/test_pallas.py -v -k "permute"— all permutation tests pass on TPUcc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @kadeng @muchulee8 @amjames @chauhang @aakhundov @coconutruben @jataylo