Skip to content

perf: avoid unnecessary gpu-cpu sync in eagle_info#20266

Merged
Qiaolin-Yu merged 3 commits intosgl-project:mainfrom
ehuaa:fix-eagle-sync-deadlock
Mar 20, 2026
Merged

perf: avoid unnecessary gpu-cpu sync in eagle_info#20266
Qiaolin-Yu merged 3 commits intosgl-project:mainfrom
ehuaa:fix-eagle-sync-deadlock

Conversation

@ehuaa
Copy link
Copy Markdown
Contributor

@ehuaa ehuaa commented Mar 10, 2026

Motivation

When stress testing Qwen3.5-397B with MTP enabled, the scheduler occasionally encounters the following scheduler watchdog error. Preliminary analysis suggests a potential deadlock arising from the interaction between GPU-CPU synchronization and collective communications (e.g., AllReduce).
According to the log below, the GPU-CPU sync happens in eagle_info.py.
This PR not only mitigates potential deadlock issues but also eliminates synchronization overhead, thereby improving overall inference performance.

[2026-03-10 02:09:02 TP2] Pyspy dump for PID 501830:
Process 501830: sglang::scheduler_TP2
Python v3.12.3 (/usr/bin/python3.12)

Thread 501830 (active): "MainThread"
    0x7f9484041d1d (libcuda.so.575.57.08)
    0x7f9483df9493 (libcuda.so.575.57.08)
    0x7f9483e3a2eb (libcuda.so.575.57.08)
    0x7f9484acc998 (libcuda.so.575.57.08)
    0x7f9484accdc8 (libcuda.so.575.57.08)
    0x7f9483e0542c (libcuda.so.575.57.08)
    0x7f9483ec4d3a (libcuda.so.575.57.08)
    0x7f9484a864e9 (libcuda.so.575.57.08)
    0x7f9483f76e37 (libcuda.so.575.57.08)
    cuStreamSynchronize (libcuda.so.575.57.08)
    0x7f95b1e128b3 (libcudart.so.12)
    cudaStreamSynchronize (libcudart.so.12)
    at::native::_local_scalar_dense_cuda(at::Tensor const&)::{lambda()#1}::operator() const (libtorch_cuda.so)
    at::native::_local_scalar_dense_cuda (libtorch_cuda.so)
    at::(anonymous namespace)::(anonymous namespace)::wrapper_CUDA___local_scalar_dense (libtorch_cuda.so)
    c10::impl::wrap_kernel_functor_unboxed_<c10::impl::detail::WrapFunctionIntoFunctor_<c10::CompileTimeFunctionPointer<c10::Scalar(at::Tensor const&), &at::(anonymous namespace)::(anonymous namespace)::wrapper_CUDA___local_scalar_
dense(at::Tensor const&)>, c10::Scalar, c10::guts::typelist::typelist<at::Tensor const&> >, c10::Scalar(at::Tensor const&)>::call (libtorch_cuda.so)
    at::_ops::_local_scalar_dense::redispatch (libtorch_cpu.so)
    torch::autograd::VariableType::(anonymous namespace)::_local_scalar_dense (libtorch_cpu.so)
    c10::impl::wrap_kernel_functor_unboxed_<c10::impl::detail::WrapFunctionIntoFunctor_<c10::CompileTimeFunctionPointer<c10::Scalar(c10::DispatchKeySet, at::Tensor const&), &torch::autograd::VariableType::(anonymous namespace)::_lo
cal_scalar_dense(c10::DispatchKeySet, at::Tensor const&)>, c10::Scalar, c10::guts::typelist::typelist<c10::DispatchKeySet, at::Tensor const&> >, c10::Scalar(c10::DispatchKeySet, at::Tensor const&)>::call (libtorch_cpu.so)
    at::_ops::_local_scalar_dense::call (libtorch_cpu.so)
    at::native::item (libtorch_cpu.so)
    c10::impl::wrap_kernel_functor_unboxed_<c10::impl::detail::WrapFunctionIntoFunctor_<c10::CompileTimeFunctionPointer<c10::Scalar(at::Tensor const&), &at::(anonymous namespace)::(anonymous namespace)::wrapper_CompositeImplicitAut
ograd__item(at::Tensor const&)>, c10::Scalar, c10::guts::typelist::typelist<at::Tensor const&> >, c10::Scalar(at::Tensor const&)>::call (libtorch_cpu.so)
    at::_ops::item::call (libtorch_cpu.so)
    at::native::is_nonzero (libtorch_cpu.so)
    at::_ops::is_nonzero::call (libtorch_cpu.so)
    torch::autograd::THPVariable_is_nonzero (libtorch_python.so)
    verify (eagle_info.py:340)
    verify (eagle_worker.py:748)
    forward_batch_generation (eagle_worker.py:316)
    run_batch (scheduler.py:2368)
    event_loop_normal (scheduler.py:1123)
   ...

Modifications

modify if-statement in eagle_info.py to use need_top_p_sampling instead of computing top_p sampling mask again.

Accuracy Tests

This pull requests does not affect model outputs.

Benchmarking and Profiling

Before
image

image

After
image

image b4c7080e-0027-4270-9da2-74b4ac789922

The comparison shows that we’ve shaved off serveral cuda kernel operations, for example, ten D2H operations (equivalent to the number of steps in the profile). And one cudaStreamSynchronize call has been eliminated, deferring the synchronization until the subsequent line:

accept_index_cpu = accept_index.tolist().

Checklist

Review Process

  1. Ping Merge Oncalls to start the PR flow. See the PR Merge Process.
  2. Get approvals from CODEOWNERS and other reviewers.
  3. Trigger CI tests with comments or contact authorized users to do so.
    • /tag-run-ci-label, /rerun-failed-ci, /tag-and-rerun-ci
  4. After green CI and required approvals, ask Merge Oncalls to merge.

@gemini-code-assist
Copy link
Copy Markdown
Contributor

Summary of Changes

Hello, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request addresses a critical issue where GPU-CPU synchronization in the eagle_info.py module could lead to deadlocks, particularly under heavy load. By refactoring a conditional check, the change removes the explicit synchronization, thereby preventing potential deadlocks and simultaneously enhancing inference performance by reducing overhead.

Highlights

  • Deadlock Prevention: Modified an if-statement in eagle_info.py to avoid implicit GPU-CPU synchronization, which was identified as a potential cause of deadlocks during stress testing with MTP enabled.
  • Performance Improvement: Eliminated synchronization overhead by preventing cudaStreamSynchronize calls, leading to improved overall inference performance.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Changelog
  • python/sglang/srt/speculative/eagle_info.py
    • Updated an if condition to use sampling_info.need_top_p_sampling instead of torch.all(sampling_info.top_ps == 1.0) to avoid implicit GPU-CPU synchronization.
Activity
  • No human activity has occurred on this pull request yet.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for GitHub and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

The pull request refactors the condition for applying top-p renormalization in the verify function within eagle_info.py. Instead of checking if all sampling_info.top_ps are equal to 1.0, the code now directly uses a dedicated boolean flag, sampling_info.need_top_p_sampling, to determine if top-p sampling is required. This change likely improves clarity and potentially efficiency or correctness by using an explicit flag.

@ehuaa ehuaa changed the title fix: avoid gpu-cpu sync in eagle_info to prevent deadlock fix: avoid unnecessary gpu-cpu sync in eagle_info Mar 12, 2026
@ehuaa
Copy link
Copy Markdown
Contributor Author

ehuaa commented Mar 12, 2026

Hi @hnyls2002 , can you help review this pr? Thanks.

@binbabou
Copy link
Copy Markdown

@ehuaa Hi, I am experiencing the same issue with Qwen3-235B deployment using sglang v0.5.4.post1. Despite simulating multiple stress testing cases, I have been unable to replicate the error. 😂
Could you provide more details on your stress testing methodology?"

eagle_info.py:308: if not torch.all(sampling_info.top_ps == 1.0):

Thread 1446 (active): "MainThread"
    verify (eagle_info.py:308)
    verify (eagle_worker.py:674)
    forward_batch_generation (eagle_worker.py:268)
    run_batch (scheduler.py:1936)
    event_loop_normal (scheduler.py:965)
    decorate_context (torch/utils/_contextlib.py:120)
    run_scheduler_process (scheduler.py:2777)
    run (multiprocessing/process.py:108)
    _bootstrap (multiprocessing/process.py:314)
    _main (multiprocessing/spawn.py:135)
    spawn_main (multiprocessing/spawn.py:122)
    <module> (<string>:1)

@ehuaa
Copy link
Copy Markdown
Contributor Author

ehuaa commented Mar 18, 2026

@ehuaa Hi, I am experiencing the same issue with Qwen3-235B deployment using sglang v0.5.4.post1. Despite simulating multiple stress testing cases, I have been unable to replicate the error. 😂 Could you provide more details on your stress testing methodology?"

eagle_info.py:308: if not torch.all(sampling_info.top_ps == 1.0):

Thread 1446 (active): "MainThread"
    verify (eagle_info.py:308)
    verify (eagle_worker.py:674)
    forward_batch_generation (eagle_worker.py:268)
    run_batch (scheduler.py:1936)
    event_loop_normal (scheduler.py:965)
    decorate_context (torch/utils/_contextlib.py:120)
    run_scheduler_process (scheduler.py:2777)
    run (multiprocessing/process.py:108)
    _bootstrap (multiprocessing/process.py:314)
    _main (multiprocessing/spawn.py:135)
    spawn_main (multiprocessing/spawn.py:122)
    <module> (<string>:1)

Hi @binbabou , my sglang version is v0.5.9 and my model version is Qwen3.5-397B, I simply use bench_serving.py to simulate stress testing, with 512 sequences(input 1k+8k) send to server simultaneously

@ehuaa ehuaa changed the title fix: avoid unnecessary gpu-cpu sync in eagle_info perf: avoid unnecessary gpu-cpu sync in eagle_info Mar 18, 2026
@Qiaolin-Yu
Copy link
Copy Markdown
Collaborator

Qiaolin-Yu commented Mar 20, 2026

/rerun-ut test_eagle_infer_beta.py

@github-actions
Copy link
Copy Markdown
Contributor

/rerun-ut is not available for fork PRs (security restriction).

Please ask a maintainer to add the run-ci label and use the normal CI flow, or use /rerun-failed-ci to rerun workflows that have already passed the gate.

1 similar comment
@github-actions
Copy link
Copy Markdown
Contributor

/rerun-ut is not available for fork PRs (security restriction).

Please ask a maintainer to add the run-ci label and use the normal CI flow, or use /rerun-failed-ci to rerun workflows that have already passed the gate.

Copy link
Copy Markdown
Collaborator

@Qiaolin-Yu Qiaolin-Yu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

tested it locally, lgtm

@Qiaolin-Yu Qiaolin-Yu merged commit fa0d8f6 into sgl-project:main Mar 20, 2026
59 of 67 checks passed
Wangzheee pushed a commit to Wangzheee/sglang that referenced this pull request Mar 21, 2026
0-693 pushed a commit to 0-693/sglang that referenced this pull request Mar 25, 2026
dutsc pushed a commit to dutsc/sglang that referenced this pull request Mar 30, 2026
JustinTong0323 pushed a commit to JustinTong0323/sglang that referenced this pull request Apr 7, 2026
yhyang201 pushed a commit to yhyang201/sglang that referenced this pull request Apr 22, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants