Skip to content

[dev] [DeepSeek-v4] Part 2: Hash MoE and SwiGLU clamp#4481

Merged
hxbai merged 9 commits into
NVIDIA:devfrom
hxbai:dsv4_moe
Apr 30, 2026
Merged

[dev] [DeepSeek-v4] Part 2: Hash MoE and SwiGLU clamp#4481
hxbai merged 9 commits into
NVIDIA:devfrom
hxbai:dsv4_moe

Conversation

@hxbai

@hxbai hxbai commented Apr 27, 2026

Copy link
Copy Markdown
Contributor

What does this PR do ?

We will create several PRs to functionally support DeepSeek-v4 training. This is the second one.

Add DeepSeek-v4 Hash MoE and SwiGLU clamp.

  • Add new argument --moe-n-hash-layers.
  • Add SwiGLU support to --activation-func-clamp-value.

⚠️ For major changes (either in lines of code or in its impact), please make sure to first share a design doc with the team. If you're unsure what's the best way to do so, contact the @mcore-oncall.

Issue tracking

For PRs from open-source community contributors:

  • New features: a linked issue is required. Please open a feature request and reference it here before submitting the PR.
  • Small updates (bug fixes, minor improvements): a linked issue is recommended and will accelerate the PR review process.

Linked issue:

Contribution process

Pre-checks

  • I have added relevant unit tests
  • I have added relevant functional tests
  • I have added proper typing to my code Typing guidelines
  • I have added relevant documentation
  • I have run the autoformatter.sh on my PR

Code review

Feel free to message or comment the @mcore-oncall to help accelerate your merge into main. The less complex your PR is, the faster it will be approved and merged!

All PRs start as draft. If you open a non-draft PR, it will be automatically converted to draft.

Step 1: Mark PR as "Ready for Review"

  1. When your PR is ready, click Ready for Review.
  2. An oncall reviewer is auto-assigned and expert reviewers are notified based on your changes.
    • Some PRs may jump straight to step 2. This is determined by .github/CODEOWNERS.

⚠️ Only mark as ready once merge-conflicts are resolved and the CI is passing.
Final Review might get declined if these requirements are not fulfilled.

Step 2: Final Review

For PRs that change megatron/core, once all expert reviewers have approved, the Final Review label is applied automatically and final reviewers are assigned.

For PRs outside megatron/core, this step is skipped.

Step 3: Approved

Once all required reviewers have approved, the Approved label is applied automatically.

Merge

Any member of mcore-engineers will be able to merge your PR.

For MRs into `dev` branch The proposed review process for `dev` branch is under active discussion.

MRs are mergable after one approval by either eharper@nvidia.com or zijiey@nvidia.com.

@copy-pr-bot

copy-pr-bot Bot commented Apr 27, 2026

Copy link
Copy Markdown

This pull request requires additional validation before any workflows can run on NVIDIA's runners.

Pull request vetters can view their responsibilities here.

Contributors can view more details about this message here.

@hxbai hxbai self-assigned this Apr 27, 2026
@hxbai hxbai added the dev branch Dev branch related issues and development label Apr 27, 2026

@Victarry Victarry left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Generally looks good to me. Co-reviewd with AI, please take a look~

Comment thread megatron/core/transformer/transformer_block.py Outdated
Comment thread megatron/core/transformer/moe/router.py
Comment thread megatron/core/fusions/fused_bias_swiglu.py
Comment thread megatron/core/transformer/transformer_block.py Outdated
Comment thread megatron/core/transformer/moe/router.py Outdated
Comment thread megatron/core/transformer/transformer_layer.py
Comment thread megatron/core/transformer/moe/router.py
Comment on lines 200 to 202

activation_func_clamp_value: Optional[float] = None
"""Clamp the output of the linear_fc1 in the activation function. Only used when activation_func

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[SUGGESTION] The docstring now claims activation_func_clamp_value works for quick_gelu or swiglu, but in this PR clamp is only wired through the weighted SwiGLU path (weighted_bias_swiglu_impl -> WeightedSwiGLUFunction). The dense path bias_swiglu_impl (used by non-MoE / non-token-weighted SwiGLU MLP) does not accept clamp_value and silently ignores it.

So a user setting activation_func_clamp_value on a model with dense SwiGLU MLP layers will see zero effect, with no warning.

Suggestion: either

  • extend bias_swiglu_impl / BiasSwiGLUFunction / SwiGLUFunction to also accept and respect clamp_value, or
  • narrow the docstring to "weighted SwiGLU (MoE) only" and add a runtime check that warns or asserts when activation_func_clamp_value > 0 is set on a dense-SwiGLU configuration.

@hxbai hxbai Apr 29, 2026

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed the docstring

Comment thread megatron/core/transformer/moe/router.py Outdated
@hxbai hxbai mentioned this pull request Apr 29, 2026
3 tasks
@hxbai hxbai marked this pull request as ready for review April 29, 2026 08:56
@hxbai hxbai requested review from a team as code owners April 29, 2026 08:56
@hxbai hxbai changed the title [dev] [DeepSeek-v4] Part 2: Hash MoE, SwiGLU clamp, and new mHC contract [dev] [DeepSeek-v4] Part 2: Hash MoE and SwiGLU clamp Apr 29, 2026
@hxbai

hxbai commented Apr 29, 2026

Copy link
Copy Markdown
Contributor Author

/ok to test df71a39

@hxbai

hxbai commented Apr 29, 2026

Copy link
Copy Markdown
Contributor Author

/ok to test c51d461

@Victarry

Copy link
Copy Markdown
Contributor

LGTM

@hxbai

hxbai commented Apr 29, 2026

Copy link
Copy Markdown
Contributor Author

/ok to test 608a1b2

@hxbai

hxbai commented Apr 29, 2026

Copy link
Copy Markdown
Contributor Author

/ok to test dfadf2e

@hxbai

hxbai commented Apr 30, 2026

Copy link
Copy Markdown
Contributor Author

/ok to test d6a9445

@hxbai hxbai added this pull request to the merge queue Apr 30, 2026
@svcnvidia-nemo-ci

Copy link
Copy Markdown

🔄 Merge queue validation started!

You can track the progress here: https://github.com/NVIDIA/Megatron-LM/actions/runs/25151919314

Merged via the queue into NVIDIA:dev with commit fe729e9 Apr 30, 2026
64 of 65 checks passed
@hxbai hxbai deleted the dsv4_moe branch April 30, 2026 07:34
hxbai added a commit to hxbai/Megatron-LM that referenced this pull request Apr 30, 2026
LiJunscs pushed a commit to LiJunscs/Megatron-LM-FL that referenced this pull request May 11, 2026
LiJunscs pushed a commit to LiJunscs/Megatron-LM-FL that referenced this pull request May 20, 2026
zhaoyinglia pushed a commit to flagos-ai/Megatron-LM-FL that referenced this pull request May 24, 2026
### PR Category
<!-- One of [ Train | Inference | Compress | Serve | RL | Core |
Hardware | CICD | Tools | Others ] -->
[Train] Most of codes are copied from Megatron-LM Dev branch. The dev
branch is different with main branch or release version.
Megatron LM PR:
DeepSeek-V4:
NVIDIA#4458
NVIDIA#4481
NVIDIA#4518
mHC:
NVIDIA#2943
### PR Types
<!-- One of [ User Experience | New Features | Bug Fixes | Improvements
| Performance | Breaking Change| Deprecations | Test Case | Docs |
Others ] -->
[New features]
### PR Description
<!-- Describe what you’ve done -->
Add DeepSeek V4 model into FlagScale and Megatron-FL
Supported:
1. CSA and HCA
2. Hash Router
3. mHC
4. Engram(optional)

Unsupported:
1. Sqrtsoftpuls router score function. ✅
2. mHC recompute. ✅
3. Overlap_grad_reduce and overlap_param_gather when Zero 1. ✅
4. Any infra optimizations.

### NOTE: This is only a draft pr, please reivew to give more
suggestions.
such as:
1. File structure.
    - All modules are moved into Megatron-FL

### Next plan:
1. Distributed training. ✅
3. Muon optimizer with Zero 1 adaptation. 🚧
4. Low precision is out of scope of this pr, limited by resource. 
5. Maybe context parallel for sparse attention.
6. Welcome to give more suggestions.

---------

Co-authored-by: Hongxiao Bai <hongxiaob@nvidia.com>
Co-authored-by: Yuzhong Wang <yuzhongw@nvidia.com>
zhaoyinglia added a commit to flagos-ai/FlagScale that referenced this pull request May 24, 2026
### PR Category
<!-- One of [ Train | Inference | Compress | Serve | RL | Core |
Hardware | CICD | Tools | Others ] -->
[Train] Most of codes are copied from Megatron-LM Dev branch. The dev
branch is different with main branch or release version.
Megatron LM PR:
DeepSeek-V4:
NVIDIA/Megatron-LM#4458
NVIDIA/Megatron-LM#4481
NVIDIA/Megatron-LM#4518
mHC:
NVIDIA/Megatron-LM#2943
### PR Types
<!-- One of [ User Experience | New Features | Bug Fixes | Improvements
| Performance | Breaking Change| Deprecations | Test Case | Docs |
Others ] -->
[New features]
### PR Description
<!-- Describe what you’ve done -->
Add DeepSeek V4 model into FlagScale and Megatron-FL
Supported:
1. CSA and HCA
2. Hash Router
3. mHC
4. Engram(optional)

Unsupported:
1. Sqrtsoftpuls router score function. ✅
2. mHC recompute. ✅
3. Overlap_grad_reduce and overlap_param_gather when Zero 1. ✅
4. Any infra optimizations.

### NOTE: This is only a draft pr, please reivew to give more
suggestions.
such as:
1. File structure.  
- **All modules are moved to Megatron-FL. Only model_builder is left in
Flagscale.**
    - Delete Engram related CI or not?

### Next plan:
1. Distributed training. ✅
3. Muon optimizer with Zero 1 adaptation. 😢
4. Low precision is out of scope of this pr, limited by resource.
5. Maybe context parallel for sparse attention.
6. Welcome to give more suggestions.

---------

Co-authored-by: zhaoyingli <86812880+zhaoyinglia@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

complexity: medium dev branch Dev branch related issues and development

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants