Enable PagedAttention through Pallas#6912
Conversation
50dac57 to
b6822a3
Compare
b0262b0 to
72fdd57
Compare
45d9fe3 to
f07c7e3
Compare
f07c7e3 to
1f70ed9
Compare
1f70ed9 to
f32836e
Compare
|
cc @WoosukKwon to take a look |
cc5ad3a to
312bef1
Compare
|
Locally, the tests are succeeding on my v4: I also just triggered the TPU CI on this PR. |
c6040cf to
19e28f8
Compare
|
The CPU CI is failing with an unrelated test: The CI including the TPU CI is passing, so this PR should be good to be reviewed. Thanks! |
| torch.allclose( | ||
| output.cpu()[seq_lens > 0], | ||
| expected_output.cpu()[seq_lens > 0], | ||
| atol=1e-1, |
There was a problem hiding this comment.
wdyt we use a tighter bound for atol and rtol? e.g. 1e-3
There was a problem hiding this comment.
Sg, updated to 1e-5 for both tests.
miladm
left a comment
There was a problem hiding this comment.
Thanks @wonjoolee95 - left a comment for you to eval and address - approving to unblock you
alanwaketan
left a comment
There was a problem hiding this comment.
In general, it looks good to me. Left a few comments.
| return FlashAttention.apply(q, k, v, causal) | ||
|
|
||
|
|
||
| def paged_attention(q, k_pages, v_pages, lengths, page_indices, |
There was a problem hiding this comment.
The original kernel has this thing called: q_dtype_for_kernel_launch? What does it do? Should we copy that as well?
There was a problem hiding this comment.
In the original kernel, the q_dtype_for_kernel_launch is always either jnp.float32 or q's dtype. In our case, I'm expecting the passed-in q's dtype to be torch.float32, so the q_dtype_for_kernel_launch will always be float32.
There was a problem hiding this comment.
No, I don't think that will be the case for actual workflow. It could be bf16 or even in8 etc...
There was a problem hiding this comment.
I see, makes sense. Just updated to handle q_dtype_for_kernel_launch, following jax's kernel -- https://github.com/google/jax/blob/main/jax/experimental/pallas/ops/tpu/paged_attention/paged_attention_kernel.py#L393. I can follow-up in another PR to add some more unit tests for different dtypes for q.
| pages_per_compute_block: int): | ||
| # This will be called when dynamo use fake tensor to construct the fake output. | ||
| # We need to make sure output tensor's shape is correct. | ||
| if k.device != torch.device("meta"): |
There was a problem hiding this comment.
It feels like this part can be consolidated with the flash attention one.
There was a problem hiding this comment.
Sg, refactored these into a helper function.
| @unittest.skipIf(xr.device_type() != 'TPU' or tpu.version() < 4, | ||
| "This test only works on TPUv4+.") | ||
| def test_paged_attention_wrapper(self): | ||
| jax.config.update('jax_default_matmul_precision', jax.lax.Precision.HIGHEST) |
There was a problem hiding this comment.
It's interesting that you use jax as the reference. I guess that works too. Wondering if we can just the eager attention helper in the class instead? Or that doesn't work? Anyway, if you are using jax as the reference, you can drop this.
There was a problem hiding this comment.
Sg, yeah I saw that we're dependent on JAX Pallas anyways, so I thought it may be easier to just test against the JAX's outputs.
Ah, makes sense. Just removed the jax.config updates.
| q_xla, | ||
| k_pages_xla, | ||
| v_pages_xla, | ||
| seq_lens_xla, |
There was a problem hiding this comment.
Can you explain what these seq_lens are? Are these the previous tokens for each batch in k, v?
There was a problem hiding this comment.
Yep, that is my understanding -- the seq_lens here equals the number of tokens that are processed in the batch. Reference: https://docs.vllm.ai/en/latest/dev/kernel/paged_attention.html#concepts.
19e28f8 to
b3a5948
Compare
| # We need to make sure output tensor's shape is correct. | ||
| if k.device != torch.device("meta"): | ||
| warnings.warn( | ||
| 'XLA flash attention should only be applied to tensors on XLA device') |
There was a problem hiding this comment.
nit, paged attention instead of flash attention
There was a problem hiding this comment.
actually it is not even paged attention, you can just make this warning message more general.
There was a problem hiding this comment.
Good catch, updated to use an f string.
| step = torch.zeros((1,), dtype=torch.int32).to("xla") | ||
| output_shape = torch.Size(list(q.shape[:-1]) + [1]) | ||
| q_output_dtype = torch.float32 | ||
| if (num_heads // num_kv_heads) % 8 != 0: |
There was a problem hiding this comment.
I guess you can combine this with the above L396 code.
There was a problem hiding this comment.
Good catch! Updated.
|
Thanks all for the reviews. After addressing all the comments, the two unit tests are still passing locally on my V4. I'll let the TPU CI verify one more time before merging. |
| ], payload, [q.shape, output_shape, output_shape], | ||
| [q_output_dtype, torch.float32, torch.float32]) | ||
|
|
||
| return output.reshape(batch_size, num_heads, head_dim) |
There was a problem hiding this comment.
You probably want to use .to to cast the output back to the original dtype here.
961dfff to
7818d0f
Compare
|
Merging as all CI is green. |
Enable PagedAttention through Pallas
Test plan:
Todo as follow-ups: