Skip to content

[Dev] fix(mtp): use padded cu_seqlens in MTP roll for THD with CP#4494

Merged
BestJuly merged 2 commits into
NVIDIA:devfrom
BestJuly:lit/fix_mtp_thd_odd_seqlen_dev
May 12, 2026
Merged

[Dev] fix(mtp): use padded cu_seqlens in MTP roll for THD with CP#4494
BestJuly merged 2 commits into
NVIDIA:devfrom
BestJuly:lit/fix_mtp_thd_odd_seqlen_dev

Conversation

@BestJuly

@BestJuly BestJuly commented Apr 28, 2026

Copy link
Copy Markdown
Contributor

Summary

  • Fixes a correctness bug in _roll_tensor_packed_seq (Multi-Token Prediction, THD packed sequences with Context Parallelism).
  • For CP>1, the local THD layout is produced by tex.thd_get_partitioned_indices(cu_seqlens_padded, ...) and requires every per-sequence padded length to be divisible by 2*cp_size. The roll function indexed local chunks with the unpadded cu_seqlens_q, so when real seqlens are not multiples of 2*cp_size (e.g. odd lengths), // cp_size produced wrong local boundaries: chunk(2) split unevenly, neighbour send/recv tensors had mismatched sizes, and tokens leaked across sequence boundaries (very small seqs hit IndexError).
  • Prefer cu_seqlens_q_padded when provided, with fallback to cu_seqlens_q (matches the convention already used in attention.py).
  • Adds parametrized unit test test_roll_tensor_with_packed_sequences_odd_seqlen covering odd seqlens for CP=1 and CP=2.

Test plan

  • pytest tests/unit_tests/transformer/test_multi_token_prediction.py::TestMultiTokenPrediction::test_roll_tensor_with_packed_sequences_odd_seqlen — CP=1 PASSED on a single GPU.
  • torchrun --nproc_per_node=2 -m pytest tests/unit_tests/transformer/test_multi_token_prediction.py::TestMultiTokenPrediction::test_roll_tensor_with_packed_sequences_odd_seqlen — both CP=1 and CP=2 PASSED.
  • Existing test_roll_tensor_with_packed_sequences (CP=1 / CP=2) still PASSES, no regression.

Reproduction (before fix, CP=2)

With odd seqlens [7, 11] padded to [8, 12]:

rank 0 expected = [2, 3, 0, 0, 12, 13, 14, 21, 0, 0]
rank 0 got      = [2, 3, 0, 8,  9,  6, 17, 18, 0, 0]   <-- tokens 6, 8, 17 leak

With sequences smaller than 2*cp_size (e.g. [3] padded to [4]), the un-fixed code raises IndexError.

Cross-reference

PR for main: #4495

`_roll_tensor_packed_seq` indexed local chunks with the unpadded
`cu_seqlens_q`, but with CP the THD layout is produced by
`tex.thd_get_partitioned_indices(cu_seqlens_padded, ...)`, which requires
each per-sequence padded length to be divisible by 2*cp_size. When real
seqlens are not multiples of 2*cp_size (e.g. odd lengths), `// cp_size`
gave the wrong local boundaries: `chunk(2)` split unevenly, neighbour
sends/recvs had mismatched sizes, and tokens leaked across sequence
boundaries (very small seqs hit IndexError).

Prefer `cu_seqlens_q_padded` when provided (with fallback to the
unpadded version), matching the convention already used in
`attention.py`.

Also adds a parametrized unit test covering odd seqlens for CP=1 and
CP=2 (`test_roll_tensor_with_packed_sequences_odd_seqlen`), with the
CP=2 case using padded `[8, 12]` over real `[7, 11]`.

Co-Authored-By: Jingliang Li <jinliangl@nvidia.com>
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
@copy-pr-bot

copy-pr-bot Bot commented Apr 28, 2026

Copy link
Copy Markdown

Auto-sync is disabled for draft pull requests in this repository. Workflows must be run manually.

Contributors can view more details about this message here.

@BestJuly BestJuly marked this pull request as ready for review April 28, 2026 14:13
@BestJuly BestJuly requested review from a team as code owners April 28, 2026 14:13
@wplf wplf self-requested a review April 28, 2026 14:57
@BestJuly BestJuly changed the title fix(mtp): use padded cu_seqlens in MTP roll for THD with CP [Dev] fix(mtp): use padded cu_seqlens in MTP roll for THD with CP Apr 29, 2026
@BestJuly BestJuly added this pull request to the merge queue May 12, 2026
@svcnvidia-nemo-ci

Copy link
Copy Markdown

🔄 Merge queue validation started!

You can track the progress here: https://github.com/NVIDIA/Megatron-LM/actions/runs/25719970914

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants