Skip to content

Commit 09ae5b2

Browse files
authored
Merge PDLB (Prefill-Decode Load Balancer) into SGLang Router (#7096)
1 parent 712bf9e commit 09ae5b2

13 files changed

Lines changed: 4045 additions & 187 deletions

File tree

python/sglang/srt/disaggregation/mini_lb.py

Lines changed: 27 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -218,15 +218,39 @@ async def get_server_info():
218218
)
219219
prefill_infos = []
220220
decode_infos = []
221+
all_internal_states = []
222+
221223
async with aiohttp.ClientSession() as session:
222224
for server in chain(prefill_servers):
223225
server_info = await session.get(f"{server}/get_server_info")
224226
prefill_infos.append(await server_info.json())
225227
for server in chain(decode_servers):
226228
server_info = await session.get(f"{server}/get_server_info")
227-
decode_infos.append(await server_info.json())
228-
229-
return {"prefill": prefill_infos, "decode": decode_infos}
229+
info_json = await server_info.json()
230+
decode_infos.append(info_json)
231+
# Extract internal_states from decode servers
232+
if "internal_states" in info_json:
233+
all_internal_states.extend(info_json["internal_states"])
234+
235+
# Return format expected by bench_one_batch_server.py
236+
if all_internal_states:
237+
return {
238+
"internal_states": all_internal_states,
239+
"prefill": prefill_infos,
240+
"decode": decode_infos,
241+
}
242+
else:
243+
# Fallback with dummy data if no internal states found
244+
return {
245+
"internal_states": [
246+
{
247+
"last_gen_throughput": 0.0,
248+
"avg_spec_accept_length": None,
249+
}
250+
],
251+
"prefill": prefill_infos,
252+
"decode": decode_infos,
253+
}
230254

231255

232256
@app.get("/get_model_info")

sgl-router/Cargo.toml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ serde = { version = "1.0", features = ["derive"] }
1515
clap = { version = "4.4", features = ["derive"] }
1616
bytes = "1.8.0"
1717
rand = "0.8.5"
18-
reqwest = { version = "0.12.8", features = ["stream", "blocking"] }
18+
reqwest = { version = "0.12.8", features = ["stream", "blocking", "json"] }
1919
futures-util = "0.3"
2020
serde_json = "1.0"
2121
pyo3 = { version = "0.22.5", features = ["extension-module"] }
@@ -33,6 +33,8 @@ futures = "0.3"
3333
# Added for metrics
3434
metrics = "0.24.2"
3535
metrics-exporter-prometheus = "0.17.0"
36+
# Added for request tracing
37+
uuid = { version = "1.10", features = ["v4", "serde"] }
3638
[profile.release]
3739
lto = "thin"
3840
codegen-units = 1

sgl-router/py_src/sglang_router/launch_router.py

Lines changed: 107 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,13 @@ class RouterArgs:
3131
host: str = "127.0.0.1"
3232
port: int = 30000
3333

34+
# PD-specific configuration
35+
pd_disaggregated: bool = False # Enable PD disaggregated mode
36+
prefill_urls: List[tuple] = dataclasses.field(
37+
default_factory=list
38+
) # List of (url, bootstrap_port)
39+
decode_urls: List[str] = dataclasses.field(default_factory=list)
40+
3441
# Routing policy
3542
policy: str = "cache_aware"
3643
worker_startup_timeout_secs: int = 300
@@ -40,7 +47,7 @@ class RouterArgs:
4047
balance_rel_threshold: float = 1.0001
4148
eviction_interval: int = 60
4249
max_tree_size: int = 2**24
43-
max_payload_size: int = 4 * 1024 * 1024 # 4MB
50+
max_payload_size: int = 256 * 1024 * 1024 # 256MB default for large batches
4451
verbose: bool = False
4552
log_dir: Optional[str] = None
4653
# Service discovery configuration
@@ -95,8 +102,29 @@ def add_cli_args(
95102
f"--{prefix}policy",
96103
type=str,
97104
default=RouterArgs.policy,
98-
choices=["random", "round_robin", "cache_aware"],
99-
help="Load balancing policy to use",
105+
choices=["random", "round_robin", "cache_aware", "power_of_two"],
106+
help="Load balancing policy to use. Note: power_of_two is only available in PD disaggregated mode",
107+
)
108+
109+
# PD-specific arguments
110+
parser.add_argument(
111+
f"--{prefix}pd-disaggregated",
112+
action="store_true",
113+
help="Enable PD (Prefill-Decode) disaggregated mode",
114+
)
115+
parser.add_argument(
116+
f"--{prefix}prefill",
117+
nargs=2,
118+
action="append",
119+
metavar=("URL", "BOOTSTRAP_PORT"),
120+
help="Prefill server URL and bootstrap port. Can be specified multiple times. BOOTSTRAP_PORT can be 'none' for no bootstrap port.",
121+
)
122+
parser.add_argument(
123+
f"--{prefix}decode",
124+
nargs=1,
125+
action="append",
126+
metavar=("URL",),
127+
help="Decode server URL. Can be specified multiple times.",
100128
)
101129
parser.add_argument(
102130
f"--{prefix}worker-startup-timeout-secs",
@@ -205,11 +233,19 @@ def from_cli_args(
205233
use_router_prefix: If True, look for arguments with 'router-' prefix
206234
"""
207235
prefix = "router_" if use_router_prefix else ""
208-
worker_urls = args.worker_urls if args.worker_urls is not None else []
236+
worker_urls = getattr(args, "worker_urls", [])
237+
238+
# Parse PD URLs
239+
prefill_urls = cls._parse_prefill_urls(getattr(args, f"{prefix}prefill", None))
240+
decode_urls = cls._parse_decode_urls(getattr(args, f"{prefix}decode", None))
241+
209242
return cls(
210243
worker_urls=worker_urls,
211244
host=args.host,
212245
port=args.port,
246+
pd_disaggregated=getattr(args, f"{prefix}pd_disaggregated", False),
247+
prefill_urls=prefill_urls,
248+
decode_urls=decode_urls,
213249
policy=getattr(args, f"{prefix}policy"),
214250
worker_startup_timeout_secs=getattr(
215251
args, f"{prefix}worker_startup_timeout_secs"
@@ -247,13 +283,54 @@ def _parse_selector(selector_list):
247283
selector[key] = value
248284
return selector
249285

286+
@staticmethod
287+
def _parse_prefill_urls(prefill_list):
288+
"""Parse prefill URLs from --prefill arguments.
289+
290+
Format: --prefill URL BOOTSTRAP_PORT
291+
Example: --prefill http://prefill1:8080 9000 --prefill http://prefill2:8080 none
292+
"""
293+
if not prefill_list:
294+
return []
295+
296+
prefill_urls = []
297+
for url, bootstrap_port_str in prefill_list:
298+
# Handle 'none' as None
299+
if bootstrap_port_str.lower() == "none":
300+
bootstrap_port = None
301+
else:
302+
try:
303+
bootstrap_port = int(bootstrap_port_str)
304+
except ValueError:
305+
raise ValueError(
306+
f"Invalid bootstrap port: {bootstrap_port_str}. Must be a number or 'none'"
307+
)
308+
309+
prefill_urls.append((url, bootstrap_port))
310+
311+
return prefill_urls
312+
313+
@staticmethod
314+
def _parse_decode_urls(decode_list):
315+
"""Parse decode URLs from --decode arguments.
316+
317+
Format: --decode URL
318+
Example: --decode http://decode1:8081 --decode http://decode2:8081
319+
"""
320+
if not decode_list:
321+
return []
322+
323+
# decode_list is a list of single-element lists due to nargs=1
324+
return [url[0] for url in decode_list]
325+
250326

251327
def policy_from_str(policy_str: str) -> PolicyType:
252328
"""Convert policy string to PolicyType enum."""
253329
policy_map = {
254330
"random": PolicyType.Random,
255331
"round_robin": PolicyType.RoundRobin,
256332
"cache_aware": PolicyType.CacheAware,
333+
"power_of_two": PolicyType.PowerOfTwo,
257334
}
258335
return policy_map[policy_str]
259336

@@ -277,8 +354,19 @@ def launch_router(args: argparse.Namespace) -> Optional[Router]:
277354
else:
278355
router_args = args
279356

357+
# Validate configuration based on mode
358+
if router_args.pd_disaggregated:
359+
# Validate PD configuration
360+
if not router_args.prefill_urls:
361+
raise ValueError("PD disaggregated mode requires --prefill")
362+
if not router_args.decode_urls:
363+
raise ValueError("PD disaggregated mode requires --decode")
364+
365+
# Create router with unified constructor
280366
router = Router(
281-
worker_urls=router_args.worker_urls,
367+
worker_urls=(
368+
router_args.worker_urls if not router_args.pd_disaggregated else []
369+
),
282370
host=router_args.host,
283371
port=router_args.port,
284372
policy=policy_from_str(router_args.policy),
@@ -298,6 +386,13 @@ def launch_router(args: argparse.Namespace) -> Optional[Router]:
298386
service_discovery_namespace=router_args.service_discovery_namespace,
299387
prometheus_port=router_args.prometheus_port,
300388
prometheus_host=router_args.prometheus_host,
389+
pd_disaggregated=router_args.pd_disaggregated,
390+
prefill_urls=(
391+
router_args.prefill_urls if router_args.pd_disaggregated else None
392+
),
393+
decode_urls=(
394+
router_args.decode_urls if router_args.pd_disaggregated else None
395+
),
301396
)
302397

303398
router.start()
@@ -326,8 +421,14 @@ def parse_router_args(args: List[str]) -> RouterArgs:
326421
multi-node setups or when you want to start workers and router separately.
327422
328423
Examples:
424+
# Regular mode
329425
python -m sglang_router.launch_router --worker-urls http://worker1:8000 http://worker2:8000
330-
python -m sglang_router.launch_router --worker-urls http://worker1:8000 http://worker2:8000 --cache-threshold 0.7 --balance-abs-threshold 64 --balance-rel-threshold 1.2
426+
427+
# PD disaggregated mode
428+
python -m sglang_router.launch_router --pd-disaggregated \\
429+
--prefill http://prefill1:8000 9000 --prefill http://prefill2:8000 none \\
430+
--decode http://decode1:8001 --decode http://decode2:8001 \\
431+
--policy cache_aware
331432
332433
""",
333434
formatter_class=CustomHelpFormatter,

sgl-router/py_src/sglang_router/router.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ class Router:
1515
- PolicyType.Random: Randomly select workers
1616
- PolicyType.RoundRobin: Distribute requests in round-robin fashion
1717
- PolicyType.CacheAware: Distribute requests based on cache state and load balance
18+
- PolicyType.PowerOfTwo: Select best of two random workers based on load (PD mode only)
1819
host: Host address to bind the router server. Default: '127.0.0.1'
1920
port: Port number to bind the router server. Default: 3001
2021
worker_startup_timeout_secs: Timeout in seconds for worker startup. Default: 300
@@ -28,7 +29,7 @@ class Router:
2829
AND max_load > min_load * rel_threshold. Otherwise, use cache aware. Default: 1.0001
2930
eviction_interval_secs: Interval in seconds between cache eviction operations in cache-aware
3031
routing. Default: 60
31-
max_payload_size: Maximum payload size in bytes. Default: 4MB
32+
max_payload_size: Maximum payload size in bytes. Default: 256MB
3233
max_tree_size: Maximum size of the approximation tree for cache-aware routing. Default: 2^24
3334
verbose: Enable verbose logging. Default: False
3435
log_dir: Directory to store log files. If None, logs are only output to console. Default: None
@@ -42,6 +43,9 @@ class Router:
4243
watches pods across all namespaces (requires cluster-wide permissions). Default: None
4344
prometheus_port: Port to expose Prometheus metrics. Default: None
4445
prometheus_host: Host address to bind the Prometheus metrics server. Default: None
46+
pd_disaggregated: Enable PD (Prefill-Decode) disaggregated mode. Default: False
47+
prefill_urls: List of (url, bootstrap_port) tuples for prefill servers (PD mode only)
48+
decode_urls: List of URLs for decode servers (PD mode only)
4549
"""
4650

4751
def __init__(
@@ -57,7 +61,7 @@ def __init__(
5761
balance_rel_threshold: float = 1.0001,
5862
eviction_interval_secs: int = 60,
5963
max_tree_size: int = 2**24,
60-
max_payload_size: int = 4 * 1024 * 1024, # 4MB
64+
max_payload_size: int = 256 * 1024 * 1024, # 256MB
6165
verbose: bool = False,
6266
log_dir: Optional[str] = None,
6367
service_discovery: bool = False,
@@ -66,6 +70,9 @@ def __init__(
6670
service_discovery_namespace: Optional[str] = None,
6771
prometheus_port: Optional[int] = None,
6872
prometheus_host: Optional[str] = None,
73+
pd_disaggregated: bool = False,
74+
prefill_urls: Optional[List[tuple]] = None,
75+
decode_urls: Optional[List[str]] = None,
6976
):
7077
if selector is None:
7178
selector = {}
@@ -91,6 +98,9 @@ def __init__(
9198
service_discovery_namespace=service_discovery_namespace,
9299
prometheus_port=prometheus_port,
93100
prometheus_host=prometheus_host,
101+
pd_disaggregated=pd_disaggregated,
102+
prefill_urls=prefill_urls,
103+
decode_urls=decode_urls,
94104
)
95105

96106
def start(self) -> None:

0 commit comments

Comments
 (0)