Skip to content

feat: add xqa mla backend#2053

Merged
yzh119 merged 3 commits intoflashinfer-ai:mainfrom
qsang-nv:add_xqa_mla_backend
Nov 10, 2025
Merged

feat: add xqa mla backend#2053
yzh119 merged 3 commits intoflashinfer-ai:mainfrom
qsang-nv:add_xqa_mla_backend

Conversation

@qsang-nv
Copy link
Copy Markdown
Collaborator

@qsang-nv qsang-nv commented Nov 6, 2025

📌 Description

add xqa mla backend and corresponding unittests

🔍 Related Issues

🚀 Pull Request Checklist

Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.

✅ Pre-commit Checks

  • I have installed pre-commit by running pip install pre-commit (or used your preferred method).
  • I have installed the hooks with pre-commit install.
  • I have run the hooks manually with pre-commit run --all-files and fixed any reported issues.

If you are unsure about how to set up pre-commit, see the pre-commit documentation.

🧪 Tests

  • Tests have been added or updated as needed.
  • All tests are passing (unittest, etc.).

Reviewer Notes

Summary by CodeRabbit

  • New Features

    • MLA batch decoding with explicit backend selection (auto / trtllm-gen / xqa) and a new public MLA decode entrypoint.
  • Improvements

    • Backend-aware routing, hardware capability checks, and stronger validation (fp8 and sequence-length constraints) with clearer errors.
    • Simplified MLA launch to a single sub-sequence configuration and more consistent data/shape handling.
  • Tests

    • Expanded test matrix with backend parameterization, GPU guards, and a new MLA integration test using tolerance-based pass criteria.

Signed-off-by: Qidi Sang <200703406+qsang-nv@users.noreply.github.com>
@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai Bot commented Nov 6, 2025

Note

Other AI code review bot(s) detected

CodeRabbit has detected other AI code review bot(s) in this pull request and will avoid duplicating their findings in the review comments. This may lead to a less comprehensive review.

Walkthrough

Kernel launch in csrc/xqa/mla_sm120.cu was simplified to use a single sub-sequence block (gridDim.y = 1). Python decoding gains an XQA MLA path and backend selection in flashinfer/decode.py. Tests were added/parameterized to cover both "trtllm-gen" and "xqa" MLA flows with SM-specific guards and tolerance checks.

Changes

Cohort / File(s) Summary
MLA SM120 Kernel Configuration
csrc/xqa/mla_sm120.cu
Removed/diasbled computation of nbSubSeqPerSeq in both launchMLA and launchMLAFlashInfer and set gridDim.y = 1, forcing a single sub-sequence block per sequence in kernel launches.
MLA Decode API & Backend Routing
flashinfer/decode.py
Added xqa_batch_decode_with_kv_cache_mla; extended trtllm_batch_decode_with_kv_cache_mla(..., backend="auto") to select backend ("auto"/"xqa"/"trtllm-gen"), validate GPU compute capability, enforce fp8 and q_len_per_request constraints for XQA, reshape inputs (kv_cache/seq_lens), and route to xqa_mla for the XQA path.
TRT-LLM MLA Tests
tests/attention/test_trtllm_gen_mla.py
Parameterized test over backend ("trtllm-gen", "xqa"); added backend-specific pre-checks (SM constraints, fp8/q_len limits) and adjusted assertions — XQA uses tolerance-based element checks with a ≥95% pass ratio.
XQA MLA Integration Test
tests/attention/test_xqa_mla_batch_decode.py
New integration test exercising xqa_batch_decode_with_kv_cache_mla on SM120 devices; creates KV-cache and workspace, runs MLA decode, computes element-wise abs/rel errors against BatchMLAPagedAttentionWrapper reference, and asserts ≥95% pass ratio across parameter sweep.

Sequence Diagram(s)

sequenceDiagram
    participant User
    participant decode.py
    participant BackendSelector
    participant XQA_MLA_Path
    participant TRTLLM_Path

    User->>decode.py: trtllm_batch_decode_with_kv_cache_mla(..., backend="auto")
    decode.py->>BackendSelector: inspect GPU compute_capability & backend param
    BackendSelector-->>decode.py: select "xqa" or "trtllm-gen"

    alt Selected "xqa"
        decode.py->>decode.py: validate fp8 & q_len_per_request == 1
        decode.py->>decode.py: reshape kv_cache, seq_lens, prepare scratch/semaphore
        decode.py->>XQA_MLA_Path: xqa_batch_decode_with_kv_cache_mla(...)
        XQA_MLA_Path-->>decode.py: decoded output
    else Selected "trtllm-gen"
        decode.py->>TRTLLM_Path: trtllm-paged-attention_decode(...)
        TRTLLM_Path-->>decode.py: decoded output
    end

    decode.py-->>User: return output tensor
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~45 minutes

  • Areas needing extra attention:
    • csrc/xqa/mla_sm120.cu: kernel launch grid semantics and rationale for forcing gridDim.y = 1.
    • flashinfer/decode.py: backend selection logic, fp8 and q_len_per_request validations, and reshaping of kv_cache/seq_lens and scratch/semaphore handling for XQA path.
    • New tests: SM-specific guards, tolerance thresholds, and correctness of reference computation in tests/attention/test_xqa_mla_batch_decode.py.

Possibly related PRs

Suggested reviewers

  • cyx-6
  • yongwww
  • nvmbreughe
  • bkryu
  • kahyunnam
  • yzh119

Poem

🐰
I hopped through kernels, tidy and spry,
One sub-sequence block now reaches the sky;
Backends chosen, tests take flight,
MLA hums through day and night,
A carrot-coded cheer — nibble and delight! 🥕

Pre-merge checks and finishing touches

❌ Failed checks (1 warning, 1 inconclusive)
Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 0.00% which is insufficient. The required threshold is 80.00%. You can run @coderabbitai generate docstrings to improve docstring coverage.
Description check ❓ Inconclusive The description briefly states the main change ('add xqa mla backend and corresponding unittests'), but the PR template checklist items are not marked complete, suggesting pre-commit checks and test validation may not have been performed. Clarify whether pre-commit checks were run and all tests pass by marking completed checklist items or providing confirmation in reviewer notes.
✅ Passed checks (1 passed)
Check name Status Explanation
Title check ✅ Passed The title 'feat: add xqa mla backend' is concise, specific, and clearly identifies the main feature addition—adding an XQA MLA backend implementation.
✨ Finishing touches
  • 📝 Generate docstrings
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

@gemini-code-assist
Copy link
Copy Markdown
Contributor

Summary of Changes

Hello @qsang-nv, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request integrates a new XQA MLA backend into FlashInfer, targeting NVIDIA's SM120 (Blackwell) GPU architecture to enhance performance for specific attention operations. It introduces a flexible mechanism to select between the new XQA backend and the existing TRTLLM-GEN backend based on the detected GPU's compute capability or explicit user choice. The changes also include dedicated unit tests to ensure the correctness and compatibility of the new backend.

Highlights

  • XQA MLA Backend Integration: Introduced a new XQA MLA (Multi-Layer Attention) backend, specifically optimized for SM120 (Blackwell) GPU architectures.
  • Dynamic Backend Selection: The trtllm_batch_decode_with_kv_cache_mla function now includes a backend parameter, allowing automatic or explicit selection between trtllm-gen (for SM100/SM103) and xqa (for SM120).
  • XQA MLA Specific Constraints: The XQA MLA backend currently supports only FP8 operations, a query length per request of 1, and does not support sinks.
  • Unit Test Expansion: Existing unit tests for MLA batch decode were updated to cover both trtllm-gen and xqa backends, and a new dedicated test file was added for the XQA MLA implementation.
  • Multi-Block Mode Disablement: Multi-block mode has been temporarily disabled in the mla_sm120.cu kernel for MLA operations.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a new XQA MLA backend and its corresponding unit tests. The changes are well-structured. My review includes suggestions for improving code clarity and maintainability by removing commented-out code, consolidating validation logic, using constants for magic numbers, and clarifying documentation. I also noted a todo comment in a new test file that should be addressed.

Comment thread csrc/xqa/mla_sm120.cu
Comment on lines +1793 to +1799
/*uint32_t const nbSubSeqPerSeq = [&]() -> uint32_t {
float const factor = 4.f;
return mha::min<uint32_t>(
mha::max<uint32_t>(
1U, (uint32_t)round(multiProcessorCount / 4 / (batchSize * nbKHeads) * factor)),
divUp(maxSeqLen, tokensPerTile * 2));
}();
}();*/ // MLA disables multi-block mode for now
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This block of code for calculating nbSubSeqPerSeq is commented out. To improve code clarity and maintainability, it's better to remove dead code rather than commenting it out. If this feature might be re-enabled in the future, it can be retrieved from version control history.

Comment thread flashinfer/decode.py
bmm2_scale: fused scale for mla bmm2 input.
bmm1_scale_log2_tensor: On-device fused scale tensor for mla bmm1 input. Must be fused with * M_LOG2E before passing in.
bmm2_scale_tensor: On-device fused scale tensor for mla bmm2 input.
sinks: additional value per head in the denominator of the softmax.
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The sinks parameter is not supported by the xqa backend, as indicated by the check at line 2589. This limitation should be documented in the docstring for clarity.

Suggested change
sinks: additional value per head in the denominator of the softmax.
sinks: additional value per head in the denominator of the softmax. Not supported by the ``xqa`` backend.

Comment thread flashinfer/decode.py
Comment on lines +2581 to +2610
if (
get_compute_capability(query.device)[0] != 12
or query.dtype != torch.float8_e4m3fn
or kv_cache.dtype != torch.float8_e4m3fn
):
raise ValueError(
f"XQA MLA only supports fp8 operation on SM120 GPUs, got {query.dtype} and {kv_cache.dtype}"
)
if sinks is not None:
raise ValueError("XQA MLA does not support sinks")
if query.size(1) != 1:
raise ValueError(
f"XQA MLA only supports q_len_per_request == 1, got {query.size(1)}"
)
return xqa_batch_decode_with_kv_cache_mla(
query,
kv_cache,
workspace_buffer,
qk_nope_head_dim,
kv_lora_rank,
qk_rope_head_dim,
block_tables,
seq_lens,
max_seq_len,
out,
bmm1_scale,
bmm2_scale,
sinks,
enable_pdl,
)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The validation checks for the xqa backend are duplicated in the xqa_batch_decode_with_kv_cache_mla function. To avoid redundancy and improve maintainability, these checks should be consolidated within the xqa_batch_decode_with_kv_cache_mla function. The dispatcher should only be responsible for routing to the correct backend implementation.

        return xqa_batch_decode_with_kv_cache_mla(
            query,
            kv_cache,
            workspace_buffer,
            qk_nope_head_dim,
            kv_lora_rank,
            qk_rope_head_dim,
            block_tables,
            seq_lens,
            max_seq_len,
            out,
            bmm1_scale,
            bmm2_scale,
            sinks,
            enable_pdl,
        )

Comment thread flashinfer/decode.py
Comment on lines +2742 to +2752
q_len_per_request = query.size(1)
if q_len_per_request != 1:
raise ValueError(
f"XQA MLA only supports q_len_per_request == 1, got {q_len_per_request}"
)
if query.dtype != torch.float8_e4m3fn or kv_cache.dtype != torch.float8_e4m3fn:
raise ValueError(
f"XQA MLA only supports fp8 tensor core operation, got {query.dtype} and {kv_cache.dtype}"
)
if sinks is not None:
raise ValueError("XQA MLA does not support sinks")
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

To make this function self-contained and safe to call directly, the compute capability check should be included here. This consolidates all xqa backend specific validations in one place, which is good practice when refactoring the dispatcher function trtllm_batch_decode_with_kv_cache_mla.

    if get_compute_capability(query.device)[0] != 12:
        raise ValueError("XQA MLA only supports SM120 GPUs")
    q_len_per_request = query.size(1)
    if q_len_per_request != 1:
        raise ValueError(
            f"XQA MLA only supports q_len_per_request == 1, got {q_len_per_request}"
        )
    if query.dtype != torch.float8_e4m3fn or kv_cache.dtype != torch.float8_e4m3fn:
        raise ValueError(
            f"XQA MLA only supports fp8 tensor core operation, got {query.dtype} and {kv_cache.dtype}"
        )
    if sinks is not None:
        raise ValueError("XQA MLA does not support sinks")

Comment thread flashinfer/decode.py
Comment on lines +2778 to +2779
semaphore = workspace_u8[: 8 * 1024 * 1024] # reserve 8MB for semaphore
scratch = workspace_u8[8 * 1024 * 1024 :]
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The magic number 8 * 1024 * 1024 for the semaphore size is used here. It's better to define it as a constant with a descriptive name to improve readability and maintainability. This makes it easier to understand the purpose of the value and to change it in one place if needed.

    SEMAPHORE_SIZE_BYTES = 8 * 1024 * 1024
    semaphore = workspace_u8[:SEMAPHORE_SIZE_BYTES]  # reserve 8MB for semaphore
    scratch = workspace_u8[SEMAPHORE_SIZE_BYTES:]

batch_size * q_len_per_request, num_q_heads, qk_rope_head_dim
)

# todo: fix kv_cache
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

There is a todo comment here to fix kv_cache. Could you please clarify what needs to be fixed or remove the comment if it's no longer relevant? Leaving todo comments in the code can create confusion for future maintainers.

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🧹 Nitpick comments (4)
tests/attention/test_xqa_mla_batch_decode.py (1)

171-189: Consider extracting tolerance validation to a shared helper.

The tolerance-based validation logic (computing absolute/relative differences and checking pass ratio) is duplicated in tests/attention/test_trtllm_gen_mla.py (lines 242-260). Consider extracting this to a shared test utility function to reduce code duplication and ensure consistent validation across backends.

tests/attention/test_trtllm_gen_mla.py (1)

217-240: Use bare raise instead of raise e.

Lines 228 and 240 use raise e, which creates a new traceback. Use bare raise to preserve the original exception traceback for better debugging.

             except AssertionError as e:
                 print("output:", output)
                 print("o_ref:", o_ref)
-                raise e
+                raise
         else:
             try:
                 torch.testing.assert_close(
                     output,
                     o_ref.view(batch_size, q_len_per_request, num_q_heads, -1),
                     rtol=1e-2,
                     atol=1e-2,
                 )
             except AssertionError as e:
                 print("output:", output)
                 print("o_ref:", o_ref)
-                raise e
+                raise

Based on learnings

flashinfer/decode.py (2)

2684-2799: Unused parameters for API consistency.

The parameters max_seq_len, bmm1_scale_log2_tensor, and bmm2_scale_tensor (lines 2693, 2697-2698) are unused in the function body. These appear to be present for API consistency with trtllm_batch_decode_with_kv_cache_mla. Consider adding a docstring note explaining why these parameters are accepted but not used in the XQA implementation, or add validation to reject them if provided.

+    Note
+    ----
+    The parameters `bmm1_scale_log2_tensor`, `bmm2_scale_tensor`, and `max_seq_len` 
+    are accepted for API consistency but not used in the XQA MLA implementation.

2700-2700: Add explicit Optional type annotation.

Line 2700 should use Optional[bool] instead of implicit bool to match PEP 484 and be consistent with the function's default value of None.

-    enable_pdl: bool = None,
+    enable_pdl: Optional[bool] = None,

Based on learnings

📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between adb0e89 and a06a690.

📒 Files selected for processing (4)
  • csrc/xqa/mla_sm120.cu (1 hunks)
  • flashinfer/decode.py (4 hunks)
  • tests/attention/test_trtllm_gen_mla.py (4 hunks)
  • tests/attention/test_xqa_mla_batch_decode.py (1 hunks)
🧰 Additional context used
🧬 Code graph analysis (3)
tests/attention/test_xqa_mla_batch_decode.py (3)
flashinfer/utils.py (1)
  • get_compute_capability (252-255)
flashinfer/decode.py (7)
  • xqa_batch_decode_with_kv_cache_mla (2684-2799)
  • plan (811-1102)
  • plan (1610-1733)
  • run (1132-1145)
  • run (1148-1161)
  • run (1163-1381)
  • run (1735-1859)
flashinfer/mla.py (1)
  • BatchMLAPagedAttentionWrapper (66-447)
flashinfer/decode.py (3)
flashinfer/xqa.py (4)
  • xqa (55-94)
  • xqa (125-278)
  • xqa_mla (303-334)
  • xqa_mla (362-472)
flashinfer/utils.py (4)
  • get_compute_capability (252-255)
  • device_support_pdl (569-573)
  • get_device_sm_count (596-597)
  • check_shape_dtype_device (519-537)
csrc/trtllm_fmha_kernel_launcher.cu (2)
  • trtllm_paged_attention_decode (197-265)
  • trtllm_paged_attention_decode (197-204)
tests/attention/test_trtllm_gen_mla.py (1)
flashinfer/utils.py (1)
  • get_compute_capability (252-255)
🪛 Ruff (0.14.3)
flashinfer/decode.py

2586-2588: Avoid specifying long messages outside the exception class

(TRY003)


2590-2590: Avoid specifying long messages outside the exception class

(TRY003)


2592-2594: Avoid specifying long messages outside the exception class

(TRY003)


2622-2622: Avoid specifying long messages outside the exception class

(TRY003)


2635-2635: Consider (*query.shape[:-1], kv_lora_rank) instead of concatenation

Replace with (*query.shape[:-1], kv_lora_rank)

(RUF005)


2653-2655: Avoid specifying long messages outside the exception class

(TRY003)


2681-2681: Avoid specifying long messages outside the exception class

(TRY003)


2693-2693: Unused function argument: max_seq_len

(ARG001)


2697-2697: Unused function argument: bmm1_scale_log2_tensor

(ARG001)


2698-2698: Unused function argument: bmm2_scale_tensor

(ARG001)


2700-2700: PEP 484 prohibits implicit Optional

Convert to Optional[T]

(RUF013)


2744-2746: Avoid specifying long messages outside the exception class

(TRY003)


2748-2750: Avoid specifying long messages outside the exception class

(TRY003)


2752-2752: Avoid specifying long messages outside the exception class

(TRY003)

tests/attention/test_trtllm_gen_mla.py

228-228: Use raise without specifying exception name

Remove exception name

(TRY201)


240-240: Use raise without specifying exception name

Remove exception name

(TRY201)

🔇 Additional comments (5)
csrc/xqa/mla_sm120.cu (1)

1793-1803: LGTM: Clean simplification for initial MLA implementation.

The change disables multi-block mode by fixing gridDim.y to 1 and commenting out the dynamic nbSubSeqPerSeq calculation. The inline comment clearly indicates this is a temporary simplification for the initial MLA release.

tests/attention/test_xqa_mla_batch_decode.py (1)

7-9: Verify thread safety of global workspace buffers.

Module-level global buffers may cause issues if tests run in parallel (e.g., with pytest-xdist). Consider using pytest fixtures with appropriate scope instead, or ensure tests that use these globals are marked to run sequentially.

Additionally, verify that reusing global_xqa_workspace_buffer across multiple test invocations doesn't cause flakiness, since the comment indicates it "must be zero initialized" for its first use.

Also applies to: 94-104

tests/attention/test_trtllm_gen_mla.py (1)

26-47: LGTM: Backend parameterization and capability guards.

The backend parameter and associated capability checks appropriately enforce requirements:

  • XQA: SM120, q_len_per_request==1, fp8 dtype
  • TRTLLM-GEN: SM100/SM103
flashinfer/decode.py (2)

2576-2610: LGTM: Backend routing logic with appropriate validation.

The backend selection and validation logic correctly:

  • Auto-selects backend based on compute capability (SM10→trtllm-gen, else→xqa)
  • Enforces XQA-specific constraints (SM120, fp8, no sinks, q_len_per_request==1)
  • Routes to the appropriate implementation

2777-2797: Verify parameter scale remapping for xqa_mla.

The function remaps bmm1_scale → q_scale and bmm2_scale → kv_scale when calling xqa_mla. Ensure this mapping is semantically correct given the note in the docstring that:

  • bmm1_scale = q_scale * k_scale * sm_scale / (head_dim_qk ** 0.5)
  • bmm2_scale = v_scale * o_scale

The current mapping passes these composite scales directly as q_scale and kv_scale to the underlying kernel.

Comment thread tests/attention/test_xqa_mla_batch_decode.py Outdated
Signed-off-by: Qidi Sang <200703406+qsang-nv@users.noreply.github.com>
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between a06a690 and 896c933.

📒 Files selected for processing (1)
  • tests/attention/test_xqa_mla_batch_decode.py (1 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
tests/attention/test_xqa_mla_batch_decode.py (3)
flashinfer/utils.py (1)
  • get_compute_capability (252-255)
flashinfer/decode.py (7)
  • xqa_batch_decode_with_kv_cache_mla (2684-2799)
  • plan (811-1102)
  • plan (1610-1733)
  • run (1132-1145)
  • run (1148-1161)
  • run (1163-1381)
  • run (1735-1859)
flashinfer/mla.py (1)
  • BatchMLAPagedAttentionWrapper (66-447)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: Deploy Docs

Comment thread tests/attention/test_xqa_mla_batch_decode.py Outdated
@yzh119 yzh119 changed the title add xqa mla backend feat: add xqa mla backend Nov 6, 2025
Comment thread flashinfer/decode.py Outdated
workspace_u8 = workspace_buffer.view(torch.uint8)
semaphore = workspace_u8[: 8 * 1024 * 1024] # reserve 8MB for semaphore
scratch = workspace_u8[8 * 1024 * 1024 :]
kv_cache_new = kv_cache.squeeze(1).unsqueeze(2).contiguous()
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please don't use contiguous here, contiguous implies data movement which we don't want this to happen for decode attention APIs (I/O bound).

And in this case I think contiguous is unnecessary (internal data layout doesn't change).

btw, transpose the second dimension and third dimension seems more natural than squeeze(1).unsqueeze(2) to me.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done, and we can not replace squeeze(1).unsqueeze(2) with transpose(1, 2) because the stride is not the same.

Comment thread flashinfer/decode.py Outdated
semaphore = workspace_u8[: 8 * 1024 * 1024] # reserve 8MB for semaphore
scratch = workspace_u8[8 * 1024 * 1024 :]
kv_cache_new = kv_cache.squeeze(1).unsqueeze(2).contiguous()
seq_lens_new = seq_lens.unsqueeze(1).contiguous()
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

contiguous is not necessary.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

Comment thread flashinfer/decode.py
Parameters:
query: [batch_size, q_len_per_request, num_heads, head_dim_qk], head_dim_qk = qk_nope_head_dim (kv_lora_rank) + qk_rope_head_dim, should be concated q_nope + q_rope; q_len_per_request is the MTP query length.
kv_cache: [num_pages, page_size, head_dim_ckv + head_dim_kpe], should be concated ckv_cache + kpe_cache
workspace_buffer: [num_semaphores, 4], used for multi_block mode. Must be initialized to 0 for its first use.
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It might be the assumption of xqa attention APIs, but trtllm API assumes an one-dimensional tensor.

My suggestion is to not change the interface semantics here (considering we have another branch for trtllm-gen), and for xqa, we reshape it to [num_semaphores, 4] if necessary.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done, this comment is copied from trtllm_batch_decode_with_kv_cache_mla, actually xqa does not require such shape.

Signed-off-by: Qidi Sang <200703406+qsang-nv@users.noreply.github.com>
@qsang-nv qsang-nv requested a review from yongwww as a code owner November 10, 2025 03:17
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 896c933 and ab405e6.

📒 Files selected for processing (3)
  • flashinfer/decode.py (6 hunks)
  • tests/attention/test_trtllm_gen_mla.py (6 hunks)
  • tests/attention/test_xqa_mla_batch_decode.py (1 hunks)
🚧 Files skipped from review as they are similar to previous changes (1)
  • tests/attention/test_xqa_mla_batch_decode.py
🧰 Additional context used
🧬 Code graph analysis (2)
tests/attention/test_trtllm_gen_mla.py (1)
flashinfer/utils.py (1)
  • get_compute_capability (252-255)
flashinfer/decode.py (3)
flashinfer/xqa.py (4)
  • xqa (55-94)
  • xqa (125-278)
  • xqa_mla (303-334)
  • xqa_mla (362-472)
flashinfer/utils.py (4)
  • get_compute_capability (252-255)
  • device_support_pdl (569-573)
  • get_device_sm_count (596-597)
  • check_shape_dtype_device (519-537)
csrc/trtllm_fmha_kernel_launcher.cu (2)
  • trtllm_paged_attention_decode (197-265)
  • trtllm_paged_attention_decode (197-204)
🪛 Ruff (0.14.3)
tests/attention/test_trtllm_gen_mla.py

228-228: Use raise without specifying exception name

Remove exception name

(TRY201)


240-240: Use raise without specifying exception name

Remove exception name

(TRY201)

flashinfer/decode.py

2584-2586: Avoid specifying long messages outside the exception class

(TRY003)


2588-2588: Avoid specifying long messages outside the exception class

(TRY003)


2590-2592: Avoid specifying long messages outside the exception class

(TRY003)


2620-2620: Avoid specifying long messages outside the exception class

(TRY003)


2633-2633: Consider (*query.shape[:-1], kv_lora_rank) instead of concatenation

Replace with (*query.shape[:-1], kv_lora_rank)

(RUF005)


2651-2653: Avoid specifying long messages outside the exception class

(TRY003)


2679-2679: Avoid specifying long messages outside the exception class

(TRY003)


2691-2691: Unused function argument: max_seq_len

(ARG001)


2695-2695: Unused function argument: bmm1_scale_log2_tensor

(ARG001)


2696-2696: Unused function argument: bmm2_scale_tensor

(ARG001)


2698-2698: PEP 484 prohibits implicit Optional

Convert to Optional[T]

(RUF013)


2742-2744: Avoid specifying long messages outside the exception class

(TRY003)


2746-2748: Avoid specifying long messages outside the exception class

(TRY003)


2750-2750: Avoid specifying long messages outside the exception class

(TRY003)

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: Deploy Docs

Comment thread flashinfer/decode.py
Comment on lines +2574 to +2578
if backend == "auto":
backend = (
"trtllm-gen" if get_compute_capability(query.device)[0] == 10 else "xqa"
)
if backend == "xqa":
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Fix auto backend selection for unsupported GPUs.

backend="auto" now maps every non-SM100 device to "xqa", but XQA MLA kernels are only supported on SM120 hardware. On Hopper (SM90) or any other architecture, this branch raises the ValueError defined below instead of choosing a valid backend or cleanly reporting "MLA decode not supported". Please gate the auto-selection by compute capability so that only SM120 falls through to "xqa" and older GPUs get an explicit unsupported error (or their previous behavior). (newreleases.io)

@yzh119
Copy link
Copy Markdown
Collaborator

yzh119 commented Nov 10, 2025

/bot run

@flashinfer-bot
Copy link
Copy Markdown
Collaborator

GitLab MR !123 has been created, and the CI pipeline #38195344 is currently running. I'll report back once the pipeline job completes.

@flashinfer-bot
Copy link
Copy Markdown
Collaborator

[SUCCESS] Pipeline #38195344: 15/17 passed

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants