[main] [2/5] Qwen3.5 support: FSDP DTensor Bridge checkpoint compatibility#4753
Open
wplf wants to merge 10 commits into
Open
[main] [2/5] Qwen3.5 support: FSDP DTensor Bridge checkpoint compatibility#4753wplf wants to merge 10 commits into
wplf wants to merge 10 commits into
Conversation
This was referenced May 12, 2026
This was referenced May 12, 2026
cbd9425 to
6f94567
Compare
Three small fixes that together let FSDP-DTensor resume from a checkpoint converted with Megatron-Bridge: 1. `split_swiglu_linear_fc1` / `split_gdn_fused`: accept a plain `Tensor` (not only `DTensor`) for `data.to_local()`. Bridge writes some already-local tensors that previously crashed the splitter. 2. `split_gdn_fused`: when the GDN tensor is already TP-local (`data.shape[split_dim] == sum(split_sizes)`), use `split_dtensor(..., update_uneven_dtensor_chunk_meta=True)` instead of the TP-mesh-based split that assumes the full unsharded shape. 3. `megatron/training/checkpointing.py::_load_base_checkpoint`: stash top-level state-dict keys that DCP storage metadata doesn't know about (`args`, `iteration`, ...), then restore them after `dcp.load`. Megatron-saved checkpoints contain these as `BytesStorageMetadata`; Bridge-converted ones don't, and DCP errors without this stash. 4. `DistributedOptimizer._copy_model_params_to_main_params`: under `use_megatron_fsdp`, return instead of raising. DCP already loads directly into the fp32 main buffer, and a post-load hook copies main → model weights — both are correct at this call site. Also wires `handle_gdn_in_state_dict` into `preprocess_fsdp_dtensor_state_dict`. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com> Co-Authored-By: xuwchen <79835960+xuwchen@users.noreply.github.com> Co-Authored-By: conver334 <56124251+conver334@users.noreply.github.com> Co-Authored-By: BestJuly <19769279+BestJuly@users.noreply.github.com>
6f94567 to
9c93a67
Compare
Member
Author
|
/ok to test bb5dbe7 |
…path Collapse the multi-line ``split_dtensor(...)`` call inside the ``isinstance(data, DTensor) and data.shape[split_dim] == total_split`` fast-path onto a single line per black/autoformat.sh rules. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
`DistributedOptimizer._copy_model_params_to_main_params` is called from the ``--finetune`` / ``--no-load-optim`` checkpoint path (``megatron/training/checkpointing.py:2261-2266``) to refresh fp32 main params from the bf16/fp16 model params that DCP just loaded — i.e., the resume case where the checkpoint has no optimizer state. Under ``use_megatron_fsdp`` we previously returned without doing anything, so fp32 main params would silently hold their pre-checkpoint init values and training would diverge. Add ``ParamAndGradBuffer.copy_model_weights_to_main_weights`` as the mirror of ``copy_main_weights_to_model_weights``: walk each parameter group, fetch the local (sharded) ``model_weight_buffer`` and ``main_weight_buffer`` slices via ``get_item(..., only_shard=...)``, and copy ``model -> main`` with implicit bf16/fp16 -> fp32 upcast. Wire ``DistributedOptimizer._copy_model_params_to_main_params`` to call the new method per model chunk under ``use_megatron_fsdp``. FP8 model params are out of scope (would need a separate fp8 -> fp32 dequantization path) and raise ``NotImplementedError``. The full Bridge resume path (which loads optimizer state via DCP) does not reach this function and is unaffected. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This reverts commit f4ba23b. The model->main copy was based on a non-FSDP mental model of where the fp32 main params live and turned out to be unneeded — and slightly harmful — under ``fsdp_dtensor``. Empirical verification (PR NVIDIA#4753 wandb runs, see also the buffer-diff probe added in the discussion): under ``use_megatron_fsdp + --no-load-optim + --finetune`` the ``main_weight_buffer`` is already populated to **full fp32 precision** from the checkpoint by the time ``_copy_model_params_to_main_params`` is invoked. The mechanism is two ``MegatronFSDP`` hooks documented in ``megatron_fsdp.py``: 1. A ``state_dict`` pre-hook (``_replace_param_with_distributed_if_needed``, megatron_fsdp.py:1099-1102 / :1327-1344) swaps each module ``nn.Parameter`` for a fp32 DTensor pointing at the matching ``main_weight_buffer`` slice, so DCP / ``model.load_state_dict`` writes the checkpoint's fp32 tensor directly into ``main_weight_buffer``. 2. A ``load_state_dict`` post-hook (``install_optimized_model_weights`` -> ``copy_main_weights_to_model_weights``, megatron_fsdp.py:1093-1095 / :1408-1413) immediately casts ``main_weight_buffer`` down to bf16/fp16 to refresh ``model_weight_buffer``. ``--no-load-optim`` only skips ``state_dict["optimizer"]`` (Adam moments and group metadata). Under ``fsdp_dtensor`` the main params live in the **model** state dict, not the optimizer state dict, so they are loaded regardless of ``--no-load-optim``. The model->main copy this commit reverts would overwrite the freshly- loaded full-fp32 ``main_weight_buffer`` with ``model_weight.float()`` (i.e., bf16-rounded-then-recast), a strict precision regression. A 50-step finetune+resume with and without the copy gave indistinguishable loss curves; the buffer-diff probe showed pre-copy diff = exact bf16 rounding noise (full main + bf16-rounded model agreeing modulo ULP), post-copy diff = 0 only because main had been downgraded to bf16 precision. The plain ``return`` is correct. Also expand the comment on the FSDP early-return so a future reader does not re-litigate this. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
wplf
added a commit
to wplf/Megatron-LM
that referenced
this pull request
May 14, 2026
This reverts commit 29db3b3 (integration branch counterpart of the PR-2 revert). Empirical verification showed the FSDP model->main copy was unneeded — and slightly harmful (downgrades fp32 main to bf16- quantized) — because under fsdp_dtensor the fp32 main_weight_buffer is loaded directly from the checkpoint's model state dict via the ``MegatronFSDP`` ``state_dict`` pre-hook (_replace_param_with_distributed_if_needed) and the load_state_dict post-hook (copy_main_weights_to_model_weights). See PRs NVIDIA#4748 / NVIDIA#4753 commit messages and the wandb runs at https://wandb.ai/wplf/fsdp-fix-verify/ for the verification trace. Also expands the comment on the FSDP early-return to document the mechanism so this isn't re-litigated. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
The early-return ``isinstance(data, DTensor) and data.shape[split_dim] == total_split`` branch in ``split_gdn_fused`` is an independent fix that landed on this PR via conver334's internal contribution. NVIDIA#4799 by the same author proposes the same fast-path against ``dev`` directly, so we remove it here to avoid a textual conflict when NVIDIA#4799 merges. The remaining changes in this PR are still needed and complementary to NVIDIA#4799 (plain-Tensor support in splitters, ``dist_param``-coordinate computation for non-DTensor optimizer states, top-level metadata stash for Bridge-converted DCP checkpoints, and the FSDP no-op in ``DistributedOptimizer._copy_model_params_to_main_params``). Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This reverts 947c91e. NVIDIA#4799 has landed on ``dev`` (and was synced into PR NVIDIA#4748 via a dev-merge), so we removed our copy from the dev- target PR to avoid a textual conflict. ``main`` does **not** yet have NVIDIA#4799, however. Without the TP-local GDN fast-path on the main side, ``split_gdn_fused`` falls into the TP-mesh-based branch even when ``data`` is already TP-local (``data.shape[split_dim] == sum(split_sizes)``), which yields wrong ``data_size`` / ``view_shape`` for Bridge-converted DTensors that encode the per-rank size as the global shape. Restore the fast-path on this PR (target ``main``) until NVIDIA#4799's main mirror lands. When that happens, we should drop the fast-path here too (mirror of what we did on PR NVIDIA#4748 dev) to avoid the textual conflict. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
The FSDP early-return comment in ``_copy_model_params_to_main_params`` cited specific line numbers in ``megatron/core/distributed/fsdp/src/megatron_fsdp/megatron_fsdp.py``. Those line numbers shift on every unrelated change to that file (e.g. the A2A overlap PR NVIDIA#3796 already shifted them on dev). Replace with the stable symbol names — ``_replace_param_with_distributed_if_needed``, ``install_optimized_model_weights``, ``copy_main_weights_to_model_weights`` — so the comment doesn't go stale. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
# Conflicts: # megatron/core/transformer/fsdp_dtensor_checkpoint.py
ede610a to
6fd7b35
Compare
Member
Author
|
/ok to test 6fd7b35 |
shjwudp
approved these changes
Jun 4, 2026
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Qwen3.5 support series
This is part of a 5-PR series adding Qwen3.5-VL support, split for review clarity.
Main PRs (this series):
Dev PRs (corresponding mirrors):
Summary
Four related fixes that together let
fsdp_dtensorresume from a checkpoint converted with Megatron-Bridge:split_swiglu_linear_fc1/split_gdn_fused: accept plainTensor(not onlyDTensor) fordata.to_local(). Bridge writes some already-local tensors that previously crashed the splitter.TP-local GDN split: when the GDN tensor is already TP-local (
data.shape[split_dim] == sum(split_sizes)), usesplit_dtensor(..., update_uneven_dtensor_chunk_meta=True)instead of the TP-mesh-based split that assumes the full unsharded shape.Stash missing top-level keys in
_load_base_checkpoint: keys not present as DCP storage metadata (e.g.args,iteration) are stashed beforedcp.loadand restored after. Megatron-saved checkpoints contain these asBytesStorageMetadata; Bridge-converted ones don't, and DCP errors without this stash.DistributedOptimizer._copy_model_params_to_main_params: underuse_megatron_fsdp, return instead of raising. DCP already loads directly into the fp32 main buffer, and a post-load hook copies main → model weights — both are correct at this call site.split_swiglu_linear_fc1/split_gdn_fusedusedist_paramcoordinates (cherry-pick of [DEV] fix(megatron-fsdp): compute SWiGLU/GDN split in item coordinates for non-DTensor optimizer states #4424 by @xuwchen): whendatais a plainTensor(FusedAdam optimizer state),data.numel()/data.shapedescribe the local item layout, not the global TP-sharded layout, so the splitter computed wrongdata_size,view_shape, andper_tp_rank_shape. Switch toglobal_shape = dist_param.shapeanddist_param.numel()everywhere, with a DTensor-shape consistency assert.Also wires
handle_gdn_in_state_dictintopreprocess_fsdp_dtensor_state_dict.Why bundled
All four fixes are needed together for the Bridge →
fsdp_dtensorresume path to work end-to-end.Note on key-unwrapping
An earlier revision of this branch also rewrote
model.get_parameter(f"module.{key}")→model.get_parameter(key). That root cause was resolved in #4393 ("Revert 'fix mfsdp unwrap stuck at MegatronFSDP [dev] (#4273)'"), so themodule.prefix is correct again and those four lookup edits have been dropped.Notes
Mirror of #4748 (same patch, targeting
maininstead ofdev).🤖 Generated with Claude Code
Update (2026-05-14): model→main copy reverted
A previous revision of this PR (commit
4c2c3cbfbon dev /f4ba23b11on main) implemented an FSDPcopy_model_weights_to_main_weightsand called it from_copy_model_params_to_main_params, on the theory that the silentreturnleft fp32 main params uninitialized under--no-load-optim/--finetune.Empirical buffer-diff probe showed that's wrong: under
fsdp_dtensor,main_weight_bufferis exposed as the module'snn.Parameters viaMegatronFSDP'sstate_dictpre-hook (_replace_param_with_distributed_if_needed), so the precedingmodel.load_state_dict(state_dict["model"])writes the checkpoint's fp32 tensors directly intomain_weight_buffer— full fp32 precision, even with--no-load-optim, because main params live in the model state dict (not the optimizer state dict) underfsdp_dtensor. The load_state_dict post-hook (copy_main_weights_to_model_weights) then casts main → bf16 to refreshmodel_weight_buffer.A 50-step finetune+resume on a Qwen3.5-VL 0.8B model with and without the copy gave indistinguishable loss curves (wandb runs
q8d7l7nm/o4tyuf5z/wwuz7y4zunder https://wandb.ai/wplf/fsdp-fix-verify/). The pre-copy buffer diff was exactly bf16 rounding noise; the post-copy diff dropped to 0 only because main had been downgraded to bf16 precision.The model→main copy commit has been reverted. The plain
returnis correct, and the comment in_copy_model_params_to_main_paramsnow documents the mechanism so this isn't re-litigated.Update (2026-05-18): TP-local GDN fast-path restored on main
The TP-local GDN fast-path is back on this main-target PR. #4799 (
conver334) provides the same fast-path againstdev, and the dev-target PR #4748 dropped its copy to avoid a textual conflict; but #4799 has not yet landed onmain, so this PR still needs to provide it. Once #4799's main mirror lands, this bullet can be removed here too (same flow as PR #4748).