Skip to content

[Feature] Deep EPLB Integration Proposal (Draft) #5309

@yicwang

Description

@yicwang

Checklist

Motivation

Background

Deep EPLB

With Expert Parallelism, experts are assigned to different GPUs at startup time. The load of experts may vary depending on the scenario and is dynamically changing, which could imbalance GPU loads. DeepSeek adopted a redundant experts strategy and EPLB (https://github.com/deepseek-ai/EPLB) is the load balancing algorithm to achieve the balancing goal.

Expert Distribution

With this PR: #4435 we did some initial analysis of expert distribution. We ran the DeepSeek-R1 model with EPMoE enabled on an 8 * H20 server. Both EP and TP are set to 8. For a sample request of "What is the capital of France?", and the expert usage heat map for the first MoE layer is as below. The summary statistics shows leveraging eplb could balance the workload for each GPU better.

Image

With the above data, we tried to run against eplb.py. Take the first layer count number raw data as an example:

tensor([489, 185,  49, 476,  89, 168, 203, 213,  40, 180, 179, 105, 158,  28,
        165, 385,  51, 341, 254, 228, 160,  50, 105, 199, 202,  54,  82,  52,
         30, 293,  90, 342,  24,  72,  47, 158, 261, 444,  71,  64,  72, 230,
        213, 180, 119,  23, 108, 310, 210, 167, 100,  60,  28,  46, 269, 158,
         70, 189,  54,  49,  31, 155, 239, 121,  48,  89, 162, 154,  58,  81,
         68,  18, 111,  24, 224, 224, 142, 182,  69, 131, 135, 320, 161, 113,
         89, 293, 283,  69, 114, 403, 106,  53,  57, 128, 120,  35, 560, 190,
        119, 108, 190, 170, 124, 173, 171,  70,  40, 143,  65, 671, 102,  31,
        136,  30,  95,  52, 127, 117, 111,  95,  32,  43, 166, 187, 114,  47,
        195, 112, 189,  27,   9, 177,  38, 202,  25, 121, 105,  28,  15, 713,
         55,  40, 288, 181, 135,  47,  50,  29,  83, 145, 215,  46,  63, 198,
         88, 105,  45, 131,  66,  43,  24, 103, 151, 155, 173, 153,  66,  48,
         36,  64,  73,  51, 175,  49, 123,  25,  49,  56,  86,   4,  48,  86,
         43,  64, 229,  50,  18,  10, 207,   4,  12, 128, 164,  70, 105, 419,
         72, 117,  16,  95,  99,  31,   9,  56,  29, 105, 105,  12, 107,  30,
        110,  17,  28,  35, 186, 130, 182, 162,  20,  35,  19, 175,  38,  21,
         46,  81,  68,  78,  23,  63, 173,  21,  25,  26,  52,  89,  68,  37,
         37,  50,  64,  56, 137,  72,  18, 138,   5,  64, 171,  24,  15,  18,
         17,  30, 100,  57], dtype=torch.int32)

num_replicas = 288
num_groups = 4
num_nodes = 1
num_gpus = 8

phy2log for the first layer is:

phy2log[0]:
# expert index
tensor([109,  24,  15, 123,  77, 103,  31,  49,   3,  47,  29, 142,  36,  18,
         62,  41,  75,  99, 161, 119, 178, 196, 226, 166,  92, 203,  21, 125,
         53, 168,  60,  28,  32, 249, 138, 246, 109, 133,  97,  57, 216, 139,
        101,  66,  81, 163,  85, 107,  36, 191,  62, 184,  74,  46, 136,  65,
        154, 227, 236, 247, 177,  25,  16,  34, 224, 238, 113, 253, 232, 231,
        211, 130, 109,  89,  15, 128,  43, 139,   5,   0,  82,  61,  85,  76,
        157, 215,  98,  88, 127, 208, 155,   4, 235, 170, 105,  39, 152, 234,
         59, 145, 151,  95, 213,  52,  73, 218, 244, 179,  37, 188, 100,  96,
        143, 131,  17,   0,   3,  35,  29, 245, 144, 116,  44, 184,  75,  72,
         50, 200,  26, 243,  87, 108,  68,  58, 171, 167, 159, 106, 209,  13,
        129, 186,  71, 207,  37,   6, 126,  96,   9, 172, 104,   0, 217,  67,
        162, 242,  80, 102, 135,  19,  83,  11, 205, 254,  69,  33,  78, 158,
        255, 140, 185, 180, 156, 237, 111, 147, 134,  45, 250, 187, 150, 195,
        153, 214,  10, 230, 248,  14,  81,  47, 165, 112,  79,  93,  94, 124,
         19, 206, 194, 199, 148, 225,  70, 240, 241, 115, 239, 173, 141,   8,
        219, 212, 160, 220, 251, 190,   7,  48,  23,  96, 139, 221,  17, 192,
         20,  12, 142,  86,  54,  18, 117, 197, 118,  90,  22, 114, 181,  38,
        193, 169,  51,  91, 146,   2, 182, 222, 201, 204, 175, 228, 198, 189,
         42, 195,  89,   1, 139, 164,  31, 122,   3,  55, 149,  86,  54, 174,
         63,  41,  74, 210, 110,  30,  84,  40,  56, 183, 229,  27, 176,  64,
        121, 132, 120, 137, 233, 223, 252, 202])

Overloaded experts like expert 109 (count = 671) and expert 139 (count = 713) get several replicas on different GPUs, which makes sense.
With EPLB, experts are duplicated when needed and are packed into groups according to the estimated workload for each expert. To estimate the benefit of using EPLB, one can first calculate average workload for each physical expert and then sum up the expert workload assigned on each GPU. We compared GPU workload distributions using a naive baseline with default contiguous expert-to-GPU mapping against the EPLB strategy:

Image

GPU workload without EPLB: [5645, 4342, 4264, 4586, 3702, 2563, 2799, 1923]
Mean: 3728, Std: 1227.908

Image

Expected GPU workload with EPLB (assume experts with same index will split the workload evenly, unrounded):
[3725.33, 3728.92, 3724.42, 3730.83, 3729.66, 3731.5, 3724.42, 3728.92]
Mean: 3728, Std: 2.867

Proposed Changes

EPLB provides a potentially better way of assigning experts to each GPU, all within just ~160 lines of Python code. In order to have the benefits of load balancing, we will need to integrate the algorithm into the inference framework and complete the necessary features needed in other components. The feature will involve changes from multiple components.

EPLB Worker

  1. Weight Loader
    python/sglang/srt/layers/moe/ep_moe/layer.py, EPMoE implementation class. The changes being proposed will not have conflict with the DeepEP implementations which focus more on communications.

The current implementation manages the weights by expert_id, meaning the loading of an expert will load all its layers into the GPU memory. However, EPLB provides finer granularity at layers, and EPLB will mix the layers from different experts and place them together into the same GPU.

Image

a. create_weights(). In the specific Quantization Method implementation, we will need to allocate additional memory to store the weights of replica expert layers:

    def create_weights(
        self,
        layer: torch.nn.Module,
        num_experts_per_partition: int,
        hidden_size: int,
        intermediate_size: int,
        params_dtype: torch.dtype,
        **extra_weight_attrs,
    ):

b. weight_loader(). Current implementation loads the weight to an array indexed by expert_id. Experts are continuous, and described by start_expert_id and end_expert_id. With EPLB, the experts and their layers being handled by each GPU (TP Worker) are totally shuffled. So the function will need to be enhanced to have two level lookup tables.
e.g. expert_id_mappings is a two-level lookup table, which stores Expert "E" Layer "L" in an array of [GPU "G1" OFFSET "K1", GPU "G2" OFFSET "G2"]. e.g. expert_id_mappings[1, 58] = [(3, 23), (5, 1)] means layer 58 of expert 1 is at two locations: GPU 3 at offset 23, and GPU 5 at offset 1. This mapping is maintained globally and dynamically changed by the weight_loader function to reflect the latest status (i.e. registration).

    def weight_loader(
        self,
        param: torch.nn.Parameter,
        loaded_weight: torch.Tensor,
        weight_name: str,
        shard_id: str,
        expert_id: int,
    ) -> None:

The weight loader will need to have "layer_id" passed in, and loads the [Expert ID, Layer ID] into the correct GPU memory slot. Instead of being called only once at startup time, the function needs to be called frequently when EPLB rebalancing is taking in place.

Future Features: Zero-overhead expert movement. Loading from disk is slow, so a potential solution is to load the weights in the background, and only activate/register to MOEGatings once the weight is completely loaded.

  1. MOE Gating
    python/sglang/srt/layers/moe/ep_moe/layer.py
    The first three layers of DeepSeek v3/R1 model are MLP layers, and starting from the 4th layer are MOE layers. MOEGating will select the expert with select_experts().
        topk_weights, topk_ids = select_experts(
            hidden_states=hidden_states,
            router_logits=router_logits,
            top_k=self.top_k,
            use_grouped_topk=self.use_grouped_topk,
            renormalize=self.renormalize,
            topk_group=self.topk_group,
            num_expert_group=self.num_expert_group,
            correction_bias=self.correction_bias,
            custom_routing_function=self.custom_routing_function,
        )

select_experts() gives the target expert for the current MoE layer, and the output of topk_ids contains only the experts. Before the EPLB, the target expert is uniquely stored in designated GPUs. With EPLB, the layers of each expert are shuffled with replicas, so MOEGating needs to determine the exact node among multiple replicas, following certain preferences or algorithms. Below pseudo code demonstrates the idea:

current_expert_id_mappings[1, 58] =  [(3, 23), (5, 1)]
best_replicas = choose_best_replica(topk_ids, layer_id, expert_id_mappings)

current_expert_id_mapping is dynamically changing when EPLB rebalancing is happening, but it will reach its final state and match the output derived from the EPLB output when the weight action is complete. With the current mapping, the MOE Gating choose_best_replica() will select the best replicas for each topk being selected to be the final, and communicate via all-reduce or all-to-all to the target GPU node which has the target weight. The algorithm will need to consider multiple factors to make optimal decisions, including the distances (local > cross-GPU > cross-node), one communication contains most of the tensors.

EPLB Manager

python/sglang/srt/managers/eplb_manager.py
EPLB needs to be run periodically to evaluate expert distributions to make replication decisions. It could be triggered by a certain threshold or simply a timer. The EPLB manager is designed to:

  1. Periodically collect the metrics about expert distributions.
    Note: The community has introduced an enhanced feature for recording expert distribution without overhead for EPLB, available at PR Expert distribution recording without overhead for EPLB #4957 (Expert distribution recording without overhead for EPLB #4957.). The manager will utilize the latest available methods to collect distribution data and process it through the EPLB algorithm.
  2. Evaluated the differences between the current running profile against the EPLB results. By comparing the deltas, make decisions to start the replica adjustment actions or not.

Related resources

No response

Metadata

Metadata

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions