Skip to content

[dev] [2/5] Qwen3.5 support: FSDP DTensor Bridge checkpoint compatibility#4748

Merged
yaox12 merged 10 commits into
NVIDIA:devfrom
wplf:fix/fsdp-dtensor-bridge-checkpoint
May 19, 2026
Merged

[dev] [2/5] Qwen3.5 support: FSDP DTensor Bridge checkpoint compatibility#4748
yaox12 merged 10 commits into
NVIDIA:devfrom
wplf:fix/fsdp-dtensor-bridge-checkpoint

Conversation

@wplf

@wplf wplf commented May 12, 2026

Copy link
Copy Markdown
Member

Qwen3.5 support series

This is part of a 5-PR series adding Qwen3.5-VL support, split for review clarity.

Dev PRs (this series):

Main PRs (corresponding mirrors):


Summary

Three related fixes that together let fsdp_dtensor resume from a checkpoint converted with Megatron-Bridge:

  1. split_swiglu_linear_fc1 / split_gdn_fused: accept plain Tensor (not only DTensor) for data.to_local(). Bridge writes some already-local tensors that previously crashed the splitter.

  2. Stash missing top-level keys in _load_base_checkpoint: keys not present as DCP storage metadata (e.g. args, iteration) are stashed before dcp.load and restored after. Megatron-saved checkpoints contain these as BytesStorageMetadata; Bridge-converted ones don't, and DCP errors without this stash.

  3. DistributedOptimizer._copy_model_params_to_main_params: under use_megatron_fsdp, return instead of raising. DCP already loads directly into the fp32 main buffer, and a post-load hook copies main → model weights — both are correct at this call site.

  4. split_swiglu_linear_fc1 / split_gdn_fused use dist_param coordinates (cherry-pick of [DEV] fix(megatron-fsdp): compute SWiGLU/GDN split in item coordinates for non-DTensor optimizer states #4424 by @xuwchen): when data is a plain Tensor (FusedAdam optimizer state), data.numel() / data.shape describe the local item layout, not the global TP-sharded layout, so the splitter computed wrong data_size, view_shape, and per_tp_rank_shape. Switch to global_shape = dist_param.shape and dist_param.numel() everywhere, with a DTensor-shape consistency assert.

  5. resolve conflict from [dev] Fix GDN DTensor splitting for FSDP checkpointing #4799

Also wires handle_gdn_in_state_dict into preprocess_fsdp_dtensor_state_dict.

Why bundled

All three fixes are needed together for the Bridge → fsdp_dtensor resume path to work end-to-end. Splitting further yields PRs that individually leave the path broken.

Note on key-unwrapping

An earlier revision of this branch also rewrote model.get_parameter(f"module.{key}")model.get_parameter(key) (and the analogous optimizer-state lookups) on the assumption that unwrap_model strips MegatronFSDPModule. That root cause was resolved in #4393 ("Revert 'fix mfsdp unwrap stuck at MegatronFSDP [dev] (#4273)'"), which removed MegatronFSDPModule from unwrap_model's tuple, so the module. prefix is correct again. Those four lookup edits have been dropped from this PR.

Test plan

  • Save a checkpoint under FSDP-DTensor with Megatron-LM, resume — unchanged.
  • Convert a checkpoint with Megatron-Bridge, resume under FSDP-DTensor — succeeds (was previously failing on missing-key error from DCP, then on the GDN splitter for TP-local tensors).
  • Hybrid model with GDN + SwiGLU MoE: loss curve matches pre-resume curve.

Notes

Extracted from a larger Qwen3.5-VL development branch to keep these fixes reviewable on their own.

🤖 Generated with Claude Code

Update (2026-05-14): model→main copy reverted

A previous revision of this PR (commit 4c2c3cbfb on dev / f4ba23b11 on main) implemented an FSDP copy_model_weights_to_main_weights and called it from _copy_model_params_to_main_params, on the theory that the silent return left fp32 main params uninitialized under --no-load-optim / --finetune.

Empirical buffer-diff probe showed that's wrong: under fsdp_dtensor, main_weight_buffer is exposed as the module's nn.Parameters via MegatronFSDP's state_dict pre-hook (_replace_param_with_distributed_if_needed), so the preceding model.load_state_dict(state_dict["model"]) writes the checkpoint's fp32 tensors directly into main_weight_buffer — full fp32 precision, even with --no-load-optim, because main params live in the model state dict (not the optimizer state dict) under fsdp_dtensor. The load_state_dict post-hook (copy_main_weights_to_model_weights) then casts main → bf16 to refresh model_weight_buffer.

A 50-step finetune+resume on a Qwen3.5-VL 0.8B model with and without the copy gave indistinguishable loss curves (wandb runs q8d7l7nm / o4tyuf5z / wwuz7y4z under https://wandb.ai/wplf/fsdp-fix-verify/). The pre-copy buffer diff was exactly bf16 rounding noise; the post-copy diff dropped to 0 only because main had been downgraded to bf16 precision.

The model→main copy commit has been reverted. The plain return is correct, and the comment in _copy_model_params_to_main_params now documents the mechanism so this isn't re-litigated.

Update (2026-05-14): TP-local GDN fast-path dropped

The original 2nd bullet (TP-local GDN split via split_dtensor) was the same fix proposed by @conver334 in #4799 against dev directly. We've dropped that bullet here so this PR doesn't textually conflict when #4799 merges; the remaining bullets are still complementary fixes that #4799 doesn't cover.

@copy-pr-bot

copy-pr-bot Bot commented May 12, 2026

Copy link
Copy Markdown

This pull request requires additional validation before any workflows can run on NVIDIA's runners.

Pull request vetters can view their responsibilities here.

Contributors can view more details about this message here.

@wplf

wplf commented May 13, 2026

Copy link
Copy Markdown
Member Author

/ok to test aeaf0d1

Three small fixes that together let FSDP-DTensor resume from a checkpoint
converted with Megatron-Bridge:

1. `split_swiglu_linear_fc1` / `split_gdn_fused`: accept a plain `Tensor`
   (not only `DTensor`) for `data.to_local()`. Bridge writes some
   already-local tensors that previously crashed the splitter.

2. `split_gdn_fused`: when the GDN tensor is already TP-local
   (`data.shape[split_dim] == sum(split_sizes)`), use
   `split_dtensor(..., update_uneven_dtensor_chunk_meta=True)` instead
   of the TP-mesh-based split that assumes the full unsharded shape.

3. `megatron/training/checkpointing.py::_load_base_checkpoint`: stash
   top-level state-dict keys that DCP storage metadata doesn't know
   about (`args`, `iteration`, ...), then restore them after
   `dcp.load`. Megatron-saved checkpoints contain these as
   `BytesStorageMetadata`; Bridge-converted ones don't, and DCP errors
   without this stash.

4. `DistributedOptimizer._copy_model_params_to_main_params`: under
   `use_megatron_fsdp`, return instead of raising. DCP already loads
   directly into the fp32 main buffer, and a post-load hook copies
   main → model weights — both are correct at this call site.

Also wires `handle_gdn_in_state_dict` into
`preprocess_fsdp_dtensor_state_dict`.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>

Co-Authored-By: xuwchen <79835960+xuwchen@users.noreply.github.com>
Co-Authored-By: conver334 <56124251+conver334@users.noreply.github.com>

Co-Authored-By: BestJuly <19769279+BestJuly@users.noreply.github.com>
@wplf wplf force-pushed the fix/fsdp-dtensor-bridge-checkpoint branch from 50c107a to c9dfa74 Compare May 13, 2026 10:24
@wplf

wplf commented May 13, 2026

Copy link
Copy Markdown
Member Author

/ok to test c9dfa74

@wplf

wplf commented May 13, 2026

Copy link
Copy Markdown
Member Author

/ok to test 7f9e6b6

@wplf wplf left a comment

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

…path

Collapse the multi-line ``split_dtensor(...)`` call inside the
``isinstance(data, DTensor) and data.shape[split_dim] == total_split``
fast-path onto a single line per black/autoformat.sh rules.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
@wplf

wplf commented May 13, 2026

Copy link
Copy Markdown
Member Author

/ok to test 021b98d

wplf added a commit to wplf/Megatron-LM that referenced this pull request May 14, 2026
This reverts commit 29db3b3 (integration branch counterpart of the
PR-2 revert). Empirical verification showed the FSDP model->main copy
was unneeded — and slightly harmful (downgrades fp32 main to bf16-
quantized) — because under fsdp_dtensor the fp32 main_weight_buffer
is loaded directly from the checkpoint's model state dict via the
``MegatronFSDP`` ``state_dict`` pre-hook
(_replace_param_with_distributed_if_needed) and the
load_state_dict post-hook (copy_main_weights_to_model_weights).
See PRs NVIDIA#4748 / NVIDIA#4753 commit messages and the wandb runs at
https://wandb.ai/wplf/fsdp-fix-verify/ for the verification trace.

Also expands the comment on the FSDP early-return to document the
mechanism so this isn't re-litigated.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
@wplf

wplf commented May 14, 2026

Copy link
Copy Markdown
Member Author

/ok to test 52f5d6e

@copy-pr-bot

copy-pr-bot Bot commented May 14, 2026

Copy link
Copy Markdown

/ok to test 52f5d6e

@wplf, there was an error processing your request: E2

See the following link for more information: https://docs.gha-runners.nvidia.com/cpr/e/2/

@wplf

wplf commented May 14, 2026

Copy link
Copy Markdown
Member Author

/ok to test 7bdf9d8

The early-return ``isinstance(data, DTensor) and data.shape[split_dim] == total_split``
branch in ``split_gdn_fused`` is an independent fix that landed on this
PR via conver334's internal contribution. NVIDIA#4799 by
the same author proposes the same fast-path against ``dev`` directly,
so we remove it here to avoid a textual conflict when NVIDIA#4799 merges.

The remaining changes in this PR are still needed and complementary
to NVIDIA#4799 (plain-Tensor support in splitters, ``dist_param``-coordinate
computation for non-DTensor optimizer states, top-level metadata
stash for Bridge-converted DCP checkpoints, and the FSDP no-op in
``DistributedOptimizer._copy_model_params_to_main_params``).

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
@wplf

wplf commented May 15, 2026

Copy link
Copy Markdown
Member Author

/ok to test ca6a716

@wplf

wplf commented May 18, 2026

Copy link
Copy Markdown
Member Author

/ok to test 99e3a5f

The merge of ``dev`` into this branch produced broken Python in
``split_gdn_fused``:

- inconsistent indentation (``             else:`` at 13 spaces vs
  ``if`` at 12), tripping black with
  ``IndentationError: unindent does not match any outer indentation level``;
- an unclosed ``assert data.shape == global_shape, (...`` (missing
  closing ``)``);
- incorrect nesting of the new TP-local fast-path (from NVIDIA#4799, now on
  ``dev``) inside ``if isinstance(data, DTensor):``.

Restore the intended structure:

1. TP-local fast-path as a flat top-level early return
   (matches the form NVIDIA#4799 landed on ``dev``).
2. PR-2's ``global_shape = dist_param.shape`` and the
   ``data.shape == global_shape`` assert for full-global DTensor data,
   kept as a sibling check after the fast-path.

CI ``BASE_REF=dev`` autoformat now passes (black, isort, pylint 10/10,
ruff).

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
@wplf

wplf commented May 18, 2026

Copy link
Copy Markdown
Member Author

/ok to test af5bc08

wplf added a commit to wplf/Megatron-LM that referenced this pull request May 18, 2026
This reverts 947c91e. NVIDIA#4799 has landed on ``dev`` (and was synced
into PR NVIDIA#4748 via a dev-merge), so we removed our copy from the dev-
target PR to avoid a textual conflict.

``main`` does **not** yet have NVIDIA#4799, however. Without the TP-local
GDN fast-path on the main side, ``split_gdn_fused`` falls into the
TP-mesh-based branch even when ``data`` is already TP-local
(``data.shape[split_dim] == sum(split_sizes)``), which yields wrong
``data_size`` / ``view_shape`` for Bridge-converted DTensors that
encode the per-rank size as the global shape.

Restore the fast-path on this PR (target ``main``) until NVIDIA#4799's
main mirror lands. When that happens, we should drop the fast-path
here too (mirror of what we did on PR NVIDIA#4748 dev) to avoid the
textual conflict.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Comment thread megatron/core/optimizer/distrib_optimizer.py Outdated
@xuwchen xuwchen disabled auto-merge May 18, 2026 14:55
xuwchen and others added 2 commits May 18, 2026 09:12
The FSDP early-return comment in ``_copy_model_params_to_main_params``
cited specific line numbers in
``megatron/core/distributed/fsdp/src/megatron_fsdp/megatron_fsdp.py``.
Those line numbers shift on every unrelated change to that file
(e.g. the A2A overlap PR NVIDIA#3796 already shifted them on dev). Replace
with the stable symbol names — ``_replace_param_with_distributed_if_needed``,
``install_optimized_model_weights``, ``copy_main_weights_to_model_weights``
— so the comment doesn't go stale.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
@wplf

wplf commented May 18, 2026

Copy link
Copy Markdown
Member Author

/ok to test 30b5852

@yaox12 yaox12 left a comment

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Approve since it passed the expert review.

@yaox12 yaox12 added this pull request to the merge queue May 19, 2026
@svcnvidia-nemo-ci

Copy link
Copy Markdown

🔄 Merge queue validation started!

You can track the progress here: https://github.com/NVIDIA/Megatron-LM/actions/runs/26072531765

Merged via the queue into NVIDIA:dev with commit 92ab682 May 19, 2026
181 of 183 checks passed
@wplf wplf deleted the fix/fsdp-dtensor-bridge-checkpoint branch May 19, 2026 04:43
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants