Skip to content

Commit 88a405c

Browse files
authored
Support EPLB balancedness prometheus metric without GPU->CPU synchronize (sgl-project#15401)
1 parent 602fe3b commit 88a405c

8 files changed

Lines changed: 111 additions & 33 deletions

File tree

python/sglang/srt/environ.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -268,6 +268,7 @@ class Envs:
268268
SGLANG_LOG_EXPERT_LOCATION_METADATA = EnvBool(False)
269269
SGLANG_EXPERT_DISTRIBUTION_RECORDER_DIR = EnvStr("/tmp")
270270
SGLANG_EPLB_HEATMAP_COLLECTION_INTERVAL = EnvInt(0)
271+
SGLANG_ENABLE_EPLB_BALANCEDNESS_METRIC = EnvBool(False)
271272

272273
# TBO
273274
SGLANG_TBO_DEBUG = EnvBool(False)

python/sglang/srt/eplb/expert_distribution.py

Lines changed: 58 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
from abc import ABC
2121
from collections import deque
2222
from contextlib import contextmanager
23+
from dataclasses import dataclass
2324
from pathlib import Path
2425
from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional, Tuple, Type
2526

@@ -43,6 +44,14 @@
4344
_OutputMode = Literal["file", "object"]
4445

4546

47+
@dataclass
48+
class ExpertDistributionMetrics:
49+
eplb_balancedness: torch.Tensor
50+
51+
def copy_to_cpu(self):
52+
self.eplb_balancedness = self.eplb_balancedness.to("cpu", non_blocking=True)
53+
54+
4655
class ExpertDistributionRecorder(ABC):
4756
"""Global expert distribution recording"""
4857

@@ -78,7 +87,7 @@ def disable_this_region(self):
7887

7988
@contextmanager
8089
def with_forward_pass(self, forward_pass_id: int, forward_batch: ForwardBatch):
81-
yield
90+
yield {}
8291

8392
def on_select_experts(self, topk_ids: torch.Tensor):
8493
pass
@@ -157,12 +166,13 @@ def with_debug_name(self, debug_name):
157166

158167
@contextmanager
159168
def with_forward_pass(self, forward_pass_id: int, forward_batch: ForwardBatch):
169+
outputs = {}
160170
with self._current_forward_pass_id.with_value(forward_pass_id):
161171
self._on_forward_pass_start(forward_batch)
162172
try:
163-
yield
173+
yield outputs
164174
finally:
165-
self._on_forward_pass_end(forward_pass_id)
175+
self._on_forward_pass_end(forward_pass_id, outputs)
166176

167177
@contextmanager
168178
def disable_this_region(self):
@@ -181,12 +191,14 @@ def _on_forward_pass_start(self, forward_batch: ForwardBatch):
181191
gatherer.reset()
182192
gatherer.on_forward_pass_start(forward_batch)
183193

184-
def _on_forward_pass_end(self, forward_pass_id: int):
194+
def _on_forward_pass_end(self, forward_pass_id: int, outputs: Dict[str, Any]):
185195
if not self._recording:
186196
return
187197
for gatherer_key, gatherer in self._single_pass_gatherers.items():
188198
single_pass_data = gatherer.collect()
189-
self._accumulator.append(forward_pass_id, gatherer_key, single_pass_data)
199+
self._accumulator.append(
200+
forward_pass_id, gatherer_key, single_pass_data, outputs
201+
)
190202

191203
def on_select_experts(self, topk_ids: torch.Tensor):
192204
self._on_hook("on_select_experts", topk_ids=topk_ids)
@@ -636,6 +648,7 @@ def append(
636648
forward_pass_id: int,
637649
gatherer_key: str,
638650
single_pass_data: Dict,
651+
outputs: Dict[str, Any],
639652
):
640653
pass
641654

@@ -659,18 +672,19 @@ def __init__(self, *args, **kwargs):
659672
self._expert_dispatch_collector = ExpertDispatchCollector(
660673
self._expert_location_metadata.ep_size
661674
)
662-
self._collection_counter = 0
675+
self._metric_heatmap_collection_counter = 0
663676

664677
def append(
665678
self,
666679
forward_pass_id: int,
667680
gatherer_key: str,
668681
single_pass_data: Dict,
682+
outputs: Dict[str, Any],
669683
):
670-
super().append(forward_pass_id, gatherer_key, single_pass_data)
684+
super().append(forward_pass_id, gatherer_key, single_pass_data, outputs)
671685
if self._enable:
672-
self._append_utilization_rate(
673-
forward_pass_id, single_pass_data["global_physical_count"]
686+
return self._append_utilization_rate(
687+
forward_pass_id, single_pass_data["global_physical_count"], outputs
674688
)
675689

676690
def reset(self):
@@ -679,7 +693,10 @@ def reset(self):
679693
self._history.clear()
680694

681695
def _append_utilization_rate(
682-
self, forward_pass_id: int, single_pass_global_physical_count: torch.Tensor
696+
self,
697+
forward_pass_id: int,
698+
single_pass_global_physical_count: torch.Tensor,
699+
outputs: Dict[str, Any],
683700
):
684701
gpu_physical_count = compute_gpu_physical_count(
685702
single_pass_global_physical_count,
@@ -691,27 +708,37 @@ def _append_utilization_rate(
691708
)
692709

693710
if self._rank == 0:
694-
self._collect_metrics_if_needed(gpu_physical_count)
711+
self._handle_metric_eplb_heatmap(gpu_physical_count)
695712

696-
utilization_rate_tensor = compute_utilization_rate(gpu_physical_count)
697-
utilization_rate = torch.mean(utilization_rate_tensor).item()
698-
self._history.append(utilization_rate)
699-
700-
gpu_physical_count_sum = gpu_physical_count.sum().item()
701-
702-
logger.info(
703-
f"[Expert Balancedness] "
704-
f"forward_pass_id={forward_pass_id} "
705-
f"current_pass_balancedness={utilization_rate:.03f} "
706-
f"{''.join(f'last_{size}_average_balancedness={value:.03f} ' for size, value in self._history.mean().items())} "
707-
f"gpu_physical_count_sum={gpu_physical_count_sum}"
708-
# f"current_pass_per_layer={[round(x, 2) for x in utilization_rate_tensor.cpu().tolist()]}"
713+
utilization_rate_gpu = torch.mean(
714+
compute_utilization_rate(gpu_physical_count)
709715
)
716+
if envs.SGLANG_ENABLE_EPLB_BALANCEDNESS_METRIC.get():
717+
print(f"hi {self._rank=} {utilization_rate_gpu=}")
718+
outputs["metrics"] = ExpertDistributionMetrics(
719+
eplb_balancedness=utilization_rate_gpu,
720+
)
721+
else:
722+
# TODO maybe refactor this part to also avoid a `.item()` gpu->cpu sync
723+
utilization_rate_cpu = utilization_rate_gpu.item()
724+
self._history.append(utilization_rate_cpu)
725+
726+
gpu_physical_count_sum = gpu_physical_count.sum().item()
727+
728+
logger.info(
729+
f"[Expert Balancedness] "
730+
f"forward_pass_id={forward_pass_id} "
731+
f"current_pass_balancedness={utilization_rate_cpu:.03f} "
732+
f"{''.join(f'last_{size}_average_balancedness={value:.03f} ' for size, value in self._history.mean().items())} "
733+
f"gpu_physical_count_sum={gpu_physical_count_sum}"
734+
# f"current_pass_per_layer={[round(x, 2) for x in utilization_rate_tensor.cpu().tolist()]}"
735+
)
710736

711-
def _collect_metrics_if_needed(self, gpu_physical_count: torch.Tensor):
737+
# TODO refactor
738+
def _handle_metric_eplb_heatmap(self, gpu_physical_count: torch.Tensor):
712739
# sglang:eplb_gpu_physical_count metric is disabled if SGLANG_EPLB_HEATMAP_COLLECTION_INTERVAL <= 0
713740
interval = get_int_env_var("SGLANG_EPLB_HEATMAP_COLLECTION_INTERVAL", 0)
714-
if interval > 0 and self._collection_counter % interval == 0:
741+
if interval > 0 and self._metric_heatmap_collection_counter % interval == 0:
715742
for layer_idx in range(self._expert_location_metadata.num_layers):
716743
count_of_layer = (
717744
self._expert_dispatch_collector.eplb_gpu_physical_count.labels(
@@ -728,7 +755,7 @@ def _collect_metrics_if_needed(self, gpu_physical_count: torch.Tensor):
728755
if count > 0:
729756
count_of_layer._sum.inc(count * gpu_rank)
730757
count_of_layer._buckets[gpu_rank].inc(count)
731-
self._collection_counter += 1
758+
self._metric_heatmap_collection_counter += 1
732759

733760

734761
class _DequeCollection:
@@ -767,8 +794,9 @@ def append(
767794
forward_pass_id: int,
768795
gatherer_key: str,
769796
single_pass_data: Dict,
797+
outputs: Dict[str, Any],
770798
):
771-
super().append(forward_pass_id, gatherer_key, single_pass_data)
799+
super().append(forward_pass_id, gatherer_key, single_pass_data, outputs)
772800

773801
def _process_object(obj):
774802
if isinstance(obj, torch.Tensor):
@@ -824,8 +852,9 @@ def append(
824852
forward_pass_id: int,
825853
gatherer_key: str,
826854
single_pass_data: Dict,
855+
outputs: Dict[str, Any],
827856
):
828-
super().append(forward_pass_id, gatherer_key, single_pass_data)
857+
super().append(forward_pass_id, gatherer_key, single_pass_data, outputs)
829858
# Can optimize if overhead here is large
830859
self._global_physical_count_of_buffered_step.append(
831860
single_pass_data["global_physical_count"]

python/sglang/srt/managers/scheduler.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2221,6 +2221,7 @@ def process_batch_result(
22212221
if result.copy_done is not None:
22222222
result.copy_done.synchronize()
22232223

2224+
self.log_batch_result_stats(batch, result)
22242225
self.maybe_send_health_check_signal()
22252226

22262227
def maybe_send_health_check_signal(self):

python/sglang/srt/managers/scheduler_metrics_mixin.py

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,20 +4,21 @@
44
import time
55
from collections import defaultdict
66
from contextlib import contextmanager
7-
from typing import TYPE_CHECKING, List, Optional
7+
from typing import TYPE_CHECKING, List, Optional, Union
88

99
from sglang.srt.disaggregation.kv_events import EventPublisherFactory, KVEventBatch
1010
from sglang.srt.disaggregation.utils import DisaggregationMode
1111
from sglang.srt.environ import envs
1212
from sglang.srt.managers.io_struct import GetLoadReqInput, GetLoadReqOutput
1313
from sglang.srt.managers.schedule_policy import PrefillAdder
1414
from sglang.srt.managers.scheduler import Req, ScheduleBatch
15+
from sglang.srt.managers.utils import GenerationBatchResult
1516
from sglang.srt.metrics.collector import SchedulerMetricsCollector, SchedulerStats
1617
from sglang.srt.utils import get_bool_env_var
1718
from sglang.srt.utils.device_timer import DeviceTimer
1819

1920
if TYPE_CHECKING:
20-
from sglang.srt.managers.scheduler import Scheduler
21+
from sglang.srt.managers.scheduler import EmbeddingBatchResult, Scheduler
2122

2223
logger = logging.getLogger(__name__)
2324

@@ -395,6 +396,22 @@ def log_decode_stats(
395396
self._emit_kv_metrics()
396397
self._publish_kv_events()
397398

399+
def log_batch_result_stats(
400+
self: Scheduler,
401+
batch: ScheduleBatch,
402+
result: Union[GenerationBatchResult, EmbeddingBatchResult],
403+
):
404+
if not self.enable_metrics:
405+
return
406+
if not isinstance(result, GenerationBatchResult):
407+
return
408+
409+
if (m := result.expert_distribution_metrics) is not None:
410+
self.metrics_collector.increment_eplb_balancedness(
411+
forward_mode=batch.forward_mode.name.lower(),
412+
balancedness=m.eplb_balancedness.item(),
413+
)
414+
398415
def _emit_kv_metrics(self: Scheduler):
399416
if not self.enable_kv_cache_events:
400417
return

python/sglang/srt/managers/tp_worker.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -406,6 +406,7 @@ def forward_batch_generation(
406406
batch_result = GenerationBatchResult(
407407
logits_output=logits_output,
408408
can_run_cuda_graph=can_run_cuda_graph,
409+
expert_distribution_metrics=out.expert_distribution_metrics,
409410
)
410411

411412
if is_verify:
@@ -460,6 +461,7 @@ def sample_batch_func():
460461
return GenerationBatchResult(
461462
pp_hidden_states_proxy_tensors=pp_proxy_tensors,
462463
can_run_cuda_graph=can_run_cuda_graph,
464+
expert_distribution_metrics=out.expert_distribution_metrics,
463465
)
464466

465467
def forward_batch_split_prefill(self, batch: ScheduleBatch):
@@ -482,6 +484,7 @@ def forward_batch_split_prefill(self, batch: ScheduleBatch):
482484
batch_result = GenerationBatchResult(
483485
logits_output=logits_output,
484486
can_run_cuda_graph=can_run_cuda_graph,
487+
expert_distribution_metrics=out.expert_distribution_metrics,
485488
)
486489
batch_result.next_token_ids = next_token_ids
487490
return batch_result

python/sglang/srt/managers/utils.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77
import torch
88

9+
from sglang.srt.eplb.expert_distribution import ExpertDistributionMetrics
910
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
1011
from sglang.srt.managers.overlap_utils import FutureIndices
1112
from sglang.srt.managers.schedule_batch import Req
@@ -44,6 +45,9 @@ class GenerationBatchResult:
4445
# relay path: forward stream -> next step forward
4546
next_draft_input: Optional[EagleDraftInput] = None
4647

48+
# metrics
49+
expert_distribution_metrics: Optional[ExpertDistributionMetrics] = None
50+
4751
def copy_to_cpu(self, return_logprob: bool):
4852
"""Copy tensors to CPU in overlap scheduling.
4953
Only the tensors which are needed for processing results are copied,
@@ -67,6 +71,9 @@ def copy_to_cpu(self, return_logprob: bool):
6771
if self.accept_lens is not None:
6872
self.accept_lens = self.accept_lens.to("cpu", non_blocking=True)
6973

74+
if (x := self.expert_distribution_metrics) is not None:
75+
x.copy_to_cpu()
76+
7077
self.copy_done.record()
7178

7279
@classmethod

python/sglang/srt/metrics/collector.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from typing import Dict, List, Optional, Union
2020

2121
from sglang.srt.disaggregation.utils import DisaggregationMode
22+
from sglang.srt.environ import envs
2223
from sglang.srt.metrics.utils import exponential_buckets, generate_buckets
2324
from sglang.srt.server_args import ServerArgs
2425
from sglang.srt.utils import get_bool_env_var
@@ -241,7 +242,7 @@ def __init__(
241242
labels: Dict[str, str],
242243
) -> None:
243244
# We need to import prometheus_client after setting the env variable `PROMETHEUS_MULTIPROC_DIR`
244-
from prometheus_client import Counter, Gauge, Histogram
245+
from prometheus_client import Counter, Gauge, Histogram, Summary
245246

246247
self.labels = labels
247248
self.last_log_time = time.perf_counter()
@@ -641,6 +642,15 @@ def __init__(
641642
labelnames=list(labels.keys()) + ["mode"],
642643
)
643644

645+
if (
646+
labels["moe_ep_rank"] == 0
647+
) and envs.SGLANG_ENABLE_EPLB_BALANCEDNESS_METRIC.get():
648+
self.eplb_balancedness = Summary(
649+
name="sglang:eplb_balancedness",
650+
documentation="Balancedness of MoE in expert parallelism.",
651+
labelnames=list(labels.keys()) + ["forward_mode"],
652+
)
653+
644654
self.new_token_ratio = Gauge(
645655
name="sglang:new_token_ratio",
646656
documentation="The new token ratio.",
@@ -698,6 +708,13 @@ def increment_cuda_graph_pass(self, value: bool) -> None:
698708
mode = "decode_cuda_graph" if value else "decode_none"
699709
self.cuda_graph_passes_total.labels(**self.labels, mode=mode).inc(1)
700710

711+
def increment_eplb_balancedness(
712+
self, forward_mode: str, balancedness: float
713+
) -> None:
714+
self.eplb_balancedness.labels(**self.labels, forward_mode=forward_mode).observe(
715+
balancedness
716+
)
717+
701718
def increment_realtime_tokens(
702719
self, prefill_compute_tokens=0, prefill_cache_tokens=0, decode_tokens=0
703720
):

python/sglang/srt/model_executor/model_runner.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,7 @@
6868
from sglang.srt.environ import envs
6969
from sglang.srt.eplb.eplb_manager import EPLBManager
7070
from sglang.srt.eplb.expert_distribution import (
71+
ExpertDistributionMetrics,
7172
ExpertDistributionRecorder,
7273
get_global_expert_distribution_recorder,
7374
set_global_expert_distribution_recorder,
@@ -272,6 +273,7 @@ def filter(self, record):
272273
class ModelRunnerOutput:
273274
logits_output: Union[LogitsProcessorOutput, PPProxyTensors]
274275
can_run_graph: bool
276+
expert_distribution_metrics: Optional[ExpertDistributionMetrics] = None
275277

276278

277279
class ModelRunner:
@@ -2738,14 +2740,15 @@ def forward(
27382740
with get_global_expert_distribution_recorder().with_forward_pass(
27392741
self.forward_pass_id,
27402742
forward_batch,
2741-
):
2743+
) as recorder_outputs:
27422744
output = self._forward_raw(
27432745
forward_batch,
27442746
skip_attn_backend_init,
27452747
pp_proxy_tensors,
27462748
reinit_attn_backend,
27472749
split_forward_count,
27482750
)
2751+
output.expert_distribution_metrics = recorder_outputs.get("metrics")
27492752

27502753
if self.eplb_manager is not None:
27512754
self.eplb_manager.on_forward_pass_end()

0 commit comments

Comments
 (0)