Skip to content

[1/N] Megatron FSDP: Introduce Megatron-FSDP2 with per-module fully_shard() API#4435

Open
shjwudp wants to merge 42 commits into
NVIDIA:mainfrom
shjwudp:mfsdp_refactor_main
Open

[1/N] Megatron FSDP: Introduce Megatron-FSDP2 with per-module fully_shard() API#4435
shjwudp wants to merge 42 commits into
NVIDIA:mainfrom
shjwudp:mfsdp_refactor_main

Conversation

@shjwudp

@shjwudp shjwudp commented Apr 22, 2026

Copy link
Copy Markdown
Contributor

What does this PR do ?

Summary

Introduce Megatron-FSDP2 — a refactored per-module sharding implementation (fully_shard_v2) with FSDP2-compatible API, communication-compute overlap, activation recompute support, and a pooled memory allocator for Megatron Core.

Key Features

  • fully_shard() API: FSDP2-compatible wrapping interface that converts modules to FSDPModule dynamically, groups parameters into ParameterGroups with flat DataParallelBuffers, and installs forward/backward hooks for the unshard → forward → reshard → backward → reduce lifecycle.
  • Communication-compute overlap (two streams):
    • Unshard prefetch (ag_stream): all-gathers parameters for the next module while the current module computes forward.
    • Async reduce-grad (rs_stream): reduce-scatters gradients while later modules compute backward. Sliding drain rule keeps ≤ 2 gradient buffers live at any time.
  • Activation recompute support: Derived backward_module tracking + persistent unshard_done_events prevent redundant all-gathers and premature resharding during gradient checkpointing recompute passes.
  • TracePoolAllocator: Three-phase (trace → plan → optimized) pooled memory allocator using greedy left-edge interval coloring. Eliminates per-call torch.empty overhead with a static pool replayed across micro-batches.
  • Four sharding strategies: no_shard, optim (ZeRO-1), optim_grads (ZeRO-2), optim_grads_params (ZeRO-3), with uneven DTensor and distributed checkpoint support.

Files Changed (27 files, +5,014 / −9)

Area Files Description
v2/ package __init__.py, fully_shard.py, fsdp_module.py, hooks.py, param_group.py, dp_buffer.py, allocator.py, utils.py, mixed_precision.py, design.md, README.md Core implementation and design docs
Integration mcore_fsdp_adapter.py, src/megatron_fsdp/__init__.py, megatron_fsdp.py, distributed_data_parallel_config.py Adapter supporting legacy and v2 paths
Example examples/megatron_fsdp/fsdp_toy.py Standalone training example
Tests tests/.../v2/test_allocator.py, test_param_group.py, test_mcore_fully_shard_api.py, test_checkpoint_online_convert.py Unit and integration tests

Experimental Results — Per-Module Sharding Rewrite

Verification was performed comparing the refactored implementation against the baseline.

Metric Baseline Megatron-FSDP2 Change
Performance (TFLOPs) 496.78 513.90 Slight improvement, within expected fluctuation range
Memory (GB) 48.30 47.15 Memory usage reduction
Convergence Comparable to baseline; observed difference ≈ 0.01, likely attributable to changes in parameter grouping
The refactored implementation achieves performance parity or better, modest memory savings, and convergence comparable to the baseline.

W&B link: https://wandb.ai/nvidia/megatron-fsdp/reports/M-FSDP-Rewrite-Convergence-Test-llama3-8b---VmlldzoxNjgzNTg3OQ

⚠️ For major changes (either in lines of code or in its impact), please make sure to first share a design doc with the team. If you're unsure what's the best way to do so, contact the @mcore-oncall.

Contribution process

Pre-checks

  • I have added relevant unit tests
  • I have added relevant functional tests
  • I have added proper typing to my code Typing guidelines
  • I have added relevant documentation
  • I have run the autoformatter.sh on my PR

Code review

Feel free to message or comment the @mcore-oncall to help accelerate your merge into main. The less complex your PR is, the faster it will be approved and merged!

All PRs start as draft. If you open a non-draft PR, it will be automatically converted to draft.

Step 1: Mark PR as "Ready for Review"

  1. When your PR is ready, click Ready for Review.
  2. An oncall reviewer is auto-assigned and expert reviewers are notified based on your changes.
    • Some PRs may jump straight to step 2. This is determined by .github/CODEOWNERS.

⚠️ Only mark as ready once merge-conflicts are resolved and the CI is passing.
Final Review might get declined if these requirements are not fulfilled.

Step 2: Final Review

For PRs that change megatron/core, once all expert reviewers have approved, the Final Review label is applied automatically and final reviewers are assigned.

For PRs outside megatron/core, this step is skipped.

Step 3: Approved

Once all required reviewers have approved, the Approved label is applied automatically.

Merge

Any member of mcore-engineers will be able to merge your PR.

For MRs into `dev` branch The proposed review process for `dev` branch is under active discussion.

MRs are mergable after one approval by either eharper@nvidia.com or zijiey@nvidia.com.

@copy-pr-bot

copy-pr-bot Bot commented Apr 22, 2026

Copy link
Copy Markdown

This pull request requires additional validation before any workflows can run on NVIDIA's runners.

Pull request vetters can view their responsibilities here.

Contributors can view more details about this message here.

shjwudp and others added 6 commits April 23, 2026 20:01
Preserve the Megatron-FSDP fully_shard optimizer contract by wiring main-gradient access into the DTensor-backed path and copying optimizer-updated main weights back into model-weight buffers after optimizer.step().

This squash also folds in the review fixes needed for the debug branch changes:
- allocate gradient reduce buffers on demand during reduce-scatter
- align ParameterGroup.reduce_grad with the fully_shard caller
- propagate NaN-check flags to every FSDP module
- validate unsharded parameters across all parameter groups
- drop leftover post-backward debug code and guard optional grad checks

Documentation for this merge is captured in this commit message per request.

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
- Move fully_shard_v2.py, param_group.py, dp_buffer.py, allocator.py to
  megatron/core/distributed/fsdp/src/megatron_fsdp/fully_shard_rewrite/
- Reference uneven_dtensor from parent module instead of copying
- Fix test imports in test_param_group.py and test_allocator.py
- Add __init__.py with exports
- Add README.md documentation
- Rename fully_shard_v2 to fully_shard_rewrite in fsdp_toy.py and adapter
- Add extensive docstrings to fully_shard.py and param_group.py
- Fix ParameterGroup call: pass mesh=mesh instead of params[0]._dtype
…_rewrite

- Add _FSDPRootContext with dedicated CUDA streams for all-gather and
  reduce-scatter to enable computation-communication overlap
- Implement unshard prefetch that pipelines parameter loading to the
  next module in forward/backward execution order
- Enable async reduce-scatter overlapped with backward pass
- Plumb mixed precision policy (main_params_dtype, main_grads_dtype)
  from adapter layer through to ParameterGroup gradient buffers
- Simplify dp_buffer.reduce_grad to always accumulate reduced shards
  synchronously into persistent buffer for gradient accumulation
- Fix no_sync property (nullcontext() -> nullcontext)
- Add --use-fsdp-fully-shard-api training argument
Fix multiple correctness bugs in the unshard-prefetch and reduce-scatter overlap
implementation for fully_shard_rewrite:

- Fix `_init_fsdp_state` to pass `enable_unshard_prefetch`/`enable_async_reduce_grad`
  to child FSDP modules (was calling without required args, causing TypeError).
- Fix `reduce_grad_buckets` from positional list to `Dict[id(module)]` mapping, fixing
  incorrect bucket index lookup that used `forward_order.index(self)`.
- Fix async reduce_grad path: remove premature `event.wait()` + `release_grad_buffer()`
  after recording event, replacing with deferred sliding drain (2-module lag).
- Fix `_post_backward_final_callback` to skip modules already handled via
  `post_backward_issued` flag and drain remaining buckets correctly.
- Fix `_scale_gradients` to operate on `dist_grad._local_tensor` instead of
  `main_grad_buffer.data`.
- Fix `_copy_main_weights_to_model_weights` to zero main grads afterward to
  avoid stale gradients.
- Fix `param_group.py`: include `"optim"` in `shard_grads` condition, fix
  gradient buffer dtype selection, fix `_init_dist_params` for `no_shard` and
  missing buffers, fix `release_grad_buffer` to drop `main_grad` views (prevent
  memory leak with TE gradient-accumulation-fusion).
- Fix `allocator.py`: change `param_group_id` from `int` to `ParamGroupIdx`,
  fix `free` to delete bucket after freeing storage, add `_free_storage` safety
  helper, remove `_resize_` in `allocate` that incorrectly truncated reused buckets.
- Add `grad_added_to_main_grad` flag support for TE gradient-accumulation-fusion.
- Wire `enable_unshard_prefetch`/`enable_async_reduce_grad` config flags through
  `mcore_fsdp_adapter.py`.
- Extract `ParamGroupIdx`, `RegisterFSDPBackwardFunction`, `_replace_module_parameter`
  into new `utils.py`.
- Add comprehensive `design_overlap.md` documenting the full overlap architecture.
…lap path

Bug fixes:

- Add stream.wait_stream(current_stream) in unshard() before launching
  async all-gather on ag_stream, ensuring main-stream writes to parameter
  data complete before the NCCL all-gather reads them.  This was the root
  cause of convergence divergence in the computation-communication overlap
  path where stale or partially-written parameter shards were gathered.
- Fix shard_grads in _init_buffers() to exclude "optim" strategy: ZeRO-1
  should shard optimizer states only, not distribute the gradient buffer.

New:

- Implement stop_communication() for the fully_shard path (was raising
  NotImplementedError): waits on ag_stream and rs_stream to bring all
  communication into the main CUDA stream before the optimizer step.
- Add NVTX ranges (MFSDP unshard/reshard/reduce_grad) for profiling.

Memory:

- Free original full-parameter storage via _free_storage() after copying
  data into weight buffers during _init_buffers(), reducing peak memory.

Docs:

- Rename design_overlap.md → design.md and expand to cover stream
  barriers, NVTX profiling, memory optimization, stop_communication().
- Fix README.md: correct file tree, sharding strategy table (optim row
  was incorrectly marked as sharding weights/grads), add ZeRO analogues.
- Remove stale FIXME comment, polish code comments and docstrings.
Register the final callback that runs after all backward passes complete.

This is only registered by the root FSDP module to avoid duplicate
callbacks. It reshards all modules and reduces gradients at the end

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@shjwudp Isn't a gradient tensor reduced after its containing FSDP module for overlapping? (I might have misunderstood your definition of "root".)

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Per-module post_backward hooks are registered via autograd on grad-carrying inputs, but in several situations those hooks may be skipped or never fire (e.g., traceable/compiled modes that explicitly skip FSDP hooks, no-grad forwards or detached inputs, conditional/MoE usage, or nonstandard multi-forward/multi-backward scheduling).

To keep things robust, root_backward_final_callback acts as a safety net that, at the end of the global backward, drives each post_backward so sharding, collectives, and state cleanup still happen correctly even when per-module hooks were not registered or invoked.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the context!

If a user chooses to skip the hooks, they should be responsible for invoking them manually and ensuring the calls remain symmetric (e.g., the same set of parameters are all-gathered and reduce-scattered). Correct?

We could still keep a final callback, but it should be limited to validating that the required constraints have been satisfied, rather than silently performing work on behalf of the user.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the final callback, the hook-related logic is primarily around gradient reduction. If users want to skip gradient reduction for certain modules or parameters, we could expose some interfaces — for example, a module/parameter tagging mechanism — to allow them to opt out selectively.

BTW, leveraging the final backward callback for state cleanup/reclamation is quite a neat and convenient pattern.

shjwudp and others added 6 commits May 11, 2026 07:25
This patchset resolves multiple convergence-critical correctness bugs
discovered during Llama3-8B convergence testing with the rewrite
(`fully_shard_rewrite`) Megatron-FSDP path, and adds diagnostic
tooling to prevent future regressions.

## Correctness Fixes (Convergence Critical)

### dist_param attribute propagation (mcore_fsdp_adapter.py)
Copy `is_embedding_parameter`, `is_embedding_or_output_parameter`,
`sequence_parallel`, `partition_dim`, `partition_stride`,
`_tensor_parallel_mode`, and other metadata from original parameters
to the FSDP DTensor dist_params.  The optimizer param-group builder
relies on these attributes; missing them causes parameters to be
assigned to wrong optimizer groups (wrong weight-decay / LR
multipliers), leading directly to convergence divergence.

### Zero-numel DTensor grads cause silent FusedAdam corruption (param_group.py)
When a parameter's local shard has numel=0 on a DP rank, creating a
DTensor with an empty local tensor and passing it to fused
multi-tensor optimizers (e.g. TE FusedAdam) silently corrupts updates
for neighboring non-empty parameters in the same group.  The optimizer
runs without error — only convergence divergence reveals the bug.
Fix: skip DTensor creation when `grad_data.numel() == 0`, record
`None` instead.  Also guard `_scale_gradients` with `dist_grad is None`.
Documented in design.md § Pitfall and README § Gotchas.

### FusedAdam master_weights always disabled (optimizer/__init__.py)
Remove the `use_precision_aware_optimizer` gate; FusedAdam's internal
master_weights are redundant with Megatron-FSDP's main-weight buffers
and cause correctness issues with 1D DTensor parameters.  Always
disable them.

### Fragment bin-packing in dp_buffer (dp_buffer.py)
Replace ad-hoc leftover-fragment placement with a proper bin-packing
algorithm that fills chunk_size_factor-sized alignment grids.  This
fixes buffer layout edge cases where misaligned fragments caused
overlaps or out-of-bounds accesses.

### Reduce-scatter output aliasing (dp_buffer.py)
Use a separate output tensor in `reduce_scatter_tensor` instead of
slicing the full input buffer directly.  Aliasing the input can cause
the collective to overwrite its own input, silently corrupting
gradients.

### Inverted async reduce wait condition (fully_shard.py)
`_wait_for_previous_async_reduce_grad` returned early when
`enable_async_reduce_grad=True`, opposite of the intended behavior
(the legacy code waited only in the async path).  This prevented
waiting for in-flight reduce events, causing premature grad-buffer
release.  Fix: invert condition to `if not ctx.enable_async_reduce_grad`.

### Meta-device init DP-rank parameter sync (fully_shard.py)
After materializing meta parameters with `reset_parameters()`, each DP
rank may have different random values (due to divergent RNG states).
Broadcast full parameters from DP rank 0 within the DP-CP mesh
*before* FSDP param groups create DTensors, so every rank's shard is
a correct slice of the same full parameter.  Also fix `broadcast_params`
in adapter: changed from `not_implemented_op` to `noop` (the training
loop calls it under `--data-parallel-random-init`; the init-time
broadcast above covers the sync).

### Meta-device init RNG tracker forking (fully_shard.py)
The legacy path wraps `reset_parameters()` with `ResetParametersContext`
which forks the model-parallel RNG tracker for non-TE modules when
TE >= 0.9.0 is present.  Without this fork, TP ranks consume different
RNG sequences during init, producing inconsistent values for
TP-duplicated parameters (LayerNorm weights, biases).  Add equivalent
forking in `_materialize_meta_module`.

### post_backward_issued guard (fully_shard.py)
Use `getattr(module, "post_backward_issued", False)` to avoid
AttributeError when the attribute is missing.

## New Features

### Gradient scaling factor & ReduceOp (fully_shard.py, dp_buffer.py,
mcore_fsdp_adapter.py)
Support `calculate_per_token_loss`, `average_in_collective`, and
default (1/dp_world_size) scaling.  Use `_make_nccl_premul_sum` for
FP32/FP16 and manual pre-scaling for BF16.

### Stream synchronization (mcore_fsdp_adapter.py)
Implement `finish_grad_sync` and `synchronize_param_gather` for the
rewrite path (formerly no-ops), synchronizing `rs_stream` and
`ag_stream` with the main CUDA stream.

### grad_added_to_main_grad handling (fully_shard.py)
When TE gradient-accumulation-fusion writes directly to `main_grad`,
discard the dummy `.grad` tensor instead of zeroing/overwriting
`main_grad`.

## Debug / Diagnostic Tooling

### per-param norm logging (fully_shard.py, megatron_fsdp.py,
training.py, config)
New `--log-per-param-norm` config flag logs per-parameter L2 norms
for both params and grads, globally reduced across DP ranks.

### Parameter group diagnostics (fully_shard.py)
`_log_parameter_groups()` prints compact buffer-layout summaries
with memory metrics.  `check_all_fsdp_buffers()` validates no
local-slice overlaps in any FSDP buffer at runtime.

### Logging callbacks wired to adapter (mcore_fsdp_adapter.py)
`log_per_param_norms`, `compute_per_param_norms`, `log_parameter_groups`
are now accessible from the `FullyShardedDataParallel` wrapper.

## Documentation
- design.md: new "Pitfall: Zero-Numel Gradient Shards" section
- README.md: new "Gotchas / Pitfalls" section
refactor(mfsdp): simplify mixed precision policy defaults
if use_megatron_fsdp:
from megatron.core.distributed.fsdp.src.megatron_fsdp.uneven_dtensor import get_state_dict
from megatron.core.distributed.fsdp.src.megatron_fsdp.v2 import FSDPModule, fully_shard
sys.modules["get_state_dict"] = get_state_dict

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm pretty nervous about this monkeypatch. I sort of know your intention of keeping the optimizer unchanged but this arguably hides too much from the user.



@dataclasses.dataclass
class Bucket:

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This abstraction doesn't look very useful?

@wujingyue wujingyue left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! I'll review the high levels this week.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can this PR test fully_shard directly?

There’s a ToyModel example, but it isn’t covered by CI and tests very little. I’m fine with the model being toy-sized, but I’d like key performance-related features, such as peak memory, to be tested continuously in CI.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This PR primarily focuses on functional regression for MCore, as we have a solid baseline for comparison.

Unit tests for the fully_shard API will also be included, but they are not the main focus at this stage. I will try leveraging AI to help generate validation checks, ensuring that critical functionalities are not broken during development.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unit tests for the fully_shard API will also be included, but they are not the main focus at this stage.

I think this is really a matter of execution order.

First of all, we’ll need to split this PR into a few “substages” anyway 😄

When we do that, fully_shard would land before the MCore adapter changes. By extension, fully_shard would also be tested independently before the MCore adapter integration. Correct?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, that’s a good suggestion—fingers crossed. It would be better if we can get the most critical parts merged first, and I’ll try to split this PR.

@shjwudp shjwudp requested a review from cspades May 14, 2026 02:42
shjwudp and others added 11 commits May 14, 2026 13:14
- Add Apache 2.0 Copyright headers to 10 new files, update year to 2026
- Replace print() with logging: logger.warning for validation,
  logger.info for param/grad norms, logger.debug for trace dumps
- Add missing docstrings: Bucket, phase, ParamGroupIdx,
  make_uneven_dtensor, get_state_dict
- Remove unused Optional import from allocator.py
- Remove misleading # TODO from fully_shard.py mp_policy param
- Revert megatron/core/optimizer/__init__.py
- isort clean
Add FSDP v2 MXFP8 mixed precision support
…uards

Squash merge of mfsdp_refactor_main_wgrad_fusion.

Key fixes:
- Set overwrite_main_grad=True for sharded params in pre_backward_hook to
  prevent gradient doubling/NaN when TE gradient_accumulation_fusion is active
- Propagate 'allreduce' attribute from original params to DTensor dist_params
  in mcore_fsdp_adapter to fix expert param misclassification
- Add runtime safety check in _init_fsdp_state() rejecting re-initialization
  while child FSDPModules are still unsharded or have pending reduce-scatters
- Add _init_default_fully_shard_mesh() for default mesh when none is provided
- Improve NaN check error messages to include parameter names

Documentation:
- Document safety constraint, overwrite_main_grad semantics, and attribute
  propagation pitfall in design.md

Tests:
- Add comprehensive test_fully_shard.py (basic API, LLM/multimodal/MoE
  scenarios, mixed precision, lifecycle, checkpoint, safety)
- Rename test_mcore_fully_shard_api.py to test_mcore_nd_parallel.py
- Enable gradient_accumulation_fusion in E2E test config

Formatting:
- Autoformat allocator.py, mixed_precision.py, param_group.py, fsdp_module.py
…u/merge-mfsdp-refactor-main-20260514

# Conflicts:
#	megatron/core/distributed/fsdp/src/megatron_fsdp/v2/mixed_precision.py
#	megatron/core/distributed/fsdp/src/megatron_fsdp/v2/param_group.py

@wujingyue wujingyue left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

About half way there--I'll review the rest by Monday!



@dataclass(frozen=True)
class FullyShardMixedPrecisionPolicy:

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
class FullyShardMixedPrecisionPolicy:
class MixedPrecisionPolicy:

should be good enough. The import path should distinguish where it comes from.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm fine with both options. Let's leave a comment for @Autumn1998 to gather his input.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm fine with both options also. MixedPrecisionPolicy also make sense to me

params: List[torch.nn.Parameter],
param_group_id: ParamGroupIdx,
*,
mp_policy: FullyShardMixedPrecisionPolicy,

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: Does ParameterGroup need to know all of MixedPrecisionPolicy? I'd probably flatten it and pass in only the attributes necessary.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The param group indeed doesn't need to know every member of MixedPrecisionPolicy, but it does need to use many parts of it. I think passing in the policy directly is cleaner and more beneficial for future extensibility.

_free_storage(self.buckets[key].data)


class TracePoolAllocator(BucketAllocator):

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't follow the name. Does this happen to be the old FixedPoolAllocator but with tracing?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is an experimental allocator intended to replace FixedPoolAllocator. Its main advantage is easier debugging; however, it can fail in certain scenarios — e.g. when a validation pass (forward-only) is interleaved within the training loop, which breaks TracePoolAllocator.

You can find more details in design.md#tracepoolallocator.

data: torch.Tensor


class BucketAllocator:

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm a bit overwhelmed by how many allocators have been implemented 😄 What allocation/deallocation algorithms do we really need? How much of it can be done with MemPool with custom allocators, which is torch native? E.g.
https://docs.pytorch.org/docs/2.12/symmetric_memory.html#torch.distributed._symmetric_memory.get_mem_pool

param_group_id: ParamGroupIdx,
*,
allocator: Optional[BucketAllocator] = None,
buffer_role: str = "model_weight",

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good to have for debugging!

Suggested change
buffer_role: str = "model_weight",
name: str = "model_weight",

to be general. I may need it within an optimizer to convert between flat-sharded and tensor-atomic.

*,
allocator: Optional[BucketAllocator] = None,
buffer_role: str = "model_weight",
is_distributed: bool = False,

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's hard to define is_distributed. Some buffers might be replicated across dp_outer but sharded across dp_inner. Therefore, in the new design, I've been using Placements.

That said, can this be computed from existing attributes? If so, it doesn't need to be another attribute or argument.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree. Do you think implementing the placement idea based on this code would introduce any conflicts?

If not, I would strongly prefer that we move forward and incorporate this idea into v2. cc @Autumn1998

is_distributed: bool = False,
gradient_scaling_factor: Optional[float] = None,
chunk_size_factor: int = 1,
sharding_strategy: str = "no_shard",

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks concerning — we’re leaking high-level abstractions like the global sharding strategy and buffer semantics into a low-level DataParallelBuffer.

I’d probably make DataParallelBuffer more self-contained, similar to DStorage in FlexShard. That feels like a cleaner abstraction and also makes the system much easier to reason about and debug.

@shjwudp shjwudp requested review from a team as code owners May 18, 2026 02:42
@shjwudp shjwudp force-pushed the mfsdp_refactor_main branch from 04f488b to a24e371 Compare May 18, 2026 05:16
shjwudp added 3 commits May 18, 2026 15:16
Reverted stage-2 files to nvidia/main baseline:
- mcore_fsdp_adapter.py, optimizer_config.py

Removed stage-2-only test files:
- test_checkpoint_online_convert.py, test_mcore_nd_parallel.py

All other MCore integration files already match nvidia/main
after merge.  Stage-2 preserved on mfsdp_refactor_main_stage2.
- distributed_data_parallel_config.py
- training/arguments.py
- training_config.py
- training.py
@shjwudp shjwudp force-pushed the mfsdp_refactor_main branch from 33d0e32 to 694c09d Compare May 18, 2026 07:18
@shjwudp shjwudp changed the title M-FSDP Rewrite for Per-Module fully_shard [1/N] Megatron FSDP: Introduce Megatron-FSDP2 with per-module fully_shard() API May 18, 2026

@wujingyue wujingyue left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I made to fully_shard.py! I will focus on mixed precision tomorrow.

In general, as I expected earlier, I’m more concerned about the low-level constructs like DataParallelBuffer and allocators and Optimizer. I’ll probably end up rewriting them yet again to properly support the different sharding strategies and formats. I started some of this work in #4835 and added you as a reviewer. We should try to converge on the design.

The high-level constructs also need a fair amount of cleanup and/or clarification, but I can leave comments on the PRs after you split this one up.

Comment thread megatron/core/distributed/fsdp/src/megatron_fsdp/v2/fsdp_module.py
# CUDA streams (communication overlap)
# ------------------------------------------------------------------
ag_stream: torch.cuda.Stream # all-gather / unshard stream
rs_stream: torch.cuda.Stream # reduce-scatter stream

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

FYI, I think we'll need at least one extra for inter-node communication in HSDP/HFSDP.


This ordering is used to:
- Schedule prefetching of parameters (unshard)
- Ensure correct overlap between compute and communication

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't follow this. Even without an explicit order, correctness is guaranteed by letting forward wait_stream allgather and reduce_scatter wait_stream backward. What am I missing?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ensure correct overlap ere means allowing the communication and computation pipelines to run efficiently in parallel instead of random blocking each other. Because there are data dependencies between them, establishing the right execution order is crucial.

FSDP modules in actual forward execution order.

This ordering is used to:
- Schedule prefetching of parameters (unshard)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How important is prefetching with per-module control? The user can choose to fully_shard on a higher-level module if they want more parameters to be unsharded in one go, sacrificing memory. But I might be misunderstanding what you mean by prefetching.

mp_policy: FullyShardMixedPrecisionPolicy,
mesh: Optional[DeviceMesh] = None,
sharding_strategy: str = "optim_grads_params",
gradient_scaling_factor: Optional[float] = None,

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Curious what this is for. It lacks documentation at this moment.

"completed gradient reduction before re-initializing FSDP state."
)

root_context = _FSDPRootContext(

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm sure I'm missing something--how did you make sure this is only constructed for the root module?

return (None,) + grads


def _replace_module_parameter(module: nn.Module, name: str, new_param: nn.Parameter):

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
def _replace_module_parameter(module: nn.Module, name: str, new_param: nn.Parameter):
def _replace_module_parameter(module: nn.Module, fqn: str, new_param: nn.Parameter):

allocator=bucket_allocator,
gradient_scaling_factor=gradient_scaling_factor,
)
setattr(self, "_fsdp_param_groups", fsdp_param_groups)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
setattr(self, "_fsdp_param_groups", fsdp_param_groups)
self._param_groups = fsdp_param_groups

raise AssertionError("Current module not found in forward module order")


class FSDPModule(nn.Module):

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see, thanks for point it out.

"""
param_groups = {}

for param in module.parameters():

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Something seems to be missing. Shouldn't we skip any submodule that is already an FSDPModule? Because it forms its own unit.

else:
data = param.data.detach()

dist_param = torch.nn.Parameter(

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Question: Conceptually, we need two sets of nn.Parameters, one for model computation (in low precision; views of model_weight_buffer) and the other for optimization (in high precision; views of main_weight_buffer). How's that handled in the code?

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess this is for the latter; the other "raw", full-shape parameter is probably constructed elsewhere I've yet to find.

param_group_id: ParamGroupIdx,
*,
mp_policy: FullyShardMixedPrecisionPolicy,
mesh: Optional[DeviceMesh] = None,

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
mesh: Optional[DeviceMesh] = None,
mesh: DeviceMesh,

Given it's an internal data structure, I don't see a really good reason to complicate the support surface.

from packaging.version import Version as PkgVersion

try:
import transformer_engine as te

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we still care about the non-TE path? It feels like the performance isn't going to be competitive at all without TE.

Comment on lines +142 to +146
fp8_param_gather: bool = False
fp8_recipe: Optional[str] = None
keep_fp8_transpose_cache: bool = False
use_decoupled_grad: bool = False
fp8: Optional[FullyShardFP8Policy] = field(default=None, repr=False)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Request for comments. Also can't they all be merged into FullyShardFp8Policy?

return tensor.dtype
return ("quantized", type(tensor).__name__, self.fp8.recipe)

def validate_param_group(self, params: List[torch.Tensor]) -> None:

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think many methods of this class (like this one) should really be moved to a different class or a helper method. E.g. ParameterGroup.validate(). This makes MixedPrecisionPolicy a leaf-level class with minimal dependency, following https://en.wikipedia.org/wiki/Single-responsibility_principle.

Comment on lines +281 to +282
if self.is_fp8_param(tensor) and self.main_params_dtype is None:
return torch.float32

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why? If a parameter is fp8, main_weight is fp32 regardless of main_params_dtype?

) -> List[torch.Tensor]:
"""Return original parameter storages that FSDP buffers have replaced."""
# The buffers are ownership signals, not data sources: non-FP8 params can
# be backed by either model or main weight storage, while FP8 params are

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

non-FP8 params can be backed by either model or main weight storage

Really? I thought params are either sharded params (backed by main_weight and main_grad) or unsharded params (sort of model_weight but allgathered).

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I suspect this refers to the backed model weight buffer in its distributed form (distributed state) and main weight shard for high-precision version. I think the description here could be more precise. FYI @Autumn1998

if main_weight_buffer is None:
return

assert model_weight_buffer is not None, "main weights require a model-weight buffer"

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Side question: for high precision like bf16, main_weight_buffer == model_weight_buffer?

Comment on lines +433 to +434
if len(model_params) == 0:
return

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if len(model_params) == 0:
return

Comment on lines +438 to +444
args = [
model_params,
main_params,
start_offsets,
data_parallel_group,
fsdp_shard_model_params,
]

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
args = [
model_params,
main_params,
start_offsets,
data_parallel_group,
fsdp_shard_model_params,
]

I'd just inline them to the call site.

main_params: List[Optional[torch.Tensor]],
start_offsets: List[Optional[int]],
data_parallel_group: torch.distributed.ProcessGroup,
fsdp_shard_model_params: List[Tuple[Optional[torch.Tensor], Optional[torch.Tensor]]],

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's this in addition to model_params and main_params?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants