Combine GEMM + SwiGLU fused MLP PRs (3890, 4071, 4095, 4219, 4311, 4324) → main#4636
Conversation
a2c5b96 to
f3873ce
Compare
|
/claude review |
| def normalize_grouped_parameter_keys( | ||
| self, | ||
| state_dict, | ||
| prefix, | ||
| local_metadata, | ||
| strict, | ||
| missing_keys, | ||
| unexpected_keys, | ||
| error_msgs, | ||
| ): | ||
| """Make grouped checkpoint keys compatible across parameter layouts.""" | ||
|
|
||
| def maybe_remap_param(param_name: str) -> None: | ||
| grouped_key = f"{prefix}{param_name}" | ||
| indexed_keys = [ | ||
| f"{prefix}{param_name}{gemm_idx}" for gemm_idx in range(self.num_gemms) | ||
| ] | ||
| has_grouped_key = grouped_key in state_dict | ||
| has_any_indexed_key = any(key in state_dict for key in indexed_keys) | ||
| has_all_indexed_keys = all(key in state_dict for key in indexed_keys) | ||
|
|
||
| if getattr(self, "single_grouped_weight", False): | ||
| if has_grouped_key or not has_all_indexed_keys: | ||
| return | ||
| state_dict[grouped_key] = torch.stack( | ||
| [state_dict.pop(key) for key in indexed_keys], dim=0 | ||
| ) | ||
| else: | ||
| if has_any_indexed_key or not has_grouped_key: | ||
| return | ||
| split_tensors = self._split_grouped_checkpoint_tensor( | ||
| state_dict.pop(grouped_key), grouped_key | ||
| ) | ||
| for gemm_idx, tensor in enumerate(split_tensors): | ||
| state_dict[f"{prefix}{param_name}{gemm_idx}"] = tensor | ||
|
|
||
| maybe_remap_param("weight") | ||
| if self.use_bias: | ||
| maybe_remap_param("bias") | ||
|
|
||
| self._register_load_state_dict_pre_hook( | ||
| normalize_grouped_parameter_keys, with_module=True | ||
| ) |
There was a problem hiding this comment.
nit: normalize_grouped_parameter_keys is checkpoint-critical logic (it silently mutates state_dict during load_state_dict) but has no direct unit test. _split_grouped_checkpoint_tensor is well-covered in test_transformer_engine_grouped_linear.py, but the hook's own orchestration — early-return conditions, the torch.stack path for single_grouped_weight, key naming — is untested.
A small test that round-trips both remapping directions (indexed→grouped and grouped→indexed) via the hook would guard against regressions in checkpoint loading across weight layouts.
There was a problem hiding this comment.
Code looks correct and well-structured. Good test coverage for the new distributed optimizer helpers and the fused ops construction. One suggestion: the normalize_grouped_parameter_keys checkpoint-loading hook would benefit from a direct unit test (see inline comment).
Promote the inner state_dict pre-hook to a `_normalize_grouped_parameter_keys` method on `TEGroupedLinear` so it can be invoked directly from tests. Behavior at registration time is preserved: the unbound method is registered with `with_module=True`, so PyTorch passes the module as `self`. Add unit tests covering the hook's orchestration: - indexed→grouped fold under `single_grouped_weight=True` (weight only, weight + bias) - grouped→indexed split under `single_grouped_weight=False` (weight only, weight + bias) - bias is left untouched when `use_bias=False` - early-return when target layout is already present - early-return when the indexed key set is incomplete - round-trip: indexed checkpoint → grouped model → indexed model preserves the per-GEMM tensors Address PR NVIDIA#4636 review comment — `_split_grouped_checkpoint_tensor` is already covered, but the surrounding hook (key naming, early-returns, the torch.stack path) was untested. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Promote the inner state_dict pre-hook to a `_normalize_grouped_parameter_keys` method on `TEGroupedLinear` so it can be invoked directly from tests. Behavior at registration time is preserved: the unbound method is registered with `with_module=True`, so PyTorch passes the module as `self`. Add unit tests covering the hook's orchestration: - indexed→grouped fold under `single_grouped_weight=True` (weight only, weight + bias) - grouped→indexed split under `single_grouped_weight=False` (weight only, weight + bias) - bias is left untouched when `use_bias=False` - early-return when target layout is already present - early-return when the indexed key set is incomplete - round-trip: indexed checkpoint → grouped model → indexed model preserves the per-GEMM tensors Address PR NVIDIA#4636 review comment — `_split_grouped_checkpoint_tensor` is already covered, but the surrounding hook (key naming, early-returns, the torch.stack path) was untested. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
49dde6b to
fe61201
Compare
|
/ok to test fe61201 |
|
/ok to test 4b0b739 |
Promote the inner state_dict pre-hook to a `_normalize_grouped_parameter_keys` method on `TEGroupedLinear` so it can be invoked directly from tests. Behavior at registration time is preserved: the unbound method is registered with `with_module=True`, so PyTorch passes the module as `self`. Add unit tests covering the hook's orchestration: - indexed→grouped fold under `single_grouped_weight=True` (weight only, weight + bias) - grouped→indexed split under `single_grouped_weight=False` (weight only, weight + bias) - bias is left untouched when `use_bias=False` - early-return when target layout is already present - early-return when the indexed key set is incomplete - round-trip: indexed checkpoint → grouped model → indexed model preserves the per-GEMM tensors Address PR NVIDIA#4636 review comment — `_split_grouped_checkpoint_tensor` is already covered, but the surrounding hook (key naming, early-returns, the torch.stack path) was untested. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
61fdfb2 to
1858cdc
Compare
|
/ok to test 1858cdc |
|
/claude review |
Per follow-up review: enable TE's cuDSL fused grouped MLP gate in the MoE unit-test CI buckets so the kernel is automatically picked up when Blackwell joins the unit-test matrix. Today the kernel also requires SM100, so on H100/A100 hardware this is a no-op (TE's other gates fail and it falls back to basic-op). The conditional matches `transformer/moe` and `test_moe_experts` buckets so non-MoE tests stay unchanged. Set in spec.script before the run_ci_test.sh invocation so the env var is in place when pytest starts and TE imports — TE's ForwardGroupedMLP_CuTeGEMMSwiGLU_MXFP8.is_supported() reads the env var during its import-time register_forward_fusion call, so it must be set in the process env, not in a fixture. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
|
/ok to test 46be5d1 |
Upstream dev PR NVIDIA#4621 ("Fix single grouped weight when enabling MXFP8 primary weight") fixes a numerical bug in the dist-optimizer's quantized-param shard path on the single-grouped- weight storage. That fix is still on dev and overlaps with code this PR already touches. The buggy combo isn't exercised by this PR's tests (bf16, no fp8), so this PR can ship without the fix — but users who set moe_single_grouped_weight=True outside the mxfp8 recipe would silently train on a broken numerical path. Add a __post_init__ guard that raises early with a clear pointer to NVIDIA#4621. Scoped to moe_single_grouped_weight only, not moe_single_grouped_bias: the fix in NVIDIA#4621 touches only the quantized-tensor code paths in distrib_optimizer.py and param_and_grad_buffer.py, and biases don't enter those paths (they aren't quantized in mxfp8 training). Upstream NVIDIA#4621's validation gates both for safety; ours is tighter because the actual bug is weight-only. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
|
Picking up this thread now that #4621 is visible. Status check:
To address your "ideally no known bug" concern without coupling this PR's schedule to #4621, I added a fail-fast if self.moe_single_grouped_weight:
if self.fp4 or not self.fp8 or self.fp8_recipe != Fp8Recipe.mxfp8:
raise ValueError(
"moe_single_grouped_weight is currently supported only with fp8 mode "
"and fp8_recipe='mxfp8'."
) |
| # TODO: find a better place to invoke _trigger_wgrad_accumulation_and_reduce_hooks. | ||
| # The wgrad hook registration lives in TE while the trigger is issued here | ||
| # in MCore, so the hook lifecycle is split across both codebases. Consolidate | ||
| # ownership on one side (either register+trigger entirely in TE, or expose | ||
| # the fused backward_dw through MCore) to remove this fragmentation. |
There was a problem hiding this comment.
Is there a plan to address this?
|
/ok to test 8d3920d |
…r hunk The `_cached_param_buffer_shards_grad_enabled` field, its read site in `start_param_sync()`, and the `with torch.no_grad()` wrap around the coalescing manager all originated in NVIDIA#3890 on the dev branch. The dev sync merge `79aeecfe0` (Mar 25 2026) explicitly removed the read site and the no_grad wrap during conflict resolution when it pulled in the layerwise-optimizer code from main — only the field init survived as an orphan in `__init__`. The active logic was deliberately dropped, no regression was reported on dev or main in the intervening months, and zhongbozhu flagged this exact block on this PR (r3211212707) noting it was removed in dev. For a PR targeting main, resurrecting a hunk that was specifically dropped during a merge — without a fresh repro proving main needs it — is the wrong default. Remove all three pieces (the orphan init, the read site, the no_grad wrap) so this file matches main's shape except for the changes that are genuinely part of this PR's scope. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
|
/ok to test d49ad27 |
zhongbozhu
left a comment
There was a problem hiding this comment.
For future reference, let's aggregate all the PRs necessary and run E2E convergence test with optimization toggles turned on.
|
🔄 Merge queue validation started! You can track the progress here: https://github.com/NVIDIA/Megatron-LM/actions/runs/25890757151 |
|
🔄 Merge queue validation started! You can track the progress here: https://github.com/NVIDIA/Megatron-LM/actions/runs/25902398542 |
|
🔄 Merge queue validation started! You can track the progress here: https://github.com/NVIDIA/Megatron-LM/actions/runs/25908295291 |
|
🔄 Merge queue validation started! You can track the progress here: https://github.com/NVIDIA/Megatron-LM/actions/runs/25930508620 |
|
🔄 Merge queue validation started! You can track the progress here: https://github.com/NVIDIA/Megatron-LM/actions/runs/25938099343 |
|
🔄 Merge queue validation started! You can track the progress here: https://github.com/NVIDIA/Megatron-LM/actions/runs/25938125921 |
|
🔄 Merge queue validation started! You can track the progress here: https://github.com/NVIDIA/Megatron-LM/actions/runs/25942162810 |
|
🔄 Merge queue validation started! You can track the progress here: https://github.com/NVIDIA/Megatron-LM/actions/runs/25947850886 |
|
🔄 Merge queue validation started! You can track the progress here: https://github.com/NVIDIA/Megatron-LM/actions/runs/25952777925 |
|
Confirmed numerics with TE tests are equivalent as that with |
| # Enable TE's cuDSL fused grouped MLP path for MoE buckets. The kernel | ||
| # additionally requires SM100 (Blackwell), so on H100/A100 CI this is a | ||
| # no-op; setting it here means the kernel is picked up automatically when | ||
| # Blackwell hardware joins the unit-test matrix. | ||
| if [[ "$BUCKET" == *"transformer/moe"* || "$BUCKET" == *"test_moe_experts"* ]]; then | ||
| export NVTE_CUTEDSL_FUSED_GROUPED_MLP=1 | ||
| fi | ||
|
|
There was a problem hiding this comment.
We shouldn't set env vars here.. it makes it difficult to spot how/where behavior comes from. We should only set env vars within the unit tests (i.e. with a fixture).
What does this PR do?
This PR consolidates six
dev-branch PRs related to TE fused grouped MLP into a single PR targetingmain. It supersedes #3971 (which mirrored only #3890 tomain) and brings the full feature set in one merge.Source PRs (all merged to
dev)TEGroupedMLP(_is_fused_impl_supported,_make_fused_ops,_fused_forward,_make_fused_impl_pre_forward_hook); GLU-interleaved layout (moe_mlp_glu_interleave_size); grouped-quantized-tensor support inDistributedOptimizer(_is_grouped_quantized_tensor,_expand_quantized_param_shard_for_cast);_normalize_grouped_parameter_keyscheckpoint-key compat hook onTEGroupedLinear; CUTLASS 256-byte alignment inget_align_size_for_quantizationwhen op-fuser is on; new CLI flag--use-transformer-engine-op-fuser.moe_router_padding_for_quantizationcheck into askip_routed_expert_padding(config)helper that also skips quantization padding when the token dispatcher isflexwithhybridepbackend (the dispatcher applies padding itself). Eliminates a double-pad on graph-safe HybridEP paths.single_grouped_biasto the op-fuser path (renamessingle_grouped_parameter→single_grouped_weight); threadsdelay_wgrad_computethroughte.pytorch.ops.GroupedLinear; drops theneed_backward_dw()gate; gates_is_fused_impl_supportedon TE ≥ 2.14.0; adds a fused-awarebackward_dw()override.quick_geluand add config for grouped params_is_fused_impl_supportedand_make_fused_opsto supportquick_geluviate.pytorch.ops.ScaledClampedQGeGLU; addsmoe_single_grouped_weight/moe_single_grouped_biasTransformerConfigfields with validation; threads them throughTEGroupedLinear.__init__and_set_argfor checkpoint loading; updates thetest_hybrid_moe_modelGOLDEN_CONFIG.backward_dw()call, explicitly invokelinear_fc{1,2}._trigger_wgrad_accumulation_and_reduce_hooks(). The wgrad hooks (DDP reduce-scatter, etc.) live on the originallinear_fc1/fc2modules butbackward_dw()is called on the new GroupedLinear instances created by_make_fused_ops(). Without the explicit trigger,param.gradis never zeroed andAccumulateGradperforms a spuriousadd_intomain_grad.load_main_params_from_ckpt=Truefor grouped weight_normalize_state_dict_for_grouped_paramstoDistributedOptimizer._build_model_param_to_state_dict_param_map. Mirrors thenormalize_grouped_parameter_keyshook (which only fires duringload_state_dict) so the optimizer reload path also handles single-grouped ↔ indexed weight layout drift.Rebase-fix work on top
Beyond the six source PRs, this branch carries fixes needed to land the combined feature on current
main:self.bias_act_func→ localbias_act_functypo on the non-recompute forward path (the closure is defined locally; calling it viaself.AttributeError'd at runtime).normalize_grouped_parameter_keysfrom an inline__init__closure to a method_normalize_grouped_parameter_keysso it's directly testable.device="meta"for op-shell construction in_make_fused_ops(existing weights are reattached after, so the GPU allocation was a wasted transient).delay_wgrad_compute(config.delay_wgrad_compute or config.overlap_dispatch_backward_with_experts_wgrad) in_make_fused_opsand thebackward_dw()gate, not the raw config flag — otherwise runs that enable wgrad delay throughoverlap_dispatch_backward_with_experts_wgradwould silently lose the overlap optimization._make_fused_opsactivation comment typo "GEGL" → "GeGLU"._apply_bias,_remove_glu_interleaving,_make_fused_impl_pre_forward_hook,_make_fused_opsshapes (with mocked TE) including thequick_gelu/ScaledClampedQGeGLU branch and thesingle_grouped_bias=Truepath,_fused_forwardarg threading,_split_grouped_checkpoint_tensor(all branches),_normalize_grouped_parameter_keys(8 cases including round-trip),_normalize_state_dict_for_grouped_params(8 cases — indexed↔grouped, bias path,module.-prefix stripping, ambiguous-match guard, incomplete-set guard, target-already-present guard, non-TE-grouped skip),backward_dwfused dispatch including PR 4311 hook trigger,backward_dwfallback when delay is off,_expand_quantized_param_shard_for_cast,_is_grouped_quantized_tensor. Plus a GPU-gatedtest_gpu_make_fused_ops_constructs_with_real_tefor the meta-device + weight-reattach path against real TE.param_and_grad_buffer.py(no behavior change), addressing reviewer questions on the original PR.Files changed (vs
main)Validation
black --skip-magic-trailing-comma --skip-string-normalization --check,isort --check,pylint10/10,ruff check) on the full lint scope.batch_singlenodepartition with the canonical post-rename TE container (TE2.14.0+f031cf87): 46 passed in 6.54s, includingtest_gpu_forward_backward,test_gpu_make_fused_ops_constructs_with_real_te, all the new coverage above, and the existing_split_grouped_checkpoint_tensor/_normalize_grouped_parameter_keysround-trip tests.Contribution process
quick_geluand add config for grouped params #4219, Fix fused grouped MLP wgrad hooks for DDP reduce-scatter #4311, Fix checkpoint loading withload_main_params_from_ckpt=Truefor grouped weight #4324 applied in theirdevmerge order (rebase-fix work for Support GEMM + Swiglu fused MLP #3890 follows Support GEMM + Swiglu fused MLP #3971's pattern)._normalize_state_dict_for_grouped_paramscross-module ambiguity guard.Outstanding follow-up
distrib_optimizer.py:1083that the single-grouped-weight path needs more numerical fixes fromdev. Awaiting pointers to specific commits/PRs to either fold into this PR or track in a follow-up.