Skip to content

[dev] [DeepSeek-v4] Part 3: MTP support with mHC and new mHC contract#4518

Merged
hxbai merged 17 commits into
NVIDIA:devfrom
hxbai:dsv4_mtp
May 14, 2026
Merged

[dev] [DeepSeek-v4] Part 3: MTP support with mHC and new mHC contract#4518
hxbai merged 17 commits into
NVIDIA:devfrom
hxbai:dsv4_mtp

Conversation

@hxbai

@hxbai hxbai commented Apr 29, 2026

Copy link
Copy Markdown
Contributor

What does this PR do ?

We will create several PRs to functionally support DeepSeek-v4 training. This is the third one.

Add DeepSeek-v4 MTP support with mHC and new mHC contract.

⚠️ For major changes (either in lines of code or in its impact), please make sure to first share a design doc with the team. If you're unsure what's the best way to do so, contact the @mcore-oncall.

Issue tracking

For PRs from open-source community contributors:

  • New features: a linked issue is required. Please open a feature request and reference it here before submitting the PR.
  • Small updates (bug fixes, minor improvements): a linked issue is recommended and will accelerate the PR review process.

Linked issue:

Contribution process

Pre-checks

  • I have added relevant unit tests
  • I have added relevant functional tests
  • I have added proper typing to my code Typing guidelines
  • I have added relevant documentation
  • I have run the autoformatter.sh on my PR

Code review

Feel free to message or comment the @mcore-oncall to help accelerate your merge into main. The less complex your PR is, the faster it will be approved and merged!

All PRs start as draft. If you open a non-draft PR, it will be automatically converted to draft.

Step 1: Mark PR as "Ready for Review"

  1. When your PR is ready, click Ready for Review.
  2. An oncall reviewer is auto-assigned and expert reviewers are notified based on your changes.
    • Some PRs may jump straight to step 2. This is determined by .github/CODEOWNERS.

⚠️ Only mark as ready once merge-conflicts are resolved and the CI is passing.
Final Review might get declined if these requirements are not fulfilled.

Step 2: Final Review

For PRs that change megatron/core, once all expert reviewers have approved, the Final Review label is applied automatically and final reviewers are assigned.

For PRs outside megatron/core, this step is skipped.

Step 3: Approved

Once all required reviewers have approved, the Approved label is applied automatically.

Merge

Any member of mcore-engineers will be able to merge your PR.

For MRs into `dev` branch The proposed review process for `dev` branch is under active discussion.

MRs are mergable after one approval by either eharper@nvidia.com or zijiey@nvidia.com.

@hxbai hxbai self-assigned this Apr 29, 2026
@hxbai hxbai added the dev branch Dev branch related issues and development label Apr 29, 2026
@copy-pr-bot

copy-pr-bot Bot commented Apr 29, 2026

Copy link
Copy Markdown

This pull request requires additional validation before any workflows can run on NVIDIA's runners.

Pull request vetters can view their responsibilities here.

Contributors can view more details about this message here.

@yaox12

yaox12 commented May 6, 2026

Copy link
Copy Markdown
Member

/claude strict-review

Comment on lines +388 to +398
if self.config.enable_hyper_connections:
hc_mult = self.config.num_residual_streams
hc_dim = self.config.hidden_size * hc_mult
self.hc_head_fn = nn.Parameter(torch.randn(hc_mult, hc_dim))
self.hc_head_base = nn.Parameter(torch.zeros(hc_mult))
self.hc_head_scale = nn.Parameter(torch.ones(1))
nn.init.xavier_uniform_(self.hc_head_fn)
if self.config.sequence_parallel:
setattr(self.hc_head_fn, 'sequence_parallel', True)
setattr(self.hc_head_base, 'sequence_parallel', True)
setattr(self.hc_head_scale, 'sequence_parallel', True)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[IMPORTANT Correctness] hc_head_fn, hc_head_base, and hc_head_scale are bare nn.Parameter objects, not nn.Module children. TransformerBlock.sharded_state_dict() has custom logic that only iterates self.named_children() (layers + final_layernorm, etc.) — it does not call self._save_to_state_dict() like the base MegatronModule.sharded_state_dict() does. As a result, these three learned parameters will be silently dropped during distributed checkpoint save and won't be restored on resume, causing training divergence after checkpoint reload.

The same parameters on MultiTokenPredictionLayer are fine because its sharded_state_dict calls super().sharded_state_dict(), which does handle bare parameters.

Fix: Add explicit handling for these parameters in TransformerBlock.sharded_state_dict(). For example, after the named_children loop, call self._save_to_state_dict(...) for the standalone parameters, or wrap them in a small nn.Module so they're picked up by named_children().

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@hxbai this one make sense and I think without wrapping those are not handled correctly by dist ckpting. These parameters are meant to be loaded on restart, right?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed

Comment thread megatron/core/transformer/transformer_block.py
Comment thread megatron/core/transformer/transformer_block.py
Comment on lines +1603 to +1605
if mhc_multistream is not None:
mhc_chunks.append(hidden_states)
hidden_states_list.append(self.layers[layer_idx]._postprocess(hidden_states))

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[SUGGESTION Simplification] _postprocess is called here to contract each MTP layer's multi-stream output for the loss list. This works correctly, but it means learned_output_contract + final_layernorm are computed for each MTP depth even though the multi-stream tensor was already computed in the same iteration.

Consider caching the contracted result inside _proj_and_transformer_layer (compute it but don't apply it to the returned tensor) to avoid the redundant forward through _postprocess. Not a correctness issue, but calling _postprocess externally on a "raw" tensor that the layer itself chose not to postprocess creates a subtle coupling between the caller and the layer's internal mhc_enabled flag.

Comment on lines +96 to +109
def learned_output_contract(
hidden_states: Tensor, head_fn: Tensor, base: Tensor, scale: Tensor, n: int, eps: float
) -> Tensor:
"""Learned output contraction: n-stream → 1-stream via sigmoid-gated weighted sum."""
dtype = hidden_states.dtype
hidden_states = hidden_states.to(torch.float32)
head_fn = head_fn.to(torch.float32)
base = base.to(torch.float32)
scale = scale.to(torch.float32)
rsqrt = torch.rsqrt(hidden_states.square().mean(-1, keepdim=True) + eps)
mixes = F.linear(hidden_states, head_fn) * rsqrt
pre = torch.sigmoid(mixes * scale + base) + 1e-6
y = torch.sum(pre.unsqueeze(-1) * hidden_states.view(*hidden_states.shape[:-1], n, -1), dim=-2)
return y.to(dtype)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[SUGGESTION Simplification] The 1e-6 offset added to pre (line 107) prevents the sigmoid gating weights from being exactly zero. However, this also means the minimum contribution per stream is always positive (1e-6), which accumulates to n * 1e-6 over all streams. This is negligible for typical hidden sizes, but it's worth noting that eps (the RMS epsilon) and this 1e-6 serve different numerical stability purposes — consider making the sigmoid floor configurable or documenting its purpose in the docstring.

Also, head_fn has shape [n, n*h] which means the F.linear projects the full concatenated multi-stream vector down to n mixing weights. This is an O(n² * h) operation per token. For large n, this could be meaningful. A comment noting this design choice (full-rank vs. per-stream factored projection) would help future readers.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

replaced with eps

Comment on lines +123 to +126
# DSv4 reference (DS-Inf) RoPE is pure rotation (norm-preserving). Yarn's
# concentration factor (mscale) is NOT part of the DSv4 model contract --
# the model relies on Q/KV RMS-norm + unit-magnitude rotation. Force 1.0.
mscale = 1.0

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[SUGGESTION Simplification] This mscale = 1.0 override is duplicated in three places across two files (csa.py and deepseek_v4_hybrid_attention.py ×2). All three share the identical comment and logic pattern. Consider handling this at a single point — for example, in the rotary embedding module's constructor or a config-driven flag — rather than patching the returned value at every call site. This would also prevent future call sites from forgetting the override.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is the special behavior of DSv4 and no need to put it into a standalone position.

Comment on lines +927 to +937
if self.mhc_enabled:
hc_mult = self.config.num_residual_streams
hc_dim = self.config.hidden_size * hc_mult
self.hc_head_fn = nn.Parameter(torch.randn(hc_mult, hc_dim))
self.hc_head_base = nn.Parameter(torch.zeros(hc_mult))
self.hc_head_scale = nn.Parameter(torch.ones(1))
nn.init.xavier_uniform_(self.hc_head_fn)
if self.config.sequence_parallel:
setattr(self.hc_head_fn, 'sequence_parallel', True)
setattr(self.hc_head_base, 'sequence_parallel', True)
setattr(self.hc_head_scale, 'sequence_parallel', True)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[SUGGESTION Simplification] This hc_head_* parameter initialization block is duplicated verbatim in TransformerBlock.__init__ (lines 388–398). Consider extracting a small helper (e.g., _init_learned_contraction_params(config) returning a ParameterDict or small nn.Module) to avoid the duplication and ensure both sites stay in sync.

@claude claude Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Review Summary — PR #4518: [dev] [DeepSeek-v4] Part 3: MTP support with mHC and new mHC contract

Findings: CRITICAL: 0 | IMPORTANT: 3 | SUGGESTION: 4

What this PR does

  1. mHC + MTP integration: Enables Multi-Token Prediction layers to work with manifold-constrained Hyper-Connections (mHC). MTP layers get separate e_proj/h_proj (instead of the concatenated eh_proj) to operate per-stream, and multi-stream state flows through the MTP loop without premature contraction.
  2. New learned output contraction: Replaces the simple stream-averaging HyperConnectionModule.output_contract with learned_output_contract — a sigmoid-gated weighted sum with learnable hc_head_fn, hc_head_base, hc_head_scale parameters. This applies to all HC models, not only HC+MTP.
  3. DSv4 mscale fix: Forces mscale = 1.0 in the non-fused RoPE path for CSA and DSv4 hybrid attention, since DSv4 uses norm-preserving rotation without Yarn's concentration factor.
  4. Removes the HC+MTP incompatibility validation from TransformerConfig.__post_init__.

Most impactful findings

  1. TransformerBlock checkpoint bug (IMPORTANT): The new hc_head_fn/base/scale parameters are bare nn.Parameter objects on TransformerBlock. The block's custom sharded_state_dict() only iterates named_children() and never calls _save_to_state_dict(), so these parameters will be silently dropped from distributed checkpoints. Training will resume with randomly-initialized contraction weights. The same parameters on MultiTokenPredictionLayer are fine (its sharded_state_dict calls super()).

  2. Breaking change for existing HC models (IMPORTANT): The contraction method change (averaging → learned) applies to all enable_hyper_connections=True models. Old checkpoints won't have the new parameters. This should be documented as a breaking change, or gated behind a flag.

  3. hybrid_model.py tuple return not handled (IMPORTANT): TransformerBlock.forward() now returns a tuple when HC+MTP are both active, but hybrid_model.py assigns the decoder output directly to hidden_states without unpacking.

Overall assessment

The algorithmic design is sound — the multi-stream dataflow through the MTP loop, the per-stream projection split, and the learned contraction function are all correctly implemented. The unit tests are thorough and cover constructor shapes, block return types, and E2E forward+backward. The checkpoint save/load bug is the highest-priority issue and should be fixed before merge. The backward compatibility concern is worth discussing given this targets dev.

@hxbai

hxbai commented May 8, 2026

Copy link
Copy Markdown
Contributor Author

/ok to test 726ab33

@hxbai

hxbai commented May 9, 2026

Copy link
Copy Markdown
Contributor Author

/ok to test c441614

@hxbai hxbai added this pull request to the merge queue May 9, 2026
@svcnvidia-nemo-ci

Copy link
Copy Markdown

🔄 Merge queue validation started!

You can track the progress here: https://github.com/NVIDIA/Megatron-LM/actions/runs/25589952024

@hxbai hxbai removed this pull request from the merge queue due to a manual request May 9, 2026
@hxbai

hxbai commented May 14, 2026

Copy link
Copy Markdown
Contributor Author

/ok to test 60ff377

@hxbai hxbai enabled auto-merge May 14, 2026 07:49
@hxbai hxbai added this pull request to the merge queue May 14, 2026
@svcnvidia-nemo-ci

Copy link
Copy Markdown

🔄 Merge queue validation started!

You can track the progress here: https://github.com/NVIDIA/Megatron-LM/actions/runs/25852541556

@svcnvidia-nemo-ci

Copy link
Copy Markdown

🔄 Merge queue validation started!

You can track the progress here: https://github.com/NVIDIA/Megatron-LM/actions/runs/25854373953

Merged via the queue into NVIDIA:dev with commit 2e55168 May 14, 2026
66 checks passed
@hxbai hxbai deleted the dsv4_mtp branch May 14, 2026 11:31
zhaoyinglia pushed a commit to flagos-ai/Megatron-LM-FL that referenced this pull request May 24, 2026
### PR Category
<!-- One of [ Train | Inference | Compress | Serve | RL | Core |
Hardware | CICD | Tools | Others ] -->
[Train] Most of codes are copied from Megatron-LM Dev branch. The dev
branch is different with main branch or release version.
Megatron LM PR:
DeepSeek-V4:
NVIDIA#4458
NVIDIA#4481
NVIDIA#4518
mHC:
NVIDIA#2943
### PR Types
<!-- One of [ User Experience | New Features | Bug Fixes | Improvements
| Performance | Breaking Change| Deprecations | Test Case | Docs |
Others ] -->
[New features]
### PR Description
<!-- Describe what you’ve done -->
Add DeepSeek V4 model into FlagScale and Megatron-FL
Supported:
1. CSA and HCA
2. Hash Router
3. mHC
4. Engram(optional)

Unsupported:
1. Sqrtsoftpuls router score function. ✅
2. mHC recompute. ✅
3. Overlap_grad_reduce and overlap_param_gather when Zero 1. ✅
4. Any infra optimizations.

### NOTE: This is only a draft pr, please reivew to give more
suggestions.
such as:
1. File structure.
    - All modules are moved into Megatron-FL

### Next plan:
1. Distributed training. ✅
3. Muon optimizer with Zero 1 adaptation. 🚧
4. Low precision is out of scope of this pr, limited by resource. 
5. Maybe context parallel for sparse attention.
6. Welcome to give more suggestions.

---------

Co-authored-by: Hongxiao Bai <hongxiaob@nvidia.com>
Co-authored-by: Yuzhong Wang <yuzhongw@nvidia.com>
zhaoyinglia added a commit to flagos-ai/FlagScale that referenced this pull request May 24, 2026
### PR Category
<!-- One of [ Train | Inference | Compress | Serve | RL | Core |
Hardware | CICD | Tools | Others ] -->
[Train] Most of codes are copied from Megatron-LM Dev branch. The dev
branch is different with main branch or release version.
Megatron LM PR:
DeepSeek-V4:
NVIDIA/Megatron-LM#4458
NVIDIA/Megatron-LM#4481
NVIDIA/Megatron-LM#4518
mHC:
NVIDIA/Megatron-LM#2943
### PR Types
<!-- One of [ User Experience | New Features | Bug Fixes | Improvements
| Performance | Breaking Change| Deprecations | Test Case | Docs |
Others ] -->
[New features]
### PR Description
<!-- Describe what you’ve done -->
Add DeepSeek V4 model into FlagScale and Megatron-FL
Supported:
1. CSA and HCA
2. Hash Router
3. mHC
4. Engram(optional)

Unsupported:
1. Sqrtsoftpuls router score function. ✅
2. mHC recompute. ✅
3. Overlap_grad_reduce and overlap_param_gather when Zero 1. ✅
4. Any infra optimizations.

### NOTE: This is only a draft pr, please reivew to give more
suggestions.
such as:
1. File structure.  
- **All modules are moved to Megatron-FL. Only model_builder is left in
Flagscale.**
    - Delete Engram related CI or not?

### Next plan:
1. Distributed training. ✅
3. Muon optimizer with Zero 1 adaptation. 😢
4. Low precision is out of scope of this pr, limited by resource.
5. Maybe context parallel for sparse attention.
6. Welcome to give more suggestions.

---------

Co-authored-by: zhaoyingli <86812880+zhaoyinglia@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

complexity: medium dev branch Dev branch related issues and development

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants