Skip to content

fix mfsdp unwrap stuck at MegatronFSDP [dev]#4273

Merged
yaox12 merged 3 commits into
NVIDIA:devfrom
wplf:jinliang/fix-fsdp-unwrap
Apr 15, 2026
Merged

fix mfsdp unwrap stuck at MegatronFSDP [dev]#4273
yaox12 merged 3 commits into
NVIDIA:devfrom
wplf:jinliang/fix-fsdp-unwrap

Conversation

@wplf

@wplf wplf commented Apr 13, 2026

Copy link
Copy Markdown
Member

What does this PR do ?

image

main PR: #4274

Problem: unwrap_model() in megatron/core/utils.py gets stuck when unwrapping a model wrapped with Megatron-FSDP. The wrapping hierarchy is:

FullyShardedDataParallel (mcore adapter)
└── .module → MegatronFSDP (core FSDP impl)
└── .module → actual model (e.g., GPTModel)

The old code only knew how to peel through DDP, torch_FSDP, megatron_FSDP (the adapter), and Float16Module. It would unwrap the outer FullyShardedDataParallel but then hit the inner MegatronFSDP and stop — returning MegatronFSDP instead of the actual model.

Fix: One-line change — adds MegatronFSDP (from megatron.core.distributed.fsdp.src.megatron_fsdp.megatron_fsdp) to the default module_instances tuple, so the while isinstance(...) loop can peel through both wrapper layers.

You can use script below to see what happens.

import torch
import torch.distributed as dist

import megatron.core.parallel_state as mpu
from megatron.core.distributed import DistributedDataParallelConfig
from megatron.core.distributed.fsdp.mcore_fsdp_adapter import FullyShardedDataParallel
from megatron.core.distributed.fsdp.src.megatron_fsdp.megatron_fsdp import MegatronFSDP
from megatron.core.models.gpt.gpt_layer_specs import get_gpt_layer_local_spec
from megatron.core.models.gpt.gpt_model import GPTModel
from megatron.core.transformer import TransformerConfig
from megatron.core.transformer.transformer_layer import TransformerLayer
from megatron.core.utils import unwrap_model


def init_distributed():
    """Initialize torch.distributed and Megatron parallel state."""
    if not dist.is_initialized():
        dist.init_process_group(backend="nccl")
    torch.cuda.set_device(dist.get_rank() % torch.cuda.device_count())
    mpu.initialize_model_parallel(
        tensor_model_parallel_size=1,
        pipeline_model_parallel_size=1,
    )


def build_fsdp_model():
    """Build a GPTModel wrapped with FullyShardedDataParallel, just like training does."""
    transformer_config = TransformerConfig(
        num_layers=2,
        hidden_size=64,
        num_attention_heads=4,
        use_cpu_initialization=True,
    )

    transformer_layer_spec = get_gpt_layer_local_spec()

    gpt_model = GPTModel(
        config=transformer_config,
        transformer_layer_spec=transformer_layer_spec,
        vocab_size=256,
        max_sequence_length=128,
        pre_process=True,
        post_process=True,
    ).cuda()

    ddp_config = DistributedDataParallelConfig(
        data_parallel_sharding_strategy="optim_grads_params",
        overlap_grad_reduce=True,
        overlap_param_gather=True,
        bucket_size=10000,
        use_megatron_fsdp=True,
    )

    fsdp_model = FullyShardedDataParallel(
        config=transformer_config,
        ddp_config=ddp_config,
        module=gpt_model,
        fsdp_unit_modules=[TransformerLayer],
    )

    return gpt_model, fsdp_model


def main():
    init_distributed()

    gpt_model, fsdp_model = build_fsdp_model()

    # Print the wrapping hierarchy
    print("=" * 60)
    print("Model wrapping hierarchy:")
    print(f"  fsdp_model              : {type(fsdp_model).__name__}")
    print(f"  fsdp_model.module       : {type(fsdp_model.module).__name__}")
    print(f"  fsdp_model.module.module: {type(fsdp_model.module.module).__name__}")
    print()

    # Verify the hierarchy:
    #   FullyShardedDataParallel -> MegatronFSDP -> GPTModel
    assert isinstance(fsdp_model, FullyShardedDataParallel)
    assert isinstance(fsdp_model.module, MegatronFSDP)
    assert isinstance(fsdp_model.module.module, GPTModel)

    # Now test unwrap_model
    unwrapped = unwrap_model(fsdp_model)

    print("unwrap_model result:")
    print(f"  type  : {type(unwrapped).__name__}")
    print(f"  is original GPTModel: {unwrapped is gpt_model}")

    # Cleanup
    mpu.destroy_model_parallel()
    dist.destroy_process_group()


if __name__ == "__main__":
    main()

result

============================================================
Model wrapping hierarchy:
  fsdp_model              : FullyShardedDataParallel
  fsdp_model.module       : MegatronFSDP
  fsdp_model.module.module: GPTModel

unwrap_model result:
  type  : MegatronFSDP
  is original GPTModel: False

@wplf wplf requested review from a team as code owners April 13, 2026 10:02
@copy-pr-bot

copy-pr-bot Bot commented Apr 13, 2026

Copy link
Copy Markdown

This pull request requires additional validation before any workflows can run on NVIDIA's runners.

Pull request vetters can view their responsibilities here.

Contributors can view more details about this message here.

@BestJuly

Copy link
Copy Markdown
Contributor

Can we add a unit test to guard this?

@cspades

cspades commented Apr 13, 2026

Copy link
Copy Markdown
Member

Can we add a unit test to guard this?

Agree w/ @BestJuly. And also there is a list of files that use this helper function, can we double-check / justify that Megatron-FSDP model logic surrounding the use of this utility is all valid?

@cspades

cspades commented Apr 13, 2026

Copy link
Copy Markdown
Member

/ok to test ff0ed7e

@svcnvidia-nemo-ci svcnvidia-nemo-ci added this to the Core 0.16 milestone Apr 13, 2026
@cspades

cspades commented Apr 13, 2026

Copy link
Copy Markdown
Member

Needs linting!

@wplf wplf changed the title fix mfsdp unwrap stuck at MegatronFSDP fix mfsdp unwrap stuck at MegatronFSDP [dev] Apr 14, 2026
wplf and others added 2 commits April 14, 2026 07:07
Guard that unwrap_model correctly peels through both DDP and
Megatron-FSDP wrapping hierarchies to reach the underlying GPTModel.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
@wplf wplf force-pushed the jinliang/fix-fsdp-unwrap branch from ff0ed7e to 3c2e66d Compare April 14, 2026 14:09
@wplf

wplf commented Apr 14, 2026

Copy link
Copy Markdown
Member Author

/ok to test 67077c9

@yaox12 yaox12 added this pull request to the merge queue Apr 15, 2026
@svcnvidia-nemo-ci

Copy link
Copy Markdown

🔄 Merge queue validation started!

You can track the progress here: https://github.com/NVIDIA/Megatron-LM/actions/runs/24432478244

@wplf wplf self-assigned this Apr 15, 2026
Merged via the queue into NVIDIA:dev with commit 9a7c5dd Apr 15, 2026
61 of 62 checks passed
yaoyu-33 added a commit to NVIDIA-NeMo/Megatron-Bridge that referenced this pull request Apr 20, 2026
…reprocessing

MCore's unwrap_model now strips the MegatronFSDP layer (added in
NVIDIA/Megatron-LM#4273), so preprocess_fsdp_dtensor_state_dict receives
a fully unwrapped GPTModel. The downstream MCore functions
(handle_swiglu_in_state_dict, handle_gdn_in_state_dict) call
model.get_parameter("module.{key}") which requires a .module wrapper.
Re-wrap the model when it arrives without one.

Fixes: AttributeError: GPTModel has no attribute `module`

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Signed-off-by: yaoyu-33 <yaoyu.094@gmail.com>
@wplf

wplf commented Apr 20, 2026

Copy link
Copy Markdown
Member Author

Hi @cspades @yaoyu-33
Sorry for breaking the MFSDP + SwiGLU checkpoint saving.

We may need to add a test to cover this case and remove the checkpoint assumption for Model.MegatronFSDPModule.GPTModel in handle_swiglu_in_state_dict.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants