Extend paged attention to support query_len>1#8328
Conversation
| page_indices, # [batch_size, pages_per_sequence] | ||
| num_kv_pages_per_compute_block, | ||
| num_queries_per_compute_block, | ||
| use_kernel=True, |
There was a problem hiding this comment.
hey @WoosukKwon, this is the integration point between vLLM and torch_xla. I'm thinking if vLLM can switch this flag use_kernel perhaps by using some flags. I want to use the nonkernel version as a per baseline. Do you know if it possible?
There was a problem hiding this comment.
For dynamo, it's similar. The integration point is at def multi_queries_paged_attention_xla( in the same file.
| q_index = q_blk_idx * num_queries_per_compute_block | ||
| kv_index = kv_blk_idx * kv_seq_len_per_kv_compute_blk | ||
| kv_len = lengths_ref[b] | ||
| row_ids = (kv_len - query_len) + q_index + jax.lax.broadcasted_iota( |
There was a problem hiding this comment.
Here, we assume the input query corresponds to the last (q_len) of the input kv. For example, if the input q_len is 8, and kv_len is 24, we assume the query corresponds to the kv at index [16. 24), and applies the causal mask accordingly.
@WoosukKwon please let us know if this assumption is valid or nor for the use cases in vLLM.
There was a problem hiding this comment.
Yes that's the desired behavior. Thanks for checking it out with me!
This PR extends the existing paged attention kernel to support query_len>1. Additionally, it upgrades the flash attention from v1 to v2.
Test plan:
cc: @miladm