Skip to content

Apply MIMO SP/CP sharding with explicit groups and enable THD in non-colocated path#5150

Merged
yashaswikarnati merged 1 commit into
NVIDIA:mainfrom
yashaswikarnati:ykarnati/upstream-mimo-sp-cp-sharding
Jun 9, 2026
Merged

Apply MIMO SP/CP sharding with explicit groups and enable THD in non-colocated path#5150
yashaswikarnati merged 1 commit into
NVIDIA:mainfrom
yashaswikarnati:ykarnati/upstream-mimo-sp-cp-sharding

Conversation

@yashaswikarnati

@yashaswikarnati yashaswikarnati commented Jun 4, 2026

Copy link
Copy Markdown
Contributor

What

Make MIMO language-model input sharding correct for non-colocated / heterogeneous parallelism, where the LM owns its own process-group grid.

PartitionAdapter.shard() owns the layout contract — sequence-first (S, B, H) in and out:

  • Explicit TP group for the SP scatter (was the global parallel_state group).
  • Correct SP dim: scatters the sequence; CP transposes to batch-first only internally (for get_batch_on_this_cp_rank), and SP-only does no transpose. Previously the SP path scattered [B, S, H] along the batch dim.
  • loss_mask and packed_seq_params threaded through the LM forward and CP-sharded, so per-token loss and packed/THD sequences line up with the CP-local output on the non-colocated path.
  • No dense attention_mask in the adapter; fast-fail if a dense mask is combined with CP (mask via a causal attn_mask_type / packed_seq_params).

shard() returns (embeddings[S/(cp*tp),B,H], labels[B,S/cp], loss_mask[B,S/cp], packed_seq_params). Single in-core caller (MimoModel).

Backward compatibility

Colocated path unchanged (still returns (output, loss_mask)). Under CP, position_ids are passed full-length — the rotary module shards them internally, so MIMO must not.

Testing

Real 8-GPU tests assert per-rank shape and content for SP-only / CP-only / CP+SP, plus loss-mask CP-sharding and the attention-mask-under-CP guard. 122 passed, 2 skipped on 1 node × 8 GPUs.

No new direct parallel_state.get_*_group() reads.

@copy-pr-bot

copy-pr-bot Bot commented Jun 4, 2026

Copy link
Copy Markdown

This pull request requires additional validation before any workflows can run on NVIDIA's runners.

Pull request vetters can view their responsibilities here.

Contributors can view more details about this message here.

Comment thread megatron/core/models/mimo/model/base.py Outdated
Comment thread megatron/core/models/mimo/model/base.py Outdated
@yashaswikarnati yashaswikarnati force-pushed the ykarnati/upstream-mimo-sp-cp-sharding branch from 645bfa6 to b70b356 Compare June 4, 2026 05:17
@yashaswikarnati

Copy link
Copy Markdown
Contributor Author

Updated per review:

  • Removed the dense-attention_mask CP guard. It was inconsistent (only _shard_language_inputs had it; _forward_all_modules didn't pass the mask in) and dense masks aren't a supported MIMO path. The adapter now simply does not handle a dense attention_mask — documented that masking goes through causal attention / packed_seq_params. Dropped the two guard tests.
  • Narrowed the docstring. Scoped the caveat to CP-local hidden states (not "whenever CP/SP is active"); clarified SP scatters only the embeddings, and labels/loss_mask are CP-sharded but never SP-scattered.
  • Real 8-GPU tests instead of mocks. Replaced the mock-collective sharding tests with TestPartitionAdapterShardRealDistributed, which builds real tensor/context-parallel groups on 8 ranks and verifies per-rank sharded shape and content (positionally) for SP-only ([S/tp,B,H]), CP-only ([S/cp,B,H] zigzag), and combined CP+SP ([S/(cp*tp),B,H]), plus that labels/loss_mask are CP-sharded but not SP-scattered. Renamed the artificial "both disabled" test to reflect the transpose-only contract. Kept the pure-Python validation tests.

Validated on 1 node × 8 GPUs (--experimental): all green, including the three real-distributed sharding tests.

Comment thread megatron/core/models/mimo/model/base.py Outdated
Comment thread megatron/core/models/mimo/model/base.py
Comment thread megatron/core/models/mimo/model/base.py Outdated
Comment thread megatron/core/models/mimo/partition/utils.py Outdated
Comment thread megatron/core/models/mimo/partition/utils.py
Comment thread megatron/core/models/mimo/partition/utils.py
Comment thread megatron/core/models/mimo/partition/utils.py
Make MIMO language-model input sharding correct for non-colocated /
heterogeneous parallelism, where the language model owns its own
process-group grid.

PartitionAdapter.shard() owns a single layout contract:
- Sequence-first input. shard() consumes (S, B, H) embeddings -- the layout
  align_embeddings_by_token_positions produces and the LM consumes -- and
  returns (S/(cp*tp), B, H). It transposes to batch-first only inside the CP
  block (get_batch_on_this_cp_rank requires batch-first); the SP-only path
  scatters dim 0 directly with no transpose round-trip.
- Explicit TP group. scatter_to_sequence_parallel_region is called with the
  adapter's tp_group instead of the global parallel_state group.
- Loss mask threaded through the LM forward and CP-sharded alongside labels,
  so the terminal per-token loss aligns with the CP-local hidden states.

MimoModel:
- Build the adapter only on language-module ranks (encoder-only ranks never
  shard and must not read process groups they do not own).
- Thread packing_kwargs/packed_seq_params through the non-colocated language
  path so packed (THD) sequences are CP-sharded and reach the LM, matching the
  colocated path (shared _build_packed_seq_params helper).
- Reject a dense attention_mask under context parallelism: a dense [B, S] mask
  cannot line up with the CP-sharded sequence, so MIMO masks via a causal
  attn_mask_type or packed_seq_params. position_ids are passed full-length, as
  required by mRoPE (the rotary module CP-shards positions internally).

Tests: real 8-GPU sharding tests (SP-only, CP-only, CP+SP) assert per-rank
shape and content; plus the attention-mask-under-CP guard, loss-mask
CP-sharding, and packed/THD handling. 122 passed, 2 skipped on 1 node x 8 GPUs.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
@yashaswikarnati yashaswikarnati force-pushed the ykarnati/upstream-mimo-sp-cp-sharding branch from b70b356 to 383cd54 Compare June 8, 2026 20:53
@yashaswikarnati

Copy link
Copy Markdown
Contributor Author

Revised per review. All inline comments addressed; verified on 1 node × 8 GPUs (--experimental): 122 passed, 2 skipped (the 2 skips are pre-existing 1f1b tests needing 2/4 GPUs).

Design changes

  • shard() is now sequence-first. It consumes (S, B, H) (what align_embeddings_by_token_positions produces and the LM consumes) and transposes to batch-first only inside the CP block (get_batch_on_this_cp_rank needs batch-first). The SP-only path now scatters dim 0 directly with zero transposes — the previous [S,B,H]→[B,S,H]→[S,B,H] round-trip is gone. (re: "how many transposes…")
  • Dense attention_mask under CP is now rejected in _forward_language_module (fast-fail) — a dense [B,S] mask cannot line up with the CP-sharded sequence; masking goes through a causal attn_mask_type or packed_seq_params. This closes the gap where the non-colocated path forwarded the mask unguarded while the colocated path forces None. (re: "what happened to attention mask / why remove it")
  • packed_seq_params is now threaded through the non-colocated path (packing_kwargs → _forward_language_module → shard() → LM) via a shared _build_packed_seq_params helper, so packed/THD sequences are CP-sharded and reach the LM — matching the colocated path (previously unsupported there).

Inline comments

  • _shard_language_inputs "do we even need this?" — kept, but it's now just the None-guard + shard() call (the transpose moved into shard()); the guard is shared across its 3 call sites. Docstring trimmed.
  • Verbose comments/docstrings (:80, :315, :531, :768, shard() docstring) — trimmed to the load-bearing rationale.
  • Optional type hints — kept; None is reachable on non-first PP stages (embeddings=None, labels-only sharding).
  • Dead is_partitioning_enabled — removed (was test-only) along with its 4 tests.
  • position_ids / mRoPE — left full-length on purpose: the rotary module (RotaryEmbedding/MultimodalRotaryEmbedding) CP-shards positions internally with the same zigzag as the hidden states, so pre-sharding would double-shard. No change.

Tests

Real 8-GPU sharding tests (SP-only, CP-only, CP+SP) assert per-rank shape and content; added the attention-mask-under-CP guard test; kept loss-mask CP-sharding coverage.

@yashaswikarnati yashaswikarnati marked this pull request as ready for review June 8, 2026 21:25
@yashaswikarnati yashaswikarnati requested review from a team as code owners June 8, 2026 21:25
@svcnvidia-nemo-ci svcnvidia-nemo-ci added Final Review PR is in the "final review" stage complexity: medium labels Jun 8, 2026
Comment thread tests/unit_tests/models/mimo/test_mimo_partition.py
@yashaswikarnati yashaswikarnati changed the title Apply MIMO SP/CP sharding with explicit groups and thread loss mask Apply MIMO SP/CP sharding with explicit groups and enable THD in non-colocated path Jun 8, 2026
@yashaswikarnati

Copy link
Copy Markdown
Contributor Author

/ok to test 383cd54

@svcnvidia-nemo-ci svcnvidia-nemo-ci added Approved All necessary approvals have been made and removed Final Review PR is in the "final review" stage labels Jun 9, 2026
@yashaswikarnati yashaswikarnati added this pull request to the merge queue Jun 9, 2026
@svcnvidia-nemo-ci

Copy link
Copy Markdown

🔄 Merge queue validation started!

You can track the progress here: https://github.com/NVIDIA/Megatron-LM/actions/runs/27224354552

Merged via the queue into NVIDIA:main with commit 3920476 Jun 9, 2026
92 of 95 checks passed
@yashaswikarnati yashaswikarnati deleted the ykarnati/upstream-mimo-sp-cp-sharding branch June 9, 2026 18:27
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Approved All necessary approvals have been made complexity: medium

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants