Skip to content

Support nextn for flashinfer mla attention backend#4218

Merged
zhyncs merged 3 commits intosgl-project:mainfrom
Fridge003:deepseek
Mar 9, 2025
Merged

Support nextn for flashinfer mla attention backend#4218
zhyncs merged 3 commits intosgl-project:mainfrom
Fridge003:deepseek

Conversation

@Fridge003
Copy link
Copy Markdown
Collaborator

@Fridge003 Fridge003 commented Mar 9, 2025

Motivation

Support the compatibility of nextn and flashinfer mla attention backend. Currently topk can only be set to 1 due to lack of custom mask support for flashinfer MLA wrapper.

Modifications

  • Implement class FlashInferMLAMultiStepDraftBackend for draft model when using flashinfer mla and eagle together.
  • Update some methods of FlashInferMLABackend so draft extend and target verify batches can be handled.
  • Update relevant document.

Usage

python3 -m sglang.launch_server --model deepseek-ai/DeepSeek-R1 --speculative-algo EAGLE --speculative-draft lmsys/DeepSeek-R1-NextN --speculative-num-steps 3 --speculative-eagle-topk 1 --speculative-num-draft-tokens 4 --trust-remote --tp 8 --enable-flashinfer-mla

The constraints of parameters:

  • speculative-eagle-topk should be set to 1
  • speculative-num-draft-tokens should be power of 2

Accuracy

GSM8K

python3 benchmark/gsm8k/bench_sglang.py --num-shots 8 --num-questions 200 --parallel 128
Accuracy: 0.975
Invalid: 0.000
Latency: 93.031 s
Output throughput: 213.779 token/s

MMLU

bash benchmark/mmlu/download_data.sh
python3 benchmark/mmlu/bench_sglang.py --nsub 100 --ntrain 5 --parallel 128
Total latency: 198.942
Average accuracy: 0.870

Benchmark

The benchmarks are run on 8*H200. Total throughput (tokens/sec) is used as the metric. Each benchmark is run five times and its average result is computed.

Launch

# Flashinfer + NextN
python3 -m sglang.launch_server --model deepseek-ai/DeepSeek-R1 --speculative-algo EAGLE --speculative-draft lmsys/DeepSeek-R1-NextN --speculative-num-steps 3 --speculative-eagle-topk 1 --speculative-num-draft-tokens 4 --trust-remote --tp 8 --enable-flashinfer-mla

# Flashinfer Only
python3 -m sglang.launch_server --model deepseek-ai/DeepSeek-R1 --tp 8 --trust-remote-code --enable-flashinfer-mla

# NextN + Triton
python3 -m sglang.launch_server --model deepseek-ai/DeepSeek-R1 --speculative-algo EAGLE --speculative-draft lmsys/DeepSeek-R1-NextN --speculative-num-steps 3 --speculative-eagle-topk 1 --speculative-num-draft-tokens 4 --trust-remote --tp 8

Input-4000-Output-200

python3 -m sglang.bench_serving --backend sglang --dataset-name random --random-input 4000 --random-output 200 --num-prompt 32
Flashinfer + NextN Flashinfer Only Triton + NextN
5702.74 4047.45 3757.34

Input-128-Output-128

python3 -m sglang.bench_serving --backend sglang --dataset-name random --random-input 128 --random-output 128 --num-prompt 32
Flashinfer + NextN Flashinfer Only Triton + NextN
905.28 630.62 929.86

Single prompt

python3 -m sglang.test.send_one
Flashinfer + NextN Flashinfer Only Triton + NextN
acc length 2.22 1.0 2.34
Throughput (tok/s) 54.83 37.71 55.25

Checklist

@lambert0312
Copy link
Copy Markdown
Contributor

lambert0312 commented Mar 9, 2025

Usage

I just experimented 16 x A800 GPU, using block-wise INT8 with nextn for flashinfer (this PR and #3911) and enable torch compile.

python3 -m sglang.launch_server --model deepseek-ai/DeepSeek-R1-Block-INT8 --speculative-algo EAGLE --speculative-draft lmsys/DeepSeek-R1-NextN-Block-INT8 --speculative-num-steps 2 --speculative-eagle-topk 1 --speculative-num-draft-tokens 2 --trust-remote --tp 16 --enable-flashinfer-mla  --enable-torch-compile --torch-compile-max-bs 4

Benchmark

Input-256-Output-256 (bs1)

python3 -m sglang.bench_serving --backend sglang --num-prompts 100 --dataset-name random --max-concurrency 1 --random-input 256 --random-output 256
============ Serving Benchmark Result ============
Backend:                                 sglang
Traffic request rate:                    inf
Max reqeuest concurrency:                1
Successful requests:                     100
Benchmark duration (s):                  275.58
Total input tokens:                      12612
Total generated tokens:                  13946
Total generated tokens (retokenized):    13862
Request throughput (req/s):              0.36
Input token throughput (tok/s):          45.77
Output token throughput (tok/s):         50.61
Total token throughput (tok/s):          96.37
Concurrency:                             1.00
Accept length:                           1.88
----------------End-to-End Latency----------------
Mean E2E Latency (ms):                   2754.84
Median E2E Latency (ms):                 2831.74
---------------Time to First Token----------------
Mean TTFT (ms):                          230.47
Median TTFT (ms):                        224.85
P99 TTFT (ms):                           308.32
---------------Inter-Token Latency----------------
Mean ITL (ms):                           18.24
Median ITL (ms):                         17.09
P95 ITL (ms):                            33.57
P99 ITL (ms):                            35.28
Max ITL (ms):                            41.60
==================================================

Input-256-Output-256 (bs16)

python3 -m sglang.bench_serving --backend sglang --num-prompts 100 --dataset-name random --max-concurrency 16 --random-input 256 --random-output 256
============ Serving Benchmark Result ============
Backend:                                 sglang
Traffic request rate:                    inf
Max reqeuest concurrency:                16
Successful requests:                     100
Benchmark duration (s):                  45.86
Total input tokens:                      12612
Total generated tokens:                  13946
Total generated tokens (retokenized):    13862
Request throughput (req/s):              2.18
Input token throughput (tok/s):          275.02
Output token throughput (tok/s):         304.11
Total token throughput (tok/s):          579.13
Concurrency:                             15.12
Accept length:                           1.88
----------------End-to-End Latency----------------
Mean E2E Latency (ms):                   6934.89
Median E2E Latency (ms):                 6949.50
---------------Time to First Token----------------
Mean TTFT (ms):                          350.74
Median TTFT (ms):                        278.31
P99 TTFT (ms):                           716.63
---------------Inter-Token Latency----------------
Mean ITL (ms):                           47.57
Median ITL (ms):                         28.52
P95 ITL (ms):                            147.11
P99 ITL (ms):                            265.12
Max ITL (ms):                            538.80
==================================================

Another Config

With --speculative-num-steps 3 and --speculative-num-draft-tokens 4:

python3 -m sglang.launch_server --model deepseek-ai/DeepSeek-R1-Block-INT8 --speculative-algo EAGLE --speculative-draft lmsys/DeepSeek-R1-NextN-Block-INT8 --speculative-num-steps 3 --speculative-eagle-topk 1 --speculative-num-draft-tokens 4 --trust-remote --tp 16 --enable-flashinfer-mla  --enable-torch-compile --torch-compile-max-bs 4

Benchmark

Input-256-Output-256 (bs1)

python3 -m sglang.bench_serving --backend sglang --num-prompts 100 --dataset-name random --max-concurrency 1 --random-input 256 --random-output 256
============ Serving Benchmark Result ============
Backend:                                 sglang
Traffic request rate:                    inf
Max reqeuest concurrency:                1
Successful requests:                     100
Benchmark duration (s):                  206.29
Total input tokens:                      12612
Total generated tokens:                  13946
Total generated tokens (retokenized):    13865
Request throughput (req/s):              0.48
Input token throughput (tok/s):          61.14
Output token throughput (tok/s):         67.61
Total token throughput (tok/s):          128.74
Concurrency:                             1.00
Accept length:                           2.70
----------------End-to-End Latency----------------
Mean E2E Latency (ms):                   2061.87
Median E2E Latency (ms):                 2059.90
---------------Time to First Token----------------
Mean TTFT (ms):                          220.47
Median TTFT (ms):                        213.55
P99 TTFT (ms):                           329.89
---------------Inter-Token Latency----------------
Mean ITL (ms):                           13.30
Median ITL (ms):                         11.92
P95 ITL (ms):                            20.25
P99 ITL (ms):                            36.89
Max ITL (ms):                            66.98
==================================================

Input-256-Output-256 (bs16)

python3 -m sglang.bench_serving --backend sglang --num-prompts 100 --dataset-name random --max-concurrency 16 --random-input 256 --random-output 256
============ Serving Benchmark Result ============
Backend:                                 sglang
Traffic request rate:                    inf
Max reqeuest concurrency:                16
Successful requests:                     100
Benchmark duration (s):                  44.07
Total input tokens:                      12612
Total generated tokens:                  13946
Total generated tokens (retokenized):    13877
Request throughput (req/s):              2.27
Input token throughput (tok/s):          286.19
Output token throughput (tok/s):         316.46
Total token throughput (tok/s):          602.65
Concurrency:                             15.22
Accept length:                           2.52
----------------End-to-End Latency----------------
Mean E2E Latency (ms):                   6709.11
Median E2E Latency (ms):                 6077.50
---------------Time to First Token----------------
Mean TTFT (ms):                          888.33
Median TTFT (ms):                        282.13
P99 TTFT (ms):                           4160.13
---------------Inter-Token Latency----------------
Mean ITL (ms):                           42.05
Median ITL (ms):                         23.14
P95 ITL (ms):                            134.53
P99 ITL (ms):                            254.66
Max ITL (ms):                            850.49
==================================================

Copy link
Copy Markdown
Contributor

@merrymercy merrymercy left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add a test case like this and assert the acceptance length

server_info = requests.get(self.base_url + "/get_server_info")
avg_spec_accept_length = server_info.json()["avg_spec_accept_length"]
print(f"{avg_spec_accept_length=}")
self.assertGreater(avg_spec_accept_length, 2.5)

Comment thread python/sglang/srt/layers/attention/flashinfer_mla_backend.py Outdated
@zhyncs zhyncs merged commit 9fb48f9 into sgl-project:main Mar 9, 2025
@junliu-mde
Copy link
Copy Markdown
Contributor

junliu-mde commented Mar 9, 2025

I noticed changes applied in #4217 will reject the args combination --speculative-num-steps 2 --speculative-eagle-topk 1 --speculative-num-draft-tokens 2 by:

assert self.speculative_num_steps < self.speculative_num_draft_tokens

This comment is just a reminder to prevent others from encountering the same confusion I experienced.

@junliu-mde
Copy link
Copy Markdown
Contributor

junliu-mde commented Mar 9, 2025

BTW maybe the --speculative-num-draft-tokens can not be 3? @Fridge003

I met the error below when set it to 3:

[2025-03-09 04:43:08 TP7] Scheduler hit an exception: Traceback (most recent call last):
  File "/usr/local/lib/python3.10/dist-packages/triton/language/core.py", line 35, in wrapper
    return fn(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/triton/language/core.py", line 1192, in arange
    return semantic.arange(start, end, _builder)
  File "/usr/local/lib/python3.10/dist-packages/triton/language/semantic.py", line 512, in arange
    raise ValueError("arange's range must be a power of 2")
ValueError: arange's range must be a power of 2

  File "/usr/local/lib/python3.10/dist-packages/torch/utils/_contextlib.py", line 116, in decorate_context
    return func(*args, **kwargs)
  File "/sgl-workspace/sglang/python/sglang/srt/managers/scheduler.py", line 486, in event_loop_normal
    result = self.run_batch(batch)
  File "/sgl-workspace/sglang/python/sglang/srt/managers/scheduler.py", line 1203, in run_batch
    ) = self.draft_worker.forward_batch_speculative_generation(batch)
  File "/sgl-workspace/sglang/python/sglang/srt/speculative/eagle_worker.py", line 179, in forward_batch_speculative_generation
    logits_output, verify_output, model_worker_batch = self.verify(
  File "/sgl-workspace/sglang/python/sglang/srt/speculative/eagle_worker.py", line 347, in verify
    res: EagleVerifyOutput = spec_info.verify(
  File "/sgl-workspace/sglang/python/sglang/srt/speculative/eagle_utils.py", line 358, in verify
    eagle_verify_retrive[(bs,)](
  File "/usr/local/lib/python3.10/dist-packages/triton/runtime/jit.py", line 345, in <lambda>
    return lambda *args, **kwargs: self.run(grid=grid, warmup=False, *args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/triton/runtime/jit.py", line 662, in run
    kernel = self.compile(
  File "/usr/local/lib/python3.10/dist-packages/triton/compiler/compiler.py", line 276, in compile
    module = src.make_ir(options, codegen_fns, context)   
  File "/usr/local/lib/python3.10/dist-packages/triton/compiler/compiler.py", line 113, in make_ir
    return ast_to_ttir(self.fn, self, context=context, options=options, codegen_fns=codegen_fns)
triton.compiler.errors.CompilationError: at 30:20:
        extract_index (out): Index for last accepted tokens
        max_len: Maximum length in a batch
        draft_token_num: Number of tokens speculatively generated
        max_len_upper An upper bound for token sequence length
    """
    pid = tl.program_id(axis=0)

    retrive_end = tl.load(retrive_cum_len + pid + 1)
    retrive_start = tl.load(retrive_cum_len + pid)
    retrive_len = retrive_end - retrive_start
    accept_ptr = accept_mask + retrive_start
    accept_offset = tl.arange(0, draft_token_num)
                    ^

Because num_steps < num_draft_tokens and num_draft_tokens - 1 <= num_steps * topk
While topk == 1 and num_draft_tokens = 2,4,8...
So the minimal combination is topk = 1, num-steps = 3, num-draft-tokens = 4?...

@Fridge003
Copy link
Copy Markdown
Collaborator Author

Fridge003 commented Mar 9, 2025

@junliu-mde Currently you can try topk=1, num-steps=3, num-draft-tokens=4. I just discussed with team, and they tell me there will be updates to speculative decoding features that remove these constraints. Also, I have changed the setting to num-steps=3, num-draft-tokens=4, and recorded a new set of benchmark data.

@xihuai18
Copy link
Copy Markdown
Contributor

xihuai18 commented Mar 9, 2025

@lambert0312 could you share your dependence?I meet the following problem
File "/sglang/python/sglang/srt/models/deepseek_v2.py", line 976, in forward
hidden_states = self.self_attn(
File "/usr/local/conda/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/usr/local/conda/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
return forward_call(*args, **kwargs)
File "/sglang/python/sglang/srt/models/deepseek_v2.py", line 585, in forward
return self.forward_absorb(positions, hidden_states, forward_batch)
File "/sglang/python/sglang/srt/models/deepseek_v2.py", line 663, in forward_absorb
q_nope_out = torch.bmm(q_nope.transpose(0, 1), self.w_kc)
RuntimeError: "baddbmm_cuda" not implemented for 'Char'

@Fridge003
Copy link
Copy Markdown
Collaborator Author

@lambert0312 could you share your dependence?I meet the following problem File "/sglang/python/sglang/srt/models/deepseek_v2.py", line 976, in forward hidden_states = self.self_attn( File "/usr/local/conda/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl return self._call_impl(*args, **kwargs) File "/usr/local/conda/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl return forward_call(*args, **kwargs) File "/sglang/python/sglang/srt/models/deepseek_v2.py", line 585, in forward return self.forward_absorb(positions, hidden_states, forward_batch) File "/sglang/python/sglang/srt/models/deepseek_v2.py", line 663, in forward_absorb q_nope_out = torch.bmm(q_nope.transpose(0, 1), self.w_kc) RuntimeError: "baddbmm_cuda" not implemented for 'Char'

Try upgrading sgl-kernel to the latest 0.0.4

@xihuai18
Copy link
Copy Markdown
Contributor

@lambert0312 could you share your dependence?I meet the following problem File "/sglang/python/sglang/srt/models/deepseek_v2.py", line 976, in forward hidden_states = self.self_attn( File "/usr/local/conda/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl return self._call_impl(*args, **kwargs) File "/usr/local/conda/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl return forward_call(*args, **kwargs) File "/sglang/python/sglang/srt/models/deepseek_v2.py", line 585, in forward return self.forward_absorb(positions, hidden_states, forward_batch) File "/sglang/python/sglang/srt/models/deepseek_v2.py", line 663, in forward_absorb q_nope_out = torch.bmm(q_nope.transpose(0, 1), self.w_kc) RuntimeError: "baddbmm_cuda" not implemented for 'Char'

Try upgrading sgl-kernel to the latest 0.0.4

already 0.0.4

@lambert0312
Copy link
Copy Markdown
Contributor

@lambert0312 could you share your dependence?I meet the following problem
File "/sglang/python/sglang/srt/models/deepseek_v2.py", line 976, in forward
hidden_states = self.self_attn(
File "/usr/local/conda/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/usr/local/conda/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
return forward_call(*args, **kwargs)
File "/sglang/python/sglang/srt/models/deepseek_v2.py", line 585, in forward
return self.forward_absorb(positions, hidden_states, forward_batch)
File "/sglang/python/sglang/srt/models/deepseek_v2.py", line 663, in forward_absorb
q_nope_out = torch.bmm(q_nope.transpose(0, 1), self.w_kc)
RuntimeError: "baddbmm_cuda" not implemented for 'Char'

@xihuai18 I used this PR, with a modification I mentioned earlier PR #3911 MTP with INT8 support

@xihuai18
Copy link
Copy Markdown
Contributor

@lambert0312 could you share your dependence?I meet the following problem
File "/sglang/python/sglang/srt/models/deepseek_v2.py", line 976, in forward
hidden_states = self.self_attn(
File "/usr/local/conda/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/usr/local/conda/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
return forward_call(*args, **kwargs)
File "/sglang/python/sglang/srt/models/deepseek_v2.py", line 585, in forward
return self.forward_absorb(positions, hidden_states, forward_batch)
File "/sglang/python/sglang/srt/models/deepseek_v2.py", line 663, in forward_absorb
q_nope_out = torch.bmm(q_nope.transpose(0, 1), self.w_kc)
RuntimeError: "baddbmm_cuda" not implemented for 'Char'

@xihuai18 I used this PR, with a modification I mentioned earlier PR #3911 MTP with INT8 support

Thanks, I will try

@xihuai18
Copy link
Copy Markdown
Contributor

@lambert0312 could you share your dependence?I meet the following problem
File "/sglang/python/sglang/srt/models/deepseek_v2.py", line 976, in forward
hidden_states = self.self_attn(
File "/usr/local/conda/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/usr/local/conda/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
return forward_call(*args, **kwargs)
File "/sglang/python/sglang/srt/models/deepseek_v2.py", line 585, in forward
return self.forward_absorb(positions, hidden_states, forward_batch)
File "/sglang/python/sglang/srt/models/deepseek_v2.py", line 663, in forward_absorb
q_nope_out = torch.bmm(q_nope.transpose(0, 1), self.w_kc)
RuntimeError: "baddbmm_cuda" not implemented for 'Char'

@xihuai18 I used this PR, with a modification I mentioned earlier PR #3911 MTP with INT8 support

How about the accuracy in #3911 ?

@lambert0312
Copy link
Copy Markdown
Contributor

lambert0312 commented Mar 10, 2025

How about the accuracy in #3911 ?

Wait, I'll run one. @xihuai18

Usage (Flashinfer Only)

python3 -m sglang.launch_server --model path/to/DeepSeek-R1-Block-INT8 --trust-remote --tp 16 --enable-flashinfer-mla

Accuracy

GSM8K

python3 benchmark/gsm8k/bench_sglang.py --num-shots 8 --num-questions 200 --parallel 128
Accuracy: 0.965
Invalid: 0.000
Latency: 46.802 s
Output throughput: 422.272 token/s

MMLU

python3 benchmark/mmlu/bench_sglang.py --nsub 100 --ntrain 5 --parallel 128
Total latency: 269.918
Average accuracy: 0.872

Usage (Flashinfer + NextN)

python3 -m sglang.launch_server --model path/to/DeepSeek-R1-Block-INT8 --speculative-algo EAGLE --speculative-draft path/to/DeepSeek-R1-NextN-Block-INT8 --speculative-num-steps 3 --speculative-eagle-topk 1 --speculative-num-draft-tokens 4 --trust-remote --tp 16 --enable-flashinfer-mla

Accuracy

GSM8K

python3 benchmark/gsm8k/bench_sglang.py --num-shots 8 --num-questions 200 --parallel 128
Accuracy: 0.985
Invalid: 0.000
Latency: 54.378 s
Output throughput: 358.731 token/s

MMLU

python3 benchmark/mmlu/bench_sglang.py --nsub 100 --ntrain 5 --parallel 128
Total latency: 215.589
Average accuracy: 0.871

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants