Skip to content

trim_overshoot: cap swa_evicted_seqlen + unit test#22900

Merged
hnyls2002 merged 6 commits intomainfrom
lsyin/trim-overshoot-swa-cap
Apr 15, 2026
Merged

trim_overshoot: cap swa_evicted_seqlen + unit test#22900
hnyls2002 merged 6 commits intomainfrom
lsyin/trim-overshoot-swa-cap

Conversation

@hnyls2002
Copy link
Copy Markdown
Collaborator

@hnyls2002 hnyls2002 commented Apr 15, 2026

Problem

_trim_overshoot (added in #22897) caps kv_committed_len and kv_allocated_len at target = origin + finished_len, but leaves req.swa_evicted_seqlen untouched. _free_tail (the match_prefix counterpart on the same SessionAwareCache) caps all three. The asymmetry causes a SWA pool leak when a finishing round overshoots while SWA was actively evicting.

swa_evicted_seqlen is a cursor — _evict_swa reads it as the start position for the next eviction sweep:

new_swa_evicted = max(req.swa_evicted_seqlen, pre_len - W - P)
free_swa(req_to_token[swa_evicted : new_swa_evicted])  # frees [old_cursor, new_cursor)
req.swa_evicted_seqlen = new_swa_evicted

If overshoot pushed the cursor past the trim boundary (e.g. swa_evicted = 42 but target = 38), save_from_req propagates the stale 42 into the slot and next turn's restore_to_req hands it back. Next turn's prefill writes new SWA slots at positions [38, 42), but _evict_swa starts scanning from 42 — those slots are skipped forever and accumulate as a leak.

Fix

One-line cap to mirror _free_tail:

req.swa_evicted_seqlen = min(req.swa_evicted_seqlen, target)

This is the only other downward-cap site for swa_evicted_seqlen outside _free_tail (and reset_for_retract, which clears to 0). Both downward-moving paths now enforce the same invariant: swa_evicted_seqlen <= origin + committed.

Test

Adds test_trim_overshoot_postcondition covering the full postcondition in one shot: kv_committed_len, kv_allocated_len, swa_evicted_seqlen all capped at target, output_ids truncated to finished_len, and the page-aligned tail freed. Without the swa cap, the test fails on swa_evicted_seqlen == target (assert 42 == 38).

@gemini-code-assist
Copy link
Copy Markdown
Contributor

Warning

You have reached your daily quota limit. Please wait up to 24 hours and I will start processing your requests again!

@hnyls2002 hnyls2002 changed the base branch from lsyin/streaming-session-overshoot-trim to main April 15, 2026 21:34
…swa-cap

# Conflicts:
#	python/sglang/srt/mem_cache/session_aware_cache.py
@hnyls2002
Copy link
Copy Markdown
Collaborator Author

/tag-and-rerun-ci

@hnyls2002
Copy link
Copy Markdown
Collaborator Author

/rerun-test test_streaming_session.py test_streaming_session_unit.py test_session_control.py test_session_latency.py

@github-actions
Copy link
Copy Markdown
Contributor

1-gpu-h100 (3 tests): View workflow run

cd test/ && python3 registered/sessions/test_streaming_session.py
cd test/ && python3 registered/sessions/test_session_control.py
cd test/ && python3 registered/sessions/test_session_latency.py

ubuntu-latest (1 test): View workflow run

cd test/ && python3 registered/unit/mem_cache/test_streaming_session_unit.py

@hnyls2002 hnyls2002 merged commit f979216 into main Apr 15, 2026
107 of 170 checks passed
@hnyls2002 hnyls2002 deleted the lsyin/trim-overshoot-swa-cap branch April 15, 2026 22:05
jmamou pushed a commit to jmamou/sglang that referenced this pull request Apr 20, 2026
yhyang201 pushed a commit to yhyang201/sglang that referenced this pull request Apr 22, 2026
zhangying098 pushed a commit to zhangying098/sglang that referenced this pull request Apr 23, 2026
kyx1999 pushed a commit to KMSorSMS/sglang that referenced this pull request Apr 27, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant