Skip to content

Pallas: add scalar prefetch and indirect access support#177212

Closed
v0i0 wants to merge 3 commits intopytorch:mainfrom
v0i0:v0i0/pallas-scalar-prefetch
Closed

Pallas: add scalar prefetch and indirect access support#177212
v0i0 wants to merge 3 commits intopytorch:mainfrom
v0i0:v0i0/pallas-scalar-prefetch

Conversation

@v0i0
Copy link
Copy Markdown
Contributor

@v0i0 v0i0 commented Mar 11, 2026

Stack

This is part of a PR stack. Merge order:

  1. Pallas: add strided access support via reshape + static indexing #176622 — Pallas: strided access support (ready to merge)
  2. Pallas: generalize permutation detection for N-D tensor transposes #176952 — Pallas: permutation detection
  3. Pallas: add scalar prefetch and indirect access support #177212 — Pallas: scalar prefetch (this PR, depends on Pallas: generalize permutation detection for N-D tensor transposes #176952)

Summary

  • Add scalar prefetch support for TPU Pallas kernels
  • Enable indirect access patterns (gather-like operations) via scalar prefetch pipeline
  • Index tensors are loaded in a separate prefetch stage for better TPU performance

Test plan

  • python -m pytest test/inductor/test_pallas.py -v — full test suite passes on TPU
  • Scalar prefetch / indirect access tests pass

cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @kadeng @muchulee8 @amjames @chauhang @aakhundov @coconutruben @jataylo

@pytorch-bot
Copy link
Copy Markdown

pytorch-bot bot commented Mar 11, 2026

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/177212

Note: Links to docs will display an error until the docs builds have been completed.

⏳ 1 Pending, 5 Unrelated Failures

As of commit fc47a36 with merge base e26fce2 (image):

BROKEN TRUNK - The following jobs failed but were present on the merge base:

👉 Rebase onto the `viable/strict` branch to avoid these failures

UNSTABLE - The following jobs are marked as unstable, possibly due to flakiness on trunk:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@pytorch-bot
Copy link
Copy Markdown

pytorch-bot bot commented Mar 11, 2026

This PR needs a release notes: label

If your changes are user facing and intended to be a part of release notes, please use a label starting with release notes:.

If not, please add the topic: not user facing label.

To add a label, you can comment to pytorchbot, for example
@pytorchbot label "topic: not user facing"

For more information, see
https://github.com/pytorch/pytorch/wiki/PyTorch-AutoLabel-Bot#why-categorize-for-release-notes-and-how-does-it-work.

@v0i0 v0i0 force-pushed the v0i0/pallas-scalar-prefetch branch from 8eaa2c3 to 1822f77 Compare March 11, 2026 22:47
@linux-foundation-easycla
Copy link
Copy Markdown

linux-foundation-easycla bot commented Mar 11, 2026

CLA Signed

The committers listed above are authorized under a signed CLA.

@v0i0 v0i0 force-pushed the v0i0/pallas-scalar-prefetch branch 3 times, most recently from 5e0ee0d to 4ae179d Compare March 12, 2026 14:02
pytorchmergebot pushed a commit that referenced this pull request Mar 12, 2026
…6622)

## Stack

This is part of a PR stack. Merge order:
1. **#176622 — Pallas: strided access support (this PR)**
2. #176952 — Pallas: permutation detection
3. #177212 — Pallas: scalar prefetch (depends on #176952)

## Summary

- Add strided access support via reshape + static indexing for non-contiguous tensor patterns
- Fix torch_tpu keyword-only API changes in `register_custom_kernel` and `call_custom_kernel`
- Set `jax_default_device` to CPU when running in interpret mode on TPU machines

## Test plan

- Full test suite on TPU: 869 passed, 153 skipped, 69 xfailed, 0 unexpected failures
- `python -m pytest test/inductor/test_pallas.py -v -k "strided"` — all strided access tests pass
Pull Request resolved: #176622
Approved by: https://github.com/oulgen, https://github.com/norx1991
pytorchmergebot pushed a commit that referenced this pull request Mar 16, 2026
…176952)

## Stack

This is part of a PR stack. Merge order:
1. #176622 — Pallas: strided access support (ready to merge)
2. **#176952 — Pallas: permutation detection (this PR)**
3. #177212 — Pallas: scalar prefetch (depends on this PR)

## Summary

- Generalize permutation detection for N-D tensor transposes (not just 2D swaps)
- Add collapsed-dimension detection for cases where iteration dimensions map to grouped tensor dimensions
- Skip permutation detection on GPU (inputs are flattened to 1D by `pallas_gpu_pad_inputs`)
- Add comprehensive permutation tests for 2D and 3D tensors

## Test plan

- `python -m pytest test/inductor/test_pallas.py -v -k "permute"` — all permutation tests pass on TPU
- Full test suite: 878 passed, 154 skipped, 67 xfailed, 102 subtests passed, 0 unexpected failures

Pull Request resolved: #176952
Approved by: https://github.com/oulgen
@v0i0 v0i0 force-pushed the v0i0/pallas-scalar-prefetch branch from 4ae179d to 1984cf6 Compare March 17, 2026 19:36
Scale weights by 1/sqrt(fan_in) instead of hardcoded 0.02 so RMSNorm
variance stays ~1 and rsqrt doesn't amplify reduction-order diffs.

Authored with Claude.
@v0i0 v0i0 force-pushed the v0i0/pallas-scalar-prefetch branch from 1984cf6 to f92c906 Compare March 18, 2026 22:19
@v0i0
Copy link
Copy Markdown
Contributor Author

v0i0 commented Mar 19, 2026

@pytorchbot label "topic: not user facing"

Add scalar prefetch support for TPU Pallas kernels, enabling efficient
indirect access patterns (gather-like operations). Scalar prefetch
allows index tensors to be loaded in a separate pipeline stage,
improving TPU performance for gather/scatter patterns.

Authored with assistance from Claude
@pytorch-bot pytorch-bot bot added the topic: not user facing topic category label Mar 19, 2026
@v0i0 v0i0 requested review from AmesingFlank and oulgen March 19, 2026 00:54
@v0i0 v0i0 force-pushed the v0i0/pallas-scalar-prefetch branch from f92c906 to 4bce3be Compare March 19, 2026 00:54
@v0i0 v0i0 force-pushed the v0i0/pallas-scalar-prefetch branch from 888c2af to fc47a36 Compare March 24, 2026 20:03
@v0i0
Copy link
Copy Markdown
Contributor Author

v0i0 commented Mar 24, 2026

@pytorchbot merge

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label Mar 24, 2026
@pytorchmergebot
Copy link
Copy Markdown
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@pytorchmergebot
Copy link
Copy Markdown
Collaborator

The merge job was canceled or timed out. This most often happen if two merge requests were issued for the same PR, or if merge job was waiting for more than 6 hours for tests to finish. In later case, please do not hesitate to reissue the merge command
For more information see pytorch-bot wiki.

@v0i0
Copy link
Copy Markdown
Contributor Author

v0i0 commented Mar 25, 2026

@pytorchbot merge

@pytorchmergebot
Copy link
Copy Markdown
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@pytorchmergebot
Copy link
Copy Markdown
Collaborator

The merge job was canceled or timed out. This most often happen if two merge requests were issued for the same PR, or if merge job was waiting for more than 6 hours for tests to finish. In later case, please do not hesitate to reissue the merge command
For more information see pytorch-bot wiki.

@v0i0
Copy link
Copy Markdown
Contributor Author

v0i0 commented Mar 25, 2026

@pytorchbot merge

@pytorchmergebot
Copy link
Copy Markdown
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@pytorchmergebot
Copy link
Copy Markdown
Collaborator

The merge job was canceled or timed out. This most often happen if two merge requests were issued for the same PR, or if merge job was waiting for more than 6 hours for tests to finish. In later case, please do not hesitate to reissue the merge command
For more information see pytorch-bot wiki.

@v0i0
Copy link
Copy Markdown
Contributor Author

v0i0 commented Mar 25, 2026

@pytorchbot merge -f "tpu tests are passing"

@pytorchmergebot
Copy link
Copy Markdown
Collaborator

Merge started

Your change will be merged immediately since you used the force (-f) flag, bypassing any CI checks (ETA: 1-5 minutes). Please use -f as last resort and instead consider -i/--ignore-current to continue the merge ignoring current failures. This will allow currently pending tests to finish and report signal before the merge.

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

Copilot AI pushed a commit that referenced this pull request Mar 27, 2026
## Stack

This is part of a PR stack. Merge order:
1. #176622 — Pallas: strided access support (ready to merge)
2. #176952 — Pallas: permutation detection
3. **#177212 — Pallas: scalar prefetch (this PR, depends on #176952)**

## Summary

- Add scalar prefetch support for TPU Pallas kernels
- Enable indirect access patterns (gather-like operations) via scalar prefetch pipeline
- Index tensors are loaded in a separate prefetch stage for better TPU performance

## Test plan

- `python -m pytest test/inductor/test_pallas.py -v` — full test suite passes on TPU
- Scalar prefetch / indirect access tests pass

Pull Request resolved: #177212
Approved by: https://github.com/oulgen

Co-authored-by: Xia-Weiwen <12522207+Xia-Weiwen@users.noreply.github.com>
EmanueleCoradin pushed a commit to EmanueleCoradin/pytorch that referenced this pull request Mar 30, 2026
…orch#176622)

## Stack

This is part of a PR stack. Merge order:
1. **pytorch#176622 — Pallas: strided access support (this PR)**
2. pytorch#176952 — Pallas: permutation detection
3. pytorch#177212 — Pallas: scalar prefetch (depends on pytorch#176952)

## Summary

- Add strided access support via reshape + static indexing for non-contiguous tensor patterns
- Fix torch_tpu keyword-only API changes in `register_custom_kernel` and `call_custom_kernel`
- Set `jax_default_device` to CPU when running in interpret mode on TPU machines

## Test plan

- Full test suite on TPU: 869 passed, 153 skipped, 69 xfailed, 0 unexpected failures
- `python -m pytest test/inductor/test_pallas.py -v -k "strided"` — all strided access tests pass
Pull Request resolved: pytorch#176622
Approved by: https://github.com/oulgen, https://github.com/norx1991
EmanueleCoradin pushed a commit to EmanueleCoradin/pytorch that referenced this pull request Mar 30, 2026
…ytorch#176952)

## Stack

This is part of a PR stack. Merge order:
1. pytorch#176622 — Pallas: strided access support (ready to merge)
2. **pytorch#176952 — Pallas: permutation detection (this PR)**
3. pytorch#177212 — Pallas: scalar prefetch (depends on this PR)

## Summary

- Generalize permutation detection for N-D tensor transposes (not just 2D swaps)
- Add collapsed-dimension detection for cases where iteration dimensions map to grouped tensor dimensions
- Skip permutation detection on GPU (inputs are flattened to 1D by `pallas_gpu_pad_inputs`)
- Add comprehensive permutation tests for 2D and 3D tensors

## Test plan

- `python -m pytest test/inductor/test_pallas.py -v -k "permute"` — all permutation tests pass on TPU
- Full test suite: 878 passed, 154 skipped, 67 xfailed, 102 subtests passed, 0 unexpected failures

Pull Request resolved: pytorch#176952
Approved by: https://github.com/oulgen
AaronWang04 pushed a commit to AaronWang04/pytorch that referenced this pull request Mar 31, 2026
…ytorch#176952)

## Stack

This is part of a PR stack. Merge order:
1. pytorch#176622 — Pallas: strided access support (ready to merge)
2. **pytorch#176952 — Pallas: permutation detection (this PR)**
3. pytorch#177212 — Pallas: scalar prefetch (depends on this PR)

## Summary

- Generalize permutation detection for N-D tensor transposes (not just 2D swaps)
- Add collapsed-dimension detection for cases where iteration dimensions map to grouped tensor dimensions
- Skip permutation detection on GPU (inputs are flattened to 1D by `pallas_gpu_pad_inputs`)
- Add comprehensive permutation tests for 2D and 3D tensors

## Test plan

- `python -m pytest test/inductor/test_pallas.py -v -k "permute"` — all permutation tests pass on TPU
- Full test suite: 878 passed, 154 skipped, 67 xfailed, 102 subtests passed, 0 unexpected failures

Pull Request resolved: pytorch#176952
Approved by: https://github.com/oulgen
AaronWang04 pushed a commit to AaronWang04/pytorch that referenced this pull request Mar 31, 2026
## Stack

This is part of a PR stack. Merge order:
1. pytorch#176622 — Pallas: strided access support (ready to merge)
2. pytorch#176952 — Pallas: permutation detection
3. **pytorch#177212 — Pallas: scalar prefetch (this PR, depends on pytorch#176952)**

## Summary

- Add scalar prefetch support for TPU Pallas kernels
- Enable indirect access patterns (gather-like operations) via scalar prefetch pipeline
- Index tensors are loaded in a separate prefetch stage for better TPU performance

## Test plan

- `python -m pytest test/inductor/test_pallas.py -v` — full test suite passes on TPU
- Scalar prefetch / indirect access tests pass

Pull Request resolved: pytorch#177212
Approved by: https://github.com/oulgen
pytorch-bot bot pushed a commit that referenced this pull request Apr 2, 2026
## Stack

This is part of a PR stack. Merge order:
1. #176622 — Pallas: strided access support (ready to merge)
2. #176952 — Pallas: permutation detection
3. **#177212 — Pallas: scalar prefetch (this PR, depends on #176952)**

## Summary

- Add scalar prefetch support for TPU Pallas kernels
- Enable indirect access patterns (gather-like operations) via scalar prefetch pipeline
- Index tensors are loaded in a separate prefetch stage for better TPU performance

## Test plan

- `python -m pytest test/inductor/test_pallas.py -v` — full test suite passes on TPU
- Scalar prefetch / indirect access tests pass

Pull Request resolved: #177212
Approved by: https://github.com/oulgen
nklshy-aws pushed a commit to nklshy-aws/pytorch that referenced this pull request Apr 7, 2026
## Stack

This is part of a PR stack. Merge order:
1. pytorch#176622 — Pallas: strided access support (ready to merge)
2. pytorch#176952 — Pallas: permutation detection
3. **pytorch#177212 — Pallas: scalar prefetch (this PR, depends on pytorch#176952)**

## Summary

- Add scalar prefetch support for TPU Pallas kernels
- Enable indirect access patterns (gather-like operations) via scalar prefetch pipeline
- Index tensors are loaded in a separate prefetch stage for better TPU performance

## Test plan

- `python -m pytest test/inductor/test_pallas.py -v` — full test suite passes on TPU
- Scalar prefetch / indirect access tests pass

Pull Request resolved: pytorch#177212
Approved by: https://github.com/oulgen
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants