Skip to content

Add disable_chunked_prefix_cache feature to TRTLLM MLA#10178

Closed
elfiegg wants to merge 5 commits intosgl-project:mainfrom
elfiegg:fix-new
Closed

Add disable_chunked_prefix_cache feature to TRTLLM MLA#10178
elfiegg wants to merge 5 commits intosgl-project:mainfrom
elfiegg:fix-new

Conversation

@elfiegg
Copy link
Copy Markdown
Collaborator

@elfiegg elfiegg commented Sep 8, 2025

Motivation

Flashinfer chunked-prefill has accuracy issue for deepseek fp4 model. add feature to disable it for a temp workaround

Currently if disable chunked-prefill cache it will

  File "/sgl-workspace/sglang/python/sglang/srt/models/deepseek_v2.py", line 1884, in forward_normal_chunked_kv_core
    forward_batch.attn_backend.init_mha_chunk_metadata(forward_batch)
  File "/sgl-workspace/sglang/python/sglang/srt/layers/attention/trtllm_mla_backend.py", line 349, in init_mha_chunk_metadata
    super().init_mha_chunk_metadata(forward_batch, disable_flashinfer_ragged=True)
  File "/sgl-workspace/sglang/python/sglang/srt/layers/attention/flashinfer_mla_backend.py", line 500, in init_mha_chunk_metadata
    self.mha_chunk_kv_cache.update_wrapper(forward_batch, disable_flashinfer_ragged)
    ^^^^^^^^^^^^^^^^^^^^^^^
AttributeError: 'TRTLLMMLABackend' object has no attribute 'mha_chunk_kv_cache'

accuracy after the fix

while true; do python3 benchmark/gsm8k/bench_sglang.py --port 40000  --num-questions 10000   --parallel 10000   --num-shots 8; done

Accuracy: 0.949
Invalid: 0.011
Latency: 119.396 s
Output throughput: 1282.291 token/s

Accuracy: 0.942
Invalid: 0.015
Latency: 96.878 s
Output throughput: 1600.295 token/s
100%|
Accuracy: 0.945
Invalid: 0.014
Latency: 93.947 s
Output throughput: 1626.703 token/s
python -m sglang.launch_server   --max-running-requests 1024   --disable-radix-cache   --disable-shared-experts-fusion  --disable-chunked-prefix-cache  --tp-size 8   --dp-size 8   --enable-dp-attention   --chunked-prefill-size 16384   --moe-dense-tp-size 1   --enable-dp-lm-head   --model-path nvidia/DeepSeek-R1-0528-FP4   --trust-remote-code   --port 40000   --mem-fraction-static 0.84   --quantization modelopt_fp4   --enable-ep-moe   --moe-runner-backend flashinfer_cutlass  --attention-backend trtllm_mla
python3 benchmark/gsm8k/bench_sglang.py --num-shots 8 --num-questions 1316 --parallel 1316 --port 40000

Modifications

Accuracy Tests

Benchmarking and Profiling

Checklist

@Fridge003 Fridge003 self-assigned this Sep 8, 2025
@elfiegg elfiegg marked this pull request as ready for review September 8, 2025 21:07
@zhyncs zhyncs added bug Something isn't working high priority labels Sep 8, 2025
@zhyncs
Copy link
Copy Markdown
Collaborator

zhyncs commented Sep 8, 2025

pre-commit run --all-files

@elfiegg elfiegg requested a review from hnyls2002 as a code owner September 8, 2025 21:42
@elfiegg
Copy link
Copy Markdown
Collaborator Author

elfiegg commented Sep 8, 2025

done @zhyncs

Comment thread python/sglang/srt/models/deepseek_v2.py Outdated
@Fridge003
Copy link
Copy Markdown
Collaborator

Also, do we have CI test for trtllm-mla backend on Blackwell? (put under test/srt)
If we don't, we need to add them after Blackwell CI is supported in #9604

@elfiegg @zhyncs

@zhyncs
Copy link
Copy Markdown
Collaborator

zhyncs commented Sep 8, 2025

@Fridge003 @elfiegg

"per-commit-8-gpu-b200": [
# add more here
TestFile("test_gpt_oss_4gpu.py", 600),
],

Comment thread python/sglang/srt/layers/attention/trtllm_mla_backend.py
@elfiegg
Copy link
Copy Markdown
Collaborator Author

elfiegg commented Sep 9, 2025

Also, do we have CI test for trtllm-mla backend on Blackwell? (put under test/srt)
If we don't, we need to add them after Blackwell CI is supported in #9604

shall I open an issue for tracking? @Fridge003

@elfiegg
Copy link
Copy Markdown
Collaborator Author

elfiegg commented Sep 9, 2025

Issue for tracking CI FP4/FP8 deepseek model: #10237

@Fridge003
Copy link
Copy Markdown
Collaborator

@elfiegg #10180 seems to fix the accuracy issue. Can you have a try?

@elfiegg
Copy link
Copy Markdown
Collaborator Author

elfiegg commented Sep 10, 2025

@Fridge003 that's a temp solution and doesn't seem to be the root cause and it's only for FP4 model. Also the perf is going to drop to 1/2 with FA2 backend.
FP8 model works well with chunk prefix cache default on by the way

@elfiegg
Copy link
Copy Markdown
Collaborator Author

elfiegg commented Sep 10, 2025

@Fridge003 but I was debugging with Shu yesterday, we both run FP4 model with Flashinfer FA2 kernel and the issue went away. Either cutlass or TRTLLM kernel would cause the accuracy drop.
Nothing is reproducible on the kernel level though. We do not yet have confidence if this is a kernel issue.

@elfiegg
Copy link
Copy Markdown
Collaborator Author

elfiegg commented Sep 10, 2025

Looks like the issue is because FP4 model is triggerring #8995.
https://github.com/sgl-project/sglang/blob/main/python/sglang/srt/model_executor/forward_batch_info.py#L713-L733
It will schedule decode batch to a prefill request, something like

[2025-09-10 06:28:40 DP2 TP2 EP2] casual=True: q.shape=torch.Size([982, 128, 192]), k.shape=torch.Size([982, 128, 192]), v.shape=torch.Size([982, 128, 128]), seq_lens=tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 1, 1], device='cuda:2'), max_seq_len=1, cum_seq_lens=tensor([  0,   1,   2,   3,   4,   5,   6,   7,   8,   9,  10,  11,  12,  13,
         14,  15,  16,  17,  18,  19,  20,  21,  22,  23,  24,  25,  26,  27,
         28,  29,  30,  31,  32,  33,  34,  35,  36,  37,  38,  39,  40,  41,
         42,  43,  44,  45,  46,  47,  48,  49,  50,  51,  52,  53,  54,  55,
         56,  57,  58,  59,  60,  61,  62,  63,  64,  65,  66,  67,  68,  69,
         70,  71,  72,  73,  74,  75,  76,  77,  78,  79,  80,  81,  82,  83,
         84,  85,  86,  87,  88,  89,  90,  91,  92,  93,  94,  95,  96,  97,
         98,  99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111,
        112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125,
        126, 127, 128], device='cuda:2', dtype=torch.int32)

Which cutlass / trtllm has 100% mismatch elements compared to FA2

  File "/opt/mycode/python/sglang/srt/layers/attention/trtllm_mla_backend.py", line 608, in forward_extend
    assert ratio < 1.0, f"Mismatch ratio {ratio:.4%} is not greater than 1%. Found {num_mismatch} mismatches of {num_total} elements."
           ^^^^^^^^^^^
AssertionError: Mismatch ratio 100.0000% is not greater than 1%. Found 16089088 mismatches of 16089088 elements.

@elfiegg
Copy link
Copy Markdown
Collaborator Author

elfiegg commented Sep 11, 2025

Merged changes into #10180

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

bug Something isn't working high priority

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants