Skip to content

Add MIMO hetero topology + distributed bootstrap (examples/mimo training-loop folder)#5260

Merged
yashaswikarnati merged 1 commit into
NVIDIA:mainfrom
yashaswikarnati:ykarnati/upstream-trainloop-mm1-topology
Jun 11, 2026
Merged

Add MIMO hetero topology + distributed bootstrap (examples/mimo training-loop folder)#5260
yashaswikarnati merged 1 commit into
NVIDIA:mainfrom
yashaswikarnati:ykarnati/upstream-trainloop-mm1-topology

Conversation

@yashaswikarnati

Copy link
Copy Markdown
Contributor

Adds per-module HyperCommGrid topology and a distributed bootstrap for hetero MIMO training under examples/mimo/training/, building each module's process groups (TP/CP/PP/DP, expert views, and language embedding groups) and packaging them into a MultiModuleProcessGroupCollection.

Why

Enables MIMO to run on the stock megatron/training loop via an explicit ProcessGroupCollection rather than parallel_state globals. Process groups are owned by HyperCommGrid, not parallel_state. The default None preserves current behavior; there are no changes to megatron/core and no homogeneous/non-MIMO behavior changes.

Testing

Real cog 8-GPU distributed unit test (no mocks), run name mm1-topology-test: 4 passed (test_grids_partition_world, test_pgc_group_sizes, test_validate_rejects_overlapping_not_equal, test_validate_rejects_gap_in_world_coverage). The homogeneous goldens remain the next CI gate.

Validation now enforces that module grids partition the world [0, world_size) with no gaps in addition to the pairwise-disjoint-XOR-fully-shared invariant.

Stacking

Standalone on origin/main (no stacking).

🤖 Generated with Claude Code

@copy-pr-bot

copy-pr-bot Bot commented Jun 10, 2026

Copy link
Copy Markdown

This pull request requires additional validation before any workflows can run on NVIDIA's runners.

Pull request vetters can view their responsibilities here.

Contributors can view more details about this message here.

Comment thread examples/mimo/training/distributed.py
Comment thread examples/mimo/training/distributed.py Outdated
Comment thread examples/mimo/training/distributed.py Outdated
Comment thread examples/mimo/training/topology.py Outdated
Comment thread examples/mimo/training/topology.py Outdated
Comment thread examples/mimo/training/topology.py Outdated
Comment thread examples/mimo/training/topology.py Outdated
Comment thread examples/mimo/training/topology.py Outdated
Comment thread examples/mimo/training/topology.py Outdated
@yashaswikarnati yashaswikarnati force-pushed the ykarnati/upstream-trainloop-mm1-topology branch from 2d2bc6e to 75e2d04 Compare June 10, 2026 06:39
@yashaswikarnati

Copy link
Copy Markdown
Contributor Author

Review comment resolutions

Mapping each review comment to its resolution in the latest revision (75e2d04):

  1. Distributed init kept (why stock cannot replace it): distributed.py retains a torch.distributed + global-memory-buffer bootstrap. Stock mpu.initialize_model_parallel cannot be used because it materializes the model-parallel globals for a single homogeneous world tiling; the hetero MIMO topology needs per-module grids with disjoint rank offsets, so we bring up torch.distributed WITHOUT initializing MPU and assert the MPU globals stay uninitialized.
  2. Assert simplified: assert_parallel_state_uninitialized is now a compact loop over the model-parallel globals (DP/TP/PP/CP/embd/pos_embd) rather than a call to model_parallel_is_initialized, so a partial/leaked init is caught.
  3. print_rank_0 reused: now imported from megatron.training.utils instead of redefined locally.
  4. specs[-1] replaced: the language module is now selected via the explicit is_language_module flag on the spec; create_topology validates exactly one spec sets it.
  5. language_module_name + side dict deleted: removed; the layout is driven by is_language_module / RankRole / ModuleLayout.
  6. Embedding groups in PGC: language embedding groups now live in PGC.embd / PGC.pos_embd, built collectively via parallel_state.default_embedding_ranks / default_position_embedding_ranks as real ProcessGroups (encoder modules get None).
  7. is_current_rank_in_grid filter kept + explained: retained to scope per-module construction to ranks actually in the grid; commented to explain why.
  8. GroupMember sentinel replaced: membership is now checked via get_rank(group) >= 0 (handles -1 for non-members) instead of the GroupMember sentinel.
  9. Validation preserved: the world-tiling invariant (grids tile the world disjointly XOR fully share ranks) is preserved and exercised by the test, including rejection of overlapping-but-not-equal and gap-in-coverage layouts.

Validation

8-GPU real-distributed run on cw-dfw: tests/unit_tests/test_mimo_hetero_topology.py5 passed (no skips):
test_grids_partition_world, test_pgc_group_sizes, test_embedding_groups, test_validate_rejects_overlapping_not_equal, test_validate_rejects_gap_in_world_coverage.

Comment thread examples/mimo/training/distributed.py Outdated
Comment thread examples/mimo/training/topology.py Outdated
Comment thread examples/mimo/training/topology.py Outdated
Comment thread examples/mimo/training/topology.py Outdated
Comment thread examples/mimo/training/topology.py Outdated
Comment thread examples/mimo/training/topology.py Outdated
Comment thread examples/mimo/training/distributed.py Outdated
@yashaswikarnati yashaswikarnati force-pushed the ykarnati/upstream-trainloop-mm1-topology branch 2 times, most recently from 75e2d04 to a2902bb Compare June 10, 2026 16:28
@yashaswikarnati

Copy link
Copy Markdown
Contributor Author

Round-2 review comments addressed

Docstrings (4 trimmed for concision while preserving architectural meaning):

  • distributed.py module docstring and initialize_distributed collapsed to a one-line summary plus a one-line justification.
  • topology.py module docstring and _build_language_embedding_groups trimmed; the essential invariants (per-module grids, embedding-group construction) are kept.

Expert dims (expt_tp / expt_dp):

  • These are now resolved to concrete ints and validated inside ModuleGridSpec.__post_init__ (no None sentinels left to flow downstream). expt_tp defaults to tp; expt_dp is derived as size // (expt_tp * ep * pp), with divisibility and product checks (expt_tp * ep * expt_dp * pp == size) enforced at construction.
  • _build_grid is simplified to read these already-resolved concrete values directly, with no None-fallback branching.

dp kept explicit (rationale):

  • The module-view size invariant is size = tp * cp * pp * dp, so dp is a first-class field of the module grid. expt_dp is its expert-view analog and is derived in __post_init__ from the expert factorization; keeping dp explicit mirrors the two distinct views (module vs. expert) rather than overloading one field.

Verified on an 8-GPU real-distributed run: all 8 tests in tests/unit_tests/test_mimo_hetero_topology.py pass (including new TestModuleGridSpecResolution cases for implicit/explicit expert resolution and invalid factorization).

Comment thread examples/mimo/training/topology.py Outdated
Comment thread examples/mimo/training/topology.py Outdated
@yashaswikarnati yashaswikarnati force-pushed the ykarnati/upstream-trainloop-mm1-topology branch 2 times, most recently from dfab655 to c37c93c Compare June 10, 2026 18:47
@yashaswikarnati yashaswikarnati marked this pull request as ready for review June 10, 2026 18:51
@svcnvidia-nemo-ci svcnvidia-nemo-ci requested a review from a team June 10, 2026 18:51
@yashaswikarnati yashaswikarnati force-pushed the ykarnati/upstream-trainloop-mm1-topology branch from c37c93c to a212d18 Compare June 10, 2026 19:39
Comment thread examples/mimo/training/topology.py Outdated
Comment thread examples/mimo/training/topology.py Outdated
Add a production training-loop folder examples/mimo/training/ with two
modules ported from the hetero prototype and cleaned to production quality:

- topology.py: builds per-module HyperCommGrid(s) from a layout-general
  ModuleGridSpec (rank offsets come from the spec, not hardcoded), adapts
  each grid into a ProcessGroupCollection, assembles a
  MultiModuleProcessGroupCollection, and creates the language embedding
  groups. Uses the on-main named-view API (register_view with shared_dims)
  for the expert factorization, routes colocated/non-colocated detection
  through RankRole.build, and validates the invariant that grids either
  tile the world disjointly XOR fully share ranks. Ships the non-colocated
  configuration without precluding a colocated one.
- distributed.py: torch.distributed + global-memory-buffer bootstrap that
  does not call mpu.initialize_model_parallel and asserts the parallel_state
  model-parallel globals are uninitialized.

Add an 8-GPU real-distributed unit test asserting the two grids partition
the world, the per-module PGC group sizes match the factorization, and the
invalidation rejects an overlapping-but-not-equal layout.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
@yashaswikarnati yashaswikarnati force-pushed the ykarnati/upstream-trainloop-mm1-topology branch from a212d18 to c82f09f Compare June 10, 2026 20:15
@yashaswikarnati

Copy link
Copy Markdown
Contributor Author

/ok to test c82f09f

@svcnvidia-nemo-ci

Copy link
Copy Markdown

🔄 Merge queue validation started!

You can track the progress here: https://github.com/NVIDIA/Megatron-LM/actions/runs/27311132496

@svcnvidia-nemo-ci

Copy link
Copy Markdown

🔄 Merge queue validation started!

You can track the progress here: https://github.com/NVIDIA/Megatron-LM/actions/runs/27312935448

@svcnvidia-nemo-ci

Copy link
Copy Markdown

🔄 Merge queue validation started!

You can track the progress here: https://github.com/NVIDIA/Megatron-LM/actions/runs/27313542518

@svcnvidia-nemo-ci

Copy link
Copy Markdown

🔄 Merge queue validation started!

You can track the progress here: https://github.com/NVIDIA/Megatron-LM/actions/runs/27315861359

Merged via the queue into NVIDIA:main with commit aa10571 Jun 11, 2026
94 of 96 checks passed
@yashaswikarnati yashaswikarnati deleted the ykarnati/upstream-trainloop-mm1-topology branch June 11, 2026 02:02
lauradang pushed a commit to lauradang/Megatron-LM that referenced this pull request Jun 11, 2026
…ing-loop folder) (NVIDIA#5260)

Co-authored-by: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants