Skip to content

feat(optim): support --fp8-param-gather for muon + mxfp8 in LayerWise#4987

Open
Wohox wants to merge 17 commits into
NVIDIA:mainfrom
Wohox:wohox/muon-mxfp8-param-gather-main
Open

feat(optim): support --fp8-param-gather for muon + mxfp8 in LayerWise#4987
Wohox wants to merge 17 commits into
NVIDIA:mainfrom
Wohox:wohox/muon-mxfp8-param-gather-main

Conversation

@Wohox

@Wohox Wohox commented May 26, 2026

Copy link
Copy Markdown
Contributor

Summary

--fp8-param-gather + --reuse-grad-buf-for-mxfp8-param-ag (mxfp8) was already wired for the standard DistributedOptimizer, but the muon LayerWiseDistributedOptimizer was missing the pieces needed to take that same path. This PR fills those gaps so --fp8-param-gather works for muon-managed 2D weights when --fp8-recipe=mxfp8, on top of #4889's muon-aware bucket alignment, with bitwise parity against the --fp8-param-gather=off baseline.

Design principle: reuse the DistOpt machinery instead of growing a parallel LayerWise path. The bf16-staging write is exposed as a duck-typed _copy_main_params_to_param_buffer() so train_step drives both optimizers through one code path, and the per-bucket param-sync filtering mirrors DistOpt's exactly.

Background: with this feature the model params are kept as persistent MXFP8 tensors; the bf16 param_buffer is aliased onto the idle grad buffer; the all-gather ships bf16 and then quantizes back to MXFP8 in place on every rank. Every gap below is a consequence of the muon/LayerWise path not yet handling some part of that lifecycle that DistOpt already did.

The gaps fall into four areas.


Area 1 — Master-weight numerics for persistent FP8 params

These two gaps made the muon loss drift from the fp8_param_gather=off baseline. Both exist only because the params are now persistent MXFP8 tensors.

1a — fp32 master was initialized from dequantized MXFP8.
The inner Float16OptimizerWithFloat16Params.__init__ builds the master via param.detach().clone().float(). For an MXFP8 param, .float() dequantizes, so the master is born from FP8-rounded values and carries ~FP8 noise from iter 0. DistOpt avoids this by reading model_param.get_high_precision_init_val() (the bf16 init TE preserves on CPU). Without the fix, the iter-10 lm-loss gap is ≈0.27 nats.
→ filled by _restore_high_precision_init_val (mirrors distrib_optimizer.py:405-416).

1b — the inner step re-quantized MXFP8 params redundantly.
_copy_main_params_to_model_params would run model.data.copy_(main_fp32) on each MXFP8 model param, triggering an extra QuantizedTensor quantize_(fp32) that the post-AG quantize_(bf16) then overwrites — wasted work that also perturbs muon convergence. DistOpt's reuse_grad_buf step path never touches MXFP8 storage at the inner step.
→ filled by _skip_mxfp8_in_copy_main_to_model (+ a shallow-copied inner_config that hides reuse_grad_buf_for_mxfp8_param_ag from the inner optimizer; self.config is restored after super().__init__ so the outer ChainedOptimizer's shared-config assertion still holds).

Area 2 — Staging the bf16 param buffer before the all-gather

With reuse_grad_buf_for_mxfp8_param_ag the bf16 param_buffer is separate from the MXFP8 storage, and the all-gather ships param_buffer. Two gaps left it holding the wrong bytes.

2a — the updated master was never written into the staging buffer.
The optimizer step updates the fp32 master, but nothing copied it (cast to bf16) into param_buffer before the AG, so the AG shipped stale bytes.
→ filled by _copy_main_params_to_param_buffer — a duck-typed twin of DistributedOptimizer._copy_main_params_to_param_buffer, so train_step can call it polymorphically over (DistributedOptimizer, LayerWiseDistributedOptimizer). Self-guards on use_buffer_param_sync + reuse_grad_buf_for_mxfp8_param_ag, so it is a no-op for configs that don't need it.

2b — the staging buffer is zeroed before the deferred all-gather (overlap path).
Because the staging buffer is aliased onto the grad buffer, the next iteration's zero_grad_buffer() zeroes it after the in-step write but before the deferred forward-pre-hook all-gather. Under --overlap-param-gather the AG therefore shipped a zeroed buffer and the muon-managed weights never updated (loss frozen). The fix is not to disable overlap, but to re-stage post-zero on every site that force-syncs through the param buffer:

  • train_step (training loop) — the same hook DistOpt already uses;
  • the eval block (disable_forward_pre_hook(param_sync=True) force-syncs params for eval);
  • paged_stash _try_copy_main_params (reachable when moe_expert_rank_capacity_factor is set with reuse_grad_buf + overlap).
    All three now re-stage LayerWiseDistributedOptimizer too, not just DistributedOptimizer.

Area 3 — Disjoint param-sync when LayerWise and DistOpt coexist

muon manages 2D weights via LayerWise; everything else goes through a sibling DistributedOptimizer (#4771). They share the same DDP buffers, so the gap is making each optimizer all-gather only its own buckets.

3a — per-bucket ownership / mixed bucket groups.
A bucket group can be mixed-ownership: in partition_buckets Case 3 (FP8 present, reduce_scatter_with_fp32_accumulation=False) non-FP8 DistOpt-managed bf16 buckets (biases, layernorms) get appended into the last FP8 bucket group. Checking only buckets[0] and dispatching AG on the whole group double-syncs the foreign buckets. Both optimizers now filter per-bucket and, for mixed groups, synthesize a thin transient bucket group containing only the buckets they own. Applied symmetrically in three places: DistributedOptimizer.start_param_sync_for_bucket_group_subset, LayerWiseDistributedOptimizer.start_param_sync_for_bucket_group_subset, and the deferred-sync path in optimizer.py (which also recurses the chained-optimizer tree so nested DistOpt instances are reached).

Area 4 — Fail-fast configuration guards

The remaining gap was silent misconfiguration: combinations that would quietly all-gather stale storage.

4a — mxfp8 + LayerWise + --fp8-param-gather now requires --reuse-grad-buf-for-mxfp8-param-ag (the two must be co-enabled).
4b — non-mxfp8 + LayerWise + --fp8-param-gather is explicitly rejected (only MXFP8's _rowwise/_columnwise round-trip is wired up today).


Deliberately NOT in this PR

  • Quantizer dual-usage re-assert. An earlier revision re-asserted quantizer.set_usage(rowwise=True, columnwise=True) before each AG, guessing TE's bwd toggled columnwise_usage. Running the bitwise test in deterministic_mode() without that re-assert still reproduces atol=rtol=0 ON-vs-OFF parity, proving it unnecessary; the helper was removed to keep the LayerWise surface minimal.
  • Blockwise FP8 + LayerWise + --fp8-param-gather — rejected at arg-parse (4b) and skipped in the test; needs a separate _rowwise/_columnwise layout + post-AG path.

Test plan

  • tests/unit_tests/test_muon_mxfp8_fp8_param_gather.py::test_on_vs_off_bitwise_identical — builds the same small GPT twice (fp8_param_gather=False/True), restores the ON run's init state from the OFF snapshot so both start bit-identical, runs 5 muon steps in deterministic_mode(), and asserts atol=rtol=0 on loss, forward output, per-parameter main_grad, fp32 master, and (non-MXFP8) model params. Parametrized over overlap_param_gather ∈ {False, True} — the True case is the regression guard for Area 2b. Skips automatically when MXFP8 is unsupported (pre-Blackwell, sm_120+, TE without MXFP8). Verified on GB200: [mxfp8-False] and [mxfp8-True] both pass; blockwise variants skip by design.
  • End-to-end convergence smoke (DSV4 proxy, GB200, muon, e4m3 mxfp8): ON tracks the OFF baseline (iter-10 within ≈0.03 nats; ≈0.27 nats without Area 1a). Overlap ON: before Area 2b the loss was frozen; after, the overlap run tracks the non-overlap run.

Changes by file

File Area
megatron/core/optimizer/layer_wise_optimizer.py 1a, 1b, 2a, 2b, 3a
megatron/core/optimizer/distrib_optimizer.py 3a (per-bucket filter / mixed-group sub-group)
megatron/core/optimizer/optimizer.py 3a (deferred-sync per-bucket filter + recursion)
megatron/training/training.py 2b (train-loop + eval re-stage)
megatron/core/transformer/moe/paged_stash.py 2b (paged-stash re-stage)
megatron/training/arguments.py 4a, 4b
tests/unit_tests/test_muon_mxfp8_fp8_param_gather.py bitwise ON-vs-OFF, overlap on/off

🤖 Generated with Claude Code

Wires the standard distopt MXFP8 + reuse_grad_buf_for_mxfp8_param_ag
flow through LayerWiseDistributedOptimizer so muon-managed 2D weights
can be all-gathered via the bf16 staging path. Adds the following on
top of PR NVIDIA#4889:

- LayerWise inner Float16Optimizer is constructed with a shallow-copied
  config that flips reuse_grad_buf_for_mxfp8_param_ag to False (the
  inner step path expects DistOpt-only _copy_main_params_to_param_buffer;
  the LayerWise wrapper owns the bf16 staging write directly).
- _skip_mxfp8_in_copy_main_to_model filters MXFP8 model params out of
  the inner's _copy_main_params_to_model_params so the inner step does
  not perform a wasted MXFP8 quantize that the post-AG quantize would
  overwrite anyway.
- _write_owned_mxfp8_masters_to_param_buffer writes the bf16 cast of
  each owned MXFP8 param's fp32 master into the DDP bf16 staging buffer
  before the AG dispatches. Mirrors DistOpt._copy_main_params_to_param_buffer.
- start_param_sync_for_bucket_group_subset triggers the standard distopt
  buffer AG only on LayerWise-managed bucket groups, so a sibling
  DistributedOptimizer's own start_param_sync call is not duplicated.
- _restore_high_precision_init_val overwrites the fp32 master of any
  MXFP8 model param with the BF16 init values that TE preserves on CPU
  (model_param.get_high_precision_init_val()) before the quantize-wrap.
  Without this, the inner Float16Optimizer's default
  param.detach().clone().float() dequantizes the MXFP8 storage and
  bakes ~FP8 precision noise into the master from iter 0, which
  amplifies through the subsequent bf16⇒MXFP8 round-trip and yields a
  small muon loss lag vs the fp8_param_gather=False baseline that
  doesn't close even by iter 100. DistOpt does the equivalent fix-up
  at distrib_optimizer.py:405-416; this mirrors that path.
- training/arguments.py asserts --fp8-recipe=mxfp8 + LayerWise opt +
  --fp8-param-gather requires --reuse-grad-buf-for-mxfp8-param-ag so
  misconfiguration fails fast (the bf16 staging buffer for the MXFP8
  AG is the idle grad buffer; the two flags must be co-enabled).

Verified on OCI-HSG (1×4 GB200, DSV4 proxy, 100 iters):
  ON v7 iter-10 lm loss = 10.6085 vs OFF baseline 10.5746 (delta 0.034 nats)
  ON v7 iter-100 lm loss = 0.01605 vs OFF baseline 0.01673 (within noise)

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
@copy-pr-bot

copy-pr-bot Bot commented May 26, 2026

Copy link
Copy Markdown

This pull request requires additional validation before any workflows can run on NVIDIA's runners.

Pull request vetters can view their responsibilities here.

Contributors can view more details about this message here.

Wohox added 4 commits May 27, 2026 09:55
…ather

Adds tests/unit_tests/test_muon_mxfp8_fp8_param_gather.py asserting that
the LayerWise bf16⇒MXFP8 staging round-trip introduced by
--fp8-param-gather + --reuse-grad-buf-for-mxfp8-param-ag produces
bitwise-identical training trajectories vs --fp8-param-gather=off for the
same small GPT model.

Wraps both runs in deterministic_mode() (mirrors
tests/unit_tests/a2a_overlap/utils.py) and compares per-step loss,
forward output, per-parameter main_grad, and per-parameter fp32 master
with atol=rtol=0 across 5 steps. Skips when MXFP8 is not supported by
the device or TE version.

Without _restore_high_precision_init_val in
layer_wise_optimizer.py, the fp32 master created by inner Float16Optimizer
.detach().clone().float() dequantizes the MXFP8 storage and carries
~FP8 precision noise relative to the bf16 init — this test catches the
resulting per-step master divergence as a bitwise mismatch.
Ensure sibling DistributedOptimizer buckets inside mixed bucket groups are synchronized when LayerWise Muon uses MXFP8 param gather with grad-buffer reuse.
…e in test

Three follow-ups on the muon LayerWise FP8 param-gather path:

1. Add hard assertion in `arguments.py` that blocks `--fp8-param-gather`
   with any recipe other than mxfp8 when the optimizer auto-promotes to
   `LayerWiseDistributedOptimizer`. The bf16-staging + post-AG quantize
   round-trip is wired up only for MXFP8's `_rowwise_data` /
   `_columnwise_data` storage; using blockwise FP8 today would silently
   gather stale storage. Fail fast at arg-parse time.

2. Strip experiment-specific data from new comments in
   `layer_wise_optimizer.py` (per-iter loss numbers, cluster names,
   absolute nat lags). Keep the *what* and *why* of each helper so the
   intent stays clear without hard-coding ephemeral run details.

3. Make the unit test parametrize over `fp8_recipe in [mxfp8, blockwise]`
   and explicitly `pytest.skip` blockwise. Today the LayerWise path only
   supports mxfp8; the parametrization documents that surface while the
   skip keeps blockwise from accidentally running and producing a
   misleading bitwise mismatch.

Also fix the typo `type(a ra m)(` -> `type(bucket_group)(` in
`distrib_optimizer.py:3019` (was a paste artifact in the prior commit).
…rted

Add an explicit `check_mxfp8_support()` skipif on top of the existing
arch / TE-version / fp8-available guards so the test exits cleanly on
environments where MXFP8 is not actually available (e.g. pre-Blackwell
GPUs, sm_120+ which TE currently rejects, or builds without MXFP8
kernels), with TE's own diagnostic string as the skip reason.
@Wohox

Wohox commented May 27, 2026

Copy link
Copy Markdown
Contributor Author

/claude review

Comment thread megatron/core/optimizer/layer_wise_optimizer.py Outdated
Comment thread megatron/core/optimizer/layer_wise_optimizer.py Outdated
Comment thread tests/unit_tests/test_muon_mxfp8_fp8_param_gather.py Outdated
Comment thread tests/unit_tests/test_muon_mxfp8_fp8_param_gather.py Outdated

@claude claude Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thorough PR — the bug descriptions in the PR body are excellent. A few items:

Bug (medium): LayerWiseDistributedOptimizer.start_param_sync_for_bucket_group_subset still checks only buckets[0] to decide bucket group ownership, while this same PR upgraded the DistOpt side (distrib_optimizer.py:3006) and the deferred-sync path (optimizer.py:1336) to per-bucket filtering for mixed bucket groups. In partition_buckets Case 3, bf16 DistOpt-managed buckets can be merged into the last FP8 bucket group, causing LayerWise to AG the entire group (including DistOpt buckets) — a double-sync. See inline comment for a suggested fix.

Minor: Unused variable original_get_pairs in _skip_mxfp8_in_copy_main_to_model, and vocal_sizevocab_size typo in the new test file (lines 225, 244).

Wohox and others added 3 commits May 28, 2026 13:34
Co-authored-by: claude[bot] <209825114+claude[bot]@users.noreply.github.com>
Co-authored-by: claude[bot] <209825114+claude[bot]@users.noreply.github.com>
Address review comments on PR NVIDIA#4987:

1. ``LayerWise.start_param_sync_for_bucket_group_subset`` now filters
   buckets per-bucket instead of checking only ``buckets[0]``, and
   builds a sub-group containing just the LayerWise-managed buckets
   when the group is mixed-ownership. ``partition_buckets`` Case 3 (FP8
   present, ``reduce_scatter_with_fp32_accumulation=False``) merges
   non-FP8 DistOpt-managed bf16 buckets (biases, layernorms) into the
   last FP8 bucket group; without per-bucket filtering, LayerWise
   dispatched AG for the entire group while the sibling DistOpt also
   synced those same bf16 buckets via its own per-bucket filter,
   double-gathering them. Mirrors the per-bucket pattern that
   ``d1ae16a0a`` applied to the DistOpt deferred-sync path.

2. Drop the unused ``original_get_pairs = inner_opt
   ._get_model_and_main_params_data_float16`` line in
   ``_skip_mxfp8_in_copy_main_to_model``. Leftover from development;
   never referenced.
@Wohox

Wohox commented May 28, 2026

Copy link
Copy Markdown
Contributor Author

/claude review

@claude claude Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@Wohox Wohox requested review from FDecaYed and kunlunl May 28, 2026 07:00
@Wohox Wohox marked this pull request as ready for review May 28, 2026 07:08
@Wohox Wohox requested review from a team as code owners May 28, 2026 07:08
@yaox12 yaox12 mentioned this pull request May 28, 2026
3 tasks
@svcnvidia-nemo-ci svcnvidia-nemo-ci requested a review from a team May 28, 2026 07:08
Comment thread megatron/core/optimizer/distrib_optimizer.py

@wujingyue wujingyue left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the PR! I'll wait until expert reviews (esp dist-optimizer) are done.

Comment thread megatron/core/optimizer/distrib_optimizer.py
…ction

Address review comment on PR NVIDIA#4987: the per-bucket filter and the
``type(bucket_group)(distopt_buckets, …)`` sub-group construction added
in ``DistributedOptimizer.start_param_sync_for_bucket_group_subset`` had
no inline justification. Without that context it isn't obvious why we
walk one bucket at a time instead of dispatching the whole group, or
when the synthesized sub-group path is taken.

Extends the docstring to call out the ``partition_buckets`` Case 3
scenario where the last FP8 bucket group holds both DistOpt-managed
bf16 buckets (biases / layernorms) and LayerWise-managed FP8 buckets,
and adds inline comments on each of the three branches (entire group
LayerWise-owned, pure DistOpt group, mixed group) describing what is
dispatched and why the synthesized sub-group inherits the parent's
ddp_config / DP group / DP world size so the AG collective lands on
the same comm.

Pure documentation; no behaviour change.
# write is gated on ``reuse_grad_buf_for_mxfp8_param_ag``, so require
# the flag explicitly to fail fast on misconfiguration.
will_use_layer_wise_distributed_optimizer = (
args.optimizer not in ('sgd', 'adam') and args.use_distributed_optimizer

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@kunlunl to also comment.
My rough understanding is fp8-gather and reuse flags are recommended to be used together even before any muon changes for mxfp8? and at least for mxfp8 there is no reason to use only fp8-gather without reuse
Maybe let's take the opportunity to finally clean up these flag that doesn't quite show what they do as name suggests. I'm thinking something like:

  • mxfp8 + fp8-param-gather: equivalent to old fp8-param-gather+reuse, deprecate the reuse flag, support layerwise
  • non-mxfp8 + fp8-param-gather: equivalent to before, layerwise not supported

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree with the proposal but have a concern regarding the concurrent use of FP8 parameter gather, reuse gradient reduce buffer, and extra overlap gradient reduce. Our testing showed that enabling these features simultaneously can trigger NaN issues.

Currently, customers can workaround this by enabling reuse gradient reduce while keeping overlap parameter gather disabled. The proposed design would remove this bypass.

@kunlunl Could you confirm if this issue has been resolved or if there is a PR tracking it?

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As far as I know, all NaN bugs related to mxfp8 / fp8 param gather / reuse grad buf / overlap param gather are resolved in dev branch, except an open PR (#4994) that also will be merged soon.

And I think the proposal by @FDecaYed doesn't force use overlap param gather. It just deprecate the reuse_xxx, because it's actually not an option, if you want to use fp8 param + mxfp8, you must enable it. if you want to use fp8 param + other fp8 recipe, you must disable reuse_xxx, so no meaning to keep it.

Comment thread tests/unit_tests/test_muon_mxfp8_fp8_param_gather.py Outdated
model_chunk._start_bucket_group_param_sync(lw_bucket_group, force_sync=False)

@torch.no_grad()
def _write_owned_mxfp8_masters_to_param_buffer(self) -> None:

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have a hunch that this is all over-complicated because we originally designed layerwise optimizer to 'replace' dist-opt, so it inherited chainedoptimizer, while distopt inherited mixprecision optimizer
since now we already made up our mind to get layerwise/distopt work alongside each other, maybe we change that? so something like a chained optimizer chaining both layerwise and distopt instance. This way, both can implement _copy_main_params_to_param_buffer that being invoked with same code path?

That said, the code here make sense, and if we want to get this in first and do refactor later, I'm also fine with that

@Wohox Wohox Jun 3, 2026

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To summarise here, we have 2 pending refactorings:

  • Unify reuse_grad_buf_for_mxfp8_param_ag and fp8_param_gather.
  • Let LayerwiseDistOpt inherit from MixedPrecisionOpt instead of ChainedOpt

I would suggest to keep this PR simple and support the above refactorings in another PR.

@Wohox Wohox requested review from a team as code owners June 2, 2026 11:07
…aram_gather mode

PR 4987 enabled fp8_param_gather for the LayerWise (muon) optimizer with
--fp8-recipe mxfp8 + --reuse-grad-buf-for-mxfp8-param-ag, but only validated
overlap_param_gather=False. With overlap_param_gather=True every muon-managed
2D weight froze at init (loss flat ~12.5, grad norm pinned ~205) while Adam
(standard DistributedOptimizer) was unaffected.

Root cause: reuse_grad_buf_for_mxfp8_param_ag aliases the bf16 param-AG
staging buffer onto the grad buffer. In overlap mode the LayerWise param
all-gather + post-AG copy into the MXFP8 param.data is deferred to the next
forward's pre-hook, but zero_grad_buffer() runs at the start of that
iteration and zeroes the shared buffer first. The deferred AG then ships a
zeroed staging buffer and the post-AG copy overwrites every muon param.data
with zeros. The standard DistributedOptimizer avoids this because train_step
re-stages its masters via _copy_main_params_to_param_buffer() AFTER
zero_grad_buffer() — but that loop was gated to isinstance(DistributedOptimizer)
and skipped the LayerWise optimizer, making the bug muon-only.

Fix: give LayerWiseDistributedOptimizer a duck-typed
_copy_main_params_to_param_buffer() that re-writes the owned MXFP8 masters
into the staging buffer post-zero, and widen the train_step post-zero loop to
isinstance(opt, (DistributedOptimizer, LayerWiseDistributedOptimizer)) so both
optimizers are handled by a single polymorphic call.

Verified on DSv4 Flash Proxy / GB200x16 / TP1 PP2 EP8 / mxfp8 / muon:
overlap + fp8_param_gather now descends matching the fp8pg-off and
non-overlap baselines, with normal grad norms. The deterministic bitwise
ON-vs-OFF unit test (test_muon_mxfp8_fp8_param_gather.py) passes.
@Wohox Wohox force-pushed the wohox/muon-mxfp8-param-gather-main branch from 8d093ad to 477dc5a Compare June 2, 2026 11:10
@greptile-apps

greptile-apps Bot commented Jun 2, 2026

Copy link
Copy Markdown

Reviews (1): Last reviewed commit: "fix(layer_wise): support fp8_param_gathe..." | Re-trigger Greptile

@Wohox

Wohox commented Jun 3, 2026

Copy link
Copy Markdown
Contributor Author

/claude review

Comment thread megatron/training/training.py

@claude claude Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One bug: the eval-path _copy_main_params_to_param_buffer call at training.py:3697 was not widened to include LayerWiseDistributedOptimizer, unlike the training-loop instance at line 2222 which this PR correctly updated. With reuse_grad_buf_for_mxfp8_param_ag + overlap_param_gather, eval will all-gather zeroed buffers for muon-managed MXFP8 weights.

See inline comment for details and the fix.

Wohox and others added 2 commits June 3, 2026 09:32
…gather test

Parametrize the bitwise ON-vs-OFF test over overlap_param_gather in
{False, True}. The True case is the regression guard for the frozen-loss
bug on the deferred forward-pre-hook all-gather path: the bf16 staging
buffer is aliased onto the grad buffer that zero_grad_buffer() zeroes each
iteration, so without re-staging the masters post-zero the all-gather ships
zeros and the muon-managed weights never update.

_run_steps now mirrors train_step's post-zero_grad_buffer re-stage
(_copy_main_params_to_param_buffer on each DistributedOptimizer /
LayerWiseDistributedOptimizer) when reuse_grad_buf_for_mxfp8_param_ag and
overlap_param_gather are set; DDP auto-registers the forward pre-hook when
overlap_param_gather is True. --overlap-param-gather requires
--overlap-grad-reduce, so the two are co-enabled.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
…h paths

The training-loop re-stage (post-zero_grad_buffer _copy_main_params_to_param_buffer)
was widened to LayerWiseDistributedOptimizer, but two structurally identical sites
were missed:

- training.py eval block: with reuse_grad_buf_for_mxfp8_param_ag + overlap_param_gather,
  the pre-eval zero_grad_buffer() zeroes the staging buffer aliased onto the grad buffer,
  then only DistributedOptimizer params were re-staged. disable_forward_pre_hook(
  param_sync=True) would then all-gather zeroed buffers for muon-managed MXFP8 weights
  during eval.
- paged_stash.py _try_copy_main_params: reachable when moe_expert_rank_capacity_factor
  is set together with reuse_grad_buf_for_mxfp8_param_ag + overlap_param_gather (copy_main_params
  is gated on exactly those flags); LayerWise was skipped, leaving its staging buffer zeroed.

Both now re-stage LayerWiseDistributedOptimizer too. _copy_main_params_to_param_buffer
self-guards on use_buffer_param_sync + reuse_grad_buf_for_mxfp8_param_ag, so the call is
a no-op for configs that don't need it.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
@Wohox Wohox requested review from a team as code owners June 3, 2026 02:09
@Wohox

Wohox commented Jun 3, 2026

Copy link
Copy Markdown
Contributor Author

/claude review

@claude claude Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@dingqingy-nv dingqingy-nv added the deepseekv4 DeepSeek V4 PRs label Jun 11, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

7 participants