feat: Full Support for Overlapped Constrained Decoding + Spec V2#15623
feat: Full Support for Overlapped Constrained Decoding + Spec V2#15623Ubospica wants to merge 4 commits intosgl-project:mainfrom
Conversation
Summary of ChangesHello @Ubospica, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request significantly enhances the efficiency of constrained decoding, particularly when using speculative decoding (Spec V2) with grammars. By enabling full overlap between CPU-bound grammar processing and GPU computations, it minimizes idle time and boosts overall throughput. This is achieved through asynchronous data transfers, a new native C++ implementation for grammar tree traversal, and a refactored scheduling mechanism that processes grammar state updates from previous batches concurrently with the current batch's GPU forward pass. Highlights
🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console. Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Code Review
This pull request introduces full support for overlapped constrained decoding with Speculative Decoding V2, which is a significant performance enhancement. The changes are well-structured, primarily affecting the scheduler and the speculative worker to enable overlapping CPU-bound grammar processing with GPU computations. Key changes include the introduction of pending_accept_info for carrying over grammar work, asynchronous data transfers using pinned memory, and the integration of a faster C++ implementation for draft tree traversal. The addition of new benchmarks and unit tests is commendable and ensures the correctness and performance of the new features. The implementation appears solid, and my feedback focuses on improving type hint accuracy and reducing some code duplication for better long-term maintainability.
| if last_batch.forward_mode.is_extend(): | ||
| # Prefill case: single token per request (next_token_ids shape: [bs]) | ||
| for i, req in enumerate(last_batch.reqs): | ||
| if ( | ||
| req.grammar is not None | ||
| and not req.finished() | ||
| and not req.is_retracted | ||
| ): | ||
| try: | ||
| req.grammar.accept_token(next_token_ids[i]) | ||
| except ValueError as e: | ||
| logger.error( | ||
| f"Grammar accept_token failed for req {req.rid} " | ||
| f"with token {next_token_ids[i]}: {e}" | ||
| ) | ||
| else: | ||
| # Decode case: multiple accepted tokens per request | ||
| # next_token_ids shape: [bs * speculative_num_draft_tokens] | ||
| accept_lens = last_result.accept_lens.tolist() | ||
| stride = self.speculative_num_draft_tokens | ||
|
|
||
| for i, req in enumerate(last_batch.reqs): | ||
| if ( | ||
| req.grammar is not None | ||
| and not req.finished() | ||
| and not req.is_retracted | ||
| ): | ||
| # Get the accepted tokens for this request | ||
| accepted_tokens = next_token_ids[ | ||
| i * stride : i * stride + accept_lens[i] | ||
| ] | ||
| try: | ||
| for token_id in accepted_tokens: | ||
| req.grammar.accept_token(token_id) | ||
| except ValueError as e: | ||
| logger.error( | ||
| f"Grammar accept_token failed for req {req.rid} " | ||
| f"with tokens {accepted_tokens}: {e}" | ||
| ) |
There was a problem hiding this comment.
The logic for handling prefill and decode cases in this method has some code duplication, particularly the checks for req.grammar, req.finished(), req.is_retracted, and the try...except block. This could be refactored to improve maintainability by first preparing a unified list of accepted tokens for each request and then iterating through them in a single loop.
if last_batch.forward_mode.is_extend():
# Prefill case: single token per request (next_token_ids shape: [bs])
accepted_tokens_per_req = [[t] for t in next_token_ids]
else:
# Decode case: multiple accepted tokens per request
accept_lens = last_result.accept_lens.tolist()
stride = self.speculative_num_draft_tokens
accepted_tokens_per_req = [
next_token_ids[i * stride : i * stride + accept_lens[i]]
for i in range(len(last_batch.reqs))
]
for i, req in enumerate(last_batch.reqs):
if (
req.grammar is not None
and not req.finished()
and not req.is_retracted
):
accepted_tokens = accepted_tokens_per_req[i]
try:
for token_id in accepted_tokens:
req.grammar.accept_token(token_id)
except ValueError as e:
logger.error(
f"Grammar accept_token failed for req {req.rid} "
f"with tokens {accepted_tokens}: {e}"
)There was a problem hiding this comment.
the suggestion looks not bad
5d71d30 to
26354d2
Compare
Signed-off-by: Ubospica <ubospica@gmail.com>
| # the current batch's verify forward to overlap CPU and GPU operations. | ||
| if ( | ||
| batch | ||
| and batch.is_eagle_v2 |
There was a problem hiding this comment.
AttributeError: 'ScheduleBatch' object has no attribute 'is_eagle_v2'
|
Also consider to adapt to python/sglang/srt/speculative/multi_layer_eagle_worker_v2.py ? |
| @@ -0,0 +1,7 @@ | |||
| #!/bin/bash | |||
There was a problem hiding this comment.
maybe we don't need this file?
| def draft_worker(self): | ||
| return self._draft_worker | ||
|
|
||
| def _init_grammar_pinned_buffers(self, bs: int, num_draft_tokens: int): |
There was a problem hiding this comment.
what's the reason we want to lazy init here? seems like non-lazy init is fine?
| ) | ||
|
|
||
| # Record event to synchronize before using CPU tensors | ||
| copy_event = torch.cuda.Event() |
There was a problem hiding this comment.
grammar_copy_done for clarity and consistency since we have a copy_done and verify_done
| if last_batch.has_grammar: | ||
| batch.pending_accept_info = (last_batch, last_result) | ||
| # Mark that grammar accept will be processed in the next batch's verify | ||
| last_result.grammar_accept_processed = True |
There was a problem hiding this comment.
I am trying to understand why we need this grammar_accept_processed flag. for spec v2 + decode + not disable_overlap_for_batch, all batches will have grammar_accept_processed = True, and all other cases grammar_accept_processed = False.
In the output processing, can we just check the conditions above instead?
There was a problem hiding this comment.
Also, IIUC, grammar_accept_processed only works for decode batch. How do you check confirm that for the last_batch?
I see you check batch.forward_mode.is_decode(), but it's checking if batch N being decode, last_batch is batch N-1, how can we infer batch N-1 from batch N?
There was a problem hiding this comment.
I think I got the reason now, the batch N's verify handles batch N-1's grammar_accept_processed

This PR supports fully overlapped constrained decoding with Spec V2.
See #13425 for the non-overlapped version. See #11762, #13019 for more background. This PR is on top of #15465.
Overlapping Pattern
Benchmark
E2E Latency:
Key Findings
test_mix_json_and_other): ~30% fasterSigned-off-by: Ubospica ubospica@gmail.com
Motivation
Modifications
Accuracy Tests
Benchmarking and Profiling
Checklist