Skip to content

[dev] Fix GDN DTensor splitting for FSDP checkpointing#4799

Merged
wplf merged 1 commit into
NVIDIA:devfrom
conver334:fix-gdn-dtensor-split
May 15, 2026
Merged

[dev] Fix GDN DTensor splitting for FSDP checkpointing#4799
wplf merged 1 commit into
NVIDIA:devfrom
conver334:fix-gdn-dtensor-split

Conversation

@conver334

Copy link
Copy Markdown
Contributor

What does this PR do ?

When using convert_checkpoints_fsdp.py to convert Qwen3.5-35B-A3B HuggingFace weights into an fsdp_dtensor checkpoint (TP=2, CP=2, EP=2), the following error occurs:

traceback : Traceback (most recent call last):
  File "/usr/local/lib/python3.12/dist-packages/torch/distributed/elastic/multiprocessing/errors/__init__.py", line 362, in wrapper
    return f(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^
  File "/root/mfsdp-intergration/Megatron-Bridge/examples/conversion/mfsdp/convert_checkpoints_fsdp.py", line 197, in import_hf_to_megatron_fsdp
    save_native_megatron_model(
  File "/root/mfsdp-intergration/Megatron-Bridge/src/megatron/bridge/training/model_load_save.py", line 711, in save_megatron_model
    save_checkpoint(
  File "/root/mfsdp-intergration/Megatron-Bridge/src/megatron/bridge/training/checkpointing.py", line 987, in save_checkpoint
    state_dict = preprocess_fsdp_dtensor_state_dict(cfg, state_dict, model[0])
                 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/root/mfsdp-intergration/Megatron-Bridge/src/megatron/bridge/training/checkpointing.py", line 1663, in preprocess_fsdp_dtensor_state_dict
    model_state_dict, _ = handle_gdn_in_state_dict(model, state_dict["model"], None)
                          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/root/mfsdp-intergration/Megatron-LM/megatron/core/transformer/fsdp_dtensor_checkpoint.py", line 531, in handle_gdn_in_state_dict
    sub_tensors = split_gdn_fused(model_state_dict[key], dist_param, sizes, dim)
                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/root/mfsdp-intergration/Megatron-LM/megatron/core/transformer/fsdp_dtensor_checkpoint.py", line 509, in split_gdn_fused
    dtensor = make_fsdp_dtensor(
              ^^^^^^^^^^^^^^^^^^
  File "/root/mfsdp-intergration/Megatron-LM/megatron/core/distributed/fsdp/src/megatron_fsdp/param_and_grad_buffer.py", line 4699, in make_fsdp_dtensor
    validate_uneven_dtensor(fsdp_tensor)
  File "/root/mfsdp-intergration/Megatron-LM/megatron/core/distributed/fsdp/src/megatron_fsdp/uneven_dtensor.py", line 196, in validate_uneven_dtensor
    assert torch.all(boundary_checks), (
           ^^^^^^^^^^^^^^^^^^^^^^^^^^
AssertionError: [Megatron-FSDP] DTensor chunk metadata boundary check failed. Offsets: (0, 0, 0), Sizes: (512, 1, 4), Global shape: torch.Size([1024, 1, 4]), Local shape: torch.Size([512, 1, 4]), Device mesh: DeviceMesh((dp_cp=4), 'cuda', stride=(2,)).

This bug occurs because, during the DTensor splitting step in the FSDP save preprocessing for GDN fused projection, a DTensor that is already TP-local is incorrectly reconstructed again using FSDP/TP logic. This results in invalid uneven DTensor metadata.

Fix: when the data is already a TP-local DTensor and data.shape[split_dim] == sum(split_sizes), directly use split_dtensor(..., update_uneven_dtensor_chunk_meta=True) to split it, instead of calling make_fsdp_dtensor(...) again.

Issue tracking

Linked issue: N/A - small bug fix found during downstream Qwen3.5 Megatron-FSDP checkpoint saving.

Contribution process

Pre-checks

  • I have added relevant unit tests
  • I have added relevant functional tests
  • I have added proper typing to my code Typing guidelines
  • I have added relevant documentation
  • I have run the autoformatter.sh on my PR

@conver334 conver334 requested review from a team as code owners May 14, 2026 12:10
@copy-pr-bot

copy-pr-bot Bot commented May 14, 2026

Copy link
Copy Markdown

This pull request requires additional validation before any workflows can run on NVIDIA's runners.

Pull request vetters can view their responsibilities here.

Contributors can view more details about this message here.

@conver334 conver334 force-pushed the fix-gdn-dtensor-split branch from a32437c to 2fab5a0 Compare May 14, 2026 12:16
Signed-off-by: conver334 <conver334@gmail.com>
@conver334 conver334 force-pushed the fix-gdn-dtensor-split branch from 2fab5a0 to 66f5efd Compare May 14, 2026 12:19
@xuwchen

xuwchen commented May 15, 2026

Copy link
Copy Markdown
Contributor

/ok to test 66f5efd

@wplf wplf enabled auto-merge May 15, 2026 02:11
@wplf wplf added this pull request to the merge queue May 15, 2026
@svcnvidia-nemo-ci

Copy link
Copy Markdown

🔄 Merge queue validation started!

You can track the progress here: https://github.com/NVIDIA/Megatron-LM/actions/runs/25907529695

wplf added a commit to wplf/Megatron-LM that referenced this pull request May 15, 2026
The early-return ``isinstance(data, DTensor) and data.shape[split_dim] == total_split``
branch in ``split_gdn_fused`` is an independent fix that landed on this
PR via conver334's internal contribution. NVIDIA#4799 by
the same author proposes the same fast-path against ``dev`` directly,
so we remove it here to avoid a textual conflict when NVIDIA#4799 merges.

The remaining changes in this PR are still needed and complementary
to NVIDIA#4799 (plain-Tensor support in splitters, ``dist_param``-coordinate
computation for non-DTensor optimizer states, top-level metadata
stash for Bridge-converted DCP checkpoints, and the FSDP no-op in
``DistributedOptimizer._copy_model_params_to_main_params``).

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
wplf added a commit to wplf/Megatron-LM that referenced this pull request May 15, 2026
The early-return ``isinstance(data, DTensor) and data.shape[split_dim] == total_split``
branch in ``split_gdn_fused`` is an independent fix that landed on this
PR via conver334's internal contribution. NVIDIA#4799 by
the same author proposes the same fast-path against ``dev`` directly,
so we remove it here to avoid a textual conflict when NVIDIA#4799 merges.

The remaining changes in this PR are still needed and complementary
to NVIDIA#4799 (plain-Tensor support in splitters, ``dist_param``-coordinate
computation for non-DTensor optimizer states, top-level metadata
stash for Bridge-converted DCP checkpoints, and the FSDP no-op in
``DistributedOptimizer._copy_model_params_to_main_params``).

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
wplf added a commit to wplf/Megatron-LM that referenced this pull request May 15, 2026
NVIDIA#4799 by conver334 proposes the same TP-local GDN
fast-path against dev directly. Remove our copy here so this integration
branch doesn't textually conflict when NVIDIA#4799 merges.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
@svcnvidia-nemo-ci

Copy link
Copy Markdown

🔄 Merge queue validation started!

You can track the progress here: https://github.com/NVIDIA/Megatron-LM/actions/runs/25910357149

Merged via the queue into NVIDIA:dev with commit df12802 May 15, 2026
64 of 66 checks passed
SpencerGarnets added a commit to ai-blaise/Megatron-LM that referenced this pull request May 16, 2026
Upstream dev tip: 77c0f8c

Pulled commits:

- 77c0f8c [Dev][feat] Support A2A Overlap for Megatron-FSDP (NVIDIA#3796)

- 8195337 [dev] [3/5] Qwen3.5 support: SharedExpertMLP meta init (NVIDIA#4749)

- 2672ff5 [DEV] fix(megatron-fsdp): preserve non-meta tensors during meta materialization (NVIDIA#4155)

- cfbd9df [dev] [4/5] Qwen3.5 support: Interleaved MRoPE layout (NVIDIA#4750)

- df12802 [dev] Fix GDN DTensor splitting for FSDP checkpointing (NVIDIA#4799)

Resolution: zero conflicts; git auto-merged 12 shared files in megatron/core/{distributed,models,pipeline_parallel,transformer} and tests/unit_tests/a2a_overlap. No ai-blaise custom files touched.

Gates:

- git diff --check: clean

- conflict markers: none

- py_compile (16 changed .py files): OK

- indexcache: 27/28 pass; the 1 fail (test_nvfp4_non_blackwell_cuda_uses_reference_fallback) reproduces identically at the pre-merge base SHA (sglang occupies all 8 H200s in EXCLUSIVE_PROCESS mode -> cudaErrorDevicesUnavailable). 1 Blackwell-only test auto-skips on H200.

- transformer gdn/mtp/moe suite: 53 failed / 7 passed / 55 skipped / 5 errors -- IDENTICAL numbers at pre-merge base; all failures are the same environmental cudaErrorDevicesUnavailable.

- 2-rank torchrun layer-wise optimizer smoke: blocked (no free GPUs).

Custom preserved: StreamBP, IndexCache config, NVFP4 indexer (7e78f28), HISA topk1024 backward test (c628c13), pyproject emerging_optimizers v0.2.0 pin, mHC/MTP/MoE composition.
wplf added a commit to wplf/Megatron-LM that referenced this pull request May 18, 2026
The merge of ``dev`` into this branch produced broken Python in
``split_gdn_fused``:

- inconsistent indentation (``             else:`` at 13 spaces vs
  ``if`` at 12), tripping black with
  ``IndentationError: unindent does not match any outer indentation level``;
- an unclosed ``assert data.shape == global_shape, (...`` (missing
  closing ``)``);
- incorrect nesting of the new TP-local fast-path (from NVIDIA#4799, now on
  ``dev``) inside ``if isinstance(data, DTensor):``.

Restore the intended structure:

1. TP-local fast-path as a flat top-level early return
   (matches the form NVIDIA#4799 landed on ``dev``).
2. PR-2's ``global_shape = dist_param.shape`` and the
   ``data.shape == global_shape`` assert for full-global DTensor data,
   kept as a sibling check after the fast-path.

CI ``BASE_REF=dev`` autoformat now passes (black, isort, pylint 10/10,
ruff).

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
wplf added a commit to wplf/Megatron-LM that referenced this pull request May 18, 2026
This reverts 947c91e. NVIDIA#4799 has landed on ``dev`` (and was synced
into PR NVIDIA#4748 via a dev-merge), so we removed our copy from the dev-
target PR to avoid a textual conflict.

``main`` does **not** yet have NVIDIA#4799, however. Without the TP-local
GDN fast-path on the main side, ``split_gdn_fused`` falls into the
TP-mesh-based branch even when ``data`` is already TP-local
(``data.shape[split_dim] == sum(split_sizes)``), which yields wrong
``data_size`` / ``view_shape`` for Bridge-converted DTensors that
encode the per-rank size as the global shape.

Restore the fast-path on this PR (target ``main``) until NVIDIA#4799's
main mirror lands. When that happens, we should drop the fast-path
here too (mirror of what we did on PR NVIDIA#4748 dev) to avoid the
textual conflict.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
wplf added a commit to wplf/Megatron-LM that referenced this pull request May 29, 2026
The split_gdn_fused() fast-path for already-TP-local DTensors (added by
NVIDIA#4799) calls split_dtensor(..., update_uneven_dtensor_chunk_meta=True).
split_dtensor() itself unconditionally calls gather_and_compute_chunk_metadata
(line 457 in uneven_dtensor.py), which issues an all_gather_object on the
DTensor's shard groups.

Inside checkpoint save preprocessing, this triggers the same deadlock as
issue NVIDIA#4910: ranks enter collectives on different mesh groups in different
orders and the NCCL watchdog aborts after 10 min.

PR #2's explicit-chunk-metadata fix only covered the slow rebuild path; the
fast-path was untouched. Disable it so GDN model state falls through to the
slow path.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants