Skip to content

[Fix] Solve the error lead by _commit_transfer_to_req() when using IntraNode NVLink in PD disaggregation#23252

Merged
ShangmingCai merged 22 commits intosgl-project:mainfrom
TTThanos:Debug/new_intra_nvlink
Apr 21, 2026
Merged

[Fix] Solve the error lead by _commit_transfer_to_req() when using IntraNode NVLink in PD disaggregation#23252
ShangmingCai merged 22 commits intosgl-project:mainfrom
TTThanos:Debug/new_intra_nvlink

Conversation

@TTThanos
Copy link
Copy Markdown
Contributor

@TTThanos TTThanos commented Apr 20, 2026

Motivation

When using Mooncake IntraNode nvlink as kv cache transport backend in PD disaggregation, decode server will crash and error log can be observed
image
This is related to the commit #8ed35df introduced metadata validation in _commit_transfer_to_req(). In IntraNode nvlink scenarios, metadata buffer is allocated on the GPU and transferred via nvlink. However, due to asynchronous nature of NVLink transport, metadata buffer may not have been transferred when poll == KVPoll.Success. Consequently, bootstrap_room remains 0 for some ranks, while other ranks have successfully read and removed their metadata, leading to mismatch in the subsequent poll_and_all_reduce() call.

Modifications

Accuracy Tests

launch server
`model_path=/mnt/models/Qwen3-235B-A22B-FP8
FILE_NAME_PREFIX=Decode_Mooncake_INTRANVLINK_kv_transfer_Hicache_test_qwen3_235b_tp4_0412

export MC_TE_METRIC=true
export MC_INTRANODE_NVLINK=true
SGLANG_MOONCAKE_CUSTOM_MEM_POOL=true SGLANG_MOONCAKE_CUSTOM_MEM_POOL=INTRA_NODE_NVLINK MC_LOG_LEVEL=INFO SGLANG_TORCH_PROFILER_DIR=/root/profile/ python3 -m sglang.launch_server
--model-path ${model_path}
--tp 4
--mem-fraction-static 0.85
--base-gpu-id 4
--disaggregation-mode decode
--port 7002
--watchdog-timeout 1000000 --decode-log-interval 1 >/root/log/${FILE_NAME_PREFIX}.log 2>&1`

Speed Tests and Profiling

Tested on 4K/2K input/output using bench serving test
image

Checklist

Review and Merge Process

  1. Ping Merge Oncalls to start the process. See the PR Merge Process.
  2. Get approvals from CODEOWNERS and other reviewers.
  3. Trigger CI tests with comments or contact authorized users to do so.
    • Common commands include /tag-and-rerun-ci, /tag-run-ci-label, /rerun-failed-ci
  4. After green CI and required approvals, ask Merge Oncalls or people with Write permission to merge the PR.

百麒 and others added 20 commits December 27, 2025 08:16
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request updates the Mooncake disaggregation logic to support the INTRA_NODE_NVLINK custom memory pool type. Key changes include forcing the use of TCP for auxiliary data transfers and switching the device type to CPU for specific memory pool configurations. A review comment suggests improving code readability in conn.py by using internal class properties instead of directly accessing environment variables in conditional checks.

Comment on lines +878 to +880
if (
self.enable_custom_mem_pool and self.custom_mem_pool_type == "NVLINK"
) or envs.SGLANG_MOONCAKE_SEND_AUX_TCP.get():
) or envs.SGLANG_MOONCAKE_CUSTOM_MEM_POOL.get() == "INTRA_NODE_NVLINK" or envs.SGLANG_MOONCAKE_SEND_AUX_TCP.get():
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

For better readability and to avoid re-reading environment variables, consider using the self.custom_mem_pool_type property which should already hold the correct value. This simplifies the condition and makes it more consistent with the surrounding code.

Suggested change
if (
self.enable_custom_mem_pool and self.custom_mem_pool_type == "NVLINK"
) or envs.SGLANG_MOONCAKE_SEND_AUX_TCP.get():
) or envs.SGLANG_MOONCAKE_CUSTOM_MEM_POOL.get() == "INTRA_NODE_NVLINK" or envs.SGLANG_MOONCAKE_SEND_AUX_TCP.get():
if (
self.enable_custom_mem_pool
and self.custom_mem_pool_type in ("NVLINK", "INTRA_NODE_NVLINK")
) or envs.SGLANG_MOONCAKE_SEND_AUX_TCP.get():

Copy link
Copy Markdown
Collaborator

@ShangmingCai ShangmingCai left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The bootstrap room validation makes sure the values are correct on the decode side. Maybe we can sync the status before marking it as successful, then moving it to the waiting queue.

This fix should work, but not a real fix, actually. I will look into it.

@ShangmingCai
Copy link
Copy Markdown
Collaborator

please fix lint first.

@stmatengss
Copy link
Copy Markdown
Collaborator

Following our offline discussion, LGTM.

@TTThanos
Copy link
Copy Markdown
Contributor Author

The bootstrap room validation makes sure the values are correct on the decode side. Maybe we can sync the status before marking it as successful, then moving it to the waiting queue.

This fix should work, but not a real fix, actually. I will look into it.

Exactly, sychronization can be made to avoid this problem.

@ShangmingCai ShangmingCai merged commit 0d04052 into sgl-project:main Apr 21, 2026
56 of 64 checks passed
zhangying098 pushed a commit to zhangying098/sglang that referenced this pull request Apr 23, 2026
…traNode NVLink in PD disaggregation (sgl-project#23252)

Co-authored-by: 百麒 <yaozhong.lyz@alibaba-inc.com>
kyx1999 pushed a commit to KMSorSMS/sglang that referenced this pull request Apr 27, 2026
…traNode NVLink in PD disaggregation (sgl-project#23252)

Co-authored-by: 百麒 <yaozhong.lyz@alibaba-inc.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants