Skip to content

Pallas: add strided access support via reshape + static indexing#176622

Closed
v0i0 wants to merge 1 commit intomainfrom
v0i0/pallas-strided-access
Closed

Pallas: add strided access support via reshape + static indexing#176622
v0i0 wants to merge 1 commit intomainfrom
v0i0/pallas-strided-access

Conversation

@v0i0
Copy link
Copy Markdown
Contributor

@v0i0 v0i0 commented Mar 5, 2026

Stack

This is part of a PR stack. Merge order:

  1. Pallas: add strided access support via reshape + static indexing #176622 — Pallas: strided access support (this PR)
  2. Pallas: generalize permutation detection for N-D tensor transposes #176952 — Pallas: permutation detection
  3. Pallas: add scalar prefetch and indirect access support #177212 — Pallas: scalar prefetch (depends on Pallas: generalize permutation detection for N-D tensor transposes #176952)

Summary

  • Add strided access support via reshape + static indexing for non-contiguous tensor patterns
  • Fix torch_tpu keyword-only API changes in register_custom_kernel and call_custom_kernel
  • Set jax_default_device to CPU when running in interpret mode on TPU machines

Test plan

  • Full test suite on TPU: 869 passed, 153 skipped, 69 xfailed, 0 unexpected failures
  • python -m pytest test/inductor/test_pallas.py -v -k "strided" — all strided access tests pass

cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @kadeng @muchulee8 @amjames @chauhang @aakhundov @coconutruben @jataylo

@v0i0 v0i0 requested review from norx1991 and oulgen March 5, 2026 18:31
@pytorch-bot
Copy link
Copy Markdown

pytorch-bot bot commented Mar 5, 2026

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/176622

Note: Links to docs will display an error until the docs builds have been completed.

⏳ 88 Pending, 1 Unrelated Failure

As of commit 1a6dce8 with merge base 4cce831 (image):

BROKEN TRUNK - The following job failed but were present on the merge base:

👉 Rebase onto the `viable/strict` branch to avoid these failures

UNSTABLE - The following jobs are marked as unstable, possibly due to flakiness on trunk:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@pytorch-bot
Copy link
Copy Markdown

pytorch-bot bot commented Mar 5, 2026

This PR needs a release notes: label

If your changes are user facing and intended to be a part of release notes, please use a label starting with release notes:.

If not, please add the topic: not user facing label.

To add a label, you can comment to pytorchbot, for example
@pytorchbot label "topic: not user facing"

For more information, see
https://github.com/pytorch/pytorch/wiki/PyTorch-AutoLabel-Bot#why-categorize-for-release-notes-and-how-does-it-work.

self.assertEqual(result, expected)

@skip_if_tpu
@skip_if_cuda
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This means the GPU path is now broken? Same for the test below.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

its just moving the skip up

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, I see. It is good to unify the style for "skip" then. Let's keep the comment in the previous line 469 if it is still valid?

@v0i0 v0i0 force-pushed the v0i0/pallas-strided-access branch from 26874e8 to e64c4e7 Compare March 6, 2026 00:38

@skip_if_tpu
@skip_if_cuda
def test_strided_int_pallas(self):
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I looked at the generated kernel here:

  def pallas_fused_mul_slice_27558c9d_kernel(out_ptr0_alias, in_ptr0, out_ptr0):                                                                                               
      x0 = jnp.arange(8)                                                                                                                                                       
      tmp0 = in_ptr0[:, 0]                          # static index after reshape                                                                                               
      tmp1 = jnp.array(2.0, dtype=jnp.float32)                                                                                                                                 
      tmp2 = tmp0 * tmp1                                    
      _val = jnp.asarray(tmp2)
      out_ptr0[...] = ...  # broadcast/reshape store

I wonder if we can bring the reshape logic inside the kernel itself (may be in another PR). Also, is it possible to remove the first line that is not useful anymore?

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also it seems the most straightforward approach will be to use Pallas' strided view access pattern (::2). Is there some blocker here?

@v0i0
Copy link
Copy Markdown
Contributor Author

v0i0 commented Mar 10, 2026

@pytorchbot merge

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label Mar 10, 2026
@pytorchmergebot
Copy link
Copy Markdown
Collaborator

Merge failed

Reason: This PR needs a release notes: label
If your changes are user facing and intended to be a part of release notes, please use a label starting with release notes:.

If not, please add the topic: not user facing label.

To add a label, you can comment to pytorchbot, for example
@pytorchbot label "topic: not user facing"

For more information, see
https://github.com/pytorch/pytorch/wiki/PyTorch-AutoLabel-Bot#why-categorize-for-release-notes-and-how-does-it-work.

Details for Dev Infra team Raised by workflow job

@pytorch-bot
Copy link
Copy Markdown

pytorch-bot bot commented Mar 10, 2026

This PR needs a release notes: label

If your changes are user facing and intended to be a part of release notes, please use a label starting with release notes:.

If not, please add the topic: not user facing label.

To add a label, you can comment to pytorchbot, for example
@pytorchbot label "topic: not user facing"

For more information, see
https://github.com/pytorch/pytorch/wiki/PyTorch-AutoLabel-Bot#why-categorize-for-release-notes-and-how-does-it-work.

@v0i0
Copy link
Copy Markdown
Contributor Author

v0i0 commented Mar 10, 2026

@pytorchbot label "topic: not user facing"

@v0i0
Copy link
Copy Markdown
Contributor Author

v0i0 commented Mar 10, 2026

@pytorchbot merge

@pytorchmergebot
Copy link
Copy Markdown
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@pytorchmergebot
Copy link
Copy Markdown
Collaborator

Merge failed

Reason: Command git -C /home/runner/work/pytorch/pytorch merge --squash __pull-request-176622__init__ returned non-zero exit code 1

Auto-merging test/inductor/test_pallas.py
Auto-merging torch/_inductor/codegen/pallas.py
CONFLICT (content): Merge conflict in torch/_inductor/codegen/pallas.py
Squash commit -- not updating HEAD
Automatic merge failed; fix conflicts and then commit the result.
Details for Dev Infra team Raised by workflow job

@v0i0 v0i0 force-pushed the v0i0/pallas-strided-access branch from e64c4e7 to 03954cc Compare March 11, 2026 16:32
@linux-foundation-easycla
Copy link
Copy Markdown

linux-foundation-easycla bot commented Mar 11, 2026

CLA Signed

The committers listed above are authorized under a signed CLA.

  • ✅ login: v0i0 / name: Markus Hoehnerbach (1a6dce8)

@pytorch-bot pytorch-bot bot added the ciflow/torchtitan Run TorchTitan integration tests label Mar 11, 2026
@v0i0 v0i0 force-pushed the v0i0/pallas-strided-access branch from 03954cc to 5260783 Compare March 11, 2026 17:34
Add support for non-contiguous tensor access patterns in the Pallas
backend by reshaping tensors and using static indexing. Also adds
compatibility with both old and new torch_tpu custom kernel APIs
(positional vs keyword-only args), and sets jax_default_device to CPU
when running in interpret mode on TPU machines.

Authored with assistance from Claude
@v0i0 v0i0 force-pushed the v0i0/pallas-strided-access branch from eed5e04 to 1a6dce8 Compare March 12, 2026 17:27
@v0i0
Copy link
Copy Markdown
Contributor Author

v0i0 commented Mar 12, 2026

@pytorchbot merge -f "tpu passed"

@pytorchmergebot
Copy link
Copy Markdown
Collaborator

Merge started

Your change will be merged immediately since you used the force (-f) flag, bypassing any CI checks (ETA: 1-5 minutes). Please use -f as last resort and instead consider -i/--ignore-current to continue the merge ignoring current failures. This will allow currently pending tests to finish and report signal before the merge.

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

pytorchmergebot pushed a commit that referenced this pull request Mar 16, 2026
…176952)

## Stack

This is part of a PR stack. Merge order:
1. #176622 — Pallas: strided access support (ready to merge)
2. **#176952 — Pallas: permutation detection (this PR)**
3. #177212 — Pallas: scalar prefetch (depends on this PR)

## Summary

- Generalize permutation detection for N-D tensor transposes (not just 2D swaps)
- Add collapsed-dimension detection for cases where iteration dimensions map to grouped tensor dimensions
- Skip permutation detection on GPU (inputs are flattened to 1D by `pallas_gpu_pad_inputs`)
- Add comprehensive permutation tests for 2D and 3D tensors

## Test plan

- `python -m pytest test/inductor/test_pallas.py -v -k "permute"` — all permutation tests pass on TPU
- Full test suite: 878 passed, 154 skipped, 67 xfailed, 102 subtests passed, 0 unexpected failures

Pull Request resolved: #176952
Approved by: https://github.com/oulgen
pytorchmergebot pushed a commit that referenced this pull request Mar 25, 2026
## Stack

This is part of a PR stack. Merge order:
1. #176622 — Pallas: strided access support (ready to merge)
2. #176952 — Pallas: permutation detection
3. **#177212 — Pallas: scalar prefetch (this PR, depends on #176952)**

## Summary

- Add scalar prefetch support for TPU Pallas kernels
- Enable indirect access patterns (gather-like operations) via scalar prefetch pipeline
- Index tensors are loaded in a separate prefetch stage for better TPU performance

## Test plan

- `python -m pytest test/inductor/test_pallas.py -v` — full test suite passes on TPU
- Scalar prefetch / indirect access tests pass

Pull Request resolved: #177212
Approved by: https://github.com/oulgen
Copilot AI pushed a commit that referenced this pull request Mar 27, 2026
## Stack

This is part of a PR stack. Merge order:
1. #176622 — Pallas: strided access support (ready to merge)
2. #176952 — Pallas: permutation detection
3. **#177212 — Pallas: scalar prefetch (this PR, depends on #176952)**

## Summary

- Add scalar prefetch support for TPU Pallas kernels
- Enable indirect access patterns (gather-like operations) via scalar prefetch pipeline
- Index tensors are loaded in a separate prefetch stage for better TPU performance

## Test plan

- `python -m pytest test/inductor/test_pallas.py -v` — full test suite passes on TPU
- Scalar prefetch / indirect access tests pass

Pull Request resolved: #177212
Approved by: https://github.com/oulgen

Co-authored-by: Xia-Weiwen <12522207+Xia-Weiwen@users.noreply.github.com>
EmanueleCoradin pushed a commit to EmanueleCoradin/pytorch that referenced this pull request Mar 30, 2026
…orch#176622)

## Stack

This is part of a PR stack. Merge order:
1. **pytorch#176622 — Pallas: strided access support (this PR)**
2. pytorch#176952 — Pallas: permutation detection
3. pytorch#177212 — Pallas: scalar prefetch (depends on pytorch#176952)

## Summary

- Add strided access support via reshape + static indexing for non-contiguous tensor patterns
- Fix torch_tpu keyword-only API changes in `register_custom_kernel` and `call_custom_kernel`
- Set `jax_default_device` to CPU when running in interpret mode on TPU machines

## Test plan

- Full test suite on TPU: 869 passed, 153 skipped, 69 xfailed, 0 unexpected failures
- `python -m pytest test/inductor/test_pallas.py -v -k "strided"` — all strided access tests pass
Pull Request resolved: pytorch#176622
Approved by: https://github.com/oulgen, https://github.com/norx1991
EmanueleCoradin pushed a commit to EmanueleCoradin/pytorch that referenced this pull request Mar 30, 2026
…ytorch#176952)

## Stack

This is part of a PR stack. Merge order:
1. pytorch#176622 — Pallas: strided access support (ready to merge)
2. **pytorch#176952 — Pallas: permutation detection (this PR)**
3. pytorch#177212 — Pallas: scalar prefetch (depends on this PR)

## Summary

- Generalize permutation detection for N-D tensor transposes (not just 2D swaps)
- Add collapsed-dimension detection for cases where iteration dimensions map to grouped tensor dimensions
- Skip permutation detection on GPU (inputs are flattened to 1D by `pallas_gpu_pad_inputs`)
- Add comprehensive permutation tests for 2D and 3D tensors

## Test plan

- `python -m pytest test/inductor/test_pallas.py -v -k "permute"` — all permutation tests pass on TPU
- Full test suite: 878 passed, 154 skipped, 67 xfailed, 102 subtests passed, 0 unexpected failures

Pull Request resolved: pytorch#176952
Approved by: https://github.com/oulgen
AaronWang04 pushed a commit to AaronWang04/pytorch that referenced this pull request Mar 31, 2026
…ytorch#176952)

## Stack

This is part of a PR stack. Merge order:
1. pytorch#176622 — Pallas: strided access support (ready to merge)
2. **pytorch#176952 — Pallas: permutation detection (this PR)**
3. pytorch#177212 — Pallas: scalar prefetch (depends on this PR)

## Summary

- Generalize permutation detection for N-D tensor transposes (not just 2D swaps)
- Add collapsed-dimension detection for cases where iteration dimensions map to grouped tensor dimensions
- Skip permutation detection on GPU (inputs are flattened to 1D by `pallas_gpu_pad_inputs`)
- Add comprehensive permutation tests for 2D and 3D tensors

## Test plan

- `python -m pytest test/inductor/test_pallas.py -v -k "permute"` — all permutation tests pass on TPU
- Full test suite: 878 passed, 154 skipped, 67 xfailed, 102 subtests passed, 0 unexpected failures

Pull Request resolved: pytorch#176952
Approved by: https://github.com/oulgen
AaronWang04 pushed a commit to AaronWang04/pytorch that referenced this pull request Mar 31, 2026
## Stack

This is part of a PR stack. Merge order:
1. pytorch#176622 — Pallas: strided access support (ready to merge)
2. pytorch#176952 — Pallas: permutation detection
3. **pytorch#177212 — Pallas: scalar prefetch (this PR, depends on pytorch#176952)**

## Summary

- Add scalar prefetch support for TPU Pallas kernels
- Enable indirect access patterns (gather-like operations) via scalar prefetch pipeline
- Index tensors are loaded in a separate prefetch stage for better TPU performance

## Test plan

- `python -m pytest test/inductor/test_pallas.py -v` — full test suite passes on TPU
- Scalar prefetch / indirect access tests pass

Pull Request resolved: pytorch#177212
Approved by: https://github.com/oulgen
pytorch-bot bot pushed a commit that referenced this pull request Apr 2, 2026
## Stack

This is part of a PR stack. Merge order:
1. #176622 — Pallas: strided access support (ready to merge)
2. #176952 — Pallas: permutation detection
3. **#177212 — Pallas: scalar prefetch (this PR, depends on #176952)**

## Summary

- Add scalar prefetch support for TPU Pallas kernels
- Enable indirect access patterns (gather-like operations) via scalar prefetch pipeline
- Index tensors are loaded in a separate prefetch stage for better TPU performance

## Test plan

- `python -m pytest test/inductor/test_pallas.py -v` — full test suite passes on TPU
- Scalar prefetch / indirect access tests pass

Pull Request resolved: #177212
Approved by: https://github.com/oulgen
nklshy-aws pushed a commit to nklshy-aws/pytorch that referenced this pull request Apr 7, 2026
## Stack

This is part of a PR stack. Merge order:
1. pytorch#176622 — Pallas: strided access support (ready to merge)
2. pytorch#176952 — Pallas: permutation detection
3. **pytorch#177212 — Pallas: scalar prefetch (this PR, depends on pytorch#176952)**

## Summary

- Add scalar prefetch support for TPU Pallas kernels
- Enable indirect access patterns (gather-like operations) via scalar prefetch pipeline
- Index tensors are loaded in a separate prefetch stage for better TPU performance

## Test plan

- `python -m pytest test/inductor/test_pallas.py -v` — full test suite passes on TPU
- Scalar prefetch / indirect access tests pass

Pull Request resolved: pytorch#177212
Approved by: https://github.com/oulgen
@github-actions github-actions bot deleted the v0i0/pallas-strided-access branch April 12, 2026 02:25
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants