Skip to content

streaming session: trim spec v2 overshoot in cache_finished_req#22897

Merged
hnyls2002 merged 1 commit intomainfrom
lsyin/streaming-session-overshoot-trim
Apr 15, 2026
Merged

streaming session: trim spec v2 overshoot in cache_finished_req#22897
hnyls2002 merged 1 commit intomainfrom
lsyin/streaming-session-overshoot-trim

Conversation

@hnyls2002
Copy link
Copy Markdown
Collaborator

@hnyls2002 hnyls2002 commented Apr 15, 2026

Problem

Because spec accepts multiple tokens per round, the finishing round (the one whose post first trips len(output_ids) >= max_new_tokens) may commit overshoot ("bad") tokens — depending on how much of accept_lens falls past the boundary. It can also land exactly on finished_len, in which case there is no overshoot.

Note: this is the finishing round itself, not the extra round that overlap scheduling speculatively launches afterward. The overshoot enters committed before any next-iter post fires, so it doesn't matter whether overlap is on/off or whether the extra round runs at all — trim is needed purely because of the atomic resolve-then-check ordering inside the finishing round's own post.

finishing round post:
  _resolve_spec_overlap_token_ids:
    committed   += accept_lens     <- overshoot enters here
    output_ids  += accept_lens
  check_finished:
    if len(output_ids) >= max_new:
      finished_len = max_new       <- API boundary set, but
                                      committed/output_ids
                                      already past it
  cache_finished_req(req)          <- save runs with overshoot baked in

Slot KV layout after a finishing round with overshoot (O = origin_input_len, F = finished_len, C = committed - origin, C > F):

positions:  0 ........ O-1 | O .......... O+F-1 | O+F ........ O+C-1
            [--- input ---]   [--- kept output ---]  [--- overshoot ---]
                                                      stale; next turn's
                                                      input has DIFFERENT
                                                      tokens at these pos
            <-- inherited correctly by next turn -->
                                                      ^^^^^^^^^^^^^^^^^^^
                                                      if match_prefix
                                                      grants prefix_len
                                                      this far, attention
                                                      reads wrong KV

session_controller builds the next turn's input from output_ids[:F], so positions [O+F, O+C) will hold different tokens next turn. If match_prefix returns prefix_len > O+F, attention reads stale overshoot KV against the new tokens.

This PR is a general correctness guard against overshoot at finish time. Spec v2 is the only path that triggers it today, but the trim is unconditional — any future scheduler path that lets a finishing round commit past finished_len is covered automatically.

Fix

In cache_finished_req, trim the slot's KV state to the finished_len boundary before save_from_req:

  • Free the page-aligned tail at [origin+finished_len, kv_allocated_len)
  • Cap kv_allocated_len and kv_committed_len at origin+finished_len
  • Truncate output_ids to finished_len (already capped for response, but the in-memory list still held overshoot tokens)

Postcondition (always holds after cache_finished_req returns):

slot.kv_committed_len  <= origin + finished_len
slot.kv_allocated_len  <= origin + finished_len
len(req.output_ids)    == finished_len

So next turn's match_prefix returns prefix_len <= origin + finished_len — the previous turn's API boundary. Overshoot is fully contained within the finishing round.

Appendix: spec v2 + streaming session — reference analysis

This appendix is the single reference for the streaming-session + spec v2 correctness work. It builds a vocabulary first, then derives every case from it — so future reviewers / bots can reason from definitions instead of reassembling facts from #22790 / #22862 / #22897 / #22900 / #22651.

Definitions

  • committed_len (req.kv_committed_len): number of sequence positions whose KV has been materialized in req_to_token. Invariant: committed_len = origin + output_ids_len - 1 during spec v2 decode (the -1 is the bonus lag, defined below).
  • allocated_len (req.kv_allocated_len): number of sequence positions the allocator has reserved slots for. Always >= committed_len during decode (over-allocation for upcoming drafts).
  • finished_len (req.finished_len): the API-visible output length, set by check_finished when len(output_ids) >= max_new_tokens (or a stop condition hits). For FINISH_LENGTH, finished_len == max_new_tokens.
  • target: origin + finished_len. The slot boundary after trim.
  • Bonus token: the +1 in accept_lens = num_drafts + 1. Sampled from target logits at the last verify position. Its KV is not written in this round's target forward; it becomes draft_token[0] of the next round and its KV is written there.
  • Bonus lag: the always-true fact len(output_ids) == committed_len - origin + 1. The last output token (always a bonus) has no KV yet.
  • Finishing round: the round whose post first trips len(output_ids) >= max_new_tokens. Called batch_F.
  • Extra round: batch_{F+1}. In overlap mode, the scheduler plans it before batch_F's post has detected finish, so the req is still included. Its forward runs and writes KV; its post-loop hits if overlap and req.finished(): continue → output_ids / check_finished are skipped but _resolve_spec_overlap_token_ids already bumped committed at the top of post.
  • Valid tokens: output_ids[:finished_len] — what the API returned.
  • Invalid tokens (aka overshoot tokens): output_ids[finished_len:] — committed by _resolve but discarded by the API boundary.

Case analysis for the finishing round

Let k = accept_lens_F (this round's accept count). The finishing round's post atomically runs:

_resolve_spec_overlap_token_ids:
  committed  += k
  output_ids += k tokens
check_finished:
  if len(output_ids) >= max_new: set finished_len, finished_reason
cache_finished_req(req)   # fires if finished

Three cases by how many of the k new tokens fall past max_new_tokens:

  1. Exact match (len(output_ids)_after == max_new): 0 invalid tokens. committed == target - 1 (bonus lag). allocated >= target.
  2. Overshoot by j (1 ≤ j ≤ k - 1): j invalid tokens. committed == target + j - 1. allocated >= target + j - 1.
  3. No finish this round: not the finishing round; skip.

In every finishing-round case, committed is non-monotonic relative to target: it can be below (case 1) or above (case 2). Both need corrective action at cache_finished_req.

The role of the extra round (overlap only)

In overlap mode, the extra round always runs for the finishing req (scheduler planned it ahead). Effects:

  • Its target forward materializes the bonus token's KV at position origin + output_ids_len_after_F - 1 (i.e. at target - 1 in the exact-match case, or target + j - 1 in the overshoot case). This is what turns the "bonus lag" into actual KV.
  • Its _resolve_spec_overlap_token_ids bumps committed by accept_lens_{F+1}. This fires after batch_F's post — specifically after cache_finished_req has already called save_from_req.
  • The rest of its post-loop is skipped (continue) so output_ids / check_finished state is untouched.

Without a separate fix, save_from_req captures the pre-extra-round committed, so the slot stores committed = target + j - 1 (case 2) or target - 1 (case 1).

In non-overlap mode (spec v1, --disable-overlap-schedule), the extra round does not exist. Bonus KV is never materialized → next turn inherits 1 position short → kv_inherit_offset = -1.

Implications per case, with and without each fix

Using T = target = origin + finished_len. Each column is the saved slot.kv_committed_len after the fix is applied (and earlier fixes in the chain).

Case Pre-all-fixes save + overshoot trim (this PR) + SWA cap (#22900) + bonus accounting (#22651)
Exact match, overlap on T - 1 T - 1 T - 1 T
Overshoot by j, overlap on T + j - 1 (stale KV past T) T (stale tail freed) T (SWA cursor also capped) T
Exact match, overlap off (spec v1) T - 1 T - 1 T - 1 T - 1 (bonus never materialized; offset=-1 is correct)
Overshoot, overlap off T + j - 1 T - 1 (trim clamps; no extra round → offset=-1) T - 1 T - 1

Key takeaways:

  • Overshoot trim is needed regardless of overlap — purely a consequence of atomic _resolve before check_finished.
  • SWA cap rides on trim (one extra min()) to keep the eviction cursor consistent with the post-trim boundary.
  • Bonus accounting only shifts the timing of the +1 committed bump; it does not create or destroy KV. Overlap-off paths still need kv_inherit_offset = -1.

Invalid tokens and KV at their positions

After trim (_free_kv_aligned(req_to_token[pool, T:alloc])) the positions [T, alloc_before_trim) are released from the KV pool. Whatever KV was written at those positions in the finishing round's target forward — for draft tokens at positions the accept logic ended up rejecting, or for the overshoot drafts — becomes unreachable. The invalid tokens (as token IDs in output_ids[T-origin:]) are also truncated. Next turn's match_prefix can grant at most T tokens of inheritance, regardless of what the finishing round happened to verify.

Retract is orthogonal

Retract retries the same req with the same token IDs, so reset_for_retract leaves the slot's prefix as a valid ancestor. The only fix on the retract side is page-aligned _free_tail (#22862) — partial-page free would corrupt pages still holding committed tokens. Retract does not interact with bonus lag or overshoot, because retract fires during decode, not at finish.

Summary: one postcondition, four fixes

After all fixes are in place, every streaming-session slot save guarantees:

slot.kv_committed_len   == origin + finished_len
slot.kv_allocated_len   <= origin + finished_len
slot.swa_evicted_seqlen <= origin + finished_len
len(req.output_ids)     == finished_len
Fix PR One-liner
Overshoot trim #22897 (this) Clamp down from T + j - 1 to T and free stale tail.
SWA cap #22900 Cap swa_evicted_seqlen at T to stop the cursor leaking past.
Bonus accounting #22651 Pre-claim the bonus slot in prepare_for_decode so committed reaches T upfront (not post-facto in the extra round).
Page-aligned tail free #22862 Avoid partial-page free corrupting committed pages on the match_prefix path.

This appendix was drafted with Claude Code assistance.

@hnyls2002
Copy link
Copy Markdown
Collaborator Author

/rerun-test test_session_control.py test_session_latency.py test_streaming_session.py test_streaming_session_swa.py

@github-actions
Copy link
Copy Markdown
Contributor

1-gpu-h100 (3 tests): View workflow run

cd test/ && python3 registered/sessions/test_session_control.py
cd test/ && python3 registered/sessions/test_session_latency.py
cd test/ && python3 registered/sessions/test_streaming_session.py

test_streaming_session_swa.py: No test file found matching test_streaming_session_swa.py under test/registered/.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a _trim_overshoot method to handle speculative decoding overshoots by trimming the KV cache and output IDs to the correct finished length. It also refactors the KV cache freeing logic into a reusable _free_kv_aligned helper. Feedback suggests updating req.swa_evicted_seqlen within _trim_overshoot to maintain state consistency, matching the logic used in _free_tail.

Comment thread python/sglang/srt/mem_cache/session_aware_cache.py
@hnyls2002 hnyls2002 merged commit efc267c into main Apr 15, 2026
59 of 67 checks passed
@hnyls2002 hnyls2002 deleted the lsyin/streaming-session-overshoot-trim branch April 15, 2026 21:15
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant