[ragged-paged-attn] Combine k_pages and v_pages on num_kv_head#8892
[ragged-paged-attn] Combine k_pages and v_pages on num_kv_head#8892vanbasten23 merged 9 commits intopytorch:masterfrom
Conversation
| mask_value: float, | ||
| sliding_window: int | None = None, | ||
| soft_cap: float | None = None, | ||
| mask_value: float | None = DEFAULT_MASK_VALUE, |
There was a problem hiding this comment.
left some comments in your original cl for the kernel.
There was a problem hiding this comment.
Thx! I will resolve there!
| page_indices_xla = page_indices.to("xla") | ||
| cu_q_lens_xla = cu_q_lens.to("xla") | ||
| num_seqs_xla = torch.tensor([num_seqs], dtype=torch.int32).to("xla") | ||
| sliding_window = sliding_window |
There was a problem hiding this comment.
nit no need: line672-673?
There was a problem hiding this comment.
Good point! Looks like we have merged this PR! Let me resolve in a separated PR
| sliding_window = sliding_window | ||
| soft_cap = soft_cap | ||
| # Test mask_value | ||
| mask_value = None |
There was a problem hiding this comment.
imo, we can just use the default mask value rather than letting the user choose one.
There was a problem hiding this comment.
I will remove that in follow up PR since this PR is merged
|
Mostly LGTM, pending on CI. Thanks Jevin! |
| _, page_size, kv_hidden_size = k_pages.shape | ||
| num_kv_heads = kv_hidden_size // head_dim | ||
| check_inputs_shapes(q, kv_pages, kv_lens, page_indices, cu_q_lens, num_seqs) | ||
| if mask_value is None: |
There was a problem hiding this comment.
Do you still needed since mask_value is assign to a default value "mask_value: float | None = DEFAULT_MASK_VALUE"
There was a problem hiding this comment.
Yes, the float | None is allowed-type. So if mask_value is None, it won't use DEFAULT_MASK_VALUE
There was a problem hiding this comment.
But you have "= DEFAULT_MASK_VALUE", that means if mask_value is None, it will use DEFAULT_MASK_VALUE, right?
There was a problem hiding this comment.
No, as mentioned, the None in float | None is just allowed type, you can do a simple test:
>>> def f(a, b: float | None = 1.0):
... print(a, b)
...
>>> f(2, None)
2 None
This PR
sliding_windowandsoft_capchangeTested: