streaming session: trim spec v2 overshoot in cache_finished_req#22897
Merged
streaming session: trim spec v2 overshoot in cache_finished_req#22897
Conversation
hnyls2002
added a commit
that referenced
this pull request
Apr 15, 2026
Collaborator
Author
|
/rerun-test test_session_control.py test_session_latency.py test_streaming_session.py test_streaming_session_swa.py |
Contributor
|
✅ ❌ |
Contributor
There was a problem hiding this comment.
Code Review
This pull request introduces a _trim_overshoot method to handle speculative decoding overshoots by trimming the KV cache and output IDs to the correct finished length. It also refactors the KV cache freeing logic into a reusable _free_kv_aligned helper. Feedback suggests updating req.swa_evicted_seqlen within _trim_overshoot to maintain state consistency, matching the logic used in _free_tail.
This was referenced Apr 15, 2026
jmamou
pushed a commit
to jmamou/sglang
that referenced
this pull request
Apr 20, 2026
hnyls2002
added a commit
that referenced
this pull request
Apr 21, 2026
Merged
4 tasks
yhyang201
pushed a commit
to yhyang201/sglang
that referenced
this pull request
Apr 22, 2026
zhangying098
pushed a commit
to zhangying098/sglang
that referenced
this pull request
Apr 23, 2026
kyx1999
pushed a commit
to KMSorSMS/sglang
that referenced
this pull request
Apr 27, 2026
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Problem
Because spec accepts multiple tokens per round, the finishing round (the one whose post first trips
len(output_ids) >= max_new_tokens) may commit overshoot ("bad") tokens — depending on how much ofaccept_lensfalls past the boundary. It can also land exactly onfinished_len, in which case there is no overshoot.Slot KV layout after a finishing round with overshoot (
O = origin_input_len,F = finished_len,C = committed - origin,C > F):session_controllerbuilds the next turn's input fromoutput_ids[:F], so positions[O+F, O+C)will hold different tokens next turn. Ifmatch_prefixreturnsprefix_len > O+F, attention reads stale overshoot KV against the new tokens.This PR is a general correctness guard against overshoot at finish time. Spec v2 is the only path that triggers it today, but the trim is unconditional — any future scheduler path that lets a finishing round commit past
finished_lenis covered automatically.Fix
In
cache_finished_req, trim the slot's KV state to thefinished_lenboundary beforesave_from_req:[origin+finished_len, kv_allocated_len)kv_allocated_lenandkv_committed_lenatorigin+finished_lenoutput_idstofinished_len(already capped for response, but the in-memory list still held overshoot tokens)Postcondition (always holds after
cache_finished_reqreturns):So next turn's
match_prefixreturnsprefix_len <= origin + finished_len— the previous turn's API boundary. Overshoot is fully contained within the finishing round.Appendix: spec v2 + streaming session — reference analysis
This appendix is the single reference for the streaming-session + spec v2 correctness work. It builds a vocabulary first, then derives every case from it — so future reviewers / bots can reason from definitions instead of reassembling facts from #22790 / #22862 / #22897 / #22900 / #22651.
Definitions
committed_len(req.kv_committed_len): number of sequence positions whose KV has been materialized inreq_to_token. Invariant:committed_len = origin + output_ids_len - 1during spec v2 decode (the-1is the bonus lag, defined below).allocated_len(req.kv_allocated_len): number of sequence positions the allocator has reserved slots for. Always>= committed_lenduring decode (over-allocation for upcoming drafts).finished_len(req.finished_len): the API-visible output length, set bycheck_finishedwhenlen(output_ids) >= max_new_tokens(or a stop condition hits). ForFINISH_LENGTH,finished_len == max_new_tokens.target:origin + finished_len. The slot boundary after trim.+1inaccept_lens = num_drafts + 1. Sampled from target logits at the last verify position. Its KV is not written in this round's target forward; it becomesdraft_token[0]of the next round and its KV is written there.len(output_ids) == committed_len - origin + 1. The last output token (always a bonus) has no KV yet.len(output_ids) >= max_new_tokens. Calledbatch_F.batch_{F+1}. In overlap mode, the scheduler plans it beforebatch_F's post has detected finish, so the req is still included. Its forward runs and writes KV; its post-loop hitsif overlap and req.finished(): continue→ output_ids / check_finished are skipped but_resolve_spec_overlap_token_idsalready bumpedcommittedat the top of post.output_ids[:finished_len]— what the API returned.output_ids[finished_len:]— committed by_resolvebut discarded by the API boundary.Case analysis for the finishing round
Let
k = accept_lens_F(this round's accept count). The finishing round's post atomically runs:Three cases by how many of the
knew tokens fall pastmax_new_tokens:len(output_ids)_after == max_new): 0 invalid tokens.committed == target - 1(bonus lag).allocated >= target.j(1 ≤ j ≤ k - 1):jinvalid tokens.committed == target + j - 1.allocated >= target + j - 1.In every finishing-round case,
committedis non-monotonic relative totarget: it can be below (case 1) or above (case 2). Both need corrective action atcache_finished_req.The role of the extra round (overlap only)
In overlap mode, the extra round always runs for the finishing req (scheduler planned it ahead). Effects:
origin + output_ids_len_after_F - 1(i.e. attarget - 1in the exact-match case, ortarget + j - 1in the overshoot case). This is what turns the "bonus lag" into actual KV._resolve_spec_overlap_token_idsbumpscommittedbyaccept_lens_{F+1}. This fires afterbatch_F's post — specifically aftercache_finished_reqhas already calledsave_from_req.continue) sooutput_ids/check_finishedstate is untouched.Without a separate fix,
save_from_reqcaptures the pre-extra-roundcommitted, so the slot storescommitted = target + j - 1(case 2) ortarget - 1(case 1).In non-overlap mode (spec v1,
--disable-overlap-schedule), the extra round does not exist. Bonus KV is never materialized → next turn inherits 1 position short →kv_inherit_offset = -1.Implications per case, with and without each fix
Using
T = target = origin + finished_len. Each column is the savedslot.kv_committed_lenafter the fix is applied (and earlier fixes in the chain).T - 1T - 1T - 1TT + j - 1(stale KV past T)T(stale tail freed)T(SWA cursor also capped)TT - 1T - 1T - 1T - 1(bonus never materialized; offset=-1 is correct)T + j - 1T - 1(trim clamps; no extra round → offset=-1)T - 1T - 1Key takeaways:
_resolvebeforecheck_finished.min()) to keep the eviction cursor consistent with the post-trim boundary.+1committed bump; it does not create or destroy KV. Overlap-off paths still needkv_inherit_offset = -1.Invalid tokens and KV at their positions
After trim (
_free_kv_aligned(req_to_token[pool, T:alloc])) the positions[T, alloc_before_trim)are released from the KV pool. Whatever KV was written at those positions in the finishing round's target forward — for draft tokens at positions the accept logic ended up rejecting, or for the overshoot drafts — becomes unreachable. The invalid tokens (as token IDs inoutput_ids[T-origin:]) are also truncated. Next turn's match_prefix can grant at mostTtokens of inheritance, regardless of what the finishing round happened to verify.Retract is orthogonal
Retract retries the same req with the same token IDs, so
reset_for_retractleaves the slot's prefix as a valid ancestor. The only fix on the retract side is page-aligned_free_tail(#22862) — partial-page free would corrupt pages still holding committed tokens. Retract does not interact with bonus lag or overshoot, because retract fires during decode, not at finish.Summary: one postcondition, four fixes
After all fixes are in place, every streaming-session slot save guarantees:
T + j - 1toTand free stale tail.swa_evicted_seqlenatTto stop the cursor leaking past.prepare_for_decodesocommittedreachesTupfront (not post-facto in the extra round).match_prefixpath.This appendix was drafted with Claude Code assistance.