Skip to content

[Pallas] Integrate FlashAttention with SPMD#6935

Merged
alanwaketan merged 9 commits intomasterfrom
alanwaketan/fa_spmd
Apr 18, 2024
Merged

[Pallas] Integrate FlashAttention with SPMD#6935
alanwaketan merged 9 commits intomasterfrom
alanwaketan/fa_spmd

Conversation

@alanwaketan
Copy link
Copy Markdown
Collaborator

Summary:
This pull request integrating FlashAttention with SPMD. The way it works is to create a manual sharding region for the kernel which means we wraps all the inputs with enable_manual_sharding and all the outputs with disable_manual_sharding.

Added a new test file because the original test file is not SPMD aware.

Test Plan:
PJRT_DEVICE=TPU python test/test_pallas_spmd.py

@alanwaketan alanwaketan self-assigned this Apr 17, 2024
Copy link
Copy Markdown
Collaborator

@JackCaoG JackCaoG left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lgtm, minor comments

Comment thread test/test_pallas_spmd.py
Comment on lines +14 to +19
if xr.device_type() == 'TPU':
from torch_xla.experimental.custom_kernel import jax_import_guard
jax_import_guard()
import jax
import jax.numpy as jnp
from jax.experimental import pallas as pl
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit, you can put this part in the setup class similar to https://github.com/pytorch/xla/blob/master/test/spmd/test_xla_sharding_base.py#L31-L35

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is python import global?

Comment thread test/test_pallas_spmd.py Outdated
jax.config.update('jax_default_matmul_precision', jax.lax.Precision.HIGHEST)
from torch_xla.experimental.custom_kernel import flash_attention

xr.use_spmd()
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit, this should be called in the setup class since it is a one time global config.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I probably can call this in main as well. The setup class seems overkilled for this.

Copy link
Copy Markdown
Collaborator

@jonb377 jonb377 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Awesome stuff Jiewen!

Comment thread test/test_pallas_spmd.py
@unittest.skipIf(xr.device_type() != 'TPU' or tpu.version() < 3,
"This test only works on TPUv3+.")
def test_flash_attention_spmd_data_parallel(self):
jax.config.update('jax_default_matmul_precision', jax.lax.Precision.HIGHEST)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this impact the resulting kernel?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yea.

Comment thread torch_xla/experimental/custom_kernel.py Outdated

@staticmethod
def forward(ctx, q, k, v, causal=False):
def forward(ctx, q, k, v, causal=False, sharding_spec=None, mesh=None):
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: sharding_spec -> partition_spec?

@alanwaketan
Copy link
Copy Markdown
Collaborator Author

Thanks Jon and Jack for the reviews.

@alanwaketan alanwaketan merged commit 9f2b82d into master Apr 18, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants