Skip to content

[Dev][feat] Support overlapping A2A Combine backprop with wgrad GEMM#3766

Merged
Wohox merged 4 commits into
NVIDIA:devfrom
Wohox:pingtian/support_backawrd_dw_for_fsdp
Apr 7, 2026
Merged

[Dev][feat] Support overlapping A2A Combine backprop with wgrad GEMM#3766
Wohox merged 4 commits into
NVIDIA:devfrom
Wohox:pingtian/support_backawrd_dw_for_fsdp

Conversation

@Wohox

@Wohox Wohox commented Mar 10, 2026

Copy link
Copy Markdown
Contributor

What does this PR do ?

PR for main: #3795

Problem

In MoE models, the expert weight gradient (wgrad) computation during backward is serialized on the main CUDA stream. This blocks the data gradient (dgrad) from flowing to earlier layers until the expert wgrad finishes, even though there is no data dependency between them. The result is wasted GPU cycles — earlier layers' backward pass sits idle waiting for expert wgrad to complete.

With FSDP, this is further compounded because the gradient reduce-scatter for expert parameters is also blocked on the same critical path.

Solution

This PR introduces a new flag --overlap-dispatch-backward-with-experts-wgrad that separates the expert wgrad computation from the main backward stream:

  1. Two autograd functions are inserted into the MoE layer's forward graph:

    • _RecordExpertDgradCompletion — placed before the expert computation; during backward, it records a CUDA event once the expert dgrad is done.
    • _RegisterDelayedWgradForExperts — placed at the dispatch boundary; during backward, it waits on the dgrad event, then launches backward_dw() on a dedicated CUDA stream, and synchronizes back to the main stream before proceeding.
  2. FSDP integration — When used with MegatronFSDP, expert parameters are marked with _fsdp_delay_grad_reduce = True so the normal post-accumulate-grad hook skips them. A callback is registered via register_process_expert_grads_fn() that triggers the FSDP reduce-scatter for expert parameters only after the delayed wgrad computation completes.

  3. TE GroupedLinear is configured with delay_wgrad_compute=True, which tells Transformer Engine to skip wgrad during the normal autograd backward and instead wait for an explicit backward_dw() call.

How to enable

--overlap-dispatch-backward-with-experts-wgrad

Requirements:

  • Transformer Engine >= 2.3.0
  • moe_grouped_gemm enabled (not legacy grouped gemm)
  • Mutually exclusive with --delay-wgrad-compute (the existing A2A-overlap-based delay)
  • Mutually exclusive with --overlap-moe-expert-parallel-comm

Works with both FSDP and 3-D parallelism (TP/EP/PP).

What is achieved

The expert wgrad computation runs on a separate CUDA stream, overlapping with the EP communication within the same transformer layer. This reduces the wall-clock time of the backward pass without changing numerical results — the feature is bit-exact with the non-delayed baseline (verified by unit tests comparing per-step losses and final weights over multiple optimizer steps).

Changes

File Description
megatron/core/model_parallel_config.py New config flag delay_wgrad_compute_for_te_grouped_gemm
megatron/core/transformer/transformer_config.py Validation assertions for the new flag
megatron/core/transformer/moe/moe_layer.py Autograd functions for delayed wgrad + dedicated CUDA stream/event + register_process_expert_grads_fn callback
megatron/core/extensions/transformer_engine.py Pass delay_wgrad_compute=True to TE GroupedLinear when the new flag is set
megatron/core/distributed/fsdp/.../megatron_fsdp.py FSDP hook to defer reduce-scatter for expert params and trigger it after delayed wgrad
tests/unit_tests/a2a_overlap/test_delay_wgrad_compute.py Unit tests covering basic, shared-expert, multi-layer, and FSDP scenarios

Test plan

  • Unit test: test_delay_wgrad_compute_for_te_grouped_gemm — full-model training loop (forward → backward → optimizer) comparing delayed vs. non-delayed across num_layers × shared_experts × dispatcher_type × fp8_flag
  • Unit test: test_delay_wgrad_compute_for_te_grouped_gemm_with_fsdp — same comparison with MegatronFSDP wrapping (fully_shard_model + fully_shard_optimizer), verifying the deferred reduce-scatter path

Contribution process

Pre-checks

  • I have added relevant unit tests
  • I have added relevant functional tests
  • I have added proper typing to my code Typing guidelines
  • I have added relevant documentation
  • I have run the autoformatter.sh on my PR

Code review

Feel free to message or comment the @mcore-oncall to help accelerate your merge into main. The less complex your PR is, the faster it will be approved and merged!

All PRs start as draft. If you open a non-draft PR, it will be automatically converted to draft.

Step 1: Mark PR as "Ready for Review"

  1. When your PR is ready, click Ready for Review.
  2. An oncall reviewer is auto-assigned and expert reviewers are notified based on your changes.
    • Some PRs may jump straight to step 2. This is determined by .github/CODEOWNERS.

⚠️ Only mark as ready once merge-conflicts are resolved and the CI is passing.
Final Review might get declined if these requirements are not fulfilled.

Step 2: Final Review

For PRs that change megatron/core, once all expert reviewers have approved, the Final Review label is applied automatically and final reviewers are assigned.

For PRs outside megatron/core, this step is skipped.

Step 3: Approved

Once all required reviewers have approved, the Approved label is applied automatically.

Merge

Any member of mcore-engineers will be able to merge your PR.

For MRs into `dev` branch The proposed review process for `dev` branch is under active discussion.

MRs are mergable after one approval by either eharper@nvidia.com or zijiey@nvidia.com.

@copy-pr-bot

copy-pr-bot Bot commented Mar 10, 2026

Copy link
Copy Markdown

This pull request requires additional validation before any workflows can run on NVIDIA's runners.

Pull request vetters can view their responsibilities here.

Contributors can view more details about this message here.

@Wohox

Wohox commented Mar 10, 2026

Copy link
Copy Markdown
Contributor Author

/ok to test 48b1c29

@svcnvidia-nemo-ci svcnvidia-nemo-ci added this to the Core 0.16 milestone Mar 10, 2026
@Wohox Wohox requested review from lhb8125 and shjwudp March 10, 2026 10:54
@Wohox

Wohox commented Mar 10, 2026

Copy link
Copy Markdown
Contributor Author

/claude review

Comment thread tests/unit_tests/a2a_overlap/utils.py Outdated
Comment thread megatron/core/distributed/fsdp/src/megatron_fsdp/megatron_fsdp.py Outdated
@Wohox

Wohox commented Mar 10, 2026

Copy link
Copy Markdown
Contributor Author

/ok to test adb70af

@Wohox

Wohox commented Mar 11, 2026

Copy link
Copy Markdown
Contributor Author

/ok to test 6d5d7be

@Wohox Wohox added dev branch Dev branch related issues and development module: moe labels Mar 11, 2026
@Wohox

Wohox commented Mar 11, 2026

Copy link
Copy Markdown
Contributor Author

/ok to test a7d2b2b

@shjwudp shjwudp left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks!

@Wohox Wohox marked this pull request as ready for review March 13, 2026 05:01
@Wohox Wohox requested review from a team as code owners March 13, 2026 05:01
@Wohox Wohox force-pushed the pingtian/support_backawrd_dw_for_fsdp branch from a146c95 to f17a325 Compare March 13, 2026 05:04
@Wohox Wohox changed the title (Draft)[Dev][feat] Support overlapping A2A Combine backprop with wgrad GEMM [Dev][feat] Support overlapping A2A Combine backprop with wgrad GEMM Mar 13, 2026
@Wohox Wohox force-pushed the pingtian/support_backawrd_dw_for_fsdp branch from f17a325 to 3466356 Compare March 13, 2026 05:24
@Wohox

Wohox commented Mar 13, 2026

Copy link
Copy Markdown
Contributor Author

/ok to test 3466356

@Wohox

Wohox commented Mar 13, 2026

Copy link
Copy Markdown
Contributor Author

/ok to test 9d33e8f

@Wohox Wohox force-pushed the pingtian/support_backawrd_dw_for_fsdp branch from 5caa4db to 658dddf Compare April 1, 2026 02:47
@Wohox

Wohox commented Apr 1, 2026

Copy link
Copy Markdown
Contributor Author

/ok to test 658dddf

@Wohox Wohox enabled auto-merge April 1, 2026 02:49
@Wohox Wohox disabled auto-merge April 1, 2026 02:56
Signed-off-by: Cory Ye <cye@nvidia.com>
@Wohox

Wohox commented Apr 3, 2026

Copy link
Copy Markdown
Contributor Author

/ok to test 625d76b

@Wohox Wohox enabled auto-merge April 3, 2026 03:23
@Wohox

Wohox commented Apr 3, 2026

Copy link
Copy Markdown
Contributor Author

/ok to test 4e58bcb

@Wohox Wohox added this pull request to the merge queue Apr 3, 2026
@svcnvidia-nemo-ci

Copy link
Copy Markdown

🔄 Merge queue validation started!

You can track the progress here: https://github.com/NVIDIA/Megatron-LM/actions/runs/23934909561

@github-merge-queue github-merge-queue Bot removed this pull request from the merge queue due to failed status checks Apr 3, 2026
@Wohox Wohox added this pull request to the merge queue Apr 3, 2026
@svcnvidia-nemo-ci

Copy link
Copy Markdown

🔄 Merge queue validation started!

You can track the progress here: https://github.com/NVIDIA/Megatron-LM/actions/runs/23936620425

@Victarry

Victarry commented Apr 3, 2026

Copy link
Copy Markdown
Contributor

Current dev branch is not healthy, trying to fix it with #4123

@github-merge-queue github-merge-queue Bot removed this pull request from the merge queue due to failed status checks Apr 3, 2026
@Wohox Wohox enabled auto-merge April 7, 2026 01:02
@Wohox

Wohox commented Apr 7, 2026

Copy link
Copy Markdown
Contributor Author

/ok to test 7b978c4

@Wohox Wohox added this pull request to the merge queue Apr 7, 2026
@svcnvidia-nemo-ci

Copy link
Copy Markdown

🔄 Merge queue validation started!

You can track the progress here: https://github.com/NVIDIA/Megatron-LM/actions/runs/24061461932

Merged via the queue into NVIDIA:dev with commit 37a4cee Apr 7, 2026
60 of 61 checks passed
@Wohox Wohox deleted the pingtian/support_backawrd_dw_for_fsdp branch April 7, 2026 03:16
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

complexity: medium dev branch Dev branch related issues and development module: moe

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants