feat: add xqa mla backend#2053
Conversation
Signed-off-by: Qidi Sang <200703406+qsang-nv@users.noreply.github.com>
|
Note Other AI code review bot(s) detectedCodeRabbit has detected other AI code review bot(s) in this pull request and will avoid duplicating their findings in the review comments. This may lead to a less comprehensive review. WalkthroughKernel launch in Changes
Sequence Diagram(s)sequenceDiagram
participant User
participant decode.py
participant BackendSelector
participant XQA_MLA_Path
participant TRTLLM_Path
User->>decode.py: trtllm_batch_decode_with_kv_cache_mla(..., backend="auto")
decode.py->>BackendSelector: inspect GPU compute_capability & backend param
BackendSelector-->>decode.py: select "xqa" or "trtllm-gen"
alt Selected "xqa"
decode.py->>decode.py: validate fp8 & q_len_per_request == 1
decode.py->>decode.py: reshape kv_cache, seq_lens, prepare scratch/semaphore
decode.py->>XQA_MLA_Path: xqa_batch_decode_with_kv_cache_mla(...)
XQA_MLA_Path-->>decode.py: decoded output
else Selected "trtllm-gen"
decode.py->>TRTLLM_Path: trtllm-paged-attention_decode(...)
TRTLLM_Path-->>decode.py: decoded output
end
decode.py-->>User: return output tensor
Estimated code review effort🎯 4 (Complex) | ⏱️ ~45 minutes
Possibly related PRs
Suggested reviewers
Poem
Pre-merge checks and finishing touches❌ Failed checks (1 warning, 1 inconclusive)
✅ Passed checks (1 passed)
✨ Finishing touches
🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
Summary of ChangesHello @qsang-nv, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request integrates a new XQA MLA backend into FlashInfer, targeting NVIDIA's SM120 (Blackwell) GPU architecture to enhance performance for specific attention operations. It introduces a flexible mechanism to select between the new XQA backend and the existing TRTLLM-GEN backend based on the detected GPU's compute capability or explicit user choice. The changes also include dedicated unit tests to ensure the correctness and compatibility of the new backend. Highlights
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Code Review
This pull request introduces a new XQA MLA backend and its corresponding unit tests. The changes are well-structured. My review includes suggestions for improving code clarity and maintainability by removing commented-out code, consolidating validation logic, using constants for magic numbers, and clarifying documentation. I also noted a todo comment in a new test file that should be addressed.
| /*uint32_t const nbSubSeqPerSeq = [&]() -> uint32_t { | ||
| float const factor = 4.f; | ||
| return mha::min<uint32_t>( | ||
| mha::max<uint32_t>( | ||
| 1U, (uint32_t)round(multiProcessorCount / 4 / (batchSize * nbKHeads) * factor)), | ||
| divUp(maxSeqLen, tokensPerTile * 2)); | ||
| }(); | ||
| }();*/ // MLA disables multi-block mode for now |
There was a problem hiding this comment.
| bmm2_scale: fused scale for mla bmm2 input. | ||
| bmm1_scale_log2_tensor: On-device fused scale tensor for mla bmm1 input. Must be fused with * M_LOG2E before passing in. | ||
| bmm2_scale_tensor: On-device fused scale tensor for mla bmm2 input. | ||
| sinks: additional value per head in the denominator of the softmax. |
There was a problem hiding this comment.
The sinks parameter is not supported by the xqa backend, as indicated by the check at line 2589. This limitation should be documented in the docstring for clarity.
| sinks: additional value per head in the denominator of the softmax. | |
| sinks: additional value per head in the denominator of the softmax. Not supported by the ``xqa`` backend. |
| if ( | ||
| get_compute_capability(query.device)[0] != 12 | ||
| or query.dtype != torch.float8_e4m3fn | ||
| or kv_cache.dtype != torch.float8_e4m3fn | ||
| ): | ||
| raise ValueError( | ||
| f"XQA MLA only supports fp8 operation on SM120 GPUs, got {query.dtype} and {kv_cache.dtype}" | ||
| ) | ||
| if sinks is not None: | ||
| raise ValueError("XQA MLA does not support sinks") | ||
| if query.size(1) != 1: | ||
| raise ValueError( | ||
| f"XQA MLA only supports q_len_per_request == 1, got {query.size(1)}" | ||
| ) | ||
| return xqa_batch_decode_with_kv_cache_mla( | ||
| query, | ||
| kv_cache, | ||
| workspace_buffer, | ||
| qk_nope_head_dim, | ||
| kv_lora_rank, | ||
| qk_rope_head_dim, | ||
| block_tables, | ||
| seq_lens, | ||
| max_seq_len, | ||
| out, | ||
| bmm1_scale, | ||
| bmm2_scale, | ||
| sinks, | ||
| enable_pdl, | ||
| ) |
There was a problem hiding this comment.
The validation checks for the xqa backend are duplicated in the xqa_batch_decode_with_kv_cache_mla function. To avoid redundancy and improve maintainability, these checks should be consolidated within the xqa_batch_decode_with_kv_cache_mla function. The dispatcher should only be responsible for routing to the correct backend implementation.
return xqa_batch_decode_with_kv_cache_mla(
query,
kv_cache,
workspace_buffer,
qk_nope_head_dim,
kv_lora_rank,
qk_rope_head_dim,
block_tables,
seq_lens,
max_seq_len,
out,
bmm1_scale,
bmm2_scale,
sinks,
enable_pdl,
)| q_len_per_request = query.size(1) | ||
| if q_len_per_request != 1: | ||
| raise ValueError( | ||
| f"XQA MLA only supports q_len_per_request == 1, got {q_len_per_request}" | ||
| ) | ||
| if query.dtype != torch.float8_e4m3fn or kv_cache.dtype != torch.float8_e4m3fn: | ||
| raise ValueError( | ||
| f"XQA MLA only supports fp8 tensor core operation, got {query.dtype} and {kv_cache.dtype}" | ||
| ) | ||
| if sinks is not None: | ||
| raise ValueError("XQA MLA does not support sinks") |
There was a problem hiding this comment.
To make this function self-contained and safe to call directly, the compute capability check should be included here. This consolidates all xqa backend specific validations in one place, which is good practice when refactoring the dispatcher function trtllm_batch_decode_with_kv_cache_mla.
if get_compute_capability(query.device)[0] != 12:
raise ValueError("XQA MLA only supports SM120 GPUs")
q_len_per_request = query.size(1)
if q_len_per_request != 1:
raise ValueError(
f"XQA MLA only supports q_len_per_request == 1, got {q_len_per_request}"
)
if query.dtype != torch.float8_e4m3fn or kv_cache.dtype != torch.float8_e4m3fn:
raise ValueError(
f"XQA MLA only supports fp8 tensor core operation, got {query.dtype} and {kv_cache.dtype}"
)
if sinks is not None:
raise ValueError("XQA MLA does not support sinks")| semaphore = workspace_u8[: 8 * 1024 * 1024] # reserve 8MB for semaphore | ||
| scratch = workspace_u8[8 * 1024 * 1024 :] |
There was a problem hiding this comment.
The magic number 8 * 1024 * 1024 for the semaphore size is used here. It's better to define it as a constant with a descriptive name to improve readability and maintainability. This makes it easier to understand the purpose of the value and to change it in one place if needed.
SEMAPHORE_SIZE_BYTES = 8 * 1024 * 1024
semaphore = workspace_u8[:SEMAPHORE_SIZE_BYTES] # reserve 8MB for semaphore
scratch = workspace_u8[SEMAPHORE_SIZE_BYTES:]| batch_size * q_len_per_request, num_q_heads, qk_rope_head_dim | ||
| ) | ||
|
|
||
| # todo: fix kv_cache |
There was a problem hiding this comment.
Actionable comments posted: 1
🧹 Nitpick comments (4)
tests/attention/test_xqa_mla_batch_decode.py (1)
171-189: Consider extracting tolerance validation to a shared helper.The tolerance-based validation logic (computing absolute/relative differences and checking pass ratio) is duplicated in
tests/attention/test_trtllm_gen_mla.py(lines 242-260). Consider extracting this to a shared test utility function to reduce code duplication and ensure consistent validation across backends.tests/attention/test_trtllm_gen_mla.py (1)
217-240: Use bareraiseinstead ofraise e.Lines 228 and 240 use
raise e, which creates a new traceback. Use bareraiseto preserve the original exception traceback for better debugging.except AssertionError as e: print("output:", output) print("o_ref:", o_ref) - raise e + raise else: try: torch.testing.assert_close( output, o_ref.view(batch_size, q_len_per_request, num_q_heads, -1), rtol=1e-2, atol=1e-2, ) except AssertionError as e: print("output:", output) print("o_ref:", o_ref) - raise e + raiseBased on learnings
flashinfer/decode.py (2)
2684-2799: Unused parameters for API consistency.The parameters
max_seq_len,bmm1_scale_log2_tensor, andbmm2_scale_tensor(lines 2693, 2697-2698) are unused in the function body. These appear to be present for API consistency withtrtllm_batch_decode_with_kv_cache_mla. Consider adding a docstring note explaining why these parameters are accepted but not used in the XQA implementation, or add validation to reject them if provided.+ Note + ---- + The parameters `bmm1_scale_log2_tensor`, `bmm2_scale_tensor`, and `max_seq_len` + are accepted for API consistency but not used in the XQA MLA implementation.
2700-2700: Add explicitOptionaltype annotation.Line 2700 should use
Optional[bool]instead of implicitboolto match PEP 484 and be consistent with the function's default value ofNone.- enable_pdl: bool = None, + enable_pdl: Optional[bool] = None,Based on learnings
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (4)
csrc/xqa/mla_sm120.cu(1 hunks)flashinfer/decode.py(4 hunks)tests/attention/test_trtllm_gen_mla.py(4 hunks)tests/attention/test_xqa_mla_batch_decode.py(1 hunks)
🧰 Additional context used
🧬 Code graph analysis (3)
tests/attention/test_xqa_mla_batch_decode.py (3)
flashinfer/utils.py (1)
get_compute_capability(252-255)flashinfer/decode.py (7)
xqa_batch_decode_with_kv_cache_mla(2684-2799)plan(811-1102)plan(1610-1733)run(1132-1145)run(1148-1161)run(1163-1381)run(1735-1859)flashinfer/mla.py (1)
BatchMLAPagedAttentionWrapper(66-447)
flashinfer/decode.py (3)
flashinfer/xqa.py (4)
xqa(55-94)xqa(125-278)xqa_mla(303-334)xqa_mla(362-472)flashinfer/utils.py (4)
get_compute_capability(252-255)device_support_pdl(569-573)get_device_sm_count(596-597)check_shape_dtype_device(519-537)csrc/trtllm_fmha_kernel_launcher.cu (2)
trtllm_paged_attention_decode(197-265)trtllm_paged_attention_decode(197-204)
tests/attention/test_trtllm_gen_mla.py (1)
flashinfer/utils.py (1)
get_compute_capability(252-255)
🪛 Ruff (0.14.3)
flashinfer/decode.py
2586-2588: Avoid specifying long messages outside the exception class
(TRY003)
2590-2590: Avoid specifying long messages outside the exception class
(TRY003)
2592-2594: Avoid specifying long messages outside the exception class
(TRY003)
2622-2622: Avoid specifying long messages outside the exception class
(TRY003)
2635-2635: Consider (*query.shape[:-1], kv_lora_rank) instead of concatenation
Replace with (*query.shape[:-1], kv_lora_rank)
(RUF005)
2653-2655: Avoid specifying long messages outside the exception class
(TRY003)
2681-2681: Avoid specifying long messages outside the exception class
(TRY003)
2693-2693: Unused function argument: max_seq_len
(ARG001)
2697-2697: Unused function argument: bmm1_scale_log2_tensor
(ARG001)
2698-2698: Unused function argument: bmm2_scale_tensor
(ARG001)
2700-2700: PEP 484 prohibits implicit Optional
Convert to Optional[T]
(RUF013)
2744-2746: Avoid specifying long messages outside the exception class
(TRY003)
2748-2750: Avoid specifying long messages outside the exception class
(TRY003)
2752-2752: Avoid specifying long messages outside the exception class
(TRY003)
tests/attention/test_trtllm_gen_mla.py
228-228: Use raise without specifying exception name
Remove exception name
(TRY201)
240-240: Use raise without specifying exception name
Remove exception name
(TRY201)
🔇 Additional comments (5)
csrc/xqa/mla_sm120.cu (1)
1793-1803: LGTM: Clean simplification for initial MLA implementation.The change disables multi-block mode by fixing
gridDim.yto 1 and commenting out the dynamicnbSubSeqPerSeqcalculation. The inline comment clearly indicates this is a temporary simplification for the initial MLA release.tests/attention/test_xqa_mla_batch_decode.py (1)
7-9: Verify thread safety of global workspace buffers.Module-level global buffers may cause issues if tests run in parallel (e.g., with pytest-xdist). Consider using pytest fixtures with appropriate scope instead, or ensure tests that use these globals are marked to run sequentially.
Additionally, verify that reusing
global_xqa_workspace_bufferacross multiple test invocations doesn't cause flakiness, since the comment indicates it "must be zero initialized" for its first use.Also applies to: 94-104
tests/attention/test_trtllm_gen_mla.py (1)
26-47: LGTM: Backend parameterization and capability guards.The backend parameter and associated capability checks appropriately enforce requirements:
- XQA: SM120, q_len_per_request==1, fp8 dtype
- TRTLLM-GEN: SM100/SM103
flashinfer/decode.py (2)
2576-2610: LGTM: Backend routing logic with appropriate validation.The backend selection and validation logic correctly:
- Auto-selects backend based on compute capability (SM10→trtllm-gen, else→xqa)
- Enforces XQA-specific constraints (SM120, fp8, no sinks, q_len_per_request==1)
- Routes to the appropriate implementation
2777-2797: Verify parameter scale remapping for xqa_mla.The function remaps
bmm1_scale → q_scaleandbmm2_scale → kv_scalewhen callingxqa_mla. Ensure this mapping is semantically correct given the note in the docstring that:
bmm1_scale = q_scale * k_scale * sm_scale / (head_dim_qk ** 0.5)bmm2_scale = v_scale * o_scaleThe current mapping passes these composite scales directly as
q_scaleandkv_scaleto the underlying kernel.
There was a problem hiding this comment.
Actionable comments posted: 1
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (1)
tests/attention/test_xqa_mla_batch_decode.py(1 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
tests/attention/test_xqa_mla_batch_decode.py (3)
flashinfer/utils.py (1)
get_compute_capability(252-255)flashinfer/decode.py (7)
xqa_batch_decode_with_kv_cache_mla(2684-2799)plan(811-1102)plan(1610-1733)run(1132-1145)run(1148-1161)run(1163-1381)run(1735-1859)flashinfer/mla.py (1)
BatchMLAPagedAttentionWrapper(66-447)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
- GitHub Check: Deploy Docs
| workspace_u8 = workspace_buffer.view(torch.uint8) | ||
| semaphore = workspace_u8[: 8 * 1024 * 1024] # reserve 8MB for semaphore | ||
| scratch = workspace_u8[8 * 1024 * 1024 :] | ||
| kv_cache_new = kv_cache.squeeze(1).unsqueeze(2).contiguous() |
There was a problem hiding this comment.
Please don't use contiguous here, contiguous implies data movement which we don't want this to happen for decode attention APIs (I/O bound).
And in this case I think contiguous is unnecessary (internal data layout doesn't change).
btw, transpose the second dimension and third dimension seems more natural than squeeze(1).unsqueeze(2) to me.
There was a problem hiding this comment.
Done, and we can not replace squeeze(1).unsqueeze(2) with transpose(1, 2) because the stride is not the same.
| semaphore = workspace_u8[: 8 * 1024 * 1024] # reserve 8MB for semaphore | ||
| scratch = workspace_u8[8 * 1024 * 1024 :] | ||
| kv_cache_new = kv_cache.squeeze(1).unsqueeze(2).contiguous() | ||
| seq_lens_new = seq_lens.unsqueeze(1).contiguous() |
There was a problem hiding this comment.
contiguous is not necessary.
| Parameters: | ||
| query: [batch_size, q_len_per_request, num_heads, head_dim_qk], head_dim_qk = qk_nope_head_dim (kv_lora_rank) + qk_rope_head_dim, should be concated q_nope + q_rope; q_len_per_request is the MTP query length. | ||
| kv_cache: [num_pages, page_size, head_dim_ckv + head_dim_kpe], should be concated ckv_cache + kpe_cache | ||
| workspace_buffer: [num_semaphores, 4], used for multi_block mode. Must be initialized to 0 for its first use. |
There was a problem hiding this comment.
It might be the assumption of xqa attention APIs, but trtllm API assumes an one-dimensional tensor.
My suggestion is to not change the interface semantics here (considering we have another branch for trtllm-gen), and for xqa, we reshape it to [num_semaphores, 4] if necessary.
There was a problem hiding this comment.
Done, this comment is copied from trtllm_batch_decode_with_kv_cache_mla, actually xqa does not require such shape.
Signed-off-by: Qidi Sang <200703406+qsang-nv@users.noreply.github.com>
There was a problem hiding this comment.
Actionable comments posted: 1
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (3)
flashinfer/decode.py(6 hunks)tests/attention/test_trtllm_gen_mla.py(6 hunks)tests/attention/test_xqa_mla_batch_decode.py(1 hunks)
🚧 Files skipped from review as they are similar to previous changes (1)
- tests/attention/test_xqa_mla_batch_decode.py
🧰 Additional context used
🧬 Code graph analysis (2)
tests/attention/test_trtllm_gen_mla.py (1)
flashinfer/utils.py (1)
get_compute_capability(252-255)
flashinfer/decode.py (3)
flashinfer/xqa.py (4)
xqa(55-94)xqa(125-278)xqa_mla(303-334)xqa_mla(362-472)flashinfer/utils.py (4)
get_compute_capability(252-255)device_support_pdl(569-573)get_device_sm_count(596-597)check_shape_dtype_device(519-537)csrc/trtllm_fmha_kernel_launcher.cu (2)
trtllm_paged_attention_decode(197-265)trtllm_paged_attention_decode(197-204)
🪛 Ruff (0.14.3)
tests/attention/test_trtllm_gen_mla.py
228-228: Use raise without specifying exception name
Remove exception name
(TRY201)
240-240: Use raise without specifying exception name
Remove exception name
(TRY201)
flashinfer/decode.py
2584-2586: Avoid specifying long messages outside the exception class
(TRY003)
2588-2588: Avoid specifying long messages outside the exception class
(TRY003)
2590-2592: Avoid specifying long messages outside the exception class
(TRY003)
2620-2620: Avoid specifying long messages outside the exception class
(TRY003)
2633-2633: Consider (*query.shape[:-1], kv_lora_rank) instead of concatenation
Replace with (*query.shape[:-1], kv_lora_rank)
(RUF005)
2651-2653: Avoid specifying long messages outside the exception class
(TRY003)
2679-2679: Avoid specifying long messages outside the exception class
(TRY003)
2691-2691: Unused function argument: max_seq_len
(ARG001)
2695-2695: Unused function argument: bmm1_scale_log2_tensor
(ARG001)
2696-2696: Unused function argument: bmm2_scale_tensor
(ARG001)
2698-2698: PEP 484 prohibits implicit Optional
Convert to Optional[T]
(RUF013)
2742-2744: Avoid specifying long messages outside the exception class
(TRY003)
2746-2748: Avoid specifying long messages outside the exception class
(TRY003)
2750-2750: Avoid specifying long messages outside the exception class
(TRY003)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
- GitHub Check: Deploy Docs
| if backend == "auto": | ||
| backend = ( | ||
| "trtllm-gen" if get_compute_capability(query.device)[0] == 10 else "xqa" | ||
| ) | ||
| if backend == "xqa": |
There was a problem hiding this comment.
Fix auto backend selection for unsupported GPUs.
backend="auto" now maps every non-SM100 device to "xqa", but XQA MLA kernels are only supported on SM120 hardware. On Hopper (SM90) or any other architecture, this branch raises the ValueError defined below instead of choosing a valid backend or cleanly reporting "MLA decode not supported". Please gate the auto-selection by compute capability so that only SM120 falls through to "xqa" and older GPUs get an explicit unsupported error (or their previous behavior). (newreleases.io)
|
/bot run |
|
[SUCCESS] Pipeline #38195344: 15/17 passed |
📌 Description
add xqa mla backend and corresponding unittests
🔍 Related Issues
🚀 Pull Request Checklist
Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.
✅ Pre-commit Checks
pre-commitby runningpip install pre-commit(or used your preferred method).pre-commit install.pre-commit run --all-filesand fixed any reported issues.🧪 Tests
unittest, etc.).Reviewer Notes
Summary by CodeRabbit
New Features
Improvements
Tests