Skip to content

MLA RoPE + quantization fused kernel: shape generalization for MHA / GQA#1924

Merged
kahyunnam merged 14 commits intoflashinfer-ai:mainfrom
kahyunnam:knam/RoPe-fusion
Oct 18, 2025
Merged

MLA RoPE + quantization fused kernel: shape generalization for MHA / GQA#1924
kahyunnam merged 14 commits intoflashinfer-ai:mainfrom
kahyunnam:knam/RoPe-fusion

Conversation

@kahyunnam
Copy link
Copy Markdown
Member

@kahyunnam kahyunnam commented Oct 13, 2025

📌 Description

Generalize the existing MLA RoPE+Q fused kernels to support GQA/MHA problem shapes.

🔍 Test Results

pytest -v tests/attention/test_rope.py::test_generalized_rope_quantize

============================================================================= 312 passed in 3.93s ==============================================================================

(Benchmark results on GB300)
python benchmarks/bench_rope_quantize_fp8.py

Running MLA benchmark...
Running GQA benchmark...
Running MHA benchmark...

=== Summary Table ===
Tokens   MLA-FI (ms)  MLA-Torch (ms) GQA-FI (ms)  GQA-Torch (ms) MHA-FI (ms)  MHA-Torch (ms)
------------------------------------------------------------------------------------------
1        0.00260      0.00417        0.00246      0.00406        0.00253      0.00406
2        0.00267      0.00406        0.00253      0.00427        0.00253      0.00407
4        0.00274      0.00468        0.00252      0.00420        0.00253      0.00427
8        0.00274      0.00550        0.00254      0.00427        0.00253      0.00427
16       0.00273      0.00754        0.00253      0.00427        0.00254      0.00439
32       0.00315      0.01121        0.00253      0.00448        0.00264      0.00468
64       0.00416      0.01830        0.00253      0.00508        0.00274      0.00560
128      0.00560      0.03274        0.00274      0.00632        0.00294      0.00775
256      0.00908      0.06161        0.00295      0.00857        0.00331      0.01155
384      0.01380      0.09127        0.00335      0.01102        0.00356      0.01533
512      0.02172      0.12066        0.00356      0.01339        0.00417      0.01901
768      0.03114      0.17852        0.00407      0.01820        0.00500      0.02659

Configuration details:
  MLA: 128 Q heads, 1 K head, 64+512 dims
  GQA: 32 Q heads, 8 K heads, 64+64 dims
  MHA: 32 Q heads, 32 K heads, 64+64 dims

Plot files saved to current directory:
  mla-rope-benchmark.png (FlashInfer vs PyTorch)
  gqa-rope-benchmark.png (FlashInfer vs PyTorch)
  mha-rope-benchmark.png (FlashInfer vs PyTorch)

mla-rope-benchmark.png:
image

gqa-rope-benchmark.png:
image

mha-rope-benchmark.png:
image

🚀 Pull Request Checklist

Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.

✅ Pre-commit Checks

  • I have installed pre-commit by running pip install pre-commit (or used your preferred method).
  • I have installed the hooks with pre-commit install.
  • I have run the hooks manually with pre-commit run --all-files and fixed any reported issues.

If you are unsure about how to set up pre-commit, see the pre-commit documentation.

🧪 Tests

  • Tests have been added or updated as needed.
  • All tests are passing (unittest, etc.).

Summary by CodeRabbit

  • Refactor

    • Generalized RoPE + FP8 quantization API and native implementation to support MLA, GQA, and MHA layouts with flexible shapes, dtypes, and improved validation/error reporting.
  • New Features

    • Consolidated RoPE quantization benchmark covering multiple attention architectures, token counts, and providers; legacy single-architecture benchmark removed.
    • Public operator and helper names unified under the generalized RoPE path.
  • Tests

    • Expanded tests validating generalized RoPE quantization across attention types, dimensions, token counts, and quantization dtypes.

@kahyunnam kahyunnam changed the title [not ready for review! draft.] MLA RoPE + quantization kernel generalization for MHA / GQA MLA RoPE + quantization fused kernel: shape generalization for MHA / GQA Oct 14, 2025
@kahyunnam kahyunnam marked this pull request as ready for review October 14, 2025 04:48
@kahyunnam kahyunnam enabled auto-merge (squash) October 14, 2025 17:37
@kahyunnam kahyunnam requested a review from yzh119 October 14, 2025 17:40
Comment thread csrc/rope.cu
@pavanimajety
Copy link
Copy Markdown
Contributor

@nvpohanh for another set of eyes.

Comment thread csrc/rope.cu
Comment thread benchmarks/bench_rope_quantize_fp8.py
Comment thread flashinfer/rope.py
Copy link
Copy Markdown
Contributor

@pavanimajety pavanimajety left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, mostly just nits for documentation and benchmark. Thanks for the effort!

Copy link
Copy Markdown
Collaborator

@yzh119 yzh119 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall LGTM, left some comments for suggestion.

Comment thread flashinfer/rope.py
Comment thread include/flashinfer/pos_enc.cuh Outdated
@kahyunnam kahyunnam self-assigned this Oct 17, 2025
@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai Bot commented Oct 17, 2025

Walkthrough

Renames MLA-specific RoPE APIs to generalized rope_quantize/rope_quantize_fp8, rewrites C++ binding and CUDA kernel to support variable rope/no-rope dims and MLA/GQA/MHA layouts, replaces the MLA-only benchmark with a unified benchmark script, and adds parameterized tests covering all attention types.

Changes

Cohort / File(s) Summary
Benchmarks
benchmarks/bench_mla_rope_quantize_fp8.py, benchmarks/bench_rope_quantize_fp8.py
Deleted MLA-only bench_mla_rope_quantize_fp8.py; added bench_rope_quantize_fp8.py implementing a generalized FP8 RoPE benchmark, FlashInferRotaryEmbedding, Triton perf_report entrypoints for MLA/GQA/MHA, dual provider support (flashinfer/torch), and CUDA profiler hooks.
C++ binding
csrc/flashinfer_rope_binding.cu, csrc/rope.cu
Renamed exported symbol mla_rope_quantizerope_quantize; updated binding to validate flexible tensor shapes/dtypes, accept 2D/3D K layouts, compute and pass num_qo_heads, num_kv_heads, rope_dim, no_rope_dim and adjusted strides into kernel dispatch; error messages and dispatch logic updated.
Python API
flashinfer/rope.py
Renamed public operator/fake operator mla_rope_quantizerope_quantize; replaced mla_rope_quantize_fp8 with rope_quantize_fp8 (adds is_neox, quantize_dtype, expanded outputs), and re-routed to the generalized C binding.
Kernel header / device
include/flashinfer/pos_enc.cuh
Renamed kernel MLARopeQuantizeKernelRopeQuantizeKernel and public entry MLARopeQuantizeRopeQuantize; added runtime dispatch macro DISPATCH_ROPE_DIM, new params (num_qo_heads, num_kv_heads, rope_dim, no_rope_dim), generalized chunked per-head processing, and dynamic launch configuration.
Tests
tests/attention/test_rope.py
Added test_generalized_rope_quantize (parameterized across attention types, head/dim configs, token counts, dtypes) and test_mla_rope_quantize; tests build rope/no-rope splits, invoke generalized rope_quantize_fp8, and validate Q/K outputs vs. a reference implementation.

Sequence Diagram(s)

sequenceDiagram
    participant User
    participant PyAPI as flashinfer.rope_quantize_fp8
    participant Binding as rope_quantize (C++)
    participant Kernel as RopeQuantizeKernel
    participant Outputs

    rect rgb(250,250,255)
    Note over User,Outputs: Generalized RoPE + FP8 quantization flow (MLA/GQA/MHA)
    User->>PyAPI: call rope_quantize_fp8(q, k, ..., is_neox?, quantize_dtype?)
    PyAPI->>Binding: validate shapes/dtypes, split rope/no-rope, prepare tensors/strides
    Binding->>Kernel: dispatch(num_qo_heads, num_kv_heads, rope_dim, no_rope_dim, strides, scales, interleave)
    Kernel->>Outputs: apply per-head/chunk RoPE and quantize outputs to FP8
    end
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~60 minutes

Poem

🐇 I hopped through kernels, renamed a little tune,
rope_dim now dances beneath a brighter moon,
bindings stretched their paws to cover every head,
benchmarks and tests cheer on each path we tread,
nibble a carrot — quantize, then swoon! 🥕

Pre-merge checks and finishing touches

✅ Passed checks (3 passed)
Check name Status Explanation
Title Check ✅ Passed The pull request title "MLA RoPE + quantization fused kernel: shape generalization for MHA / GQA" clearly and specifically describes the main change in the PR. The title accurately reflects the core objective: generalizing the existing MLA-specific RoPE + quantization fused kernel to support additional attention architectures (MHA and GQA). The title is concise (72 characters), avoids vague terminology, and provides sufficient specificity for developers reviewing the repository history to understand the primary change without ambiguity.
Description Check ✅ Passed The PR description follows most of the required template structure but has some gaps. The author provided a clear Description section explaining the generalization of MLA RoPE+Q kernels to support GQA/MHA shapes, completed all Pre-commit Checks and Tests checklist items (marked [x]), and included comprehensive test results (312 passing tests) and benchmark data with configuration details and visualizations. However, the Related Issues section is completely absent without explanation, and the Reviewer Notes section (while marked optional in the template) is also not included. The description is minimal but substantive given the supporting test and benchmark evidence provided.
Docstring Coverage ✅ Passed No functions found in the changes. Docstring coverage check skipped.
✨ Finishing touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
include/flashinfer/pos_enc.cuh (1)

447-487: Critical: OOB loads/stores in non-RoPE paths when no_rope_dim < rope_dim or tail not multiple of vec_size

bdx is derived from rope_dim, but K/Q non-RoPE branches index with txvec_size against no_rope_dim without guarding. Threads with txvec_size ≥ no_rope_dim will cast_load/cast_store out-of-bounds. Tails not divisible by vec_size also risk partial overruns.

Fix by bounding per-chunk width and handling tails safely. Example patch:

@@
-      uint32_t elem_offset = nope_chunk_idx * rope_chunk_size;  // Use same chunk size
+      uint32_t elem_offset = nope_chunk_idx * rope_chunk_size;  // Use same chunk size
+      // Bound valid width for this chunk
+      uint32_t valid = (no_rope_dim > elem_offset)
+                           ? min(rope_chunk_size, no_rope_dim - elem_offset)
+                           : 0;
@@
-      vec_t<float, vec_size> k_nope_vec;
-      k_nope_vec.cast_load(k_nope_in_ptr + tx * vec_size);
-#pragma unroll
-      for (uint32_t i = 0; i < vec_size; ++i) {
-        k_nope_vec[i] = k_nope_vec[i] * quant_scale_kv;
-      }
-      k_nope_vec.cast_store(k_nope_out_ptr + tx * vec_size);
+      if (tx * vec_size + (vec_size - 1) < valid) {
+        // Full vector fits
+        vec_t<float, vec_size> k_nope_vec;
+        k_nope_vec.cast_load(k_nope_in_ptr + tx * vec_size);
+#pragma unroll
+        for (uint32_t i = 0; i < vec_size; ++i) {
+          k_nope_vec[i] = k_nope_vec[i] * quant_scale_kv;
+        }
+        k_nope_vec.cast_store(k_nope_out_ptr + tx * vec_size);
+      } else if (tx * vec_size < valid) {
+        // Scalar tail
+        const uint32_t rem = valid - tx * vec_size;
+        for (uint32_t i = 0; i < rem; ++i) {
+          float v = static_cast<float>(k_nope_in_ptr[tx * vec_size + i]) * quant_scale_kv;
+          k_nope_out_ptr[tx * vec_size + i] = static_cast<QuantType>(v);
+        }
+      }
@@
-      uint32_t elem_offset = nope_chunk_idx * rope_chunk_size;  // Use same chunk size
+      uint32_t elem_offset = nope_chunk_idx * rope_chunk_size;  // Use same chunk size
+      uint32_t valid = (no_rope_dim > elem_offset)
+                           ? min(rope_chunk_size, no_rope_dim - elem_offset)
+                           : 0;
@@
-      vec_t<float, vec_size> q_nope_vec;
-      q_nope_vec.cast_load(q_nope_in_ptr + tx * vec_size);
-#pragma unroll
-      for (uint32_t i = 0; i < vec_size; ++i) {
-        q_nope_vec[i] = q_nope_vec[i] * quant_scale_q;
-      }
-      q_nope_vec.cast_store(q_nope_out_ptr + tx * vec_size);
+      if (tx * vec_size + (vec_size - 1) < valid) {
+        vec_t<float, vec_size> q_nope_vec;
+        q_nope_vec.cast_load(q_nope_in_ptr + tx * vec_size);
+#pragma unroll
+        for (uint32_t i = 0; i < vec_size; ++i) {
+          q_nope_vec[i] = q_nope_vec[i] * quant_scale_q;
+        }
+        q_nope_vec.cast_store(q_nope_out_ptr + tx * vec_size);
+      } else if (tx * vec_size < valid) {
+        const uint32_t rem = valid - tx * vec_size;
+        for (uint32_t i = 0; i < rem; ++i) {
+          float v = static_cast<float>(q_nope_in_ptr[tx * vec_size + i]) * quant_scale_q;
+          q_nope_out_ptr[tx * vec_size + i] = static_cast<QuantType>(v);
+        }
+      }

If vec_t already supports masked store/load, prefer using that to avoid the scalar tail. Otherwise, the above prevents OOB while covering all elements.

♻️ Duplicate comments (2)
flashinfer/rope.py (1)

1262-1262: Add Python-level BC alias.

To honor earlier feedback and avoid breaking downstreams, export the old name.

 return q_rope_out, k_rope_out, q_nope_out, k_nope_out
+
+# Backward compatibility
+mla_rope_quantize_fp8 = rope_quantize_fp8
include/flashinfer/pos_enc.cuh (1)

816-827: PDL support added — thanks

Programmatic stream serialization path looks good and addresses prior ask for PDL.

🧹 Nitpick comments (11)
csrc/flashinfer_rope_binding.cu (1)

42-46: Preserve FFI BC: export the old symbol as an alias, too.

Some external callers may still invoke flashinfer::mla_rope_quantize. Export it as an alias to rope_quantize to avoid runtime breakage.

 TVM_FFI_DLL_EXPORT_TYPED_FUNC(apply_rope_pos_ids_cos_sin_cache, apply_rope_pos_ids_cos_sin_cache);
-TVM_FFI_DLL_EXPORT_TYPED_FUNC(rope_quantize, rope_quantize);
+TVM_FFI_DLL_EXPORT_TYPED_FUNC(rope_quantize, rope_quantize);
+// Backward compatibility alias for older clients
+TVM_FFI_DLL_EXPORT_TYPED_FUNC(mla_rope_quantize, rope_quantize);

Also applies to: 53-53

tests/attention/test_rope.py (1)

359-379: Cover both RoPE layouts (is_neox True/False) and fix the stray main-call.

  • Parametrize is_neox to test both interleaved and non-interleaved layouts.
  • Remove or update the main invocation; it no longer matches the signature.
 @pytest.mark.parametrize(
     "attention_type,num_qo_heads,num_kv_heads,rope_dim,no_rope_dim",
@@
 )
 @pytest.mark.parametrize("num_tokens", [1, 19, 128, 199, 899, 2047])
 @pytest.mark.parametrize("input_dtype", [torch.float16, torch.bfloat16])
 @pytest.mark.parametrize("quant_dtype", [torch.float8_e4m3fn, torch.float8_e5m2])
+@pytest.mark.parametrize("is_neox", [True, False])
 def test_generalized_rope_quantize(
     attention_type,
     num_qo_heads,
     num_kv_heads,
     rope_dim,
     no_rope_dim,
     num_tokens,
     input_dtype,
     quant_dtype,
+    is_neox,
 ):
@@
-    rope_flashinfer = FlashInferRotaryEmbedding(
+    rope_flashinfer = FlashInferRotaryEmbedding(
         total_dim,
         rope_dim,
         4096,  # max_position_embeddings
         10000,  # base
-        False,  # is_neox_style
+        is_neox,  # is_neox_style
         input_dtype,
         device,
     )
@@
-    flashinfer.rope.rope_quantize_fp8(
+    flashinfer.rope.rope_quantize_fp8(
         q_rope_in,
         k_rope_in,
         q_nope_in,
         k_nope_in,
         rope_flashinfer.cos_sin_cache,
         pos_ids,
-        is_neox=False,
+        is_neox=is_neox,
         q_rope_out=q_rope_out,
         k_rope_out=k_rope_out,
         q_nope_out=q_nope_out,
         k_nope_out=k_nope_out,
         quant_scale_q=1.0,
         quant_scale_kv=1.0,
     )

And at file end:

-if __name__ == "__main__":
-    # ...
-    test_generalized_rope_quantize(1, torch.float16, torch.float8_e4m3fn)
+if __name__ == "__main__":
+    # Prefer: run via pytest; keep manual smoke disabled or provide a minimal harness.
+    pass

Also applies to: 447-462

flashinfer/rope.py (3)

170-190: Docstring drift in _rope_quantize.

This wrapper doesn’t “convert is_neox”; it forwards interleave. Please update the docstring to avoid confusion.

 def _rope_quantize(...):
-    r"""Custom operator that routes to the CUDA kernel implementation.
-
-    Converts is_neox parameter to interleave format and dispatches to the underlying
-    CUDA kernel via the JIT-compiled module.
-    """
+    r"""Custom operator that routes to the CUDA kernel implementation.
+    Forwards the pre-split tensors and interleave flag to the JIT-compiled CUDA kernel.
+    """

1147-1211: Validate quantize_dtype early (fail fast).

Raise early if quantize_dtype is not FP8; avoids diving into FFI before erroring.

 def rope_quantize_fp8(..., quantize_dtype: Optional[torch.dtype] = None, ...):
@@
     if quantize_dtype is None:
         ...
     else:
-        pass
+        if quantize_dtype not in (torch.float8_e4m3fn, torch.float8_e5m2):
+            raise ValueError("quantize_dtype must be torch.float8_e4m3fn or torch.float8_e5m2")

249-260: Align fake op signature with the real custom op.

The fake takes (cos_cache, sin_cache), but the real op takes a single cos_sin_cache. Keep them consistent.

-@register_fake_op("flashinfer::apply_rope_pos_ids_cos_sin_cache")
-def _fake_apply_rope_pos_ids_cos_sin_cache(
+@register_fake_op("flashinfer::apply_rope_pos_ids_cos_sin_cache")
+def _fake_apply_rope_pos_ids_cos_sin_cache(
     q: torch.Tensor,
     k: torch.Tensor,
     q_rope: torch.Tensor,
     k_rope: torch.Tensor,
-    cos_cache: torch.Tensor,
-    sin_cache: torch.Tensor,
-    pos_ids: torch.Tensor,
+    cos_sin_cache: torch.Tensor,
+    pos_ids: torch.Tensor,
     interleave: bool,
 ) -> None:
     pass
benchmarks/bench_rope_quantize_fp8.py (1)

19-89: Remove unused local FlashInferRotaryEmbedding to reduce duplication.

This class is unused (you import RotaryEmbedding from tests). Dropping it simplifies the benchmark module.

-# Local FlashInferRotaryEmbedding implementation (unused)
-class FlashInferRotaryEmbedding(nn.Module):
-    ...
-    def _apply_rotary_emb(...):
-        ...
-        return ...
include/flashinfer/pos_enc.cuh (5)

362-370: Use rope_chunk_size for RoPE, but compute per-branch chunk width for non‑RoPE to avoid wasted CTAs

rope_chunk_size is set to rope_dim and reused for no_rope_dim. This is fine for no_rope_dim ≥ rope_dim (MLA case), but for no_rope_dim < rope_dim it spawns blocks where many threads do nothing after the fix. Consider computing:

  • rope_chunk_size = rope_dim
  • nope_chunk_size = min(rope_dim, round_up_to_multiple(no_rope_dim, vec_size))

and use separate chunk counts per branch to reduce empty work.

Also applies to: 769-773


723-742: Broaden DISPATCH_ROPE_DIM or add a generic fallback

Limiting rope_dim to {16,32,64,128,256} constrains future models (e.g., 80, 160, 192). Add cases or a generic path:

  • If rope_dim % (32/sizeof(DType)) == 0, set constexpr bdx = rope_dim / vec_size and dispatch; else return an error. This preserves compile‑time bdx while widening support.

760-771: Verify vec_size choice and alignment assumptions

vec_size = 32 / sizeof(DType) moves from the 16‑byte convention used elsewhere (and often matches L1/L2 transaction granularity). Please verify alignment of q/k pointers and cos_sin_cache for 32‑byte vectorization on all dtypes; if not guaranteed, fall back to max(16/sizeof(DType), rope_dim/32) as used in other paths for safety.


773-806: Remove dead args[] or switch to cudaLaunchKernel with args[]

args[] is constructed but not used with cudaLaunchKernelEx (you pass typed args directly). Drop args[] to reduce confusion, or switch to cudaLaunchKernel((void*)kernel, ..., args, ...) for consistency with other call sites.

Also applies to: 829-837


352-357: Naming consistency: k_ stride parameters*

q_* use q__stride_n/h, while k use k_stride and kstride_h. Consider renaming to k*_stride_n for symmetry with get_elem_offset_impl usage. No behavior change, just readability.

Also applies to: 749-754

📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between bea5949 and 8fee3d9.

📒 Files selected for processing (7)
  • benchmarks/bench_mla_rope_quantize_fp8.py (0 hunks)
  • benchmarks/bench_rope_quantize_fp8.py (1 hunks)
  • csrc/flashinfer_rope_binding.cu (1 hunks)
  • csrc/rope.cu (3 hunks)
  • flashinfer/rope.py (6 hunks)
  • include/flashinfer/pos_enc.cuh (2 hunks)
  • tests/attention/test_rope.py (2 hunks)
💤 Files with no reviewable changes (1)
  • benchmarks/bench_mla_rope_quantize_fp8.py
🧰 Additional context used
🧬 Code graph analysis (5)
csrc/flashinfer_rope_binding.cu (1)
csrc/rope.cu (2)
  • rope_quantize (271-421)
  • rope_quantize (271-275)
tests/attention/test_rope.py (3)
benchmarks/bench_rope_quantize_fp8.py (1)
  • FlashInferRotaryEmbedding (19-88)
tests/test_helpers/rope_reference.py (1)
  • forward_native (194-232)
flashinfer/rope.py (1)
  • rope_quantize_fp8 (1147-1262)
benchmarks/bench_rope_quantize_fp8.py (3)
flashinfer/testing/utils.py (2)
  • bench_gpu_time (972-1033)
  • bench_gpu_time_with_cudagraph (855-969)
tests/test_helpers/rope_reference.py (2)
  • RotaryEmbedding (117-232)
  • forward_native (194-232)
flashinfer/rope.py (1)
  • rope_quantize_fp8 (1147-1262)
csrc/rope.cu (2)
flashinfer/comm/cuda_ipc.py (2)
  • cudaSetDevice (149-150)
  • cudaGetErrorString (146-147)
csrc/tvm_ffi_utils.h (1)
  • get_stream (272-274)
flashinfer/rope.py (3)
csrc/rope.cu (2)
  • rope_quantize (271-421)
  • rope_quantize (271-275)
csrc/flashinfer_rope_binding.cu (1)
  • rope_quantize (42-46)
flashinfer/utils.py (2)
  • register_fake_op (277-281)
  • register_fake_op (306-311)
🪛 Ruff (0.14.0)
benchmarks/bench_rope_quantize_fp8.py

111-111: Avoid specifying long messages outside the exception class

(TRY003)


214-214: Avoid specifying long messages outside the exception class

(TRY003)

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: Deploy Docs

Comment thread csrc/rope.cu
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🧹 Nitpick comments (1)
tests/attention/test_rope.py (1)

448-462: Consider explicitly passing quantize_dtype for clarity.

The test relies on implicit inference of quantize_dtype from the pre-allocated output tensors. While this should work according to the function's logic, explicitly passing quantize_dtype=quant_dtype would make the test's intent clearer and more robust.

Apply this diff:

     # Call the generalized function
     flashinfer.rope.rope_quantize_fp8(
         q_rope_in,
         k_rope_in,
         q_nope_in,
         k_nope_in,
         rope_flashinfer.cos_sin_cache,
         pos_ids,
         is_neox=False,
+        quantize_dtype=quant_dtype,
         q_rope_out=q_rope_out,
         k_rope_out=k_rope_out,
         q_nope_out=q_nope_out,
         k_nope_out=k_nope_out,
         quant_scale_q=1.0,
         quant_scale_kv=1.0,
     )
📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 8fee3d9 and 5c56c72aa6859f3829ac4850b9a20926a6557c59.

📒 Files selected for processing (1)
  • tests/attention/test_rope.py (2 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
tests/attention/test_rope.py (3)
benchmarks/bench_rope_quantize_fp8.py (1)
  • FlashInferRotaryEmbedding (19-88)
tests/test_helpers/rope_reference.py (1)
  • forward_native (194-232)
flashinfer/rope.py (1)
  • rope_quantize_fp8 (1147-1262)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: Deploy Docs
🔇 Additional comments (1)
tests/attention/test_rope.py (1)

359-410: Excellent test coverage for multiple attention architectures.

The parametrization comprehensively covers MLA, GQA, and MHA configurations with realistic head counts and dimensions, including explicit DeepSeek R1 MLA and Llama3 8B/70B configurations. The input tensor creation correctly handles the 2D vs 3D key tensor distinction between MLA and GQA/MHA.

Comment thread tests/attention/test_rope.py Outdated
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

🧹 Nitpick comments (8)
tests/attention/test_rope.py (5)

359-379: Add coverage for both RoPE layouts (is_neox True/False).

Currently only interleaved is exercised (is_neox=False). Parameterize is_neox to cover both code paths and kernel branches.


424-430: Consider cross-checking with CUDA path too.

Besides forward_native, add a compare against forward_cuda to validate equivalence with our cos/sin-cache operator under the same settings.


431-446: Assert output dtypes and shapes explicitly.

After the call, add quick checks that q_out.dtype == quant_dtype and k_out.dtype == quant_dtype and shapes match inputs to catch silent dtype/shape regressions.

+    assert q_out.dtype is quant_dtype and k_out.dtype is quant_dtype
+    assert q_out.shape == q_in.shape and k_out.shape == k_in.shape

465-478: Tighten tolerances or justify rtol=0.2.

rtol=2e-1 is generous for FP8-cast equivalence. If feasible, lower rtol (e.g., 5e-2) or add a brief comment explaining why 0.2 is required here.


538-544: Remove direct test invocation under main.

Stray manual call can surprise users running the file. Prefer relying on pytest discovery or guard it behind an env flag.

-    test_mla_rope_quantize(1, torch.float16, torch.float8_e4m3fn)
+    pass  # Rely on pytest collection
flashinfer/rope.py (3)

166-205: Fix docstring: this function already receives interleave.

The doc says it “converts is_neox to interleave,” but _rope_quantize takes interleave directly. Update the doc to avoid confusion.

-    Converts is_neox parameter to interleave format and dispatches to the underlying
-    CUDA kernel via the JIT-compiled module.
+    Routes to the CUDA kernel via the JIT-compiled module using the provided
+    `interleave` flag.

1147-1179: Backward-compat API preserved — LGTM.

Wrapper maintains mla_rope_quantize_fp8 while delegating to the generalized API. Consider a direct alias to reduce indirection if docs don’t need a separate entry.

- def mla_rope_quantize_fp8(...):
-     return rope_quantize_fp8(...)
+mla_rope_quantize_fp8 = rope_quantize_fp8

1181-1296: Optional: upfront dtype validation for clearer errors.

The CUDA op will validate dtypes, but adding lightweight Python-side checks (allowed input dtypes: fp16/bf16; output/quant dtypes: FP8 e4m3/e5m2) can surface mistakes earlier.

+    if q_rope.dtype not in (torch.float16, torch.bfloat16) or q_nope.dtype != q_rope.dtype:
+        raise TypeError("q_* inputs must be float16 or bfloat16 and match dtypes")
+    if k_rope.dtype not in (torch.float16, torch.bfloat16) or k_nope.dtype != k_rope.dtype:
+        raise TypeError("k_* inputs must be float16 or bfloat16 and match dtypes")
+    if quantize_dtype not in (None, torch.float8_e4m3fn, torch.float8_e5m2):
+        raise TypeError("quantize_dtype must be FP8 e4m3fn or e5m2")
📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between f75ce8640a2b5ba6b39fdcfd22516e0653570471 and 32149b6.

📒 Files selected for processing (2)
  • flashinfer/rope.py (5 hunks)
  • tests/attention/test_rope.py (1 hunks)
🧰 Additional context used
🧬 Code graph analysis (2)
tests/attention/test_rope.py (3)
benchmarks/bench_rope_quantize_fp8.py (1)
  • FlashInferRotaryEmbedding (19-88)
tests/test_helpers/rope_reference.py (1)
  • forward_native (194-232)
flashinfer/rope.py (1)
  • rope_quantize_fp8 (1181-1296)
flashinfer/rope.py (3)
csrc/flashinfer_rope_binding.cu (1)
  • rope_quantize (42-46)
csrc/rope.cu (2)
  • rope_quantize (271-421)
  • rope_quantize (271-275)
flashinfer/utils.py (2)
  • register_fake_op (277-281)
  • register_fake_op (306-311)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: Deploy Docs

Copy link
Copy Markdown
Collaborator

@yzh119 yzh119 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thank you @kahyunnam !

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🧹 Nitpick comments (5)
include/flashinfer/pos_enc.cuh (5)

387-387: Use && for consistency.

Replace C++ alternative token “and” with “&&” to match the rest of the file and avoid style churn.

-    if ((tx * vec_size < rope_dim) and (by < k_rope_end)) {
+    if ((tx * vec_size < rope_dim) && (by < k_rope_end)) {

729-748: DISPATCH_ROPE_DIM coverage — confirm required sizes.

Only {16,32,64,128,256} are supported. If models with rotary_dim=80 (seen in some configs) must be supported, extend the dispatch or fail fast with a clear error at API entry.


766-767: Vector width set to 32B — verify alignment.

vec_size = 32/sizeof(DType) implies 32‑byte wide loads/stores. Ensure q/k/nope and cos_sin_cache pointers are 32B‑aligned; otherwise, you may hit misaligned global accesses (perf hit or UB if vec_t assumes alignment). If 32B alignment isn’t guaranteed, fall back to the 16B convention used elsewhere:

-  constexpr uint32_t vec_size = 32 / sizeof(DType);
+  constexpr uint32_t vec_size = 16 / sizeof(DType);

Or document/enforce alignment preconditions at the API boundary.


779-812: Remove unused args[].

args is built but not used with cudaLaunchKernelEx’s typed invocation. Drop it to reduce noise.

-      void* args[] = {(void*)&q_rope_in,
-                      (void*)&k_rope_in,
-                      (void*)&q_nope_in,
-                      (void*)&k_nope_in,
-                      (void*)&q_rope_out,
-                      (void*)&k_rope_out,
-                      (void*)&q_nope_out,
-                      (void*)&k_nope_out,
-                      (void*)&cos_sin_cache,
-                      (void*)&pos_ids,
-                      (void*)&nnz,
-                      (void*)&num_qo_heads,
-                      (void*)&num_kv_heads,
-                      (void*)&rope_dim,
-                      (void*)&no_rope_dim,
-                      (void*)&q_rope_in_stride_n,
-                      (void*)&q_rope_in_stride_h,
-                      (void*)&q_nope_in_stride_n,
-                      (void*)&q_nope_in_stride_h,
-                      (void*)&q_rope_out_stride_n,
-                      (void*)&q_rope_out_stride_h,
-                      (void*)&q_nope_out_stride_n,
-                      (void*)&q_nope_out_stride_h,
-                      (void*)&k_rope_in_stride,
-                      (void*)&k_rope_in_stride_h,
-                      (void*)&k_nope_in_stride,
-                      (void*)&k_nope_in_stride_h,
-                      (void*)&k_rope_out_stride,
-                      (void*)&k_rope_out_stride_h,
-                      (void*)&k_nope_out_stride,
-                      (void*)&k_nope_out_stride_h,
-                      (void*)&quant_scale_q,
-                      (void*)&quant_scale_kv};

352-357: Name stride params consistently (_stride_n/_stride_h).

k__stride lacks the “n” suffix unlike q counterparts. Consider renaming for symmetry and to reduce call‑site confusion.

Also applies to: 431-436, 459-461, 478-482

📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 32149b6 and 19e3d04.

📒 Files selected for processing (1)
  • include/flashinfer/pos_enc.cuh (3 hunks)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: Deploy Docs
🔇 Additional comments (1)
include/flashinfer/pos_enc.cuh (1)

358-361: PDL grid dependency hooks added — looks good.

The kernel now issues griddepcontrol.wait at entry and launch_dependents at exit under CUDA 12+/sm90 guards. This satisfies the earlier review ask for PDL. Please run a quick H100 sanity to confirm no functional/perf regressions with enable_pdl on/off.

Also applies to: 492-495

Comment on lines +365 to 373
// Calculate flexible boundaries for block allocation
uint32_t rope_chunk_size = rope_dim; // Process entire rope_dim per chunk
uint32_t rope_chunks = (rope_dim + rope_chunk_size - 1) / rope_chunk_size;
uint32_t no_rope_chunks = (no_rope_dim + rope_chunk_size - 1) / rope_chunk_size;

uint32_t q_rope_end = num_qo_heads * rope_chunks;
uint32_t k_rope_end = q_rope_end + num_kv_heads * rope_chunks;
uint32_t k_nope_end = k_rope_end + num_kv_heads * no_rope_chunks;

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

Fix OOB on no_rope tails (bdx derived from rope_dim, no tail masking).

no_rope blocks reuse bdx = rope_dim/vec_size and step by rope_chunk_size = rope_dim. When no_rope_dim is not a multiple of rope_dim or vec_size, the last chunk can read/write past valid elements. Add per‑lane bounds and masked tail load/store for both K and Q no‑RoPE paths.

Apply this diff (K no‑RoPE):

@@
-      vec_t<float, vec_size> k_nope_vec;
-      k_nope_vec.cast_load(k_nope_in_ptr + tx * vec_size);
-#pragma unroll
-      for (uint32_t i = 0; i < vec_size; ++i) {
-        k_nope_vec[i] = k_nope_vec[i] * quant_scale_kv;
-      }
-      k_nope_vec.cast_store(k_nope_out_ptr + tx * vec_size);
+      const uint32_t lane = tx * vec_size;
+      const uint32_t valid = min(no_rope_dim - elem_offset, rope_chunk_size);
+      if (lane + vec_size <= valid) {
+        vec_t<float, vec_size> k_nope_vec;
+        k_nope_vec.cast_load(k_nope_in_ptr + lane);
+#pragma unroll
+        for (uint32_t i = 0; i < vec_size; ++i) {
+          k_nope_vec[i] = k_nope_vec[i] * quant_scale_kv;
+        }
+        k_nope_vec.cast_store(k_nope_out_ptr + lane);
+      } else if (lane < valid) {
+#pragma unroll
+        for (uint32_t j = 0; j < vec_size && (lane + j) < valid; ++j) {
+          float v = static_cast<float>(k_nope_in_ptr[lane + j]);
+          k_nope_out_ptr[lane + j] = static_cast<QuantType>(v * quant_scale_kv);
+        }
+      }

…and similarly for Q no‑RoPE:

@@
-      vec_t<float, vec_size> q_nope_vec;
-      q_nope_vec.cast_load(q_nope_in_ptr + tx * vec_size);
-#pragma unroll
-      for (uint32_t i = 0; i < vec_size; ++i) {
-        q_nope_vec[i] = q_nope_vec[i] * quant_scale_q;
-      }
-      q_nope_vec.cast_store(q_nope_out_ptr + tx * vec_size);
+      const uint32_t lane = tx * vec_size;
+      const uint32_t valid = min(no_rope_dim - elem_offset, rope_chunk_size);
+      if (lane + vec_size <= valid) {
+        vec_t<float, vec_size> q_nope_vec;
+        q_nope_vec.cast_load(q_nope_in_ptr + lane);
+#pragma unroll
+        for (uint32_t i = 0; i < vec_size; ++i) {
+          q_nope_vec[i] = q_nope_vec[i] * quant_scale_q;
+        }
+        q_nope_vec.cast_store(q_nope_out_ptr + lane);
+      } else if (lane < valid) {
+#pragma unroll
+        for (uint32_t j = 0; j < vec_size && (lane + j) < valid; ++j) {
+          float v = static_cast<float>(q_nope_in_ptr[lane + j]);
+          q_nope_out_ptr[lane + j] = static_cast<QuantType>(v * quant_scale_q);
+        }
+      }

If you prefer not to implement masked tails now, at minimum guard with a runtime check and return cudaErrorInvalidValue when (no_rope_dim % vec_size) != 0 to avoid silent memory bugs.

Also applies to: 450-469, 470-490

@kahyunnam kahyunnam merged commit 8a88b0e into flashinfer-ai:main Oct 18, 2025
4 checks passed
@yzh119 yzh119 mentioned this pull request Oct 22, 2025
31 tasks
yzh119 added a commit that referenced this pull request Nov 18, 2025
…sed RoPE + Q + KV cache, supports MLA/GQA/MHA) (#2037)

<!-- .github/pull_request_template.md -->

## 📌 Description

Add `flashinfer.rope.rope_quantize_fp8_append_paged_kv_cache`, which
runs a fused RoPE + Quantization (16 -> 8) + append KV Cache operation
kernel.

Note that this does not support optional quantization (there is no "RoPE
+ append KV Cache" fused operation available).

Tested on NVIDIA H100 NVL + flashinfer/flashinfer-ci-cu130:latest for
MLA/MHA/GQA problem sizes for decode and prefill cases.

## 🔍 Related Issues

"[Model Optimization] Add RoPE, RoPE+Q, RoPE+Q+KVCacheUpdate fused
kernels for MLA/GQA/MHA" item from Q4 roadmap:
#1770.

This PR is part 2 to earlier PR for RoPE + Q:
#1924

FW Stakeholders: @nvpohanh @pavanimajety 

## 🧪 Test results

```
$ pytest tests/attention/test_rope.py::test_rope_quantize_fp8_append_paged_kv_cache_decode -s
======================================================== test session starts =========================================================platform linux -- Python 3.12.11, pytest-8.4.2, pluggy-1.6.0
rootdir: /workspace/flashinfer
configfile: pytest.ini
collected 384 items

tests/attention/test_rope.py ................................................................................................................................................................................................................................................................................................................................................................................................

======================================================== 384 passed in 35.22s ========================================================
```

```
$ pytest tests/attention/test_rope.py::test_generalized_rope_quantize_append_kv_cache -s
======================================================== test session starts =========================================================
platform linux -- Python 3.12.11, pytest-8.4.2, pluggy-1.6.0
rootdir: /workspace/flashinfer
configfile: pytest.ini
collected 1248 items

tests/attention/test_rope.py .........................................................................................................
......................................................................................................................................
......................................................................................................................................
......................................................................................................................................
......................................................................................................................................
......................................................................................................................................
......................................................................................................................................
......................................................................................................................................
......................................................................................................................................
.......................................................................

================================================== 1248 passed in 63.07s (0:01:03) ===================================================
```

```
$ python benchmarks/bench_rope_quantize_fp8_append_cache.py

Detected GPU: NVIDIA GB200
Theoretical Peak Memory Bandwidth: 7928.06 GB/s


====================================================================================================
  MLA: 128 Q heads, 1 K head, 64+512 dims (DeepSeek-style)
====================================================================================================
Tokens     Time (ms)    BW (GB/s)    BW% (Peak)     TFLOPs
----------------------------------------------------------------------
1          0.00258      86.53        1.1            0.010
32         0.00381      1873.82      23.6           0.208
128        0.00763      3744.50      47.2           0.416
384        0.01848      4637.34      58.5           0.515
768        0.03694      4639.75      58.5           0.515
1024       0.04879      4683.57      59.1           0.520
2048       0.09590      4766.09      60.1           0.529
4096       0.19031      4803.27      60.6           0.533
8192       0.38523      4745.78      59.9           0.527

====================================================================================================
  GQA: 32 Q heads, 8 K heads, 64+64 dims (Llama-style)
====================================================================================================
Tokens     Time (ms)    BW (GB/s)    BW% (Peak)     TFLOPs
----------------------------------------------------------------------
1          0.00294      6.36         0.1            0.003
32         0.00316      189.48       2.4            0.078
128        0.00317      755.23       9.5            0.310
384        0.00398      1803.09      22.7           0.741
768        0.00522      2750.51      34.7           1.130
1024       0.00617      3100.80      39.1           1.274
2048       0.00927      4130.83      52.1           1.697
4096       0.01631      4695.01      59.2           1.929
8192       0.03466      4418.01      55.7           1.815

====================================================================================================
  MHA: 32 Q heads, 32 K heads, 64+64 dims (Standard)
====================================================================================================
Tokens     Time (ms)    BW (GB/s)    BW% (Peak)     TFLOPs
----------------------------------------------------------------------
1          0.00293      12.68        0.2            0.004
32         0.00313      379.98       4.8            0.126
128        0.00357      1331.80      16.8           0.441
384        0.00517      2756.73      34.8           0.912
768        0.00742      3840.41      48.4           1.271
1024       0.00887      4287.15      54.1           1.419
2048       0.01504      5055.18      63.8           1.673
4096       0.03343      4548.12      57.4           1.505
8192       0.06410      4744.76      59.8           1.571

====================================================================================================
Configuration details:
  Page size: 32, Batch size: 4
  Token range: 1 (single decode) → 8192 (large prefill)
  GPU: NVIDIA GB200
  Theoretical Peak Memory Bandwidth: 7928.06 GB/s
  BW% calculated as: (achieved_bandwidth / peak_bandwidth) * 100
====================================================================================================

```

## 🚀 Pull Request Checklist

Thank you for contributing to FlashInfer! Before we review your pull
request, please make sure the following items are complete.

### ✅ Pre-commit Checks

- [x] I have installed `pre-commit` by running `pip install pre-commit`
(or used your preferred method).
- [x] I have installed the hooks with `pre-commit install`.
- [x] I have run the hooks manually with `pre-commit run --all-files`
and fixed any reported issues.

> If you are unsure about how to set up `pre-commit`, see [the
pre-commit documentation](https://pre-commit.com/).

## 🧪 Tests

- [x] Tests have been added or updated as needed.
- [x] All tests are passing (`unittest`, etc.).

## Reviewer Notes

<!-- Optional: anything you'd like reviewers to focus on, concerns, etc.
-->


<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->
## Summary by CodeRabbit

* **New Features**
* Fused RoPE + FP8 quantize-and-append for paged KV caches (MLA,
GQA/MHA) with layout, page-size, interleave and PDL options; returns
quantized Q outputs and writes K/V into paged caches; public ops and
high-level API added.

* **Tests**
* Deterministic, parameterized tests for append and decode/continuation
across attention types, layouts, dtypes and quant settings with
reference validation.

* **Benchmarks**
* New benchmark script for performance, bandwidth and Nsight profiling
of the paged-KV quantize+append path.

* **Chores**
  * Added cached GPU memory-bandwidth utility for benchmarks.
<!-- end of auto-generated comment: release notes by coderabbit.ai -->

---------

Co-authored-by: Zihao Ye <expye@outlook.com>
@coderabbitai coderabbitai Bot mentioned this pull request Dec 1, 2025
5 tasks
BingooYang pushed a commit to BingooYang/flashinfer that referenced this pull request Mar 13, 2026
…sed RoPE + Q + KV cache, supports MLA/GQA/MHA) (flashinfer-ai#2037)

<!-- .github/pull_request_template.md -->

Add `flashinfer.rope.rope_quantize_fp8_append_paged_kv_cache`, which
runs a fused RoPE + Quantization (16 -> 8) + append KV Cache operation
kernel.

Note that this does not support optional quantization (there is no "RoPE
+ append KV Cache" fused operation available).

Tested on NVIDIA H100 NVL + flashinfer/flashinfer-ci-cu130:latest for
MLA/MHA/GQA problem sizes for decode and prefill cases.

"[Model Optimization] Add RoPE, RoPE+Q, RoPE+Q+KVCacheUpdate fused
kernels for MLA/GQA/MHA" item from Q4 roadmap:
flashinfer-ai#1770.

This PR is part 2 to earlier PR for RoPE + Q:
flashinfer-ai#1924

FW Stakeholders: @nvpohanh @pavanimajety

```
$ pytest tests/attention/test_rope.py::test_rope_quantize_fp8_append_paged_kv_cache_decode -s
======================================================== test session starts =========================================================platform linux -- Python 3.12.11, pytest-8.4.2, pluggy-1.6.0
rootdir: /workspace/flashinfer
configfile: pytest.ini
collected 384 items

tests/attention/test_rope.py ................................................................................................................................................................................................................................................................................................................................................................................................

======================================================== 384 passed in 35.22s ========================================================
```

```
$ pytest tests/attention/test_rope.py::test_generalized_rope_quantize_append_kv_cache -s
======================================================== test session starts =========================================================
platform linux -- Python 3.12.11, pytest-8.4.2, pluggy-1.6.0
rootdir: /workspace/flashinfer
configfile: pytest.ini
collected 1248 items

tests/attention/test_rope.py .........................................................................................................
......................................................................................................................................
......................................................................................................................................
......................................................................................................................................
......................................................................................................................................
......................................................................................................................................
......................................................................................................................................
......................................................................................................................................
......................................................................................................................................
.......................................................................

================================================== 1248 passed in 63.07s (0:01:03) ===================================================
```

```
$ python benchmarks/bench_rope_quantize_fp8_append_cache.py

Detected GPU: NVIDIA GB200
Theoretical Peak Memory Bandwidth: 7928.06 GB/s

====================================================================================================
  MLA: 128 Q heads, 1 K head, 64+512 dims (DeepSeek-style)
====================================================================================================
Tokens     Time (ms)    BW (GB/s)    BW% (Peak)     TFLOPs
----------------------------------------------------------------------
1          0.00258      86.53        1.1            0.010
32         0.00381      1873.82      23.6           0.208
128        0.00763      3744.50      47.2           0.416
384        0.01848      4637.34      58.5           0.515
768        0.03694      4639.75      58.5           0.515
1024       0.04879      4683.57      59.1           0.520
2048       0.09590      4766.09      60.1           0.529
4096       0.19031      4803.27      60.6           0.533
8192       0.38523      4745.78      59.9           0.527

====================================================================================================
  GQA: 32 Q heads, 8 K heads, 64+64 dims (Llama-style)
====================================================================================================
Tokens     Time (ms)    BW (GB/s)    BW% (Peak)     TFLOPs
----------------------------------------------------------------------
1          0.00294      6.36         0.1            0.003
32         0.00316      189.48       2.4            0.078
128        0.00317      755.23       9.5            0.310
384        0.00398      1803.09      22.7           0.741
768        0.00522      2750.51      34.7           1.130
1024       0.00617      3100.80      39.1           1.274
2048       0.00927      4130.83      52.1           1.697
4096       0.01631      4695.01      59.2           1.929
8192       0.03466      4418.01      55.7           1.815

====================================================================================================
  MHA: 32 Q heads, 32 K heads, 64+64 dims (Standard)
====================================================================================================
Tokens     Time (ms)    BW (GB/s)    BW% (Peak)     TFLOPs
----------------------------------------------------------------------
1          0.00293      12.68        0.2            0.004
32         0.00313      379.98       4.8            0.126
128        0.00357      1331.80      16.8           0.441
384        0.00517      2756.73      34.8           0.912
768        0.00742      3840.41      48.4           1.271
1024       0.00887      4287.15      54.1           1.419
2048       0.01504      5055.18      63.8           1.673
4096       0.03343      4548.12      57.4           1.505
8192       0.06410      4744.76      59.8           1.571

====================================================================================================
Configuration details:
  Page size: 32, Batch size: 4
  Token range: 1 (single decode) → 8192 (large prefill)
  GPU: NVIDIA GB200
  Theoretical Peak Memory Bandwidth: 7928.06 GB/s
  BW% calculated as: (achieved_bandwidth / peak_bandwidth) * 100
====================================================================================================

```

Thank you for contributing to FlashInfer! Before we review your pull
request, please make sure the following items are complete.

- [x] I have installed `pre-commit` by running `pip install pre-commit`
(or used your preferred method).
- [x] I have installed the hooks with `pre-commit install`.
- [x] I have run the hooks manually with `pre-commit run --all-files`
and fixed any reported issues.

> If you are unsure about how to set up `pre-commit`, see [the
pre-commit documentation](https://pre-commit.com/).

- [x] Tests have been added or updated as needed.
- [x] All tests are passing (`unittest`, etc.).

<!-- Optional: anything you'd like reviewers to focus on, concerns, etc.
-->

<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->

* **New Features**
* Fused RoPE + FP8 quantize-and-append for paged KV caches (MLA,
GQA/MHA) with layout, page-size, interleave and PDL options; returns
quantized Q outputs and writes K/V into paged caches; public ops and
high-level API added.

* **Tests**
* Deterministic, parameterized tests for append and decode/continuation
across attention types, layouts, dtypes and quant settings with
reference validation.

* **Benchmarks**
* New benchmark script for performance, bandwidth and Nsight profiling
of the paged-KV quantize+append path.

* **Chores**
  * Added cached GPU memory-bandwidth utility for benchmarks.
<!-- end of auto-generated comment: release notes by coderabbit.ai -->

---------

Co-authored-by: Zihao Ye <expye@outlook.com>
murphymatt pushed a commit to fw-ai/flashinfer that referenced this pull request Mar 31, 2026
…sed RoPE + Q + KV cache, supports MLA/GQA/MHA) (#2037)

<!-- .github/pull_request_template.md -->

## 📌 Description

Add `flashinfer.rope.rope_quantize_fp8_append_paged_kv_cache`, which
runs a fused RoPE + Quantization (16 -> 8) + append KV Cache operation
kernel.

Note that this does not support optional quantization (there is no "RoPE
+ append KV Cache" fused operation available).

Tested on NVIDIA H100 NVL + flashinfer/flashinfer-ci-cu130:latest for
MLA/MHA/GQA problem sizes for decode and prefill cases.

## 🔍 Related Issues

"[Model Optimization] Add RoPE, RoPE+Q, RoPE+Q+KVCacheUpdate fused
kernels for MLA/GQA/MHA" item from Q4 roadmap:
flashinfer-ai/flashinfer#1770.

This PR is part 2 to earlier PR for RoPE + Q:
flashinfer-ai/flashinfer#1924

FW Stakeholders: @nvpohanh @pavanimajety 

## 🧪 Test results

```
$ pytest tests/attention/test_rope.py::test_rope_quantize_fp8_append_paged_kv_cache_decode -s
======================================================== test session starts =========================================================platform linux -- Python 3.12.11, pytest-8.4.2, pluggy-1.6.0
rootdir: /workspace/flashinfer
configfile: pytest.ini
collected 384 items

tests/attention/test_rope.py ................................................................................................................................................................................................................................................................................................................................................................................................

======================================================== 384 passed in 35.22s ========================================================
```

```
$ pytest tests/attention/test_rope.py::test_generalized_rope_quantize_append_kv_cache -s
======================================================== test session starts =========================================================
platform linux -- Python 3.12.11, pytest-8.4.2, pluggy-1.6.0
rootdir: /workspace/flashinfer
configfile: pytest.ini
collected 1248 items

tests/attention/test_rope.py .........................................................................................................
......................................................................................................................................
......................................................................................................................................
......................................................................................................................................
......................................................................................................................................
......................................................................................................................................
......................................................................................................................................
......................................................................................................................................
......................................................................................................................................
.......................................................................

================================================== 1248 passed in 63.07s (0:01:03) ===================================================
```

```
$ python benchmarks/bench_rope_quantize_fp8_append_cache.py

Detected GPU: NVIDIA GB200
Theoretical Peak Memory Bandwidth: 7928.06 GB/s


====================================================================================================
  MLA: 128 Q heads, 1 K head, 64+512 dims (DeepSeek-style)
====================================================================================================
Tokens     Time (ms)    BW (GB/s)    BW% (Peak)     TFLOPs
----------------------------------------------------------------------
1          0.00258      86.53        1.1            0.010
32         0.00381      1873.82      23.6           0.208
128        0.00763      3744.50      47.2           0.416
384        0.01848      4637.34      58.5           0.515
768        0.03694      4639.75      58.5           0.515
1024       0.04879      4683.57      59.1           0.520
2048       0.09590      4766.09      60.1           0.529
4096       0.19031      4803.27      60.6           0.533
8192       0.38523      4745.78      59.9           0.527

====================================================================================================
  GQA: 32 Q heads, 8 K heads, 64+64 dims (Llama-style)
====================================================================================================
Tokens     Time (ms)    BW (GB/s)    BW% (Peak)     TFLOPs
----------------------------------------------------------------------
1          0.00294      6.36         0.1            0.003
32         0.00316      189.48       2.4            0.078
128        0.00317      755.23       9.5            0.310
384        0.00398      1803.09      22.7           0.741
768        0.00522      2750.51      34.7           1.130
1024       0.00617      3100.80      39.1           1.274
2048       0.00927      4130.83      52.1           1.697
4096       0.01631      4695.01      59.2           1.929
8192       0.03466      4418.01      55.7           1.815

====================================================================================================
  MHA: 32 Q heads, 32 K heads, 64+64 dims (Standard)
====================================================================================================
Tokens     Time (ms)    BW (GB/s)    BW% (Peak)     TFLOPs
----------------------------------------------------------------------
1          0.00293      12.68        0.2            0.004
32         0.00313      379.98       4.8            0.126
128        0.00357      1331.80      16.8           0.441
384        0.00517      2756.73      34.8           0.912
768        0.00742      3840.41      48.4           1.271
1024       0.00887      4287.15      54.1           1.419
2048       0.01504      5055.18      63.8           1.673
4096       0.03343      4548.12      57.4           1.505
8192       0.06410      4744.76      59.8           1.571

====================================================================================================
Configuration details:
  Page size: 32, Batch size: 4
  Token range: 1 (single decode) → 8192 (large prefill)
  GPU: NVIDIA GB200
  Theoretical Peak Memory Bandwidth: 7928.06 GB/s
  BW% calculated as: (achieved_bandwidth / peak_bandwidth) * 100
====================================================================================================

```

## 🚀 Pull Request Checklist

Thank you for contributing to FlashInfer! Before we review your pull
request, please make sure the following items are complete.

### ✅ Pre-commit Checks

- [x] I have installed `pre-commit` by running `pip install pre-commit`
(or used your preferred method).
- [x] I have installed the hooks with `pre-commit install`.
- [x] I have run the hooks manually with `pre-commit run --all-files`
and fixed any reported issues.

> If you are unsure about how to set up `pre-commit`, see [the
pre-commit documentation](https://pre-commit.com/).

## 🧪 Tests

- [x] Tests have been added or updated as needed.
- [x] All tests are passing (`unittest`, etc.).

## Reviewer Notes

<!-- Optional: anything you'd like reviewers to focus on, concerns, etc.
-->


<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->
## Summary by CodeRabbit

* **New Features**
* Fused RoPE + FP8 quantize-and-append for paged KV caches (MLA,
GQA/MHA) with layout, page-size, interleave and PDL options; returns
quantized Q outputs and writes K/V into paged caches; public ops and
high-level API added.

* **Tests**
* Deterministic, parameterized tests for append and decode/continuation
across attention types, layouts, dtypes and quant settings with
reference validation.

* **Benchmarks**
* New benchmark script for performance, bandwidth and Nsight profiling
of the paged-KV quantize+append path.

* **Chores**
  * Added cached GPU memory-bandwidth utility for benchmarks.
<!-- end of auto-generated comment: release notes by coderabbit.ai -->

---------

Co-authored-by: Zihao Ye <expye@outlook.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants