Apply MIMO SP/CP sharding with explicit groups and enable THD in non-colocated path#5150
Merged
yashaswikarnati merged 1 commit intoJun 9, 2026
Conversation
yashaswikarnati
commented
Jun 4, 2026
yashaswikarnati
commented
Jun 4, 2026
645bfa6 to
b70b356
Compare
Contributor
Author
|
Updated per review:
Validated on 1 node × 8 GPUs ( |
yashaswikarnati
commented
Jun 8, 2026
yashaswikarnati
commented
Jun 8, 2026
yashaswikarnati
commented
Jun 8, 2026
yashaswikarnati
commented
Jun 8, 2026
yashaswikarnati
commented
Jun 8, 2026
yashaswikarnati
commented
Jun 8, 2026
yashaswikarnati
commented
Jun 8, 2026
Make MIMO language-model input sharding correct for non-colocated / heterogeneous parallelism, where the language model owns its own process-group grid. PartitionAdapter.shard() owns a single layout contract: - Sequence-first input. shard() consumes (S, B, H) embeddings -- the layout align_embeddings_by_token_positions produces and the LM consumes -- and returns (S/(cp*tp), B, H). It transposes to batch-first only inside the CP block (get_batch_on_this_cp_rank requires batch-first); the SP-only path scatters dim 0 directly with no transpose round-trip. - Explicit TP group. scatter_to_sequence_parallel_region is called with the adapter's tp_group instead of the global parallel_state group. - Loss mask threaded through the LM forward and CP-sharded alongside labels, so the terminal per-token loss aligns with the CP-local hidden states. MimoModel: - Build the adapter only on language-module ranks (encoder-only ranks never shard and must not read process groups they do not own). - Thread packing_kwargs/packed_seq_params through the non-colocated language path so packed (THD) sequences are CP-sharded and reach the LM, matching the colocated path (shared _build_packed_seq_params helper). - Reject a dense attention_mask under context parallelism: a dense [B, S] mask cannot line up with the CP-sharded sequence, so MIMO masks via a causal attn_mask_type or packed_seq_params. position_ids are passed full-length, as required by mRoPE (the rotary module CP-shards positions internally). Tests: real 8-GPU sharding tests (SP-only, CP-only, CP+SP) assert per-rank shape and content; plus the attention-mask-under-CP guard, loss-mask CP-sharding, and packed/THD handling. 122 passed, 2 skipped on 1 node x 8 GPUs. Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
b70b356 to
383cd54
Compare
Contributor
Author
|
Revised per review. All inline comments addressed; verified on 1 node × 8 GPUs ( Design changes
Inline comments
TestsReal 8-GPU sharding tests (SP-only, CP-only, CP+SP) assert per-rank shape and content; added the attention-mask-under-CP guard test; kept loss-mask CP-sharding coverage. |
yashaswikarnati
commented
Jun 8, 2026
yaoyu-33
approved these changes
Jun 8, 2026
Contributor
Author
|
/ok to test 383cd54 |
kvareddy
approved these changes
Jun 9, 2026
|
🔄 Merge queue validation started! You can track the progress here: https://github.com/NVIDIA/Megatron-LM/actions/runs/27224354552 |
71 tasks
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
What
Make MIMO language-model input sharding correct for non-colocated / heterogeneous parallelism, where the LM owns its own process-group grid.
PartitionAdapter.shard()owns the layout contract — sequence-first(S, B, H)in and out:parallel_stategroup).get_batch_on_this_cp_rank), and SP-only does no transpose. Previously the SP path scattered[B, S, H]along the batch dim.loss_maskandpacked_seq_paramsthreaded through the LM forward and CP-sharded, so per-token loss and packed/THD sequences line up with the CP-local output on the non-colocated path.attention_maskin the adapter; fast-fail if a dense mask is combined with CP (mask via a causalattn_mask_type/packed_seq_params).shard()returns(embeddings[S/(cp*tp),B,H], labels[B,S/cp], loss_mask[B,S/cp], packed_seq_params). Single in-core caller (MimoModel).Backward compatibility
Colocated path unchanged (still returns
(output, loss_mask)). Under CP,position_idsare passed full-length — the rotary module shards them internally, so MIMO must not.Testing
Real 8-GPU tests assert per-rank shape and content for SP-only / CP-only / CP+SP, plus loss-mask CP-sharding and the attention-mask-under-CP guard. 122 passed, 2 skipped on 1 node × 8 GPUs.
No new direct
parallel_state.get_*_group()reads.