feat(optim): support --fp8-param-gather for muon + mxfp8 in LayerWise#4987
feat(optim): support --fp8-param-gather for muon + mxfp8 in LayerWise#4987Wohox wants to merge 17 commits into
Conversation
Wires the standard distopt MXFP8 + reuse_grad_buf_for_mxfp8_param_ag flow through LayerWiseDistributedOptimizer so muon-managed 2D weights can be all-gathered via the bf16 staging path. Adds the following on top of PR NVIDIA#4889: - LayerWise inner Float16Optimizer is constructed with a shallow-copied config that flips reuse_grad_buf_for_mxfp8_param_ag to False (the inner step path expects DistOpt-only _copy_main_params_to_param_buffer; the LayerWise wrapper owns the bf16 staging write directly). - _skip_mxfp8_in_copy_main_to_model filters MXFP8 model params out of the inner's _copy_main_params_to_model_params so the inner step does not perform a wasted MXFP8 quantize that the post-AG quantize would overwrite anyway. - _write_owned_mxfp8_masters_to_param_buffer writes the bf16 cast of each owned MXFP8 param's fp32 master into the DDP bf16 staging buffer before the AG dispatches. Mirrors DistOpt._copy_main_params_to_param_buffer. - start_param_sync_for_bucket_group_subset triggers the standard distopt buffer AG only on LayerWise-managed bucket groups, so a sibling DistributedOptimizer's own start_param_sync call is not duplicated. - _restore_high_precision_init_val overwrites the fp32 master of any MXFP8 model param with the BF16 init values that TE preserves on CPU (model_param.get_high_precision_init_val()) before the quantize-wrap. Without this, the inner Float16Optimizer's default param.detach().clone().float() dequantizes the MXFP8 storage and bakes ~FP8 precision noise into the master from iter 0, which amplifies through the subsequent bf16⇒MXFP8 round-trip and yields a small muon loss lag vs the fp8_param_gather=False baseline that doesn't close even by iter 100. DistOpt does the equivalent fix-up at distrib_optimizer.py:405-416; this mirrors that path. - training/arguments.py asserts --fp8-recipe=mxfp8 + LayerWise opt + --fp8-param-gather requires --reuse-grad-buf-for-mxfp8-param-ag so misconfiguration fails fast (the bf16 staging buffer for the MXFP8 AG is the idle grad buffer; the two flags must be co-enabled). Verified on OCI-HSG (1×4 GB200, DSV4 proxy, 100 iters): ON v7 iter-10 lm loss = 10.6085 vs OFF baseline 10.5746 (delta 0.034 nats) ON v7 iter-100 lm loss = 0.01605 vs OFF baseline 0.01673 (within noise) Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
…ather Adds tests/unit_tests/test_muon_mxfp8_fp8_param_gather.py asserting that the LayerWise bf16⇒MXFP8 staging round-trip introduced by --fp8-param-gather + --reuse-grad-buf-for-mxfp8-param-ag produces bitwise-identical training trajectories vs --fp8-param-gather=off for the same small GPT model. Wraps both runs in deterministic_mode() (mirrors tests/unit_tests/a2a_overlap/utils.py) and compares per-step loss, forward output, per-parameter main_grad, and per-parameter fp32 master with atol=rtol=0 across 5 steps. Skips when MXFP8 is not supported by the device or TE version. Without _restore_high_precision_init_val in layer_wise_optimizer.py, the fp32 master created by inner Float16Optimizer .detach().clone().float() dequantizes the MXFP8 storage and carries ~FP8 precision noise relative to the bf16 init — this test catches the resulting per-step master divergence as a bitwise mismatch.
Ensure sibling DistributedOptimizer buckets inside mixed bucket groups are synchronized when LayerWise Muon uses MXFP8 param gather with grad-buffer reuse.
…e in test Three follow-ups on the muon LayerWise FP8 param-gather path: 1. Add hard assertion in `arguments.py` that blocks `--fp8-param-gather` with any recipe other than mxfp8 when the optimizer auto-promotes to `LayerWiseDistributedOptimizer`. The bf16-staging + post-AG quantize round-trip is wired up only for MXFP8's `_rowwise_data` / `_columnwise_data` storage; using blockwise FP8 today would silently gather stale storage. Fail fast at arg-parse time. 2. Strip experiment-specific data from new comments in `layer_wise_optimizer.py` (per-iter loss numbers, cluster names, absolute nat lags). Keep the *what* and *why* of each helper so the intent stays clear without hard-coding ephemeral run details. 3. Make the unit test parametrize over `fp8_recipe in [mxfp8, blockwise]` and explicitly `pytest.skip` blockwise. Today the LayerWise path only supports mxfp8; the parametrization documents that surface while the skip keeps blockwise from accidentally running and producing a misleading bitwise mismatch. Also fix the typo `type(a ra m)(` -> `type(bucket_group)(` in `distrib_optimizer.py:3019` (was a paste artifact in the prior commit).
…rted Add an explicit `check_mxfp8_support()` skipif on top of the existing arch / TE-version / fp8-available guards so the test exits cleanly on environments where MXFP8 is not actually available (e.g. pre-Blackwell GPUs, sm_120+ which TE currently rejects, or builds without MXFP8 kernels), with TE's own diagnostic string as the skip reason.
|
/claude review |
There was a problem hiding this comment.
Thorough PR — the bug descriptions in the PR body are excellent. A few items:
Bug (medium): LayerWiseDistributedOptimizer.start_param_sync_for_bucket_group_subset still checks only buckets[0] to decide bucket group ownership, while this same PR upgraded the DistOpt side (distrib_optimizer.py:3006) and the deferred-sync path (optimizer.py:1336) to per-bucket filtering for mixed bucket groups. In partition_buckets Case 3, bf16 DistOpt-managed buckets can be merged into the last FP8 bucket group, causing LayerWise to AG the entire group (including DistOpt buckets) — a double-sync. See inline comment for a suggested fix.
Minor: Unused variable original_get_pairs in _skip_mxfp8_in_copy_main_to_model, and vocal_size → vocab_size typo in the new test file (lines 225, 244).
Co-authored-by: claude[bot] <209825114+claude[bot]@users.noreply.github.com>
Co-authored-by: claude[bot] <209825114+claude[bot]@users.noreply.github.com>
Address review comments on PR NVIDIA#4987: 1. ``LayerWise.start_param_sync_for_bucket_group_subset`` now filters buckets per-bucket instead of checking only ``buckets[0]``, and builds a sub-group containing just the LayerWise-managed buckets when the group is mixed-ownership. ``partition_buckets`` Case 3 (FP8 present, ``reduce_scatter_with_fp32_accumulation=False``) merges non-FP8 DistOpt-managed bf16 buckets (biases, layernorms) into the last FP8 bucket group; without per-bucket filtering, LayerWise dispatched AG for the entire group while the sibling DistOpt also synced those same bf16 buckets via its own per-bucket filter, double-gathering them. Mirrors the per-bucket pattern that ``d1ae16a0a`` applied to the DistOpt deferred-sync path. 2. Drop the unused ``original_get_pairs = inner_opt ._get_model_and_main_params_data_float16`` line in ``_skip_mxfp8_in_copy_main_to_model``. Leftover from development; never referenced.
|
/claude review |
wujingyue
left a comment
There was a problem hiding this comment.
Thanks for the PR! I'll wait until expert reviews (esp dist-optimizer) are done.
…ction Address review comment on PR NVIDIA#4987: the per-bucket filter and the ``type(bucket_group)(distopt_buckets, …)`` sub-group construction added in ``DistributedOptimizer.start_param_sync_for_bucket_group_subset`` had no inline justification. Without that context it isn't obvious why we walk one bucket at a time instead of dispatching the whole group, or when the synthesized sub-group path is taken. Extends the docstring to call out the ``partition_buckets`` Case 3 scenario where the last FP8 bucket group holds both DistOpt-managed bf16 buckets (biases / layernorms) and LayerWise-managed FP8 buckets, and adds inline comments on each of the three branches (entire group LayerWise-owned, pure DistOpt group, mixed group) describing what is dispatched and why the synthesized sub-group inherits the parent's ddp_config / DP group / DP world size so the AG collective lands on the same comm. Pure documentation; no behaviour change.
| # write is gated on ``reuse_grad_buf_for_mxfp8_param_ag``, so require | ||
| # the flag explicitly to fail fast on misconfiguration. | ||
| will_use_layer_wise_distributed_optimizer = ( | ||
| args.optimizer not in ('sgd', 'adam') and args.use_distributed_optimizer |
There was a problem hiding this comment.
@kunlunl to also comment.
My rough understanding is fp8-gather and reuse flags are recommended to be used together even before any muon changes for mxfp8? and at least for mxfp8 there is no reason to use only fp8-gather without reuse
Maybe let's take the opportunity to finally clean up these flag that doesn't quite show what they do as name suggests. I'm thinking something like:
- mxfp8 + fp8-param-gather: equivalent to old fp8-param-gather+reuse, deprecate the reuse flag, support layerwise
- non-mxfp8 + fp8-param-gather: equivalent to before, layerwise not supported
There was a problem hiding this comment.
I agree with the proposal but have a concern regarding the concurrent use of FP8 parameter gather, reuse gradient reduce buffer, and extra overlap gradient reduce. Our testing showed that enabling these features simultaneously can trigger NaN issues.
Currently, customers can workaround this by enabling reuse gradient reduce while keeping overlap parameter gather disabled. The proposed design would remove this bypass.
@kunlunl Could you confirm if this issue has been resolved or if there is a PR tracking it?
There was a problem hiding this comment.
As far as I know, all NaN bugs related to mxfp8 / fp8 param gather / reuse grad buf / overlap param gather are resolved in dev branch, except an open PR (#4994) that also will be merged soon.
And I think the proposal by @FDecaYed doesn't force use overlap param gather. It just deprecate the reuse_xxx, because it's actually not an option, if you want to use fp8 param + mxfp8, you must enable it. if you want to use fp8 param + other fp8 recipe, you must disable reuse_xxx, so no meaning to keep it.
| model_chunk._start_bucket_group_param_sync(lw_bucket_group, force_sync=False) | ||
|
|
||
| @torch.no_grad() | ||
| def _write_owned_mxfp8_masters_to_param_buffer(self) -> None: |
There was a problem hiding this comment.
I have a hunch that this is all over-complicated because we originally designed layerwise optimizer to 'replace' dist-opt, so it inherited chainedoptimizer, while distopt inherited mixprecision optimizer
since now we already made up our mind to get layerwise/distopt work alongside each other, maybe we change that? so something like a chained optimizer chaining both layerwise and distopt instance. This way, both can implement _copy_main_params_to_param_buffer that being invoked with same code path?
That said, the code here make sense, and if we want to get this in first and do refactor later, I'm also fine with that
There was a problem hiding this comment.
To summarise here, we have 2 pending refactorings:
- Unify
reuse_grad_buf_for_mxfp8_param_agandfp8_param_gather. - Let LayerwiseDistOpt inherit from MixedPrecisionOpt instead of ChainedOpt
I would suggest to keep this PR simple and support the above refactorings in another PR.
…aram_gather mode PR 4987 enabled fp8_param_gather for the LayerWise (muon) optimizer with --fp8-recipe mxfp8 + --reuse-grad-buf-for-mxfp8-param-ag, but only validated overlap_param_gather=False. With overlap_param_gather=True every muon-managed 2D weight froze at init (loss flat ~12.5, grad norm pinned ~205) while Adam (standard DistributedOptimizer) was unaffected. Root cause: reuse_grad_buf_for_mxfp8_param_ag aliases the bf16 param-AG staging buffer onto the grad buffer. In overlap mode the LayerWise param all-gather + post-AG copy into the MXFP8 param.data is deferred to the next forward's pre-hook, but zero_grad_buffer() runs at the start of that iteration and zeroes the shared buffer first. The deferred AG then ships a zeroed staging buffer and the post-AG copy overwrites every muon param.data with zeros. The standard DistributedOptimizer avoids this because train_step re-stages its masters via _copy_main_params_to_param_buffer() AFTER zero_grad_buffer() — but that loop was gated to isinstance(DistributedOptimizer) and skipped the LayerWise optimizer, making the bug muon-only. Fix: give LayerWiseDistributedOptimizer a duck-typed _copy_main_params_to_param_buffer() that re-writes the owned MXFP8 masters into the staging buffer post-zero, and widen the train_step post-zero loop to isinstance(opt, (DistributedOptimizer, LayerWiseDistributedOptimizer)) so both optimizers are handled by a single polymorphic call. Verified on DSv4 Flash Proxy / GB200x16 / TP1 PP2 EP8 / mxfp8 / muon: overlap + fp8_param_gather now descends matching the fp8pg-off and non-overlap baselines, with normal grad norms. The deterministic bitwise ON-vs-OFF unit test (test_muon_mxfp8_fp8_param_gather.py) passes.
8d093ad to
477dc5a
Compare
|
Reviews (1): Last reviewed commit: "fix(layer_wise): support fp8_param_gathe..." | Re-trigger Greptile |
|
/claude review |
There was a problem hiding this comment.
One bug: the eval-path _copy_main_params_to_param_buffer call at training.py:3697 was not widened to include LayerWiseDistributedOptimizer, unlike the training-loop instance at line 2222 which this PR correctly updated. With reuse_grad_buf_for_mxfp8_param_ag + overlap_param_gather, eval will all-gather zeroed buffers for muon-managed MXFP8 weights.
See inline comment for details and the fix.
…gather test
Parametrize the bitwise ON-vs-OFF test over overlap_param_gather in
{False, True}. The True case is the regression guard for the frozen-loss
bug on the deferred forward-pre-hook all-gather path: the bf16 staging
buffer is aliased onto the grad buffer that zero_grad_buffer() zeroes each
iteration, so without re-staging the masters post-zero the all-gather ships
zeros and the muon-managed weights never update.
_run_steps now mirrors train_step's post-zero_grad_buffer re-stage
(_copy_main_params_to_param_buffer on each DistributedOptimizer /
LayerWiseDistributedOptimizer) when reuse_grad_buf_for_mxfp8_param_ag and
overlap_param_gather are set; DDP auto-registers the forward pre-hook when
overlap_param_gather is True. --overlap-param-gather requires
--overlap-grad-reduce, so the two are co-enabled.
Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
…h paths The training-loop re-stage (post-zero_grad_buffer _copy_main_params_to_param_buffer) was widened to LayerWiseDistributedOptimizer, but two structurally identical sites were missed: - training.py eval block: with reuse_grad_buf_for_mxfp8_param_ag + overlap_param_gather, the pre-eval zero_grad_buffer() zeroes the staging buffer aliased onto the grad buffer, then only DistributedOptimizer params were re-staged. disable_forward_pre_hook( param_sync=True) would then all-gather zeroed buffers for muon-managed MXFP8 weights during eval. - paged_stash.py _try_copy_main_params: reachable when moe_expert_rank_capacity_factor is set together with reuse_grad_buf_for_mxfp8_param_ag + overlap_param_gather (copy_main_params is gated on exactly those flags); LayerWise was skipped, leaving its staging buffer zeroed. Both now re-stage LayerWiseDistributedOptimizer too. _copy_main_params_to_param_buffer self-guards on use_buffer_param_sync + reuse_grad_buf_for_mxfp8_param_ag, so the call is a no-op for configs that don't need it. Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
|
/claude review |
Summary
--fp8-param-gather+--reuse-grad-buf-for-mxfp8-param-ag(mxfp8) was already wired for the standardDistributedOptimizer, but the muonLayerWiseDistributedOptimizerwas missing the pieces needed to take that same path. This PR fills those gaps so--fp8-param-gatherworks for muon-managed 2D weights when--fp8-recipe=mxfp8, on top of #4889's muon-aware bucket alignment, with bitwise parity against the--fp8-param-gather=offbaseline.Design principle: reuse the DistOpt machinery instead of growing a parallel LayerWise path. The bf16-staging write is exposed as a duck-typed
_copy_main_params_to_param_buffer()sotrain_stepdrives both optimizers through one code path, and the per-bucket param-sync filtering mirrors DistOpt's exactly.Background: with this feature the model params are kept as persistent MXFP8 tensors; the bf16
param_bufferis aliased onto the idle grad buffer; the all-gather ships bf16 and then quantizes back to MXFP8 in place on every rank. Every gap below is a consequence of the muon/LayerWise path not yet handling some part of that lifecycle that DistOpt already did.The gaps fall into four areas.
Area 1 — Master-weight numerics for persistent FP8 params
These two gaps made the muon loss drift from the
fp8_param_gather=offbaseline. Both exist only because the params are now persistent MXFP8 tensors.1a — fp32 master was initialized from dequantized MXFP8.
The inner
Float16OptimizerWithFloat16Params.__init__builds the master viaparam.detach().clone().float(). For an MXFP8 param,.float()dequantizes, so the master is born from FP8-rounded values and carries ~FP8 noise from iter 0. DistOpt avoids this by readingmodel_param.get_high_precision_init_val()(the bf16 init TE preserves on CPU). Without the fix, the iter-10 lm-loss gap is ≈0.27 nats.→ filled by
_restore_high_precision_init_val(mirrorsdistrib_optimizer.py:405-416).1b — the inner step re-quantized MXFP8 params redundantly.
_copy_main_params_to_model_paramswould runmodel.data.copy_(main_fp32)on each MXFP8 model param, triggering an extraQuantizedTensorquantize_(fp32)that the post-AGquantize_(bf16)then overwrites — wasted work that also perturbs muon convergence. DistOpt'sreuse_grad_bufstep path never touches MXFP8 storage at the inner step.→ filled by
_skip_mxfp8_in_copy_main_to_model(+ a shallow-copiedinner_configthat hidesreuse_grad_buf_for_mxfp8_param_agfrom the inner optimizer;self.configis restored aftersuper().__init__so the outer ChainedOptimizer's shared-config assertion still holds).Area 2 — Staging the bf16 param buffer before the all-gather
With
reuse_grad_buf_for_mxfp8_param_agthe bf16param_bufferis separate from the MXFP8 storage, and the all-gather shipsparam_buffer. Two gaps left it holding the wrong bytes.2a — the updated master was never written into the staging buffer.
The optimizer step updates the fp32 master, but nothing copied it (cast to bf16) into
param_bufferbefore the AG, so the AG shipped stale bytes.→ filled by
_copy_main_params_to_param_buffer— a duck-typed twin ofDistributedOptimizer._copy_main_params_to_param_buffer, sotrain_stepcan call it polymorphically over(DistributedOptimizer, LayerWiseDistributedOptimizer). Self-guards onuse_buffer_param_sync+reuse_grad_buf_for_mxfp8_param_ag, so it is a no-op for configs that don't need it.2b — the staging buffer is zeroed before the deferred all-gather (overlap path).
Because the staging buffer is aliased onto the grad buffer, the next iteration's
zero_grad_buffer()zeroes it after the in-step write but before the deferred forward-pre-hook all-gather. Under--overlap-param-gatherthe AG therefore shipped a zeroed buffer and the muon-managed weights never updated (loss frozen). The fix is not to disable overlap, but to re-stage post-zero on every site that force-syncs through the param buffer:train_step(training loop) — the same hook DistOpt already uses;disable_forward_pre_hook(param_sync=True)force-syncs params for eval);paged_stash_try_copy_main_params(reachable whenmoe_expert_rank_capacity_factoris set withreuse_grad_buf + overlap).All three now re-stage
LayerWiseDistributedOptimizertoo, not justDistributedOptimizer.Area 3 — Disjoint param-sync when LayerWise and DistOpt coexist
muon manages 2D weights via LayerWise; everything else goes through a sibling
DistributedOptimizer(#4771). They share the same DDP buffers, so the gap is making each optimizer all-gather only its own buckets.3a — per-bucket ownership / mixed bucket groups.
A bucket group can be mixed-ownership: in
partition_bucketsCase 3 (FP8 present,reduce_scatter_with_fp32_accumulation=False) non-FP8 DistOpt-managed bf16 buckets (biases, layernorms) get appended into the last FP8 bucket group. Checking onlybuckets[0]and dispatching AG on the whole group double-syncs the foreign buckets. Both optimizers now filter per-bucket and, for mixed groups, synthesize a thin transient bucket group containing only the buckets they own. Applied symmetrically in three places:DistributedOptimizer.start_param_sync_for_bucket_group_subset,LayerWiseDistributedOptimizer.start_param_sync_for_bucket_group_subset, and the deferred-sync path inoptimizer.py(which also recurses the chained-optimizer tree so nested DistOpt instances are reached).Area 4 — Fail-fast configuration guards
The remaining gap was silent misconfiguration: combinations that would quietly all-gather stale storage.
4a — mxfp8 + LayerWise +
--fp8-param-gathernow requires--reuse-grad-buf-for-mxfp8-param-ag(the two must be co-enabled).4b — non-mxfp8 + LayerWise +
--fp8-param-gatheris explicitly rejected (only MXFP8's_rowwise/_columnwiseround-trip is wired up today).Deliberately NOT in this PR
quantizer.set_usage(rowwise=True, columnwise=True)before each AG, guessing TE's bwd toggledcolumnwise_usage. Running the bitwise test indeterministic_mode()without that re-assert still reproducesatol=rtol=0ON-vs-OFF parity, proving it unnecessary; the helper was removed to keep the LayerWise surface minimal.--fp8-param-gather— rejected at arg-parse (4b) and skipped in the test; needs a separate_rowwise/_columnwiselayout + post-AG path.Test plan
tests/unit_tests/test_muon_mxfp8_fp8_param_gather.py::test_on_vs_off_bitwise_identical— builds the same small GPT twice (fp8_param_gather=False/True), restores the ON run's init state from the OFF snapshot so both start bit-identical, runs 5 muon steps indeterministic_mode(), and assertsatol=rtol=0on loss, forward output, per-parametermain_grad, fp32 master, and (non-MXFP8) model params. Parametrized overoverlap_param_gather ∈ {False, True}— theTruecase is the regression guard for Area 2b. Skips automatically when MXFP8 is unsupported (pre-Blackwell, sm_120+, TE without MXFP8). Verified on GB200:[mxfp8-False]and[mxfp8-True]both pass; blockwise variants skip by design.Changes by file
megatron/core/optimizer/layer_wise_optimizer.pymegatron/core/optimizer/distrib_optimizer.pymegatron/core/optimizer/optimizer.pymegatron/training/training.pymegatron/core/transformer/moe/paged_stash.pymegatron/training/arguments.pytests/unit_tests/test_muon_mxfp8_fp8_param_gather.py🤖 Generated with Claude Code