MLA RoPE + quantization fused kernel: shape generalization for MHA / GQA#1924
MLA RoPE + quantization fused kernel: shape generalization for MHA / GQA#1924kahyunnam merged 14 commits intoflashinfer-ai:mainfrom
Conversation
fa658e2 to
bd2a338
Compare
bd2a338 to
0fc5d11
Compare
|
@nvpohanh for another set of eyes. |
pavanimajety
left a comment
There was a problem hiding this comment.
LGTM, mostly just nits for documentation and benchmark. Thanks for the effort!
yzh119
left a comment
There was a problem hiding this comment.
Overall LGTM, left some comments for suggestion.
WalkthroughRenames MLA-specific RoPE APIs to generalized rope_quantize/rope_quantize_fp8, rewrites C++ binding and CUDA kernel to support variable rope/no-rope dims and MLA/GQA/MHA layouts, replaces the MLA-only benchmark with a unified benchmark script, and adds parameterized tests covering all attention types. Changes
Sequence Diagram(s)sequenceDiagram
participant User
participant PyAPI as flashinfer.rope_quantize_fp8
participant Binding as rope_quantize (C++)
participant Kernel as RopeQuantizeKernel
participant Outputs
rect rgb(250,250,255)
Note over User,Outputs: Generalized RoPE + FP8 quantization flow (MLA/GQA/MHA)
User->>PyAPI: call rope_quantize_fp8(q, k, ..., is_neox?, quantize_dtype?)
PyAPI->>Binding: validate shapes/dtypes, split rope/no-rope, prepare tensors/strides
Binding->>Kernel: dispatch(num_qo_heads, num_kv_heads, rope_dim, no_rope_dim, strides, scales, interleave)
Kernel->>Outputs: apply per-head/chunk RoPE and quantize outputs to FP8
end
Estimated code review effort🎯 4 (Complex) | ⏱️ ~60 minutes Poem
Pre-merge checks and finishing touches✅ Passed checks (3 passed)
✨ Finishing touches🧪 Generate unit tests (beta)
Comment |
There was a problem hiding this comment.
Actionable comments posted: 1
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
include/flashinfer/pos_enc.cuh (1)
447-487: Critical: OOB loads/stores in non-RoPE paths when no_rope_dim < rope_dim or tail not multiple of vec_sizebdx is derived from rope_dim, but K/Q non-RoPE branches index with txvec_size against no_rope_dim without guarding. Threads with txvec_size ≥ no_rope_dim will cast_load/cast_store out-of-bounds. Tails not divisible by vec_size also risk partial overruns.
Fix by bounding per-chunk width and handling tails safely. Example patch:
@@ - uint32_t elem_offset = nope_chunk_idx * rope_chunk_size; // Use same chunk size + uint32_t elem_offset = nope_chunk_idx * rope_chunk_size; // Use same chunk size + // Bound valid width for this chunk + uint32_t valid = (no_rope_dim > elem_offset) + ? min(rope_chunk_size, no_rope_dim - elem_offset) + : 0; @@ - vec_t<float, vec_size> k_nope_vec; - k_nope_vec.cast_load(k_nope_in_ptr + tx * vec_size); -#pragma unroll - for (uint32_t i = 0; i < vec_size; ++i) { - k_nope_vec[i] = k_nope_vec[i] * quant_scale_kv; - } - k_nope_vec.cast_store(k_nope_out_ptr + tx * vec_size); + if (tx * vec_size + (vec_size - 1) < valid) { + // Full vector fits + vec_t<float, vec_size> k_nope_vec; + k_nope_vec.cast_load(k_nope_in_ptr + tx * vec_size); +#pragma unroll + for (uint32_t i = 0; i < vec_size; ++i) { + k_nope_vec[i] = k_nope_vec[i] * quant_scale_kv; + } + k_nope_vec.cast_store(k_nope_out_ptr + tx * vec_size); + } else if (tx * vec_size < valid) { + // Scalar tail + const uint32_t rem = valid - tx * vec_size; + for (uint32_t i = 0; i < rem; ++i) { + float v = static_cast<float>(k_nope_in_ptr[tx * vec_size + i]) * quant_scale_kv; + k_nope_out_ptr[tx * vec_size + i] = static_cast<QuantType>(v); + } + } @@ - uint32_t elem_offset = nope_chunk_idx * rope_chunk_size; // Use same chunk size + uint32_t elem_offset = nope_chunk_idx * rope_chunk_size; // Use same chunk size + uint32_t valid = (no_rope_dim > elem_offset) + ? min(rope_chunk_size, no_rope_dim - elem_offset) + : 0; @@ - vec_t<float, vec_size> q_nope_vec; - q_nope_vec.cast_load(q_nope_in_ptr + tx * vec_size); -#pragma unroll - for (uint32_t i = 0; i < vec_size; ++i) { - q_nope_vec[i] = q_nope_vec[i] * quant_scale_q; - } - q_nope_vec.cast_store(q_nope_out_ptr + tx * vec_size); + if (tx * vec_size + (vec_size - 1) < valid) { + vec_t<float, vec_size> q_nope_vec; + q_nope_vec.cast_load(q_nope_in_ptr + tx * vec_size); +#pragma unroll + for (uint32_t i = 0; i < vec_size; ++i) { + q_nope_vec[i] = q_nope_vec[i] * quant_scale_q; + } + q_nope_vec.cast_store(q_nope_out_ptr + tx * vec_size); + } else if (tx * vec_size < valid) { + const uint32_t rem = valid - tx * vec_size; + for (uint32_t i = 0; i < rem; ++i) { + float v = static_cast<float>(q_nope_in_ptr[tx * vec_size + i]) * quant_scale_q; + q_nope_out_ptr[tx * vec_size + i] = static_cast<QuantType>(v); + } + }If vec_t already supports masked store/load, prefer using that to avoid the scalar tail. Otherwise, the above prevents OOB while covering all elements.
♻️ Duplicate comments (2)
flashinfer/rope.py (1)
1262-1262: Add Python-level BC alias.To honor earlier feedback and avoid breaking downstreams, export the old name.
return q_rope_out, k_rope_out, q_nope_out, k_nope_out + +# Backward compatibility +mla_rope_quantize_fp8 = rope_quantize_fp8include/flashinfer/pos_enc.cuh (1)
816-827: PDL support added — thanksProgrammatic stream serialization path looks good and addresses prior ask for PDL.
🧹 Nitpick comments (11)
csrc/flashinfer_rope_binding.cu (1)
42-46: Preserve FFI BC: export the old symbol as an alias, too.Some external callers may still invoke flashinfer::mla_rope_quantize. Export it as an alias to rope_quantize to avoid runtime breakage.
TVM_FFI_DLL_EXPORT_TYPED_FUNC(apply_rope_pos_ids_cos_sin_cache, apply_rope_pos_ids_cos_sin_cache); -TVM_FFI_DLL_EXPORT_TYPED_FUNC(rope_quantize, rope_quantize); +TVM_FFI_DLL_EXPORT_TYPED_FUNC(rope_quantize, rope_quantize); +// Backward compatibility alias for older clients +TVM_FFI_DLL_EXPORT_TYPED_FUNC(mla_rope_quantize, rope_quantize);Also applies to: 53-53
tests/attention/test_rope.py (1)
359-379: Cover both RoPE layouts (is_neox True/False) and fix the stray main-call.
- Parametrize is_neox to test both interleaved and non-interleaved layouts.
- Remove or update the main invocation; it no longer matches the signature.
@pytest.mark.parametrize( "attention_type,num_qo_heads,num_kv_heads,rope_dim,no_rope_dim", @@ ) @pytest.mark.parametrize("num_tokens", [1, 19, 128, 199, 899, 2047]) @pytest.mark.parametrize("input_dtype", [torch.float16, torch.bfloat16]) @pytest.mark.parametrize("quant_dtype", [torch.float8_e4m3fn, torch.float8_e5m2]) +@pytest.mark.parametrize("is_neox", [True, False]) def test_generalized_rope_quantize( attention_type, num_qo_heads, num_kv_heads, rope_dim, no_rope_dim, num_tokens, input_dtype, quant_dtype, + is_neox, ): @@ - rope_flashinfer = FlashInferRotaryEmbedding( + rope_flashinfer = FlashInferRotaryEmbedding( total_dim, rope_dim, 4096, # max_position_embeddings 10000, # base - False, # is_neox_style + is_neox, # is_neox_style input_dtype, device, ) @@ - flashinfer.rope.rope_quantize_fp8( + flashinfer.rope.rope_quantize_fp8( q_rope_in, k_rope_in, q_nope_in, k_nope_in, rope_flashinfer.cos_sin_cache, pos_ids, - is_neox=False, + is_neox=is_neox, q_rope_out=q_rope_out, k_rope_out=k_rope_out, q_nope_out=q_nope_out, k_nope_out=k_nope_out, quant_scale_q=1.0, quant_scale_kv=1.0, )And at file end:
-if __name__ == "__main__": - # ... - test_generalized_rope_quantize(1, torch.float16, torch.float8_e4m3fn) +if __name__ == "__main__": + # Prefer: run via pytest; keep manual smoke disabled or provide a minimal harness. + passAlso applies to: 447-462
flashinfer/rope.py (3)
170-190: Docstring drift in _rope_quantize.This wrapper doesn’t “convert is_neox”; it forwards interleave. Please update the docstring to avoid confusion.
def _rope_quantize(...): - r"""Custom operator that routes to the CUDA kernel implementation. - - Converts is_neox parameter to interleave format and dispatches to the underlying - CUDA kernel via the JIT-compiled module. - """ + r"""Custom operator that routes to the CUDA kernel implementation. + Forwards the pre-split tensors and interleave flag to the JIT-compiled CUDA kernel. + """
1147-1211: Validate quantize_dtype early (fail fast).Raise early if quantize_dtype is not FP8; avoids diving into FFI before erroring.
def rope_quantize_fp8(..., quantize_dtype: Optional[torch.dtype] = None, ...): @@ if quantize_dtype is None: ... else: - pass + if quantize_dtype not in (torch.float8_e4m3fn, torch.float8_e5m2): + raise ValueError("quantize_dtype must be torch.float8_e4m3fn or torch.float8_e5m2")
249-260: Align fake op signature with the real custom op.The fake takes (cos_cache, sin_cache), but the real op takes a single cos_sin_cache. Keep them consistent.
-@register_fake_op("flashinfer::apply_rope_pos_ids_cos_sin_cache") -def _fake_apply_rope_pos_ids_cos_sin_cache( +@register_fake_op("flashinfer::apply_rope_pos_ids_cos_sin_cache") +def _fake_apply_rope_pos_ids_cos_sin_cache( q: torch.Tensor, k: torch.Tensor, q_rope: torch.Tensor, k_rope: torch.Tensor, - cos_cache: torch.Tensor, - sin_cache: torch.Tensor, - pos_ids: torch.Tensor, + cos_sin_cache: torch.Tensor, + pos_ids: torch.Tensor, interleave: bool, ) -> None: passbenchmarks/bench_rope_quantize_fp8.py (1)
19-89: Remove unused local FlashInferRotaryEmbedding to reduce duplication.This class is unused (you import RotaryEmbedding from tests). Dropping it simplifies the benchmark module.
-# Local FlashInferRotaryEmbedding implementation (unused) -class FlashInferRotaryEmbedding(nn.Module): - ... - def _apply_rotary_emb(...): - ... - return ...include/flashinfer/pos_enc.cuh (5)
362-370: Use rope_chunk_size for RoPE, but compute per-branch chunk width for non‑RoPE to avoid wasted CTAsrope_chunk_size is set to rope_dim and reused for no_rope_dim. This is fine for no_rope_dim ≥ rope_dim (MLA case), but for no_rope_dim < rope_dim it spawns blocks where many threads do nothing after the fix. Consider computing:
- rope_chunk_size = rope_dim
- nope_chunk_size = min(rope_dim, round_up_to_multiple(no_rope_dim, vec_size))
and use separate chunk counts per branch to reduce empty work.
Also applies to: 769-773
723-742: Broaden DISPATCH_ROPE_DIM or add a generic fallbackLimiting rope_dim to {16,32,64,128,256} constrains future models (e.g., 80, 160, 192). Add cases or a generic path:
- If rope_dim % (32/sizeof(DType)) == 0, set constexpr bdx = rope_dim / vec_size and dispatch; else return an error. This preserves compile‑time bdx while widening support.
760-771: Verify vec_size choice and alignment assumptionsvec_size = 32 / sizeof(DType) moves from the 16‑byte convention used elsewhere (and often matches L1/L2 transaction granularity). Please verify alignment of q/k pointers and cos_sin_cache for 32‑byte vectorization on all dtypes; if not guaranteed, fall back to max(16/sizeof(DType), rope_dim/32) as used in other paths for safety.
773-806: Remove dead args[] or switch to cudaLaunchKernel with args[]args[] is constructed but not used with cudaLaunchKernelEx (you pass typed args directly). Drop args[] to reduce confusion, or switch to cudaLaunchKernel((void*)kernel, ..., args, ...) for consistency with other call sites.
Also applies to: 829-837
352-357: Naming consistency: k_ stride parameters*q_* use q__stride_n/h, while k use k_stride and kstride_h. Consider renaming to k*_stride_n for symmetry with get_elem_offset_impl usage. No behavior change, just readability.
Also applies to: 749-754
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (7)
benchmarks/bench_mla_rope_quantize_fp8.py(0 hunks)benchmarks/bench_rope_quantize_fp8.py(1 hunks)csrc/flashinfer_rope_binding.cu(1 hunks)csrc/rope.cu(3 hunks)flashinfer/rope.py(6 hunks)include/flashinfer/pos_enc.cuh(2 hunks)tests/attention/test_rope.py(2 hunks)
💤 Files with no reviewable changes (1)
- benchmarks/bench_mla_rope_quantize_fp8.py
🧰 Additional context used
🧬 Code graph analysis (5)
csrc/flashinfer_rope_binding.cu (1)
csrc/rope.cu (2)
rope_quantize(271-421)rope_quantize(271-275)
tests/attention/test_rope.py (3)
benchmarks/bench_rope_quantize_fp8.py (1)
FlashInferRotaryEmbedding(19-88)tests/test_helpers/rope_reference.py (1)
forward_native(194-232)flashinfer/rope.py (1)
rope_quantize_fp8(1147-1262)
benchmarks/bench_rope_quantize_fp8.py (3)
flashinfer/testing/utils.py (2)
bench_gpu_time(972-1033)bench_gpu_time_with_cudagraph(855-969)tests/test_helpers/rope_reference.py (2)
RotaryEmbedding(117-232)forward_native(194-232)flashinfer/rope.py (1)
rope_quantize_fp8(1147-1262)
csrc/rope.cu (2)
flashinfer/comm/cuda_ipc.py (2)
cudaSetDevice(149-150)cudaGetErrorString(146-147)csrc/tvm_ffi_utils.h (1)
get_stream(272-274)
flashinfer/rope.py (3)
csrc/rope.cu (2)
rope_quantize(271-421)rope_quantize(271-275)csrc/flashinfer_rope_binding.cu (1)
rope_quantize(42-46)flashinfer/utils.py (2)
register_fake_op(277-281)register_fake_op(306-311)
🪛 Ruff (0.14.0)
benchmarks/bench_rope_quantize_fp8.py
111-111: Avoid specifying long messages outside the exception class
(TRY003)
214-214: Avoid specifying long messages outside the exception class
(TRY003)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
- GitHub Check: Deploy Docs
5c56c72 to
f75ce86
Compare
There was a problem hiding this comment.
Actionable comments posted: 1
🧹 Nitpick comments (1)
tests/attention/test_rope.py (1)
448-462: Consider explicitly passingquantize_dtypefor clarity.The test relies on implicit inference of
quantize_dtypefrom the pre-allocated output tensors. While this should work according to the function's logic, explicitly passingquantize_dtype=quant_dtypewould make the test's intent clearer and more robust.Apply this diff:
# Call the generalized function flashinfer.rope.rope_quantize_fp8( q_rope_in, k_rope_in, q_nope_in, k_nope_in, rope_flashinfer.cos_sin_cache, pos_ids, is_neox=False, + quantize_dtype=quant_dtype, q_rope_out=q_rope_out, k_rope_out=k_rope_out, q_nope_out=q_nope_out, k_nope_out=k_nope_out, quant_scale_q=1.0, quant_scale_kv=1.0, )
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📥 Commits
Reviewing files that changed from the base of the PR and between 8fee3d9 and 5c56c72aa6859f3829ac4850b9a20926a6557c59.
📒 Files selected for processing (1)
tests/attention/test_rope.py(2 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
tests/attention/test_rope.py (3)
benchmarks/bench_rope_quantize_fp8.py (1)
FlashInferRotaryEmbedding(19-88)tests/test_helpers/rope_reference.py (1)
forward_native(194-232)flashinfer/rope.py (1)
rope_quantize_fp8(1147-1262)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
- GitHub Check: Deploy Docs
🔇 Additional comments (1)
tests/attention/test_rope.py (1)
359-410: Excellent test coverage for multiple attention architectures.The parametrization comprehensively covers MLA, GQA, and MHA configurations with realistic head counts and dimensions, including explicit DeepSeek R1 MLA and Llama3 8B/70B configurations. The input tensor creation correctly handles the 2D vs 3D key tensor distinction between MLA and GQA/MHA.
f75ce86 to
32149b6
Compare
There was a problem hiding this comment.
Actionable comments posted: 0
🧹 Nitpick comments (8)
tests/attention/test_rope.py (5)
359-379: Add coverage for both RoPE layouts (is_neox True/False).Currently only interleaved is exercised (is_neox=False). Parameterize is_neox to cover both code paths and kernel branches.
424-430: Consider cross-checking with CUDA path too.Besides forward_native, add a compare against forward_cuda to validate equivalence with our cos/sin-cache operator under the same settings.
431-446: Assert output dtypes and shapes explicitly.After the call, add quick checks that q_out.dtype == quant_dtype and k_out.dtype == quant_dtype and shapes match inputs to catch silent dtype/shape regressions.
+ assert q_out.dtype is quant_dtype and k_out.dtype is quant_dtype + assert q_out.shape == q_in.shape and k_out.shape == k_in.shape
465-478: Tighten tolerances or justify rtol=0.2.rtol=2e-1 is generous for FP8-cast equivalence. If feasible, lower rtol (e.g., 5e-2) or add a brief comment explaining why 0.2 is required here.
538-544: Remove direct test invocation under main.Stray manual call can surprise users running the file. Prefer relying on pytest discovery or guard it behind an env flag.
- test_mla_rope_quantize(1, torch.float16, torch.float8_e4m3fn) + pass # Rely on pytest collectionflashinfer/rope.py (3)
166-205: Fix docstring: this function already receivesinterleave.The doc says it “converts is_neox to interleave,” but
_rope_quantizetakesinterleavedirectly. Update the doc to avoid confusion.- Converts is_neox parameter to interleave format and dispatches to the underlying - CUDA kernel via the JIT-compiled module. + Routes to the CUDA kernel via the JIT-compiled module using the provided + `interleave` flag.
1147-1179: Backward-compat API preserved — LGTM.Wrapper maintains
mla_rope_quantize_fp8while delegating to the generalized API. Consider a direct alias to reduce indirection if docs don’t need a separate entry.- def mla_rope_quantize_fp8(...): - return rope_quantize_fp8(...) +mla_rope_quantize_fp8 = rope_quantize_fp8
1181-1296: Optional: upfront dtype validation for clearer errors.The CUDA op will validate dtypes, but adding lightweight Python-side checks (allowed input dtypes: fp16/bf16; output/quant dtypes: FP8 e4m3/e5m2) can surface mistakes earlier.
+ if q_rope.dtype not in (torch.float16, torch.bfloat16) or q_nope.dtype != q_rope.dtype: + raise TypeError("q_* inputs must be float16 or bfloat16 and match dtypes") + if k_rope.dtype not in (torch.float16, torch.bfloat16) or k_nope.dtype != k_rope.dtype: + raise TypeError("k_* inputs must be float16 or bfloat16 and match dtypes") + if quantize_dtype not in (None, torch.float8_e4m3fn, torch.float8_e5m2): + raise TypeError("quantize_dtype must be FP8 e4m3fn or e5m2")
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📥 Commits
Reviewing files that changed from the base of the PR and between f75ce8640a2b5ba6b39fdcfd22516e0653570471 and 32149b6.
📒 Files selected for processing (2)
flashinfer/rope.py(5 hunks)tests/attention/test_rope.py(1 hunks)
🧰 Additional context used
🧬 Code graph analysis (2)
tests/attention/test_rope.py (3)
benchmarks/bench_rope_quantize_fp8.py (1)
FlashInferRotaryEmbedding(19-88)tests/test_helpers/rope_reference.py (1)
forward_native(194-232)flashinfer/rope.py (1)
rope_quantize_fp8(1181-1296)
flashinfer/rope.py (3)
csrc/flashinfer_rope_binding.cu (1)
rope_quantize(42-46)csrc/rope.cu (2)
rope_quantize(271-421)rope_quantize(271-275)flashinfer/utils.py (2)
register_fake_op(277-281)register_fake_op(306-311)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
- GitHub Check: Deploy Docs
yzh119
left a comment
There was a problem hiding this comment.
LGTM, thank you @kahyunnam !
There was a problem hiding this comment.
Actionable comments posted: 1
🧹 Nitpick comments (5)
include/flashinfer/pos_enc.cuh (5)
387-387: Use && for consistency.Replace C++ alternative token “and” with “&&” to match the rest of the file and avoid style churn.
- if ((tx * vec_size < rope_dim) and (by < k_rope_end)) { + if ((tx * vec_size < rope_dim) && (by < k_rope_end)) {
729-748: DISPATCH_ROPE_DIM coverage — confirm required sizes.Only {16,32,64,128,256} are supported. If models with rotary_dim=80 (seen in some configs) must be supported, extend the dispatch or fail fast with a clear error at API entry.
766-767: Vector width set to 32B — verify alignment.vec_size = 32/sizeof(DType) implies 32‑byte wide loads/stores. Ensure q/k/nope and cos_sin_cache pointers are 32B‑aligned; otherwise, you may hit misaligned global accesses (perf hit or UB if vec_t assumes alignment). If 32B alignment isn’t guaranteed, fall back to the 16B convention used elsewhere:
- constexpr uint32_t vec_size = 32 / sizeof(DType); + constexpr uint32_t vec_size = 16 / sizeof(DType);Or document/enforce alignment preconditions at the API boundary.
779-812: Remove unused args[].args is built but not used with cudaLaunchKernelEx’s typed invocation. Drop it to reduce noise.
- void* args[] = {(void*)&q_rope_in, - (void*)&k_rope_in, - (void*)&q_nope_in, - (void*)&k_nope_in, - (void*)&q_rope_out, - (void*)&k_rope_out, - (void*)&q_nope_out, - (void*)&k_nope_out, - (void*)&cos_sin_cache, - (void*)&pos_ids, - (void*)&nnz, - (void*)&num_qo_heads, - (void*)&num_kv_heads, - (void*)&rope_dim, - (void*)&no_rope_dim, - (void*)&q_rope_in_stride_n, - (void*)&q_rope_in_stride_h, - (void*)&q_nope_in_stride_n, - (void*)&q_nope_in_stride_h, - (void*)&q_rope_out_stride_n, - (void*)&q_rope_out_stride_h, - (void*)&q_nope_out_stride_n, - (void*)&q_nope_out_stride_h, - (void*)&k_rope_in_stride, - (void*)&k_rope_in_stride_h, - (void*)&k_nope_in_stride, - (void*)&k_nope_in_stride_h, - (void*)&k_rope_out_stride, - (void*)&k_rope_out_stride_h, - (void*)&k_nope_out_stride, - (void*)&k_nope_out_stride_h, - (void*)&quant_scale_q, - (void*)&quant_scale_kv};
352-357: Name stride params consistently (_stride_n/_stride_h).k__stride lacks the “n” suffix unlike q counterparts. Consider renaming for symmetry and to reduce call‑site confusion.
Also applies to: 431-436, 459-461, 478-482
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (1)
include/flashinfer/pos_enc.cuh(3 hunks)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
- GitHub Check: Deploy Docs
🔇 Additional comments (1)
include/flashinfer/pos_enc.cuh (1)
358-361: PDL grid dependency hooks added — looks good.The kernel now issues griddepcontrol.wait at entry and launch_dependents at exit under CUDA 12+/sm90 guards. This satisfies the earlier review ask for PDL. Please run a quick H100 sanity to confirm no functional/perf regressions with enable_pdl on/off.
Also applies to: 492-495
| // Calculate flexible boundaries for block allocation | ||
| uint32_t rope_chunk_size = rope_dim; // Process entire rope_dim per chunk | ||
| uint32_t rope_chunks = (rope_dim + rope_chunk_size - 1) / rope_chunk_size; | ||
| uint32_t no_rope_chunks = (no_rope_dim + rope_chunk_size - 1) / rope_chunk_size; | ||
|
|
||
| uint32_t q_rope_end = num_qo_heads * rope_chunks; | ||
| uint32_t k_rope_end = q_rope_end + num_kv_heads * rope_chunks; | ||
| uint32_t k_nope_end = k_rope_end + num_kv_heads * no_rope_chunks; | ||
|
|
There was a problem hiding this comment.
Fix OOB on no_rope tails (bdx derived from rope_dim, no tail masking).
no_rope blocks reuse bdx = rope_dim/vec_size and step by rope_chunk_size = rope_dim. When no_rope_dim is not a multiple of rope_dim or vec_size, the last chunk can read/write past valid elements. Add per‑lane bounds and masked tail load/store for both K and Q no‑RoPE paths.
Apply this diff (K no‑RoPE):
@@
- vec_t<float, vec_size> k_nope_vec;
- k_nope_vec.cast_load(k_nope_in_ptr + tx * vec_size);
-#pragma unroll
- for (uint32_t i = 0; i < vec_size; ++i) {
- k_nope_vec[i] = k_nope_vec[i] * quant_scale_kv;
- }
- k_nope_vec.cast_store(k_nope_out_ptr + tx * vec_size);
+ const uint32_t lane = tx * vec_size;
+ const uint32_t valid = min(no_rope_dim - elem_offset, rope_chunk_size);
+ if (lane + vec_size <= valid) {
+ vec_t<float, vec_size> k_nope_vec;
+ k_nope_vec.cast_load(k_nope_in_ptr + lane);
+#pragma unroll
+ for (uint32_t i = 0; i < vec_size; ++i) {
+ k_nope_vec[i] = k_nope_vec[i] * quant_scale_kv;
+ }
+ k_nope_vec.cast_store(k_nope_out_ptr + lane);
+ } else if (lane < valid) {
+#pragma unroll
+ for (uint32_t j = 0; j < vec_size && (lane + j) < valid; ++j) {
+ float v = static_cast<float>(k_nope_in_ptr[lane + j]);
+ k_nope_out_ptr[lane + j] = static_cast<QuantType>(v * quant_scale_kv);
+ }
+ }…and similarly for Q no‑RoPE:
@@
- vec_t<float, vec_size> q_nope_vec;
- q_nope_vec.cast_load(q_nope_in_ptr + tx * vec_size);
-#pragma unroll
- for (uint32_t i = 0; i < vec_size; ++i) {
- q_nope_vec[i] = q_nope_vec[i] * quant_scale_q;
- }
- q_nope_vec.cast_store(q_nope_out_ptr + tx * vec_size);
+ const uint32_t lane = tx * vec_size;
+ const uint32_t valid = min(no_rope_dim - elem_offset, rope_chunk_size);
+ if (lane + vec_size <= valid) {
+ vec_t<float, vec_size> q_nope_vec;
+ q_nope_vec.cast_load(q_nope_in_ptr + lane);
+#pragma unroll
+ for (uint32_t i = 0; i < vec_size; ++i) {
+ q_nope_vec[i] = q_nope_vec[i] * quant_scale_q;
+ }
+ q_nope_vec.cast_store(q_nope_out_ptr + lane);
+ } else if (lane < valid) {
+#pragma unroll
+ for (uint32_t j = 0; j < vec_size && (lane + j) < valid; ++j) {
+ float v = static_cast<float>(q_nope_in_ptr[lane + j]);
+ q_nope_out_ptr[lane + j] = static_cast<QuantType>(v * quant_scale_q);
+ }
+ }If you prefer not to implement masked tails now, at minimum guard with a runtime check and return cudaErrorInvalidValue when (no_rope_dim % vec_size) != 0 to avoid silent memory bugs.
Also applies to: 450-469, 470-490
…sed RoPE + Q + KV cache, supports MLA/GQA/MHA) (#2037) <!-- .github/pull_request_template.md --> ## 📌 Description Add `flashinfer.rope.rope_quantize_fp8_append_paged_kv_cache`, which runs a fused RoPE + Quantization (16 -> 8) + append KV Cache operation kernel. Note that this does not support optional quantization (there is no "RoPE + append KV Cache" fused operation available). Tested on NVIDIA H100 NVL + flashinfer/flashinfer-ci-cu130:latest for MLA/MHA/GQA problem sizes for decode and prefill cases. ## 🔍 Related Issues "[Model Optimization] Add RoPE, RoPE+Q, RoPE+Q+KVCacheUpdate fused kernels for MLA/GQA/MHA" item from Q4 roadmap: #1770. This PR is part 2 to earlier PR for RoPE + Q: #1924 FW Stakeholders: @nvpohanh @pavanimajety ## 🧪 Test results ``` $ pytest tests/attention/test_rope.py::test_rope_quantize_fp8_append_paged_kv_cache_decode -s ======================================================== test session starts =========================================================platform linux -- Python 3.12.11, pytest-8.4.2, pluggy-1.6.0 rootdir: /workspace/flashinfer configfile: pytest.ini collected 384 items tests/attention/test_rope.py ................................................................................................................................................................................................................................................................................................................................................................................................ ======================================================== 384 passed in 35.22s ======================================================== ``` ``` $ pytest tests/attention/test_rope.py::test_generalized_rope_quantize_append_kv_cache -s ======================================================== test session starts ========================================================= platform linux -- Python 3.12.11, pytest-8.4.2, pluggy-1.6.0 rootdir: /workspace/flashinfer configfile: pytest.ini collected 1248 items tests/attention/test_rope.py ......................................................................................................... ...................................................................................................................................... ...................................................................................................................................... ...................................................................................................................................... ...................................................................................................................................... ...................................................................................................................................... ...................................................................................................................................... ...................................................................................................................................... ...................................................................................................................................... ....................................................................... ================================================== 1248 passed in 63.07s (0:01:03) =================================================== ``` ``` $ python benchmarks/bench_rope_quantize_fp8_append_cache.py Detected GPU: NVIDIA GB200 Theoretical Peak Memory Bandwidth: 7928.06 GB/s ==================================================================================================== MLA: 128 Q heads, 1 K head, 64+512 dims (DeepSeek-style) ==================================================================================================== Tokens Time (ms) BW (GB/s) BW% (Peak) TFLOPs ---------------------------------------------------------------------- 1 0.00258 86.53 1.1 0.010 32 0.00381 1873.82 23.6 0.208 128 0.00763 3744.50 47.2 0.416 384 0.01848 4637.34 58.5 0.515 768 0.03694 4639.75 58.5 0.515 1024 0.04879 4683.57 59.1 0.520 2048 0.09590 4766.09 60.1 0.529 4096 0.19031 4803.27 60.6 0.533 8192 0.38523 4745.78 59.9 0.527 ==================================================================================================== GQA: 32 Q heads, 8 K heads, 64+64 dims (Llama-style) ==================================================================================================== Tokens Time (ms) BW (GB/s) BW% (Peak) TFLOPs ---------------------------------------------------------------------- 1 0.00294 6.36 0.1 0.003 32 0.00316 189.48 2.4 0.078 128 0.00317 755.23 9.5 0.310 384 0.00398 1803.09 22.7 0.741 768 0.00522 2750.51 34.7 1.130 1024 0.00617 3100.80 39.1 1.274 2048 0.00927 4130.83 52.1 1.697 4096 0.01631 4695.01 59.2 1.929 8192 0.03466 4418.01 55.7 1.815 ==================================================================================================== MHA: 32 Q heads, 32 K heads, 64+64 dims (Standard) ==================================================================================================== Tokens Time (ms) BW (GB/s) BW% (Peak) TFLOPs ---------------------------------------------------------------------- 1 0.00293 12.68 0.2 0.004 32 0.00313 379.98 4.8 0.126 128 0.00357 1331.80 16.8 0.441 384 0.00517 2756.73 34.8 0.912 768 0.00742 3840.41 48.4 1.271 1024 0.00887 4287.15 54.1 1.419 2048 0.01504 5055.18 63.8 1.673 4096 0.03343 4548.12 57.4 1.505 8192 0.06410 4744.76 59.8 1.571 ==================================================================================================== Configuration details: Page size: 32, Batch size: 4 Token range: 1 (single decode) → 8192 (large prefill) GPU: NVIDIA GB200 Theoretical Peak Memory Bandwidth: 7928.06 GB/s BW% calculated as: (achieved_bandwidth / peak_bandwidth) * 100 ==================================================================================================== ``` ## 🚀 Pull Request Checklist Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete. ### ✅ Pre-commit Checks - [x] I have installed `pre-commit` by running `pip install pre-commit` (or used your preferred method). - [x] I have installed the hooks with `pre-commit install`. - [x] I have run the hooks manually with `pre-commit run --all-files` and fixed any reported issues. > If you are unsure about how to set up `pre-commit`, see [the pre-commit documentation](https://pre-commit.com/). ## 🧪 Tests - [x] Tests have been added or updated as needed. - [x] All tests are passing (`unittest`, etc.). ## Reviewer Notes <!-- Optional: anything you'd like reviewers to focus on, concerns, etc. --> <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit * **New Features** * Fused RoPE + FP8 quantize-and-append for paged KV caches (MLA, GQA/MHA) with layout, page-size, interleave and PDL options; returns quantized Q outputs and writes K/V into paged caches; public ops and high-level API added. * **Tests** * Deterministic, parameterized tests for append and decode/continuation across attention types, layouts, dtypes and quant settings with reference validation. * **Benchmarks** * New benchmark script for performance, bandwidth and Nsight profiling of the paged-KV quantize+append path. * **Chores** * Added cached GPU memory-bandwidth utility for benchmarks. <!-- end of auto-generated comment: release notes by coderabbit.ai --> --------- Co-authored-by: Zihao Ye <expye@outlook.com>
…sed RoPE + Q + KV cache, supports MLA/GQA/MHA) (flashinfer-ai#2037) <!-- .github/pull_request_template.md --> Add `flashinfer.rope.rope_quantize_fp8_append_paged_kv_cache`, which runs a fused RoPE + Quantization (16 -> 8) + append KV Cache operation kernel. Note that this does not support optional quantization (there is no "RoPE + append KV Cache" fused operation available). Tested on NVIDIA H100 NVL + flashinfer/flashinfer-ci-cu130:latest for MLA/MHA/GQA problem sizes for decode and prefill cases. "[Model Optimization] Add RoPE, RoPE+Q, RoPE+Q+KVCacheUpdate fused kernels for MLA/GQA/MHA" item from Q4 roadmap: flashinfer-ai#1770. This PR is part 2 to earlier PR for RoPE + Q: flashinfer-ai#1924 FW Stakeholders: @nvpohanh @pavanimajety ``` $ pytest tests/attention/test_rope.py::test_rope_quantize_fp8_append_paged_kv_cache_decode -s ======================================================== test session starts =========================================================platform linux -- Python 3.12.11, pytest-8.4.2, pluggy-1.6.0 rootdir: /workspace/flashinfer configfile: pytest.ini collected 384 items tests/attention/test_rope.py ................................................................................................................................................................................................................................................................................................................................................................................................ ======================================================== 384 passed in 35.22s ======================================================== ``` ``` $ pytest tests/attention/test_rope.py::test_generalized_rope_quantize_append_kv_cache -s ======================================================== test session starts ========================================================= platform linux -- Python 3.12.11, pytest-8.4.2, pluggy-1.6.0 rootdir: /workspace/flashinfer configfile: pytest.ini collected 1248 items tests/attention/test_rope.py ......................................................................................................... ...................................................................................................................................... ...................................................................................................................................... ...................................................................................................................................... ...................................................................................................................................... ...................................................................................................................................... ...................................................................................................................................... ...................................................................................................................................... ...................................................................................................................................... ....................................................................... ================================================== 1248 passed in 63.07s (0:01:03) =================================================== ``` ``` $ python benchmarks/bench_rope_quantize_fp8_append_cache.py Detected GPU: NVIDIA GB200 Theoretical Peak Memory Bandwidth: 7928.06 GB/s ==================================================================================================== MLA: 128 Q heads, 1 K head, 64+512 dims (DeepSeek-style) ==================================================================================================== Tokens Time (ms) BW (GB/s) BW% (Peak) TFLOPs ---------------------------------------------------------------------- 1 0.00258 86.53 1.1 0.010 32 0.00381 1873.82 23.6 0.208 128 0.00763 3744.50 47.2 0.416 384 0.01848 4637.34 58.5 0.515 768 0.03694 4639.75 58.5 0.515 1024 0.04879 4683.57 59.1 0.520 2048 0.09590 4766.09 60.1 0.529 4096 0.19031 4803.27 60.6 0.533 8192 0.38523 4745.78 59.9 0.527 ==================================================================================================== GQA: 32 Q heads, 8 K heads, 64+64 dims (Llama-style) ==================================================================================================== Tokens Time (ms) BW (GB/s) BW% (Peak) TFLOPs ---------------------------------------------------------------------- 1 0.00294 6.36 0.1 0.003 32 0.00316 189.48 2.4 0.078 128 0.00317 755.23 9.5 0.310 384 0.00398 1803.09 22.7 0.741 768 0.00522 2750.51 34.7 1.130 1024 0.00617 3100.80 39.1 1.274 2048 0.00927 4130.83 52.1 1.697 4096 0.01631 4695.01 59.2 1.929 8192 0.03466 4418.01 55.7 1.815 ==================================================================================================== MHA: 32 Q heads, 32 K heads, 64+64 dims (Standard) ==================================================================================================== Tokens Time (ms) BW (GB/s) BW% (Peak) TFLOPs ---------------------------------------------------------------------- 1 0.00293 12.68 0.2 0.004 32 0.00313 379.98 4.8 0.126 128 0.00357 1331.80 16.8 0.441 384 0.00517 2756.73 34.8 0.912 768 0.00742 3840.41 48.4 1.271 1024 0.00887 4287.15 54.1 1.419 2048 0.01504 5055.18 63.8 1.673 4096 0.03343 4548.12 57.4 1.505 8192 0.06410 4744.76 59.8 1.571 ==================================================================================================== Configuration details: Page size: 32, Batch size: 4 Token range: 1 (single decode) → 8192 (large prefill) GPU: NVIDIA GB200 Theoretical Peak Memory Bandwidth: 7928.06 GB/s BW% calculated as: (achieved_bandwidth / peak_bandwidth) * 100 ==================================================================================================== ``` Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete. - [x] I have installed `pre-commit` by running `pip install pre-commit` (or used your preferred method). - [x] I have installed the hooks with `pre-commit install`. - [x] I have run the hooks manually with `pre-commit run --all-files` and fixed any reported issues. > If you are unsure about how to set up `pre-commit`, see [the pre-commit documentation](https://pre-commit.com/). - [x] Tests have been added or updated as needed. - [x] All tests are passing (`unittest`, etc.). <!-- Optional: anything you'd like reviewers to focus on, concerns, etc. --> <!-- This is an auto-generated comment: release notes by coderabbit.ai --> * **New Features** * Fused RoPE + FP8 quantize-and-append for paged KV caches (MLA, GQA/MHA) with layout, page-size, interleave and PDL options; returns quantized Q outputs and writes K/V into paged caches; public ops and high-level API added. * **Tests** * Deterministic, parameterized tests for append and decode/continuation across attention types, layouts, dtypes and quant settings with reference validation. * **Benchmarks** * New benchmark script for performance, bandwidth and Nsight profiling of the paged-KV quantize+append path. * **Chores** * Added cached GPU memory-bandwidth utility for benchmarks. <!-- end of auto-generated comment: release notes by coderabbit.ai --> --------- Co-authored-by: Zihao Ye <expye@outlook.com>
…sed RoPE + Q + KV cache, supports MLA/GQA/MHA) (#2037) <!-- .github/pull_request_template.md --> ## 📌 Description Add `flashinfer.rope.rope_quantize_fp8_append_paged_kv_cache`, which runs a fused RoPE + Quantization (16 -> 8) + append KV Cache operation kernel. Note that this does not support optional quantization (there is no "RoPE + append KV Cache" fused operation available). Tested on NVIDIA H100 NVL + flashinfer/flashinfer-ci-cu130:latest for MLA/MHA/GQA problem sizes for decode and prefill cases. ## 🔍 Related Issues "[Model Optimization] Add RoPE, RoPE+Q, RoPE+Q+KVCacheUpdate fused kernels for MLA/GQA/MHA" item from Q4 roadmap: flashinfer-ai/flashinfer#1770. This PR is part 2 to earlier PR for RoPE + Q: flashinfer-ai/flashinfer#1924 FW Stakeholders: @nvpohanh @pavanimajety ## 🧪 Test results ``` $ pytest tests/attention/test_rope.py::test_rope_quantize_fp8_append_paged_kv_cache_decode -s ======================================================== test session starts =========================================================platform linux -- Python 3.12.11, pytest-8.4.2, pluggy-1.6.0 rootdir: /workspace/flashinfer configfile: pytest.ini collected 384 items tests/attention/test_rope.py ................................................................................................................................................................................................................................................................................................................................................................................................ ======================================================== 384 passed in 35.22s ======================================================== ``` ``` $ pytest tests/attention/test_rope.py::test_generalized_rope_quantize_append_kv_cache -s ======================================================== test session starts ========================================================= platform linux -- Python 3.12.11, pytest-8.4.2, pluggy-1.6.0 rootdir: /workspace/flashinfer configfile: pytest.ini collected 1248 items tests/attention/test_rope.py ......................................................................................................... ...................................................................................................................................... ...................................................................................................................................... ...................................................................................................................................... ...................................................................................................................................... ...................................................................................................................................... ...................................................................................................................................... ...................................................................................................................................... ...................................................................................................................................... ....................................................................... ================================================== 1248 passed in 63.07s (0:01:03) =================================================== ``` ``` $ python benchmarks/bench_rope_quantize_fp8_append_cache.py Detected GPU: NVIDIA GB200 Theoretical Peak Memory Bandwidth: 7928.06 GB/s ==================================================================================================== MLA: 128 Q heads, 1 K head, 64+512 dims (DeepSeek-style) ==================================================================================================== Tokens Time (ms) BW (GB/s) BW% (Peak) TFLOPs ---------------------------------------------------------------------- 1 0.00258 86.53 1.1 0.010 32 0.00381 1873.82 23.6 0.208 128 0.00763 3744.50 47.2 0.416 384 0.01848 4637.34 58.5 0.515 768 0.03694 4639.75 58.5 0.515 1024 0.04879 4683.57 59.1 0.520 2048 0.09590 4766.09 60.1 0.529 4096 0.19031 4803.27 60.6 0.533 8192 0.38523 4745.78 59.9 0.527 ==================================================================================================== GQA: 32 Q heads, 8 K heads, 64+64 dims (Llama-style) ==================================================================================================== Tokens Time (ms) BW (GB/s) BW% (Peak) TFLOPs ---------------------------------------------------------------------- 1 0.00294 6.36 0.1 0.003 32 0.00316 189.48 2.4 0.078 128 0.00317 755.23 9.5 0.310 384 0.00398 1803.09 22.7 0.741 768 0.00522 2750.51 34.7 1.130 1024 0.00617 3100.80 39.1 1.274 2048 0.00927 4130.83 52.1 1.697 4096 0.01631 4695.01 59.2 1.929 8192 0.03466 4418.01 55.7 1.815 ==================================================================================================== MHA: 32 Q heads, 32 K heads, 64+64 dims (Standard) ==================================================================================================== Tokens Time (ms) BW (GB/s) BW% (Peak) TFLOPs ---------------------------------------------------------------------- 1 0.00293 12.68 0.2 0.004 32 0.00313 379.98 4.8 0.126 128 0.00357 1331.80 16.8 0.441 384 0.00517 2756.73 34.8 0.912 768 0.00742 3840.41 48.4 1.271 1024 0.00887 4287.15 54.1 1.419 2048 0.01504 5055.18 63.8 1.673 4096 0.03343 4548.12 57.4 1.505 8192 0.06410 4744.76 59.8 1.571 ==================================================================================================== Configuration details: Page size: 32, Batch size: 4 Token range: 1 (single decode) → 8192 (large prefill) GPU: NVIDIA GB200 Theoretical Peak Memory Bandwidth: 7928.06 GB/s BW% calculated as: (achieved_bandwidth / peak_bandwidth) * 100 ==================================================================================================== ``` ## 🚀 Pull Request Checklist Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete. ### ✅ Pre-commit Checks - [x] I have installed `pre-commit` by running `pip install pre-commit` (or used your preferred method). - [x] I have installed the hooks with `pre-commit install`. - [x] I have run the hooks manually with `pre-commit run --all-files` and fixed any reported issues. > If you are unsure about how to set up `pre-commit`, see [the pre-commit documentation](https://pre-commit.com/). ## 🧪 Tests - [x] Tests have been added or updated as needed. - [x] All tests are passing (`unittest`, etc.). ## Reviewer Notes <!-- Optional: anything you'd like reviewers to focus on, concerns, etc. --> <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit * **New Features** * Fused RoPE + FP8 quantize-and-append for paged KV caches (MLA, GQA/MHA) with layout, page-size, interleave and PDL options; returns quantized Q outputs and writes K/V into paged caches; public ops and high-level API added. * **Tests** * Deterministic, parameterized tests for append and decode/continuation across attention types, layouts, dtypes and quant settings with reference validation. * **Benchmarks** * New benchmark script for performance, bandwidth and Nsight profiling of the paged-KV quantize+append path. * **Chores** * Added cached GPU memory-bandwidth utility for benchmarks. <!-- end of auto-generated comment: release notes by coderabbit.ai --> --------- Co-authored-by: Zihao Ye <expye@outlook.com>
📌 Description
Generalize the existing MLA RoPE+Q fused kernels to support GQA/MHA problem shapes.
🔍 Test Results
pytest -v tests/attention/test_rope.py::test_generalized_rope_quantize(Benchmark results on GB300)
python benchmarks/bench_rope_quantize_fp8.pymla-rope-benchmark.png:

gqa-rope-benchmark.png:

mha-rope-benchmark.png:

🚀 Pull Request Checklist
Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.
✅ Pre-commit Checks
pre-commitby runningpip install pre-commit(or used your preferred method).pre-commit install.pre-commit run --all-filesand fixed any reported issues.🧪 Tests
unittest, etc.).Summary by CodeRabbit
Refactor
New Features
Tests