2020from abc import ABC
2121from collections import deque
2222from contextlib import contextmanager
23+ from dataclasses import dataclass
2324from pathlib import Path
2425from typing import TYPE_CHECKING , Any , Dict , List , Literal , Optional , Tuple , Type
2526
4344_OutputMode = Literal ["file" , "object" ]
4445
4546
47+ @dataclass
48+ class ExpertDistributionMetrics :
49+ eplb_balancedness : torch .Tensor
50+
51+ def copy_to_cpu (self ):
52+ self .eplb_balancedness = self .eplb_balancedness .to ("cpu" , non_blocking = True )
53+
54+
4655class ExpertDistributionRecorder (ABC ):
4756 """Global expert distribution recording"""
4857
@@ -78,7 +87,7 @@ def disable_this_region(self):
7887
7988 @contextmanager
8089 def with_forward_pass (self , forward_pass_id : int , forward_batch : ForwardBatch ):
81- yield
90+ yield {}
8291
8392 def on_select_experts (self , topk_ids : torch .Tensor ):
8493 pass
@@ -157,12 +166,13 @@ def with_debug_name(self, debug_name):
157166
158167 @contextmanager
159168 def with_forward_pass (self , forward_pass_id : int , forward_batch : ForwardBatch ):
169+ outputs = {}
160170 with self ._current_forward_pass_id .with_value (forward_pass_id ):
161171 self ._on_forward_pass_start (forward_batch )
162172 try :
163- yield
173+ yield outputs
164174 finally :
165- self ._on_forward_pass_end (forward_pass_id )
175+ self ._on_forward_pass_end (forward_pass_id , outputs )
166176
167177 @contextmanager
168178 def disable_this_region (self ):
@@ -181,12 +191,14 @@ def _on_forward_pass_start(self, forward_batch: ForwardBatch):
181191 gatherer .reset ()
182192 gatherer .on_forward_pass_start (forward_batch )
183193
184- def _on_forward_pass_end (self , forward_pass_id : int ):
194+ def _on_forward_pass_end (self , forward_pass_id : int , outputs : Dict [ str , Any ] ):
185195 if not self ._recording :
186196 return
187197 for gatherer_key , gatherer in self ._single_pass_gatherers .items ():
188198 single_pass_data = gatherer .collect ()
189- self ._accumulator .append (forward_pass_id , gatherer_key , single_pass_data )
199+ self ._accumulator .append (
200+ forward_pass_id , gatherer_key , single_pass_data , outputs
201+ )
190202
191203 def on_select_experts (self , topk_ids : torch .Tensor ):
192204 self ._on_hook ("on_select_experts" , topk_ids = topk_ids )
@@ -636,6 +648,7 @@ def append(
636648 forward_pass_id : int ,
637649 gatherer_key : str ,
638650 single_pass_data : Dict ,
651+ outputs : Dict [str , Any ],
639652 ):
640653 pass
641654
@@ -659,18 +672,19 @@ def __init__(self, *args, **kwargs):
659672 self ._expert_dispatch_collector = ExpertDispatchCollector (
660673 self ._expert_location_metadata .ep_size
661674 )
662- self ._collection_counter = 0
675+ self ._metric_heatmap_collection_counter = 0
663676
664677 def append (
665678 self ,
666679 forward_pass_id : int ,
667680 gatherer_key : str ,
668681 single_pass_data : Dict ,
682+ outputs : Dict [str , Any ],
669683 ):
670- super ().append (forward_pass_id , gatherer_key , single_pass_data )
684+ super ().append (forward_pass_id , gatherer_key , single_pass_data , outputs )
671685 if self ._enable :
672- self ._append_utilization_rate (
673- forward_pass_id , single_pass_data ["global_physical_count" ]
686+ return self ._append_utilization_rate (
687+ forward_pass_id , single_pass_data ["global_physical_count" ], outputs
674688 )
675689
676690 def reset (self ):
@@ -679,7 +693,10 @@ def reset(self):
679693 self ._history .clear ()
680694
681695 def _append_utilization_rate (
682- self , forward_pass_id : int , single_pass_global_physical_count : torch .Tensor
696+ self ,
697+ forward_pass_id : int ,
698+ single_pass_global_physical_count : torch .Tensor ,
699+ outputs : Dict [str , Any ],
683700 ):
684701 gpu_physical_count = compute_gpu_physical_count (
685702 single_pass_global_physical_count ,
@@ -691,27 +708,37 @@ def _append_utilization_rate(
691708 )
692709
693710 if self ._rank == 0 :
694- self ._collect_metrics_if_needed (gpu_physical_count )
711+ self ._handle_metric_eplb_heatmap (gpu_physical_count )
695712
696- utilization_rate_tensor = compute_utilization_rate (gpu_physical_count )
697- utilization_rate = torch .mean (utilization_rate_tensor ).item ()
698- self ._history .append (utilization_rate )
699-
700- gpu_physical_count_sum = gpu_physical_count .sum ().item ()
701-
702- logger .info (
703- f"[Expert Balancedness] "
704- f"forward_pass_id={ forward_pass_id } "
705- f"current_pass_balancedness={ utilization_rate :.03f} "
706- f"{ '' .join (f'last_{ size } _average_balancedness={ value :.03f} ' for size , value in self ._history .mean ().items ())} "
707- f"gpu_physical_count_sum={ gpu_physical_count_sum } "
708- # f"current_pass_per_layer={[round(x, 2) for x in utilization_rate_tensor.cpu().tolist()]}"
713+ utilization_rate_gpu = torch .mean (
714+ compute_utilization_rate (gpu_physical_count )
709715 )
716+ if envs .SGLANG_ENABLE_EPLB_BALANCEDNESS_METRIC .get ():
717+ print (f"hi { self ._rank = } { utilization_rate_gpu = } " )
718+ outputs ["metrics" ] = ExpertDistributionMetrics (
719+ eplb_balancedness = utilization_rate_gpu ,
720+ )
721+ else :
722+ # TODO maybe refactor this part to also avoid a `.item()` gpu->cpu sync
723+ utilization_rate_cpu = utilization_rate_gpu .item ()
724+ self ._history .append (utilization_rate_cpu )
725+
726+ gpu_physical_count_sum = gpu_physical_count .sum ().item ()
727+
728+ logger .info (
729+ f"[Expert Balancedness] "
730+ f"forward_pass_id={ forward_pass_id } "
731+ f"current_pass_balancedness={ utilization_rate_cpu :.03f} "
732+ f"{ '' .join (f'last_{ size } _average_balancedness={ value :.03f} ' for size , value in self ._history .mean ().items ())} "
733+ f"gpu_physical_count_sum={ gpu_physical_count_sum } "
734+ # f"current_pass_per_layer={[round(x, 2) for x in utilization_rate_tensor.cpu().tolist()]}"
735+ )
710736
711- def _collect_metrics_if_needed (self , gpu_physical_count : torch .Tensor ):
737+ # TODO refactor
738+ def _handle_metric_eplb_heatmap (self , gpu_physical_count : torch .Tensor ):
712739 # sglang:eplb_gpu_physical_count metric is disabled if SGLANG_EPLB_HEATMAP_COLLECTION_INTERVAL <= 0
713740 interval = get_int_env_var ("SGLANG_EPLB_HEATMAP_COLLECTION_INTERVAL" , 0 )
714- if interval > 0 and self ._collection_counter % interval == 0 :
741+ if interval > 0 and self ._metric_heatmap_collection_counter % interval == 0 :
715742 for layer_idx in range (self ._expert_location_metadata .num_layers ):
716743 count_of_layer = (
717744 self ._expert_dispatch_collector .eplb_gpu_physical_count .labels (
@@ -728,7 +755,7 @@ def _collect_metrics_if_needed(self, gpu_physical_count: torch.Tensor):
728755 if count > 0 :
729756 count_of_layer ._sum .inc (count * gpu_rank )
730757 count_of_layer ._buckets [gpu_rank ].inc (count )
731- self ._collection_counter += 1
758+ self ._metric_heatmap_collection_counter += 1
732759
733760
734761class _DequeCollection :
@@ -767,8 +794,9 @@ def append(
767794 forward_pass_id : int ,
768795 gatherer_key : str ,
769796 single_pass_data : Dict ,
797+ outputs : Dict [str , Any ],
770798 ):
771- super ().append (forward_pass_id , gatherer_key , single_pass_data )
799+ super ().append (forward_pass_id , gatherer_key , single_pass_data , outputs )
772800
773801 def _process_object (obj ):
774802 if isinstance (obj , torch .Tensor ):
@@ -824,8 +852,9 @@ def append(
824852 forward_pass_id : int ,
825853 gatherer_key : str ,
826854 single_pass_data : Dict ,
855+ outputs : Dict [str , Any ],
827856 ):
828- super ().append (forward_pass_id , gatherer_key , single_pass_data )
857+ super ().append (forward_pass_id , gatherer_key , single_pass_data , outputs )
829858 # Can optimize if overhead here is large
830859 self ._global_physical_count_of_buffered_step .append (
831860 single_pass_data ["global_physical_count" ]
0 commit comments