Skip to content

[Pallas] Support ab for flash_attention#7840

Merged
JackCaoG merged 6 commits intomasterfrom
alanwaketan/flash_ab
Aug 14, 2024
Merged

[Pallas] Support ab for flash_attention#7840
JackCaoG merged 6 commits intomasterfrom
alanwaketan/flash_ab

Conversation

@alanwaketan
Copy link
Copy Markdown
Collaborator

Summary:
This pull request adds ab support for flash_attention which is a custom mask for attention weight.

Test Plan:
PJRT_DEVICE=TPU python test/test_pallas.py -v -k test_flash_attention_ab

@alanwaketan alanwaketan self-assigned this Aug 13, 2024
@ZhiyuLi-goog
Copy link
Copy Markdown
Contributor

Thank you @alanwaketan!

Comment thread test/test_pallas.py
@unittest.skipIf(xr.device_type() != 'TPU' or tpu.version() < 3,
"This test only works on TPUv3+.")
def test_flash_attention_ab(self):
jax.config.update("jax_default_matmul_precision", "highest")
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lol we should really make a context manager that take cares of this in this test.

@JackCaoG JackCaoG added the tpuci label Aug 14, 2024
Copy link
Copy Markdown
Collaborator

@JackCaoG JackCaoG left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you rerun the CI? TPUCI was not enabled so the test was not run.

@JackCaoG
Copy link
Copy Markdown
Collaborator

actually let me just trigger it...

@JackCaoG JackCaoG merged commit 21a0b5a into master Aug 14, 2024
@JackCaoG JackCaoG deleted the alanwaketan/flash_ab branch August 14, 2024 21:19
@alanwaketan
Copy link
Copy Markdown
Collaborator Author

Thanks, Jack.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants