Skip to content

Integrate LayerWiseDistributedOptimizer with DDP buffer infrastructure#4509

Merged
deepakn94 merged 4 commits into
NVIDIA:mainfrom
deepakn94:dnarayanan/layerwise_param_layout
May 13, 2026
Merged

Integrate LayerWiseDistributedOptimizer with DDP buffer infrastructure#4509
deepakn94 merged 4 commits into
NVIDIA:mainfrom
deepakn94:dnarayanan/layerwise_param_layout

Conversation

@deepakn94

@deepakn94 deepakn94 commented Apr 29, 2026

Copy link
Copy Markdown
Contributor

Summary

Adds a shard-aligned parameter layout for LayerWiseDistributedOptimizer that guarantees no parameter is split across shard boundaries — the invariant optimizers like Muon need so Newton-Schulz iteration can run on full weight matrices. The optimizer publishes this layout to DDP, which means DDP can manage the buffers exactly as it does for DistributedOptimizer:

  • Gradient reduction: reduce-scatter via use_distributed_optimizer=True, replacing the previous all-reduce-then-pick-your-shard scheme.
  • Parameter sync: DDP's standard buffer all-gather via model_chunk.start_param_sync(), replacing the legacy flatten / all_gather_v / unflatten path that copied data on both ends.

Why

Before this PR, LayerWiseDistributedOptimizer ran outside DDP's buffer infrastructure:

  • DDP was configured with use_distributed_optimizer=False, so gradients were all-reduced rather than reduce-scattered. Each rank held a full reduced gradient buffer even though it would only update its own shard.
  • Parameter sync was a hand-rolled all_gather_v that flattened each rank's params into a contiguous tensor, all-gathered, then unflattened and copy_-ed the results back into model params. Two extra full-size copies per sync.

Both fell out of "the optimizer doesn't know how the buffer is laid out, so DDP can't be told to slice it." Once the optimizer can publish a layout that guarantees "every param fits inside one DP shard," DDP's existing reduce-scatter and buffer all-gather work directly.

Approach

Shard-aligned size-matching layout

LayerWiseDistributedOptimizer._compute_per_buffer_param_layout produces a shard-aligned layout per buffer using a size-matching algorithm: each round claims one param for shard 0 and assigns same-sized params (or padding) to shards 1..dp-1, so all shards grow by exactly the same amount in lockstep. Shard sizes stay equal by construction; for repeated-layer models (N identical transformer blocks) padding overhead is zero. Every param ends up fully contained within a single DP-rank's shard — the invariant Muon's whole-tensor update requires.

Shared (tied) embeddings get isolated buckets — they are placed alone in shard 0 of their own bucket, with shards 1..dp-1 padded to the same size. This preserves the no-shard-crossing invariant but pays a (dp_size - 1) * pad(numel) cost per shared embedding. Eliminating that cost is the goal of a follow-up PR that routes Adam-managed params through a separate DistributedOptimizer.

Optimizer state partitioning derived from the layout

_shard_params_from_layout derives each parameter's owning rank directly from the published layout ((start - bucket_start) // shard_size) instead of the previous independent ping-pong-by-numel assignment. Optimizer and DDP now agree on shard assignments by construction, eliminating the class of bugs where the two could disagree and the optimizer would step on a parameter whose gradient was reduce-scattered to a different rank. A defensive assertion in _shard_params_from_layout catches any param straddling a shard boundary, so layout regressions fail loudly rather than silently corrupting the emerging optimizer's update.

Step flow

DDP is now configured with use_distributed_optimizer=True for layerwise mode, so model params are views into bucket.param_data and grads are reduce-scattered into each rank's shard of the gradient buffer during backward. On step:

  1. LayerWise.step runs the Muon update on its local-rank shard of every layerwise buffer; the standard "main → model_param" copy updates the param buffer in place (because model_param is a view into the buffer).
  2. model_chunk.start_param_sync(force_sync=True) syncs the buffer across DP ranks via DDP's standard all-gather — no flatten/unflatten copies.

DDP wiring centralised in a helper

wrap_model_chunks_with_ddp (new, in megatron/training/training.py) centralises the DDP-construction logic shared between get_model and unit tests:

  • Forces ddp_config.use_distributed_optimizer = True when use_layer_wise_distributed_optimizer=True (needed for reduce-scatter).
  • Computes the per-chunk full_param_layout via LayerWiseDistributedOptimizer.compute_full_param_layout (layerwise) or DistributedOptimizer.compute_full_param_layout (standard distopt).
  • Wraps each chunk with the layout passed through.

Legacy paths kept as fallback (for the no-layout case)

The pre-PR sync paths are retained for callers that don't supply a layout yet and are marked for removal once all call sites pass one:

  • LayerWiseDistributedOptimizer.allgather_params (flatten / all_gather_v / unflatten path).
  • LayerWiseDistributedOptimizer.set_bucket_layerwise_params_list.
  • LayerWiseDistributedOptimizer._shard_params_ping_pong (the ping-pong-by-numel fallback used when no layout is supplied).
  • The variable-size all-gather branch in _ParamAndGradBucketGroup.start_param_sync.

Tests

  • New tests/unit_tests/distributed/test_layerwise_param_layout.py: covers the size-matching layout (uniform / mixed / shared-embedding / dp-divisibility / backprop ordering / dp_size ∈ {1, 2, 4, 8}).
  • test_emerging_optimizers and dist_checkpointing/test_layer_wise_optimizer: updated to construct DDP via the new wrap_model_chunks_with_ddp helper so the layerwise wiring matches training.get_model.

Test plan

  • Verify test_layerwise_param_layout.py passes (shard divisor, size-matching layout, shared-embedding isolation, bucket alignment, backprop ordering, full layout grouping)
  • Verify existing test_param_layout.py tests still pass
  • Verify test_emerging_optimizers tests pass against LayerWiseDistributedOptimizer
  • Verify dist_checkpointing/test_layer_wise_optimizer tests pass
  • Run Muon training with --use-layer-wise-distributed-optimizer and verify convergence matches the existing layerwise baseline
  • Verify reduce-scatter is used (not allreduce) when the layerwise optimizer is active — i.e. ddp_config.use_distributed_optimizer=True is in effect and grads are reduce-scattered into per-rank shards
  • Cluster: small Muon training run at dp_size > 1 to verify gradient flow + convergence end-to-end
image

Follow-up

Routing Adam-managed parameters through a separate DistributedOptimizer (sub-tensor sharding) so tied embeddings avoid the (dp_size - 1) * pad(numel) per-rank padding cost. That work is staged on a separate branch and will land as a follow-up PR.

@copy-pr-bot

copy-pr-bot Bot commented Apr 29, 2026

Copy link
Copy Markdown

Auto-sync is disabled for draft pull requests in this repository. Workflows must be run manually.

Contributors can view more details about this message here.

Comment thread megatron/core/optimizer/layer_wise_optimizer.py Outdated
@deepakn94

Copy link
Copy Markdown
Contributor Author

/claude review

@claude claude Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@deepakn94 deepakn94 force-pushed the dnarayanan/layerwise_param_layout branch from f834483 to bfb4f2c Compare April 29, 2026 14:57
@deepakn94

Copy link
Copy Markdown
Contributor Author

/claude review

@deepakn94 deepakn94 marked this pull request as ready for review April 29, 2026 14:59
@deepakn94 deepakn94 requested review from a team as code owners April 29, 2026 14:59
@svcnvidia-nemo-ci svcnvidia-nemo-ci requested a review from a team April 29, 2026 14:59

@claude claude Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@deepakn94 deepakn94 force-pushed the dnarayanan/layerwise_param_layout branch from 24d6864 to eb82980 Compare April 29, 2026 18:40
@deepakn94 deepakn94 force-pushed the dnarayanan/layerwise_param_layout branch from eb82980 to 47ff44d Compare April 29, 2026 19:33
@jaredcasper jaredcasper requested review from FDecaYed and skyw May 1, 2026 21:47
@deepakn94 deepakn94 force-pushed the dnarayanan/layerwise_param_layout branch from 47ff44d to 8d2876b Compare May 6, 2026 17:46
@deepakn94

Copy link
Copy Markdown
Contributor Author

Addressed review comment: renamed S to param_numel throughout _compute_per_buffer_param_layout and updated the docstring to match distrib_optimizer.py conventions.

Also in the latest push: all buffers now use the shard-aligned layout (not just Muon buffers). This ensures no param is split across shard boundaries regardless of optimizer type. The separate BufferKey grouping still keeps Muon and Adam params in different buffers to minimize size-matching padding.

@deepakn94 deepakn94 force-pushed the dnarayanan/layerwise_param_layout branch from 8d2876b to c2a0018 Compare May 7, 2026 05:11
@deepakn94 deepakn94 force-pushed the dnarayanan/layerwise_param_layout branch from c2a0018 to 8b54011 Compare May 8, 2026 16:33
@deepakn94 deepakn94 force-pushed the dnarayanan/layerwise_param_layout branch from 8b54011 to 60b1ef7 Compare May 8, 2026 19:36
@deepakn94 deepakn94 changed the title Shard-aligned param layout for layerwise distributed optimizer Integrate LayerWiseDistributedOptimizer with DDP buffer infrastructure May 8, 2026
@svcnvidia-nemo-ci svcnvidia-nemo-ci added the Approved All necessary approvals have been made label May 12, 2026
deepakn94 and others added 4 commits May 12, 2026 13:52
Adds a shard-aligned parameter layout for LayerWiseDistributedOptimizer that
guarantees no parameter is split across shard boundaries — the invariant
optimizers like Muon need so Newton-Schulz iteration can run on full weight
matrices. The optimizer publishes this layout to DDP, which means DDP can
manage the buffers exactly as it does for DistributedOptimizer:

- Gradient reduction: reduce-scatter via use_distributed_optimizer=True,
  replacing the previous all-reduce-then-pick-your-shard scheme.
- Parameter sync: DDP's standard buffer all-gather via start_param_sync(),
  replacing the legacy flatten / all_gather_v / unflatten path.

Existing layerwise optimizer tests were patched to also exercise the new
code path. Code that computes param_layouts is separately tested.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Adds a use_layer_wise_param_layout=True kwarg to wrap_model_chunks_with_ddp.
Layout computation now requires both use_layer_wise_distributed_optimizer
and use_layer_wise_param_layout. The get_model production call site passes
False so live training runs stay on the legacy LayerWise sync path.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
When the size-matching loop pops a seed with no exact-numel peers (e.g. an
embedding), the remaining shard slots are greedily packed with the next
unassigned smaller params from the queue (respecting 64-element alignment)
instead of being filled with pure padding. For a unique-large seed at the
top of the backprop pool, this turns ``(dp_size - 1) * param_numel`` of
empty padding into productive bucket content. Also renames the
param-layout test file to ``test_layer_wise_param_layout.py`` to match the
``layer_wise_`` convention used by the optimizer module.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
@svcnvidia-nemo-ci

Copy link
Copy Markdown

🔄 Merge queue validation started!

You can track the progress here: https://github.com/NVIDIA/Megatron-LM/actions/runs/25770105108

@svcnvidia-nemo-ci

Copy link
Copy Markdown

🔄 Merge queue validation started!

You can track the progress here: https://github.com/NVIDIA/Megatron-LM/actions/runs/25776760140

@github-merge-queue github-merge-queue Bot removed this pull request from the merge queue due to failed status checks May 13, 2026
@deepakn94 deepakn94 added this pull request to the merge queue May 13, 2026
@svcnvidia-nemo-ci

Copy link
Copy Markdown

🔄 Merge queue validation started!

You can track the progress here: https://github.com/NVIDIA/Megatron-LM/actions/runs/25805685255

Merged via the queue into NVIDIA:main with commit c1e938b May 13, 2026
75 checks passed
@hxbai hxbai mentioned this pull request May 29, 2026
3 tasks
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Approved All necessary approvals have been made complexity: medium

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants