Skip to content

[main] [2/5] Qwen3.5 support: FSDP DTensor Bridge checkpoint compatibility#4753

Open
wplf wants to merge 10 commits into
NVIDIA:mainfrom
wplf:fix/fsdp-dtensor-bridge-checkpoint-main
Open

[main] [2/5] Qwen3.5 support: FSDP DTensor Bridge checkpoint compatibility#4753
wplf wants to merge 10 commits into
NVIDIA:mainfrom
wplf:fix/fsdp-dtensor-bridge-checkpoint-main

Conversation

@wplf

@wplf wplf commented May 12, 2026

Copy link
Copy Markdown
Member

Qwen3.5 support series

This is part of a 5-PR series adding Qwen3.5-VL support, split for review clarity.

Main PRs (this series):

Dev PRs (corresponding mirrors):


Summary

Four related fixes that together let fsdp_dtensor resume from a checkpoint converted with Megatron-Bridge:

  1. split_swiglu_linear_fc1 / split_gdn_fused: accept plain Tensor (not only DTensor) for data.to_local(). Bridge writes some already-local tensors that previously crashed the splitter.

  2. TP-local GDN split: when the GDN tensor is already TP-local (data.shape[split_dim] == sum(split_sizes)), use split_dtensor(..., update_uneven_dtensor_chunk_meta=True) instead of the TP-mesh-based split that assumes the full unsharded shape.

  3. Stash missing top-level keys in _load_base_checkpoint: keys not present as DCP storage metadata (e.g. args, iteration) are stashed before dcp.load and restored after. Megatron-saved checkpoints contain these as BytesStorageMetadata; Bridge-converted ones don't, and DCP errors without this stash.

  4. DistributedOptimizer._copy_model_params_to_main_params: under use_megatron_fsdp, return instead of raising. DCP already loads directly into the fp32 main buffer, and a post-load hook copies main → model weights — both are correct at this call site.

  5. split_swiglu_linear_fc1 / split_gdn_fused use dist_param coordinates (cherry-pick of [DEV] fix(megatron-fsdp): compute SWiGLU/GDN split in item coordinates for non-DTensor optimizer states #4424 by @xuwchen): when data is a plain Tensor (FusedAdam optimizer state), data.numel() / data.shape describe the local item layout, not the global TP-sharded layout, so the splitter computed wrong data_size, view_shape, and per_tp_rank_shape. Switch to global_shape = dist_param.shape and dist_param.numel() everywhere, with a DTensor-shape consistency assert.

Also wires handle_gdn_in_state_dict into preprocess_fsdp_dtensor_state_dict.

Why bundled

All four fixes are needed together for the Bridge → fsdp_dtensor resume path to work end-to-end.

Note on key-unwrapping

An earlier revision of this branch also rewrote model.get_parameter(f"module.{key}")model.get_parameter(key). That root cause was resolved in #4393 ("Revert 'fix mfsdp unwrap stuck at MegatronFSDP [dev] (#4273)'"), so the module. prefix is correct again and those four lookup edits have been dropped.

Notes

Mirror of #4748 (same patch, targeting main instead of dev).

🤖 Generated with Claude Code

Update (2026-05-14): model→main copy reverted

A previous revision of this PR (commit 4c2c3cbfb on dev / f4ba23b11 on main) implemented an FSDP copy_model_weights_to_main_weights and called it from _copy_model_params_to_main_params, on the theory that the silent return left fp32 main params uninitialized under --no-load-optim / --finetune.

Empirical buffer-diff probe showed that's wrong: under fsdp_dtensor, main_weight_buffer is exposed as the module's nn.Parameters via MegatronFSDP's state_dict pre-hook (_replace_param_with_distributed_if_needed), so the preceding model.load_state_dict(state_dict["model"]) writes the checkpoint's fp32 tensors directly into main_weight_buffer — full fp32 precision, even with --no-load-optim, because main params live in the model state dict (not the optimizer state dict) under fsdp_dtensor. The load_state_dict post-hook (copy_main_weights_to_model_weights) then casts main → bf16 to refresh model_weight_buffer.

A 50-step finetune+resume on a Qwen3.5-VL 0.8B model with and without the copy gave indistinguishable loss curves (wandb runs q8d7l7nm / o4tyuf5z / wwuz7y4z under https://wandb.ai/wplf/fsdp-fix-verify/). The pre-copy buffer diff was exactly bf16 rounding noise; the post-copy diff dropped to 0 only because main had been downgraded to bf16 precision.

The model→main copy commit has been reverted. The plain return is correct, and the comment in _copy_model_params_to_main_params now documents the mechanism so this isn't re-litigated.

Update (2026-05-18): TP-local GDN fast-path restored on main

The TP-local GDN fast-path is back on this main-target PR. #4799 (conver334) provides the same fast-path against dev, and the dev-target PR #4748 dropped its copy to avoid a textual conflict; but #4799 has not yet landed on main, so this PR still needs to provide it. Once #4799's main mirror lands, this bullet can be removed here too (same flow as PR #4748).

@wplf wplf added the Run tests label May 12, 2026
@copy-pr-bot

copy-pr-bot Bot commented May 12, 2026

Copy link
Copy Markdown

This pull request requires additional validation before any workflows can run on NVIDIA's runners.

Pull request vetters can view their responsibilities here.

Contributors can view more details about this message here.

Three small fixes that together let FSDP-DTensor resume from a checkpoint
converted with Megatron-Bridge:

1. `split_swiglu_linear_fc1` / `split_gdn_fused`: accept a plain `Tensor`
   (not only `DTensor`) for `data.to_local()`. Bridge writes some
   already-local tensors that previously crashed the splitter.

2. `split_gdn_fused`: when the GDN tensor is already TP-local
   (`data.shape[split_dim] == sum(split_sizes)`), use
   `split_dtensor(..., update_uneven_dtensor_chunk_meta=True)` instead
   of the TP-mesh-based split that assumes the full unsharded shape.

3. `megatron/training/checkpointing.py::_load_base_checkpoint`: stash
   top-level state-dict keys that DCP storage metadata doesn't know
   about (`args`, `iteration`, ...), then restore them after
   `dcp.load`. Megatron-saved checkpoints contain these as
   `BytesStorageMetadata`; Bridge-converted ones don't, and DCP errors
   without this stash.

4. `DistributedOptimizer._copy_model_params_to_main_params`: under
   `use_megatron_fsdp`, return instead of raising. DCP already loads
   directly into the fp32 main buffer, and a post-load hook copies
   main → model weights — both are correct at this call site.

Also wires `handle_gdn_in_state_dict` into
`preprocess_fsdp_dtensor_state_dict`.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>

Co-Authored-By: xuwchen <79835960+xuwchen@users.noreply.github.com>
Co-Authored-By: conver334 <56124251+conver334@users.noreply.github.com>

Co-Authored-By: BestJuly <19769279+BestJuly@users.noreply.github.com>
@wplf wplf force-pushed the fix/fsdp-dtensor-bridge-checkpoint-main branch from 6f94567 to 9c93a67 Compare May 13, 2026 10:24
@wplf

wplf commented May 13, 2026

Copy link
Copy Markdown
Member Author

/ok to test bb5dbe7

wplf and others added 3 commits May 13, 2026 06:07
…path

Collapse the multi-line ``split_dtensor(...)`` call inside the
``isinstance(data, DTensor) and data.shape[split_dim] == total_split``
fast-path onto a single line per black/autoformat.sh rules.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
`DistributedOptimizer._copy_model_params_to_main_params` is called from
the ``--finetune`` / ``--no-load-optim`` checkpoint path
(``megatron/training/checkpointing.py:2261-2266``) to refresh fp32 main
params from the bf16/fp16 model params that DCP just loaded — i.e.,
the resume case where the checkpoint has no optimizer state. Under
``use_megatron_fsdp`` we previously returned without doing anything, so
fp32 main params would silently hold their pre-checkpoint init values
and training would diverge.

Add ``ParamAndGradBuffer.copy_model_weights_to_main_weights`` as the
mirror of ``copy_main_weights_to_model_weights``: walk each parameter
group, fetch the local (sharded) ``model_weight_buffer`` and
``main_weight_buffer`` slices via ``get_item(..., only_shard=...)``,
and copy ``model -> main`` with implicit bf16/fp16 -> fp32 upcast.

Wire ``DistributedOptimizer._copy_model_params_to_main_params`` to
call the new method per model chunk under ``use_megatron_fsdp``.

FP8 model params are out of scope (would need a separate fp8 -> fp32
dequantization path) and raise ``NotImplementedError``. The full Bridge
resume path (which loads optimizer state via DCP) does not reach this
function and is unaffected.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This reverts commit f4ba23b. The model->main copy was based on a
non-FSDP mental model of where the fp32 main params live and turned out
to be unneeded — and slightly harmful — under ``fsdp_dtensor``.

Empirical verification (PR NVIDIA#4753 wandb runs, see also the buffer-diff
probe added in the discussion): under
``use_megatron_fsdp + --no-load-optim + --finetune`` the
``main_weight_buffer`` is already populated to **full fp32 precision**
from the checkpoint by the time ``_copy_model_params_to_main_params``
is invoked. The mechanism is two ``MegatronFSDP`` hooks documented in
``megatron_fsdp.py``:

1. A ``state_dict`` pre-hook (``_replace_param_with_distributed_if_needed``,
   megatron_fsdp.py:1099-1102 / :1327-1344) swaps each module
   ``nn.Parameter`` for a fp32 DTensor pointing at the matching
   ``main_weight_buffer`` slice, so DCP / ``model.load_state_dict``
   writes the checkpoint's fp32 tensor directly into
   ``main_weight_buffer``.
2. A ``load_state_dict`` post-hook
   (``install_optimized_model_weights`` -> ``copy_main_weights_to_model_weights``,
   megatron_fsdp.py:1093-1095 / :1408-1413) immediately casts
   ``main_weight_buffer`` down to bf16/fp16 to refresh
   ``model_weight_buffer``.

``--no-load-optim`` only skips ``state_dict["optimizer"]`` (Adam moments
and group metadata). Under ``fsdp_dtensor`` the main params live in the
**model** state dict, not the optimizer state dict, so they are loaded
regardless of ``--no-load-optim``.

The model->main copy this commit reverts would overwrite the freshly-
loaded full-fp32 ``main_weight_buffer`` with ``model_weight.float()``
(i.e., bf16-rounded-then-recast), a strict precision regression.
A 50-step finetune+resume with and without the copy gave indistinguishable
loss curves; the buffer-diff probe showed pre-copy diff = exact bf16
rounding noise (full main + bf16-rounded model agreeing modulo ULP),
post-copy diff = 0 only because main had been downgraded to bf16
precision. The plain ``return`` is correct.

Also expand the comment on the FSDP early-return so a future reader does
not re-litigate this.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
wplf added a commit to wplf/Megatron-LM that referenced this pull request May 14, 2026
This reverts commit 29db3b3 (integration branch counterpart of the
PR-2 revert). Empirical verification showed the FSDP model->main copy
was unneeded — and slightly harmful (downgrades fp32 main to bf16-
quantized) — because under fsdp_dtensor the fp32 main_weight_buffer
is loaded directly from the checkpoint's model state dict via the
``MegatronFSDP`` ``state_dict`` pre-hook
(_replace_param_with_distributed_if_needed) and the
load_state_dict post-hook (copy_main_weights_to_model_weights).
See PRs NVIDIA#4748 / NVIDIA#4753 commit messages and the wandb runs at
https://wandb.ai/wplf/fsdp-fix-verify/ for the verification trace.

Also expands the comment on the FSDP early-return to document the
mechanism so this isn't re-litigated.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
The early-return ``isinstance(data, DTensor) and data.shape[split_dim] == total_split``
branch in ``split_gdn_fused`` is an independent fix that landed on this
PR via conver334's internal contribution. NVIDIA#4799 by
the same author proposes the same fast-path against ``dev`` directly,
so we remove it here to avoid a textual conflict when NVIDIA#4799 merges.

The remaining changes in this PR are still needed and complementary
to NVIDIA#4799 (plain-Tensor support in splitters, ``dist_param``-coordinate
computation for non-DTensor optimizer states, top-level metadata
stash for Bridge-converted DCP checkpoints, and the FSDP no-op in
``DistributedOptimizer._copy_model_params_to_main_params``).

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
wplf and others added 4 commits May 17, 2026 20:50
This reverts 947c91e. NVIDIA#4799 has landed on ``dev`` (and was synced
into PR NVIDIA#4748 via a dev-merge), so we removed our copy from the dev-
target PR to avoid a textual conflict.

``main`` does **not** yet have NVIDIA#4799, however. Without the TP-local
GDN fast-path on the main side, ``split_gdn_fused`` falls into the
TP-mesh-based branch even when ``data`` is already TP-local
(``data.shape[split_dim] == sum(split_sizes)``), which yields wrong
``data_size`` / ``view_shape`` for Bridge-converted DTensors that
encode the per-rank size as the global shape.

Restore the fast-path on this PR (target ``main``) until NVIDIA#4799's
main mirror lands. When that happens, we should drop the fast-path
here too (mirror of what we did on PR NVIDIA#4748 dev) to avoid the
textual conflict.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
The FSDP early-return comment in ``_copy_model_params_to_main_params``
cited specific line numbers in
``megatron/core/distributed/fsdp/src/megatron_fsdp/megatron_fsdp.py``.
Those line numbers shift on every unrelated change to that file
(e.g. the A2A overlap PR NVIDIA#3796 already shifted them on dev). Replace
with the stable symbol names — ``_replace_param_with_distributed_if_needed``,
``install_optimized_model_weights``, ``copy_main_weights_to_model_weights``
— so the comment doesn't go stale.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
# Conflicts:
#	megatron/core/transformer/fsdp_dtensor_checkpoint.py
@wplf wplf force-pushed the fix/fsdp-dtensor-bridge-checkpoint-main branch from ede610a to 6fd7b35 Compare June 4, 2026 10:16
@wplf wplf marked this pull request as ready for review June 4, 2026 10:18
@wplf wplf requested review from a team as code owners June 4, 2026 10:18
@wplf

wplf commented Jun 4, 2026

Copy link
Copy Markdown
Member Author

/ok to test 6fd7b35

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants