Skip to content

[ragged-paged-attn] Combine k_pages and v_pages on num_kv_head#8892

Merged
vanbasten23 merged 9 commits intopytorch:masterfrom
bythew3i:ragged-attn-v2
Mar 27, 2025
Merged

[ragged-paged-attn] Combine k_pages and v_pages on num_kv_head#8892
vanbasten23 merged 9 commits intopytorch:masterfrom
bythew3i:ragged-attn-v2

Conversation

@bythew3i
Copy link
Copy Markdown
Contributor

@bythew3i bythew3i commented Mar 26, 2025

This PR

  • Combines k_pages and v_pages on num_kv_head to support sharding num_kv_heads to 1 (multi-chip) while still having performant kernel and scatter.
  • Merge sliding_window and soft_cap change
  • Integrate kernel to pytorch cusom kernel
  • Refactor tests and improve the tests coverage
  • Tested dynamo compilation with dynamic grid from pallas

Tested:

python test/test_pallas.py -v -k PallasTest.test_ragged_paged_attention_wrapper

mask_value: float,
sliding_window: int | None = None,
soft_cap: float | None = None,
mask_value: float | None = DEFAULT_MASK_VALUE,
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

left some comments in your original cl for the kernel.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thx! I will resolve there!

Comment thread test/test_pallas.py
page_indices_xla = page_indices.to("xla")
cu_q_lens_xla = cu_q_lens.to("xla")
num_seqs_xla = torch.tensor([num_seqs], dtype=torch.int32).to("xla")
sliding_window = sliding_window
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit no need: line672-673?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point! Looks like we have merged this PR! Let me resolve in a separated PR

Comment thread test/test_pallas.py
sliding_window = sliding_window
soft_cap = soft_cap
# Test mask_value
mask_value = None
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

imo, we can just use the default mask value rather than letting the user choose one.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I will remove that in follow up PR since this PR is merged

@vanbasten23
Copy link
Copy Markdown
Collaborator

Mostly LGTM, pending on CI. Thanks Jevin!

_, page_size, kv_hidden_size = k_pages.shape
num_kv_heads = kv_hidden_size // head_dim
check_inputs_shapes(q, kv_pages, kv_lens, page_indices, cu_q_lens, num_seqs)
if mask_value is None:
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you still needed since mask_value is assign to a default value "mask_value: float | None = DEFAULT_MASK_VALUE"

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, the float | None is allowed-type. So if mask_value is None, it won't use DEFAULT_MASK_VALUE

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But you have "= DEFAULT_MASK_VALUE", that means if mask_value is None, it will use DEFAULT_MASK_VALUE, right?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, as mentioned, the None in float | None is just allowed type, you can do a simple test:

>>> def f(a, b: float | None = 1.0):
...     print(a, b)
...
>>> f(2, None)
2 None

@vanbasten23 vanbasten23 merged commit 7a3c051 into pytorch:master Mar 27, 2025
23 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants