Skip to content

[Pallas] Support segment ids in flash attention#6943

Merged
alanwaketan merged 11 commits intomasterfrom
alanwaketan/fa_segment_ids
May 1, 2024
Merged

[Pallas] Support segment ids in flash attention#6943
alanwaketan merged 11 commits intomasterfrom
alanwaketan/fa_segment_ids

Conversation

@alanwaketan
Copy link
Copy Markdown
Collaborator

@alanwaketan alanwaketan commented Apr 19, 2024

Summary:
This PR is to add segment ids to the flash attention wrapper. The segment ids are a way to create an attention mask where each token can only attend to other tokens within the same segment. The mask is therefore a block diagonal matrix.

To support it, we further split the flash attention forward into tracing and execution part, and implement all the shape operations to make it compatible with the kernel.

Test Plan:
PJRT_DEVICE=TPU python test/test_pallas.py

@alanwaketan alanwaketan self-assigned this Apr 19, 2024
@alanwaketan alanwaketan force-pushed the alanwaketan/fa_segment_ids branch from b6a8ed8 to b9cfc67 Compare April 26, 2024 01:52
@JackCaoG
Copy link
Copy Markdown
Collaborator

Is this ready for review?

[q, k, v, q_segment_ids, kv_segment_ids], payload, shapes, dtypes)

if not save_residuals:
o = o[0]
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what's this for?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

_xla_tpu_custom_call always return an array.

Comment thread torch_xla/experimental/custom_kernel.py Outdated
Comment thread torch_xla/experimental/custom_kernel.py
@alanwaketan
Copy link
Copy Markdown
Collaborator Author

Is this ready for review?

I still need to add spmd and dynamo support. So not yet.

@alanwaketan alanwaketan marked this pull request as ready for review April 30, 2024 22:46
@alanwaketan
Copy link
Copy Markdown
Collaborator Author

@JackCaoG Do you think we can do the SPMD and dynamo parts later since the customer is not using either of them now?

@JackCaoG
Copy link
Copy Markdown
Collaborator

yea.. don't worry about SPMD and dynamo for this pr, let's do that in a separate pr..

Comment thread test/test_pallas.py

@unittest.skipIf(xr.device_type() != 'TPU' or tpu.version() < 3,
"This test only works on TPUv3+.")
def test_flash_attention_wrapper_segment_ids_2(self):
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

so you have 2 test, one compare to native torch, one compare to jax?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yea, the JAX test is written before I figure out how to do the non-kernel mask.

Comment thread test/test_pallas.py
Comment on lines +695 to +703
torch.manual_seed(42)
q = torch.randn(4, 2, 128, 8, requires_grad=True).to("xla")
k = torch.randn(4, 2, 128, 8, requires_grad=True).to("xla")
v = torch.randn(4, 2, 128, 8, requires_grad=True).to("xla")
q_segment_ids = torch.zeros(4, 128).to("xla")
kv_segment_ids = torch.zeros(4, 128).to("xla")
q.retain_grad()
k.retain_grad()
v.retain_grad()
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we refactor this part out in a helper function in this test?

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

actually just refactor this part out and uses it on all tests, it is the same for all tests.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You mean the tensor initializations? Those are kinda of expected paperworks. I don't think it's necessary to improve...

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I will leave that to you. When I see two large chunks of codes that looks similar, I usually tried to find how they are different. It confused me a bit when I realized it is the same code repeating over and over.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yea, for testing, it's sometime hard to avoid... haha

grad_v, partition_spec, full_shape, mesh=mesh).global_tensor

return grad_q, grad_k, grad_v, None, None, None
return grad_q, grad_k, grad_v, None, None, None, None, None
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why do we need to return these Nones?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's the rule of the autograd.Function where all the inputs passed in the forward need to have the corresponding grads. For inputs that we don't diff on, we return None.

@JackCaoG JackCaoG added the tpuci label Apr 30, 2024
@alanwaketan
Copy link
Copy Markdown
Collaborator Author

Thanks, Jack.

@alanwaketan alanwaketan merged commit 400bd0c into master May 1, 2024
@alanwaketan alanwaketan deleted the alanwaketan/fa_segment_ids branch May 1, 2024 18:38
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants