[Pallas] Support segment ids in flash attention#6943
Conversation
b6a8ed8 to
b9cfc67
Compare
|
Is this ready for review? |
| [q, k, v, q_segment_ids, kv_segment_ids], payload, shapes, dtypes) | ||
|
|
||
| if not save_residuals: | ||
| o = o[0] |
There was a problem hiding this comment.
_xla_tpu_custom_call always return an array.
I still need to add spmd and dynamo support. So not yet. |
|
@JackCaoG Do you think we can do the SPMD and dynamo parts later since the customer is not using either of them now? |
|
yea.. don't worry about SPMD and dynamo for this pr, let's do that in a separate pr.. |
|
|
||
| @unittest.skipIf(xr.device_type() != 'TPU' or tpu.version() < 3, | ||
| "This test only works on TPUv3+.") | ||
| def test_flash_attention_wrapper_segment_ids_2(self): |
There was a problem hiding this comment.
so you have 2 test, one compare to native torch, one compare to jax?
There was a problem hiding this comment.
Yea, the JAX test is written before I figure out how to do the non-kernel mask.
| torch.manual_seed(42) | ||
| q = torch.randn(4, 2, 128, 8, requires_grad=True).to("xla") | ||
| k = torch.randn(4, 2, 128, 8, requires_grad=True).to("xla") | ||
| v = torch.randn(4, 2, 128, 8, requires_grad=True).to("xla") | ||
| q_segment_ids = torch.zeros(4, 128).to("xla") | ||
| kv_segment_ids = torch.zeros(4, 128).to("xla") | ||
| q.retain_grad() | ||
| k.retain_grad() | ||
| v.retain_grad() |
There was a problem hiding this comment.
can we refactor this part out in a helper function in this test?
There was a problem hiding this comment.
actually just refactor this part out and uses it on all tests, it is the same for all tests.
There was a problem hiding this comment.
You mean the tensor initializations? Those are kinda of expected paperworks. I don't think it's necessary to improve...
There was a problem hiding this comment.
I will leave that to you. When I see two large chunks of codes that looks similar, I usually tried to find how they are different. It confused me a bit when I realized it is the same code repeating over and over.
There was a problem hiding this comment.
Yea, for testing, it's sometime hard to avoid... haha
| grad_v, partition_spec, full_shape, mesh=mesh).global_tensor | ||
|
|
||
| return grad_q, grad_k, grad_v, None, None, None | ||
| return grad_q, grad_k, grad_v, None, None, None, None, None |
There was a problem hiding this comment.
why do we need to return these Nones?
There was a problem hiding this comment.
It's the rule of the autograd.Function where all the inputs passed in the forward need to have the corresponding grads. For inputs that we don't diff on, we return None.
|
Thanks, Jack. |
Summary:
This PR is to add segment ids to the flash attention wrapper. The segment ids are a way to create an attention mask where each token can only attend to other tokens within the same segment. The mask is therefore a block diagonal matrix.
To support it, we further split the flash attention forward into tracing and execution part, and implement all the shape operations to make it compatible with the kernel.
Test Plan:
PJRT_DEVICE=TPU python test/test_pallas.py