[dev] [2/5] Qwen3.5 support: FSDP DTensor Bridge checkpoint compatibility#4748
Merged
Conversation
329a368 to
aeaf0d1
Compare
This was referenced May 12, 2026
This was referenced May 12, 2026
Member
Author
|
/ok to test aeaf0d1 |
aeaf0d1 to
50c107a
Compare
Three small fixes that together let FSDP-DTensor resume from a checkpoint converted with Megatron-Bridge: 1. `split_swiglu_linear_fc1` / `split_gdn_fused`: accept a plain `Tensor` (not only `DTensor`) for `data.to_local()`. Bridge writes some already-local tensors that previously crashed the splitter. 2. `split_gdn_fused`: when the GDN tensor is already TP-local (`data.shape[split_dim] == sum(split_sizes)`), use `split_dtensor(..., update_uneven_dtensor_chunk_meta=True)` instead of the TP-mesh-based split that assumes the full unsharded shape. 3. `megatron/training/checkpointing.py::_load_base_checkpoint`: stash top-level state-dict keys that DCP storage metadata doesn't know about (`args`, `iteration`, ...), then restore them after `dcp.load`. Megatron-saved checkpoints contain these as `BytesStorageMetadata`; Bridge-converted ones don't, and DCP errors without this stash. 4. `DistributedOptimizer._copy_model_params_to_main_params`: under `use_megatron_fsdp`, return instead of raising. DCP already loads directly into the fp32 main buffer, and a post-load hook copies main → model weights — both are correct at this call site. Also wires `handle_gdn_in_state_dict` into `preprocess_fsdp_dtensor_state_dict`. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com> Co-Authored-By: xuwchen <79835960+xuwchen@users.noreply.github.com> Co-Authored-By: conver334 <56124251+conver334@users.noreply.github.com> Co-Authored-By: BestJuly <19769279+BestJuly@users.noreply.github.com>
50c107a to
c9dfa74
Compare
Member
Author
|
/ok to test c9dfa74 |
Member
Author
|
/ok to test 7f9e6b6 |
…path Collapse the multi-line ``split_dtensor(...)`` call inside the ``isinstance(data, DTensor) and data.shape[split_dim] == total_split`` fast-path onto a single line per black/autoformat.sh rules. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Member
Author
|
/ok to test 021b98d |
wplf
added a commit
to wplf/Megatron-LM
that referenced
this pull request
May 14, 2026
This reverts commit 29db3b3 (integration branch counterpart of the PR-2 revert). Empirical verification showed the FSDP model->main copy was unneeded — and slightly harmful (downgrades fp32 main to bf16- quantized) — because under fsdp_dtensor the fp32 main_weight_buffer is loaded directly from the checkpoint's model state dict via the ``MegatronFSDP`` ``state_dict`` pre-hook (_replace_param_with_distributed_if_needed) and the load_state_dict post-hook (copy_main_weights_to_model_weights). See PRs NVIDIA#4748 / NVIDIA#4753 commit messages and the wandb runs at https://wandb.ai/wplf/fsdp-fix-verify/ for the verification trace. Also expands the comment on the FSDP early-return to document the mechanism so this isn't re-litigated. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Member
Author
|
/ok to test 52f5d6e |
@wplf, there was an error processing your request: See the following link for more information: https://docs.gha-runners.nvidia.com/cpr/e/2/ |
Member
Author
|
/ok to test 7bdf9d8 |
The early-return ``isinstance(data, DTensor) and data.shape[split_dim] == total_split`` branch in ``split_gdn_fused`` is an independent fix that landed on this PR via conver334's internal contribution. NVIDIA#4799 by the same author proposes the same fast-path against ``dev`` directly, so we remove it here to avoid a textual conflict when NVIDIA#4799 merges. The remaining changes in this PR are still needed and complementary to NVIDIA#4799 (plain-Tensor support in splitters, ``dist_param``-coordinate computation for non-DTensor optimizer states, top-level metadata stash for Bridge-converted DCP checkpoints, and the FSDP no-op in ``DistributedOptimizer._copy_model_params_to_main_params``). Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Member
Author
|
/ok to test ca6a716 |
Member
Author
|
/ok to test 99e3a5f |
The merge of ``dev`` into this branch produced broken Python in ``split_gdn_fused``: - inconsistent indentation (`` else:`` at 13 spaces vs ``if`` at 12), tripping black with ``IndentationError: unindent does not match any outer indentation level``; - an unclosed ``assert data.shape == global_shape, (...`` (missing closing ``)``); - incorrect nesting of the new TP-local fast-path (from NVIDIA#4799, now on ``dev``) inside ``if isinstance(data, DTensor):``. Restore the intended structure: 1. TP-local fast-path as a flat top-level early return (matches the form NVIDIA#4799 landed on ``dev``). 2. PR-2's ``global_shape = dist_param.shape`` and the ``data.shape == global_shape`` assert for full-global DTensor data, kept as a sibling check after the fast-path. CI ``BASE_REF=dev`` autoformat now passes (black, isort, pylint 10/10, ruff). Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Member
Author
|
/ok to test af5bc08 |
wplf
added a commit
to wplf/Megatron-LM
that referenced
this pull request
May 18, 2026
This reverts 947c91e. NVIDIA#4799 has landed on ``dev`` (and was synced into PR NVIDIA#4748 via a dev-merge), so we removed our copy from the dev- target PR to avoid a textual conflict. ``main`` does **not** yet have NVIDIA#4799, however. Without the TP-local GDN fast-path on the main side, ``split_gdn_fused`` falls into the TP-mesh-based branch even when ``data`` is already TP-local (``data.shape[split_dim] == sum(split_sizes)``), which yields wrong ``data_size`` / ``view_shape`` for Bridge-converted DTensors that encode the per-rank size as the global shape. Restore the fast-path on this PR (target ``main``) until NVIDIA#4799's main mirror lands. When that happens, we should drop the fast-path here too (mirror of what we did on PR NVIDIA#4748 dev) to avoid the textual conflict. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
xuwchen
reviewed
May 18, 2026
xuwchen
approved these changes
May 18, 2026
The FSDP early-return comment in ``_copy_model_params_to_main_params`` cited specific line numbers in ``megatron/core/distributed/fsdp/src/megatron_fsdp/megatron_fsdp.py``. Those line numbers shift on every unrelated change to that file (e.g. the A2A overlap PR NVIDIA#3796 already shifted them on dev). Replace with the stable symbol names — ``_replace_param_with_distributed_if_needed``, ``install_optimized_model_weights``, ``copy_main_weights_to_model_weights`` — so the comment doesn't go stale. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Member
Author
|
/ok to test 30b5852 |
yaox12
approved these changes
May 19, 2026
yaox12
left a comment
Member
There was a problem hiding this comment.
Approve since it passed the expert review.
|
🔄 Merge queue validation started! You can track the progress here: https://github.com/NVIDIA/Megatron-LM/actions/runs/26072531765 |
5 tasks
71 tasks
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Qwen3.5 support series
This is part of a 5-PR series adding Qwen3.5-VL support, split for review clarity.
Dev PRs (this series):
Main PRs (corresponding mirrors):
Summary
Three related fixes that together let
fsdp_dtensorresume from a checkpoint converted with Megatron-Bridge:split_swiglu_linear_fc1/split_gdn_fused: accept plainTensor(not onlyDTensor) fordata.to_local(). Bridge writes some already-local tensors that previously crashed the splitter.Stash missing top-level keys in
_load_base_checkpoint: keys not present as DCP storage metadata (e.g.args,iteration) are stashed beforedcp.loadand restored after. Megatron-saved checkpoints contain these asBytesStorageMetadata; Bridge-converted ones don't, and DCP errors without this stash.DistributedOptimizer._copy_model_params_to_main_params: underuse_megatron_fsdp, return instead of raising. DCP already loads directly into the fp32 main buffer, and a post-load hook copies main → model weights — both are correct at this call site.split_swiglu_linear_fc1/split_gdn_fusedusedist_paramcoordinates (cherry-pick of [DEV] fix(megatron-fsdp): compute SWiGLU/GDN split in item coordinates for non-DTensor optimizer states #4424 by @xuwchen): whendatais a plainTensor(FusedAdam optimizer state),data.numel()/data.shapedescribe the local item layout, not the global TP-sharded layout, so the splitter computed wrongdata_size,view_shape, andper_tp_rank_shape. Switch toglobal_shape = dist_param.shapeanddist_param.numel()everywhere, with a DTensor-shape consistency assert.resolve conflict from [dev] Fix GDN DTensor splitting for FSDP checkpointing #4799
Also wires
handle_gdn_in_state_dictintopreprocess_fsdp_dtensor_state_dict.Why bundled
All three fixes are needed together for the Bridge →
fsdp_dtensorresume path to work end-to-end. Splitting further yields PRs that individually leave the path broken.Note on key-unwrapping
An earlier revision of this branch also rewrote
model.get_parameter(f"module.{key}")→model.get_parameter(key)(and the analogous optimizer-state lookups) on the assumption thatunwrap_modelstripsMegatronFSDPModule. That root cause was resolved in #4393 ("Revert 'fix mfsdp unwrap stuck at MegatronFSDP [dev] (#4273)'"), which removedMegatronFSDPModulefromunwrap_model's tuple, so themodule.prefix is correct again. Those four lookup edits have been dropped from this PR.Test plan
Notes
Extracted from a larger Qwen3.5-VL development branch to keep these fixes reviewable on their own.
🤖 Generated with Claude Code
Update (2026-05-14): model→main copy reverted
A previous revision of this PR (commit
4c2c3cbfbon dev /f4ba23b11on main) implemented an FSDPcopy_model_weights_to_main_weightsand called it from_copy_model_params_to_main_params, on the theory that the silentreturnleft fp32 main params uninitialized under--no-load-optim/--finetune.Empirical buffer-diff probe showed that's wrong: under
fsdp_dtensor,main_weight_bufferis exposed as the module'snn.Parameters viaMegatronFSDP'sstate_dictpre-hook (_replace_param_with_distributed_if_needed), so the precedingmodel.load_state_dict(state_dict["model"])writes the checkpoint's fp32 tensors directly intomain_weight_buffer— full fp32 precision, even with--no-load-optim, because main params live in the model state dict (not the optimizer state dict) underfsdp_dtensor. The load_state_dict post-hook (copy_main_weights_to_model_weights) then casts main → bf16 to refreshmodel_weight_buffer.A 50-step finetune+resume on a Qwen3.5-VL 0.8B model with and without the copy gave indistinguishable loss curves (wandb runs
q8d7l7nm/o4tyuf5z/wwuz7y4zunder https://wandb.ai/wplf/fsdp-fix-verify/). The pre-copy buffer diff was exactly bf16 rounding noise; the post-copy diff dropped to 0 only because main had been downgraded to bf16 precision.The model→main copy commit has been reverted. The plain
returnis correct, and the comment in_copy_model_params_to_main_paramsnow documents the mechanism so this isn't re-litigated.Update (2026-05-14): TP-local GDN fast-path dropped
The original 2nd bullet (TP-local GDN split via
split_dtensor) was the same fix proposed by @conver334 in #4799 againstdevdirectly. We've dropped that bullet here so this PR doesn't textually conflict when #4799 merges; the remaining bullets are still complementary fixes that #4799 doesn't cover.