Skip to content

FIX: (NSA) Compute topk_indices_offset when NSA prefill flashmla_sparse is used with FP8 KV cache#20606

Merged
Fridge003 merged 1 commit intosgl-project:mainfrom
bytedance-iaas:horenc/fix_nsa_prefill_flashmla_sparse_kv8
Mar 26, 2026
Merged

FIX: (NSA) Compute topk_indices_offset when NSA prefill flashmla_sparse is used with FP8 KV cache#20606
Fridge003 merged 1 commit intosgl-project:mainfrom
bytedance-iaas:horenc/fix_nsa_prefill_flashmla_sparse_kv8

Conversation

@JackChuang
Copy link
Copy Markdown
Contributor

@JackChuang JackChuang commented Mar 15, 2026

Motivation

When using the flashmla_sparse NSA prefill backend with FP8 KV cache, topk_indices_offset is never computed outside the normal EXTEND forward mode, causing a crash in forward_extend(). Rather than letting this silently crash the server, this PR ensures topk_indices_offset is always correctly computed whenever TopkTransformMethod.RAGGED is active, allowing inference to proceed normally.

Root Cause

The bug is triggered by any configuration that satisfies:

  1. FP8 KV cache (--kv-cache-dtype fp8_e4m3fn) — BF16 KV cache routes through PAGED and is entirely unaffected.
  2. flashmla_sparse as the prefill backend — either explicitly via --nsa-prefill-backend flashmla_sparse, or automatically selected by flashmla_auto heuristic when total_kv_tokens < total_q_tokens * 512.

This is GPU architecture-independent and affects both SM90 (Hopper/H200) and SM100 (Blackwell/B200). Short prompts that manually select flashmla_sparse will hit the crash as when the backend is forced explicitly — rather than degrading gracefully or falling back.

Error log:

  File "/data07/jackc/sglang_q8kv8/python/sglang/srt/models/deepseek_v2.py", line 1310, in forward
    s = self.forward_prepare(
        ^^^^^^^^^^^^^^^^^^^^^
  File "/data07/jackc/sglang_q8kv8/python/sglang/srt/models/deepseek_v2.py", line 1366, in forward_prepare
    inner_state = self.forward_absorb_prepare(
                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/data07/jackc/sglang_q8kv8/python/sglang/srt/models/deepseek_common/attention_forward_methods/forward_mla.py", line 196, in forward_absorb_prepare
    topk_indices = self.indexer(
                   ^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1775, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1786, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/data07/jackc/sglang_q8kv8/python/sglang/srt/layers/utils/multi_platform.py", line 71, in forward
    return self._forward_method(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/data07/jackc/sglang_q8kv8/python/sglang/srt/layers/attention/nsa/nsa_indexer.py", line 1145, in forward_cuda
    topk_result = self._get_topk_paged(
                  ^^^^^^^^^^^^^^^^^^^^^
  File "/data07/jackc/sglang_q8kv8/python/sglang/srt/layers/attention/nsa/nsa_indexer.py", line 460, in _get_topk_paged
    topk_result = metadata.topk_transform(logits, self.index_topk)
                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/data07/jackc/sglang_q8kv8/python/sglang/srt/layers/attention/nsa_backend.py", line 257, in topk_transform
    return fast_topk_transform_ragged_fused(
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/sgl_kernel/top_k.py", line 114, in fast_topk_transform_ragged_fused
    torch.ops.sgl_kernel.fast_topk_transform_ragged_fused(
  File "/usr/local/lib/python3.12/dist-packages/torch/_ops.py", line 1255, in __call__
    return self._op(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^
RuntimeError: topk_indices_offset must be a CUDA tensor

Fix/Modification

  • Made get_topk_transform_method aware of forward_mode.
  • Updated call sites to pass forward_mode consistently during metadata initialization and runtime path selection.
  • Added a fail-fast check in ragged fused top-k transform to require topk_indices_offset and raise a clear error when metadata is missing.

With this fix, running FP8 + nsa prefill flashmla_sparse across all prompt lengths — including short prompts that would previously crash — now works correctly without any accuracy or performance impact.

Tested

Machine: H200*8
Request: income=3500, outcome=1500
Model: DeepSeek-V3.2-Exp

CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 SGLANG_ENABLE_JIT_DEEPGEMM=0 TORCH_CUDA_ARCH_LIST=9.0 python3 -m sglang.launch_server --model-path /models/DeepSeek-V3.2-Exp --tp-size 8 --ep-size 8 --dp-size 8 --enable-dp-lm-head --enable-dp-attention --trust-remote-code --device cuda --host 127.0.0.1 --port 33000 --mem-fraction-static 0.80 --attention-backend nsa --disable-cuda-graph --nsa-decode-backend flashmla_kv --kv-cache-dtype fp8_e4m3 --nsa-prefill-backend flashmla_sparse

python3 -m sglang.bench_serving --backend sglang-oai-chat --base-url http://127.0.0.1:33000 --model /models/DeepSeek-V3.2-Exp --dataset-name random --seed 5 --random-input-len 3500 --random-output-len 1500 --num-prompts 512

Accuracy test results - gsm8k

$ python3 benchmark/gsm8k/bench_sglang.py --num-shots 20 --port 47000

Results
Accuracy: 0.985
Invalid: 0.000

Checklist

Review Process

  1. Ping Merge Oncalls to start the PR flow. See the PR Merge Process.
  2. Get approvals from CODEOWNERS and other reviewers.
  3. Trigger CI tests with comments or contact authorized users to do so.
    • /tag-run-ci-label, /rerun-failed-ci, /tag-and-rerun-ci
  4. After green CI and required approvals, ask Merge Oncalls to merge.

@gemini-code-assist
Copy link
Copy Markdown
Contributor

Summary of Changes

Hello, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request addresses a specific runtime error that caused the system to crash when a particular combination of KV cache type (FP8) and prefill backend (flashmla_sparse) was used. The core issue was an uninitialized tensor required for attention indexing. The changes ensure that this tensor is correctly computed in all relevant execution paths, thereby enhancing the stability and reliability of the inference process under these configurations.

Highlights

  • Bug Fix: FP8 KV Cache Crash: Resolved a critical crash occurring when using the flashmla_sparse NSA prefill backend with FP8 KV cache. The crash was due to topk_indices_offset not being computed in specific scenarios.
  • Enhanced topk_indices_offset Computation: Implemented the computation of topk_indices_offset within the is_decode_or_idle() branch, specifically when TopkTransformMethod.RAGGED is active. This ensures correct indexing during decode operations.
  • Improved Inference Stability: The fix allows inference to proceed normally across all prompt lengths, including short prompts that previously caused crashes, without impacting accuracy or performance.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Changelog
  • python/sglang/srt/layers/attention/nsa_backend.py
    • Added logic to compute topk_indices_offset for TopkTransformMethod.RAGGED within the decode/idle forward mode.
Activity
  • The pull request was created by JackChuang to fix a reported crash.
  • The author provided a detailed motivation, root cause analysis, and description of the fix.
  • Testing details on H200*8 with a specific model and benchmark command were included.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for GitHub and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request addresses a crash that occurs when using the flashmla_sparse NSA prefill backend with an FP8 KV cache. The root cause is that topk_indices_offset was not being computed for decode steps following a prefill, leading to a RuntimeError. The fix correctly computes this offset within the is_decode_or_idle() branch by setting it to cu_seqlens_k[:-1]. This change is logical and effectively resolves the bug, allowing inference to proceed correctly in the specified configuration. The implementation is clean and well-commented.

@Fridge003
Copy link
Copy Markdown
Collaborator

@JackChuang I think the root cause is, when we are running decoding batches, the topk_transform_method shouldn't be TopkTransformMethod.RAGGED. So a better way might be fixing the logic of get_topk_transform_method

@Fridge003
Copy link
Copy Markdown
Collaborator

Also can you please post the result of gsm8k 20shots/GPQA after this change

@JackChuang JackChuang force-pushed the horenc/fix_nsa_prefill_flashmla_sparse_kv8 branch from f1a3c71 to 0ad8ae9 Compare March 25, 2026 17:49
@JackChuang
Copy link
Copy Markdown
Contributor Author

JackChuang commented Mar 25, 2026

@Fridge003 Thanks for your inputs! I've modified the code according to your inputs. Please check again. Thank you again!

Accuracy test
$ python3 benchmark/gsm8k/bench_sglang.py --num-shots 20 --port 47000
Accuracy: 0.985
Invalid: 0.000

Thanks to @rainj-me for providing input on the NSA/MTP part.

@JackChuang
Copy link
Copy Markdown
Contributor Author

@Fridge003 Thanks for your approval! If there's nothing else that needs to be addressed, could you help merge this PR? Thanks.

@Fridge003
Copy link
Copy Markdown
Collaborator

@Fridge003 Thanks for your approval! If there's nothing else that needs to be addressed, could you help merge this PR? Thanks.

Can be merged after passing CIs

@Fridge003
Copy link
Copy Markdown
Collaborator

@JackChuang Please fix lint

Use mode-aware topk transform and guard missing ragged offsets
- Pass forward_mode into topk transform method selection.
- Force PAGED transform on decode when flashmla_sparse+fp8 would otherwise use RAGGED.
- Add a clear RuntimeError if RAGGED fused transform is called without topk_indices_offset.
If you want an even shorter one-liner:

Signed-off-by: Ho-Ren (Jack) Chuang <horenchuang@bytedance.com>
@JackChuang JackChuang force-pushed the horenc/fix_nsa_prefill_flashmla_sparse_kv8 branch from 0ad8ae9 to c9dd583 Compare March 26, 2026 17:15
@JackChuang
Copy link
Copy Markdown
Contributor Author

@Fridge003 Got it. Thanks!
The Lint issue has been fixed. It can pass now.

@Fridge003
Copy link
Copy Markdown
Collaborator

/rerun-stage stage-c-large-8-gpu-h200

@github-actions
Copy link
Copy Markdown
Contributor

❌ Stage stage-c-large-8-gpu-h200 doesn't support isolated runs yet.

NVIDIA stages:

  • stage-a-test-1-gpu-small
  • stage-a-test-cpu
  • stage-b-test-1-gpu-small
  • stage-b-test-1-gpu-large
  • stage-b-test-2-gpu-large
  • stage-b-test-4-gpu-b200
  • stage-c-test-4-gpu-h100
  • stage-c-test-8-gpu-h200
  • stage-c-test-8-gpu-h20
  • stage-c-test-4-gpu-b200
  • stage-c-test-4-gpu-gb200
  • stage-c-test-deepep-4-gpu-h100
  • stage-c-test-deepep-8-gpu-h200
  • multimodal-gen-test-1-gpu
  • multimodal-gen-test-2-gpu

AMD stages:

  • sgl-kernel-unit-test-amd
  • sgl-kernel-unit-test-2-gpu-amd
  • stage-a-test-1-gpu-small-amd
  • stage-b-test-1-gpu-small-amd
  • stage-b-test-1-gpu-small-amd-nondeterministic
  • stage-b-test-1-gpu-small-amd-mi35x
  • stage-b-test-1-gpu-large-amd
  • stage-b-test-2-gpu-large-amd
  • multimodal-gen-test-1-gpu-amd
  • multimodal-gen-test-2-gpu-amd
  • stage-c-test-large-8-gpu-amd
  • stage-c-test-large-8-gpu-amd-mi35x

Other stages will be added soon. For now, use /rerun-failed-ci for those stages.

@Fridge003
Copy link
Copy Markdown
Collaborator

/rerun-stage stage-c-test-8-gpu-h200

@Fridge003
Copy link
Copy Markdown
Collaborator

/rerun-stage stage-c-test-4-gpu-b200

@github-actions
Copy link
Copy Markdown
Contributor

✅ Triggered stage-c-test-8-gpu-h200 to run independently (skipping dependencies).

@github-actions
Copy link
Copy Markdown
Contributor

✅ Triggered stage-c-test-4-gpu-b200 to run independently (skipping dependencies).

@github-actions
Copy link
Copy Markdown
Contributor

🔗 View workflow run

@github-actions
Copy link
Copy Markdown
Contributor

🔗 View workflow run

@Fridge003 Fridge003 merged commit 4b5f63e into sgl-project:main Mar 26, 2026
53 of 61 checks passed
Fridge003 pushed a commit that referenced this pull request Mar 26, 2026
…se is used with FP8 KV cache (#20606)

Signed-off-by: Ho-Ren (Jack) Chuang <horenchuang@bytedance.com>
satyamk7054 pushed a commit to satyamk7054/sglang that referenced this pull request Apr 3, 2026
…se is used with FP8 KV cache (sgl-project#20606)

Signed-off-by: Ho-Ren (Jack) Chuang <horenchuang@bytedance.com>
JustinTong0323 pushed a commit to JustinTong0323/sglang that referenced this pull request Apr 7, 2026
…se is used with FP8 KV cache (sgl-project#20606)

Signed-off-by: Ho-Ren (Jack) Chuang <horenchuang@bytedance.com>
yhyang201 pushed a commit to yhyang201/sglang that referenced this pull request Apr 22, 2026
…se is used with FP8 KV cache (sgl-project#20606)

Signed-off-by: Ho-Ren (Jack) Chuang <horenchuang@bytedance.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants