refactor(mimo): unified MimoModelProvider with ModuleSpec-based API#2040
Conversation
Signed-off-by: Ali Roshan Ghias <aliroshanghias@nvidia.com>
|
lgtm |
|
@yaoyu-33 / @maanug-nv could one of you help review |
📝 WalkthroughWalkthroughThis PR introduces a comprehensive MIMO (Multi-Input Multi-Output) provider system enabling heterogeneous multi-module training with per-module parallelism support. It includes configuration management with deployment mode validation, process-group grid building, and concrete implementations for vision-language models. Changes
Sequence DiagramsequenceDiagram
participant User as MimoModelProvider.provide()
participant Grid as build_hypercomm_grids()
participant PG as _get_pg_collections_from_grids()
participant Inject as _inject_pg_collection_*()
participant Model as MimoModel constructor
User->>Grid: Create grids per module from parallelism config
Grid-->>User: Dict[module_name, HyperCommGrid]
User->>PG: Extract pg_collections from grids per rank
PG-->>User: Dict[module_name, ProcessGroupCollection]
User->>Inject: Deep copy language spec and inject pg_collection
Inject-->>User: Augmented language_model_spec
User->>Inject: Deep copy modality specs and inject pg_collection per encoder
Inject-->>User: Augmented modality_submodules_spec
User->>Model: Construct MimoModel with injected specs
Model-->>User: MimoModel instance
User->>User: Apply freezing and device/dtype transfer
User-->>User: Return MimoModelProviderResult (model, grids, topology, pg_collections)
Estimated code review effort🎯 4 (Complex) | ⏱️ ~60 minutes 🚥 Pre-merge checks | ✅ 2 | ❌ 2❌ Failed checks (2 warnings)
✅ Passed checks (2 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing touches
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Actionable comments posted: 3
🤖 Fix all issues with AI agents
In `@src/megatron/bridge/models/mimo/llava_provider.py`:
- Around line 50-82: The subclass __post_init__ in LlavaMimoProvider must invoke
the parent initializer to preserve MimoModelProvider setup: at the start of
LlavaMimoProvider.__post_init__, call super().__post_init__() before accessing
or assigning self.language_config, then proceed with the existing logic that
sets defaults and builds language_model_spec, modality_submodules_spec and
special_token_ids; ensure no assumptions about parent-initialized fields are
made before the super call so parent validations/fields run first.
- Around line 109-113: The projection TransformerConfig used for the
MultimodalProjector (see projection_config in llava_provider.py) includes an
unnecessary num_attention_heads=1; remove that keyword so the config only sets
fields relevant to the MLP projector (e.g., num_layers and hidden_size) to avoid
implying attention is used—update the projection_config instantiation in the
code path that constructs the MultimodalProjector so it no longer passes
num_attention_heads.
In `@tests/unit_tests/models/mimo/test_mimo_provider.py`:
- Around line 210-230: The test_freezing_language_model creates a `result =
provider.provide()` that's not used; fix by either asserting something about
`result` or marking it as intentionally unused (e.g., rename to `_result`) —
locate the test function `test_freezing_language_model`, the `result` assignment
after calling `MimoModelProvider(...).provide()`, and update it to `_result` or
add an assertion referencing `result` (for example ensuring `result` is the
provider or a non-None value) while keeping the existing assertion that
`mock_param.requires_grad` is False.
🧹 Nitpick comments (7)
src/megatron/bridge/models/mimo/mimo_config.py (1)
126-131: Consider validatingdata_parallelbeforefinalize()in heterogeneous mode.In heterogeneous mode,
parallelism.finalize(None)is called (Line 131), which will fail with "world_size must be provided" ifdata_parallelisNone. While_validate_heterogeneous()later checks for this, the error message fromfinalize()may be confusing. Consider moving thedata_parallel is Nonecheck earlier:♻️ Suggested improvement
if self.deployment_mode in ("colocated", "homogeneous"): for parallelism in self.module_parallelisms.values(): parallelism.finalize(world_size) else: + # Heterogeneous requires data_parallel to be pre-set + for name, parallelism in self.module_parallelisms.items(): + if parallelism.data_parallel is None: + raise ValueError( + f"data_parallel must be set for module '{name}' in heterogeneous deployment." + ) for parallelism in self.module_parallelisms.values(): parallelism.finalize(None)src/megatron/bridge/models/mimo/mimo_builder.py (2)
3-3: Remove unused importOptional.
Optionalis imported but never used in this file.🧹 Suggested fix
-from typing import Dict, List, Optional +from typing import Dict, List
62-62: Simplify redundant dict comprehension.The dict comprehension
{name: grid for name, grid in grids.items()}creates an identical copy ofgrids. If a copy is intended, usegrids.copy(); otherwise, passgridsdirectly.♻️ Suggested fix
- module_to_grid_map = {name: grid for name, grid in grids.items()} topology = _default_topology(mimo_parallelism_config) return ColocatedCommConfig( - module_to_grid_map=module_to_grid_map, + module_to_grid_map=grids, topology=topology, dim_mapping={"b": 0, "s": 1, "h": 2}, )src/megatron/bridge/models/mimo/mimo_provider.py (1)
201-204: Rename unused loop variableencoder_nameto_encoder_name.The loop variable
encoder_nameis not used within the loop body. Following Python conventions, prefix it with underscore to indicate it's intentionally unused.🧹 Suggested fix
- for encoder_name, encoder_spec in spec.submodules["encoders"].items(): + for _encoder_name, encoder_spec in spec.submodules["encoders"].items():tests/unit_tests/training/mimo/test_mimo_config.py (1)
14-17: Use raw string for regex pattern with metacharacters.The pattern
"world_size .* not divisible"contains regex metacharacters (.*). Use a raw string prefix for clarity and to avoid potential escaping issues.🧹 Suggested fix
def test_module_parallelism_finalize_invalid_world_size(): parallelism = ModuleParallelismConfig(tensor_parallel=3, pipeline_parallel=2) - with pytest.raises(ValueError, match="world_size .* not divisible"): + with pytest.raises(ValueError, match=r"world_size .* not divisible"): parallelism.finalize(world_size=10)src/megatron/bridge/models/mimo/llava_provider.py (1)
130-137: Encoder key mismatch:"clip_encoder"vs modality key"images".The encoder is registered under
"clip_encoder"in the submodules dict (line 134), but the modality itself is keyed as"images"(line 77-78). This is likely intentional for semantic clarity, but ensure downstream code correctly maps between modality names and encoder names when applying per-encoder parallelism.tests/unit_tests/models/mimo/test_mimo_provider.py (1)
4-11: Unused imports:pytestandtorch.Static analysis correctly identifies that
pytestandtorchare imported but not used. Whilepytestis typically needed for fixtures and markers, neither is used in this file. Consider removing them or adding a# noqa: F401comment if they're kept for future use.Suggested fix
-from unittest.mock import MagicMock, Mock, patch - -import pytest -import torch +from unittest.mock import MagicMock, Mock, patch from megatron.core.transformer.spec_utils import ModuleSpec
| class MimoParallelismConfig: | ||
| """Configuration for multi-module (MIMO) heterogeneous parallelism.""" | ||
|
|
||
| llm_module_name: str |
There was a problem hiding this comment.
The canonical MiMo always has an LLM as the middle component, hence the special treatment for topology.
| # Step 8: Move to device/dtype | ||
| mimo_model.to(device).to(dtype) | ||
|
|
||
| return MimoModelProviderResult( |
There was a problem hiding this comment.
we want to keep provide only yield a model, so we can blend in train loop.
Maybe have another method to yield meta data and keep provider only yield model instance?
Otherwise it's confusing design.
There was a problem hiding this comment.
Completely rewrote mimo_provider.py. Now it has mixin inheritance, provide() returns model directly, new build_infra(), provide_distributed_model() override, MimoStubModel for non-participating ranks.
There was a problem hiding this comment.
Please review, and let me know if it is now compatible with the train loop.
|
|
||
| # Model specs (user provides, like llava_vlm.py example) | ||
| language_model_spec: ModuleSpec | ||
| modality_submodules_spec: Dict[str, ModuleSpec] = field(default_factory=dict) |
There was a problem hiding this comment.
Would like to understand the design choice of using a string label for each module. seems this way the module state/components (spec, parallelism config, etc.) are broken up into separate dicts.
to me, it would make more sense to bundle the contents of a single module into a MimoModuleWrapper object that contains spec, parallelism config, commgrid, etc.
not asking for a refactor, just trying to discuss for my understanding.
There was a problem hiding this comment.
The string-keyed dict follows Megatron-Core's existing MimoModelConfig.modality_submodules_spec: Dict[str, ModuleSpec] API - the provider passes it directly without translation.
Extending this pattern to MimoParallelismConfig keeps the interface consistent. A MimoModuleWrapper could work as a convenience layer on top if needed, but the current design keeps components decoupled (e.g., specs can be used without parallelism config for testing).
|
|
||
|
|
||
| @dataclass | ||
| class MimoModelProvider(ModelProviderMixin[MimoModel]): |
There was a problem hiding this comment.
im quite ignorant on MIMO design, but is there no need for a dict of TransformerConfigs per module?
There was a problem hiding this comment.
Each ModuleSpec carries its config in params['config'] (e.g., language_model_spec.params['config'] holds the LLM's TransformerConfig). The projector spec has its own config too. Some modules like HF wrappers manage config internally. So there's no need for a separate Dict[str, TransformerConfig] at the provider level - configs live within their respective specs.
maanug-nv
left a comment
There was a problem hiding this comment.
overall don't see any issues. but left several comments to hopefully improve my understanding
2a0cfaf to
c293eaa
Compare
Add extensive unit test coverage for the MIMO (Multi-Input Multi-Output) model provider refactor to meet codecov requirements. This ensures the new ModuleSpec-based API is thoroughly tested. Test coverage includes: - MimoModelProvider: initialization, properties, infrastructure building, provide methods, distributed model handling, DDP wrapping, device/dtype handling, freezing, and parallelism validation - LlavaMimoProvider: initialization, spec generation, configuration, freezing, and error handling - mimo_builder: HyperCommGrid building with various parallelism configs and topology generation - MimoParallelismConfig: edge cases for heterogeneous rank validation, properties, and finalization Fixes: - Make LlavaMimoProvider.language_model_spec optional (built in __post_init__) - Add MockModule class for proper torch.nn.Module mocking in tests - Add comprehensive CUDA mocking to prevent GPU access in unit tests Signed-off-by: Ali Roshan Ghias <aroshanghias@nvidia.com> Co-authored-by: Cursor <cursoragent@cursor.com>
|
/ok to test 1527a1b |
Remove unused variables to fix ruff F841 errors: - Remove unused 'grids' variable in test_mimo_builder.py (4 instances) - Remove unused 'result' variable in test_mimo_provider.py (6 instances) These variables were assigned but never used in the test assertions. Signed-off-by: Ali Roshan Ghias <aroshanghias@nvidia.com> Co-authored-by: Cursor <cursoragent@cursor.com>
|
/ok to test a7c46c9 |
Fix trailing whitespace and apply ruff formatting to pass CI linting checks. Signed-off-by: Ali Roshan Ghias <aroshanghias@nvidia.com>
|
/ok to test 1ed7296 |
Summary
This PR introduces a unified
MimoModelProviderthat acceptsModuleSpecs directly, removing theEncoderProviderabstraction layer.Key Changes
MimoModelProvideraccepting language_model_spec and modality_submodules_specEncoderProviderabstractionMIMOConfig→MimoParallelismConfigColocatedCommConfiglogic for colocated and homogeneous modesTests
✅ All MIMO tests passing
Summary by CodeRabbit
Release Notes
New Features
Tests
✏️ Tip: You can customize this high-level summary in your review settings.