Skip to content

[recipe] feat: Add DeepSeek-V4-Flash pretraining recipes#3893

Merged
cuichenx merged 6 commits into
mainfrom
chcui/dsv4-train-pr3562-pr4518
May 31, 2026
Merged

[recipe] feat: Add DeepSeek-V4-Flash pretraining recipes#3893
cuichenx merged 6 commits into
mainfrom
chcui/dsv4-train-pr3562-pr4518

Conversation

@weijiac0619

@weijiac0619 weijiac0619 commented May 19, 2026

Copy link
Copy Markdown
Contributor

What does this PR do?

Adds pretraining recipes for DeepSeek-V4-Flash on Blackwell, plus a Slurm
launcher example, unit tests, and a functional test. The base recipe targets
TP=1 / PP=4 / EP=8 with selective activation recompute, an MTP-aware pipeline
layout, and BF16. Two variants extend the base for the two supported
optimizer + precision combinations.

Changelog

  • src/megatron/bridge/recipes/deepseek/deepseek_v4.py (new): three pretrain
    configs and a pipeline-layout helper.
    • deepseek_v4_flash_pretrain_config() — BF16 base; TP=1, PP=4, EP=8,
      selective recompute (moe_act, mhc), MTP placed on the last PP stage
      via pipeline_model_parallel_layout.
    • deepseek_v4_flash_pretrain_mxfp8_config() — Adam + MXFP8 training,
      BF16 MTP / validation eval, quant_recipe selects MXFP8 for TE linears.
    • deepseek_v4_flash_pretrain_muon_config() — Muon optimizer + BF16,
      non-layer-wise dispatch.
    • set_deepseek_v4_pipeline_model_parallel_layout() helper builds the
      even decoder layout with MTP and loss on the last PP rank.
  • src/megatron/bridge/recipes/deepseek/__init__.py: re-export the new
    configs and helper.
  • examples/models/deepseek_v4/README.md: document the new recipes and
    Slurm launcher.
  • examples/models/deepseek_v4/slurm_pretrain.sh: Slurm sbatch script for
    the new recipes.
  • tests/unit_tests/recipes/test_deepseek_recipes.py: extend coverage to
    DSv4 configs.
  • tests/functional_tests/test_groups/recipes/test_deepseek_recipes_pretrain.py
    (new): L0/L1 pretrain smoke for DSv4 variants.

GitHub Actions CI

See the CI section in the Contributing doc for how to trigger the CI.

Before your PR is "Ready for review"

Pre checks:

  • Make sure you read and followed Contributor guidelines
  • Did you write any new necessary tests?
  • Did you add or update any necessary documentation?
  • Does the PR affect components that are optional to install? (Ex: Numba, Pynini, Apex etc)
    • Reviewer: Does the PR have correct import guards for all optional libraries?

@copy-pr-bot

copy-pr-bot Bot commented May 19, 2026

Copy link
Copy Markdown

This pull request requires additional validation before any workflows can run on NVIDIA's runners.

Pull request vetters can view their responsibilities here.

Contributors can view more details about this message here.

@weijiac0619 weijiac0619 force-pushed the chcui/dsv4-train-pr3562-pr4518 branch from 1c4ed05 to 6f7273a Compare May 28, 2026 23:09
@weijiac0619 weijiac0619 marked this pull request as ready for review May 28, 2026 23:12
@weijiac0619 weijiac0619 requested a review from cuichenx May 28, 2026 23:15
from .deepseek_v4 import (
deepseek_v4_flash_pretrain_muon_config,
deepseek_v4_flash_pretrain_mxfp8_config,
)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

set_deepseek_v4_pipeline_model_parallel_layout is a public function (no underscore prefix) but is not re-exported here or added to __all__, unlike the V3 equivalent set_deepseek_v3_pipeline_model_parallel_layout. If it's intended for user customization of pipeline layouts, it should be exported. If it's internal-only, prefix it with _.

Suggested change
)
from .deepseek_v4 import (
deepseek_v4_flash_pretrain_muon_config,
deepseek_v4_flash_pretrain_mxfp8_config,
set_deepseek_v4_pipeline_model_parallel_layout,
)

Comment on lines +227 to +229
clip_grad=1.0,
)
opt_cfg.optimizer = "muon"

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: distributed_muon_with_cosine_annealing sets optimizer="dist_muon", then line 229 immediately overwrites it to "muon". This controls layer_wise_distributed_optimizer in optim.py (line 105). The intent (non-distributed Muon with no_shard) is correct, but calling a function named distributed_muon_* only to undo the "distributed" part is confusing. Consider adding a brief comment here explaining why it's overridden, e.g. "DSv4 Muon uses non-layer-wise optimizer dispatch".

@claude

claude Bot commented May 28, 2026

Copy link
Copy Markdown
Contributor

Light Code Review - No critical bugs found. Two inline comments posted (missing init export, confusing optimizer override). Main feedback is on test coverage gaps. Test coverage gaps: (1) set_deepseek_v4_pipeline_model_parallel_layout is untested - the V3 equivalent has dedicated tests but the V4 function with real logic (divmod layer distribution, embedding/MTP/loss placement) has none. Currently FakeModelCfg lacks num_layers and mtp_num_layers so the function silently bails out to None. (2) Error paths are untested - ValueError for Muon+MXFP8 and invalid optimizer type are both one-line pytest.raises tests. (3) Mixed-precision assertions for Muon recipe - the MXFP8 test verifies fp8_recipe and fp8_param_gather but the Muon test does not assert anything about mixed_precision. Suggested test cases: No perf tests impacted.

@claude

claude Bot commented May 28, 2026

Copy link
Copy Markdown
Contributor

Light Code Review - No critical bugs found. Two inline comments posted (missing init export, confusing optimizer override). Main feedback is on test coverage gaps. Test coverage gaps: (1) set_deepseek_v4_pipeline_model_parallel_layout is untested - the V3 equivalent has dedicated tests but the V4 function with real logic (divmod layer distribution, embedding/MTP/loss placement) has none, since _FakeModelCfg lacks num_layers and mtp_num_layers so the function silently bails to None. (2) Error paths are untested - ValueError for Muon+MXFP8 and invalid optimizer type are both one-line pytest.raises tests. (3) Mixed-precision assertions for Muon recipe - the MXFP8 test verifies fp8_recipe and fp8_param_gather but the Muon test does not assert anything about mixed_precision (should be plain BF16, no FP8). Suggested test cases: No perf tests impacted.

@weijiac0619 weijiac0619 force-pushed the chcui/dsv4-train-pr3562-pr4518 branch from 6f7273a to 502890d Compare May 28, 2026 23:24
Signed-off-by: weijiac <weijiac@nvidia.com>
@weijiac0619 weijiac0619 force-pushed the chcui/dsv4-train-pr3562-pr4518 branch from 502890d to 24ccffa Compare May 28, 2026 23:25

CASE_NAME="${CASE_NAME:-${RECIPE_NAME}}"
JOB_ID="${SLURM_JOB_ID:-manual}"
OUTDIR="${OUTDIR:-${WORKSPACE}/results/${MODEL_NAME}_${CASE_NAME}_${JOB_ID}}"

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we simplify these -- only keep the parameters that users would possibly change. Keep the other ones at a default value

from megatron.bridge.training.mixed_precision import bf16_mixed, bf16_with_mxfp8_mixed


DSV4_CSA_BACKEND = Literal["unfused", "cudnn_dsa", "tilelang_official", "flashmla_official"]

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we can remove tilelang_official and flashmla_official now, they were added while debugging forward pass parity

@weijiac0619 weijiac0619 force-pushed the chcui/dsv4-train-pr3562-pr4518 branch 2 times, most recently from 9c4d972 to f0efc48 Compare May 29, 2026 00:29
return cfg


def deepseek_v4_flash_pretrain_mxfp8_config(hf_path: str = "deepseek-ai/DeepSeek-V4-Flash") -> ConfigContainer:

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this seems way too many levels created by agent, just use a single at most 2 levels of config system, flatten as much as you can.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

resolved

@weijiac0619 weijiac0619 force-pushed the chcui/dsv4-train-pr3562-pr4518 branch from f0efc48 to 24ccffa Compare May 29, 2026 02:51
@yaoyu-33 yaoyu-33 added area:recipe Training recipes and launch configs feature New capabilities, enhancements, or enablement work needs-more-tests Requires additional L0 and L1 test coverage before merge waiting-on-customer Waiting on the original author to respond labels May 29, 2026
Signed-off-by: weijiac <weijiac@nvidia.com>
Comment thread examples/models/deepseek_v4/README.md
Signed-off-by: weijiac <weijiac@nvidia.com>
Signed-off-by: weijiac <weijiac@nvidia.com>
@cuichenx cuichenx changed the title Chcui/dsv4 train pr3562 pr4518 [recipe] feat: Add DeepSeek-V4-Flash pretraining recipes May 30, 2026
cuichenx
cuichenx previously approved these changes May 30, 2026
@cuichenx cuichenx removed the needs-more-tests Requires additional L0 and L1 test coverage before merge label May 30, 2026
@cuichenx

Copy link
Copy Markdown
Contributor

/ok to test 8ecd047

Signed-off-by: Chen Cui <chcui@nvidia.com>
@cuichenx

Copy link
Copy Markdown
Contributor

/ok to test 0bfb0eb

@cuichenx cuichenx merged commit 0eb1932 into main May 31, 2026
173 of 175 checks passed
@cuichenx cuichenx deleted the chcui/dsv4-train-pr3562-pr4518 branch May 31, 2026 20:28
vasunvidia pushed a commit to vasunvidia/Megatron-Bridge that referenced this pull request Jun 10, 2026
…#3893)

Signed-off-by: weijiac <weijiac@nvidia.com>
Signed-off-by: Chen Cui <chcui@nvidia.com>
Co-authored-by: Chen Cui <chcui@nvidia.com>
Signed-off-by: Vasudevan Rengasamy <vrengasamy@nvidia.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

area:recipe Training recipes and launch configs feature New capabilities, enhancements, or enablement work high-priority waiting-on-customer Waiting on the original author to respond

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants