feat(mimo): phase 3 - MIMO data loading scaffolding#2007
Conversation
Signed-off-by: Ali Roshan Ghias <aliroshanghias@nvidia.com>
233b268 to
bcd4b61
Compare
| For testing with synthetic data, use MockMimoProvider instead. | ||
|
|
||
| Args: | ||
| seq_length: Total sequence length (text + encoder tokens). |
There was a problem hiding this comment.
how would be this seq_length arg get used for HF datasets? would something like padding to this max seq_len happens?
There was a problem hiding this comment.
seq_length is the total sequence budget for the model, used for:
-
Placeholder insertion - We insert encoder_seq_lengths[modality] placeholder tokens per modality (e.g., 577 for CLIP ViT-L/14). These are later replaced 1:1 with encoder hidden states during forward.
-
Text truncation - Text is truncated to fit: max_text_tokens = seq_length - sum(encoder_seq_lengths.values())
-
Padding - Shorter sequences are padded to exactly seq_length
I've added encoder_seq_lengths parameter to specify how many placeholder tokens each encoder needs. This matches the prototype pattern in Megatron-LM (image_seq_length in MockVLMDataset).
Updated the docstrings to clarify this.
| Returns: | ||
| A tuple (train_dataloader, valid_dataloader, test_dataloader). | ||
| """ | ||
| # Check for MIMO path |
There was a problem hiding this comment.
ooc, how is dataloader builder func for vlm happening currently.
There was a problem hiding this comment.
Current VLM uses vlm_datasets/ providers (e.g., MockVLMConversationProvider) with manual build_train_valid_test_datasets_provider functions. MIMO adds auto-dispatch because it needs heterogeneous DP-aware data loading that the standard path doesn't support.
679b9bf to
0e15082
Compare
📝 WalkthroughWalkthroughThis PR introduces comprehensive MIMO (Multi-Input Multi-Output) model and data loading infrastructure to Megatron, enabling heterogeneous multi-module parallel training with configurable deployment modes. The implementation spans dataset handling, model provider abstractions, parallelism configuration with validation, distributed data loading with DP-aware sharding, and training integration with comprehensive test coverage. Changes
Estimated code review effort🎯 5 (Critical) | ⏱️ ~120 minutes Possibly related PRs
Suggested reviewers
🚥 Pre-merge checks | ✅ 3 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (3 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing touches
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Actionable comments posted: 11
🤖 Fix all issues with AI agents
In `@src/megatron/bridge/data/mimo/collate.py`:
- Around line 87-107: Update the function docstring that contains the logic
using modality_batch_items / first_non_empty (the collating routine that builds
modality_inputs) to explicitly document sparse-modality behavior: state that
some batch items may omit a modality key (sparse modalities are supported), that
when stacking tensors the resulting batch dimension equals the number of items
that include that modality (not the original full batch size), and that the
model's forward pass must handle variable/unequal batch sizes across different
modalities; keep this note near the existing "stacked along batch dimension"
sentence so callers understand this designed behavior.
In `@src/megatron/bridge/data/mimo/dataset.py`:
- Around line 158-166: The attention_mask is currently created as ones
(attention_mask) and ignores padding in input_ids; change it to mask out padding
by creating attention_mask = (input_ids != pad_token_id).long() (or boolean) so
padding tokens are 0 and real tokens 1, and ensure position_ids still derived
from token positions (position_ids = torch.arange(len(input_ids)) or use
cumulative non-padding positions if needed); update any consumers expecting
int/long dtype accordingly and reference attention_mask, input_ids,
position_ids, and pad_token_id when applying the fix.
In `@src/megatron/bridge/data/mimo/loaders.py`:
- Line 107: The code computes micro_batch_size = cfg.train.global_batch_size //
dp_size without validating divisibility or non-zero result; update the logic in
loaders.py around the micro_batch_size calculation to (1) check that
cfg.train.global_batch_size >= dp_size and (2) check that
cfg.train.global_batch_size % dp_size == 0, and if either check fails raise a
clear ValueError (mentioning cfg.train.global_batch_size and dp_size) so callers
know to adjust configuration; only after validation compute micro_batch_size =
cfg.train.global_batch_size // dp_size.
In `@src/megatron/bridge/data/mimo/mock_provider.py`:
- Around line 95-113: MockMimoProvider is missing the trust_remote_code
attribute referenced in _load_processors; add a trust_remote_code boolean field
to MockMimoProvider (e.g., an __init__ parameter defaulting to False and stored
on self) so is_safe_repo can read it, or set a class-level default bool; update
the MockMimoProvider constructor to accept and assign trust_remote_code and
ensure any tests or instantiations pass the intended value when creating
processor_paths and calling _load_processors.
In `@tests/unit_tests/data/mimo/test_dp_utils.py`:
- Around line 40-50: The helper _make_mimo_cfg constructs
ModuleParallelismConfig using incorrect field names; update the
ModuleParallelismConfig instantiations inside _make_mimo_cfg to use
tensor_model_parallel_size and data_parallel_size instead of tensor_parallel and
data_parallel, keeping rank_offset and the surrounding
MimoParallelismConfig(llm_module_name="llm",
module_parallelisms=module_parallelisms, deployment_mode=deployment_mode) intact
so tests reference the correct attributes.
- Around line 71-86: The test modifies the MIMO config using the wrong field
name: in test_get_mimo_dp_info_colocated_llm_first_pp (and any similar tests
using _make_mimo_cfg) you should set
module_parallelisms["vision"].data_parallel_size and
module_parallelisms["llm"].data_parallel_size instead of data_parallel; update
those assignments so the config uses the correct attribute name
(data_parallel_size) before calling get_mimo_dp_info so the FakeGrid sizes match
the config.
In `@tests/unit_tests/training/mimo/test_mimo_config.py`:
- Around line 78-91: The test constructs ModuleParallelismConfig with incorrect
field names; replace tensor_parallel and data_parallel with the actual dataclass
fields tensor_model_parallel_size and data_parallel_size when creating
module_parallelisms for the MimoParallelismConfig in
test_mimo_heterogeneous_gap_in_middle_raises_error so the
ModuleParallelismConfig instances match the class definition and then call
mimo_parallelism_config.finalize(world_size=None) as before to assert the
ValueError.
- Around line 94-107: The test uses incorrect keyword argument names when
constructing ModuleParallelismConfig; update the two
ModuleParallelismConfig(...) calls in test_mimo_heterogeneous_leading_gap_warns
to use the correct field names (e.g., tensor_parallelism and data_parallelism
instead of tensor_parallel and data_parallel) so the objects are created
correctly and then call MimoParallelismConfig(...).finalize(world_size=None) as
before.
- Around line 110-125: The test test_mimo_heterogeneous_contiguous_no_warning
uses incorrect constructor field names for ModuleParallelismConfig; update the
module_parallelisms entries to use the real parameter names expected by
ModuleParallelismConfig (e.g., tensor_parallel_size and data_parallel_size and
the correct rank_offset field) so MimoParallelismConfig(...) and subsequent
mimo_parallelism_config.finalize(...) exercise the contiguous allocation path
without raising due to mismatched keys.
In `@tests/unit_tests/training/mimo/test_mimo_ddp.py`:
- Around line 194-196: Remove the unused local variable assignments named result
where wrap_mimo_model_distributed(...) is called (the calls to
wrap_mimo_model_distributed with arguments mimo_model, ddp_config,
mimo_parallelism_config, grids, pg_collections and the later identical call
around lines 229-231); instead invoke wrap_mimo_model_distributed(...) without
assigning its return (or assign to _ if you intentionally ignore it) so the
result variable is not created and Ruff/Flake8 F841 is resolved.
- Line 6: Remove the unused top-level import "import pytest" from the
test_mimo_ddp module; locate the import statement in the test_mimo_ddp.py file
and delete it so the linter no longer flags an unused pytest import.
🧹 Nitpick comments (15)
src/megatron/bridge/data/mimo/collate.py (1)
56-60: Standard field stacking assumes uniform sequence lengths.
torch.stackwill raise a RuntimeError if items have different sequence lengths. Unlike the modality inputs handling (which catches RuntimeError), these standard fields will propagate the exception without a helpful message.Consider adding a try/except with a more descriptive error, or validating shapes upfront.
♻️ Optional: Add descriptive error handling
# Stack standard fields - input_ids = torch.stack([item["input_ids"] for item in batch]) - labels = torch.stack([item["labels"] for item in batch]) - attention_mask = torch.stack([item["attention_mask"] for item in batch]) - position_ids = torch.stack([item["position_ids"] for item in batch]) + try: + input_ids = torch.stack([item["input_ids"] for item in batch]) + labels = torch.stack([item["labels"] for item in batch]) + attention_mask = torch.stack([item["attention_mask"] for item in batch]) + position_ids = torch.stack([item["position_ids"] for item in batch]) + except RuntimeError as e: + raise RuntimeError( + f"Failed to stack batch items. Ensure all items have the same sequence length. " + f"Original error: {e}" + ) from etests/unit_tests/data/mimo/test_collate.py (1)
10-24: Consider explicitOptionaltype hint.Per PEP 484, implicit
Optionalvia= Nonedefault is discouraged. Static analysis flagged this (RUF013).♻️ Optional fix
+from typing import Optional + def make_sample( seq_length: int = 64, - modalities: dict = None, + modalities: Optional[dict] = None, ) -> dict:tests/unit_tests/data/mimo/test_dataset.py (1)
371-410: Multi-modality placeholder ordering test relies on dict insertion order.This test assumes placeholders appear in dict insertion order (
visionbeforeaudio). While Python 3.7+ guarantees dict ordering, this implicit dependency could be fragile if the ordering contract isn't documented inMimoDataset.Consider adding a brief comment in the test or verifying the ordering is documented in the dataset implementation.
src/megatron/bridge/data/mimo/loaders.py (3)
6-6: Remove unused import.
Dictis imported but not used in this file.🔧 Proposed fix
-from typing import TYPE_CHECKING, Dict, Optional, Tuple +from typing import TYPE_CHECKING, Optional, Tuple
20-27: Consider documenting or removing unusedtrain_stateparameter.The
train_stateparameter is declared but not used in the function body. If this is for API consistency with other data loader builders or planned future use, consider adding a brief comment. Otherwise, it can be removed.
66-67: UseTypeErrorfor type validation.When validating that an argument is of a specific type,
TypeErroris semantically more appropriate thanValueError.🔧 Proposed fix
if not isinstance(cfg.model, MimoModelProvider): - raise ValueError("cfg.model must be MimoModelProvider for MIMO data loading.") + raise TypeError("cfg.model must be MimoModelProvider for MIMO data loading.")src/megatron/bridge/training/mimo_ddp.py (1)
78-94: Edge case: empty encoders dictionary.The condition
if hasattr(submodule, 'encoders') and submodule.encoderspasses ifencodersis a non-empty dict, butnext(iter(submodule.encoders.keys()))would fail withStopIterationif called on an empty dict. The current check is correct, but consider adding a more explicit check or comment for clarity.🔧 Suggested clarification
- if hasattr(submodule, 'encoders') and submodule.encoders: + # Skip submodules without encoders or with empty encoder dict + if not (hasattr(submodule, 'encoders') and submodule.encoders): + continue + + if True: # Encoders present encoder_key = next(iter(submodule.encoders.keys()))Or simply add a comment to clarify the truthiness check covers the empty dict case.
src/megatron/bridge/data/mimo/hf_provider.py (2)
8-8: Remove unused imports.
ListandUnionare imported but not used in this file.🔧 Proposed fix
-from typing import Any, Callable, Dict, List, Optional, Tuple, Union +from typing import Any, Callable, Dict, Optional, Tuple
122-140: Consider narrowing exception handling.Catching
ValueErrorbroadly may inadvertently suppress errors unrelated to missing splits (e.g., configuration issues). Consider checking the exception message or using a more specific approach.🔧 Proposed approach
try: dataset = load_dataset( self.hf_dataset_path, name=self.hf_dataset_name, split=split, streaming=self.streaming, trust_remote_code=is_safe_repo( trust_remote_code=self.trust_remote_code, hf_path=self.hf_dataset_path, ), ) return dataset - except ValueError: - # Split doesn't exist + except ValueError as e: + # Check if this is specifically a missing split error + if "split" in str(e).lower(): + return None + raise - return Nonesrc/megatron/bridge/data/mimo/dataset.py (2)
6-6: Remove unused imports.
ListandUnionare imported but not used in this file.🔧 Proposed fix
-from typing import Any, Callable, Dict, List, Optional, Union +from typing import Any, Callable, Dict, Optional
189-212: Potential over-truncation of text.The text is first tokenized with
max_length=self.seq_length(line 193), then truncated again after adding prefix tokens (line 211). This means if prefix tokens take up N positions, the final text could be truncated toseq_length - Ntokens, but the initial tokenization may have already truncated content that could have fit.Consider pre-computing the available text space:
🔧 Proposed optimization
+ # Pre-compute space available for text tokens + total_prefix_len = sum( + self.encoder_seq_lengths.get(m, 1) + for m in modality_inputs.keys() + if m in self.special_token_ids + ) + max_text_length = self.seq_length - total_prefix_len + # Tokenize the text encoded = self.tokenizer( text, truncation=True, - max_length=self.seq_length, + max_length=max_text_length, return_tensors="pt", ) input_ids = encoded["input_ids"].squeeze(0) # Insert placeholder tokens for each modality at the beginning - # The order follows the order of modality_inputs (Python 3.7+ dict ordering) prefix_tokens = [] for modality_name in modality_inputs.keys(): if modality_name in self.special_token_ids: token_id = self.special_token_ids[modality_name] num_tokens = self.encoder_seq_lengths.get(modality_name, 1) prefix_tokens.extend([token_id] * num_tokens) if prefix_tokens: prefix = torch.tensor(prefix_tokens, dtype=input_ids.dtype) - # Truncate text tokens to make room for placeholders - max_text_len = self.seq_length - len(prefix_tokens) - input_ids = input_ids[:max_text_len] input_ids = torch.cat([prefix, input_ids])src/megatron/bridge/models/mimo/mimo_builder.py (1)
61-73: Return type annotation uses forward reference but static analysis can't see the lazy import.The return type
"ColocatedCommConfig"on Line 63 is correctly handled via a string annotation (forward reference), and the actual import happens on Line 65 before use. This is a valid pattern for lazy imports.However, Line 67 creates a redundant copy of the grids dict:
♻️ Minor simplification
def build_colocated_comm_config( mimo_parallelism_config: MimoParallelismConfig, grids: Dict[str, "HyperCommGrid"] ) -> "ColocatedCommConfig": """Build ColocatedCommConfig with default encoder-to-LLM topology.""" from megatron.core.models.mimo.config.base_configs import ColocatedCommConfig - module_to_grid_map = {name: grid for name, grid in grids.items()} topology = _default_topology(mimo_parallelism_config) return ColocatedCommConfig( - module_to_grid_map=module_to_grid_map, + module_to_grid_map=grids, topology=topology, dim_mapping={"b": 0, "s": 1, "h": 2}, )tests/unit_tests/models/mimo/test_mimo_provider.py (1)
355-376: Unused variablemodelfrom static analysis is valid but intentional.The
modelvariable on Line 372 is assigned but never used because the test focuses on verifying the side effect (parameter freezing) via the mock. While technically unused, the assignment makes the test intent clearer.Consider prefixing with underscore to silence the warning:
♻️ Silence unused variable warning
- model = provider.provide() + _model = provider.provide()src/megatron/bridge/models/mimo/mimo_provider.py (2)
22-22: Unused importDistributedDataParallelflagged by static analysis.The import at Line 22 is flagged as unused. While
DistributedDataParallelConfigis used,DistributedDataParallelappears to not be directly referenced in this file (it's used inmimo_ddp.pyinstead).♻️ Remove unused import
-from megatron.core.distributed import DistributedDataParallel, DistributedDataParallelConfig +from megatron.core.distributed import DistributedDataParallelConfig
294-299: Unused loop variableencoder_namecan be renamed.The loop at Line 296 iterates over encoder items but only uses
encoder_spec, notencoder_name.♻️ Rename unused loop variable
- for encoder_name, encoder_spec in spec.submodules["encoders"].items(): + for _encoder_name, encoder_spec in spec.submodules["encoders"].items():
97b3478 to
5ac9802
Compare
Add DDP wrapping utilities, embedding group support for PP > 1, and improved validation for heterogeneous deployment. Key changes: - Add wrap_mimo_model_distributed() for rank-aware DDP wrapping of MIMO submodules - Add embedding group helpers to mimo_builder.py: populate_embedding_and_position_groups(), is_pp_first_stage(), is_pp_last_stage(), is_current_rank_in_grid() - Improve gap detection in MimoParallelismConfig._validate_heterogeneous() - Extend _get_pg_collections_from_grids() to populate pos_embd and embd process groups - Set variable_seq_lengths=True in provide_distributed_model() - Update copyright headers to 2026 - Add comprehensive unit tests for all new functionality Signed-off-by: Ali Roshan Ghias <aroshanghias@nvidia.com>
5ac9802 to
45517d0
Compare
|
/ok to test faf91b0 |
|
/claude review |
| num_workers=mimo_provider.num_workers, | ||
| collate_fn=collate_fn, | ||
| pin_memory=mimo_provider.pin_memory, | ||
| drop_last=mimo_provider.drop_last, |
There was a problem hiding this comment.
Bug: mimo_provider.drop_last will raise AttributeError at runtime. drop_last is not defined on DataloaderConfig, DatasetProvider, HFMimoDatasetProvider, or MockMimoProvider.
Either add drop_last: bool = True to DataloaderConfig, or hardcode/parameterize it here.
There was a problem hiding this comment.
Added drop_last to DataloaderConfig.
| self.mimo.finalize(world_size if world_size and world_size > 1 else None) | ||
|
|
||
| llm_parallelism = self.mimo.get_parallelism(self.mimo.llm_module_name) | ||
| parallelism_checks = { |
There was a problem hiding this comment.
Design concern: MIMO fields belong on MimoModelProvider, not ConfigContainer
Adding mimo and encoder_providers to ConfigContainer means every user of GPT/Llama/T5/Mamba/… sees these fields, even though ~95% of models have no use for them. It also creates an implicit invariant that users must keep cfg.model and cfg.mimo in sync manually — a footgun that _validate_mimo() is trying to paper over.
Both fields are logically MIMO-specific and could live entirely inside MimoModelProvider:
@dataclass
class MimoModelProvider(ModelProviderMixin[MimoModel]):
...
mimo_parallelism_config: Optional[MimoParallelismConfig] = None
encoder_providers: Optional[dict[str, EncoderProvider]] = None # move here
# (MIMOConfig's deployment_mode / module_parallelisms could fold into
# mimo_parallelism_config or replace it, keeping everything in one place)Downstream call sites that currently read cfg.mimo could instead read cfg.model.mimo_parallelism_config (or a renamed equivalent), and dispatch guards like:
# config.py get_data_parallel_size()
if isinstance(self.model, MimoModelProvider):
return self.model.get_data_parallel_size(world_size)_validate_mimo() would similarly move into MimoModelProvider.validate() or __post_init__, keeping MIMO invariants collocated with MIMO state.
There's also a naming overlap worth clarifying: MIMOConfig (training/mimo_config.py) and MimoParallelismConfig (models/mimo/mimo_config.py) cover very similar ground (module parallelisms, deployment mode, special tokens). Are these intentionally separate, or could one replace the other?
Happy to be corrected if there's a reason ConfigContainer needs to see these directly (e.g. Hydra instantiation, schema introspection), but the motivation isn't obvious from the PR.
There was a problem hiding this comment.
Good catch. This should have been removed. Cleaned up. ConfigContainer no longer carries mimo and other outdated fields.
|
|
||
| def get_data_parallel_size(self, world_size: int) -> int: | ||
| """Calculate the data parallel size based on the model configuration.""" | ||
| if self.mimo is not None: |
There was a problem hiding this comment.
why not use the checker util?
There was a problem hiding this comment.
all below should use checkers? is_colocated_mimo
There was a problem hiding this comment.
Good point — updated these checks to use the existing is_mimo checker for consistency.
| return self.mimo is not None | ||
|
|
||
| @property | ||
| def is_colocated_mimo(self) -> bool: |
There was a problem hiding this comment.
Yup, it was dead code.
Remove ConfigContainer-level MIMO helper and validation paths, and keep training MIMO tests focused on provider-owned mimo_config behavior to preserve phase2-style boundaries. Signed-off-by: Ali Roshan Ghias <aroshanghias@nvidia.com>
|
/ok to test 9502e7a |
|
/ok to test 5d39f33 |
|
/claude review |
| Args: | ||
| cfg: Configuration container with MimoModelProvider as cfg.model. | ||
| train_state: Current training state. | ||
| mimo_provider: MIMO dataset provider (e.g., MockMimoDatasetProvider) |
There was a problem hiding this comment.
Typo: the class is named MockMimoProvider, not MockMimoDatasetProvider.
| mimo_provider: MIMO dataset provider (e.g., MockMimoDatasetProvider) | |
| mimo_provider: MIMO dataset provider (e.g., MockMimoProvider) |
There was a problem hiding this comment.
Fixed, thanks. Updated the docstring example to use MockMimoProvider (same behavior, correct class name).
Review NotesMostly LGTM — clean scaffolding with good unit tests for collate, dataset, and dp_utils. Two inline comments posted (typo + dead parameter). Missing test coverage:
|
Fix the MIMO loader/docstring nits by removing an unused dp-utils parameter and add targeted unit tests for HF provider and MIMO loader wiring to cover guard clauses and happy path behavior. Signed-off-by: Ali Roshan Ghias <aroshanghias@nvidia.com>
|
/ok to test c580879 |
Apply ruff/ruff-format cleanup requested by CI in dp_utils typing imports and test_loaders line wrapping so lint and downstream aggregate checks pass. Signed-off-by: Ali Roshan Ghias <aroshanghias@nvidia.com>
|
/ok to test 8e36f4c |
Signed-off-by: Ali Roshan Ghias <aroshanghias@nvidia.com>
|
/ok to test e98b742 |
What does this PR do?
Implements MIMO-specific data loading that integrates with Bridge's standard training flow while supporting per-module parallelism.
Key Components
Core Classes
base_provider.pyMimoDatasetProviderdataset.pyMimoDatasetcollate.pymimo_collate_fndp_utils.pyget_mimo_dp_infoloaders.pybuild_mimo_data_loadersProviders
HFMimoDatasetProviderMockMimoProviderIntegration
build_train_valid_test_data_loaders()auto-dispatches to MIMO path when:cfg.modelisMimoModelProvidercfg.datasetisMimoDatasetProviderMimoModelProvider._gridsafterbuild_model()- reused by data loadingOutput Format
Usage
Tests
test_dataset.py- MimoDataset construction, getitem, modalitiestest_collate.py- mimo_collate_fn batchingtest_dp_utils.py- DP info for different deployment modesAll 70 tests passing.
Summary by CodeRabbit
Release Notes
New Features
Tests
✏️ Tip: You can customize this high-level summary in your review settings.