Summary
When using the HybridEP backend (--moe_flex_dispatcher_backend hybridep) for MoE expert-parallel training across multiple nodes, HybridEPDispatch.forward() in fused_a2a.py incorrectly uses the total number of tokens in a micro-batch (seq_length × micro_batch_size) as the max_num_of_tokens_per_rank parameter for HybridEPBuffer. This causes the RDMA Queue Pair send-queue depth (tx_depth) to exceed the hardware limit of 65535, triggering an assertion failure in DeepEP's internode communication initialization.
Environment
- Megatron-LM: latest (hybrid-ep branch of DeepEP integrated)
- DeepEP: v1.2.1 (hybrid-ep branch)
- Model: Qwen3-30B-A3B (MoE, 128 experts)
- Hardware: 2 nodes × 8 GPUs, InfiniBand RDMA interconnect
- Training config:
--max_length 8192 --micro_batch_size 8 --packing true --expert_model_parallel_size 16 --moe_flex_dispatcher_backend hybridep
Error
python: /path/to/DeepEP/csrc/hybrid_ep/buffer/internode.cu:167:
void setup_qp_init_attr(..., int): Assertion `tx_depth > 0 && tx_depth < 65536' failed.
All ranks crash with SIGABRT (signal 6) during HybridEP buffer initialization.
Root Cause
The call chain
-
MoEFlexTokenDispatcher.dispatch_preprocess() (token_dispatcher.py:1438) reshapes hidden_states from [seq_length, batch_size, hidden_size] to [seq_length * batch_size, hidden_size]:
hidden_states = hidden_states.view(-1, self.hidden_shape[-1])
-
HybridEPDispatch.forward() (fused_a2a.py:354-359) extracts the first dimension of the already-flattened tensor and uses it as seq_len:
if _hybrid_ep_buffer is None:
seq_len, hidden_dim = x.shape[-2:] # x is [seq_len * batch_size, hidden_dim]
init_hybrid_ep_buffer(group, hidden_dim, seq_len, ...) # seq_len is actually num_total_tokens
-
init_hybrid_ep_buffer() (fused_a2a.py:316) passes this value directly as max_num_of_tokens_per_rank:
_hybrid_ep_buffer = HybridEPBuffer(
group=group,
hidden_dim=hidden_dim,
max_num_of_tokens_per_rank=seq_len, # <-- This is seq_length * micro_batch_size, not per-rank tokens
...
)
-
DeepEP internode.cu uses max_num_of_tokens_per_rank to compute the RDMA QP send-queue depth:
// internode.cu:430 (dispatch)
setup_qp_init_attr(..., 3 * buffer_config.max_num_of_tokens_per_rank + 1);
// internode.cu:587 (combine)
setup_qp_init_attr(..., 2 * buffer_config.max_num_of_tokens_per_rank + 1);
-
The assertion enforces the IB hardware limit:
assert(tx_depth > 0 && tx_depth < 65536);
Concrete example
With max_length=8192, micro_batch_size=8, packing=true:
| Parameter |
Value |
x.shape[0] (after flatten) |
8192 × 8 = 65536 |
max_num_of_tokens_per_rank passed to DeepEP |
65536 |
dispatch tx_depth |
3 × 65536 + 1 = 196609 |
| Hardware limit |
< 65536 |
The dispatch tx_depth exceeds the limit by 3×.
Why single-node works but multi-node fails
The RDMA QP initialization (and thus the tx_depth assertion) only runs when num_of_nodes > 1. Single-node setups use NVLink-only communication and never hit this code path.
Steps to Reproduce
-
Configure a multi-node MoE training with HybridEP:
--expert_model_parallel_size 16 \
--moe_token_dispatcher_type flex \
--moe_flex_dispatcher_backend hybridep \
--micro_batch_size 8 \
--max_length 8192 \
--packing true
-
Run training across 2+ nodes with RDMA/InfiniBand.
-
Training crashes immediately during HybridEP buffer initialization with:
Assertion `tx_depth > 0 && tx_depth < 65536' failed.
Note: The issue is triggered when 3 × seq_length × micro_batch_size + 1 > 65535, i.e., seq_length × micro_batch_size > 21845. Common configurations like 8192 × 4 = 32768 or 4096 × 8 = 32768 will hit this.
Affected Code
megatron/core/transformer/moe/fused_a2a.py — HybridEPDispatch.forward() (line 354-359) and init_hybrid_ep_buffer() (line 316)
Additional context
Add any other context about the problem here.
Summary
When using the HybridEP backend (
--moe_flex_dispatcher_backend hybridep) for MoE expert-parallel training across multiple nodes,HybridEPDispatch.forward()infused_a2a.pyincorrectly uses the total number of tokens in a micro-batch (seq_length × micro_batch_size) as themax_num_of_tokens_per_rankparameter forHybridEPBuffer. This causes the RDMA Queue Pair send-queue depth (tx_depth) to exceed the hardware limit of 65535, triggering an assertion failure in DeepEP's internode communication initialization.Environment
--max_length 8192 --micro_batch_size 8 --packing true --expert_model_parallel_size 16 --moe_flex_dispatcher_backend hybridepError
All ranks crash with
SIGABRT (signal 6)during HybridEP buffer initialization.Root Cause
The call chain
MoEFlexTokenDispatcher.dispatch_preprocess()(token_dispatcher.py:1438) reshapeshidden_statesfrom[seq_length, batch_size, hidden_size]to[seq_length * batch_size, hidden_size]:HybridEPDispatch.forward()(fused_a2a.py:354-359) extracts the first dimension of the already-flattened tensor and uses it asseq_len:init_hybrid_ep_buffer()(fused_a2a.py:316) passes this value directly asmax_num_of_tokens_per_rank:DeepEP
internode.cuusesmax_num_of_tokens_per_rankto compute the RDMA QP send-queue depth:The assertion enforces the IB hardware limit:
Concrete example
With
max_length=8192,micro_batch_size=8,packing=true:x.shape[0](after flatten)8192 × 8 = 65536max_num_of_tokens_per_rankpassed to DeepEP65536tx_depth3 × 65536 + 1 = 196609< 65536The dispatch
tx_depthexceeds the limit by 3×.Why single-node works but multi-node fails
The RDMA QP initialization (and thus the
tx_depthassertion) only runs whennum_of_nodes > 1. Single-node setups use NVLink-only communication and never hit this code path.Steps to Reproduce
Configure a multi-node MoE training with HybridEP:
--expert_model_parallel_size 16 \ --moe_token_dispatcher_type flex \ --moe_flex_dispatcher_backend hybridep \ --micro_batch_size 8 \ --max_length 8192 \ --packing trueRun training across 2+ nodes with RDMA/InfiniBand.
Training crashes immediately during HybridEP buffer initialization with:
Note: The issue is triggered when
3 × seq_length × micro_batch_size + 1 > 65535, i.e.,seq_length × micro_batch_size > 21845. Common configurations like8192 × 4 = 32768or4096 × 8 = 32768will hit this.Affected Code
megatron/core/transformer/moe/fused_a2a.py—HybridEPDispatch.forward()(line 354-359) andinit_hybrid_ep_buffer()(line 316)Additional context
Add any other context about the problem here.