Skip to content

[feat] HybridStack grouped syntax + checkpoint compat + EP-overlap (2/4 of #4798)#4942

Draft
Connor-XY wants to merge 10 commits into
NVIDIA:mainfrom
Connor-XY:pr4798-2-hybrid-stack-grouped
Draft

[feat] HybridStack grouped syntax + checkpoint compat + EP-overlap (2/4 of #4798)#4942
Connor-XY wants to merge 10 commits into
NVIDIA:mainfrom
Connor-XY:pr4798-2-hybrid-stack-grouped

Conversation

@Connor-XY

@Connor-XY Connor-XY commented May 22, 2026

Copy link
Copy Markdown
Contributor

What does this PR do?

Part 2 of 4 splitting #4798 by @Wohox and @Connor-XY. Original changes by @Wohox and @Connor-XY.

Summary

Add the core HybridStack feature work from #4798:

  • Bracketed HybridStack group syntax (e.g. [*-], [*E], M[M*]-): build bracketed groups as nested HybridStack instances; reject invalid recursion (nested brackets inside a group); keep MoE constrained to the last symbol inside a group for EP-overlap scheduling.
  • Transformer-compatible sharded checkpoint keys for grouped HybridStack: HybridModel.sharded_state_dict() drops the empty output_layer._extra_state to match GPT behavior. Verified cross-load (Transformer ↔ HybridModel) in [feat] Hybrid model ep overlapping main #4798.
  • HybridStack EP-overlap: add HybridStackModelChunkSchedulePlan, hybrid/fine_grained_callables.py, expose HybridModel.build_schedule_plan, and wire the return_schedule_plan path in pretrain_hybrid.py.
  • Mamba backward_dw: register Mamba pre-layer wgrad alongside attention/GDN pre-layers so the schedule node iterates a uniform set of callables.
  • MoE TopKRouter MTP fix: collapse self.layer_number via modulo so the aux-loss tracker is not indexed past its size when MTP wraps a HybridStack (e.g. *E for one depth).
  • Carries the ce6e229 fix from [feat] Hybrid model ep overlapping main #4798: drop the redundant pre_mlp_layernorm recompute hook in moe_combine that corrupted attention gradients in bracketed-hybrid logical layers ([*E]).

Why this slice

Touches 6 reviewer groups: core-adlr, core-nemo, hybrid-model, hybrid-mamba, mixture-of-experts-adlr, mixture-of-experts-devtech.

Dependencies

GitHub will show #4941's diff in this PR until #4941 merges; expected.

Validation

The full integrated change was validated in #4798:

  • Unit tests: test_hybrid_layer_allocation, test_hybrid_block, test_hybrid_model::test_grouped_sharded_state_dict_uses_transformer_checkpoint_keys — 86 passed.
  • Checkpoint cross-load (Transformer ↔ grouped HybridModel) with --dist-ckpt-strictness raise_unexpected: STATUS 0 for save / transformer→hybrid / hybrid→transformer.
  • DeepSeek-V3 deterministic Transformer/Hybrid baseline + EP-overlap smoke tests.
  • ce6e229 bitwise verification on lite-deter-hybrid + [*-][*-]|[*-][*-]|[*E][*E]|... pattern + VPP=2 + A2A + recompute layernorm: 100/100 iters identical between baseline and A2A overlap.

Issue tracking

Linked issue: part of #4798.

Pre-checks

🤖 Generated with Claude Code

@copy-pr-bot

copy-pr-bot Bot commented May 22, 2026

Copy link
Copy Markdown

This pull request requires additional validation before any workflows can run on NVIDIA's runners.

Pull request vetters can view their responsibilities here.

Contributors can view more details about this message here.

Connor-XY and others added 10 commits June 3, 2026 11:48
Move model-agnostic schedule-plan helpers out of gpt/fine_grained_callables.py
into megatron/core/models/common/ so non-GPT models (HybridStack) can build the
same combined-1F1B / EP-overlap schedule plans. No behavior change for GPT/MTP.

- Add megatron/core/models/common/{utils.py, fine_grained_callables.py,
  model_chunk_schedule_plan.py} with the shared abstractions.
- Reduce megatron/core/models/gpt/fine_grained_callables.py to GPT-specific
  pieces; reuse the new common base classes.
- Switch GraphableMegatronModule.init_backward_dw_wrapper to import
  _BackwardDWWrapper from the new common module.
- Relax combined_forward_backward_step's GPTModel-only assert to a duck-type
  on build_schedule_plan so any model implementing it can participate.

Part 1/4 of splitting NVIDIA#4798 (original changes by @Wohox).

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
…unk)

Carries over upstream commit f6ea23b from NVIDIA#4798: when a hybrid layer
pattern places MTP in a post_process VPP chunk that holds no main
HybridStack layers (e.g. trailing pipe before the MTP separator), the
EP-overlap schedule never invokes ``_maybe_apply_final_norm`` on the main
path, so the unnormalized hidden_states feed straight into the LM head and
lm_loss diverges by ~10x.

Fix in ``submodule_mtp_pre_dispatch_forward``: run the main decoder's
``final_norm`` just before ``torch.chunk``, gated on
``len(model.decoder.layers) == 0`` and ``isinstance(model, HybridModel)``.
The HybridModel import is deferred inside the function so this file does
not gain a module-level dependency on the hybrid package — PR 1 remains
independently importable.

Part 1/4 of splitting NVIDIA#4798 (original changes by @Wohox).

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Add bracketed HybridStack group syntax (e.g. ``[*-]``, ``M[M*]-``) with
nested HybridStack instances, rejecting invalid recursion. Migrate grouped
HybridStack checkpoints to Transformer-compatible logical layer keys and
make ``HybridModel.sharded_state_dict()`` drop the empty
``output_layer._extra_state`` to match GPT behavior.

Extend EP-overlap scheduling to HybridStack: add the hybrid fine-grained
callables and ``HybridStackModelChunkSchedulePlan``, expose
``HybridModel.build_schedule_plan`` and add the ``return_schedule_plan``
path in ``pretrain_hybrid.py``. Add Mamba ``backward_dw`` so the hybrid
schedule node can register Mamba pre-layer weight grads alongside attention
and GDN pre-layers. Fix the MoE TopKRouter MTP layer-number indexing when
the MTP block wraps a HybridStack so the aux-loss tracker is not indexed
past its size.

Part 2/4 of splitting NVIDIA#4798 (original changes by @Wohox). Depends on
the common combined-1F1B refactor in part 1/4 (#TBD).

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Carries over upstream commit ce6e229 from NVIDIA#4798: in HybridStack's
``_run_moe_combine`` (A2A overlap path), ``layer._forward_post_mlp``
registers a second ``discard_output_and_register_recompute`` hook on
``mlp_output_with_bias[0]``. The hook fires during combine_bwd's autograd
backward and triggers the LN recompute ahead of mlp_bwd / pre_dispatch_bwd.
In bracketed-hybrid logical layers (``[*E]``), this corrupts gradients in
attention's autograd chain (grad_norm explodes from iter 2).

Fix: stop calling ``_forward_post_mlp`` from ``_run_moe_combine``; inline
the ``bda + offload_mlp_norm + make_viewless_tensor`` steps directly,
mirroring GPT's ``submodule_combine_forward``. The first recompute hook on
``expert_output`` (registered in ``_run_moe_experts``) already fires the
LN recompute in mlp_bwd, so the second hook is redundant.

Part 2/4 of splitting NVIDIA#4798 (original changes by @Wohox).

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
The recent merge of origin/main introduced `name=(name + f".layers.{i}")`
into every layer-type branch of HybridStack's build loop, but didn't change
the local loop header `for layer_type in self.layer_type_list:` to surface
`i`. Result: `NameError: name 'i' is not defined` at HybridStack init for
all hybrid runs (GPT path unaffected).

Trigger: any hybrid_stack_spec model crashes on init, including the 16-node
Bug 2a repro and the 8-node GPT-vs-Hybrid perf comparison runs.

Fix: convert the loop to `for i, layer_type in enumerate(...)`. Keep the
existing `physical_layer_offset` counter (used for FP8/FP4 contexts and
`layer_number`) because bracket groups count >1 physical layer per logical
entry — these are separate from the logical index `i` used for module names.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
@Connor-XY Connor-XY force-pushed the pr4798-2-hybrid-stack-grouped branch from 5e5dee3 to 91e86bb Compare June 3, 2026 19:21
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants