[dev] [DeepSeek-v4] Part 3: MTP support with mHC and new mHC contract#4518
Conversation
|
/claude strict-review |
| if self.config.enable_hyper_connections: | ||
| hc_mult = self.config.num_residual_streams | ||
| hc_dim = self.config.hidden_size * hc_mult | ||
| self.hc_head_fn = nn.Parameter(torch.randn(hc_mult, hc_dim)) | ||
| self.hc_head_base = nn.Parameter(torch.zeros(hc_mult)) | ||
| self.hc_head_scale = nn.Parameter(torch.ones(1)) | ||
| nn.init.xavier_uniform_(self.hc_head_fn) | ||
| if self.config.sequence_parallel: | ||
| setattr(self.hc_head_fn, 'sequence_parallel', True) | ||
| setattr(self.hc_head_base, 'sequence_parallel', True) | ||
| setattr(self.hc_head_scale, 'sequence_parallel', True) |
There was a problem hiding this comment.
[IMPORTANT Correctness] hc_head_fn, hc_head_base, and hc_head_scale are bare nn.Parameter objects, not nn.Module children. TransformerBlock.sharded_state_dict() has custom logic that only iterates self.named_children() (layers + final_layernorm, etc.) — it does not call self._save_to_state_dict() like the base MegatronModule.sharded_state_dict() does. As a result, these three learned parameters will be silently dropped during distributed checkpoint save and won't be restored on resume, causing training divergence after checkpoint reload.
The same parameters on MultiTokenPredictionLayer are fine because its sharded_state_dict calls super().sharded_state_dict(), which does handle bare parameters.
Fix: Add explicit handling for these parameters in TransformerBlock.sharded_state_dict(). For example, after the named_children loop, call self._save_to_state_dict(...) for the standalone parameters, or wrap them in a small nn.Module so they're picked up by named_children().
There was a problem hiding this comment.
@hxbai this one make sense and I think without wrapping those are not handled correctly by dist ckpting. These parameters are meant to be loaded on restart, right?
| if mhc_multistream is not None: | ||
| mhc_chunks.append(hidden_states) | ||
| hidden_states_list.append(self.layers[layer_idx]._postprocess(hidden_states)) |
There was a problem hiding this comment.
[SUGGESTION Simplification] _postprocess is called here to contract each MTP layer's multi-stream output for the loss list. This works correctly, but it means learned_output_contract + final_layernorm are computed for each MTP depth even though the multi-stream tensor was already computed in the same iteration.
Consider caching the contracted result inside _proj_and_transformer_layer (compute it but don't apply it to the returned tensor) to avoid the redundant forward through _postprocess. Not a correctness issue, but calling _postprocess externally on a "raw" tensor that the layer itself chose not to postprocess creates a subtle coupling between the caller and the layer's internal mhc_enabled flag.
| def learned_output_contract( | ||
| hidden_states: Tensor, head_fn: Tensor, base: Tensor, scale: Tensor, n: int, eps: float | ||
| ) -> Tensor: | ||
| """Learned output contraction: n-stream → 1-stream via sigmoid-gated weighted sum.""" | ||
| dtype = hidden_states.dtype | ||
| hidden_states = hidden_states.to(torch.float32) | ||
| head_fn = head_fn.to(torch.float32) | ||
| base = base.to(torch.float32) | ||
| scale = scale.to(torch.float32) | ||
| rsqrt = torch.rsqrt(hidden_states.square().mean(-1, keepdim=True) + eps) | ||
| mixes = F.linear(hidden_states, head_fn) * rsqrt | ||
| pre = torch.sigmoid(mixes * scale + base) + 1e-6 | ||
| y = torch.sum(pre.unsqueeze(-1) * hidden_states.view(*hidden_states.shape[:-1], n, -1), dim=-2) | ||
| return y.to(dtype) |
There was a problem hiding this comment.
[SUGGESTION Simplification] The 1e-6 offset added to pre (line 107) prevents the sigmoid gating weights from being exactly zero. However, this also means the minimum contribution per stream is always positive (1e-6), which accumulates to n * 1e-6 over all streams. This is negligible for typical hidden sizes, but it's worth noting that eps (the RMS epsilon) and this 1e-6 serve different numerical stability purposes — consider making the sigmoid floor configurable or documenting its purpose in the docstring.
Also, head_fn has shape [n, n*h] which means the F.linear projects the full concatenated multi-stream vector down to n mixing weights. This is an O(n² * h) operation per token. For large n, this could be meaningful. A comment noting this design choice (full-rank vs. per-stream factored projection) would help future readers.
| # DSv4 reference (DS-Inf) RoPE is pure rotation (norm-preserving). Yarn's | ||
| # concentration factor (mscale) is NOT part of the DSv4 model contract -- | ||
| # the model relies on Q/KV RMS-norm + unit-magnitude rotation. Force 1.0. | ||
| mscale = 1.0 |
There was a problem hiding this comment.
[SUGGESTION Simplification] This mscale = 1.0 override is duplicated in three places across two files (csa.py and deepseek_v4_hybrid_attention.py ×2). All three share the identical comment and logic pattern. Consider handling this at a single point — for example, in the rotary embedding module's constructor or a config-driven flag — rather than patching the returned value at every call site. This would also prevent future call sites from forgetting the override.
There was a problem hiding this comment.
This is the special behavior of DSv4 and no need to put it into a standalone position.
| if self.mhc_enabled: | ||
| hc_mult = self.config.num_residual_streams | ||
| hc_dim = self.config.hidden_size * hc_mult | ||
| self.hc_head_fn = nn.Parameter(torch.randn(hc_mult, hc_dim)) | ||
| self.hc_head_base = nn.Parameter(torch.zeros(hc_mult)) | ||
| self.hc_head_scale = nn.Parameter(torch.ones(1)) | ||
| nn.init.xavier_uniform_(self.hc_head_fn) | ||
| if self.config.sequence_parallel: | ||
| setattr(self.hc_head_fn, 'sequence_parallel', True) | ||
| setattr(self.hc_head_base, 'sequence_parallel', True) | ||
| setattr(self.hc_head_scale, 'sequence_parallel', True) |
There was a problem hiding this comment.
[SUGGESTION Simplification] This hc_head_* parameter initialization block is duplicated verbatim in TransformerBlock.__init__ (lines 388–398). Consider extracting a small helper (e.g., _init_learned_contraction_params(config) returning a ParameterDict or small nn.Module) to avoid the duplication and ensure both sites stay in sync.
There was a problem hiding this comment.
Review Summary — PR #4518: [dev] [DeepSeek-v4] Part 3: MTP support with mHC and new mHC contract
Findings: CRITICAL: 0 | IMPORTANT: 3 | SUGGESTION: 4
What this PR does
- mHC + MTP integration: Enables Multi-Token Prediction layers to work with manifold-constrained Hyper-Connections (mHC). MTP layers get separate
e_proj/h_proj(instead of the concatenatedeh_proj) to operate per-stream, and multi-stream state flows through the MTP loop without premature contraction. - New learned output contraction: Replaces the simple stream-averaging
HyperConnectionModule.output_contractwithlearned_output_contract— a sigmoid-gated weighted sum with learnablehc_head_fn,hc_head_base,hc_head_scaleparameters. This applies to all HC models, not only HC+MTP. - DSv4 mscale fix: Forces
mscale = 1.0in the non-fused RoPE path for CSA and DSv4 hybrid attention, since DSv4 uses norm-preserving rotation without Yarn's concentration factor. - Removes the HC+MTP incompatibility validation from
TransformerConfig.__post_init__.
Most impactful findings
-
TransformerBlockcheckpoint bug (IMPORTANT): The newhc_head_fn/base/scaleparameters are barenn.Parameterobjects onTransformerBlock. The block's customsharded_state_dict()only iteratesnamed_children()and never calls_save_to_state_dict(), so these parameters will be silently dropped from distributed checkpoints. Training will resume with randomly-initialized contraction weights. The same parameters onMultiTokenPredictionLayerare fine (itssharded_state_dictcallssuper()). -
Breaking change for existing HC models (IMPORTANT): The contraction method change (averaging → learned) applies to all
enable_hyper_connections=Truemodels. Old checkpoints won't have the new parameters. This should be documented as a breaking change, or gated behind a flag. -
hybrid_model.pytuple return not handled (IMPORTANT):TransformerBlock.forward()now returns a tuple when HC+MTP are both active, buthybrid_model.pyassigns the decoder output directly tohidden_stateswithout unpacking.
Overall assessment
The algorithmic design is sound — the multi-stream dataflow through the MTP loop, the per-stream projection split, and the learned contraction function are all correctly implemented. The unit tests are thorough and cover constructor shapes, block return types, and E2E forward+backward. The checkpoint save/load bug is the highest-priority issue and should be fixed before merge. The backward compatibility concern is worth discussing given this targets dev.
|
/ok to test 726ab33 |
|
/ok to test c441614 |
|
🔄 Merge queue validation started! You can track the progress here: https://github.com/NVIDIA/Megatron-LM/actions/runs/25589952024 |
|
/ok to test 60ff377 |
|
🔄 Merge queue validation started! You can track the progress here: https://github.com/NVIDIA/Megatron-LM/actions/runs/25852541556 |
|
🔄 Merge queue validation started! You can track the progress here: https://github.com/NVIDIA/Megatron-LM/actions/runs/25854373953 |
### PR Category <!-- One of [ Train | Inference | Compress | Serve | RL | Core | Hardware | CICD | Tools | Others ] --> [Train] Most of codes are copied from Megatron-LM Dev branch. The dev branch is different with main branch or release version. Megatron LM PR: DeepSeek-V4: NVIDIA#4458 NVIDIA#4481 NVIDIA#4518 mHC: NVIDIA#2943 ### PR Types <!-- One of [ User Experience | New Features | Bug Fixes | Improvements | Performance | Breaking Change| Deprecations | Test Case | Docs | Others ] --> [New features] ### PR Description <!-- Describe what you’ve done --> Add DeepSeek V4 model into FlagScale and Megatron-FL Supported: 1. CSA and HCA 2. Hash Router 3. mHC 4. Engram(optional) Unsupported: 1. Sqrtsoftpuls router score function. ✅ 2. mHC recompute. ✅ 3. Overlap_grad_reduce and overlap_param_gather when Zero 1. ✅ 4. Any infra optimizations. ### NOTE: This is only a draft pr, please reivew to give more suggestions. such as: 1. File structure. - All modules are moved into Megatron-FL ### Next plan: 1. Distributed training. ✅ 3. Muon optimizer with Zero 1 adaptation. 🚧 4. Low precision is out of scope of this pr, limited by resource. 5. Maybe context parallel for sparse attention. 6. Welcome to give more suggestions. --------- Co-authored-by: Hongxiao Bai <hongxiaob@nvidia.com> Co-authored-by: Yuzhong Wang <yuzhongw@nvidia.com>
### PR Category <!-- One of [ Train | Inference | Compress | Serve | RL | Core | Hardware | CICD | Tools | Others ] --> [Train] Most of codes are copied from Megatron-LM Dev branch. The dev branch is different with main branch or release version. Megatron LM PR: DeepSeek-V4: NVIDIA/Megatron-LM#4458 NVIDIA/Megatron-LM#4481 NVIDIA/Megatron-LM#4518 mHC: NVIDIA/Megatron-LM#2943 ### PR Types <!-- One of [ User Experience | New Features | Bug Fixes | Improvements | Performance | Breaking Change| Deprecations | Test Case | Docs | Others ] --> [New features] ### PR Description <!-- Describe what you’ve done --> Add DeepSeek V4 model into FlagScale and Megatron-FL Supported: 1. CSA and HCA 2. Hash Router 3. mHC 4. Engram(optional) Unsupported: 1. Sqrtsoftpuls router score function. ✅ 2. mHC recompute. ✅ 3. Overlap_grad_reduce and overlap_param_gather when Zero 1. ✅ 4. Any infra optimizations. ### NOTE: This is only a draft pr, please reivew to give more suggestions. such as: 1. File structure. - **All modules are moved to Megatron-FL. Only model_builder is left in Flagscale.** - Delete Engram related CI or not? ### Next plan: 1. Distributed training. ✅ 3. Muon optimizer with Zero 1 adaptation. 😢 4. Low precision is out of scope of this pr, limited by resource. 5. Maybe context parallel for sparse attention. 6. Welcome to give more suggestions. --------- Co-authored-by: zhaoyingli <86812880+zhaoyinglia@users.noreply.github.com>
What does this PR do ?
We will create several PRs to functionally support DeepSeek-v4 training. This is the third one.
Add DeepSeek-v4 MTP support with mHC and new mHC contract.
Issue tracking
For PRs from open-source community contributors:
Linked issue:
Contribution process
Pre-checks
Code review
Feel free to message or comment the @mcore-oncall to help accelerate your merge into main. The less complex your PR is, the faster it will be approved and merged!
All PRs start as draft. If you open a non-draft PR, it will be automatically converted to draft.
Step 1: Mark PR as "Ready for Review"
.github/CODEOWNERS.Final Review might get declined if these requirements are not fulfilled.
Step 2: Final Review
For PRs that change
megatron/core, once all expert reviewers have approved, theFinal Reviewlabel is applied automatically and final reviewers are assigned.For PRs outside
megatron/core, this step is skipped.Step 3: Approved
Once all required reviewers have approved, the
Approvedlabel is applied automatically.Merge
Any member of mcore-engineers will be able to merge your PR.
For MRs into `dev` branch
The proposed review process for `dev` branch is under active discussion.MRs are mergable after one approval by either
eharper@nvidia.comorzijiey@nvidia.com.