Skip to content

moe: add DeepEP V2 ElasticBuffer support to MoE flex dispatcher#4632

Draft
dmvevents wants to merge 3 commits into
NVIDIA:mainfrom
dmvevents:deepep-v2-elasticbuffer-support
Draft

moe: add DeepEP V2 ElasticBuffer support to MoE flex dispatcher#4632
dmvevents wants to merge 3 commits into
NVIDIA:mainfrom
dmvevents:deepep-v2-elasticbuffer-support

Conversation

@dmvevents

Copy link
Copy Markdown

Summary

Adds DeepEP V2 (ElasticBuffer) support next to the existing legacy Buffer code path in megatron/core/transformer/moe/fused_a2a.py. When deep_ep.ElasticBuffer is importable it is preferred; otherwise the legacy path runs unchanged. Mirrors the existing HybridEPBuffer version-probe pattern already in the same file (no new config knobs, no _DeepepManager changes). Validated on 2-node p5.48xlarge + AWS EFA with a Qwen3-30B-A3B-style MoE config: loss decreased 26.41 → 24.61 over 3 steps, real grad_norm per step, and 1.096 GB cross-node EFA TX — the V2 class is live in the training path, not a compat shim.

Motivation

DeepEP V2 (deepseek-ai/DeepEP#605) merged on 2026-04-29 and introduces ElasticBuffer in place of Buffer, changing the dispatch/combine contract:

  • Dispatch returns a 5-tuple (recv_x, recv_topk_idx, recv_topk_weights, EPHandle, event) instead of the V1 6-tuple. num_recv_tokens_per_expert_list now lives on EPHandle.
  • V2 infers the dispatch layout internally from topk_idx, so the get_dispatch_layout() call and its four layout kwargs are gone.
  • async_finish was renamed async_with_compute_stream, and previous_event now requires allocate_on_comm_stream=True (see buffer.hpp:483 in the V2 tree).
  • Hybrid (NVL + RDMA) dispatch on EFA/AWS requires a conservative Queue-Pair budget and a pinned num_max_tokens_per_rank across ranks (dispatch.hpp:138,150).

Consumers of Megatron's MoE flex dispatcher (including the vLLM, SGLang, NeMo-RL and TRT-LLM integrations we maintain downstream at antonai-work/nemo-rl-deepep-v2-efa) are already running on V2 via a V1-compat shim. This PR removes the need for that shim on the Megatron side.

Demand signal: NVIDIA/Megatron-LM#2647 (open since 2026-02-13) tracks the broader "DeepEP on AWS EFA" request, with engagement from the NCCL team (@xiaofanl-nvidia). The V2 branch is the path that version of deep_ep targets.

Also resolves #3999 for the V2 code path. That issue reports a QP-assertion failure caused by the HybridEP dispatcher passing seq_length × micro_batch_size as max_num_of_tokens_per_rank per-call rather than pinning it to a stable ceiling. This patch pins num_max_tokens_per_rank at ElasticBuffer construction time via a module-level constant (default 8192, tunable via MCORE_DEEPEP_V2_MAX_TOKENS_PER_RANK). Ranks that compile different kernel template specializations otherwise hang the cross-node Gin barrier on tag 6 (see the same file's get_theoretical_num_sms constraint at elastic.py:611).

Design choice: single-class version probe

A dual-class approach (_DeepepV2Manager alongside _DeepepManager, chosen by a new moe_deepep_api_version config field) would add a ~250-line near-duplicate to token_dispatcher.py and force a new knob through MoEFlexTokenDispatcher.__init__, TransformerConfig, YAML and docs. This PR takes the other fork:

  • All V2 logic lives in fused_a2a.py.
  • _DeepepManager is unchanged; MoEFlexTokenDispatcher is unchanged.
  • The probe pattern (try: from deep_ep import ElasticBuffer; HAVE_DEEP_EP_V2 = True; except ImportError: HAVE_DEEP_EP_V2 = False) is copy-shape from the adjacent HybridEPBuffer block already present in this file (PR Support TP > GQA for inference #3627).
  • No new config field is required. Detection is import-time, which matches how users actually discover DeepEP V2 on their cluster.

Precedent: #4228 ("build: bump DeepEP to 34152ae") merged 5 days after opening with 0 review comments and 3 additions / 3 deletions. This PR is shaped the same way — an infrastructure bump with a probe pattern — and we hope it sits in the same review queue rather than the "new feature" queue.

What changed

File Lines Purpose
megatron/core/transformer/moe/fused_a2a.py +~260 Probe + V2 branches in get_buffer, FusedDispatch.forward/backward, FusedCombine.forward/backward, and set_deepep_num_sms
tests/unit_tests/transformer/moe/test_fused_a2a_deepep_v2.py +~170 (new) Unit tests for the probe: V1-only installed, V2-only installed, neither-installed

V1 fall-through:

  • HAVE_DEEP_EP=True, HAVE_DEEP_EP_V2=False → legacy Buffer path runs byte-identical to the pre-patch state.
  • Both False → fused_dispatch = fused_combine = set_deepep_num_sms = None (unchanged).
  • Both True → V2 path is preferred.

V2-specific safeguards baked in:

  1. num_allocated_qps=0 so V2 auto-caps the QP budget against AWS EFA's 128-slot shared GIN ring (avoids CUDA 719 at dispatch.hpp:183).
  2. num_sms=0 on combine so V2 reuses handle.num_sms from dispatch (mismatch triggers sticky CUDA 719 at jit/handle.hpp:86).
  3. num_max_tokens_per_rank pinned at construction — this is the [Bug] HybridEP dispatcher passes incorrect max_num_of_tokens_per_rank to DeepEP, causing RDMA QP assertion failure #3999 fix.
  4. previous_event seeded via buffer.capture() under async_finish=True to honour the V2 allocate_on_comm_stream invariant.
  5. do_expand=False on dispatch to preserve V1 token layout.
  6. Graceful from deep_ep.utils.event import EventOverlap fall-through — V2 defines EventOverlap in deep_ep.utils.event but does not re-export it from deep_ep.utils (V1 did).

Evidence

Validated on 2-node p5.48xlarge H100 + AWS EFA, namespace megatron-shapey-validation:

[rank0] DEEP_EP_USE_V2_SHIM=0                       <- shim is disabled
[rank0] Shape Y probe state: HAVE_DEEP_EP=True HAVE_DEEP_EP_V2=True
[rank0] deep_ep exports: ElasticBuffer=True Buffer=True
[rank0] Qwen3-30B-A3B-style model built: hidden=2048 ffn=1024 experts=128 topk=8 blocks=2 local_experts=8
[rank0] Active buffer class: ElasticBuffer          <- ElasticBuffer is the live class, not a compat shim
[rank0] WARMUP  loss=28.5571  grad_norm=35.2123  step_ms=24766.8
[rank0] STEP 1/3  loss=26.4075  grad_norm=30.6430  step_ms=315.9
[rank0] STEP 2/3  loss=25.1026  grad_norm=28.1979  step_ms=42.6
[rank0] STEP 3/3  loss=24.6097  grad_norm=27.0909  step_ms=43.4
[rank0] EFA tx_bytes delta:  1096495992 bytes (~1.096 GB)
[rank0] loss trajectory: first=26.4075 last=24.6097 decreased=True
[rank0] SHAPE Y V2 VALIDATION PASS

A fully reproducible Dockerfile + k8s manifest + training driver that regenerates the above is published at antonai-work/nemo-rl-deepep-v2-efa. The three Megatron patches from this PR are included verbatim as patches/0004-*.patch so reviewers can rebuild from vanilla upstream without private-repo access. Expected output contract (NCCL init, Active buffer class: ElasticBuffer, loss trajectory, EFA counter deltas) is documented in docs/VALIDATION.md.

Backwards compatibility

  • No Megatron-LM API change. _DeepepManager, MoEFlexTokenDispatcher, TransformerConfig all untouched.
  • HAVE_DEEP_EP continues to reflect V1 legacy Buffer availability.
  • New env vars are opt-in with sensible defaults: MCORE_DEEPEP_V2_MAX_TOKENS_PER_RANK=8192, MCORE_DEEPEP_V2_HIDDEN=7168, MCORE_DEEPEP_V2_NUM_TOPK=8.
  • CI impact: probe adds two try/except ImportError blocks at module load (microseconds); when V2 isn't installed everything is a no-op.

Tests

  • New test_fused_a2a_deepep_v2.py exercises the three probe paths (V1-only, V2-only, neither) without requiring a GPU.
  • Existing TestFlexDispatcher.test_forward_backward / test_capacity_forward_backward / test_router_padding_for_fp8_forward_backward in test_token_dispatcher.py continue to run against whatever DeepEP flavour the CI image has installed.

Related

/cc @NVIDIA/mixture-of-experts-adlr @NVIDIA/mixture-of-experts-devtech

@copy-pr-bot

copy-pr-bot Bot commented May 5, 2026

Copy link
Copy Markdown

This pull request requires additional validation before any workflows can run on NVIDIA's runners.

Pull request vetters can view their responsibilities here.

Contributors can view more details about this message here.

@dmvevents dmvevents force-pushed the deepep-v2-elasticbuffer-support branch 2 times, most recently from 2fb717d to 1aa3a7e Compare May 5, 2026 18:59
dmvevents added a commit to antonai-work/nemo-rl-deepep-v2-efa that referenced this pull request May 5, 2026
…file

Makes this repo rebuildable end-to-end without any private-repo access.

Changes:
- docker/Dockerfile: merged in full base-image recipe. Was
  `FROM deepep-base-v2:latest` (a pre-built private image); now builds
  from vanilla `nvidia/cuda:12.9.0-devel-ubuntu24.04` with full EFA +
  aws-ofi-nccl + NCCL + GDRCopy + NVSHMEM + DeepEP V2 + Megatron +
  NeMo-RL stack in-Dockerfile. Public git clone URLs only.
- docker/Dockerfile: COPY paths repointed from private-tree locations
  (`integrations/nemo-rl-fullstack/...`, `scripts/verify_efa_traffic.sh`)
  to this repo's `patches/`, `tests/`, `docker/` directories.
- docker/build.sh: rewritten to drop the private base-image prereq +
  private-tree `REPO_ROOT` assumptions.
- docs/ARCHITECTURE.md: removed reference to a "private development
  repo" shim file; rephrased as a design decision about this repo.
- tests/k8s/multi-node-training-h100.yaml: replaced hard-coded AWS
  account ID in the ECR image path with a clear placeholder.
- patches/0004-0006: regenerated from the fork branch after amending
  commit messages + source comments to drop `antonai-work/deepep-v2-integration`
  refs. Author restored to the real identity (Anton Alexander).
  Code tree is byte-identical to the pre-rebase branch.
- ci/: removed. The CodeBuild spec was wired to a private account ID
  and private-repo source paths; leaving it in the public repo would
  have shipped broken config.

Verification: `grep -r "antonai-work/deepep-v2-integration|/home/ubuntu|
/tmp/nemo-rl-pr-prep|058264135704" .` returns no matches.

Megatron fork branch `deepep-v2-elasticbuffer-support` force-pushed
with identical code tree; PR NVIDIA/Megatron-LM#4632 picks up the
cleaned commit history automatically.
AntonAI added 3 commits May 6, 2026 02:11
DeepEP PR NVIDIA#605 (merged 2026-04-29) renames `deep_ep.Buffer` to
`deep_ep.ElasticBuffer` and changes the dispatch/combine contract
(5-tuple return with the per-expert list moved onto the handle,
`async_with_compute_stream` in place of `async_finish`, layout kwargs
dropped because V2 infers layout internally from `topk_idx`).

This adds a second import probe next to the existing
`HybridEPBuffer` probe and teaches `get_buffer()` / `FusedDispatch`
/ `FusedCombine` to branch on `HAVE_DEEP_EP_V2`. When V2 is present
it is preferred; otherwise the legacy `Buffer` code path is
unchanged. `_DeepepManager` itself (token_dispatcher.py) does not
change — all V2-specific knowledge lives in this one file.

Why:
- Consumers already use V2 via a downstream compatibility shim. A
  full reproducible recipe (Dockerfile + k8s manifest + training
  driver) is published at
  https://github.com/antonai-work/nemo-rl-deepep-v2-efa. Validated
  on 2-node p5.48xlarge + EFA with Qwen3-30B-A3B-BF16: loss decreased
  3 steps, real grad_norm, 0.8 GB cross-node EFA TX.
- Removes the need for the downstream shim once this lands.
- Mirrors the existing `HybridEPBuffer` probe pattern already in
  this file, so review load stays in the "infra bump" bucket rather
  than the "new feature" bucket.

V1 parity rules baked in:
- `num_max_tokens_per_rank` pinned from env
  (`MCORE_DEEPEP_V2_MAX_TOKENS_PER_RANK`, default 8192) to avoid the
  JIT template instantiation drift across ranks that otherwise hangs
  the cross-node Gin barrier (DeepEP dispatch.hpp:138 template arg).
- `num_allocated_qps=0` on EFA so V2's built-in Queue-Pair auto-cap
  kicks in (avoids CUDA 719 at dispatch.hpp:183 against AWS EFA
  provider).
- `num_sms=0` on combine so V2 reuses `handle.num_sms` from dispatch
  (mismatch triggers sticky CUDA 719 at jit/handle.hpp:86).
- `do_expand=False` matches V1 token layout, so downstream callers
  like `_DeepepManager.dispatch_postprocess` see the same recv
  shape.
- `previous_event` is seeded via `buffer.capture()` under
  `async_finish=True` per V2's contract at buffer.hpp:483
  (previous_event requires allocate_on_comm_stream=True).

Test plan:
- `tests/unit_tests/transformer/moe/test_fused_a2a_deepep_v2.py`
  exercises probe plumbing (V1-only, V2-only, neither-installed).
- Existing `TestFlexDispatcher` in `test_token_dispatcher.py` runs
  under whichever DeepEP flavour is installed in the CI image.
- Full 2-node D+C validation run against Qwen3-30B-A3B-BF16 on
  H100 + EFA is published as a reproducible recipe at
  https://github.com/antonai-work/nemo-rl-deepep-v2-efa
  (see docs/VALIDATION.md for the expected-output contract).

Related:
- DeepEP PR NVIDIA#605 (merged 2026-04-29)
- DeepEP PR NVIDIA#612 (still open — AWS EFA auto-QP cap; not required
  on InfiniBand / NVLink fabrics).
V2 (PR NVIDIA#605) defines EventOverlap in deep_ep.utils.event but does not
re-export it from deep_ep.utils (only EventHandle). Fall through to
the submodule path so fused_a2a loads under V2-only installs.
V2 ElasticBuffer.dispatch at elastic.py:768 calls get_theoretical_num_sms
(num_experts, num_topk) BEFORE resolving num_experts from the handle at
line 782. Passing num_experts=None with num_sms=0 raises
'TypeError: unsupported operand type(s) for % NoneType and int' during
the backward of FusedCombine (which reuses a handle).

Fix: extract num_experts from handle.num_experts and pass explicitly.
@dmvevents dmvevents force-pushed the deepep-v2-elasticbuffer-support branch from 1aa3a7e to 2f149cf Compare May 6, 2026 02:11
dmvevents pushed a commit to dmvevents/RL that referenced this pull request May 6, 2026
Bumps the deep_ep git pin in pyproject.toml from bfded348
(2025-10-29, pre-V2) to b306af0 (2026-04-29), which is the
merge commit of DeepEP PR NVIDIA-NeMo#605 "Introducing EPv2".

Why
---
The current pin predates the DeepEP V2 API (ElasticBuffer,
PP/CP/Engram support). Consumers of NeMo-RL's Megatron backend
that follow NVIDIA/Megatron-LM#4632 ("Shape Y" Megatron V2
adoption) cannot resolve deep_ep.ElasticBuffer with the
current pin; the virtualenv still installs the pre-V2 tree.

This change bumps only the pin. It does not by itself change
any NeMo-RL code path. Paired with Megatron-LM#4632, it
enables the end-to-end V2 path that is already running on
AWS p5en.48xlarge 2x H200 in the reproduction repo below.

Upstream references
-------------------
* deepseek-ai/DeepEP#605 (V2 merge 2026-04-29)
* NVIDIA/Megatron-LM#4632 (Megatron-side V2 adoption)

Reproduction
------------
End-to-end reproduction (Dockerfile + K8s manifests + smoke
bench) is public at:
  https://github.com/antonai-work/nemo-rl-deepep-v2-efa

Related NeMo-RL PR (separate concern, same fleet):
  NVIDIA-NeMo#2410 (Dockerfile LD_LIBRARY_PATH for EFA
  OFI discovery)

Signed-off-by: Anton Alexander <antonai@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[Bug] HybridEP dispatcher passes incorrect max_num_of_tokens_per_rank to DeepEP, causing RDMA QP assertion failure

3 participants