Pallas: add strided access support via reshape + static indexing#176622
Pallas: add strided access support via reshape + static indexing#176622
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/176622
Note: Links to docs will display an error until the docs builds have been completed. ⏳ 88 Pending, 1 Unrelated FailureAs of commit 1a6dce8 with merge base 4cce831 ( BROKEN TRUNK - The following job failed but were present on the merge base:👉 Rebase onto the `viable/strict` branch to avoid these failures
UNSTABLE - The following jobs are marked as unstable, possibly due to flakiness on trunk:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
This PR needs a
|
test/inductor/test_pallas.py
Outdated
| self.assertEqual(result, expected) | ||
|
|
||
| @skip_if_tpu | ||
| @skip_if_cuda |
There was a problem hiding this comment.
This means the GPU path is now broken? Same for the test below.
There was a problem hiding this comment.
its just moving the skip up
There was a problem hiding this comment.
Oh, I see. It is good to unify the style for "skip" then. Let's keep the comment in the previous line 469 if it is still valid?
26874e8 to
e64c4e7
Compare
|
|
||
| @skip_if_tpu | ||
| @skip_if_cuda | ||
| def test_strided_int_pallas(self): |
There was a problem hiding this comment.
I looked at the generated kernel here:
def pallas_fused_mul_slice_27558c9d_kernel(out_ptr0_alias, in_ptr0, out_ptr0):
x0 = jnp.arange(8)
tmp0 = in_ptr0[:, 0] # static index after reshape
tmp1 = jnp.array(2.0, dtype=jnp.float32)
tmp2 = tmp0 * tmp1
_val = jnp.asarray(tmp2)
out_ptr0[...] = ... # broadcast/reshape store
I wonder if we can bring the reshape logic inside the kernel itself (may be in another PR). Also, is it possible to remove the first line that is not useful anymore?
There was a problem hiding this comment.
Also it seems the most straightforward approach will be to use Pallas' strided view access pattern (::2). Is there some blocker here?
|
@pytorchbot merge |
Merge failedReason: This PR needs a If not, please add the To add a label, you can comment to pytorchbot, for example For more information, see Details for Dev Infra teamRaised by workflow job |
This PR needs a
|
|
@pytorchbot label "topic: not user facing" |
|
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
Merge failedReason: Command Details for Dev Infra teamRaised by workflow job |
e64c4e7 to
03954cc
Compare
|
|
03954cc to
5260783
Compare
0fac774 to
7a6bc8e
Compare
Add support for non-contiguous tensor access patterns in the Pallas backend by reshaping tensors and using static indexing. Also adds compatibility with both old and new torch_tpu custom kernel APIs (positional vs keyword-only args), and sets jax_default_device to CPU when running in interpret mode on TPU machines. Authored with assistance from Claude
eed5e04 to
1a6dce8
Compare
|
@pytorchbot merge -f "tpu passed" |
Merge startedYour change will be merged immediately since you used the force (-f) flag, bypassing any CI checks (ETA: 1-5 minutes). Please use Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
…176952) ## Stack This is part of a PR stack. Merge order: 1. #176622 — Pallas: strided access support (ready to merge) 2. **#176952 — Pallas: permutation detection (this PR)** 3. #177212 — Pallas: scalar prefetch (depends on this PR) ## Summary - Generalize permutation detection for N-D tensor transposes (not just 2D swaps) - Add collapsed-dimension detection for cases where iteration dimensions map to grouped tensor dimensions - Skip permutation detection on GPU (inputs are flattened to 1D by `pallas_gpu_pad_inputs`) - Add comprehensive permutation tests for 2D and 3D tensors ## Test plan - `python -m pytest test/inductor/test_pallas.py -v -k "permute"` — all permutation tests pass on TPU - Full test suite: 878 passed, 154 skipped, 67 xfailed, 102 subtests passed, 0 unexpected failures Pull Request resolved: #176952 Approved by: https://github.com/oulgen
## Stack This is part of a PR stack. Merge order: 1. #176622 — Pallas: strided access support (ready to merge) 2. #176952 — Pallas: permutation detection 3. **#177212 — Pallas: scalar prefetch (this PR, depends on #176952)** ## Summary - Add scalar prefetch support for TPU Pallas kernels - Enable indirect access patterns (gather-like operations) via scalar prefetch pipeline - Index tensors are loaded in a separate prefetch stage for better TPU performance ## Test plan - `python -m pytest test/inductor/test_pallas.py -v` — full test suite passes on TPU - Scalar prefetch / indirect access tests pass Pull Request resolved: #177212 Approved by: https://github.com/oulgen
## Stack This is part of a PR stack. Merge order: 1. #176622 — Pallas: strided access support (ready to merge) 2. #176952 — Pallas: permutation detection 3. **#177212 — Pallas: scalar prefetch (this PR, depends on #176952)** ## Summary - Add scalar prefetch support for TPU Pallas kernels - Enable indirect access patterns (gather-like operations) via scalar prefetch pipeline - Index tensors are loaded in a separate prefetch stage for better TPU performance ## Test plan - `python -m pytest test/inductor/test_pallas.py -v` — full test suite passes on TPU - Scalar prefetch / indirect access tests pass Pull Request resolved: #177212 Approved by: https://github.com/oulgen Co-authored-by: Xia-Weiwen <12522207+Xia-Weiwen@users.noreply.github.com>
…orch#176622) ## Stack This is part of a PR stack. Merge order: 1. **pytorch#176622 — Pallas: strided access support (this PR)** 2. pytorch#176952 — Pallas: permutation detection 3. pytorch#177212 — Pallas: scalar prefetch (depends on pytorch#176952) ## Summary - Add strided access support via reshape + static indexing for non-contiguous tensor patterns - Fix torch_tpu keyword-only API changes in `register_custom_kernel` and `call_custom_kernel` - Set `jax_default_device` to CPU when running in interpret mode on TPU machines ## Test plan - Full test suite on TPU: 869 passed, 153 skipped, 69 xfailed, 0 unexpected failures - `python -m pytest test/inductor/test_pallas.py -v -k "strided"` — all strided access tests pass Pull Request resolved: pytorch#176622 Approved by: https://github.com/oulgen, https://github.com/norx1991
…ytorch#176952) ## Stack This is part of a PR stack. Merge order: 1. pytorch#176622 — Pallas: strided access support (ready to merge) 2. **pytorch#176952 — Pallas: permutation detection (this PR)** 3. pytorch#177212 — Pallas: scalar prefetch (depends on this PR) ## Summary - Generalize permutation detection for N-D tensor transposes (not just 2D swaps) - Add collapsed-dimension detection for cases where iteration dimensions map to grouped tensor dimensions - Skip permutation detection on GPU (inputs are flattened to 1D by `pallas_gpu_pad_inputs`) - Add comprehensive permutation tests for 2D and 3D tensors ## Test plan - `python -m pytest test/inductor/test_pallas.py -v -k "permute"` — all permutation tests pass on TPU - Full test suite: 878 passed, 154 skipped, 67 xfailed, 102 subtests passed, 0 unexpected failures Pull Request resolved: pytorch#176952 Approved by: https://github.com/oulgen
…ytorch#176952) ## Stack This is part of a PR stack. Merge order: 1. pytorch#176622 — Pallas: strided access support (ready to merge) 2. **pytorch#176952 — Pallas: permutation detection (this PR)** 3. pytorch#177212 — Pallas: scalar prefetch (depends on this PR) ## Summary - Generalize permutation detection for N-D tensor transposes (not just 2D swaps) - Add collapsed-dimension detection for cases where iteration dimensions map to grouped tensor dimensions - Skip permutation detection on GPU (inputs are flattened to 1D by `pallas_gpu_pad_inputs`) - Add comprehensive permutation tests for 2D and 3D tensors ## Test plan - `python -m pytest test/inductor/test_pallas.py -v -k "permute"` — all permutation tests pass on TPU - Full test suite: 878 passed, 154 skipped, 67 xfailed, 102 subtests passed, 0 unexpected failures Pull Request resolved: pytorch#176952 Approved by: https://github.com/oulgen
## Stack This is part of a PR stack. Merge order: 1. pytorch#176622 — Pallas: strided access support (ready to merge) 2. pytorch#176952 — Pallas: permutation detection 3. **pytorch#177212 — Pallas: scalar prefetch (this PR, depends on pytorch#176952)** ## Summary - Add scalar prefetch support for TPU Pallas kernels - Enable indirect access patterns (gather-like operations) via scalar prefetch pipeline - Index tensors are loaded in a separate prefetch stage for better TPU performance ## Test plan - `python -m pytest test/inductor/test_pallas.py -v` — full test suite passes on TPU - Scalar prefetch / indirect access tests pass Pull Request resolved: pytorch#177212 Approved by: https://github.com/oulgen
## Stack This is part of a PR stack. Merge order: 1. #176622 — Pallas: strided access support (ready to merge) 2. #176952 — Pallas: permutation detection 3. **#177212 — Pallas: scalar prefetch (this PR, depends on #176952)** ## Summary - Add scalar prefetch support for TPU Pallas kernels - Enable indirect access patterns (gather-like operations) via scalar prefetch pipeline - Index tensors are loaded in a separate prefetch stage for better TPU performance ## Test plan - `python -m pytest test/inductor/test_pallas.py -v` — full test suite passes on TPU - Scalar prefetch / indirect access tests pass Pull Request resolved: #177212 Approved by: https://github.com/oulgen
## Stack This is part of a PR stack. Merge order: 1. pytorch#176622 — Pallas: strided access support (ready to merge) 2. pytorch#176952 — Pallas: permutation detection 3. **pytorch#177212 — Pallas: scalar prefetch (this PR, depends on pytorch#176952)** ## Summary - Add scalar prefetch support for TPU Pallas kernels - Enable indirect access patterns (gather-like operations) via scalar prefetch pipeline - Index tensors are loaded in a separate prefetch stage for better TPU performance ## Test plan - `python -m pytest test/inductor/test_pallas.py -v` — full test suite passes on TPU - Scalar prefetch / indirect access tests pass Pull Request resolved: pytorch#177212 Approved by: https://github.com/oulgen
Stack
This is part of a PR stack. Merge order:
Summary
register_custom_kernelandcall_custom_kerneljax_default_deviceto CPU when running in interpret mode on TPU machinesTest plan
python -m pytest test/inductor/test_pallas.py -v -k "strided"— all strided access tests passcc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @kadeng @muchulee8 @amjames @chauhang @aakhundov @coconutruben @jataylo