Skip to content

Add heuristic default block sizes for different cases in ragged attention kernel#8922

Merged
yaochengji merged 8 commits intomasterfrom
chengji/ragged-attn
Apr 3, 2025
Merged

Add heuristic default block sizes for different cases in ragged attention kernel#8922
yaochengji merged 8 commits intomasterfrom
chengji/ragged-attn

Conversation

@yaochengji
Copy link
Copy Markdown
Collaborator

@yaochengji yaochengji commented Apr 2, 2025

No description provided.

@yaochengji yaochengji requested a review from vanbasten23 April 2, 2025 00:58
Comment thread test/test_pallas.py Outdated
@@ -856,6 +867,22 @@ def test_ragged_paged_attention_wrapper_without_dynamo(
use_dynamo=False,
)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove the old _test_ragged_paged_attention?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We'd better also test the non-None block size parameter, that's why it's still there.

Comment thread test/test_pallas.py
@@ -817,6 +812,22 @@ def test_ragged_paged_attention_wrapper_with_dynamo(
use_dynamo=True,
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove the old _test_ragged_paged_attention?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We'd better also test the non-None block size parameter, that's why it's still there.

Comment thread torch_xla/experimental/custom_kernel.py Outdated
raise NotImplementedError("TPU version must be 4 or higher.")
# NOTE: the TPU v4's vmem capacity is 16MB
if tpu_version == 4:
vmem_limit_bytes = 16 * 1024 * 1024
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is fine, even we set more than 16MB, it will still use 16 MB

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, done.

Comment thread test/test_pallas.py Outdated
soft_cap=soft_cap,
pad_tokens_and_seqs=pad_tokens_and_seqs,
use_dynamo=True,
num_kv_pages_per_block=None,
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: consider to make the block sizes a parameter, eg num_kv_pages_per_block=[16, None], similar for num_queries_per_block

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the suggestion, done.

Comment thread torch_xla/experimental/custom_kernel.py Outdated
# This heristic is based on the initial kernel micro benchmarking:
# When the token_num is small, there's no long request of prefill.
# While when it's larger, the block size is adjusted for it.
if token_num <= 128:
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder if we should choose the block sizes in vLLM. If we choose to do in torch_xla, then we need to change it in torch_xla and wait for the wheel tmr. If we do in vLLM, it'd be more convenient. wdyt?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If vLLM pass a non-None block size, the default value will not be used.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder why we couldn't do _get_default_ragged_paged_attention_block_size in vLLM..

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Usually it's a good idea to put the tuned-parameter table in the kernel lib, not in the app lib.

Copy link
Copy Markdown
Collaborator

@vanbasten23 vanbasten23 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks Chengji!

@yaochengji yaochengji force-pushed the chengji/ragged-attn branch from 64d9bad to 8286d56 Compare April 3, 2025 03:37
@yaochengji yaochengji merged commit 6c3f231 into master Apr 3, 2025
23 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants