[ragged-paged-attn] Use hidden states in kv cache and support any num_kv_head#8851
[ragged-paged-attn] Use hidden states in kv cache and support any num_kv_head#8851vanbasten23 merged 4 commits intopytorch:masterfrom
Conversation
| mask_value = DEFAULT_MASK_VALUE | ||
| validate_ragged_paged_attention_inputs(q, k_pages, v_pages, kv_lens, | ||
| page_indices, cu_q_lens, num_seqs) | ||
|
|
There was a problem hiding this comment.
why stopped checking validate_ragged_paged_attention_inputs?
There was a problem hiding this comment.
Because we have these static shape check in JAX already
|
|
||
| q_packing = get_dtype_packing(q_dtype) | ||
| max_q_tiling = 8 * q_packing | ||
| min_q_heads = lcm(max_q_tiling, num_q_heads_per_kv_head) |
There was a problem hiding this comment.
I am not sure if I follow. If dtype is bf16, then max_q_tiling is 16. If it's qwen where num_q_heads=12, kum_kv_head=2, num_q_heads_per_kv_head=6, then min_q_heads (=lcm(max_q_tiling, num_q_heads_per_kv_head)) will be 48. What does min_q_heads mean?
There was a problem hiding this comment.
It tries to find a min number that is fully divisible by both max_q_tiling and num_q_heads_per_kv_head, if this number can divide total num_q_heads evenly, we just use this number as num_q_heads_per_blk. If we can not find one, we use the total num_q_heads .
Checking if it is divisible by max_q_tiling is to make sure it can be fully tiled by XLA.
Checking if it is divisible by num_q_heads_per_kv_head is to make sure we do not need to have inner split in num_q_heads_per_kv_head.
There was a problem hiding this comment.
Thanks! Could you add what you said as a comment in the code?
| raise ValueError(f"{num_seqs[0]=} must be less or equal to {max_num_seqs=}") | ||
| max_kv_len = jnp.max(kv_lens) | ||
| min_pages_per_seq = ceil_div(max_kv_len, page_size) | ||
| min_pages_per_seq = cdiv(max_kv_len, page_size) |
There was a problem hiding this comment.
why is it min? Shouldn't it be max_pages_per_seq since you used cdiv(jnp.max(kv_lens), page_size)?
There was a problem hiding this comment.
That is lower bound for pages_per_seq
| _, page_size, kv_model_dim = k_pages.shape | ||
| kv_packing = get_dtype_packing(k_pages.dtype) | ||
| if page_size % kv_packing != 0: | ||
| raise ValueError(f"Expected {page_size=} is divisible by {kv_packing=}") |
There was a problem hiding this comment.
page_size % kv_packing != 0 indicating there will be padding so we may waste some memory. Can we give a warning instead of raising an exception?
There was a problem hiding this comment.
The page size is chosen by the serving config, the error indicates we should choose better one. Otherwise when people using bf16 or quantized types (fp8, int8, int4) there will be no bandwidth saving. We should prevent this.
There was a problem hiding this comment.
I see. I guess it's the same reason why before this PR when num_kv_head==1 and dtype=bf16, we would raise an exception
xla/torch_xla/experimental/pallas_kernels/ragged_paged_attention_v2.py
Lines 534 to 536 in c7d0b1e
Previous if num_kv_head == 1 and dtype=bfloat16, we will have implicit padding in TPU.). The point is the code may still run fine but there will be no bandwidth savings.
There was a problem hiding this comment.
Yes, the point of quantization is to save more memory and bandwidth
vanbasten23
left a comment
There was a problem hiding this comment.
Thanks Jevin. LGTM pending on CI.
|
@bythew3i I assume you have run the tests tests/pallas/tpu_ragged_paged_attention_test.py and they all pass? |
Yes I tested the kernel. |
This PR uses hidden states (num_kv_head * head_dim) in kv cache. This change can unblock us with any num_kv_head. Previous if num_kv_head == 1 and dtype=bfloat16, we will have implicit padding in TPU. But now, after just using hidden states directly from projection, we no-longer need to use strided load, but just load by slice directly.
This PR should help us support multi-chip sharding which shard num_kv_head to 1 for llama-3-70B.
Tested: