Skip to content

[Bug Fix] Remove follow_bootstrap_room fast path in PD disaggregation DP rank resolution#22901

Merged
hnyls2002 merged 4 commits intomainfrom
fix/remove-follow-bootstrap-room-fast-path
Apr 16, 2026
Merged

[Bug Fix] Remove follow_bootstrap_room fast path in PD disaggregation DP rank resolution#22901
hnyls2002 merged 4 commits intomainfrom
fix/remove-follow-bootstrap-room-fast-path

Conversation

@ByronHsu
Copy link
Copy Markdown
Collaborator

@ByronHsu ByronHsu commented Apr 15, 2026

Motivation

In PD (Prefill-Decode) disaggregation mode with data parallelism, the decode server needs to know which prefill DP worker handled a given request to establish the correct KV transfer connection. There are two mechanisms for determining the DP worker:

  1. Explicit DP rank: The request carries data_parallel_rank (mapped to routed_dp_rank), which the DP controller dispatcher respects with highest priority via maybe_external_dp_rank_routing, overriding any load balance method.
  2. Load balance method: If no explicit rank is specified, the DP controller dispatches according to the configured method (e.g., round_robin, follow_bootstrap_room, etc.).

To communicate the actual prefill DP rank to the decode server:

  • Prefill calls _register_prefill_dp_rank to register the request's assigned DP rank to the bootstrap server.
  • Decode calls query_prefill_dp_ranks to retrieve the actual DP rank used by prefill.

Bug: When the prefill load balance method is follow_bootstrap_room, there were two fast-path optimizations that assumed the prefill DP rank is always bootstrap_room % dp_size:

  1. Prefill side (CommonKVSender.__init__): Skipped calling _register_prefill_dp_rank when load_balance_method == "follow_bootstrap_room", since the decode side could infer the rank from the bootstrap room number.
  2. Decode side (_resolve_prefill_dp_rank): Returned bootstrap_room % dp_size directly when prefill_info.follow_bootstrap_room was true, bypassing the query_prefill_dp_ranks call.

This assumption breaks when an external router (e.g., the model gateway) sets routed_dp_rank on the request. The DP controller routes to the externally-specified rank (which may differ from bootstrap_room % dp_size), but prefill never registers the actual rank and decode computes the wrong one.

Revised Fix

The original approach removed the fast path entirely and always fell through to register/query. This causes unnecessary perf regression for strict follow_bootstrap_room deployments (extra HTTP round-trip per request).

follow_bootstrap_room is a strict dispatch policy — SGLang assumes that if prefill uses it, the rank is always bootstrap_room % dp_size, and decode infers the rank accordingly without querying. The real issue is that routed_dp_rank can silently override this policy on the prefill side, causing a mismatch that decode cannot detect.

The revised fix takes a different approach:

  • Preserve the fast path: Decode still infers bootstrap_room % dp_size when follow_bootstrap_room is active. No perf regression for existing deployments.
  • Detect conflict on prefill side: If follow_bootstrap_room is configured but the actual attn_dp_rank != bootstrap_room % dp_size (i.e., an external routed_dp_rank overrode the dispatch), the request is aborted with a clear error message. This prevents decode from silently computing the wrong rank.
  • Env var escape hatch: SGLANG_DISAGGREGATION_FORCE_QUERY_PREFILL_DP_RANK=1 disables the fast path on both sides — prefill always registers, decode always queries. This supports non-strict routing (e.g., external router overriding follow_bootstrap_room) at the cost of an extra round-trip.

Checklist

@gemini-code-assist
Copy link
Copy Markdown
Contributor

Warning

You have reached your daily quota limit. Please wait up to 24 hours and I will start processing your requests again!

@github-actions github-actions Bot added blackwell SM100/SM120 diffusion SGLang Diffusion labels Apr 15, 2026
…s used with external routing

Remove the follow_bootstrap_room fast path in PD disaggregation DP rank
resolution. The fast path assumed the prefill DP rank is always
bootstrap_room % dp_size, which is incorrect when an external router
(e.g., the model gateway) overrides routing via routed_dp_rank /
data_parallel_rank on the request. The DP controller correctly respects
this override, but the fast path bypassed the actual rank registration
and query:

- Prefill: skipped _register_prefill_dp_rank when load_balance_method
  was follow_bootstrap_room, so the bootstrap server never learned the
  actual DP rank used.
- Decode: returned bootstrap_room % dp_size directly instead of querying
  the bootstrap server, getting the wrong rank when external routing
  was in effect.

Now prefill always registers its DP rank (when dp_size > 1), and decode
always queries for it when disagg_prefill_dp_rank is not explicitly set.

Made-with: Cursor
@ByronHsu ByronHsu force-pushed the fix/remove-follow-bootstrap-room-fast-path branch from 0bc16df to 638a986 Compare April 15, 2026 21:37
@ByronHsu ByronHsu changed the title Fix/remove follow bootstrap room fast path [Bug Fix] Remove follow_bootstrap_room fast path in PD disaggregation DP rank resolution Apr 15, 2026
Copy link
Copy Markdown
Collaborator

@ShangmingCai ShangmingCai left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good. But then it is not pure follow bootstrap room lb strategy anymore, and will increase the TTFT a little bit. I think it is a mixed strategy. Should we add a condition to check whether dp rank has been assigned by the router then we decide whether to bypass dp rank register and query?

@ShangmingCai
Copy link
Copy Markdown
Collaborator

CC @hnyls2002

@hnyls2002
Copy link
Copy Markdown
Collaborator

@ByronHsu @ShangmingCai

  • follow_bootstrap_room is a strict dispatch policy — SGLang assumes that if prefill's DP controller uses it, the rank is always bootstrap_room % dp_size, and decode infers the rank accordingly without querying.
  • However, routed_dp_rank can silently override any dispatch policy on the prefill side. When this happens, decode has no way to know the actual prefill rank diverged from bootstrap_room % dp_size. Ideally the router should also pass disagg_prefill_dp_rank to the decode instance so both sides stay aligned.
  • As a quick fix, we can add a flag telling decode to always query the real prefill dp rank from the bootstrap server. This avoids perf regression for strict follow_bootstrap_room deployments while also supporting non-strict external routing.

@ByronHsu
Copy link
Copy Markdown
Collaborator Author

ByronHsu commented Apr 16, 2026

Can we disallow using explicit dp rank for follow_bootstrap_room but allow for other policies? so we can avoid adding one more flag. In practice, will users use follow_bootstrap_room and explicit dp rank together? To me, follow_bootstrap_room is a policy to skip the query time for every request, but it conflicts with explicit dp rank. If they want to use explicit dp rank, they can choose other policies. Maybe we can set prefill and decode default policy to RR.

@hnyls2002
Copy link
Copy Markdown
Collaborator

/rerun-test test_disaggregation_basic.py test_disaggregation_dp_attention.py

@hnyls2002
Copy link
Copy Markdown
Collaborator

/rerun-test test/registered/disaggregation/test_disaggregation_basic.py

@hnyls2002
Copy link
Copy Markdown
Collaborator

/rerun-test test/registered/disaggregation/test_disaggregation_basic.py test_disaggregation_dp_attention.py

@github-actions
Copy link
Copy Markdown
Contributor

2-gpu-h100 (1 test): View workflow run

cd test/ && python3 registered/disaggregation/test_disaggregation_basic.py

8-gpu-h20 (1 test): View workflow run

cd test/ && python3 registered/distributed/test_disaggregation_dp_attention.py

@sgl-project sgl-project deleted a comment from github-actions Bot Apr 16, 2026
@sgl-project sgl-project deleted a comment from github-actions Bot Apr 16, 2026
@hnyls2002 hnyls2002 merged commit 3600465 into main Apr 16, 2026
31 of 70 checks passed
@hnyls2002 hnyls2002 deleted the fix/remove-follow-bootstrap-room-fast-path branch April 16, 2026 05:53
ByronHsu added a commit that referenced this pull request Apr 17, 2026
jmamou pushed a commit to jmamou/sglang that referenced this pull request Apr 20, 2026
yhyang201 pushed a commit to yhyang201/sglang that referenced this pull request Apr 22, 2026
zhangying098 pushed a commit to zhangying098/sglang that referenced this pull request Apr 23, 2026
kyx1999 pushed a commit to KMSorSMS/sglang that referenced this pull request Apr 27, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants