[1/N] Megatron FSDP: Introduce Megatron-FSDP2 with per-module fully_shard() API#4435
[1/N] Megatron FSDP: Introduce Megatron-FSDP2 with per-module fully_shard() API#4435shjwudp wants to merge 42 commits into
Conversation
2. init fully_shard v2 api
Preserve the Megatron-FSDP fully_shard optimizer contract by wiring main-gradient access into the DTensor-backed path and copying optimizer-updated main weights back into model-weight buffers after optimizer.step(). This squash also folds in the review fixes needed for the debug branch changes: - allocate gradient reduce buffers on demand during reduce-scatter - align ParameterGroup.reduce_grad with the fully_shard caller - propagate NaN-check flags to every FSDP module - validate unsharded parameters across all parameter groups - drop leftover post-backward debug code and guard optional grad checks Documentation for this merge is captured in this commit message per request. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
- Move fully_shard_v2.py, param_group.py, dp_buffer.py, allocator.py to megatron/core/distributed/fsdp/src/megatron_fsdp/fully_shard_rewrite/ - Reference uneven_dtensor from parent module instead of copying - Fix test imports in test_param_group.py and test_allocator.py - Add __init__.py with exports - Add README.md documentation
- Rename fully_shard_v2 to fully_shard_rewrite in fsdp_toy.py and adapter - Add extensive docstrings to fully_shard.py and param_group.py - Fix ParameterGroup call: pass mesh=mesh instead of params[0]._dtype
…_rewrite - Add _FSDPRootContext with dedicated CUDA streams for all-gather and reduce-scatter to enable computation-communication overlap - Implement unshard prefetch that pipelines parameter loading to the next module in forward/backward execution order - Enable async reduce-scatter overlapped with backward pass - Plumb mixed precision policy (main_params_dtype, main_grads_dtype) from adapter layer through to ParameterGroup gradient buffers - Simplify dp_buffer.reduce_grad to always accumulate reduced shards synchronously into persistent buffer for gradient accumulation - Fix no_sync property (nullcontext() -> nullcontext) - Add --use-fsdp-fully-shard-api training argument
Fix multiple correctness bugs in the unshard-prefetch and reduce-scatter overlap implementation for fully_shard_rewrite: - Fix `_init_fsdp_state` to pass `enable_unshard_prefetch`/`enable_async_reduce_grad` to child FSDP modules (was calling without required args, causing TypeError). - Fix `reduce_grad_buckets` from positional list to `Dict[id(module)]` mapping, fixing incorrect bucket index lookup that used `forward_order.index(self)`. - Fix async reduce_grad path: remove premature `event.wait()` + `release_grad_buffer()` after recording event, replacing with deferred sliding drain (2-module lag). - Fix `_post_backward_final_callback` to skip modules already handled via `post_backward_issued` flag and drain remaining buckets correctly. - Fix `_scale_gradients` to operate on `dist_grad._local_tensor` instead of `main_grad_buffer.data`. - Fix `_copy_main_weights_to_model_weights` to zero main grads afterward to avoid stale gradients. - Fix `param_group.py`: include `"optim"` in `shard_grads` condition, fix gradient buffer dtype selection, fix `_init_dist_params` for `no_shard` and missing buffers, fix `release_grad_buffer` to drop `main_grad` views (prevent memory leak with TE gradient-accumulation-fusion). - Fix `allocator.py`: change `param_group_id` from `int` to `ParamGroupIdx`, fix `free` to delete bucket after freeing storage, add `_free_storage` safety helper, remove `_resize_` in `allocate` that incorrectly truncated reused buckets. - Add `grad_added_to_main_grad` flag support for TE gradient-accumulation-fusion. - Wire `enable_unshard_prefetch`/`enable_async_reduce_grad` config flags through `mcore_fsdp_adapter.py`. - Extract `ParamGroupIdx`, `RegisterFSDPBackwardFunction`, `_replace_module_parameter` into new `utils.py`. - Add comprehensive `design_overlap.md` documenting the full overlap architecture.
…lap path Bug fixes: - Add stream.wait_stream(current_stream) in unshard() before launching async all-gather on ag_stream, ensuring main-stream writes to parameter data complete before the NCCL all-gather reads them. This was the root cause of convergence divergence in the computation-communication overlap path where stale or partially-written parameter shards were gathered. - Fix shard_grads in _init_buffers() to exclude "optim" strategy: ZeRO-1 should shard optimizer states only, not distribute the gradient buffer. New: - Implement stop_communication() for the fully_shard path (was raising NotImplementedError): waits on ag_stream and rs_stream to bring all communication into the main CUDA stream before the optimizer step. - Add NVTX ranges (MFSDP unshard/reshard/reduce_grad) for profiling. Memory: - Free original full-parameter storage via _free_storage() after copying data into weight buffers during _init_buffers(), reducing peak memory. Docs: - Rename design_overlap.md → design.md and expand to cover stream barriers, NVTX profiling, memory optimization, stop_communication(). - Fix README.md: correct file tree, sharding strategy table (optim row was incorrectly marked as sharding weights/grads), add ZeRO analogues. - Remove stale FIXME comment, polish code comments and docstrings.
| Register the final callback that runs after all backward passes complete. | ||
|
|
||
| This is only registered by the root FSDP module to avoid duplicate | ||
| callbacks. It reshards all modules and reduces gradients at the end |
There was a problem hiding this comment.
@shjwudp Isn't a gradient tensor reduced after its containing FSDP module for overlapping? (I might have misunderstood your definition of "root".)
There was a problem hiding this comment.
Per-module post_backward hooks are registered via autograd on grad-carrying inputs, but in several situations those hooks may be skipped or never fire (e.g., traceable/compiled modes that explicitly skip FSDP hooks, no-grad forwards or detached inputs, conditional/MoE usage, or nonstandard multi-forward/multi-backward scheduling).
To keep things robust, root_backward_final_callback acts as a safety net that, at the end of the global backward, drives each post_backward so sharding, collectives, and state cleanup still happen correctly even when per-module hooks were not registered or invoked.
There was a problem hiding this comment.
Thanks for the context!
If a user chooses to skip the hooks, they should be responsible for invoking them manually and ensuring the calls remain symmetric (e.g., the same set of parameters are all-gathered and reduce-scattered). Correct?
We could still keep a final callback, but it should be limited to validating that the required constraints have been satisfied, rather than silently performing work on behalf of the user.
There was a problem hiding this comment.
In the final callback, the hook-related logic is primarily around gradient reduction. If users want to skip gradient reduction for certain modules or parameters, we could expose some interfaces — for example, a module/parameter tagging mechanism — to allow them to opt out selectively.
BTW, leveraging the final backward callback for state cleanup/reclamation is quite a neat and convenient pattern.
This patchset resolves multiple convergence-critical correctness bugs discovered during Llama3-8B convergence testing with the rewrite (`fully_shard_rewrite`) Megatron-FSDP path, and adds diagnostic tooling to prevent future regressions. ## Correctness Fixes (Convergence Critical) ### dist_param attribute propagation (mcore_fsdp_adapter.py) Copy `is_embedding_parameter`, `is_embedding_or_output_parameter`, `sequence_parallel`, `partition_dim`, `partition_stride`, `_tensor_parallel_mode`, and other metadata from original parameters to the FSDP DTensor dist_params. The optimizer param-group builder relies on these attributes; missing them causes parameters to be assigned to wrong optimizer groups (wrong weight-decay / LR multipliers), leading directly to convergence divergence. ### Zero-numel DTensor grads cause silent FusedAdam corruption (param_group.py) When a parameter's local shard has numel=0 on a DP rank, creating a DTensor with an empty local tensor and passing it to fused multi-tensor optimizers (e.g. TE FusedAdam) silently corrupts updates for neighboring non-empty parameters in the same group. The optimizer runs without error — only convergence divergence reveals the bug. Fix: skip DTensor creation when `grad_data.numel() == 0`, record `None` instead. Also guard `_scale_gradients` with `dist_grad is None`. Documented in design.md § Pitfall and README § Gotchas. ### FusedAdam master_weights always disabled (optimizer/__init__.py) Remove the `use_precision_aware_optimizer` gate; FusedAdam's internal master_weights are redundant with Megatron-FSDP's main-weight buffers and cause correctness issues with 1D DTensor parameters. Always disable them. ### Fragment bin-packing in dp_buffer (dp_buffer.py) Replace ad-hoc leftover-fragment placement with a proper bin-packing algorithm that fills chunk_size_factor-sized alignment grids. This fixes buffer layout edge cases where misaligned fragments caused overlaps or out-of-bounds accesses. ### Reduce-scatter output aliasing (dp_buffer.py) Use a separate output tensor in `reduce_scatter_tensor` instead of slicing the full input buffer directly. Aliasing the input can cause the collective to overwrite its own input, silently corrupting gradients. ### Inverted async reduce wait condition (fully_shard.py) `_wait_for_previous_async_reduce_grad` returned early when `enable_async_reduce_grad=True`, opposite of the intended behavior (the legacy code waited only in the async path). This prevented waiting for in-flight reduce events, causing premature grad-buffer release. Fix: invert condition to `if not ctx.enable_async_reduce_grad`. ### Meta-device init DP-rank parameter sync (fully_shard.py) After materializing meta parameters with `reset_parameters()`, each DP rank may have different random values (due to divergent RNG states). Broadcast full parameters from DP rank 0 within the DP-CP mesh *before* FSDP param groups create DTensors, so every rank's shard is a correct slice of the same full parameter. Also fix `broadcast_params` in adapter: changed from `not_implemented_op` to `noop` (the training loop calls it under `--data-parallel-random-init`; the init-time broadcast above covers the sync). ### Meta-device init RNG tracker forking (fully_shard.py) The legacy path wraps `reset_parameters()` with `ResetParametersContext` which forks the model-parallel RNG tracker for non-TE modules when TE >= 0.9.0 is present. Without this fork, TP ranks consume different RNG sequences during init, producing inconsistent values for TP-duplicated parameters (LayerNorm weights, biases). Add equivalent forking in `_materialize_meta_module`. ### post_backward_issued guard (fully_shard.py) Use `getattr(module, "post_backward_issued", False)` to avoid AttributeError when the attribute is missing. ## New Features ### Gradient scaling factor & ReduceOp (fully_shard.py, dp_buffer.py, mcore_fsdp_adapter.py) Support `calculate_per_token_loss`, `average_in_collective`, and default (1/dp_world_size) scaling. Use `_make_nccl_premul_sum` for FP32/FP16 and manual pre-scaling for BF16. ### Stream synchronization (mcore_fsdp_adapter.py) Implement `finish_grad_sync` and `synchronize_param_gather` for the rewrite path (formerly no-ops), synchronizing `rs_stream` and `ag_stream` with the main CUDA stream. ### grad_added_to_main_grad handling (fully_shard.py) When TE gradient-accumulation-fusion writes directly to `main_grad`, discard the dummy `.grad` tensor instead of zeroing/overwriting `main_grad`. ## Debug / Diagnostic Tooling ### per-param norm logging (fully_shard.py, megatron_fsdp.py, training.py, config) New `--log-per-param-norm` config flag logs per-parameter L2 norms for both params and grads, globally reduced across DP ranks. ### Parameter group diagnostics (fully_shard.py) `_log_parameter_groups()` prints compact buffer-layout summaries with memory metrics. `check_all_fsdp_buffers()` validates no local-slice overlaps in any FSDP buffer at runtime. ### Logging callbacks wired to adapter (mcore_fsdp_adapter.py) `log_per_param_norms`, `compute_per_param_norms`, `log_parameter_groups` are now accessible from the `FullyShardedDataParallel` wrapper. ## Documentation - design.md: new "Pitfall: Zero-Numel Gradient Shards" section - README.md: new "Gotchas / Pitfalls" section
Refactor fully shard runtime modules
refactor(mfsdp): simplify mixed precision policy defaults
| if use_megatron_fsdp: | ||
| from megatron.core.distributed.fsdp.src.megatron_fsdp.uneven_dtensor import get_state_dict | ||
| from megatron.core.distributed.fsdp.src.megatron_fsdp.v2 import FSDPModule, fully_shard | ||
| sys.modules["get_state_dict"] = get_state_dict |
There was a problem hiding this comment.
I'm pretty nervous about this monkeypatch. I sort of know your intention of keeping the optimizer unchanged but this arguably hides too much from the user.
|
|
||
|
|
||
| @dataclasses.dataclass | ||
| class Bucket: |
There was a problem hiding this comment.
This abstraction doesn't look very useful?
wujingyue
left a comment
There was a problem hiding this comment.
Thanks! I'll review the high levels this week.
There was a problem hiding this comment.
Can this PR test fully_shard directly?
There’s a ToyModel example, but it isn’t covered by CI and tests very little. I’m fine with the model being toy-sized, but I’d like key performance-related features, such as peak memory, to be tested continuously in CI.
There was a problem hiding this comment.
This PR primarily focuses on functional regression for MCore, as we have a solid baseline for comparison.
Unit tests for the fully_shard API will also be included, but they are not the main focus at this stage. I will try leveraging AI to help generate validation checks, ensuring that critical functionalities are not broken during development.
There was a problem hiding this comment.
Unit tests for the fully_shard API will also be included, but they are not the main focus at this stage.
I think this is really a matter of execution order.
First of all, we’ll need to split this PR into a few “substages” anyway 😄
When we do that, fully_shard would land before the MCore adapter changes. By extension, fully_shard would also be tested independently before the MCore adapter integration. Correct?
There was a problem hiding this comment.
Sure, that’s a good suggestion—fingers crossed. It would be better if we can get the most critical parts merged first, and I’ll try to split this PR.
- Add Apache 2.0 Copyright headers to 10 new files, update year to 2026 - Replace print() with logging: logger.warning for validation, logger.info for param/grad norms, logger.debug for trace dumps - Add missing docstrings: Bucket, phase, ParamGroupIdx, make_uneven_dtensor, get_state_dict - Remove unused Optional import from allocator.py - Remove misleading # TODO from fully_shard.py mp_policy param - Revert megatron/core/optimizer/__init__.py - isort clean
Add FSDP v2 MXFP8 mixed precision support
…uards Squash merge of mfsdp_refactor_main_wgrad_fusion. Key fixes: - Set overwrite_main_grad=True for sharded params in pre_backward_hook to prevent gradient doubling/NaN when TE gradient_accumulation_fusion is active - Propagate 'allreduce' attribute from original params to DTensor dist_params in mcore_fsdp_adapter to fix expert param misclassification - Add runtime safety check in _init_fsdp_state() rejecting re-initialization while child FSDPModules are still unsharded or have pending reduce-scatters - Add _init_default_fully_shard_mesh() for default mesh when none is provided - Improve NaN check error messages to include parameter names Documentation: - Document safety constraint, overwrite_main_grad semantics, and attribute propagation pitfall in design.md Tests: - Add comprehensive test_fully_shard.py (basic API, LLM/multimodal/MoE scenarios, mixed precision, lifecycle, checkpoint, safety) - Rename test_mcore_fully_shard_api.py to test_mcore_nd_parallel.py - Enable gradient_accumulation_fusion in E2E test config Formatting: - Autoformat allocator.py, mixed_precision.py, param_group.py, fsdp_module.py
…u/merge-mfsdp-refactor-main-20260514 # Conflicts: # megatron/core/distributed/fsdp/src/megatron_fsdp/v2/mixed_precision.py # megatron/core/distributed/fsdp/src/megatron_fsdp/v2/param_group.py
wujingyue
left a comment
There was a problem hiding this comment.
About half way there--I'll review the rest by Monday!
|
|
||
|
|
||
| @dataclass(frozen=True) | ||
| class FullyShardMixedPrecisionPolicy: |
There was a problem hiding this comment.
| class FullyShardMixedPrecisionPolicy: | |
| class MixedPrecisionPolicy: |
should be good enough. The import path should distinguish where it comes from.
There was a problem hiding this comment.
I'm fine with both options. Let's leave a comment for @Autumn1998 to gather his input.
There was a problem hiding this comment.
I'm fine with both options also. MixedPrecisionPolicy also make sense to me
| params: List[torch.nn.Parameter], | ||
| param_group_id: ParamGroupIdx, | ||
| *, | ||
| mp_policy: FullyShardMixedPrecisionPolicy, |
There was a problem hiding this comment.
Nit: Does ParameterGroup need to know all of MixedPrecisionPolicy? I'd probably flatten it and pass in only the attributes necessary.
There was a problem hiding this comment.
The param group indeed doesn't need to know every member of MixedPrecisionPolicy, but it does need to use many parts of it. I think passing in the policy directly is cleaner and more beneficial for future extensibility.
| _free_storage(self.buckets[key].data) | ||
|
|
||
|
|
||
| class TracePoolAllocator(BucketAllocator): |
There was a problem hiding this comment.
I don't follow the name. Does this happen to be the old FixedPoolAllocator but with tracing?
There was a problem hiding this comment.
This is an experimental allocator intended to replace FixedPoolAllocator. Its main advantage is easier debugging; however, it can fail in certain scenarios — e.g. when a validation pass (forward-only) is interleaved within the training loop, which breaks TracePoolAllocator.
You can find more details in design.md#tracepoolallocator.
| data: torch.Tensor | ||
|
|
||
|
|
||
| class BucketAllocator: |
There was a problem hiding this comment.
I'm a bit overwhelmed by how many allocators have been implemented 😄 What allocation/deallocation algorithms do we really need? How much of it can be done with MemPool with custom allocators, which is torch native? E.g.
https://docs.pytorch.org/docs/2.12/symmetric_memory.html#torch.distributed._symmetric_memory.get_mem_pool
| param_group_id: ParamGroupIdx, | ||
| *, | ||
| allocator: Optional[BucketAllocator] = None, | ||
| buffer_role: str = "model_weight", |
There was a problem hiding this comment.
Good to have for debugging!
| buffer_role: str = "model_weight", | |
| name: str = "model_weight", |
to be general. I may need it within an optimizer to convert between flat-sharded and tensor-atomic.
| *, | ||
| allocator: Optional[BucketAllocator] = None, | ||
| buffer_role: str = "model_weight", | ||
| is_distributed: bool = False, |
There was a problem hiding this comment.
It's hard to define is_distributed. Some buffers might be replicated across dp_outer but sharded across dp_inner. Therefore, in the new design, I've been using Placements.
That said, can this be computed from existing attributes? If so, it doesn't need to be another attribute or argument.
There was a problem hiding this comment.
I agree. Do you think implementing the placement idea based on this code would introduce any conflicts?
If not, I would strongly prefer that we move forward and incorporate this idea into v2. cc @Autumn1998
| is_distributed: bool = False, | ||
| gradient_scaling_factor: Optional[float] = None, | ||
| chunk_size_factor: int = 1, | ||
| sharding_strategy: str = "no_shard", |
There was a problem hiding this comment.
This looks concerning — we’re leaking high-level abstractions like the global sharding strategy and buffer semantics into a low-level DataParallelBuffer.
I’d probably make DataParallelBuffer more self-contained, similar to DStorage in FlexShard. That feels like a cleaner abstraction and also makes the system much easier to reason about and debug.
Tongliu/fsdp v2 mixed precision
04f488b to
a24e371
Compare
Reverted stage-2 files to nvidia/main baseline: - mcore_fsdp_adapter.py, optimizer_config.py Removed stage-2-only test files: - test_checkpoint_online_convert.py, test_mcore_nd_parallel.py All other MCore integration files already match nvidia/main after merge. Stage-2 preserved on mfsdp_refactor_main_stage2.
- distributed_data_parallel_config.py - training/arguments.py - training_config.py - training.py
33d0e32 to
694c09d
Compare
There was a problem hiding this comment.
I made to fully_shard.py! I will focus on mixed precision tomorrow.
In general, as I expected earlier, I’m more concerned about the low-level constructs like DataParallelBuffer and allocators and Optimizer. I’ll probably end up rewriting them yet again to properly support the different sharding strategies and formats. I started some of this work in #4835 and added you as a reviewer. We should try to converge on the design.
The high-level constructs also need a fair amount of cleanup and/or clarification, but I can leave comments on the PRs after you split this one up.
| # CUDA streams (communication overlap) | ||
| # ------------------------------------------------------------------ | ||
| ag_stream: torch.cuda.Stream # all-gather / unshard stream | ||
| rs_stream: torch.cuda.Stream # reduce-scatter stream |
There was a problem hiding this comment.
FYI, I think we'll need at least one extra for inter-node communication in HSDP/HFSDP.
|
|
||
| This ordering is used to: | ||
| - Schedule prefetching of parameters (unshard) | ||
| - Ensure correct overlap between compute and communication |
There was a problem hiding this comment.
I don't follow this. Even without an explicit order, correctness is guaranteed by letting forward wait_stream allgather and reduce_scatter wait_stream backward. What am I missing?
There was a problem hiding this comment.
Ensure correct overlap ere means allowing the communication and computation pipelines to run efficiently in parallel instead of random blocking each other. Because there are data dependencies between them, establishing the right execution order is crucial.
| FSDP modules in actual forward execution order. | ||
|
|
||
| This ordering is used to: | ||
| - Schedule prefetching of parameters (unshard) |
There was a problem hiding this comment.
How important is prefetching with per-module control? The user can choose to fully_shard on a higher-level module if they want more parameters to be unsharded in one go, sacrificing memory. But I might be misunderstanding what you mean by prefetching.
| mp_policy: FullyShardMixedPrecisionPolicy, | ||
| mesh: Optional[DeviceMesh] = None, | ||
| sharding_strategy: str = "optim_grads_params", | ||
| gradient_scaling_factor: Optional[float] = None, |
There was a problem hiding this comment.
Curious what this is for. It lacks documentation at this moment.
| "completed gradient reduction before re-initializing FSDP state." | ||
| ) | ||
|
|
||
| root_context = _FSDPRootContext( |
There was a problem hiding this comment.
I'm sure I'm missing something--how did you make sure this is only constructed for the root module?
| return (None,) + grads | ||
|
|
||
|
|
||
| def _replace_module_parameter(module: nn.Module, name: str, new_param: nn.Parameter): |
There was a problem hiding this comment.
| def _replace_module_parameter(module: nn.Module, name: str, new_param: nn.Parameter): | |
| def _replace_module_parameter(module: nn.Module, fqn: str, new_param: nn.Parameter): |
| allocator=bucket_allocator, | ||
| gradient_scaling_factor=gradient_scaling_factor, | ||
| ) | ||
| setattr(self, "_fsdp_param_groups", fsdp_param_groups) |
There was a problem hiding this comment.
| setattr(self, "_fsdp_param_groups", fsdp_param_groups) | |
| self._param_groups = fsdp_param_groups |
| raise AssertionError("Current module not found in forward module order") | ||
|
|
||
|
|
||
| class FSDPModule(nn.Module): |
There was a problem hiding this comment.
Since it's a mixed in, I don't think it need/should be an nn.Module. Refer to https://github.com/pytorch/pytorch/blob/67464b8261bc79dd8516f95964e4095d0546c533/torch/distributed/fsdp/_fully_shard/_fully_shard.py#L318
There was a problem hiding this comment.
I see, thanks for point it out.
| """ | ||
| param_groups = {} | ||
|
|
||
| for param in module.parameters(): |
There was a problem hiding this comment.
Something seems to be missing. Shouldn't we skip any submodule that is already an FSDPModule? Because it forms its own unit.
| else: | ||
| data = param.data.detach() | ||
|
|
||
| dist_param = torch.nn.Parameter( |
There was a problem hiding this comment.
Question: Conceptually, we need two sets of nn.Parameters, one for model computation (in low precision; views of model_weight_buffer) and the other for optimization (in high precision; views of main_weight_buffer). How's that handled in the code?
There was a problem hiding this comment.
I guess this is for the latter; the other "raw", full-shape parameter is probably constructed elsewhere I've yet to find.
| param_group_id: ParamGroupIdx, | ||
| *, | ||
| mp_policy: FullyShardMixedPrecisionPolicy, | ||
| mesh: Optional[DeviceMesh] = None, |
There was a problem hiding this comment.
| mesh: Optional[DeviceMesh] = None, | |
| mesh: DeviceMesh, |
Given it's an internal data structure, I don't see a really good reason to complicate the support surface.
| from packaging.version import Version as PkgVersion | ||
|
|
||
| try: | ||
| import transformer_engine as te |
There was a problem hiding this comment.
Do we still care about the non-TE path? It feels like the performance isn't going to be competitive at all without TE.
| fp8_param_gather: bool = False | ||
| fp8_recipe: Optional[str] = None | ||
| keep_fp8_transpose_cache: bool = False | ||
| use_decoupled_grad: bool = False | ||
| fp8: Optional[FullyShardFP8Policy] = field(default=None, repr=False) |
There was a problem hiding this comment.
Request for comments. Also can't they all be merged into FullyShardFp8Policy?
| return tensor.dtype | ||
| return ("quantized", type(tensor).__name__, self.fp8.recipe) | ||
|
|
||
| def validate_param_group(self, params: List[torch.Tensor]) -> None: |
There was a problem hiding this comment.
I think many methods of this class (like this one) should really be moved to a different class or a helper method. E.g. ParameterGroup.validate(). This makes MixedPrecisionPolicy a leaf-level class with minimal dependency, following https://en.wikipedia.org/wiki/Single-responsibility_principle.
| if self.is_fp8_param(tensor) and self.main_params_dtype is None: | ||
| return torch.float32 |
There was a problem hiding this comment.
Why? If a parameter is fp8, main_weight is fp32 regardless of main_params_dtype?
| ) -> List[torch.Tensor]: | ||
| """Return original parameter storages that FSDP buffers have replaced.""" | ||
| # The buffers are ownership signals, not data sources: non-FP8 params can | ||
| # be backed by either model or main weight storage, while FP8 params are |
There was a problem hiding this comment.
non-FP8 params can be backed by either model or main weight storage
Really? I thought params are either sharded params (backed by main_weight and main_grad) or unsharded params (sort of model_weight but allgathered).
There was a problem hiding this comment.
I suspect this refers to the backed model weight buffer in its distributed form (distributed state) and main weight shard for high-precision version. I think the description here could be more precise. FYI @Autumn1998
| if main_weight_buffer is None: | ||
| return | ||
|
|
||
| assert model_weight_buffer is not None, "main weights require a model-weight buffer" |
There was a problem hiding this comment.
Side question: for high precision like bf16, main_weight_buffer == model_weight_buffer?
| if len(model_params) == 0: | ||
| return |
There was a problem hiding this comment.
| if len(model_params) == 0: | |
| return |
| args = [ | ||
| model_params, | ||
| main_params, | ||
| start_offsets, | ||
| data_parallel_group, | ||
| fsdp_shard_model_params, | ||
| ] |
There was a problem hiding this comment.
| args = [ | |
| model_params, | |
| main_params, | |
| start_offsets, | |
| data_parallel_group, | |
| fsdp_shard_model_params, | |
| ] |
I'd just inline them to the call site.
| main_params: List[Optional[torch.Tensor]], | ||
| start_offsets: List[Optional[int]], | ||
| data_parallel_group: torch.distributed.ProcessGroup, | ||
| fsdp_shard_model_params: List[Tuple[Optional[torch.Tensor], Optional[torch.Tensor]]], |
There was a problem hiding this comment.
What's this in addition to model_params and main_params?
What does this PR do ?
Summary
Introduce Megatron-FSDP2 — a refactored per-module sharding implementation (
fully_shard_v2) with FSDP2-compatible API, communication-compute overlap, activation recompute support, and a pooled memory allocator for Megatron Core.Key Features
fully_shard()API: FSDP2-compatible wrapping interface that converts modules toFSDPModuledynamically, groups parameters intoParameterGroups with flatDataParallelBuffers, and installs forward/backward hooks for the unshard → forward → reshard → backward → reduce lifecycle.ag_stream): all-gathers parameters for the next module while the current module computes forward.rs_stream): reduce-scatters gradients while later modules compute backward. Sliding drain rule keeps ≤ 2 gradient buffers live at any time.backward_moduletracking + persistentunshard_done_eventsprevent redundant all-gathers and premature resharding during gradient checkpointing recompute passes.TracePoolAllocator: Three-phase (trace → plan → optimized) pooled memory allocator using greedy left-edge interval coloring. Eliminates per-calltorch.emptyoverhead with a static pool replayed across micro-batches.no_shard,optim(ZeRO-1),optim_grads(ZeRO-2),optim_grads_params(ZeRO-3), with uneven DTensor and distributed checkpoint support.Files Changed (27 files, +5,014 / −9)
v2/package__init__.py,fully_shard.py,fsdp_module.py,hooks.py,param_group.py,dp_buffer.py,allocator.py,utils.py,mixed_precision.py,design.md,README.mdmcore_fsdp_adapter.py,src/megatron_fsdp/__init__.py,megatron_fsdp.py,distributed_data_parallel_config.pyexamples/megatron_fsdp/fsdp_toy.pytests/.../v2/test_allocator.py,test_param_group.py,test_mcore_fully_shard_api.py,test_checkpoint_online_convert.pyExperimental Results — Per-Module Sharding Rewrite
Verification was performed comparing the refactored implementation against the baseline.
W&B link: https://wandb.ai/nvidia/megatron-fsdp/reports/M-FSDP-Rewrite-Convergence-Test-llama3-8b---VmlldzoxNjgzNTg3OQ
Contribution process
Pre-checks
Code review
Feel free to message or comment the @mcore-oncall to help accelerate your merge into main. The less complex your PR is, the faster it will be approved and merged!
All PRs start as draft. If you open a non-draft PR, it will be automatically converted to draft.
Step 1: Mark PR as "Ready for Review"
.github/CODEOWNERS.Final Review might get declined if these requirements are not fulfilled.
Step 2: Final Review
For PRs that change
megatron/core, once all expert reviewers have approved, theFinal Reviewlabel is applied automatically and final reviewers are assigned.For PRs outside
megatron/core, this step is skipped.Step 3: Approved
Once all required reviewers have approved, the
Approvedlabel is applied automatically.Merge
Any member of mcore-engineers will be able to merge your PR.
For MRs into `dev` branch
The proposed review process for `dev` branch is under active discussion.MRs are mergable after one approval by either
eharper@nvidia.comorzijiey@nvidia.com.