Skip to content

Support megacore_mode in paged_attention#7060

Merged
wonjoo-wj merged 1 commit intomasterfrom
wonjoo/paged-attention/megacore-modes
May 14, 2024
Merged

Support megacore_mode in paged_attention#7060
wonjoo-wj merged 1 commit intomasterfrom
wonjoo/paged-attention/megacore-modes

Conversation

@wonjoo-wj
Copy link
Copy Markdown
Collaborator

@wonjoo-wj wonjoo-wj commented May 14, 2024

Support megacore_mode in paged_attention

JAX reference for megacore_mode: https://github.com/google/jax/blob/main/jax/experimental/pallas/ops/tpu/paged_attention/paged_attention_kernel.py#L318

Test plan:

python test/test_pallas.py PallasTest.test_paged_attention_wrapper_with_megacore_modes

+ TPU CI

@wonjoo-wj
Copy link
Copy Markdown
Collaborator Author

Locally test is succeeding on my v4-8:

root@t1v-n-4989e8c7-w-0:~/pytorch/xla# python test/test_pallas.py PallasTest.test_paged_attention_wrapper_with_megacore_modes
.
----------------------------------------------------------------------
Ran 1 test in 3.283s

OK
root@t1v-n-4989e8c7-w-0:~/pytorch/xla# 

I'll wait for TPU CI to verify the rest.

@wonjoo-wj wonjoo-wj requested review from JackCaoG and alanwaketan May 14, 2024 19:31
Copy link
Copy Markdown
Collaborator

@alanwaketan alanwaketan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM.

@wonjoo-wj
Copy link
Copy Markdown
Collaborator Author

Thanks for the reviews, merging as all CIs are green.

@wonjoo-wj wonjoo-wj merged commit cbb9e21 into master May 14, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants