Skip to content

[Dev][feat] Support A2A Overlap for Megatron-FSDP#3796

Merged
Wohox merged 10 commits into
NVIDIA:devfrom
Wohox:pingtian/support_1f1b_fsdp
May 15, 2026
Merged

[Dev][feat] Support A2A Overlap for Megatron-FSDP#3796
Wohox merged 10 commits into
NVIDIA:devfrom
Wohox:pingtian/support_1f1b_fsdp

Conversation

@Wohox

@Wohox Wohox commented Mar 11, 2026

Copy link
Copy Markdown
Contributor

What does this PR do ?

PR for main: #3797
Based on: #3766

Support Megatron-FSDP with EP overlap 1F1B schedule

Supported Feature Combinations

# without param sharding
--use-megatron-fsdp --data-parallel-sharding-strategy "optim_grads" --overlap-moe-expert-parallel-comm
--use-megatron-fsdp --data-parallel-sharding-strategy "optim_grads" --overlap-moe-expert-parallel-comm --delay-wgrad-compute


# param sharding
--use-megatron-fsdp --data-parallel-sharding-strategy "optim_grads_params" --overlap-moe-expert-parallel-comm
--use-megatron-fsdp --data-parallel-sharding-strategy "optim_grads_params" --overlap-moe-expert-parallel-comm --delay-wgrad-compute

Problem

The EP overlap schedule (overlap_moe_expert_parallel_comm) interleaves forward and backward passes at the sub-module level (attn, mlp, moe_dispatch, moe_combine) to overlap Expert Parallel A2A communication with compute. This schedule bypasses TransformerLayer.forward() entirely, calling sub-modules directly. As a result, FSDP hooks registered on TransformerLayer — including parameter all-gather, parameter release, gradient reduction, and root pre/post backward — are never triggered, making FSDP with optim_grads_params incompatible with this schedule.

FSDP Hook Triggering: Standard vs 1F1B Overlap

Hook Standard Path (non-1F1B) 1F1B Overlap Path Reason for Difference
_replace_param_with_raw_if_needed Called inside MegatronFSDP.forward() at the start of every forward call. Called once at the top of combined_1f1b_schedule_for_no_pipelining, before any schedule phase begins. The overlap schedule bypasses MegatronFSDP.forward() entirely — it calls GPTModel.build_schedule_plan() on the unwrapped module directly. The param swap must happen explicitly before the schedule accesses layers.
_pre_forward_param_unshard (per-layer all-gather) Registered as nn.Module.register_forward_pre_hook on each FSDP unit. Triggered automatically by PyTorch when TransformerLayer.forward() is called. Same mechanism — still triggered by PyTorch forward pre-hooks. With enable_fine_grained_param_gather=True, hooks are registered on every sub-module (not just FSDP units), giving finer-grained all-gather. The schedule still calls each layer's forward() via ScheduleNode, which triggers the registered forward pre-hooks normally. No manual intervention needed.
_post_forward (post_forward_release_module) — per-layer param release after forward Registered as nn.Module.register_forward_hook on each FSDP unit. Triggered automatically by PyTorch after TransformerLayer.forward() returns. Called manually by the schedule plan via set_fsdp_reshard_hooks. Wired in combined_forward_backward_step onto each TransformerLayerSchedulePlan, and invoked after the last forward node (mlp/moe_combine) of each layer completes. The schedule decomposes TransformerLayer.forward() into sub-ops (attn, mlp, moe_dispatch, moe_combine). The registered register_forward_hook on TransformerLayer is never fired because the schedule never calls TransformerLayer.forward() as a whole — it calls sub-modules directly. So the release must be done explicitly.
_post_backward_release_module — per-layer param release after backward Registered via _register_post_backward_hook (a register_forward_pre_hook that inserts a RegisterFSDPBackwardFunction into the autograd graph). Triggered automatically by autograd during backward. Called manually by the schedule plan via set_fsdp_reshard_hooks. Wired onto TransformerLayerSchedulePlan.attn (the last backward node of each layer). Same reason as _post_forward: the schedule runs backward per-sub-op, not per-TransformerLayer. The autograd-based post-backward hook on the FSDP unit output is never reached because the layer-level forward was never called as a whole.
_pre_backward_param_unshard (per-layer backward all-gather) Attached via create_custom_backward_hook (a register_forward_hook that registers register_multi_grad_hook on the output tensors). Triggered automatically by autograd when gradients arrive at the layer's output. With enable_fine_grained_param_gather=True, hooks are registered on every sub-module via create_custom_backward_hook. Triggered automatically by autograd when gradients flow through each sub-module. When enable_fine_grained_param_gather=False, hooks only exist on FSDP units; since the layer-level forward is skipped, those hooks are never hit. With =True, hooks exist on every sub-module, so autograd still triggers them during the schedule's backward.
_root_pre_backward Attached via create_custom_backward_hook on the root module's output. Triggered automatically by autograd (gradient flows back through root output). Also queues _root_post_backward via queue_callback. Called manually via fsdp_wrapper.pre_backward() at the start of each combined_forward_backward_step that has a backward model (b_model is not None). Called with skip_backward_hook=True so it does not queue _root_post_backward. The overlap schedule does not call MegatronFSDP.forward(), so no root-level output tensor exists in the autograd graph to trigger the hook. It must be called manually. skip_backward_hook=True because _root_post_backward is also called manually.
_root_post_backward (grad accumulation + reduce-scatter) Queued via torch.autograd.Variable._execution_engine.queue_callback() inside _root_pre_backward. Triggered automatically by the autograd engine after the entire backward pass finishes. Called manually via fsdp_wrapper.post_backward() at the end of each combined_forward_backward_step that has a backward model. Since _root_pre_backward is called with skip_backward_hook=True, the autograd queue_callback is skipped. The schedule must call post_backward() explicitly after all layer backward ops complete to trigger gradient accumulation and reduce-scatter.
_process_post_backward_gradients (per-param grad accumulation) Registered via param.register_post_accumulate_grad_hook. Triggered automatically by autograd when .grad is accumulated on each parameter. Same mechanism w.o. delay_wgrad_compute — still triggered automatically by autograd. Parameters are the same leaf tensors regardless of which code path invokes the computation. Manual Triggered with delay_wgrad_compute since grad acc hook needs to be delayed as well. This is a parameter-level hook (not module-level), so it fires whenever autograd writes to .grad, independent of whether the compute was driven by module.forward() or the schedule.

Summary: The root cause of all differences is that the 1F1B overlap schedule decomposes TransformerLayer.forward() into sub-operations (attn, mlp, moe_dispatch, moe_combine) and interleaves them across microbatches. This means:

  1. MegatronFSDP.forward() is never called → _replace_param_with_raw_if_needed must be explicit.
  2. TransformerLayer.forward() is never called as a whole → module-level register_forward_hook / register_forward_pre_hook on the FSDP unit are not fired → _post_forward and _post_backward_release_module must be wired manually via set_fsdp_reshard_hooks.
  3. No root-level output tensor in autograd graph → _root_pre_backward / _root_post_backward must be called manually.
  4. Hooks that operate at the parameter level (_process_post_backward_gradients) or sub-module level (_pre_forward_param_unshard, _pre_backward_param_unshard with enable_fine_grained=True) continue to work automatically because the schedule still calls sub-module forwards, and autograd still writes to parameter .grad.

Solution

Bridge the gap between Megatron-FSDP's hook-based lifecycle and the EP overlap schedule's direct sub-module invocation by:

  1. Enabling fine-grained per-sub-module all-gather hooks so each sub-module can unshard its own parameters.
  2. Exposing manual handles for FSDP lifecycle events that the schedule must drive explicitly.

Code changes

combined_1f1b.py — Schedule-level FSDP integration

  • Add find_megatron_fsdp() helper to locate the FSDP wrapper from the model chain.
  • Call fsdp_wrapper._replace_param_with_raw_if_needed() before the schedule starts (normally done inside MegatronFSDP.forward()).
  • Manually invoke fsdp_wrapper.pre_backward() / post_backward() around the combined forward+backward step since the schedule bypasses the autograd hooks that normally trigger them.
  • Wire per-layer FSDP release callbacks via set_fsdp_reshard_hooks().

megatron_fsdp.py — Expose manual hook handles & support delayed wgrad

  • Add skip_backward_hook parameter to _root_pre_backward so the manual caller can prevent the redundant queue_callback(_root_post_backward).
  • Expose pre_backward, post_backward, post_forward_release_module, and post_backward_release_module as public attributes for external schedule use.
  • Add is_delayed parameter to _process_post_backward_gradients to distinguish delayed wgrad gradient processing from the standard path, preventing deferred parameters from being filtered out by the skip_backward_post_hook guard.
  • Register fine-grained backward all-gather hooks (_register_pre_backward_param_unshard_hook) on every sub-module when overlap_moe_expert_parallel_comm + optim_grads_params, so backward all-gather works at the sub-module level.
  • In per-parameter register_post_accumulate_grad_hook, skip parameters with skip_backward_post_hook=True so their gradient processing is deferred to the delayed wgrad path.

mcore_fsdp_adapter.py — Auto-configure FSDP for EP overlap

  • Enable enable_fine_grained_param_gather_hook when overlap_moe_expert_parallel_comm is set (for all sharding strategies), so every sub-module gets its own forward all-gather/unshard hooks.
  • Assert that fsdp_double_buffer and partial CUDA graph scopes are disabled when overlap_moe_expert_parallel_comm is set.
  • Assert that fsdp_unit_modules must be [TransformerLayer] when combined with optim_grads_params.

model_chunk_schedule_plan.py — Per-layer FSDP hook wiring

  • Add set_fsdp_reshard_hooks() to TransformerLayerSchedulePlan that attaches post-forward and post-backward release callbacks to the correct schedule nodes (last forward node, first backward node).

fine_grained_callables.py — Node-level FSDP hooks

  • Override forward() / backward() in TransformerLayerNode to invoke FSDP reshard hooks at layer boundaries.
  • In backward_dw(), collect and execute delayed post_wgrad_grad_acc_hook for FSDP gradient processing, then invoke the reshard hook.
  • Add set_fsdp_post_forward_reshard_hook() / set_fsdp_post_backward_reshard_hook() registration methods.
  • Add parameters() iterator to _BackwardDWWrapper so FSDP can discover parameters owned by the attn and shared-expert backward-dw callables.

Contribution process

Pre-checks

  • I have added relevant unit tests
  • I have added relevant functional tests
  • I have added proper typing to my code Typing guidelines
  • I have added relevant documentation
  • I have run the autoformatter.sh on my PR

Code review

Feel free to message or comment the @mcore-oncall to help accelerate your merge into main. The less complex your PR is, the faster it will be approved and merged!

All PRs start as draft. If you open a non-draft PR, it will be automatically converted to draft.

Step 1: Mark PR as "Ready for Review"

  1. When your PR is ready, click Ready for Review.
  2. An oncall reviewer is auto-assigned and expert reviewers are notified based on your changes.
    • Some PRs may jump straight to step 2. This is determined by .github/CODEOWNERS.

⚠️ Only mark as ready once merge-conflicts are resolved and the CI is passing.
Final Review might get declined if these requirements are not fulfilled.

Step 2: Final Review

For PRs that change megatron/core, once all expert reviewers have approved, the Final Review label is applied automatically and final reviewers are assigned.

For PRs outside megatron/core, this step is skipped.

Step 3: Approved

Once all required reviewers have approved, the Approved label is applied automatically.

Merge

Any member of mcore-engineers will be able to merge your PR.

For MRs into `dev` branch The proposed review process for `dev` branch is under active discussion.

MRs are mergable after one approval by either eharper@nvidia.com or zijiey@nvidia.com.

@copy-pr-bot

copy-pr-bot Bot commented Mar 11, 2026

Copy link
Copy Markdown

This pull request requires additional validation before any workflows can run on NVIDIA's runners.

Pull request vetters can view their responsibilities here.

Contributors can view more details about this message here.

@Wohox Wohox changed the title (Draft) Support A2A Overlap for Megatron-FSDP (Draft)[Dev][feat] Support A2A Overlap for Megatron-FSDP Mar 11, 2026
@Wohox Wohox force-pushed the pingtian/support_1f1b_fsdp branch from 0c8c10a to 67c3e5f Compare March 16, 2026 03:54
@Wohox Wohox changed the title (Draft)[Dev][feat] Support A2A Overlap for Megatron-FSDP [Dev][feat] Support A2A Overlap for Megatron-FSDP Mar 16, 2026
@Wohox

Wohox commented Mar 16, 2026

Copy link
Copy Markdown
Contributor Author

/ok to test 4676377

@svcnvidia-nemo-ci svcnvidia-nemo-ci added this to the Core 0.16 milestone Mar 16, 2026
@Wohox Wohox force-pushed the pingtian/support_1f1b_fsdp branch from 593456e to e5a27f0 Compare April 7, 2026 09:25
@Wohox

Wohox commented Apr 7, 2026

Copy link
Copy Markdown
Contributor Author

/ok to test e5a27f0

@Wohox Wohox force-pushed the pingtian/support_1f1b_fsdp branch from e95d80c to 553c34f Compare April 9, 2026 07:27
@Wohox

Wohox commented Apr 9, 2026

Copy link
Copy Markdown
Contributor Author

/ok to test 5c3847c

@Wohox

Wohox commented Apr 9, 2026

Copy link
Copy Markdown
Contributor Author

/ok to test 1de0217

@Wohox Wohox force-pushed the pingtian/support_1f1b_fsdp branch from f8b2044 to 9660240 Compare April 15, 2026 08:13
@Wohox Wohox marked this pull request as ready for review April 15, 2026 08:13
@Wohox Wohox requested review from a team as code owners April 15, 2026 08:13
@Wohox

Wohox commented Apr 15, 2026

Copy link
Copy Markdown
Contributor Author

/ok to test 9660240

@Wohox

Wohox commented Apr 15, 2026

Copy link
Copy Markdown
Contributor Author

/ok to test 1a0c098

@copy-pr-bot

copy-pr-bot Bot commented Apr 15, 2026

Copy link
Copy Markdown

/ok to test 1a0c098

@Wohox, there was an error processing your request: E2

See the following link for more information: https://docs.gha-runners.nvidia.com/cpr/e/2/

@Wohox

Wohox commented Apr 15, 2026

Copy link
Copy Markdown
Contributor Author

/ok to test bbaca66

@Wohox Wohox closed this May 7, 2026
@Wohox Wohox reopened this May 14, 2026
@Wohox Wohox force-pushed the pingtian/support_1f1b_fsdp branch from 2543009 to 6bc436d Compare May 14, 2026 07:59
Co-authored-by: claude[bot] <209825114+claude[bot]@users.noreply.github.com>
@Wohox

Wohox commented May 14, 2026

Copy link
Copy Markdown
Contributor Author

/ok to test a477539

@Wohox

Wohox commented May 14, 2026

Copy link
Copy Markdown
Contributor Author

/ok to test cb83f70

@Wohox Wohox added this pull request to the merge queue May 14, 2026
@svcnvidia-nemo-ci

Copy link
Copy Markdown

🔄 Merge queue validation started!

You can track the progress here: https://github.com/NVIDIA/Megatron-LM/actions/runs/25860314806

@github-merge-queue github-merge-queue Bot removed this pull request from the merge queue due to failed status checks May 14, 2026
@Wohox

Wohox commented May 15, 2026

Copy link
Copy Markdown
Contributor Author

/ok to test db40f1d

@Wohox Wohox enabled auto-merge May 15, 2026 01:44
@Wohox Wohox added this pull request to the merge queue May 15, 2026
@svcnvidia-nemo-ci

Copy link
Copy Markdown

🔄 Merge queue validation started!

You can track the progress here: https://github.com/NVIDIA/Megatron-LM/actions/runs/25904045091

@svcnvidia-nemo-ci

Copy link
Copy Markdown

🔄 Merge queue validation started!

You can track the progress here: https://github.com/NVIDIA/Megatron-LM/actions/runs/25905070464

@github-merge-queue github-merge-queue Bot removed this pull request from the merge queue due to failed status checks May 15, 2026
@Wohox Wohox added this pull request to the merge queue May 15, 2026
@svcnvidia-nemo-ci

Copy link
Copy Markdown

🔄 Merge queue validation started!

You can track the progress here: https://github.com/NVIDIA/Megatron-LM/actions/runs/25914455079

Merged via the queue into NVIDIA:dev with commit 77c0f8c May 15, 2026
65 of 66 checks passed
@Wohox Wohox deleted the pingtian/support_1f1b_fsdp branch May 15, 2026 15:09
SpencerGarnets added a commit to ai-blaise/Megatron-LM that referenced this pull request May 16, 2026
Upstream dev tip: 77c0f8c

Pulled commits:

- 77c0f8c [Dev][feat] Support A2A Overlap for Megatron-FSDP (NVIDIA#3796)

- 8195337 [dev] [3/5] Qwen3.5 support: SharedExpertMLP meta init (NVIDIA#4749)

- 2672ff5 [DEV] fix(megatron-fsdp): preserve non-meta tensors during meta materialization (NVIDIA#4155)

- cfbd9df [dev] [4/5] Qwen3.5 support: Interleaved MRoPE layout (NVIDIA#4750)

- df12802 [dev] Fix GDN DTensor splitting for FSDP checkpointing (NVIDIA#4799)

Resolution: zero conflicts; git auto-merged 12 shared files in megatron/core/{distributed,models,pipeline_parallel,transformer} and tests/unit_tests/a2a_overlap. No ai-blaise custom files touched.

Gates:

- git diff --check: clean

- conflict markers: none

- py_compile (16 changed .py files): OK

- indexcache: 27/28 pass; the 1 fail (test_nvfp4_non_blackwell_cuda_uses_reference_fallback) reproduces identically at the pre-merge base SHA (sglang occupies all 8 H200s in EXCLUSIVE_PROCESS mode -> cudaErrorDevicesUnavailable). 1 Blackwell-only test auto-skips on H200.

- transformer gdn/mtp/moe suite: 53 failed / 7 passed / 55 skipped / 5 errors -- IDENTICAL numbers at pre-merge base; all failures are the same environmental cudaErrorDevicesUnavailable.

- 2-rank torchrun layer-wise optimizer smoke: blocked (no free GPUs).

Custom preserved: StreamBP, IndexCache config, NVFP4 indexer (7e78f28), HISA topk1024 backward test (c628c13), pyproject emerging_optimizers v0.2.0 pin, mHC/MTP/MoE composition.
wplf added a commit to wplf/Megatron-LM that referenced this pull request May 18, 2026
The FSDP early-return comment in ``_copy_model_params_to_main_params``
cited specific line numbers in
``megatron/core/distributed/fsdp/src/megatron_fsdp/megatron_fsdp.py``.
Those line numbers shift on every unrelated change to that file
(e.g. the A2A overlap PR NVIDIA#3796 already shifted them on dev). Replace
with the stable symbol names — ``_replace_param_with_distributed_if_needed``,
``install_optimized_model_weights``, ``copy_main_weights_to_model_weights``
— so the comment doesn't go stale.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
wplf added a commit to wplf/Megatron-LM that referenced this pull request May 18, 2026
The FSDP early-return comment in ``_copy_model_params_to_main_params``
cited specific line numbers in
``megatron/core/distributed/fsdp/src/megatron_fsdp/megatron_fsdp.py``.
Those line numbers shift on every unrelated change to that file
(e.g. the A2A overlap PR NVIDIA#3796 already shifted them on dev). Replace
with the stable symbol names — ``_replace_param_with_distributed_if_needed``,
``install_optimized_model_weights``, ``copy_main_weights_to_model_weights``
— so the comment doesn't go stale.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
wplf added a commit to wplf/Megatron-LM that referenced this pull request May 18, 2026
The FSDP early-return comment in ``_copy_model_params_to_main_params``
cited specific line numbers in
``megatron/core/distributed/fsdp/src/megatron_fsdp/megatron_fsdp.py``.
Those line numbers shift on every unrelated change to that file
(e.g. the A2A overlap PR NVIDIA#3796 already shifted them on dev). Replace
with the stable symbol names — ``_replace_param_with_distributed_if_needed``,
``install_optimized_model_weights``, ``copy_main_weights_to_model_weights``
— so the comment doesn't go stale.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants