Skip to content

[Feature] Spec V2 DFlash Support#23000

Open
dcw02 wants to merge 13 commits intomainfrom
dcw02/dflash-spec-v2
Open

[Feature] Spec V2 DFlash Support#23000
dcw02 wants to merge 13 commits intomainfrom
dcw02/dflash-spec-v2

Conversation

@dcw02
Copy link
Copy Markdown
Collaborator

@dcw02 dcw02 commented Apr 16, 2026

Motivation

Add spec v2 to DFlash

Benchmarks

Run on gcp b200:8 node, using a gsm8k sweep script, qwen3-8b target, z-lab/Qwen3-8B-DFlash-b16 draft model, trtllm_mha target attention, fa4 draft attention, piecewise cuda graphs on.

v1 performance

DFLASH output tok/s
tp\conc       1         32
-------  ------  ---------
      1  845.98  11,405.85

DFLASH accuracy
tp\conc      1     32
-------  -----  -----
      1  0.852  0.844

DFLASH acceptance length (mean spec_accept_length)
tp\conc      1     32
-------  -----  -----
      1  6.345  6.487

v2 performance

DFLASH output tok/s
tp\conc         1         32
-------  --------  ---------
      1  1,075.48  13,022.18

DFLASH accuracy
tp\conc      1     32
-------  -----  -----
      1  0.852  0.844

DFLASH acceptance length (mean spec_accept_length)
tp\conc      1     32
-------  -----  -----
      1  6.352  6.482

this spec v2 version also brings in some extra optimizations compared to #20547 which brought bs1 performance from 900 -> 1075 tok/s and bs32 from 12,300 -> 13,000 tok/s.

Benchmarking is done with this script using the command SGLANG_ENABLE_SPEC_V2=1 SGLANG_ENABLE_OVERLAP_PLAN_STREAM=1 python benchmark/dflash/bench_dflash_gsm8k_sweep.py --skip-baseline --tp-sizes 1 --concurrencies 1,32 --attention-backends trtllm_mha --speculative-draft-attention-backend fa4 on 1xB200

i removed mamba memory calculations to add later once i figure out the best way to do that

@gemini-code-assist
Copy link
Copy Markdown
Contributor

Warning

You have reached your daily quota limit. Please wait up to 24 hours and I will start processing your requests again!

@dcw02
Copy link
Copy Markdown
Collaborator Author

dcw02 commented Apr 16, 2026

/rerun-test test/registered/spec/dflash/test_dflash.py

@github-actions
Copy link
Copy Markdown
Contributor

1-gpu-5090 (1 test): View workflow run

cd test/ && python3 registered/spec/dflash/test_dflash.py

@ggg-s
Copy link
Copy Markdown

ggg-s commented Apr 16, 2026

What optimizations were made on top of PR #20547? PCG?

@dcw02
Copy link
Copy Markdown
Collaborator Author

dcw02 commented Apr 16, 2026

What optimizations were made on top of PR #20547? PCG?

I rewrote the fused kv helper, added some new triton ops, removed some syncs, etc. PCG already exists, I did not add it.

@dcw02
Copy link
Copy Markdown
Collaborator Author

dcw02 commented Apr 24, 2026

I am investigating accept length degradations for both v1 and v2 paths in this PR but not in #20547

@dcw02
Copy link
Copy Markdown
Collaborator Author

dcw02 commented Apr 25, 2026

the accept length degradation issue has been fixed, it was a rope config handling issue when transformers version got bumped

@dcw02
Copy link
Copy Markdown
Collaborator Author

dcw02 commented Apr 25, 2026

so i realized we can carry reserved kv allocation metadata through overlap draft state and let next-step prep use the prepared allocation watermark, we could get rid of a scheduling bubble that helps low concurrency a lot. for correctness, scheduler output processing applies the request watermark monotonically later.

prior v2 baseline:

DFLASH output tok/s
tp\conc       1         32
-------  ------  ---------
      1  975.04  12,956.91

decoupling:

DFLASH output tok/s
tp\conc         1         32
-------  --------  ---------
      1  1,073.64  12,995.40

@ggg-s
Copy link
Copy Markdown

ggg-s commented Apr 27, 2026

@dcw02 i encountered a error:
ValueError: Speculative decoding for Qwen3_5ForConditionalGeneration is not compatible with radix cache when using --mamba-scheduler-strategy no_buffer.To use radix cache with speculative decoding, please use --mamba-scheduler-strategy extra_buffer and set SGLANG_ENABLE_SPEC_V2=1. How can I solve this problem?

@liusy58
Copy link
Copy Markdown
Collaborator

liusy58 commented Apr 27, 2026

@dcw02 Great work! How do Dflash and Eagle3 stack up against each other in terms of performance? Do you have any current data on this?

@tugot17
Copy link
Copy Markdown
Contributor

tugot17 commented Apr 27, 2026

@dcw02 I would like to add support for LFM to this PR.
If I make a PR to your branch introducing the changes could you merge it?

@dcw02
Copy link
Copy Markdown
Collaborator Author

dcw02 commented Apr 27, 2026

@ggg-s you can either disable radix cache or set --mamba-scheduler-strategy extra_buffer which should work for v1 and v2 dflash. there might be some concurrency clamping as I removed the mamba memory calculations for another PR

@dcw02
Copy link
Copy Markdown
Collaborator Author

dcw02 commented Apr 27, 2026

@liusy58 in my testing dflash is faster for my use cases, but both can be very good depending on how well you train the draft models

@dcw02
Copy link
Copy Markdown
Collaborator Author

dcw02 commented Apr 27, 2026

@tugot17 yes, we can merge it after spec v2 dflash is merged. thanks for your contribution!

@liusy58
Copy link
Copy Markdown
Collaborator

liusy58 commented Apr 27, 2026

@dcw02 Thank you for your reply. Can we chat on slack?

@tugot17
Copy link
Copy Markdown
Contributor

tugot17 commented Apr 27, 2026

@dcw02
I added the LFM changes, but if it will be easier to add it after the DFLash is merge to main in the first place than let's wait

https://github.com/sgl-project/sglang/pull/23847/changes

@liusy58
Copy link
Copy Markdown
Collaborator

liusy58 commented Apr 28, 2026

@dcw02 Could you please resolve these merge conflicts?

@dcw02
Copy link
Copy Markdown
Collaborator Author

dcw02 commented Apr 28, 2026

@liusy58 fixed merged conflicts

@dcw02 dcw02 requested a review from kpham-sgl as a code owner May 1, 2026 04:56
@dcw02 dcw02 force-pushed the dcw02/dflash-spec-v2 branch from 8ae7dd3 to 9893ef8 Compare May 1, 2026 05:01
@dcw02
Copy link
Copy Markdown
Collaborator Author

dcw02 commented May 1, 2026

I will put up separate PRs for the draft swa layers and gemma 4 support so they can be merged in first for v1

yahya010 added a commit to abdelfattah-lab/sglang that referenced this pull request May 4, 2026
…roject#23000

Cherry-picked the two files needed for smcsd's DFlash direct-load path:
- python/sglang/srt/models/dflash.py (DFlashDraftModel + DFlashDecoderLayer)
- python/sglang/srt/speculative/dflash_utils.py (helpers used by the model)

Copied from sglang upstream PR refs/pull/23000/head, which is the canonical
implementation of DFlash speculative decoding referenced by checkpoints like
z-lab/Qwen3.6-27B-DFlash. Adding the model class to our branch lets smcsd's
_init_dflash_direct load DFlash drafts directly via sglang's class registry
instead of transformers' trust_remote_code (which would 404 on dflash.py).

The other DFlash files in PR sgl-project#23000 (dflash_worker, dflash_info,
dflash_accept_bonus, etc.) are sglang-side speculative decoding scaffolding
not used by smcsd's SMC-DFlash worker.
@ggg-s
Copy link
Copy Markdown

ggg-s commented May 7, 2026

hi @dcw02 Is the current PR compatible with DFLASH + FlashInfer + mixed batches?

@dcw02
Copy link
Copy Markdown
Collaborator Author

dcw02 commented May 7, 2026

hi @dcw02 Is the current PR compatible with DFLASH + FlashInfer + mixed batches?

I haven't tested that myself so I'm unsure

@ggg-s
Copy link
Copy Markdown

ggg-s commented May 8, 2026

hi @dcw02 Can the current PCG be used?

class TestDFlashServerSpecV2(TestDFlashServerBase):
spec_v2 = True

@unittest.skip
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

qq: why do we need to skip this?

@@ -26,6 +28,8 @@ class TestDFlashServerBase(CustomTestCase, MatchedStopMixin, GSM8KMixin):
attention_backend = "flashinfer"
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

qq: Does dflash only support flashinfer?

@@ -110,6 +97,23 @@ def _lazy_init_buf(self, draft_input: EagleDraftInput):
device=self.device,
)

if self.spec_algo.is_dflash():
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: I prefer adding a more general function (something like need_topk) instead of checking whether it's dflash here. What do you think?

logger.warning(
"Overlap scheduler is disabled when using DFLASH speculative decoding (spec v2 is not supported yet)."
)
if envs.SGLANG_ENABLE_SPEC_V2.get():
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

spec v2 is opened by default. the logic here may need to be changed

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants