Pallas: add scalar prefetch and indirect access support#177212
Pallas: add scalar prefetch and indirect access support#177212v0i0 wants to merge 3 commits intopytorch:mainfrom
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/177212
Note: Links to docs will display an error until the docs builds have been completed. ⏳ 1 Pending, 5 Unrelated FailuresAs of commit fc47a36 with merge base e26fce2 ( BROKEN TRUNK - The following jobs failed but were present on the merge base:👉 Rebase onto the `viable/strict` branch to avoid these failures
UNSTABLE - The following jobs are marked as unstable, possibly due to flakiness on trunk:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
This PR needs a
|
8eaa2c3 to
1822f77
Compare
5e0ee0d to
4ae179d
Compare
…6622) ## Stack This is part of a PR stack. Merge order: 1. **#176622 — Pallas: strided access support (this PR)** 2. #176952 — Pallas: permutation detection 3. #177212 — Pallas: scalar prefetch (depends on #176952) ## Summary - Add strided access support via reshape + static indexing for non-contiguous tensor patterns - Fix torch_tpu keyword-only API changes in `register_custom_kernel` and `call_custom_kernel` - Set `jax_default_device` to CPU when running in interpret mode on TPU machines ## Test plan - Full test suite on TPU: 869 passed, 153 skipped, 69 xfailed, 0 unexpected failures - `python -m pytest test/inductor/test_pallas.py -v -k "strided"` — all strided access tests pass Pull Request resolved: #176622 Approved by: https://github.com/oulgen, https://github.com/norx1991
…176952) ## Stack This is part of a PR stack. Merge order: 1. #176622 — Pallas: strided access support (ready to merge) 2. **#176952 — Pallas: permutation detection (this PR)** 3. #177212 — Pallas: scalar prefetch (depends on this PR) ## Summary - Generalize permutation detection for N-D tensor transposes (not just 2D swaps) - Add collapsed-dimension detection for cases where iteration dimensions map to grouped tensor dimensions - Skip permutation detection on GPU (inputs are flattened to 1D by `pallas_gpu_pad_inputs`) - Add comprehensive permutation tests for 2D and 3D tensors ## Test plan - `python -m pytest test/inductor/test_pallas.py -v -k "permute"` — all permutation tests pass on TPU - Full test suite: 878 passed, 154 skipped, 67 xfailed, 102 subtests passed, 0 unexpected failures Pull Request resolved: #176952 Approved by: https://github.com/oulgen
4ae179d to
1984cf6
Compare
Scale weights by 1/sqrt(fan_in) instead of hardcoded 0.02 so RMSNorm variance stays ~1 and rsqrt doesn't amplify reduction-order diffs. Authored with Claude.
1984cf6 to
f92c906
Compare
|
@pytorchbot label "topic: not user facing" |
Add scalar prefetch support for TPU Pallas kernels, enabling efficient indirect access patterns (gather-like operations). Scalar prefetch allows index tensors to be loaded in a separate pipeline stage, improving TPU performance for gather/scatter patterns. Authored with assistance from Claude
f92c906 to
4bce3be
Compare
888c2af to
fc47a36
Compare
|
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
|
The merge job was canceled or timed out. This most often happen if two merge requests were issued for the same PR, or if merge job was waiting for more than 6 hours for tests to finish. In later case, please do not hesitate to reissue the merge command |
|
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
|
The merge job was canceled or timed out. This most often happen if two merge requests were issued for the same PR, or if merge job was waiting for more than 6 hours for tests to finish. In later case, please do not hesitate to reissue the merge command |
|
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
|
The merge job was canceled or timed out. This most often happen if two merge requests were issued for the same PR, or if merge job was waiting for more than 6 hours for tests to finish. In later case, please do not hesitate to reissue the merge command |
|
@pytorchbot merge -f "tpu tests are passing" |
Merge startedYour change will be merged immediately since you used the force (-f) flag, bypassing any CI checks (ETA: 1-5 minutes). Please use Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
## Stack This is part of a PR stack. Merge order: 1. #176622 — Pallas: strided access support (ready to merge) 2. #176952 — Pallas: permutation detection 3. **#177212 — Pallas: scalar prefetch (this PR, depends on #176952)** ## Summary - Add scalar prefetch support for TPU Pallas kernels - Enable indirect access patterns (gather-like operations) via scalar prefetch pipeline - Index tensors are loaded in a separate prefetch stage for better TPU performance ## Test plan - `python -m pytest test/inductor/test_pallas.py -v` — full test suite passes on TPU - Scalar prefetch / indirect access tests pass Pull Request resolved: #177212 Approved by: https://github.com/oulgen Co-authored-by: Xia-Weiwen <12522207+Xia-Weiwen@users.noreply.github.com>
…orch#176622) ## Stack This is part of a PR stack. Merge order: 1. **pytorch#176622 — Pallas: strided access support (this PR)** 2. pytorch#176952 — Pallas: permutation detection 3. pytorch#177212 — Pallas: scalar prefetch (depends on pytorch#176952) ## Summary - Add strided access support via reshape + static indexing for non-contiguous tensor patterns - Fix torch_tpu keyword-only API changes in `register_custom_kernel` and `call_custom_kernel` - Set `jax_default_device` to CPU when running in interpret mode on TPU machines ## Test plan - Full test suite on TPU: 869 passed, 153 skipped, 69 xfailed, 0 unexpected failures - `python -m pytest test/inductor/test_pallas.py -v -k "strided"` — all strided access tests pass Pull Request resolved: pytorch#176622 Approved by: https://github.com/oulgen, https://github.com/norx1991
…ytorch#176952) ## Stack This is part of a PR stack. Merge order: 1. pytorch#176622 — Pallas: strided access support (ready to merge) 2. **pytorch#176952 — Pallas: permutation detection (this PR)** 3. pytorch#177212 — Pallas: scalar prefetch (depends on this PR) ## Summary - Generalize permutation detection for N-D tensor transposes (not just 2D swaps) - Add collapsed-dimension detection for cases where iteration dimensions map to grouped tensor dimensions - Skip permutation detection on GPU (inputs are flattened to 1D by `pallas_gpu_pad_inputs`) - Add comprehensive permutation tests for 2D and 3D tensors ## Test plan - `python -m pytest test/inductor/test_pallas.py -v -k "permute"` — all permutation tests pass on TPU - Full test suite: 878 passed, 154 skipped, 67 xfailed, 102 subtests passed, 0 unexpected failures Pull Request resolved: pytorch#176952 Approved by: https://github.com/oulgen
…ytorch#176952) ## Stack This is part of a PR stack. Merge order: 1. pytorch#176622 — Pallas: strided access support (ready to merge) 2. **pytorch#176952 — Pallas: permutation detection (this PR)** 3. pytorch#177212 — Pallas: scalar prefetch (depends on this PR) ## Summary - Generalize permutation detection for N-D tensor transposes (not just 2D swaps) - Add collapsed-dimension detection for cases where iteration dimensions map to grouped tensor dimensions - Skip permutation detection on GPU (inputs are flattened to 1D by `pallas_gpu_pad_inputs`) - Add comprehensive permutation tests for 2D and 3D tensors ## Test plan - `python -m pytest test/inductor/test_pallas.py -v -k "permute"` — all permutation tests pass on TPU - Full test suite: 878 passed, 154 skipped, 67 xfailed, 102 subtests passed, 0 unexpected failures Pull Request resolved: pytorch#176952 Approved by: https://github.com/oulgen
## Stack This is part of a PR stack. Merge order: 1. pytorch#176622 — Pallas: strided access support (ready to merge) 2. pytorch#176952 — Pallas: permutation detection 3. **pytorch#177212 — Pallas: scalar prefetch (this PR, depends on pytorch#176952)** ## Summary - Add scalar prefetch support for TPU Pallas kernels - Enable indirect access patterns (gather-like operations) via scalar prefetch pipeline - Index tensors are loaded in a separate prefetch stage for better TPU performance ## Test plan - `python -m pytest test/inductor/test_pallas.py -v` — full test suite passes on TPU - Scalar prefetch / indirect access tests pass Pull Request resolved: pytorch#177212 Approved by: https://github.com/oulgen
## Stack This is part of a PR stack. Merge order: 1. #176622 — Pallas: strided access support (ready to merge) 2. #176952 — Pallas: permutation detection 3. **#177212 — Pallas: scalar prefetch (this PR, depends on #176952)** ## Summary - Add scalar prefetch support for TPU Pallas kernels - Enable indirect access patterns (gather-like operations) via scalar prefetch pipeline - Index tensors are loaded in a separate prefetch stage for better TPU performance ## Test plan - `python -m pytest test/inductor/test_pallas.py -v` — full test suite passes on TPU - Scalar prefetch / indirect access tests pass Pull Request resolved: #177212 Approved by: https://github.com/oulgen
## Stack This is part of a PR stack. Merge order: 1. pytorch#176622 — Pallas: strided access support (ready to merge) 2. pytorch#176952 — Pallas: permutation detection 3. **pytorch#177212 — Pallas: scalar prefetch (this PR, depends on pytorch#176952)** ## Summary - Add scalar prefetch support for TPU Pallas kernels - Enable indirect access patterns (gather-like operations) via scalar prefetch pipeline - Index tensors are loaded in a separate prefetch stage for better TPU performance ## Test plan - `python -m pytest test/inductor/test_pallas.py -v` — full test suite passes on TPU - Scalar prefetch / indirect access tests pass Pull Request resolved: pytorch#177212 Approved by: https://github.com/oulgen
Stack
This is part of a PR stack. Merge order:
Summary
Test plan
python -m pytest test/inductor/test_pallas.py -v— full test suite passes on TPUcc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @kadeng @muchulee8 @amjames @chauhang @aakhundov @coconutruben @jataylo