Skip to content

Fix VRAM leak in overlap scheduling with structured output (#20640)#20697

Merged
hnyls2002 merged 5 commits intosgl-project:mainfrom
Cishoon:fix/vram-leak-extra-buffer-structured-output
Mar 23, 2026
Merged

Fix VRAM leak in overlap scheduling with structured output (#20640)#20697
hnyls2002 merged 5 commits intosgl-project:mainfrom
Cishoon:fix/vram-leak-extra-buffer-structured-output

Conversation

@Cishoon
Copy link
Copy Markdown
Contributor

@Cishoon Cishoon commented Mar 16, 2026

Motivation

Fix VRAM leak when extra_buffer (overlap scheduling) is enabled together with structured JSON output (grammar-based decoding). Fixes #20640

When overlap scheduling is active and grammar is used, tp_worker.py wraps sampling into a delay_sample_func closure. This closure captures forward_batch (holding sampling_info.vocab_mask) and logits_output (holding next_token_logits). Since batch_result is kept alive by result_queue and batch_record_buf until the next scheduler iteration, these large GPU tensors are never freed in time. Each iteration allocates a new vocab_mask, so VRAM grows steadily until OOM.

The leak only triggers when both conditions are met:

  • extra_buffer enabled → overlap scheduling → delay_sample_func closure created
  • Structured output requested → grammar active → vocab_mask allocated every iteration

Modifications

Release GPU tensor references as soon as they are no longer needed:

  • model_runner.py: Set sampling_info.vocab_mask = None immediately after apply_logits_bias() — the mask has already been applied to logits and is not needed downstream.
  • scheduler.py: Set batch_result.delay_sample_func = None after the closure executes — breaks the closure's reference to forward_batch and logits_output.
  • scheduler.py: Set logits_output.next_token_logits = None after sampling — the raw logits are not used by copy_to_cpu() (which only copies next_token_ids and logprob fields).

Accuracy Tests

Benchmarking and Profiling

Verification

Tested with a reproduction script. After the fix, torch.cuda.memory_allocated() remains stable across iterations, confirming the leak is resolved.

Checklist

Review Process

  1. Ping Merge Oncalls to start the PR flow. See the PR Merge Process.
  2. Get approvals from CODEOWNERS and other reviewers.
  3. Trigger CI tests with comments or contact authorized users to do so.
    • /tag-run-ci-label, /rerun-failed-ci, /tag-and-rerun-ci
  4. After green CI and required approvals, ask Merge Oncalls to merge.

…ct#20640)

Release delay_sample_func closure and vocab_mask/next_token_logits GPU
tensors after sampling completes. These references were kept alive by
result_queue and batch_record_buf until the next iteration, causing
steady VRAM growth when grammar-based decoding is used with overlap
scheduling (extra_buffer) enabled.

- Set sampling_info.vocab_mask = None after applying it to logits
- Set batch_result.delay_sample_func = None after execution
- Set logits_output.next_token_logits = None after sampling
@gemini-code-assist
Copy link
Copy Markdown
Contributor

Summary of Changes

Hello, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request resolves a critical VRAM leak that manifested under specific conditions involving overlap scheduling and structured output. By strategically clearing references to large GPU tensors as soon as they are no longer needed, the change ensures efficient memory management and prevents out-of-memory errors, thereby enhancing the stability and performance of the system in these advanced usage scenarios.

Highlights

  • VRAM Leak Fix: Addressed a VRAM leak that occurred when overlap scheduling (extra_buffer) was enabled concurrently with structured JSON output (grammar-based decoding).
  • Root Cause Identification: Identified the leak's cause as large GPU tensors (vocab_mask, forward_batch, logits_output) being inadvertently held alive by closures and batch results, preventing timely deallocation.
  • Memory Reference Nullification: Implemented explicit nullification of references to these GPU tensors (sampling_info.vocab_mask, batch_result.delay_sample_func, batch_result.logits_output.next_token_logits) immediately after their utility expired to ensure prompt memory release.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Changelog
  • python/sglang/srt/managers/scheduler.py
    • Released references to batch_result.delay_sample_func and batch_result.logits_output.next_token_logits after sampling to prevent VRAM leaks.
  • python/sglang/srt/model_executor/model_runner.py
    • Released the sampling_info.vocab_mask GPU tensor reference immediately after its application to logits.
Activity
  • The author, Cishoon, created this pull request to fix a VRAM leak.
  • The code has been formatted and adheres to SGLang code style guidance.
  • Verification was performed using a reproduction script, confirming stable torch.cuda.memory_allocated() and resolution of the leak.
  • No human activity (comments, reviews) has been recorded on this pull request yet.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for GitHub and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request effectively addresses a VRAM leak occurring with overlap scheduling and structured output. The fix correctly releases references to large GPU tensors as soon as they are no longer needed. The changes in scheduler.py and model_runner.py are clear, well-commented, and directly target the source of the leak. The implementation appears correct and should resolve the memory issue without side effects.

@Swipe4057
Copy link
Copy Markdown
Contributor

I tested it on the Qwen3.5-122B-A10B model and it works great.

@Cishoon
Copy link
Copy Markdown
Contributor Author

Cishoon commented Mar 18, 2026

Hi, @hnyls2002 , this is my first PR in sglang. It has been waiting for review for a while, so I was wondering whether I may have missed any steps or done anything incorrectly. If there's anything I should fix, please let me know. If you have time, could you help review it? Thanks!

@sglang-bot sglang-bot force-pushed the fix/vram-leak-extra-buffer-structured-output branch from 33c04f6 to 68b3a0d Compare March 20, 2026 04:23
Copy link
Copy Markdown
Collaborator

@hnyls2002 hnyls2002 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This change seems to be reasonable. But why does the VRAM grow steadily instead of releasing just one iteration late?

@Cishoon
Copy link
Copy Markdown
Contributor Author

Cishoon commented Mar 20, 2026

This change seems to be reasonable. But why does the VRAM grow steadily instead of releasing just one iteration late?

In tp_worker.py, overlap scheduling creates a delay_sample_func closure that captures forward_batch and
logits_output, and that closure stayed referenced after it ran.

@hnyls2002
Copy link
Copy Markdown
Collaborator

This change seems to be reasonable. But why does the VRAM grow steadily instead of releasing just one iteration late?

In tp_worker.py, overlap scheduling creates a delay_sample_func closure that captures forward_batch and logits_output, and that closure stayed referenced after it ran.

That closure won't hold the reference forever; the tensor will just delay releasing for one or two rounds.

@hnyls2002 hnyls2002 merged commit 999bad5 into sgl-project:main Mar 23, 2026
217 of 238 checks passed
dingzhiqiang pushed a commit to dingzhiqiang/sglang that referenced this pull request Mar 24, 2026
…d release on finish

Address review feedback from DarkSharpness:
- Simplify filter_batch grammar filtering logic (3 cases: empty, slice, clear)
- Remove grammars=None from copy_for_forward (breaks grammar in overlap mode)
- Remove vocab_mask cleanup from filter_batch (handled by sgl-project#20697 in _preprocess_logits)
- Fix comment on req.grammar=None (grammar is CPU-only, not GPU)
- Keep req.grammar=None for releasing CPU resources on request finish

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[Bug] Qwen3.5+extra_buffer video memory leak during structured JSON response generation.

3 participants