Add preliminary Muon+M-FSDP support#4486
Conversation
|
Auto-sync is disabled for draft pull requests in this repository. Workflows must be run manually. Contributors can view more details about this message here. |
3dbd0b4 to
be93e7f
Compare
de7dbb0 to
a74bb57
Compare
| if cached_indices is not None: | ||
| return cached_indices | ||
|
|
||
| has_mfsdp_params = any(getattr(param, "orig_param", None) is not None for param in params) |
There was a problem hiding this comment.
BTW this isn't a good way to tell if we are using Megatron-FSDP, plus we really only care that the DTensor.shape != DTensor._local_tensor.shape for when we need to all-gather, which does not care about MFSDP.
| local_chunks_info = [ | ||
| {"shape": torch.Size(local_tensor.shape), "offset": tuple(local_chunk_metadata.offsets)} | ||
| ] | ||
| shard_group = dtensor_ref.device_mesh.get_group(shard_mesh_dims[0]) |
There was a problem hiding this comment.
This doesn't look right, why is the shard group just the first mesh dim? HFSDP has 2 sharding groups, for instance? (Or you need to flatten them.)
| torch.empty(chunk_info["numel"], dtype=local_tensor.dtype, device=local_tensor.device) | ||
| for chunk_info in plan["chunk_infos"] | ||
| ] | ||
| torch.distributed.all_gather(group_tensors, local_buffer, group=plan["shard_group"]) |
There was a problem hiding this comment.
See above, I don't know how this could work with HFSDP. You need to all-gather across two groups, not one. chunk_info["numel"] looks to be a full-shard, and I assume this all-gathers into a partial shard, likely in the wrong order (inner first [1], then outer [0]).
Route the emerging-optimizer factory through a Megatron-FSDP-specific
path when `ddp_config.use_megatron_fsdp` is set. Megatron-FSDP attaches
grads via `finish_grad_sync()` on DTensor params instead of via DDP's
main_grad buffers, so the standard `Float16OptimizerWithFloat16Params`
wrapper does not apply; we always wrap with `FP32Optimizer` instead and
drive the FSDP step contract from a thin `FSDPMuonChainedOptimizer`
adapter that calls `finish_grad_sync()` and
`install_optimized_model_weights()` around the inner step.
For now this supports ZeRO-0 ("no_shard") only; ZeRO-1/2/3 will work
without errors on the wiring but require a sharding-aware Muon variant
for numerical correctness, added in a follow-up.
Also patch `LayerWiseDistributedOptimizer._allgather_helper` to read
DTensor-backed params via `_local_tensor`, so the layer-wise + FSDP
combination can flatten the local shard rather than the global DTensor.
Add `FSDPZeROTensorParallelMuon`, a TensorParallelMuon subclass that: 1. Extracts the `Shard(0)` local tensor from each gradient DTensor: (`finish_grad_sync` produces a row-shard per DP rank for `optim`, `optim_grads` and `optim_grads_params`). 2. Allgathers the shards across the DP group to reconstruct the TP-local, DP-full gradient matrix. 3. Trims FSDP bucket-padding rows using the DTensor's declared global shape. 4. Delegates Newton-Schulz to the parent class (which handles the TP dimension via `newton_schulz_tp`). 5. Re-shards the orthogonalized result back to a `Shard(0)` DTensor with matching placements so the in-place update in `OrthogonalizedOptimizer.step` does not promote to `Replicate` and trip the global-shape check. The FSDP factory in `_build_megatron_fsdp_emerging_optimizer` now picks `FSDPZeROTensorParallelMuon` for any sharded inner-DP strategy and passes `pg_collection.dp_cp` for dense params and `pg_collection.expt_dp` for expert params (since expert grads reduce-scatter over a different group). "no_shard" continues to use plain `TensorParallelMuon`. DTensor is imported at module scope with a `_HAVE_DTENSOR` guard so the isinstance checks stay cheap and the module still imports on stacks without `torch.distributed.tensor`.
Three phases of tests for the Muon + Megatron-FSDP integration: - Phase 1: `FSDPMuonChainedOptimizer` adapter (single-rank, mock-based). Verifies the step contract – finish_grad_sync -> inner step -> install_optimized_model_weights – and attribute delegation. - Phase 2: `FSDPZeROTensorParallelMuon.orthogonalize` (multi-rank). Asserts the allgather -> Newton-Schulz -> reshard cycle is numerically equivalent to running NS on the full gradient and extracting the local row-shard, including FSDP padding edge cases. Includes a DTensor round-trip test that catches the `p.add_(orthogonalized_dtensor)` placement-promotion bug. - Phase 3: `_build_megatron_fsdp_emerging_optimizer` factory. Confirms the factory dispatches plain `TensorParallelMuon` for `no_shard` and `FSDPZeROTensorParallelMuon` for sharded strategies, and that expert vs. non-expert Muon instances receive `expt_dp` vs. `dp_cp` as their allgather group.
Signed-off-by: Cory Ye <cye@nvidia.com>
Signed-off-by: Cory Ye <cye@nvidia.com>
The `full_tensor` should be reconstructed fully.
Also - fix `--empty-unused-memory-level` default. Setting it to 1 is too detrimental to performance; setting it to 0 still has relatively stable throughput, - assert that the full tensor reconstruction actually covers the full tensor.
Trade peak memory for slightly increased throughput.
Also make batched gathering optional due to its effect on peak memory.
- Improve boundary detection - Do not perform NS on ranks with empty shards
... by replacing `DTensor.to_local()` with `DTensor._local_tensor`. As suggested by @cspades.
Keeps M-FSDP gradients as DTensors when collecting grads for norm/zero-count stats, then makes the grad-stat helpers reduce over the actual DTensor shard mesh dimensions. This fixes incorrect or duplicated reductions for FSDP/HSDP/HFSDP, especially avoiding double-counting replicated outer-DP ranks while still including all sharded data.
Fixes an HSDP/HFSDP numerics bug where the inner-DP gradient reduction could complete without the required outer-DP reduction before optimizer gradients were attached. It tracks buckets that still need outer-DP reduction and completes that work during gradient synchronization/reset, so HSDP/HFSDP see correctly reduced gradients before optimizer steps.
When using Muon, we also have to use Adam. Using both leads to calling `install_optimized_model_weights` twice, once after each optimizer. We now delay the call until all `chained_optimizers` in the `ChainedOptimizer` have run.
Introduce Muon support to M-FSDP. Currently 1.5×–2.7× as slow compared to an Adam baseline with a 1B–8B DeepSeek-V3 proxy model. Peak memory slightly lower than with Adam (4–7 % less).