Skip to content

Inference Optimized MoEs#3496

Merged
sidsingh-nvidia merged 104 commits into
NVIDIA:mainfrom
sidsingh-nvidia:inf-opt-all-gather-dispatcher
Mar 3, 2026
Merged

Inference Optimized MoEs#3496
sidsingh-nvidia merged 104 commits into
NVIDIA:mainfrom
sidsingh-nvidia:inf-opt-all-gather-dispatcher

Conversation

@sidsingh-nvidia

@sidsingh-nvidia sidsingh-nvidia commented Feb 19, 2026

Copy link
Copy Markdown
Contributor

What does this PR do ?

Optimizes MoE decode (and small batch prefill) performance by eliminating host synchronizations, enabling end-to-end CUDA graph capture, and using latency-optimized NVLS collectives.

Motivation

The default MoE layer in Megatron-LM is not well-suited for inference. The AlltoAll dispatcher and GroupedGEMM rely on CPU-resident tensors for token-expert assignments, which breaks CUDA graph capture due to host synchronizations. The current workaround — padding to maximum capacity (routing all tokens to all experts) — enables static shapes but wastes significant compute and communication.

This PR introduces an inference-optimized MoE layer that achieves the best of both worlds: compute/communication-optimal routing with full CUDA graph compatibility. These layers are specially designed to optimize the decode phase, where small batch sizes make host synchronization and kernel launch overhead disproportionately expensive.

Optimizations

  • InferenceGroupedMLP — Expert computation layer that operates directly on GPU-resident token-expert splits, eliminating the host synchronizations that block CUDA graph capture in the default GroupedGEMM path. Uses FlashInfer cutlass_fused_moe (fused permute + GEMM) for CUDA-graphed iterations and torch._grouped_mm with GPU-resident cumsum offsets for eager mode. Inherits from TEGroupedMLP for checkpoint compatibility..
  • InferenceTopKRouter — Stripped-down router that removes training overhead (z-loss, auxiliary losses, token dropping) and is optimized via @torch.compile().
  • InferenceCUDAGraphTokenDispatcher — Replaces AlltoAll with AllGather/ReduceScatter for token exchange, keeping all metadata GPU-resident. Supports latency-optimized NVLS collectives on Hopper+ with automatic NCCL fallback.
  • Fused 3-tensor NVLS all-gather kernel (multimem_all_gather_fused) for routing_map, probs, and hidden_states — single kernel launch + single barrier.

Other Minor Changes

  • InferenceSpecProvider backend for wiring inference-optimized modules into model specs.
  • MoELayer dynamically swaps between standard and inference dispatchers based on is_inference_cuda_graphed_iteration.
  • Centralized NVLS eligibility checks (are_tensors_nvls_eligible, is_device_nvls_capable) shared across TP and EP communication paths.
  • Separate symmetric memory buffer pools for TP and EP (get_global_symmetric_memory_buffer_tp/ep).
  • Multi-tensor packing in symmetric memory buffers (maybe_get_tensors) with 16-byte alignment.
  • Kill-switch via --inference-disable-triton-nvls-kernels config flag. This makes the system fallback to NCCL.
  • Config validation: inference-optimized MoE rejects expert tensor parallelism, capacity-factor routing, and padded routing maps.

How to enable?

these flags -

--transformer-impl inference_optimized \
--moe-router-dtype fp32 # flashinfer only supports fp32 probabilities

⚠️ For major changes (either in lines of code or in its impact), please make sure to first share a design doc with the team. If you're unsure what's the best way to do so, contact the @mcore-oncall.

Contribution process

flowchart LR
    A[Pre-checks] --> B[PR Tests]
    subgraph Code Review/Approval
        C1[Expert Review] --> C2[Final Review]
    end
    B --> C1
    C2 --> D[Merge]
Loading

Pre-checks

  • I want this PR in a versioned release and have added the appropriate Milestone (e.g., Core 0.8)
  • I have added relevant unit tests
  • I have added relevant functional tests
  • I have added proper typing to my code Typing guidelines
  • I have added relevant documentation
  • I have run the autoformatter.sh on my PR

Code review

The following process is enforced via the CODEOWNERS file for changes into megatron/core. For changes outside of megatron/core, it is up to the PR author whether or not to tag the Final Reviewer team.

For MRs into `main` branch

Feel free to message or comment the @mcore-oncall to help accelerate your merge into main. The less complex your PR is, the faster it will be approved and merged!

(Step 1): Add PR label Expert Review

(Step 2): Collect the expert reviewers reviews

  1. Attach the Expert Review label when your PR is ready for review.
  2. GitHub auto-assigns expert reviewers based on your changes. They will get notified and pick up your PR soon.

⚠️ Only proceed to the next step once all reviewers have approved, merge-conflict are resolved and the CI is passing.
Final Review might get declined if these requirements are not fulfilled.

(Step 3): Final Review

  1. Add Final Review label
  2. GitHub auto-assigns final reviewers based on your changes. They will get notified and pick up your PR soon.

(Optional Step 4): Cherry-pick into release branch

If this PR also needs to be merged into core_r* release branches, after this PR has been merged, select Cherry-pick to open a new PR into the release branch.

For MRs into `dev` branch The proposed review process for `dev` branch is under active discussion.

MRs are mergable after one approval by either eharper@nvidia.com or zijiey@nvidia.com.

Merging your PR

Any member of core-adlr and core-nemo will be able to merge your PR.

@yanring yanring left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Approve on behalf of Pingtian

@svcnvidia-nemo-ci

Copy link
Copy Markdown

🔄 Merge queue validation started!

You can track the progress here: https://github.com/NVIDIA/Megatron-LM/actions/runs/22603788165

Merged via the queue into NVIDIA:main with commit 7d1c016 Mar 3, 2026
117 checks passed
@sidsingh-nvidia sidsingh-nvidia deleted the inf-opt-all-gather-dispatcher branch March 3, 2026 02:34
ilml added a commit to ilml/Megatron-LM that referenced this pull request Mar 20, 2026
New files:
  - megatron/core/transformer/moe/token_dispatcher_inference.py
  - tests/unit_tests/inference/test_moe_inference.py
yangbofun pushed a commit to xlm-research/Megatron-LM that referenced this pull request May 22, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Expert Review [deprecated] Apply this label to indicate that your PR is ready for expert review. module: inference module: moe Run functional tests

Projects

None yet

Development

Successfully merging this pull request may close these issues.