[feat] HybridStack grouped syntax + checkpoint compat + EP-overlap (2/4 of #4798)#4942
Draft
Connor-XY wants to merge 10 commits into
Draft
[feat] HybridStack grouped syntax + checkpoint compat + EP-overlap (2/4 of #4798)#4942Connor-XY wants to merge 10 commits into
Connor-XY wants to merge 10 commits into
Conversation
This was referenced May 22, 2026
5fa940f to
5e5dee3
Compare
Move model-agnostic schedule-plan helpers out of gpt/fine_grained_callables.py
into megatron/core/models/common/ so non-GPT models (HybridStack) can build the
same combined-1F1B / EP-overlap schedule plans. No behavior change for GPT/MTP.
- Add megatron/core/models/common/{utils.py, fine_grained_callables.py,
model_chunk_schedule_plan.py} with the shared abstractions.
- Reduce megatron/core/models/gpt/fine_grained_callables.py to GPT-specific
pieces; reuse the new common base classes.
- Switch GraphableMegatronModule.init_backward_dw_wrapper to import
_BackwardDWWrapper from the new common module.
- Relax combined_forward_backward_step's GPTModel-only assert to a duck-type
on build_schedule_plan so any model implementing it can participate.
Part 1/4 of splitting NVIDIA#4798 (original changes by @Wohox).
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
…unk) Carries over upstream commit f6ea23b from NVIDIA#4798: when a hybrid layer pattern places MTP in a post_process VPP chunk that holds no main HybridStack layers (e.g. trailing pipe before the MTP separator), the EP-overlap schedule never invokes ``_maybe_apply_final_norm`` on the main path, so the unnormalized hidden_states feed straight into the LM head and lm_loss diverges by ~10x. Fix in ``submodule_mtp_pre_dispatch_forward``: run the main decoder's ``final_norm`` just before ``torch.chunk``, gated on ``len(model.decoder.layers) == 0`` and ``isinstance(model, HybridModel)``. The HybridModel import is deferred inside the function so this file does not gain a module-level dependency on the hybrid package — PR 1 remains independently importable. Part 1/4 of splitting NVIDIA#4798 (original changes by @Wohox). Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Add bracketed HybridStack group syntax (e.g. ``[*-]``, ``M[M*]-``) with nested HybridStack instances, rejecting invalid recursion. Migrate grouped HybridStack checkpoints to Transformer-compatible logical layer keys and make ``HybridModel.sharded_state_dict()`` drop the empty ``output_layer._extra_state`` to match GPT behavior. Extend EP-overlap scheduling to HybridStack: add the hybrid fine-grained callables and ``HybridStackModelChunkSchedulePlan``, expose ``HybridModel.build_schedule_plan`` and add the ``return_schedule_plan`` path in ``pretrain_hybrid.py``. Add Mamba ``backward_dw`` so the hybrid schedule node can register Mamba pre-layer weight grads alongside attention and GDN pre-layers. Fix the MoE TopKRouter MTP layer-number indexing when the MTP block wraps a HybridStack so the aux-loss tracker is not indexed past its size. Part 2/4 of splitting NVIDIA#4798 (original changes by @Wohox). Depends on the common combined-1F1B refactor in part 1/4 (#TBD). Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Carries over upstream commit ce6e229 from NVIDIA#4798: in HybridStack's ``_run_moe_combine`` (A2A overlap path), ``layer._forward_post_mlp`` registers a second ``discard_output_and_register_recompute`` hook on ``mlp_output_with_bias[0]``. The hook fires during combine_bwd's autograd backward and triggers the LN recompute ahead of mlp_bwd / pre_dispatch_bwd. In bracketed-hybrid logical layers (``[*E]``), this corrupts gradients in attention's autograd chain (grad_norm explodes from iter 2). Fix: stop calling ``_forward_post_mlp`` from ``_run_moe_combine``; inline the ``bda + offload_mlp_norm + make_viewless_tensor`` steps directly, mirroring GPT's ``submodule_combine_forward``. The first recompute hook on ``expert_output`` (registered in ``_run_moe_experts``) already fires the LN recompute in mlp_bwd, so the second hook is redundant. Part 2/4 of splitting NVIDIA#4798 (original changes by @Wohox). Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
The recent merge of origin/main introduced `name=(name + f".layers.{i}")`
into every layer-type branch of HybridStack's build loop, but didn't change
the local loop header `for layer_type in self.layer_type_list:` to surface
`i`. Result: `NameError: name 'i' is not defined` at HybridStack init for
all hybrid runs (GPT path unaffected).
Trigger: any hybrid_stack_spec model crashes on init, including the 16-node
Bug 2a repro and the 8-node GPT-vs-Hybrid perf comparison runs.
Fix: convert the loop to `for i, layer_type in enumerate(...)`. Keep the
existing `physical_layer_offset` counter (used for FP8/FP4 contexts and
`layer_number`) because bracket groups count >1 physical layer per logical
entry — these are separate from the logical index `i` used for module names.
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
5e5dee3 to
91e86bb
Compare
71 tasks
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
What does this PR do?
Part 2 of 4 splitting #4798 by @Wohox and @Connor-XY. Original changes by @Wohox and @Connor-XY.
Summary
Add the core HybridStack feature work from #4798:
[*-],[*E],M[M*]-): build bracketed groups as nestedHybridStackinstances; reject invalid recursion (nested brackets inside a group); keep MoE constrained to the last symbol inside a group for EP-overlap scheduling.HybridModel.sharded_state_dict()drops the emptyoutput_layer._extra_stateto match GPT behavior. Verified cross-load (Transformer ↔ HybridModel) in [feat] Hybrid model ep overlapping main #4798.HybridStackModelChunkSchedulePlan,hybrid/fine_grained_callables.py, exposeHybridModel.build_schedule_plan, and wire thereturn_schedule_planpath inpretrain_hybrid.py.backward_dw: register Mamba pre-layer wgrad alongside attention/GDN pre-layers so the schedule node iterates a uniform set of callables.self.layer_numbervia modulo so the aux-loss tracker is not indexed past its size when MTP wraps a HybridStack (e.g.*Efor one depth).pre_mlp_layernormrecompute hook in moe_combine that corrupted attention gradients in bracketed-hybrid logical layers ([*E]).Why this slice
Touches 6 reviewer groups:
core-adlr,core-nemo,hybrid-model,hybrid-mamba,mixture-of-experts-adlr,mixture-of-experts-devtech.Dependencies
_BackwardDWWrapperlocation.GitHub will show #4941's diff in this PR until #4941 merges; expected.
Validation
The full integrated change was validated in #4798:
test_hybrid_layer_allocation,test_hybrid_block,test_hybrid_model::test_grouped_sharded_state_dict_uses_transformer_checkpoint_keys— 86 passed.--dist-ckpt-strictness raise_unexpected: STATUS 0 for save / transformer→hybrid / hybrid→transformer.[*-][*-]|[*-][*-]|[*E][*E]|...pattern + VPP=2 + A2A + recompute layernorm: 100/100 iters identical between baseline and A2A overlap.Issue tracking
Linked issue: part of #4798.
Pre-checks
🤖 Generated with Claude Code