Skip to content

refactor(mimo): unified MimoModelProvider with ModuleSpec-based API#2040

Merged
yashaswikarnati merged 77 commits into
NVIDIA-NeMo:mainfrom
aroshanghias-nvd:phase1/mimo-provider-refactor
Feb 13, 2026
Merged

refactor(mimo): unified MimoModelProvider with ModuleSpec-based API#2040
yashaswikarnati merged 77 commits into
NVIDIA-NeMo:mainfrom
aroshanghias-nvd:phase1/mimo-provider-refactor

Conversation

@aroshanghias-nvd

@aroshanghias-nvd aroshanghias-nvd commented Jan 22, 2026

Copy link
Copy Markdown
Contributor

Summary

This PR introduces a unified MimoModelProvider that accepts ModuleSpecs directly, removing the EncoderProvider abstraction layer.

Key Changes

  • Add MimoModelProvider accepting language_model_spec and modality_submodules_spec
  • Remove EncoderProvider abstraction
  • Rename MIMOConfigMimoParallelismConfig
  • Fix ColocatedCommConfig logic for colocated and homogeneous modes

Tests

✅ All MIMO tests passing

Summary by CodeRabbit

Release Notes

  • New Features

    • Added support for multi-module MIMO (vision-language) models with heterogeneous parallelism configurations enabling different parallel strategies per module.
    • Introduced pre-configured LLaVA vision-language model provider for integrated vision-language training.
    • Extended training configuration system to support MIMO models.
  • Tests

    • Added comprehensive unit tests for MIMO configuration and provider functionality.

✏️ Tip: You can customize this high-level summary in your review settings.

Ali Roshan Ghias added 2 commits January 15, 2026 12:49
@copy-pr-bot

copy-pr-bot Bot commented Jan 22, 2026

Copy link
Copy Markdown

This pull request requires additional validation before any workflows can run on NVIDIA's runners.

Pull request vetters can view their responsibilities here.

Contributors can view more details about this message here.

Comment thread src/megatron/bridge/models/mimo/mimo_provider.py Outdated
Comment thread src/megatron/bridge/training/config.py Outdated
Comment thread src/megatron/bridge/training/config.py Outdated
Comment thread src/megatron/bridge/models/mimo/mimo_builder.py
@yashaswikarnati

Copy link
Copy Markdown
Contributor

lgtm

@yashaswikarnati

Copy link
Copy Markdown
Contributor

@yaoyu-33 / @maanug-nv could one of you help review

Comment thread src/megatron/bridge/training/config.py Outdated
Comment thread src/megatron/bridge/models/mimo/mimo_config.py
Comment thread src/megatron/bridge/models/mimo/mimo_provider.py
@coderabbitai

coderabbitai Bot commented Jan 27, 2026

Copy link
Copy Markdown
Contributor
📝 Walkthrough

Walkthrough

This PR introduces a comprehensive MIMO (Multi-Input Multi-Output) provider system enabling heterogeneous multi-module training with per-module parallelism support. It includes configuration management with deployment mode validation, process-group grid building, and concrete implementations for vision-language models.

Changes

Cohort / File(s) Summary
Core MIMO Configuration
src/megatron/bridge/models/mimo/mimo_config.py
Introduces ModuleParallelismConfig (tracks tensor, pipeline, context, expert, data parallelism) and MimoParallelismConfig (manages multi-module configs with finalization, validation for colocated/homogeneous/heterogeneous deployment modes). Validates consistency constraints and rank_offset ranges.
MIMO Provider Infrastructure
src/megatron/bridge/models/mimo/mimo_provider.py
Implements MimoModelProvider (constructs MIMO models with per-module parallelism, grid/topology creation, pg-collection extraction and injection) and MimoModelProviderResult (container for model, grids, topology, process-groups). Supports per-encoder parallelism, optional colocated communication config, and freezing controls.
MIMO Builder Utilities
src/megatron/bridge/models/mimo/mimo_builder.py
Provides build_hypercomm_grids() (creates HyperCommGrid per module with nccl backend and per-dimension groups), build_colocated_comm_config() (constructs ColocatedCommConfig with topology mapping), and _default_topology() helper for encoder-to-LLM mapping.
LLaVA Provider Implementation
src/megatron/bridge/models/mimo/llava_provider.py
Adds LlavaMimoProvider dataclass for preconfigured LLaVA-style vision-language models with Vicuna-7B language model, CLIP-like vision encoder, and 2-layer MLP projector; includes default config generation and vision submodule spec building.
Package Initialization
src/megatron/bridge/models/mimo/__init__.py
Re-exports public API: MimoParallelismConfig, ModuleParallelismConfig, MimoModelProvider, MimoModelProviderResult, LlavaMimoProvider.
Training Config Integration
src/megatron/bridge/training/config.py
Updates ConfigContainer.model type annotation to include MimoModelProvider alongside existing providers (GPTModelProvider, T5ModelProvider, MambaModelProvider).
Unit Tests
tests/unit_tests/models/mimo/test_mimo_provider.py
Comprehensive test suite covering MimoModelProvider initialization, provide() behavior with/without parallelism config, pg-collection injection into specs, freezing logic, per-encoder parallelism, and non-participating rank scenarios.
Configuration Tests
tests/unit_tests/training/mimo/test_mimo_config.py
Tests ModuleParallelismConfig finalization (data_parallel computation, world_size validation), MimoParallelismConfig validation paths (colocated rank mismatch, homogeneous parallelism consistency, heterogeneous rank_offset overlap detection).
Test Package Init
tests/unit_tests/models/mimo/__init__.py
Adds copyright header to test package initializer.

Sequence Diagram

sequenceDiagram
    participant User as MimoModelProvider.provide()
    participant Grid as build_hypercomm_grids()
    participant PG as _get_pg_collections_from_grids()
    participant Inject as _inject_pg_collection_*()
    participant Model as MimoModel constructor
    
    User->>Grid: Create grids per module from parallelism config
    Grid-->>User: Dict[module_name, HyperCommGrid]
    
    User->>PG: Extract pg_collections from grids per rank
    PG-->>User: Dict[module_name, ProcessGroupCollection]
    
    User->>Inject: Deep copy language spec and inject pg_collection
    Inject-->>User: Augmented language_model_spec
    
    User->>Inject: Deep copy modality specs and inject pg_collection per encoder
    Inject-->>User: Augmented modality_submodules_spec
    
    User->>Model: Construct MimoModel with injected specs
    Model-->>User: MimoModel instance
    
    User->>User: Apply freezing and device/dtype transfer
    User-->>User: Return MimoModelProviderResult (model, grids, topology, pg_collections)
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~60 minutes

🚥 Pre-merge checks | ✅ 2 | ❌ 2
❌ Failed checks (2 warnings)
Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 65.00% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
Test Results For Major Changes ⚠️ Warning PR describes major architectural refactoring but lacks specific test counts, test names, convergence regression results, or performance impact analysis for topology/process group changes. Document specific test counts and names, confirm backward compatibility, provide convergence regression results or numeric equivalence confirmation, and include performance impact analysis.
✅ Passed checks (2 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The title 'refactor(mimo): unified MimoModelProvider with ModuleSpec-based API' accurately describes the main change: unifying MimoModelProvider to accept ModuleSpec objects directly, which is the core objective of this refactor.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing touches
  • 📝 Generate docstrings

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

@coderabbitai coderabbitai Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 3

🤖 Fix all issues with AI agents
In `@src/megatron/bridge/models/mimo/llava_provider.py`:
- Around line 50-82: The subclass __post_init__ in LlavaMimoProvider must invoke
the parent initializer to preserve MimoModelProvider setup: at the start of
LlavaMimoProvider.__post_init__, call super().__post_init__() before accessing
or assigning self.language_config, then proceed with the existing logic that
sets defaults and builds language_model_spec, modality_submodules_spec and
special_token_ids; ensure no assumptions about parent-initialized fields are
made before the super call so parent validations/fields run first.
- Around line 109-113: The projection TransformerConfig used for the
MultimodalProjector (see projection_config in llava_provider.py) includes an
unnecessary num_attention_heads=1; remove that keyword so the config only sets
fields relevant to the MLP projector (e.g., num_layers and hidden_size) to avoid
implying attention is used—update the projection_config instantiation in the
code path that constructs the MultimodalProjector so it no longer passes
num_attention_heads.

In `@tests/unit_tests/models/mimo/test_mimo_provider.py`:
- Around line 210-230: The test_freezing_language_model creates a `result =
provider.provide()` that's not used; fix by either asserting something about
`result` or marking it as intentionally unused (e.g., rename to `_result`) —
locate the test function `test_freezing_language_model`, the `result` assignment
after calling `MimoModelProvider(...).provide()`, and update it to `_result` or
add an assertion referencing `result` (for example ensuring `result` is the
provider or a non-None value) while keeping the existing assertion that
`mock_param.requires_grad` is False.
🧹 Nitpick comments (7)
src/megatron/bridge/models/mimo/mimo_config.py (1)

126-131: Consider validating data_parallel before finalize() in heterogeneous mode.

In heterogeneous mode, parallelism.finalize(None) is called (Line 131), which will fail with "world_size must be provided" if data_parallel is None. While _validate_heterogeneous() later checks for this, the error message from finalize() may be confusing. Consider moving the data_parallel is None check earlier:

♻️ Suggested improvement
         if self.deployment_mode in ("colocated", "homogeneous"):
             for parallelism in self.module_parallelisms.values():
                 parallelism.finalize(world_size)
         else:
+            # Heterogeneous requires data_parallel to be pre-set
+            for name, parallelism in self.module_parallelisms.items():
+                if parallelism.data_parallel is None:
+                    raise ValueError(
+                        f"data_parallel must be set for module '{name}' in heterogeneous deployment."
+                    )
             for parallelism in self.module_parallelisms.values():
                 parallelism.finalize(None)
src/megatron/bridge/models/mimo/mimo_builder.py (2)

3-3: Remove unused import Optional.

Optional is imported but never used in this file.

🧹 Suggested fix
-from typing import Dict, List, Optional
+from typing import Dict, List

62-62: Simplify redundant dict comprehension.

The dict comprehension {name: grid for name, grid in grids.items()} creates an identical copy of grids. If a copy is intended, use grids.copy(); otherwise, pass grids directly.

♻️ Suggested fix
-    module_to_grid_map = {name: grid for name, grid in grids.items()}
     topology = _default_topology(mimo_parallelism_config)
     return ColocatedCommConfig(
-        module_to_grid_map=module_to_grid_map,
+        module_to_grid_map=grids,
         topology=topology,
         dim_mapping={"b": 0, "s": 1, "h": 2},
     )
src/megatron/bridge/models/mimo/mimo_provider.py (1)

201-204: Rename unused loop variable encoder_name to _encoder_name.

The loop variable encoder_name is not used within the loop body. Following Python conventions, prefix it with underscore to indicate it's intentionally unused.

🧹 Suggested fix
-            for encoder_name, encoder_spec in spec.submodules["encoders"].items():
+            for _encoder_name, encoder_spec in spec.submodules["encoders"].items():
tests/unit_tests/training/mimo/test_mimo_config.py (1)

14-17: Use raw string for regex pattern with metacharacters.

The pattern "world_size .* not divisible" contains regex metacharacters (.*). Use a raw string prefix for clarity and to avoid potential escaping issues.

🧹 Suggested fix
 def test_module_parallelism_finalize_invalid_world_size():
     parallelism = ModuleParallelismConfig(tensor_parallel=3, pipeline_parallel=2)
-    with pytest.raises(ValueError, match="world_size .* not divisible"):
+    with pytest.raises(ValueError, match=r"world_size .* not divisible"):
         parallelism.finalize(world_size=10)
src/megatron/bridge/models/mimo/llava_provider.py (1)

130-137: Encoder key mismatch: "clip_encoder" vs modality key "images".

The encoder is registered under "clip_encoder" in the submodules dict (line 134), but the modality itself is keyed as "images" (line 77-78). This is likely intentional for semantic clarity, but ensure downstream code correctly maps between modality names and encoder names when applying per-encoder parallelism.

tests/unit_tests/models/mimo/test_mimo_provider.py (1)

4-11: Unused imports: pytest and torch.

Static analysis correctly identifies that pytest and torch are imported but not used. While pytest is typically needed for fixtures and markers, neither is used in this file. Consider removing them or adding a # noqa: F401 comment if they're kept for future use.

Suggested fix
-from unittest.mock import MagicMock, Mock, patch
-
-import pytest
-import torch
+from unittest.mock import MagicMock, Mock, patch
 
 from megatron.core.transformer.spec_utils import ModuleSpec

Comment thread src/megatron/bridge/models/mimo/llava_provider.py
Comment thread src/megatron/bridge/models/mimo/llava_provider.py
Comment thread tests/unit_tests/models/mimo/test_mimo_provider.py Outdated
Comment thread src/megatron/bridge/models/mimo/mimo_config.py Outdated
class MimoParallelismConfig:
"""Configuration for multi-module (MIMO) heterogeneous parallelism."""

llm_module_name: str

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why llm is special?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The canonical MiMo always has an LLM as the middle component, hence the special treatment for topology.

# Step 8: Move to device/dtype
mimo_model.to(device).to(dtype)

return MimoModelProviderResult(

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we want to keep provide only yield a model, so we can blend in train loop.
Maybe have another method to yield meta data and keep provider only yield model instance?
Otherwise it's confusing design.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Completely rewrote mimo_provider.py. Now it has mixin inheritance, provide() returns model directly, new build_infra(), provide_distributed_model() override, MimoStubModel for non-participating ranks.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please review, and let me know if it is now compatible with the train loop.

Comment thread src/megatron/bridge/models/mimo/mimo_builder.py Outdated

# Model specs (user provides, like llava_vlm.py example)
language_model_spec: ModuleSpec
modality_submodules_spec: Dict[str, ModuleSpec] = field(default_factory=dict)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would like to understand the design choice of using a string label for each module. seems this way the module state/components (spec, parallelism config, etc.) are broken up into separate dicts.
to me, it would make more sense to bundle the contents of a single module into a MimoModuleWrapper object that contains spec, parallelism config, commgrid, etc.

not asking for a refactor, just trying to discuss for my understanding.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The string-keyed dict follows Megatron-Core's existing MimoModelConfig.modality_submodules_spec: Dict[str, ModuleSpec] API - the provider passes it directly without translation.
Extending this pattern to MimoParallelismConfig keeps the interface consistent. A MimoModuleWrapper could work as a convenience layer on top if needed, but the current design keeps components decoupled (e.g., specs can be used without parallelism config for testing).



@dataclass
class MimoModelProvider(ModelProviderMixin[MimoModel]):

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

im quite ignorant on MIMO design, but is there no need for a dict of TransformerConfigs per module?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Each ModuleSpec carries its config in params['config'] (e.g., language_model_spec.params['config'] holds the LLM's TransformerConfig). The projector spec has its own config too. Some modules like HF wrappers manage config internally. So there's no need for a separate Dict[str, TransformerConfig] at the provider level - configs live within their respective specs.

Comment thread src/megatron/bridge/models/mimo/mimo_provider.py
maanug-nv
maanug-nv previously approved these changes Jan 29, 2026

@maanug-nv maanug-nv left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

overall don't see any issues. but left several comments to hopefully improve my understanding

Add extensive unit test coverage for the MIMO (Multi-Input Multi-Output)
model provider refactor to meet codecov requirements. This ensures the new
ModuleSpec-based API is thoroughly tested.

Test coverage includes:
- MimoModelProvider: initialization, properties, infrastructure building,
  provide methods, distributed model handling, DDP wrapping, device/dtype
  handling, freezing, and parallelism validation
- LlavaMimoProvider: initialization, spec generation, configuration,
  freezing, and error handling
- mimo_builder: HyperCommGrid building with various parallelism configs
  and topology generation
- MimoParallelismConfig: edge cases for heterogeneous rank validation,
  properties, and finalization

Fixes:
- Make LlavaMimoProvider.language_model_spec optional (built in __post_init__)
- Add MockModule class for proper torch.nn.Module mocking in tests
- Add comprehensive CUDA mocking to prevent GPU access in unit tests

Signed-off-by: Ali Roshan Ghias <aroshanghias@nvidia.com>
Co-authored-by: Cursor <cursoragent@cursor.com>
@aroshanghias-nvd

Copy link
Copy Markdown
Contributor Author

/ok to test 1527a1b

Remove unused variables to fix ruff F841 errors:
- Remove unused 'grids' variable in test_mimo_builder.py (4 instances)
- Remove unused 'result' variable in test_mimo_provider.py (6 instances)

These variables were assigned but never used in the test assertions.

Signed-off-by: Ali Roshan Ghias <aroshanghias@nvidia.com>
Co-authored-by: Cursor <cursoragent@cursor.com>
@aroshanghias-nvd

Copy link
Copy Markdown
Contributor Author

/ok to test a7c46c9

Fix trailing whitespace and apply ruff formatting to pass CI linting checks.

Signed-off-by: Ali Roshan Ghias <aroshanghias@nvidia.com>
@aroshanghias-nvd

Copy link
Copy Markdown
Contributor Author

/ok to test 1ed7296

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.