This repository was archived by the owner on Mar 3, 2026. It is now read-only.
Conversation
tengyifei
reviewed
Mar 8, 2025
bhavya01
reviewed
Mar 8, 2025
Collaborator
Author
|
We are almost achieving the same fwd and bwd for each decoder layer compared with MaxText. However, overall, there is still a gap between the torchprime (profile) and maxtext (profile) performance. Below is the summary: Test: Llama 3.1 8B, 8K seq_len; host: v6e-256.
Details:First main issue is the Second issue, PTXLA have an extra pallas call during activation remat: |
Merged
Collaborator
Author
|
The PR requires the nightly build from 03/12 to include pytorch/xla#8789. Need merge #145 to fix the test failure. |
tengyifei
suggested changes
Mar 13, 2025
Contributor
tengyifei
left a comment
There was a problem hiding this comment.
still reviewing the other half
tengyifei
reviewed
Mar 13, 2025
tengyifei
suggested changes
Mar 13, 2025
tengyifei
approved these changes
Mar 14, 2025
Contributor
|
need to rebase i think |
This was referenced Mar 18, 2025
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to subscribe to this conversation on GitHub.
Already have an account?
Sign in.
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.



Support of Splash Attention (SA) using new feature xla_builder.call_jax (internal one pager doc).
call_jaxenables us to copy JAX code directly into torch_xla without diving into detail how to interact with the bottom level API of pallas kernel. This PR serves as an example of how to run feedforward and retrieve the grad under the torch framework.For quick test, we can use
Note: The PR is still in the experimental stage. We plan to test it in TorchPrime first and move the kernel to PyTorch/XLA later.
Performance improvement summary compared with MaxText and Flash Attention (FA) kernel: