Skip to content

Adapt Splash Attention from TorchPrime#8911

Merged
zpcore merged 16 commits intomasterfrom
piz/port_sa
Apr 11, 2025
Merged

Adapt Splash Attention from TorchPrime#8911
zpcore merged 16 commits intomasterfrom
piz/port_sa

Conversation

@zpcore
Copy link
Copy Markdown
Member

@zpcore zpcore commented Mar 31, 2025

Adapt the PR AI-Hypercomputer/torchprime#145 from TorchPrime into PTXLA. Also simplified the code to use jit hashing from #8878.

In addition, fix a small bug in xla_builder.call_jax when the input arg contains both None and other hashable types in sequence.

@zpcore zpcore marked this pull request as ready for review April 5, 2025 21:00
@zpcore zpcore requested a review from tengyifei April 5, 2025 21:00
@zpcore zpcore enabled auto-merge (squash) April 5, 2025 23:28
Copy link
Copy Markdown
Collaborator

@tengyifei tengyifei left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice!

Comment thread torch_xla/experimental/custom_kernel_from_jax.py Outdated
Comment thread torch_xla/experimental/custom_kernel_from_jax.py Outdated
Comment thread torch_xla/experimental/custom_kernel_from_jax.py Outdated
Comment thread torch_xla/experimental/custom_kernel_from_jax.py Outdated
Comment thread torch_xla/experimental/custom_kernel_from_jax.py Outdated
Comment thread test/test_splash_attention_jax.py Outdated
Comment thread test/test_splash_attention_jax.py Outdated
Comment thread test/tpu/run_tests.sh Outdated
Comment thread torch_xla/experimental/custom_kernel_from_jax.py
Comment thread torch_xla/experimental/custom_kernel_from_jax.py
@tengyifei tengyifei requested a review from bhavya01 April 8, 2025 21:29
@tengyifei
Copy link
Copy Markdown
Collaborator

Looks like some comments still need to be addressed -- LMK whenever I should TAL!

@zpcore
Copy link
Copy Markdown
Member Author

zpcore commented Apr 9, 2025

Looks like some comments still need to be addressed -- LMK whenever I should TAL!

Yes, I am working on getting rid of the lru_cache. Need to fix some small issues before resolving the feedback. Thanks!

Comment thread torch_xla/experimental/splash_attention.py
@zpcore
Copy link
Copy Markdown
Member Author

zpcore commented Apr 10, 2025

Oh, interesting that the test failed for the cache miss count. Looks like the HLO cache can be reused between test functions.

@zpcore
Copy link
Copy Markdown
Member Author

zpcore commented Apr 11, 2025

Hi @tengyifei , I created issue #8963 to track the hashing issue. Will follow up in a separate PR for the fix.

Copy link
Copy Markdown
Collaborator

@tengyifei tengyifei left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

SGTM / LGTM

Comment thread test/test_splash_attention.py Outdated
Comment thread test/test_splash_attention.py
@zpcore zpcore merged commit 4583051 into master Apr 11, 2025
23 of 24 checks passed
@zpcore zpcore deleted the piz/port_sa branch April 11, 2025 16:41
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants