Skip to content

Fix chunked prefill and KV cache leaks for streaming sessions#20476

Merged
hnyls2002 merged 4 commits intosgl-project:mainfrom
YazhiGao:fix/streaming-chunked-prefill-leak
Mar 13, 2026
Merged

Fix chunked prefill and KV cache leaks for streaming sessions#20476
hnyls2002 merged 4 commits intosgl-project:mainfrom
YazhiGao:fix/streaming-chunked-prefill-leak

Conversation

@YazhiGao
Copy link
Copy Markdown
Contributor

Three fixes for streaming session KV cache management:

  1. Enforce single chunked request per prefill batch: Track new_chunked_req and has_reusing_chunked_req in PrefillAdder to prevent multiple chunked requests from being batched together. The previous has_chunked_req parameter only checked self.chunked_req from prior batches, missing within-batch duplicates. Removes the now-unnecessary has_chunked_req parameter from add_one_req.

  2. Make SessionSlot.restore_to_req non-destructive: don't clear req_pool_idx and mamba_pool_idx from the slot after restoring to the request. During chunked prefill, a request may be rejected by the scheduler (e.g. budget exhausted) and retried in the next cycle, causing match_prefix to call restore_to_req again. The destructive restore caused the second call to fall through to the inner radix cache, acquiring real tree node locks that were never properly released.

  3. Always skip inner cache_unfinished_req for streaming chunked stash: move the chunked prefix_indices save above the slot existence check so it applies uniformly to all turns, preventing redundant radix tree insertions during inter-chunk stashing.

Motivation

Modifications

Accuracy Tests

Benchmarking and Profiling

Checklist

Review Process

  1. Ping Merge Oncalls to start the PR flow. See the PR Merge Process.
  2. Get approvals from CODEOWNERS and other reviewers.
  3. Trigger CI tests with comments or contact authorized users to do so.
    • /tag-run-ci-label, /rerun-failed-ci, /tag-and-rerun-ci
  4. After green CI and required approvals, ask Merge Oncalls to merge.

@gemini-code-assist
Copy link
Copy Markdown
Contributor

Warning

You have reached your daily quota limit. Please wait up to 24 hours and I will start processing your requests again!

@hnyls2002 hnyls2002 self-assigned this Mar 12, 2026
Comment thread python/sglang/srt/dllm/mixin/scheduler.py
Three fixes for streaming session KV cache management:

1. Enforce single chunked request per prefill batch: Track new_chunked_req
   and has_reusing_chunked_req in PrefillAdder to prevent multiple chunked
   requests from being batched together. The previous has_chunked_req
   parameter only checked self.chunked_req from prior batches, missing
   within-batch duplicates. Removes the now-unnecessary has_chunked_req
   parameter from add_one_req.

2. Make SessionSlot.restore_to_req non-destructive: don't clear req_pool_idx
   and mamba_pool_idx from the slot after restoring to the request. During
   chunked prefill, a request may be rejected by the scheduler (e.g. budget
   exhausted) and retried in the next cycle, causing match_prefix to call
   restore_to_req again. The destructive restore caused the second call to
   fall through to the inner radix cache, acquiring real tree node locks
   that were never properly released.

3. Always skip inner cache_unfinished_req for streaming chunked stash: move
   the chunked prefix_indices save above the slot existence check so it
   applies uniformly to all turns, preventing redundant radix tree
   insertions during inter-chunk stashing.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
@YazhiGao YazhiGao force-pushed the fix/streaming-chunked-prefill-leak branch from aca4ea8 to 83d9e75 Compare March 12, 2026 23:37
self.can_run_list.append(req)
# Track if this batch has a reusing request with is_chunked > 0
# to prevent batching another chunked-reusing request (memory_pool assertion).
if req.req_pool_idx is not None and req.is_chunked > 0:
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the new chunked prefill pipeline, when the previous request A's last chunk gets added into the prefill batch, has_reusing_chunked_req will be set to True. This prevents request B(next request) from being added to the prefill batch. I suggest not changing the original pipeline (this causes performance regression) and just relaxing the assertion first.

@hnyls2002
Copy link
Copy Markdown
Collaborator

/tag-and-rerun-ci

@hnyls2002
Copy link
Copy Markdown
Collaborator

hnyls2002 commented Mar 13, 2026

@cctry @YazhiGao

The req_pool_idx reuse assert added in #17850 needs to be relaxed for streaming sessions:

assert (
    sum(1 for i in reusing if reqs[i].is_chunked > 0) <= 1
), "only one chunked request may reuse req_pool_idx in a batch"

Original reuse (from #17850): A chunked request keeps its req_pool_idx across chunks instead of freeing and reallocating — preventing a data race where an async CUDA kernel reads a slot that has been freed and reassigned. The scheduler tracks exactly one self.chunked_req, so <= 1 was always satisfied.

New reuse (streaming sessions): Streaming session requests enter scheduling with a pre-existing req_pool_idx recovered from a session slot (restore_to_req). This introduces a second source of reuse — independent from the chunked prefill reuse.

Why at most 2: A batch can contain at most one request from add_chunked_req (the single self.chunked_req) and at most one new chunked request from add_one_req (new_chunked_req). When both are streaming and carry their own req_pool_idx from session slots, the count reaches 2. It cannot exceed 2 because the scheduler only tracks one chunked_req and one new_chunked_req.

Trigger scenario: Request A (streaming, is_chunked=1, last chunk, req_pool_idx=2) via add_chunked_req + Request B (streaming, new turn, req_pool_idx=3 from session slot, gets chunked, is_chunked bumped to 1 before alloc_req_slots) → sum = 2 → assert fails.

Safety: Each reusing request writes to its own independent row in req_to_token_pool.req_to_token — no slot aliasing, no data race. The invariant from #17850 (no freed-then-reallocated slot read by async CUDA) is preserved.

Proposed fix: <= 1<= 2. I will rewrite the chunked prefill logic in the following PRs.

@cctry
Copy link
Copy Markdown
Collaborator

cctry commented Mar 13, 2026

Proposed fix: <= 1 → <= 2. I will rewrite the chunked prefill logic in the following PRs.

make sense to me, but i'd suggest making the assertion be more specific instead of just relaxing the count.

@hnyls2002 hnyls2002 merged commit b1246c5 into sgl-project:main Mar 13, 2026
134 of 155 checks passed
liubiyongge pushed a commit to liubiyongge/sglang that referenced this pull request Mar 13, 2026
…oject#20476)

Co-authored-by: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Co-authored-by: hnyls2002 <lsyincs@gmail.com>
yhyang201 pushed a commit to yhyang201/sglang that referenced this pull request Mar 15, 2026
…oject#20476)

Co-authored-by: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Co-authored-by: hnyls2002 <lsyincs@gmail.com>
Wangzheee pushed a commit to Wangzheee/sglang that referenced this pull request Mar 21, 2026
…oject#20476)

Co-authored-by: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Co-authored-by: hnyls2002 <lsyincs@gmail.com>
0-693 pushed a commit to 0-693/sglang that referenced this pull request Mar 25, 2026
…oject#20476)

Co-authored-by: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Co-authored-by: hnyls2002 <lsyincs@gmail.com>
JustinTong0323 pushed a commit to JustinTong0323/sglang that referenced this pull request Apr 7, 2026
…oject#20476)

Co-authored-by: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Co-authored-by: hnyls2002 <lsyincs@gmail.com>
yhyang201 pushed a commit to yhyang201/sglang that referenced this pull request Apr 22, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants