Skip to content
This repository was archived by the owner on Mar 3, 2026. It is now read-only.

Support of Splash Attention using xla_builder.call_jax#145

Merged
zpcore merged 21 commits intomainfrom
piz/sa
Mar 14, 2025
Merged

Support of Splash Attention using xla_builder.call_jax#145
zpcore merged 21 commits intomainfrom
piz/sa

Conversation

@zpcore
Copy link
Copy Markdown
Collaborator

@zpcore zpcore commented Mar 8, 2025

Support of Splash Attention (SA) using new feature xla_builder.call_jax (internal one pager doc). call_jax enables us to copy JAX code directly into torch_xla without diving into detail how to interact with the bottom level API of pallas kernel. This PR serves as an example of how to run feedforward and retrieve the grad under the torch framework.

For quick test, we can use

python torchprime/torch_xla_models/train.py model=llama-3.1-8b dataset_config_name=wikitext-103-raw-v1 global_batch_size=4 profile_step=3 ici_mesh.fsdp=4 model.attention_kernel=splash_attention

Note: The PR is still in the experimental stage. We plan to test it in TorchPrime first and move the kernel to PyTorch/XLA later.

Performance improvement summary compared with MaxText and Flash Attention (FA) kernel:

PTXLA + FA PTXLA + SA PTXLA + SA + SCAN + hostoffload PTXLA + SA + jit caching MaxText + FA MaxText + SA
step time 5.87s 5.367s 4.619s 4.694s 5.95s 4.45s
mfu 37.15% 40.6% 47.19% 46.44% 36.63% 48.98%

@zpcore zpcore requested review from bhavya01, qihqi and tengyifei March 8, 2025 00:19
Comment thread torchprime/torch_xla_models/experimental/custom_kernel.py Outdated
Comment thread torchprime/torch_xla_models/experimental/custom_kernel.py Outdated
@zpcore zpcore marked this pull request as ready for review March 8, 2025 09:01
@zpcore
Copy link
Copy Markdown
Collaborator Author

zpcore commented Mar 10, 2025

We are almost achieving the same fwd and bwd for each decoder layer compared with MaxText. However, overall, there is still a gap between the torchprime (profile) and maxtext (profile) performance. Below is the summary:

Test: Llama 3.1 8B, 8K seq_len; host: v6e-256.

MaxText (JAX) TorchPrime Cause of difference
fwd of each decoder layer 35.7ms 33.2ms N/A
bwd of each decoder layer 93.7ms 109.4ms Extra pallas kernel call in attention activation remat
step time 4.451s 5.367s barrier core overhead

Details:

First main issue is the barrier core, which takes 704ms:
image
If we can get rid of it, the step time can be improved to 4.68s. As a comparison, JAX step time is 4.451s.

Second issue, PTXLA have an extra pallas call during activation remat:
JAX:
image

PTXLA:
image

@zpcore zpcore mentioned this pull request Mar 12, 2025
@zpcore
Copy link
Copy Markdown
Collaborator Author

zpcore commented Mar 12, 2025

The PR requires the nightly build from 03/12 to include pytorch/xla#8789. Need merge #145 to fix the test failure.

Copy link
Copy Markdown
Contributor

@tengyifei tengyifei left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

still reviewing the other half

Comment thread torchprime/launcher/Dockerfile Outdated
Comment thread torchprime/torch_xla_models/experimental/custom_kernel.py Outdated
Comment thread torchprime/torch_xla_models/experimental/custom_kernel.py Outdated
Comment thread torchprime/torch_xla_models/experimental/custom_kernel.py Outdated
Comment thread torchprime/torch_xla_models/experimental/custom_kernel.py
Comment thread torchprime/torch_xla_models/experimental/custom_kernel.py Outdated
Comment thread torchprime/torch_xla_models/experimental/test/test_splash_attention.py Outdated
Comment thread .github/workflows/e2e_test.yml Outdated
Comment thread .github/workflows/e2e_test.yml Outdated
Comment thread .github/workflows/e2e_test.yml Outdated
Comment thread torchprime/torch_xla_models/experimental/custom_kernel.py Outdated
Comment thread torchprime/torch_xla_models/experimental/custom_kernel.py Outdated
Comment thread torchprime/torch_xla_models/experimental/test/test_splash_attention.py Outdated
Comment thread torchprime/torch_xla_models/experimental/test/test_splash_attention.py Outdated
Comment thread torchprime/torch_xla_models/llama/model.py
Comment thread torchprime/torch_xla_models/train.py Outdated
Comment thread .github/workflows/e2e_test.yml Outdated
Comment thread torchprime/torch_xla_models/experimental/custom_kernel.py Outdated
@tengyifei
Copy link
Copy Markdown
Contributor

need to rebase i think

@tengyifei
Copy link
Copy Markdown
Contributor

Oh! Not have to be in this PR; would you like to update 1 and close out #133. We could add the cmdline recipe and MFU to 2 and close out #135 too!

@zpcore zpcore merged commit e100715 into main Mar 14, 2025
@zpcore zpcore deleted the piz/sa branch March 14, 2025 23:19
@zpcore
Copy link
Copy Markdown
Collaborator Author

zpcore commented Mar 14, 2025

Oh! Not have to be in this PR; would you like to update 1 and close out #133. We could add the cmdline recipe and MFU to 2 and close out #135 too!

Sure, I will update the performance with the SA.

Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants