[Dev][feat] Support A2A Overlap for Megatron-FSDP#3796
Conversation
0c8c10a to
67c3e5f
Compare
|
/ok to test 4676377 |
593456e to
e5a27f0
Compare
|
/ok to test e5a27f0 |
e95d80c to
553c34f
Compare
|
/ok to test 5c3847c |
|
/ok to test 1de0217 |
f8b2044 to
9660240
Compare
|
/ok to test 9660240 |
|
/ok to test 1a0c098 |
@Wohox, there was an error processing your request: See the following link for more information: https://docs.gha-runners.nvidia.com/cpr/e/2/ |
|
/ok to test bbaca66 |
2543009 to
6bc436d
Compare
Co-authored-by: claude[bot] <209825114+claude[bot]@users.noreply.github.com>
|
/ok to test a477539 |
|
/ok to test cb83f70 |
|
🔄 Merge queue validation started! You can track the progress here: https://github.com/NVIDIA/Megatron-LM/actions/runs/25860314806 |
|
/ok to test db40f1d |
|
🔄 Merge queue validation started! You can track the progress here: https://github.com/NVIDIA/Megatron-LM/actions/runs/25904045091 |
|
🔄 Merge queue validation started! You can track the progress here: https://github.com/NVIDIA/Megatron-LM/actions/runs/25905070464 |
|
🔄 Merge queue validation started! You can track the progress here: https://github.com/NVIDIA/Megatron-LM/actions/runs/25914455079 |
Upstream dev tip: 77c0f8c Pulled commits: - 77c0f8c [Dev][feat] Support A2A Overlap for Megatron-FSDP (NVIDIA#3796) - 8195337 [dev] [3/5] Qwen3.5 support: SharedExpertMLP meta init (NVIDIA#4749) - 2672ff5 [DEV] fix(megatron-fsdp): preserve non-meta tensors during meta materialization (NVIDIA#4155) - cfbd9df [dev] [4/5] Qwen3.5 support: Interleaved MRoPE layout (NVIDIA#4750) - df12802 [dev] Fix GDN DTensor splitting for FSDP checkpointing (NVIDIA#4799) Resolution: zero conflicts; git auto-merged 12 shared files in megatron/core/{distributed,models,pipeline_parallel,transformer} and tests/unit_tests/a2a_overlap. No ai-blaise custom files touched. Gates: - git diff --check: clean - conflict markers: none - py_compile (16 changed .py files): OK - indexcache: 27/28 pass; the 1 fail (test_nvfp4_non_blackwell_cuda_uses_reference_fallback) reproduces identically at the pre-merge base SHA (sglang occupies all 8 H200s in EXCLUSIVE_PROCESS mode -> cudaErrorDevicesUnavailable). 1 Blackwell-only test auto-skips on H200. - transformer gdn/mtp/moe suite: 53 failed / 7 passed / 55 skipped / 5 errors -- IDENTICAL numbers at pre-merge base; all failures are the same environmental cudaErrorDevicesUnavailable. - 2-rank torchrun layer-wise optimizer smoke: blocked (no free GPUs). Custom preserved: StreamBP, IndexCache config, NVFP4 indexer (7e78f28), HISA topk1024 backward test (c628c13), pyproject emerging_optimizers v0.2.0 pin, mHC/MTP/MoE composition.
The FSDP early-return comment in ``_copy_model_params_to_main_params`` cited specific line numbers in ``megatron/core/distributed/fsdp/src/megatron_fsdp/megatron_fsdp.py``. Those line numbers shift on every unrelated change to that file (e.g. the A2A overlap PR NVIDIA#3796 already shifted them on dev). Replace with the stable symbol names — ``_replace_param_with_distributed_if_needed``, ``install_optimized_model_weights``, ``copy_main_weights_to_model_weights`` — so the comment doesn't go stale. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
The FSDP early-return comment in ``_copy_model_params_to_main_params`` cited specific line numbers in ``megatron/core/distributed/fsdp/src/megatron_fsdp/megatron_fsdp.py``. Those line numbers shift on every unrelated change to that file (e.g. the A2A overlap PR NVIDIA#3796 already shifted them on dev). Replace with the stable symbol names — ``_replace_param_with_distributed_if_needed``, ``install_optimized_model_weights``, ``copy_main_weights_to_model_weights`` — so the comment doesn't go stale. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
The FSDP early-return comment in ``_copy_model_params_to_main_params`` cited specific line numbers in ``megatron/core/distributed/fsdp/src/megatron_fsdp/megatron_fsdp.py``. Those line numbers shift on every unrelated change to that file (e.g. the A2A overlap PR NVIDIA#3796 already shifted them on dev). Replace with the stable symbol names — ``_replace_param_with_distributed_if_needed``, ``install_optimized_model_weights``, ``copy_main_weights_to_model_weights`` — so the comment doesn't go stale. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
What does this PR do ?
PR for main: #3797
Based on: #3766
Support Megatron-FSDP with EP overlap 1F1B schedule
Supported Feature Combinations
Problem
The EP overlap schedule (
overlap_moe_expert_parallel_comm) interleaves forward and backward passes at the sub-module level (attn, mlp, moe_dispatch, moe_combine) to overlap Expert Parallel A2A communication with compute. This schedule bypassesTransformerLayer.forward()entirely, calling sub-modules directly. As a result, FSDP hooks registered onTransformerLayer— including parameter all-gather, parameter release, gradient reduction, and root pre/post backward — are never triggered, making FSDP withoptim_grads_paramsincompatible with this schedule.FSDP Hook Triggering: Standard vs 1F1B Overlap
_replace_param_with_raw_if_neededMegatronFSDP.forward()at the start of every forward call.combined_1f1b_schedule_for_no_pipelining, before any schedule phase begins.MegatronFSDP.forward()entirely — it callsGPTModel.build_schedule_plan()on the unwrapped module directly. The param swap must happen explicitly before the schedule accesses layers._pre_forward_param_unshard(per-layer all-gather)nn.Module.register_forward_pre_hookon each FSDP unit. Triggered automatically by PyTorch whenTransformerLayer.forward()is called.enable_fine_grained_param_gather=True, hooks are registered on every sub-module (not just FSDP units), giving finer-grained all-gather.forward()viaScheduleNode, which triggers the registered forward pre-hooks normally. No manual intervention needed._post_forward(post_forward_release_module) — per-layer param release after forwardnn.Module.register_forward_hookon each FSDP unit. Triggered automatically by PyTorch afterTransformerLayer.forward()returns.set_fsdp_reshard_hooks. Wired incombined_forward_backward_steponto eachTransformerLayerSchedulePlan, and invoked after the last forward node (mlp/moe_combine) of each layer completes.TransformerLayer.forward()into sub-ops (attn, mlp, moe_dispatch, moe_combine). The registeredregister_forward_hookonTransformerLayeris never fired because the schedule never callsTransformerLayer.forward()as a whole — it calls sub-modules directly. So the release must be done explicitly._post_backward_release_module— per-layer param release after backward_register_post_backward_hook(aregister_forward_pre_hookthat inserts aRegisterFSDPBackwardFunctioninto the autograd graph). Triggered automatically by autograd during backward.set_fsdp_reshard_hooks. Wired ontoTransformerLayerSchedulePlan.attn(the last backward node of each layer)._post_forward: the schedule runs backward per-sub-op, not per-TransformerLayer. The autograd-based post-backward hook on the FSDP unit output is never reached because the layer-level forward was never called as a whole._pre_backward_param_unshard(per-layer backward all-gather)create_custom_backward_hook(aregister_forward_hookthat registersregister_multi_grad_hookon the output tensors). Triggered automatically by autograd when gradients arrive at the layer's output.enable_fine_grained_param_gather=True, hooks are registered on every sub-module viacreate_custom_backward_hook. Triggered automatically by autograd when gradients flow through each sub-module.enable_fine_grained_param_gather=False, hooks only exist on FSDP units; since the layer-level forward is skipped, those hooks are never hit. With=True, hooks exist on every sub-module, so autograd still triggers them during the schedule's backward._root_pre_backwardcreate_custom_backward_hookon the root module's output. Triggered automatically by autograd (gradient flows back through root output). Also queues_root_post_backwardviaqueue_callback.fsdp_wrapper.pre_backward()at the start of eachcombined_forward_backward_stepthat has a backward model (b_model is not None). Called withskip_backward_hook=Trueso it does not queue_root_post_backward.MegatronFSDP.forward(), so no root-level output tensor exists in the autograd graph to trigger the hook. It must be called manually.skip_backward_hook=Truebecause_root_post_backwardis also called manually._root_post_backward(grad accumulation + reduce-scatter)torch.autograd.Variable._execution_engine.queue_callback()inside_root_pre_backward. Triggered automatically by the autograd engine after the entire backward pass finishes.fsdp_wrapper.post_backward()at the end of eachcombined_forward_backward_stepthat has a backward model._root_pre_backwardis called withskip_backward_hook=True, the autogradqueue_callbackis skipped. The schedule must callpost_backward()explicitly after all layer backward ops complete to trigger gradient accumulation and reduce-scatter._process_post_backward_gradients(per-param grad accumulation)param.register_post_accumulate_grad_hook. Triggered automatically by autograd when.gradis accumulated on each parameter.delay_wgrad_compute— still triggered automatically by autograd. Parameters are the same leaf tensors regardless of which code path invokes the computation. Manual Triggered withdelay_wgrad_computesince grad acc hook needs to be delayed as well..grad, independent of whether the compute was driven bymodule.forward()or the schedule.Summary: The root cause of all differences is that the 1F1B overlap schedule decomposes
TransformerLayer.forward()into sub-operations (attn, mlp, moe_dispatch, moe_combine) and interleaves them across microbatches. This means:MegatronFSDP.forward()is never called →_replace_param_with_raw_if_neededmust be explicit.TransformerLayer.forward()is never called as a whole → module-levelregister_forward_hook/register_forward_pre_hookon the FSDP unit are not fired →_post_forwardand_post_backward_release_modulemust be wired manually viaset_fsdp_reshard_hooks._root_pre_backward/_root_post_backwardmust be called manually._process_post_backward_gradients) or sub-module level (_pre_forward_param_unshard,_pre_backward_param_unshardwithenable_fine_grained=True) continue to work automatically because the schedule still calls sub-module forwards, and autograd still writes to parameter.grad.Solution
Bridge the gap between Megatron-FSDP's hook-based lifecycle and the EP overlap schedule's direct sub-module invocation by:
Code changes
combined_1f1b.py— Schedule-level FSDP integrationfind_megatron_fsdp()helper to locate the FSDP wrapper from the model chain.fsdp_wrapper._replace_param_with_raw_if_needed()before the schedule starts (normally done insideMegatronFSDP.forward()).fsdp_wrapper.pre_backward()/post_backward()around the combined forward+backward step since the schedule bypasses the autograd hooks that normally trigger them.set_fsdp_reshard_hooks().megatron_fsdp.py— Expose manual hook handles & support delayed wgradskip_backward_hookparameter to_root_pre_backwardso the manual caller can prevent the redundantqueue_callback(_root_post_backward).pre_backward,post_backward,post_forward_release_module, andpost_backward_release_moduleas public attributes for external schedule use.is_delayedparameter to_process_post_backward_gradientsto distinguish delayed wgrad gradient processing from the standard path, preventing deferred parameters from being filtered out by theskip_backward_post_hookguard._register_pre_backward_param_unshard_hook) on every sub-module whenoverlap_moe_expert_parallel_comm+optim_grads_params, so backward all-gather works at the sub-module level.register_post_accumulate_grad_hook, skip parameters withskip_backward_post_hook=Trueso their gradient processing is deferred to the delayed wgrad path.mcore_fsdp_adapter.py— Auto-configure FSDP for EP overlapenable_fine_grained_param_gather_hookwhenoverlap_moe_expert_parallel_commis set (for all sharding strategies), so every sub-module gets its own forward all-gather/unshard hooks.fsdp_double_bufferand partial CUDA graph scopes are disabled whenoverlap_moe_expert_parallel_commis set.fsdp_unit_modulesmust be[TransformerLayer]when combined withoptim_grads_params.model_chunk_schedule_plan.py— Per-layer FSDP hook wiringset_fsdp_reshard_hooks()toTransformerLayerSchedulePlanthat attaches post-forward and post-backward release callbacks to the correct schedule nodes (last forward node, first backward node).fine_grained_callables.py— Node-level FSDP hooksforward()/backward()inTransformerLayerNodeto invoke FSDP reshard hooks at layer boundaries.backward_dw(), collect and execute delayedpost_wgrad_grad_acc_hookfor FSDP gradient processing, then invoke the reshard hook.set_fsdp_post_forward_reshard_hook()/set_fsdp_post_backward_reshard_hook()registration methods.parameters()iterator to_BackwardDWWrapperso FSDP can discover parameters owned by the attn and shared-expert backward-dw callables.Contribution process
Pre-checks
Code review
Feel free to message or comment the @mcore-oncall to help accelerate your merge into main. The less complex your PR is, the faster it will be approved and merged!
All PRs start as draft. If you open a non-draft PR, it will be automatically converted to draft.
Step 1: Mark PR as "Ready for Review"
.github/CODEOWNERS.Final Review might get declined if these requirements are not fulfilled.
Step 2: Final Review
For PRs that change
megatron/core, once all expert reviewers have approved, theFinal Reviewlabel is applied automatically and final reviewers are assigned.For PRs outside
megatron/core, this step is skipped.Step 3: Approved
Once all required reviewers have approved, the
Approvedlabel is applied automatically.Merge
Any member of mcore-engineers will be able to merge your PR.
For MRs into `dev` branch
The proposed review process for `dev` branch is under active discussion.MRs are mergable after one approval by either
eharper@nvidia.comorzijiey@nvidia.com.