[Feature] Add Elasticity Support to DeepEP for Fault-Tolerant EP Inference#370
[Feature] Add Elasticity Support to DeepEP for Fault-Tolerant EP Inference#370sphish merged 10 commits intodeepseek-ai:mainfrom
Conversation
| // Mask rank if timeout | ||
| if (mask_buffer_ptr != nullptr && wait_recv_cost > NUM_TIMEOUT_CYCLES) { | ||
| atomicExch(mask_buffer_ptr + dst_rank, 1); | ||
| // printf("[rank %d] Clean LL buffer: rank %d is masked due to timeout\n", rank, dst_rank); |
There was a problem hiding this comment.
An error needs to be raised when a timeout occurs and not enable_shrink.
| if (mask_buffer_ptr == nullptr || ld_acquire_sys_global(mask_buffer_ptr + dst_rank) == 0) { | ||
| // Update remote counter | ||
| if (dst_p2p_ptr == 0) { | ||
| nvshmemi_ibgda_amo_nonfetch_add(reinterpret_cast<int*>(dst_ptr), -1, dst_rank, 0); |
There was a problem hiding this comment.
We need quiet all QPs before the barrier.
| clean_0, num_clean_int_0, clean_1, num_clean_int_1); | ||
|
|
||
| if (sync_buffer_ptr == nullptr) { | ||
| LAUNCH_KERNEL(&cfg, clean_low_latency_buffer<kNumThreads>, |
There was a problem hiding this comment.
I think we should merge these two kernels.
| UNROLLED_WARP_COPY(8, lane_id, num_int4_per_msg, dst_int4_ptr, src_int4_ptr, ld_nc_global, st_na_global); | ||
| } | ||
| } | ||
|
|
| if (mask_buffer_ptr != nullptr && wait_recv_cost > NUM_TIMEOUT_CYCLES) { | ||
| if (local_expert_idx == 0) { | ||
| atomicExch(mask_buffer_ptr + src_rank, 1); | ||
| // printf("[rank %d] Dispatch: rank %d is masked due to timeout\n", rank, src_rank); |
There was a problem hiding this comment.
And if not enable_shrink, a error should be raised.
| while ( | ||
| (mask_buffer_ptr == nullptr || ld_acquire_sys_global(mask_buffer_ptr + src_rank) == 0) // rank not masked | ||
| && (num_recv_tokens = ld_acquire_sys_global(rdma_recv_count + local_expert_idx * num_ranks + src_rank)) == 0 // data not arrived | ||
| && (mask_buffer_ptr == nullptr || ((wait_recv_cost = clock64()-start_time) <= NUM_TIMEOUT_CYCLES)) // not timeout |
| nvshmemi_ibgda_amo_nonfetch_add(reinterpret_cast<int*>(dst_ptr), 1, dst_rank, local_expert_idx); | ||
| } else { | ||
| st_release_sys_global(reinterpret_cast<int*>(dst_p2p_ptr), 1); | ||
| if (mask_buffer_ptr == nullptr || ld_acquire_sys_global(mask_buffer_ptr + dst_rank) == 0) { |
There was a problem hiding this comment.
This condition appears many times; I think it can be replaced with an inline function like is_masked(rank).
|
all fixed 😄 |
| if (not is_rank_masked(mask_buffer_ptr, dst_rank)) { | ||
| if (dst_p2p_ptr == 0) { | ||
| nvshmemi_ibgda_rma_p(reinterpret_cast<int*>(dst_ptr), cnt, dst_rank, 0); | ||
| // nvshmemi_ibgda_amo_nonfetch_add(reinterpret_cast<int*>(dst_ptr), -1, dst_rank, 0); |
| auto start_time = clock64(); | ||
| uint64_t wait_recv_cost = 0; | ||
| while ((ld_acquire_sys_global(sync_buffer_ptr + dst_rank) != cnt) // remote is not ready | ||
| && (wait_recv_cost = clock64()-start_time) <= NUM_TIMEOUT_CYCLES // not timeout |
| uint64_t wait_recv_cost = 0; | ||
| while (not is_rank_masked(mask_buffer_ptr, src_rank) // rank not masked | ||
| && ld_acquire_sys_global(rdma_recv_flag + responsible_expert_idx) == 0 // recv not ready | ||
| && (wait_recv_cost = clock64()-start_time) <= NUM_TIMEOUT_CYCLES // not timeout |
|
Thanks! |
|
hello! I'd like to ask whether fault tolerance can be correct in this case. Generally, if there is a hardware problem in the GPU, Cudakern will make an error, and then the process will report an error and quit, or a Segmentaion fault. Then in another process, pytorch will sense that other processes quit, and then it will start an avalanche (all pytorch processes will die, which is what I observed). I don't know if you have any modifications in pytorch and then avoid this avalanche phenomenon. |
Hi! |
|
@jammycc Regarding the design mentioned above, is there an existing PR for its vLLM integration? I’d like to try running it myself but I’m not sure which PR to follow. Could you provide a link? Thanks. |
Hi, |
@jammycc Is there a WIP pull request for vLLM integration? |
Motivation and Target
This PR ensures that large-scale Expert Parallel (EP) inference remain operational in a downscaled state when encountering failures — such as individual GPU, NIC, or node outages, or even corner-case engine failures, instead of a full teardown.
This provides the following benefits:
The downscaled deployment can still provide degraded (lower-capacity) service
For multi-endpoint/API-server setups, most requests will only experience a brief stall before resuming processing, rather than being entirely aborted.
This graceful degradation provides sufficient time for instance-level redeployment or recovery actions.
This PR includes
Enable
dispatchandcombineoperations to bypass failed ranks by introducing per-rank status information in the buffer, achieving real-time degradation.Decentralized rank failure detection, with each rank monitoring communication timeout to determine failure.
Provide query and update APIs for the engine to query and update per-rank status.
This PR does not include (as future work)
Support for Prefill phase / Normal kernels
If a Prefill request fails, it usually needs to be fully re-executed anyway, making fault tolerance less impactful.
In PD-disaggregated architectures, the Prefill stage typically uses smaller EP sizes — sometimes even single-node deployments.
Dynamic scaling up (expansion)
DeepEP Modifications
Currently, DeepEP’s low-latency (LL) kernels hang when any rank fails, causing the entire EP inference service and request processing pipeline to stall. We aim to enhance DeepEP with failure detection, graceful exit on faults, and dynamic downscaling. To this end, all communication-related APIs — including
low_latency_dispatch,low_latency_combine, andclean_low_latency_buffer— must support these capabilities.Additionally, since other engine-level collective communications (e.g., metadata synchronization within DP groups) can also be affected by rank failures, we introduce a new
low_latency_allgatherAPI to replace these collectives.1. Failure Detection and Downscaling via Mask Buffer
Each node maintains a mask buffer to track the status of all ranks. Once a rank failure is detected, its corresponding status in the mask buffer is set to 1, indicating that subsequent communication with this rank should be skipped.
Key Changes:
Add
enable_elasticparameter during buffer initialization to enable elastic mode. This triggers allocation of an additionalmask_bufferof sizenum_ranks(int32).Add timeout-based failure detection and rank masking in all communication APIs (
low_latency_dispatch,low_latency_combine,clean_low_latency_buffer) :If a rank's data is not received within
NUM_TIMEOUT_CYCLES, its status in the mask buffer is set to 1.Subsequent communication will neither receive data from nor send data to masked ranks.
New APIs:
low_latency_query_mask_buffer: Query current rank mask status.low_latency_update_mask_buffer: Manually update the mask status of a specific rank.2. Synchronization Mechanism via Sync Buffer
For non-dispatch/combine communication (e.g.,
clean_low_latency_buffer), global synchronization is required. Previously, this relied on NVSHMEM barriers, which fail if any rank is down.To support synchronization in downscaled scenarios:
Introduce a dedicated sync_buffer to track synchronization counters.
Replace global
nvshmem_barriercalls inclean_low_latency_bufferwith sync-buffer-based logic, enabling barrier even when some ranks are masked.3. New Collective Communication:
low_latency_allgatherWe observed that non-EP collectives (e.g., metadata sync in pure DP attention) can also be blocked by rank failures. To address this, we implement a fault-tolerant
low_latency_allgatherAPI to replace these communication.Add
num_coll_buffer_bytesparameter during initialization to allocate space for collective communication buffers.Use the sync_buffer for synchronization and skip masked ranks.
Engine Integration Example
We have validated this design by integrating DeepEP into vLLM. We briefly describe below how the engine integrates with DeepEP’s fault-tolerance features, using vLLM v1 as an example.
1. Enable Elastic Mode in DeepEP Initialization
2. Check Status After Each Decode Step
In
gpu_worker.py, theWorkerclass exposes acheck_step_statusmethod, called via RPC byEngineCoreto check for failures:3. Handle Detected Failures
Once
EngineCoredetects new failed ranks:Report to
LLMEngineTrigger expert rebalance (EPLB)
Recover requests (via migration or recomputation)
4. Handling Externally Triggered Failures
In
gpu_worker.py, theWorkerclass exposes aupdate_deepep_mask_buffermethod, called via RPC byEngineCoreto manually mask a rank:5. Redundant Expert Slots for CUDAGraph
Reserve extra expert slots per rank to accommodate experts from failed ranks.
Ensures sufficient capacity for redistribution upon single-rank failure.
6. Decoupling from Other Communication Backends
Replace other collectives (e.g., DP-group metadata sync) with DeepEP’s extended
low_latency_allgather.TP-group communication is not replaced, as entire TP groups can be treated as atomic units — if one fails, the whole group is isolated.