Skip to content

Combine GEMM + SwiGLU fused MLP PRs (3890, 4071, 4095, 4219, 4311, 4324) → main#4636

Merged
Connor-XY merged 33 commits into
NVIDIA:mainfrom
Connor-XY:yxu1/pr3971-rebase-fix-checks
May 16, 2026
Merged

Combine GEMM + SwiGLU fused MLP PRs (3890, 4071, 4095, 4219, 4311, 4324) → main#4636
Connor-XY merged 33 commits into
NVIDIA:mainfrom
Connor-XY:yxu1/pr3971-rebase-fix-checks

Conversation

@Connor-XY

@Connor-XY Connor-XY commented May 5, 2026

Copy link
Copy Markdown
Contributor

What does this PR do?

This PR consolidates six dev-branch PRs related to TE fused grouped MLP into a single PR targeting main. It supersedes #3971 (which mirrored only #3890 to main) and brings the full feature set in one merge.

Source PRs (all merged to dev)

dev PR Title What it adds
#3890 Support GEMM + Swiglu fused MLP TE op-fuser path for TEGroupedMLP (_is_fused_impl_supported, _make_fused_ops, _fused_forward, _make_fused_impl_pre_forward_hook); GLU-interleaved layout (moe_mlp_glu_interleave_size); grouped-quantized-tensor support in DistributedOptimizer (_is_grouped_quantized_tensor, _expand_quantized_param_shard_for_cast); _normalize_grouped_parameter_keys checkpoint-key compat hook on TEGroupedLinear; CUTLASS 256-byte alignment in get_align_size_for_quantization when op-fuser is on; new CLI flag --use-transformer-engine-op-fuser.
#4071 Skip routed expert padding for graph-safe MoE Refactors the moe_router_padding_for_quantization check into a skip_routed_expert_padding(config) helper that also skips quantization padding when the token dispatcher is flex with hybridep backend (the dispatcher applies padding itself). Eliminates a double-pad on graph-safe HybridEP paths.
#4095 TE fused grouped MLP with grouped bias and delayed wgrad Adds single_grouped_bias to the op-fuser path (renames single_grouped_parametersingle_grouped_weight); threads delay_wgrad_compute through te.pytorch.ops.GroupedLinear; drops the need_backward_dw() gate; gates _is_fused_impl_supported on TE ≥ 2.14.0; adds a fused-aware backward_dw() override.
#4219 Enabled fused grouped MLP for quick_gelu and add config for grouped params Extends _is_fused_impl_supported and _make_fused_ops to support quick_gelu via te.pytorch.ops.ScaledClampedQGeGLU; adds moe_single_grouped_weight / moe_single_grouped_bias TransformerConfig fields with validation; threads them through TEGroupedLinear.__init__ and _set_arg for checkpoint loading; updates the test_hybrid_moe_model GOLDEN_CONFIG.
#4311 Fix fused grouped MLP wgrad hooks for DDP reduce-scatter After the fused children's backward_dw() call, explicitly invoke linear_fc{1,2}._trigger_wgrad_accumulation_and_reduce_hooks(). The wgrad hooks (DDP reduce-scatter, etc.) live on the original linear_fc1/fc2 modules but backward_dw() is called on the new GroupedLinear instances created by _make_fused_ops(). Without the explicit trigger, param.grad is never zeroed and AccumulateGrad performs a spurious add_ into main_grad.
#4324 Fix checkpoint loading with load_main_params_from_ckpt=True for grouped weight Adds _normalize_state_dict_for_grouped_params to DistributedOptimizer._build_model_param_to_state_dict_param_map. Mirrors the normalize_grouped_parameter_keys hook (which only fires during load_state_dict) so the optimizer reload path also handles single-grouped ↔ indexed weight layout drift.

Rebase-fix work on top

Beyond the six source PRs, this branch carries fixes needed to land the combined feature on current main:

  • Fix the self.bias_act_func → local bias_act_func typo on the non-recompute forward path (the closure is defined locally; calling it via self. AttributeError'd at runtime).
  • Promote normalize_grouped_parameter_keys from an inline __init__ closure to a method _normalize_grouped_parameter_keys so it's directly testable.
  • Use device="meta" for op-shell construction in _make_fused_ops (existing weights are reattached after, so the GPU allocation was a wasted transient).
  • Read the wrappers' combined delay_wgrad_compute (config.delay_wgrad_compute or config.overlap_dispatch_backward_with_experts_wgrad) in _make_fused_ops and the backward_dw() gate, not the raw config flag — otherwise runs that enable wgrad delay through overlap_dispatch_backward_with_experts_wgrad would silently lose the overlap optimization.
  • Fix _make_fused_ops activation comment typo "GEGL" → "GeGLU".
  • Add unit-test coverage that the original PRs didn't carry: _apply_bias, _remove_glu_interleaving, _make_fused_impl_pre_forward_hook, _make_fused_ops shapes (with mocked TE) including the quick_gelu/ScaledClampedQGeGLU branch and the single_grouped_bias=True path, _fused_forward arg threading, _split_grouped_checkpoint_tensor (all branches), _normalize_grouped_parameter_keys (8 cases including round-trip), _normalize_state_dict_for_grouped_params (8 cases — indexed↔grouped, bias path, module.-prefix stripping, ambiguous-match guard, incomplete-set guard, target-already-present guard, non-TE-grouped skip), backward_dw fused dispatch including PR 4311 hook trigger, backward_dw fallback when delay is off, _expand_quantized_param_shard_for_cast, _is_grouped_quantized_tensor. Plus a GPU-gated test_gpu_make_fused_ops_constructs_with_real_te for the meta-device + weight-reattach path against real TE.
  • Comment improvements in param_and_grad_buffer.py (no behavior change), addressing reviewer questions on the original PR.
  • Various lint and config alignments to keep CI green.

Files changed (vs main)

megatron/core/distributed/param_and_grad_buffer.py            (comments only)
megatron/core/extensions/transformer_engine.py                (_normalize_grouped_parameter_keys + _split_grouped_checkpoint_tensor + extra_kwargs for grouped params)
megatron/core/optimizer/distrib_optimizer.py                  (grouped-quantized-tensor support + _normalize_state_dict_for_grouped_params)
megatron/core/transformer/moe/experts.py                      (TEGroupedMLP fused path + DDP wgrad hook trigger + skip_routed_expert_padding)
megatron/core/transformer/moe/moe_utils.py                    (CUTLASS alignment + skip_routed_expert_padding helper)
megatron/core/transformer/transformer_config.py               (new flags: use_transformer_engine_op_fuser, moe_mlp_glu_interleave_size, moe_single_grouped_{weight,bias})
megatron/training/checkpointing.py                            (_set_arg for new flags)
tests/unit_tests/models/test_hybrid_moe_model.py              (GOLDEN_CONFIG)
tests/unit_tests/optimizer/test_distrib_optimizer_grouped_quantized.py   (NEW)
tests/unit_tests/transformer/moe/test_grouped_mlp.py
tests/unit_tests/transformer/test_transformer_engine_grouped_linear.py   (NEW)

Validation

  • Linters clean (black --skip-magic-trailing-comma --skip-string-normalization --check, isort --check, pylint 10/10, ruff check) on the full lint scope.
  • CPU mock + GPU integration tests pass on draco's batch_singlenode partition with the canonical post-rename TE container (TE 2.14.0+f031cf87): 46 passed in 6.54s, including test_gpu_forward_backward, test_gpu_make_fused_ops_constructs_with_real_te, all the new coverage above, and the existing _split_grouped_checkpoint_tensor / _normalize_grouped_parameter_keys round-trip tests.

Contribution process

Outstanding follow-up

  • @zhongbozhu noted on distrib_optimizer.py:1083 that the single-grouped-weight path needs more numerical fixes from dev. Awaiting pointers to specific commits/PRs to either fold into this PR or track in a follow-up.

@copy-pr-bot

copy-pr-bot Bot commented May 5, 2026

Copy link
Copy Markdown

This pull request requires additional validation before any workflows can run on NVIDIA's runners.

Pull request vetters can view their responsibilities here.

Contributors can view more details about this message here.

@Connor-XY Connor-XY force-pushed the yxu1/pr3971-rebase-fix-checks branch from a2c5b96 to f3873ce Compare May 5, 2026 22:16
@Connor-XY

Copy link
Copy Markdown
Contributor Author

/claude review

Comment on lines +1794 to +1836
def normalize_grouped_parameter_keys(
self,
state_dict,
prefix,
local_metadata,
strict,
missing_keys,
unexpected_keys,
error_msgs,
):
"""Make grouped checkpoint keys compatible across parameter layouts."""

def maybe_remap_param(param_name: str) -> None:
grouped_key = f"{prefix}{param_name}"
indexed_keys = [
f"{prefix}{param_name}{gemm_idx}" for gemm_idx in range(self.num_gemms)
]
has_grouped_key = grouped_key in state_dict
has_any_indexed_key = any(key in state_dict for key in indexed_keys)
has_all_indexed_keys = all(key in state_dict for key in indexed_keys)

if getattr(self, "single_grouped_weight", False):
if has_grouped_key or not has_all_indexed_keys:
return
state_dict[grouped_key] = torch.stack(
[state_dict.pop(key) for key in indexed_keys], dim=0
)
else:
if has_any_indexed_key or not has_grouped_key:
return
split_tensors = self._split_grouped_checkpoint_tensor(
state_dict.pop(grouped_key), grouped_key
)
for gemm_idx, tensor in enumerate(split_tensors):
state_dict[f"{prefix}{param_name}{gemm_idx}"] = tensor

maybe_remap_param("weight")
if self.use_bias:
maybe_remap_param("bias")

self._register_load_state_dict_pre_hook(
normalize_grouped_parameter_keys, with_module=True
)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: normalize_grouped_parameter_keys is checkpoint-critical logic (it silently mutates state_dict during load_state_dict) but has no direct unit test. _split_grouped_checkpoint_tensor is well-covered in test_transformer_engine_grouped_linear.py, but the hook's own orchestration — early-return conditions, the torch.stack path for single_grouped_weight, key naming — is untested.

A small test that round-trips both remapping directions (indexed→grouped and grouped→indexed) via the hook would guard against regressions in checkpoint loading across weight layouts.

@claude claude Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code looks correct and well-structured. Good test coverage for the new distributed optimizer helpers and the fused ops construction. One suggestion: the normalize_grouped_parameter_keys checkpoint-loading hook would benefit from a direct unit test (see inline comment).

Connor-XY added a commit to Connor-XY/Megatron-LM that referenced this pull request May 5, 2026
Promote the inner state_dict pre-hook to a `_normalize_grouped_parameter_keys`
method on `TEGroupedLinear` so it can be invoked directly from tests.
Behavior at registration time is preserved: the unbound method is registered
with `with_module=True`, so PyTorch passes the module as `self`.

Add unit tests covering the hook's orchestration:
- indexed→grouped fold under `single_grouped_weight=True` (weight only,
  weight + bias)
- grouped→indexed split under `single_grouped_weight=False` (weight only,
  weight + bias)
- bias is left untouched when `use_bias=False`
- early-return when target layout is already present
- early-return when the indexed key set is incomplete
- round-trip: indexed checkpoint → grouped model → indexed model preserves
  the per-GEMM tensors

Address PR NVIDIA#4636 review comment — `_split_grouped_checkpoint_tensor` is
already covered, but the surrounding hook (key naming, early-returns, the
torch.stack path) was untested.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Connor-XY added a commit to Connor-XY/Megatron-LM that referenced this pull request May 5, 2026
Promote the inner state_dict pre-hook to a `_normalize_grouped_parameter_keys`
method on `TEGroupedLinear` so it can be invoked directly from tests.
Behavior at registration time is preserved: the unbound method is registered
with `with_module=True`, so PyTorch passes the module as `self`.

Add unit tests covering the hook's orchestration:
- indexed→grouped fold under `single_grouped_weight=True` (weight only,
  weight + bias)
- grouped→indexed split under `single_grouped_weight=False` (weight only,
  weight + bias)
- bias is left untouched when `use_bias=False`
- early-return when target layout is already present
- early-return when the indexed key set is incomplete
- round-trip: indexed checkpoint → grouped model → indexed model preserves
  the per-GEMM tensors

Address PR NVIDIA#4636 review comment — `_split_grouped_checkpoint_tensor` is
already covered, but the surrounding hook (key naming, early-returns, the
torch.stack path) was untested.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
@Connor-XY Connor-XY force-pushed the yxu1/pr3971-rebase-fix-checks branch from 49dde6b to fe61201 Compare May 5, 2026 23:08
@Connor-XY

Copy link
Copy Markdown
Contributor Author

/ok to test fe61201

@Connor-XY Connor-XY marked this pull request as ready for review May 5, 2026 23:14
@Connor-XY Connor-XY requested review from a team as code owners May 5, 2026 23:14
@svcnvidia-nemo-ci svcnvidia-nemo-ci requested a review from a team May 5, 2026 23:14
@Connor-XY

Copy link
Copy Markdown
Contributor Author

/ok to test 4b0b739

@Connor-XY Connor-XY mentioned this pull request May 7, 2026
5 tasks
@Connor-XY Connor-XY requested review from a team as code owners May 7, 2026 20:37
@Connor-XY Connor-XY changed the title Support GEMM + SwiGLU fused MLP (rebased from #3971) Combine GEMM + SwiGLU fused MLP PRs (3890, 4095, 4219, 4324) main → main May 7, 2026
Connor-XY added a commit to Connor-XY/Megatron-LM that referenced this pull request May 7, 2026
Promote the inner state_dict pre-hook to a `_normalize_grouped_parameter_keys`
method on `TEGroupedLinear` so it can be invoked directly from tests.
Behavior at registration time is preserved: the unbound method is registered
with `with_module=True`, so PyTorch passes the module as `self`.

Add unit tests covering the hook's orchestration:
- indexed→grouped fold under `single_grouped_weight=True` (weight only,
  weight + bias)
- grouped→indexed split under `single_grouped_weight=False` (weight only,
  weight + bias)
- bias is left untouched when `use_bias=False`
- early-return when target layout is already present
- early-return when the indexed key set is incomplete
- round-trip: indexed checkpoint → grouped model → indexed model preserves
  the per-GEMM tensors

Address PR NVIDIA#4636 review comment — `_split_grouped_checkpoint_tensor` is
already covered, but the surrounding hook (key naming, early-returns, the
torch.stack path) was untested.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
@Connor-XY Connor-XY force-pushed the yxu1/pr3971-rebase-fix-checks branch from 61fdfb2 to 1858cdc Compare May 7, 2026 23:23
@Connor-XY

Copy link
Copy Markdown
Contributor Author

/ok to test 1858cdc

@Connor-XY

Copy link
Copy Markdown
Contributor Author

/claude review

Per follow-up review: enable TE's cuDSL fused grouped MLP gate in the
MoE unit-test CI buckets so the kernel is automatically picked up when
Blackwell joins the unit-test matrix. Today the kernel also requires
SM100, so on H100/A100 hardware this is a no-op (TE's other gates fail
and it falls back to basic-op). The conditional matches
`transformer/moe` and `test_moe_experts` buckets so non-MoE tests stay
unchanged.

Set in spec.script before the run_ci_test.sh invocation so the env var
is in place when pytest starts and TE imports — TE's
ForwardGroupedMLP_CuTeGEMMSwiGLU_MXFP8.is_supported() reads the env
var during its import-time register_forward_fusion call, so it must be
set in the process env, not in a fixture.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
@Connor-XY

Copy link
Copy Markdown
Contributor Author

/ok to test 46be5d1

Upstream dev PR NVIDIA#4621 ("Fix single grouped weight
when enabling MXFP8 primary weight") fixes a numerical bug in the
dist-optimizer's quantized-param shard path on the single-grouped-
weight storage. That fix is still on dev and overlaps with code this
PR already touches.

The buggy combo isn't exercised by this PR's tests (bf16, no fp8), so
this PR can ship without the fix — but users who set
moe_single_grouped_weight=True outside the mxfp8 recipe would silently
train on a broken numerical path. Add a __post_init__ guard that
raises early with a clear pointer to NVIDIA#4621.

Scoped to moe_single_grouped_weight only, not moe_single_grouped_bias:
the fix in NVIDIA#4621 touches only the quantized-tensor code paths in
distrib_optimizer.py and param_and_grad_buffer.py, and biases don't
enter those paths (they aren't quantized in mxfp8 training). Upstream
NVIDIA#4621's validation gates both for safety; ours is tighter because the
actual bug is weight-only.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
@Connor-XY

Connor-XY commented May 13, 2026

Copy link
Copy Markdown
Contributor Author

Picking up this thread now that #4621 is visible.

Status check:

To address your "ideally no known bug" concern without coupling this PR's schedule to #4621, I added a fail-fast __post_init__ guard in transformer_config.py (commit 8d3920d). It rejects moe_single_grouped_weight=True unless the user is also on fp8 mode with the mxfp8 recipe:

if self.moe_single_grouped_weight:
    if self.fp4 or not self.fp8 or self.fp8_recipe != Fp8Recipe.mxfp8:
        raise ValueError(
            "moe_single_grouped_weight is currently supported only with fp8 mode "
            "and fp8_recipe='mxfp8'."
        )

Comment on lines +764 to +768
# TODO: find a better place to invoke _trigger_wgrad_accumulation_and_reduce_hooks.
# The wgrad hook registration lives in TE while the trigger is issued here
# in MCore, so the hook lifecycle is split across both codebases. Consolidate
# ownership on one side (either register+trigger entirely in TE, or expose
# the fused backward_dw through MCore) to remove this fragmentation.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a plan to address this?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cc @gdengk

@Connor-XY

Copy link
Copy Markdown
Contributor Author

/ok to test 8d3920d

…r hunk

The `_cached_param_buffer_shards_grad_enabled` field, its read site in
`start_param_sync()`, and the `with torch.no_grad()` wrap around the
coalescing manager all originated in NVIDIA#3890 on the dev branch. The dev
sync merge `79aeecfe0` (Mar 25 2026) explicitly removed the read site
and the no_grad wrap during conflict resolution when it pulled in the
layerwise-optimizer code from main — only the field init survived as an
orphan in `__init__`. The active logic was deliberately dropped, no
regression was reported on dev or main in the intervening months, and
zhongbozhu flagged this exact block on this PR (r3211212707) noting it
was removed in dev.

For a PR targeting main, resurrecting a hunk that was specifically
dropped during a merge — without a fresh repro proving main needs it —
is the wrong default. Remove all three pieces (the orphan init, the
read site, the no_grad wrap) so this file matches main's shape except
for the changes that are genuinely part of this PR's scope.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
@Connor-XY

Copy link
Copy Markdown
Contributor Author

/ok to test d49ad27

@zhongbozhu zhongbozhu left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For future reference, let's aggregate all the PRs necessary and run E2E convergence test with optimization toggles turned on.

@svcnvidia-nemo-ci

Copy link
Copy Markdown

🔄 Merge queue validation started!

You can track the progress here: https://github.com/NVIDIA/Megatron-LM/actions/runs/25890757151

@svcnvidia-nemo-ci

Copy link
Copy Markdown

🔄 Merge queue validation started!

You can track the progress here: https://github.com/NVIDIA/Megatron-LM/actions/runs/25902398542

@svcnvidia-nemo-ci

Copy link
Copy Markdown

🔄 Merge queue validation started!

You can track the progress here: https://github.com/NVIDIA/Megatron-LM/actions/runs/25908295291

@svcnvidia-nemo-ci

Copy link
Copy Markdown

🔄 Merge queue validation started!

You can track the progress here: https://github.com/NVIDIA/Megatron-LM/actions/runs/25930508620

@svcnvidia-nemo-ci

Copy link
Copy Markdown

🔄 Merge queue validation started!

You can track the progress here: https://github.com/NVIDIA/Megatron-LM/actions/runs/25938099343

@svcnvidia-nemo-ci

Copy link
Copy Markdown

🔄 Merge queue validation started!

You can track the progress here: https://github.com/NVIDIA/Megatron-LM/actions/runs/25938125921

@svcnvidia-nemo-ci

Copy link
Copy Markdown

🔄 Merge queue validation started!

You can track the progress here: https://github.com/NVIDIA/Megatron-LM/actions/runs/25942162810

@svcnvidia-nemo-ci

Copy link
Copy Markdown

🔄 Merge queue validation started!

You can track the progress here: https://github.com/NVIDIA/Megatron-LM/actions/runs/25947850886

@svcnvidia-nemo-ci

Copy link
Copy Markdown

🔄 Merge queue validation started!

You can track the progress here: https://github.com/NVIDIA/Megatron-LM/actions/runs/25952777925

@ksivaman

Copy link
Copy Markdown
Member

Confirmed numerics with TE tests are equivalent as that with dev branch. LGTM.

Comment on lines +46 to +53
# Enable TE's cuDSL fused grouped MLP path for MoE buckets. The kernel
# additionally requires SM100 (Blackwell), so on H100/A100 CI this is a
# no-op; setting it here means the kernel is picked up automatically when
# Blackwell hardware joins the unit-test matrix.
if [[ "$BUCKET" == *"transformer/moe"* || "$BUCKET" == *"test_moe_experts"* ]]; then
export NVTE_CUTEDSL_FUSED_GROUPED_MLP=1
fi

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We shouldn't set env vars here.. it makes it difficult to spot how/where behavior comes from. We should only set env vars within the unit tests (i.e. with a fixture).

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

26.06 Approved All necessary approvals have been made complexity: high dev2main: mbridge dev to main: this PR is needed in main for mbridge nemotron Run functional tests

Projects

None yet

Development

Successfully merging this pull request may close these issues.