Skip to content

Fix Megatron-FSDP optimizer CPU offload and checkpointing#4623

Open
wplf wants to merge 4 commits into
NVIDIA:devfrom
wplf:jinliang/fix-mfsdp-optimizer-offload
Open

Fix Megatron-FSDP optimizer CPU offload and checkpointing#4623
wplf wants to merge 4 commits into
NVIDIA:devfrom
wplf:jinliang/fix-mfsdp-optimizer-offload

Conversation

@wplf

@wplf wplf commented May 5, 2026

Copy link
Copy Markdown
Member

Summary

  • Handle Megatron-FSDP DTensor parameters and gradients by operating on local shards before optimizer CPU offload copies.
  • Save HybridDeviceOptimizer state for fsdp_dtensor checkpoints as deterministic FSDP DTensors, skipping duplicate master_param entries and attaching local DCP chunk metadata without metadata collectives.
  • Allow SWiGLU/GDN optimizer checkpoint preprocessing to split flat optimizer DTensors using the corresponding model parameter metadata, and restore loaded DTensor optimizer state back to local tensors.

Fixes #4910.

Tests

  • uv run isort megatron/core/optimizer/distrib_optimizer.py
  • python -m py_compile megatron/core/optimizer/distrib_optimizer.py megatron/core/optimizer/cpu_offloading/hybrid_optimizer.py megatron/core/transformer/fsdp_dtensor_checkpoint.py
  • PYTHONDONTWRITEBYTECODE=1 PYTHONPATH=. uv run python -m pytest tests/unit_tests/test_optimizer_cpu_offloading.py -q (72 passed)
  • Local Qwen3.5-VL proxy repro with optimizer_cpu_offload=True, optimizer_offload_fraction=1, overlap_cpu_optimizer_d2h_h2d=True, use_precision_aware_optimizer=True, and fsdp_dtensor optimizer checkpointing: checkpoint saves at iterations 10 and 12 completed successfully.

@copy-pr-bot

copy-pr-bot Bot commented May 5, 2026

Copy link
Copy Markdown

This pull request requires additional validation before any workflows can run on NVIDIA's runners.

Pull request vetters can view their responsibilities here.

Contributors can view more details about this message here.

@wplf wplf changed the title Fix optimizer CPU offload for DTensor params Fix optimizer CPU offload for megatron-fsdp dtensor param May 5, 2026
@wplf wplf self-assigned this May 5, 2026
@wplf wplf marked this pull request as ready for review May 5, 2026 04:50
@wplf wplf requested review from a team as code owners May 5, 2026 04:50
@wplf

wplf commented May 5, 2026

Copy link
Copy Markdown
Member Author

/ok to test

@copy-pr-bot

copy-pr-bot Bot commented May 5, 2026

Copy link
Copy Markdown

/ok to test

@wplf, there was an error processing your request: E1

See the following link for more information: https://docs.gha-runners.nvidia.com/cpr/e/1/

@wplf

wplf commented May 5, 2026

Copy link
Copy Markdown
Member Author

/ok to test 5da7d67

@yaox12 yaox12 requested a review from shjwudp May 6, 2026 01:38

@cspades cspades left a comment

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@wplf Could you share some information on what the bug was? Curious what was the root cause and how it is related to DTensors!

Comment thread megatron/core/optimizer/cpu_offloading/hybrid_optimizer.py Outdated
@wplf wplf changed the title Fix optimizer CPU offload for megatron-fsdp dtensor param Fix Megatron-FSDP optimizer CPU offload and checkpointing May 21, 2026
wplf added 3 commits May 20, 2026 23:41
Handle Megatron-FSDP DTensor parameters and gradients by operating on local shards before CPU optimizer offload copies. This avoids dispatching pin_memory/is_pinned through DTensor and lets pin_cpu_params control CPU parameter pinning.
@wplf wplf force-pushed the jinliang/fix-mfsdp-optimizer-offload branch from eff7bfb to cdaac89 Compare May 21, 2026 06:44
@wplf wplf force-pushed the jinliang/fix-mfsdp-optimizer-offload branch from cdaac89 to 2f2442b Compare May 27, 2026 03:49
@wplf wplf force-pushed the jinliang/fix-mfsdp-optimizer-offload branch from 2f2442b to 21ba0b9 Compare May 27, 2026 03:54

@cspades cspades left a comment

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Some important nits, basically if we can move most of this code into MFSDP source it would be much cleaner and better.

INFO: Could you also add more commentary explaining the lifecycle of the offloaded DTensors here? Just want to make sure if this is the right way to implement this.

Comment on lines +1625 to +1631
if isinstance(self.optimizer, HybridDeviceOptimizer):
packed_state = self._pack_hybrid_optimizer_fsdp_state_dict()
else:
packed_state = {
(self._param_name(k) if isinstance(k, torch.Tensor) else k): v
for k, v in self.state.items()
}

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This optimizer isn't just used for FSDP right? We should check that the parameters are FSDP parameters before doing this.

Comment on lines +1670 to +1673
def _pack_hybrid_optimizer_fsdp_state(
self, param: torch.nn.Parameter, param_name: str, state: dict[str, Any]
) -> dict[str, Any]:
"""Convert HybridDeviceOptimizer FSDP-local tensor state to DTensors."""

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not just use preprocess_state_dict_for_uneven_dtensor or update_uneven_dtensor_chunk_metadata instead of re-implementing it here? (Also, we should have all of these types of utilities inside the Megatron-FSDP source directory unless it absolutely needs to be in the training loop / optimizer code!)

Comment on lines +308 to +314
local_tensor = data.to_local()
else:
assert data.numel() == dist_param.numel(), (
f"DTensor shape mismatch: data.shape={data.shape} vs "
f"dist_param.shape={global_shape}"
)
local_tensor = data.to_local().view(-1)

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These can also be DTensor._local_tensor BTW! Plus more instances of this in this same file.

Comment on lines +1725 to +1735
value = make_fsdp_dtensor(
value.data.view(-1),
flat_param,
dist_index=param.megatron_fsdp_dist_index,
is_expert_param=is_expert_param,
run_check=False,
update_uneven_dtensor_chunk_meta=False,
)
self._set_flat_fsdp_dtensor_chunk_metadata(
value, param.megatron_fsdp_slice
)

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So TL;DR I assume the packed_state is our offloaded state, and we want them to be DTensors as well?

Comment on lines +1747 to +1748
def _set_flat_fsdp_dtensor_chunk_metadata(tensor: DTensor, fsdp_slice: slice) -> None:
"""Attach DCP chunk metadata for a flat FSDP-local shard without collectives."""

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can just re-factor the existing function in uneven_dtensor.py to skip the metadata AG update! Less code outside MFSDP the better!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants