Skip to content

Refactor: decouple segment tracking from comm registration#21392

Merged
ShangmingCai merged 24 commits intosgl-project:mainfrom
wangfakang:symm_spool_track
May 6, 2026
Merged

Refactor: decouple segment tracking from comm registration#21392
ShangmingCai merged 24 commits intosgl-project:mainfrom
wangfakang:symm_spool_track

Conversation

@wangfakang
Copy link
Copy Markdown
Contributor

@wangfakang wangfakang commented Mar 25, 2026

CC @nvcastet @yizhang2077 @merrymercy @ShangmingCai @Fridge003 @ch-wan PTAL, thx.

Motivation

When multiple communication groups share a single global MemPool, memory blocks released by one group's comm may be reused by another group's comm. However, symmetric memory requires buffers to be registered with a specific ncclComm via ncclCommWindowRegister. Reusing memory across groups causes the registration to be associated with the wrong communicator.

So redesign symmetric memory allocator to defer NCCL window registration from allocation-time to context exit-time. This enables correct memory reuse across different communicators and eliminates the CPU overhead of snapshot(). Thanks to @nvcastet for the help!

Related PR: #19329 (comment) and #20153

Modifications

Key changes:

  1. Allocation-time tracking: C++ layer now tracks memory segments (ptr, size) during their lifetime without registering to any comm.

  2. Deferred registration: Registration with the NCCL communicator happens at SymmetricMemoryContext.exit() using pynccl API, enabling proper handling of both newly allocated and reused memory.

  3. Shared MemPool: All groups share a single MemPool to reduce memory fragmentation, with proper per-comm registration tracking.

Benchmarking and Profiling

In benchmark testing, it was found that the cpu overhead of _get_tracked_segments() is about 25 times lower than the snapshot() function (5.351μs vs 134.320μs).

#NCCL_DEBUG=WARN python benchmark/bench_pynccl_allocator/bench_segment_tracking.py --num-segments 50 --num-iters 1000
================================================================================
Benchmark: Segment Tracking CPU Overhead
================================================================================
Segment size: 1.00 MB
Iterations per measurement: 1000

Segments     _get_tracked_segments (µs)     snapshot (µs)        Speedup   
--------------------------------------------------------------------------------
25           5.351                          134.320              25.10     x
--------------------------------------------------------------------------------

Checklist

Review Process

  1. Ping Merge Oncalls to start the PR flow. See the PR Merge Process.
  2. Get approvals from CODEOWNERS and other reviewers.
  3. Trigger CI tests with comments or contact authorized users to do so.
    • /tag-run-ci-label, /rerun-failed-ci, /tag-and-rerun-ci
  4. After green CI and required approvals, ask Merge Oncalls to merge.

Signed-off-by: wangfakang <fakangwang@gmail.com>
… comms

Signed-off-by: wangfakang <fakangwang@gmail.com>
@gemini-code-assist
Copy link
Copy Markdown
Contributor

Summary of Changes

Hello, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request refactors the symmetric memory allocator to decouple NCCL window registration from memory allocation. The previous approach led to issues when multiple communication groups shared a MemPool, as memory registered with one communicator could be incorrectly reused by another. The new design defers registration to the context exit time, allowing memory to be tracked at the C++ level and registered with the appropriate communicator in Python, which also significantly reduces CPU overhead compared to the prior snapshot() method.

Highlights

  • Allocation-time tracking: The C++ layer now tracks memory segments (pointer, size) throughout their lifetime without performing immediate NCCL registration.
  • Deferred registration: Registration with the NCCL communicator is now handled at the SymmetricMemoryContext.__exit__() using the pynccl API, ensuring correct handling for both newly allocated and reused memory.
  • Shared MemPool: All communication groups now share a single MemPool to minimize memory fragmentation, with proper per-communicator registration tracking managed by the Python layer.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for GitHub and other Google products, sign up here.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request refactors the PyNCCL allocator to introduce a C++-based segment tracking mechanism, deferring memory registration to the Python layer within the SymmetricMemoryContext. A new benchmark script is added to evaluate the CPU overhead of this new tracking method. Review feedback indicates a memory leak in the _ptr_to_registered_comms dictionary, as stale entries are not removed, and suggests a performance improvement for the C++ untrack_segment function by considering a different data structure than std::vector for better deallocation time complexity.

Comment thread python/sglang/srt/distributed/device_communicators/pynccl_allocator.py Outdated
Comment thread python/sglang/srt/distributed/device_communicators/pynccl_allocator.py Outdated
Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a new benchmark script to compare the CPU overhead of custom C++ segment tracking (using std::vector) against PyTorch's memory_snapshot(). The core changes in pynccl_allocator.py refactor how NCCL memory segments are tracked and registered. Segments are now tracked in a C++ std::vector and exposed to Python via ctypes. The registration with NCCL communicators is deferred to the Python SymmetricMemoryContext's __exit__ method, allowing for a single shared memory pool across groups and handling registration for both new and reused memory. Review comments point out a potential performance bottleneck in untrack_segment due to a linear scan, a thread-safety issue in the global _ptr_to_registered_comms dictionary, and several minor issues in the benchmark script including an incorrect type hint, an unused import, and code style improvements. An outdated comment in the C++ source also needs to be updated to reflect the use of std::vector instead of map.

Comment thread python/sglang/srt/distributed/device_communicators/pynccl_allocator.py Outdated
Comment thread python/sglang/srt/distributed/device_communicators/pynccl_allocator.py Outdated
Comment thread benchmark/bench_pynccl_allocator/bench_segment_tracking.py Outdated
Comment thread benchmark/bench_pynccl_allocator/bench_segment_tracking.py Outdated
Comment thread benchmark/bench_pynccl_allocator/bench_segment_tracking.py Outdated
Comment thread python/sglang/srt/distributed/device_communicators/pynccl_allocator.py Outdated
wangfakang and others added 4 commits March 25, 2026 17:24
Signed-off-by: wangfakang <fakangwang@gmail.com>
…cator.py

Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
Signed-off-by: wangfakang <fakangwang@gmail.com>
@wangfakang
Copy link
Copy Markdown
Contributor Author

/rerun-failed-ci

1 similar comment
@wangfakang
Copy link
Copy Markdown
Contributor Author

/rerun-failed-ci

Copy link
Copy Markdown
Collaborator

@nvcastet nvcastet left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Instead of copying list of segments from C++ to python, can you have a single C++ API register_segments_with_comm(comm_ptr)and have the registrations and book keeping inside C++.
Ideally you would have an unordered_map mapping a comm_ptr to next index idx in g_segments to register then you would register with:

next_idx = map[comm_ptr]; // next_idx will be 0 if map does not contain comm_prt
ncclComm_t comm = (ncclComm_t)(comm_ptr);
for (size_t i = next_idx; i < g_segments.size(); ++i) {
  auto seg = g_segments[i];
  ncclWindow_t win;
  NCCLCHECK(ncclCommWindowRegister(comm, seg[0], seg[1], &win, NCCL_WIN_COLL_SYMMETRIC));
}
map[comm_ptr] = g_segments.size();

Comment thread benchmark/bench_pynccl_allocator/bench_segment_tracking.py Outdated
Comment thread python/sglang/srt/distributed/device_communicators/pynccl_allocator.py Outdated
Comment thread python/sglang/srt/distributed/device_communicators/pynccl_allocator.py Outdated
@wangfakang
Copy link
Copy Markdown
Contributor Author

Instead of copying list of segments from C++ to python, can you have a single C++ API register_segments_with_comm(comm_ptr)and have the registrations and book keeping inside C++. Ideally you would have an unordered_map mapping a comm_ptr to next index idx in g_segments to register then you would register with:

next_idx = map[comm_ptr]; // next_idx will be 0 if map does not contain comm_prt
ncclComm_t comm = (ncclComm_t)(comm_ptr);
for (size_t i = next_idx; i < g_segments.size(); ++i) {
  auto seg = g_segments[i];
  ncclWindow_t win;
  NCCLCHECK(ncclCommWindowRegister(comm, seg[0], seg[1], &win, NCCL_WIN_COLL_SYMMETRIC));
}
map[comm_ptr] = g_segments.size();

@nvcastet Thank you for the suggestions. I have addressed all comments.

Copy link
Copy Markdown
Collaborator

@nvcastet nvcastet left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Once all changes are done.
Please make sure to test TP config and DP+EP config for deepseekR1 fp4 and check gpqa eval accuracy:
See comments at:
#8238
#9358

Comment thread python/sglang/srt/distributed/device_communicators/pynccl_allocator.py Outdated
Comment thread python/sglang/srt/distributed/device_communicators/pynccl_allocator.py Outdated
Comment thread python/sglang/srt/distributed/device_communicators/pynccl_allocator.py Outdated
Comment thread python/sglang/srt/distributed/device_communicators/pynccl_allocator.py Outdated
Comment thread python/sglang/srt/distributed/device_communicators/pynccl_allocator.py Outdated
Comment thread python/sglang/srt/distributed/device_communicators/pynccl_allocator.py Outdated
Comment thread benchmark/bench_pynccl_allocator/bench_segment_tracking.py Outdated
Comment thread python/sglang/srt/distributed/device_communicators/pynccl_allocator.py Outdated
Comment thread python/sglang/srt/distributed/device_communicators/pynccl_allocator.py Outdated
Comment thread python/sglang/srt/distributed/device_communicators/pynccl_allocator.py Outdated
Signed-off-by: wangfakang <fakangwang@gmail.com>
Signed-off-by: wangfakang <fakangwang@gmail.com>
@nvcastet
Copy link
Copy Markdown
Collaborator

/tag-and-rerun-ci

@wangfakang
Copy link
Copy Markdown
Contributor Author

/rerun-failed-ci

stage-b-test-2-gpu-large (2) install dependencies failed

image

@wangfakang
Copy link
Copy Markdown
Contributor Author

/rerun-failed-ci

Trigger waiting test task.

@wangfakang
Copy link
Copy Markdown
Contributor Author

Hello @nvcastet, the GPU-related test cases (stage-a-test-*, stage-b-test-*, and some stage-c-test-*) have all passed. I don't have permission to trigger the remaining stage-c-test-* cases individually. Could you please help with those? Thanks!

@wangfakang
Copy link
Copy Markdown
Contributor Author

Hello @nvcastet, the GPU-related test cases (stage-a-test-*, stage-b-test-*, and some stage-c-test-*) have all passed. I don't have permission to trigger the remaining stage-c-test-* cases individually. Could you please help with those? Thanks!

Frendly ping @nvcastet

@nvcastet
Copy link
Copy Markdown
Collaborator

/rerun-stage stage-c-test-4-gpu-b200

@github-actions
Copy link
Copy Markdown
Contributor

✅ Triggered stage-c-test-4-gpu-b200 to run independently (skipping dependencies). View workflow run

@nvcastet
Copy link
Copy Markdown
Collaborator

/rerun-stage stage-c-test-8-gpu-h200

@github-actions
Copy link
Copy Markdown
Contributor

✅ Triggered stage-c-test-8-gpu-h200 to run independently (skipping dependencies). View workflow run

@nvcastet
Copy link
Copy Markdown
Collaborator

You will need a code owner to review first for the PR to be merged.

@wangfakang
Copy link
Copy Markdown
Contributor Author

You will need a code owner to review first for the PR to be merged.

hello @nvcastet , thanks for the reminder and your patient review. This PR has already been approved by the code owner, @yizhang2077.

@nvcastet
Copy link
Copy Markdown
Collaborator

stage-c-test-4-gpu-b200 times out. I don't know if it is related to your changes.

@wangfakang
Copy link
Copy Markdown
Contributor Author

wangfakang commented Apr 29, 2026

stage-c-test-4-gpu-b200 times out. I don't know if it is related to your changes.

Hi @nvcastet , I checked the logs and found that the failing test case is test_fp8_blockwise_gemm.py. Since this test case doesn't enable symm, it won't execute the code modified in this PR. Therefore, the timeout issue is unrelated to the changes in this PR.
Additionally, I noticed that these test cases have had timeout problems before, as shown in these previous fix PRs:

Meanwhile, the corresponding deepseek-v3-fp4 and deepseek-v32 in this PR were also executed successfully, as the logs. This PR has been thoroughly validated in terms of both performance and accuracy, as detailed in the previous report.

@wangfakang
Copy link
Copy Markdown
Contributor Author

stage-c-test-4-gpu-b200 times out. I don't know if it is related to your changes.

Hi @nvcastet , I checked the logs and found that the failing test case is test_fp8_blockwise_gemm.py. Since this test case doesn't enable symm, it won't execute the code modified in this PR. Therefore, the timeout issue is unrelated to the changes in this PR. Additionally, I noticed that these test cases have had timeout problems before, as shown in these previous fix PRs:

Meanwhile, the corresponding deepseek-v3-fp4 and deepseek-v32 in this PR were also executed successfully, as the logs. This PR has been thoroughly validated in terms of both performance and accuracy, as detailed in the previous report.

Frendly ping @nvcastet

@nvcastet
Copy link
Copy Markdown
Collaborator

nvcastet commented May 4, 2026

There is a conflict to solve but looks good to me.

wangfakang added 2 commits May 5, 2026 00:12
Signed-off-by: wangfakang <fakangwang@gmail.com>
@wangfakang
Copy link
Copy Markdown
Contributor Author

There is a conflict to solve but looks good to me.

hello @nvcastet, Conflicts resolved with no logic changes. PTAL when you have time. Thanks!

@nvcastet
Copy link
Copy Markdown
Collaborator

nvcastet commented May 4, 2026

LGTM.
Someone with merge permission would need to push it. @Fridge003 ?

@wangfakang
Copy link
Copy Markdown
Contributor Author

LGTM. Someone with merge permission would need to push it. @Fridge003 ?

Thank you for the review, @nvcastet and @yizhang2077. Ping @ch-wan, @ShangmingCai, or @Fridge003 could you please take a look and help merge this when you have a moment? Thank you!

Signed-off-by: wangfakang <fakangwang@gmail.com>
@ShangmingCai
Copy link
Copy Markdown
Collaborator

/rerun-stage stage-c-test-4-gpu-b200

@github-actions
Copy link
Copy Markdown
Contributor

github-actions Bot commented May 6, 2026

✅ Triggered stage-c-test-4-gpu-b200 to run independently (skipping dependencies). View workflow run

@wangfakang
Copy link
Copy Markdown
Contributor Author

✅ Triggered stage-c-test-4-gpu-b200 to run independently (skipping dependencies). View workflow run

I checked the logs and found that the failing test case is test_qwen35_models.py . Since this test case doesn't enable symm, it won't execute the code modified in this PR. The error message is CUDA out of memory , which is unrelated to the changes in this PR.

image image

cc @ShangmingCai

This PR has been thoroughly validated for both performance and accuracy, as detailed in the previous comment: #21392 (comment). Based on these results, I believe it's ready to merge.

image

@ShangmingCai
Copy link
Copy Markdown
Collaborator

image

Related CI has passed.

@ShangmingCai ShangmingCai merged commit c8bc235 into sgl-project:main May 6, 2026
78 of 108 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants