Skip to content

Minor improvements for Dynamic-cp#4226

Merged
yuzhongw-nvidia merged 9 commits into
NVIDIA:devfrom
xiaoyao0115:dynamic-cp-gdn
Jun 8, 2026
Merged

Minor improvements for Dynamic-cp#4226
yuzhongw-nvidia merged 9 commits into
NVIDIA:devfrom
xiaoyao0115:dynamic-cp-gdn

Conversation

@xiaoyao0115

@xiaoyao0115 xiaoyao0115 commented Apr 9, 2026

Copy link
Copy Markdown
Contributor

What does this PR do ?

Dynamic CP support for GDN and MTP, data scheduling adaptations

Summary

  1. Add Dynamic CP support for Gated Delta Net (GDN), including Triton autotune padding to avoid repeated kernel recompilation
  2. Add Dynamic CP support for MTP
  3. Add MLA + Dynamic CP support, including MultiLatentAttention, AbsorbedMLA, and DSv4HybridAttention paths
  4. Adapt data scheduling to the new data_iterator creation policy introduced by upstream PR Fix quantize.py script and support packed sequences in pretrain_gpt.py #3564 — middle PP stages now create their own data_iterators, so broadcast_to_pp_group is removed and per-stage data field stripping is added

Changes

1. GDN Dynamic CP support (gated_delta_net.py)

  • All CP-related operations (tensor_a2a_cp2hp/tensor_a2a_hp2cp, get_parameter_local_cp, conv1d groups, QKV split dimensions, etc.) now use the dynamic CP group resolved from packed_seq_params
  • Triton autotune padding: FLA's causal_conv1d Triton kernel includes NB = cdiv(total_tokens, 1024) in its autotune key. Dynamic CP causes total_tokens to vary per microbatch, triggering repeated Triton autotuning (hundreds of ms each). Input is now padded to _CONV_PAD_ALIGNMENT = 4096 boundaries to collapse NB into far fewer buckets. Implementation: zero-pad the input tensor along the sequence dimension, adjust cu_seqlens[-1] to include padding, run conv1d, then strip padding from the output

2. MTP dynamic CP group fix (gpt_model.py, multi_token_prediction.py)

  • MultiTokenPredictionLayer._roll_input_ids_and_position_ids: roll_tensor now uses resolve_cp_group() to resolve the dynamic CP group from packed_seq_params, instead of the static self.cp_group
  • GPTModel.forward: process_mtp_loss now passes the dynamic CP group resolved from packed_seq_params instead of the static self.pg_collection.cp

3. Data scheduling adaptations (data_schedule.py, data_schedule_utils.py)

Upstream PR #3564 (Fix quantize.py script and support packed sequences in pretrain_gpt.py) changed is_dataset_built_on_rank and get_batch in pretrain_gpt.py so that in packed-sequence (SFT) mode, all PP stages (including middle ones) build datasets and receive data_iterators. Middle stages get a PackedSeqParams with cu_seqlens/max_seqlen directly from their own iterator. Previously only first/last PP stages had data_iterators, and middle stages relied on broadcast_to_pp_group to receive metadata from the first PP stage.

Adaptations:

  • Remove broadcast_to_pp_group: Middle PP stages now participate in scheduling and have their own data_iterators, making the PP-group metadata broadcast unnecessary. The function (~95 lines) and its numpy dependency are removed
  • Per-stage data field stripping: Before the all-to-all reroute, data fields not needed by the current PP stage are stripped from samples. First PP keeps only tokens/position_ids, last PP keeps only labels/loss_mask, MTP stages keep all four. Middle stages only carry metadata (cu_seqlens, max_seqlen, etc.) and do not participate in data tensor all-to-all

⚠️ For major changes (either in lines of code or in its impact), please make sure to first share a design doc with the team. If you're unsure what's the best way to do so, contact the @mcore-oncall.

Contribution process

Pre-checks

  • I have added relevant unit tests
  • I have added relevant functional tests
  • I have added proper typing to my code Typing guidelines
  • I have added relevant documentation
  • I have run the autoformatter.sh on my PR

Code review

Feel free to message or comment the @mcore-oncall to help accelerate your merge into main. The less complex your PR is, the faster it will be approved and merged!

All PRs start as draft. If you open a non-draft PR, it will be automatically converted to draft.

Step 1: Mark PR as "Ready for Review"

  1. When your PR is ready, click Ready for Review.
  2. An oncall reviewer is auto-assigned and expert reviewers are notified based on your changes.
    • Some PRs may jump straight to step 2. This is determined by .github/CODEOWNERS.

⚠️ Only mark as ready once merge-conflicts are resolved and the CI is passing.
Final Review might get declined if these requirements are not fulfilled.

Step 2: Final Review

For PRs that change megatron/core, once all expert reviewers have approved, the Final Review label is applied automatically and final reviewers are assigned.

For PRs outside megatron/core, this step is skipped.

Step 3: Approved

Once all required reviewers have approved, the Approved label is applied automatically.

Merge

Any member of mcore-engineers will be able to merge your PR.

For MRs into `dev` branch The proposed review process for `dev` branch is under active discussion.

MRs are mergable after one approval by either eharper@nvidia.com or zijiey@nvidia.com.

@xiaoyao0115 xiaoyao0115 requested review from a team as code owners April 9, 2026 08:46
@copy-pr-bot

copy-pr-bot Bot commented Apr 9, 2026

Copy link
Copy Markdown

This pull request requires additional validation before any workflows can run on NVIDIA's runners.

Pull request vetters can view their responsibilities here.

Contributors can view more details about this message here.

@xiaoyao0115 xiaoyao0115 changed the title support dcp+gdn and fix for vpp+dcp Minor improvements for Dynamic-cp Apr 9, 2026
@xiaoyao0115

Copy link
Copy Markdown
Contributor Author

/ok to test b169e13

@svcnvidia-nemo-ci svcnvidia-nemo-ci added this to the Core 0.16 milestone Apr 14, 2026
@Victarry

Copy link
Copy Markdown
Contributor

/claude strict-review

Comment thread megatron/core/ssm/gated_delta_net.py Outdated
Comment on lines +471 to +472
if _pad_n > 0:
qkv = qkv[:, :_orig_seq, :]

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[CRITICAL Implementation] Related to the same issue above — this output stripping also needs to be guarded against the case where cu_seqlens_q is None and _pad_n > 0. Without the guard on lines 458-461, this line is unreachable (the code would have already crashed), but once the guard is applied above, you need to ensure the stripping only happens when padding was actually applied.

Consider:

            if _pad_n > 0 and cu_seqlens_q is not None:
                qkv = qkv[:, :_orig_seq, :]

Or, more cleanly, track whether padding was actually applied (since without cu_seqlens the padding itself is still harmless—only the cu_seqlens mutation is the problem). If you keep the padding unconditionally, this stripping is still needed. But if you guard the entire padding block on cu_seqlens_q is not None, then this stripping should be guarded the same way.

Comment on lines +259 to +267
keys_to_keep = {'original_seq_len', 'padded_seq_len'}
if is_first_pp or mtp_on_this_pp:
keys_to_keep.update(['tokens', 'position_ids'])
if is_last_pp or mtp_on_this_pp:
keys_to_keep.update(['labels', 'loss_mask'])
for sample in batch:
for key in list(sample.keys()):
if key not in keys_to_keep:
del sample[key]

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[SUGGESTION Simplification] The stripping logic is correct and well-structured. One minor robustness note: this assumes the only keys in the raw (unpacked) samples are {tokens, labels, loss_mask, position_ids, original_seq_len, padded_seq_len}. If a custom dataset adds extra metadata keys beyond these (e.g., a sample_id or language_tag), they would be silently dropped.

This is unlikely to be a problem in practice since _unpack_batch only produces those six keys, but a brief inline comment noting the assumption (that only these keys exist post-unpack) would help future maintainers.

Comment thread megatron/core/datasets/data_schedule.py Outdated
Comment thread megatron/core/packed_seq_params.py
Comment thread megatron/core/ssm/gated_delta_net.py
@claude

claude Bot commented Apr 14, 2026

Copy link
Copy Markdown
Contributor

Code Review Summary

PR: Minor improvements for Dynamic-cp — Adds dynamic CP support for GDN and MTP, adapts data scheduling to the new per-stage data_iterator policy from PR #3564.

Findings

Severity Count
CRITICAL 1
IMPORTANT 0
SUGGESTION 4

Critical

  • Conv1d padding crash in gated_delta_net.py: The new Triton autotune padding logic calls cu_seqlens_q.clone() unconditionally, but cu_seqlens_q is None when packed_seq_params is None (non-packed-sequence mode). If deterministic_mode is also False (the default) and the sequence length is not a multiple of 4096, this crashes with AttributeError. This is a regression — the code worked before this PR. Fix: guard the cu_seqlens manipulation on cu_seqlens_q is not None.

What looks good

  • resolve_cp_group helper (packed_seq_params.py): Clean centralization of the repeated if packed_seq_params is not None and packed_seq_params.cp_group is not None pattern used across GPTModel, GDN, MambaMixer, and MTP layers.
  • Data scheduling adaptation: The removal of broadcast_to_pp_group and addition of per-stage data field stripping is well-designed. The stripping correctly preserves original_seq_len/padded_seq_len (needed by build_packed_microbatches) while removing unnecessary data tensors before the all-to-all reroute. The _unpack_batch → strip → reroute → _pack_sequences pipeline handles missing data keys correctly throughout.
  • VPP vpp_needs_data logic: The semantic-based approach (first/last/MTP stages need data) is more correct than the old vpp_has_data approach (which was based on iterator presence). The independent metadata creation per non-data VPP stage prevents shared-reference mutation bugs.
  • MTP dynamic CP: resolve_cp_group is correctly used in both _get_embeddings (for roll_tensor) and forward (for the temporary self.cp_group override), and _orig_cp_group is properly restored at line 1181.
  • MambaMixer: The identity check _resolved_cp_group is not _orig_cp_group avoids unnecessary set_context_parallel_group calls — a nice micro-optimization.
  • Head-parallel CP assertions in GDN __init__: Good addition catching configurations that would have silently produced incorrect results.
  • CUDA graph scope guard in utils.py: Correctly prevents dynamic-shape cu_seqlens broadcasts inside full_iteration captured regions.

Risk Assessment

Medium risk due to the critical conv1d padding bug. The data scheduling changes are well-structured but affect a complex subsystem (dynamic CP × VPP × MTP × per-stage data routing), so thorough integration testing with multi-stage pipelines, VPP, and MTP is recommended. The GDN and MTP dynamic CP changes are straightforward resolve_cp_group substitutions and look correct.

@xiaoyao0115

Copy link
Copy Markdown
Contributor Author

/ok to test b850cee

@xiaoyao0115

Copy link
Copy Markdown
Contributor Author

/ok to test b642842

@xiaoyao0115

Copy link
Copy Markdown
Contributor Author

/ok to test 570ff74

@xiaoyao0115

Copy link
Copy Markdown
Contributor Author

/ok to test daf2e0a

@xiaoyao0115

Copy link
Copy Markdown
Contributor Author

/ok to test daf2e0a

@xiaoyao0115

Copy link
Copy Markdown
Contributor Author

/ok to test 0bede0d

@svcnvidia-nemo-ci

Copy link
Copy Markdown

🔄 Merge queue validation started!

You can track the progress here: https://github.com/NVIDIA/Megatron-LM/actions/runs/27022438250

@github-merge-queue github-merge-queue Bot removed this pull request from the merge queue due to no response for status checks Jun 5, 2026
@xiaoyao0115 xiaoyao0115 added this pull request to the merge queue Jun 6, 2026
@svcnvidia-nemo-ci

Copy link
Copy Markdown

🔄 Merge queue validation started!

You can track the progress here: https://github.com/NVIDIA/Megatron-LM/actions/runs/27066685670

@github-merge-queue github-merge-queue Bot removed this pull request from the merge queue due to failed status checks Jun 6, 2026
@yuzhongw-nvidia yuzhongw-nvidia added this pull request to the merge queue Jun 8, 2026
@svcnvidia-nemo-ci

Copy link
Copy Markdown

🔄 Merge queue validation started!

You can track the progress here: https://github.com/NVIDIA/Megatron-LM/actions/runs/27112839305

xiaoyao0115 and others added 9 commits June 7, 2026 19:38
Signed-off-by: xiaoyao0115 <1804647152@qq.com>
Signed-off-by: tailaim <tailaim@nvidia.com>
When token counts differ across ranks (e.g. Dynamic CP) or microbatches,
locally dividing the loss by num_tokens before all-reducing produces a
biased per-token loss in the logs. Switch MTPLossLoggingHelper to store
raw loss sums and token counts, perform a single all-reduce over the
packed [sums, counts] tensor, then compute sum/sum afterwards. This
yields the correct weighted-average per-token loss regardless of
per-rank token-count imbalance.

Drop the 1/num_microbatches scaling in training.py since the tracker
already returns the per-token loss aggregated across all ranks and
microbatches; no further scaling is needed.

Signed-off-by: xiaoyao0115 <1804647152@qq.com>
Signed-off-by: xiaoyao0115 <1804647152@qq.com>
Signed-off-by: xiaoyao0115 <1804647152@qq.com>
Signed-off-by: xiaoyao0115 <1804647152@qq.com>
Signed-off-by: xiaoyao0115 <1804647152@qq.com>
Signed-off-by: xiaoyao0115 <1804647152@qq.com>
Signed-off-by: xiaoyao0115 <1804647152@qq.com>
@xiaoyao0115

Copy link
Copy Markdown
Contributor Author

/ok to test 3de8fbe

@github-merge-queue github-merge-queue Bot removed this pull request from the merge queue due to no response for status checks Jun 8, 2026
@yuzhongw-nvidia yuzhongw-nvidia enabled auto-merge June 8, 2026 09:24
@yuzhongw-nvidia yuzhongw-nvidia added this pull request to the merge queue Jun 8, 2026
@svcnvidia-nemo-ci

Copy link
Copy Markdown

🔄 Merge queue validation started!

You can track the progress here: https://github.com/NVIDIA/Megatron-LM/actions/runs/27135364923

Merged via the queue into NVIDIA:dev with commit 959a542 Jun 8, 2026
65 of 66 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants