Skip to content

refactor(data): consolidate get_batch and enable PP for SFT THD#4103

Merged
ericharper merged 35 commits into
NVIDIA:mainfrom
asolergi-nv:pr1-get-batch-refactoring
May 21, 2026
Merged

refactor(data): consolidate get_batch and enable PP for SFT THD#4103
ericharper merged 35 commits into
NVIDIA:mainfrom
asolergi-nv:pr1-get-batch-refactoring

Conversation

@asolergi-nv

@asolergi-nv asolergi-nv commented Apr 1, 2026

Copy link
Copy Markdown
Contributor

What does this PR do ?

Unifies batch-processing utilities (get_batch, get_batch_on_this_tp_rank, get_batch_on_this_cp_rank) behind a single framework that handles pretraining, SFT (THD packed sequences), and Hybrid Context Parallel, and enables pipeline parallelism for SFT with THD layout.

Changes

Core refactor (megatron/core/utils.py, megatron/training/utils.py)

  • Move get_batch_on_this_tp_rank / get_batch_on_this_cp_rank from megatron/training/utils.py to megatron/core/utils.py.
  • Rewrite get_batch_on_this_tp_rank around an explicit, length-prefixed broadcast protocol for THD metadata (cu_seqlens, cu_seqlens_padded, max_seqlen, local_cp_size), with dedicated branches for PP-first / PP-last / PP-intermediate / MTP stages.
  • Split CP partitioning into three strategies dispatched by get_batch_on_this_cp_rank:
    • get_sft_batch_on_this_cp_rank — THD index-based partitioning via tex.thd_get_partitioned_indices (SFT).
    • get_pretrain_batch_on_this_cp_rank — zigzag load-balanced chunking (pretrain / Hybrid CP).
    • A dispatcher that selects between them based on cu_seqlens presence and is_hybrid_cp.

PP for SFT (THD)

  • get_batch in both entrypoints now broadcasts only the required THD metadata (cu_seqlens, cu_seqlens_padded, max_seqlen) to intermediate PP stages; tokens/labels/loss_mask/position_ids are intentionally left None.
  • get_sft_batch_on_this_cp_rank CP-partitions whichever of tokens / labels / loss_mask / position_ids are present on the current PP stage

MTP API cleanup (megatron/core/transformer/multi_token_prediction.py)

  • mtp_on_this_rank now takes layout: PipelineParallelLayerLayout and mtp_num_layers: int directly instead of a full TransformerConfig.

Entry scripts (pretrain_gpt.py, pretrain_mamba.py)

  • Unified, sorted BATCH_KEYS tuple; single call into the new core utilities.
  • Build PackedSeqParams from the returned tensors, passing Python ints (int(max_seqlen.item()), int(local_cp_size.item())) to match PackedSeqParams' declared types.
  • pretrain_mamba.py: also pass total_tokens=int(cu_seqlens_for_params[-1].item()) so PackedSeqParams.post_init populates seq_idx for the Mamba SSM kernel (prevents state bleeding across packed sequence boundaries). Works at every PP stage and for any CP size.

Tests (tests/unit_tests/data/test_get_batch.py)

  • New suite covering:
    • test_sft_batch — SFT + TP/PP/CP product over {1,2,4}, seq_length ∈ {1024, 4096}, with per-PP-stage shape / dtype / mask / metadata assertions (including the new intermediate-PP branch).
    • test_pretrain_batch — pretrain path across TP/PP/CP × micro_batch_size ∈ {1,4} × create_attention_mask ∈ {True, False}.
    • test_hybrid_cp_batch — Hybrid CP broadcasts and zigzag partitioning.

Contribution process

Pre-checks

  • I have added relevant unit tests
  • I have added relevant functional tests
  • I have added proper typing to my code Typing guidelines
  • I have added relevant documentation
  • I have run the autoformatter.sh on my PR

Code review

Feel free to message or comment the @mcore-oncall to help accelerate your merge into main. The less complex your PR is, the faster it will be approved and merged!

All PRs start as draft. If you open a non-draft PR, it will be automatically converted to draft.

Step 1: Mark PR as "Ready for Review"

  1. When your PR is ready, click Ready for Review.
  2. An oncall reviewer is auto-assigned and expert reviewers are notified based on your changes.
    • Some PRs may jump straight to step 2. This is determined by .github/CODEOWNERS.

⚠️ Only mark as ready once merge-conflicts are resolved and the CI is passing.
Final Review might get declined if these requirements are not fulfilled.

Step 2: Final Review

For PRs that change megatron/core, once all expert reviewers have approved, the Final Review label is applied automatically and final reviewers are assigned.

For PRs outside megatron/core, this step is skipped.

Step 3: Approved

Once all required reviewers have approved, the Approved label is applied automatically.

Merge

Any member of mcore-engineers will be able to merge your PR.

For MRs into `dev` branch The proposed review process for `dev` branch is under active discussion.

MRs are mergable after one approval by either eharper@nvidia.com or zijiey@nvidia.com.

asolergi-nv and others added 3 commits April 1, 2026 20:01
…ank, and mtp_on_this_rank

Move batch processing functions from megatron/training/utils.py to megatron/core/utils.py,
introducing a unified batch handling framework that supports pretraining, SFT (THD packed
sequences), and hybrid context parallelism modes. Refactor mtp_on_this_rank to accept
explicit layout and mtp_num_layers parameters instead of the full TransformerConfig.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
@copy-pr-bot

copy-pr-bot Bot commented Apr 1, 2026

Copy link
Copy Markdown

This pull request requires additional validation before any workflows can run on NVIDIA's runners.

Pull request vetters can view their responsibilities here.

Contributors can view more details about this message here.

@asolergi-nv

Copy link
Copy Markdown
Contributor Author

/ok to test dbdd34c

@asolergi-nv

Copy link
Copy Markdown
Contributor Author

/ok to test 8c4c82b

@asolergi-nv

Copy link
Copy Markdown
Contributor Author

/ok to test 0d5d42b

@asolergi-nv

Copy link
Copy Markdown
Contributor Author

/ok to test bead714

@asolergi-nv

Copy link
Copy Markdown
Contributor Author

/ok to test a7da44a

@asolergi-nv

Copy link
Copy Markdown
Contributor Author

/ok to test 9f6af12

@asolergi-nv

Copy link
Copy Markdown
Contributor Author

/ok to test b309cbc

@yashaswikarnati yashaswikarnati left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@asolergi-nv

Copy link
Copy Markdown
Contributor Author

/ok to test 1bdd865

@ericharper ericharper enabled auto-merge May 21, 2026 16:27
@svcnvidia-nemo-ci svcnvidia-nemo-ci added the Approved All necessary approvals have been made label May 21, 2026
@ericharper ericharper added this pull request to the merge queue May 21, 2026
@svcnvidia-nemo-ci

Copy link
Copy Markdown

🔄 Merge queue validation started!

You can track the progress here: https://github.com/NVIDIA/Megatron-LM/actions/runs/26239906766

Merged via the queue into NVIDIA:main with commit 5e4fc93 May 21, 2026
143 of 147 checks passed
@xuantengh

Copy link
Copy Markdown
Contributor

It seems that this PR breaks the MBridge:

https://github.com/NVIDIA-NeMo/Megatron-Bridge/blob/4c1f95a021959bc333a167a2a2bb48a96c8f2341/src/megatron/bridge/training/gpt_step.py#L184-L187

When running with CP without sequence packing, it fails with:

TypeError: get_batch_on_this_cp_rank() missing 1 required positional argument: 'is_hybrid_cp'

janEbert pushed a commit to janEbert/Megatron-LM that referenced this pull request Jun 2, 2026
…IA#4103)

Co-authored-by: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Co-authored-by: claude[bot] <209825114+claude[bot]@users.noreply.github.com>
mathemakitten pushed a commit to mathemakitten/Megatron-LM that referenced this pull request Jun 12, 2026
…IA#4103)

Co-authored-by: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Co-authored-by: claude[bot] <209825114+claude[bot]@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Approved All necessary approvals have been made complexity: high Run functional tests

Projects

None yet

Development

Successfully merging this pull request may close these issues.

10 participants