Skip to content

Extend paged attention to support query_len>1#8328

Merged
vanbasten23 merged 15 commits intomasterfrom
xiowei/extend_paged_attention_cleanedup
Oct 31, 2024
Merged

Extend paged attention to support query_len>1#8328
vanbasten23 merged 15 commits intomasterfrom
xiowei/extend_paged_attention_cleanedup

Conversation

@vanbasten23
Copy link
Copy Markdown
Collaborator

@vanbasten23 vanbasten23 commented Oct 27, 2024

This PR extends the existing paged attention kernel to support query_len>1. Additionally, it upgrades the flash attention from v1 to v2.

Test plan:

  • python pytorch/xla/test/test_pallas.py -v -k PallasTest.test_paged_attention_multi_queries_wrapper
  • python pytorch/xla/test/test_tpu_paged_attention_kernel.py 2>&1 | tee out.txt

cc: @miladm

@vanbasten23 vanbasten23 marked this pull request as ready for review October 28, 2024 17:34
page_indices, # [batch_size, pages_per_sequence]
num_kv_pages_per_compute_block,
num_queries_per_compute_block,
use_kernel=True,
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hey @WoosukKwon, this is the integration point between vLLM and torch_xla. I'm thinking if vLLM can switch this flag use_kernel perhaps by using some flags. I want to use the nonkernel version as a per baseline. Do you know if it possible?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For dynamo, it's similar. The integration point is at def multi_queries_paged_attention_xla( in the same file.

Comment thread torch_xla/experimental/pallas_kernels/multi_queries_paged_attention_kernel.py Outdated
q_index = q_blk_idx * num_queries_per_compute_block
kv_index = kv_blk_idx * kv_seq_len_per_kv_compute_blk
kv_len = lengths_ref[b]
row_ids = (kv_len - query_len) + q_index + jax.lax.broadcasted_iota(
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here, we assume the input query corresponds to the last (q_len) of the input kv. For example, if the input q_len is 8, and kv_len is 24, we assume the query corresponds to the kv at index [16. 24), and applies the causal mask accordingly.

@WoosukKwon please let us know if this assumption is valid or nor for the use cases in vLLM.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes that's the desired behavior. Thanks for checking it out with me!

Comment thread torch_xla/experimental/pallas_kernels/multi_queries_paged_attention_kernel.py Outdated
Comment thread torch_xla/experimental/custom_kernel.py
@vanbasten23 vanbasten23 merged commit 1bac062 into master Oct 31, 2024
@miladm miladm added the pallas label Nov 22, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants