[Speculative] Support penalty for spec v2 overlap scheduling#22049
[Speculative] Support penalty for spec v2 overlap scheduling#22049hnyls2002 merged 3 commits intosgl-project:mainfrom
Conversation
|
Hi @hnyls2002, could a maintainer help trigger CI? This PR adds penalty support (frequency_penalty, presence_penalty, logit_bias) for spec v2 overlap scheduling, as listed in #11762. Thanks! |
There was a problem hiding this comment.
Code Review
This pull request implements a relaxed version of penalty accumulation and application for Eagle speculative decoding (v2) and includes a new test case to verify these parameters. Feedback highlights a logic error where only the last token is accumulated instead of all newly accepted tokens, a runtime error caused by a typo in the attribute name acc_linear_penalties, and the omission of scaling penalties like repetition penalty in the current implementation.
| output_ids = torch.tensor( | ||
| [ | ||
| ( | ||
| req.output_ids[-1] | ||
| if len(req.output_ids) | ||
| else req.origin_input_ids[-1] | ||
| ) | ||
| for req in batch.reqs | ||
| ], | ||
| dtype=torch.int64, | ||
| device=batch.device, | ||
| ) |
There was a problem hiding this comment.
In speculative decoding, multiple tokens can be accepted in a single verify round. This logic only accumulates the last accepted token (req.output_ids[-1]) into the penalizer state. All newly accepted tokens from the previous round should be accumulated to ensure frequency_penalty and presence_penalty counters are accurate.
Additionally, creating a new tensor from a list in a loop can be inefficient for large batches; consider gathering the IDs more efficiently if they are already available in a tensor format.
|
/tag-and-rerun-ci |
|
/rerun-test registered/spec/eagle/test_eagle_infer_beta.py |
|
✅ |
|
/rerun-test registered/spec/eagle/test_eagle_infer_b.py |
|
✅ |
|
✅ |
Motivation
Closes the penalty support item in #11762.
Spec v2 (overlap scheduling) previously ignored
frequency_penalty,presence_penalty,repetition_penalty, andlogit_biasduring verification,silently producing unpenalized outputs.
Modifications
Two changes in
python/sglang/srt/speculative/eagle_info_v2.py:Apply penalties during verify sampling (
sample())acc_additive_penalties,acc_scaling_penalties, andlogit_biasdirectly to verify logits, each broadcast via
repeat_interleaveto match the(bs * draft_token_num, V)shapeapply_logits_bias()logic with per-field expansion; follows thesame relaxed approximation as spec v1
Accumulate penalty state per decode round (
prepare_for_decode())penalizer_orchestrator.cumulate_output_tokens()to keep penalty countersup-to-date
Test (
test/registered/spec/eagle/test_eagle_infer_beta.py):test_penalty()— concurrent requests with varied penalty combinationsand differentiated
max_new_tokensto exercisefilter_batchChecklist