Skip to content

feat(mimo): phase 3 - MIMO data loading scaffolding#2007

Merged
yaoyu-33 merged 21 commits into
NVIDIA-NeMo:mainfrom
aroshanghias-nvd:mimo/phase3-data-loading
Mar 20, 2026
Merged

feat(mimo): phase 3 - MIMO data loading scaffolding#2007
yaoyu-33 merged 21 commits into
NVIDIA-NeMo:mainfrom
aroshanghias-nvd:mimo/phase3-data-loading

Conversation

@aroshanghias-nvd

@aroshanghias-nvd aroshanghias-nvd commented Jan 20, 2026

Copy link
Copy Markdown
Contributor

What does this PR do?

Implements MIMO-specific data loading that integrates with Bridge's standard training flow while supporting per-module parallelism.

Key Components

Core Classes

File Class Purpose
base_provider.py MimoDatasetProvider ABC for all MIMO dataset providers
dataset.py MimoDataset Unified dataset with per-modality HF preprocessing
collate.py mimo_collate_fn Batches modality_inputs for model forward
dp_utils.py get_mimo_dp_info DP rank/size for heterogeneous parallelism
loaders.py build_mimo_data_loaders Creates DataLoaders with DP-aware sampling

Providers

Provider Purpose
HFMimoDatasetProvider Loads real HuggingFace datasets with processors
MockMimoProvider Synthetic data + real HF processors for testing

Integration

  • build_train_valid_test_data_loaders() auto-dispatches to MIMO path when:
    • cfg.model is MimoModelProvider
    • cfg.dataset is MimoDatasetProvider
  • Grids cached in MimoModelProvider._grids after build_model() - reused by data loading

Output Format

batch = {
    "input_ids": (batch_size, seq_length),
    "labels": (batch_size, seq_length),
    "attention_mask": (batch_size, seq_length),
    "position_ids": (batch_size, seq_length),
    "modality_inputs": {
        "vision": {"pixel_values": (batch_size, 3, 224, 224)},
    },
}

Usage

cfg.model = MimoModelProvider(...)
cfg.dataset = HFMimoDatasetProvider(
    seq_length=2048,
    hf_dataset_path="liuhaotian/LLaVA-Instruct-150K",
    hf_tokenizer_path="meta-llama/Llama-2-7b-hf",
    processor_paths={"vision": "openai/clip-vit-large-patch14"},
    special_token_ids={"vision": 32000},
    modality_columns={"vision": "image"},
)
# DataLoaders built automatically via Bridge's training setup

Tests

  • test_dataset.py - MimoDataset construction, getitem, modalities
  • test_collate.py - mimo_collate_fn batching
  • test_dp_utils.py - DP info for different deployment modes

All 70 tests passing.

Summary by CodeRabbit

Release Notes

  • New Features

    • Added multi-modal model training capabilities, enabling vision-language model training (e.g., LLaVA-style architectures)
    • Introduced HuggingFace-based dataset provider for multi-modal data loading with unified batching across modalities
    • Added flexible parallelism configuration supporting colocated, homogeneous, and heterogeneous deployment modes
  • Tests

    • Added comprehensive unit test coverage for multi-modal data loading and model training components

✏️ Tip: You can customize this high-level summary in your review settings.

Ali Roshan Ghias added 2 commits January 15, 2026 12:49
@copy-pr-bot

copy-pr-bot Bot commented Jan 20, 2026

Copy link
Copy Markdown

This pull request requires additional validation before any workflows can run on NVIDIA's runners.

Pull request vetters can view their responsibilities here.

Contributors can view more details about this message here.

@aroshanghias-nvd aroshanghias-nvd changed the title Mimo/phase3 data loading feat(mimo): phase 3 - data loading Jan 20, 2026
@aroshanghias-nvd aroshanghias-nvd force-pushed the mimo/phase3-data-loading branch from 233b268 to bcd4b61 Compare January 27, 2026 01:36
@aroshanghias-nvd aroshanghias-nvd changed the title feat(mimo): phase 3 - data loading feat(mimo): phase 3 - MIMO data loading scaffolding Jan 27, 2026
Comment thread src/megatron/bridge/data/mimo/dataset.py
For testing with synthetic data, use MockMimoProvider instead.

Args:
seq_length: Total sequence length (text + encoder tokens).

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

how would be this seq_length arg get used for HF datasets? would something like padding to this max seq_len happens?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

seq_length is the total sequence budget for the model, used for:

  1. Placeholder insertion - We insert encoder_seq_lengths[modality] placeholder tokens per modality (e.g., 577 for CLIP ViT-L/14). These are later replaced 1:1 with encoder hidden states during forward.

  2. Text truncation - Text is truncated to fit: max_text_tokens = seq_length - sum(encoder_seq_lengths.values())

  3. Padding - Shorter sequences are padded to exactly seq_length

I've added encoder_seq_lengths parameter to specify how many placeholder tokens each encoder needs. This matches the prototype pattern in Megatron-LM (image_seq_length in MockVLMDataset).

Updated the docstrings to clarify this.

Comment thread src/megatron/bridge/data/mimo/hf_provider.py Outdated
Comment thread src/megatron/bridge/data/loaders.py Outdated
Returns:
A tuple (train_dataloader, valid_dataloader, test_dataloader).
"""
# Check for MIMO path

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ooc, how is dataloader builder func for vlm happening currently.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Current VLM uses vlm_datasets/ providers (e.g., MockVLMConversationProvider) with manual build_train_valid_test_datasets_provider functions. MIMO adds auto-dispatch because it needs heterogeneous DP-aware data loading that the standard path doesn't support.

@aroshanghias-nvd aroshanghias-nvd force-pushed the mimo/phase3-data-loading branch from 679b9bf to 0e15082 Compare January 28, 2026 19:09
@coderabbitai

coderabbitai Bot commented Jan 28, 2026

Copy link
Copy Markdown
Contributor
📝 Walkthrough

Walkthrough

This PR introduces comprehensive MIMO (Multi-Input Multi-Output) model and data loading infrastructure to Megatron, enabling heterogeneous multi-module parallel training with configurable deployment modes. The implementation spans dataset handling, model provider abstractions, parallelism configuration with validation, distributed data loading with DP-aware sharding, and training integration with comprehensive test coverage.

Changes

Cohort / File(s) Summary
MIMO Data Core
src/megatron/bridge/data/mimo/__init__.py, dataset.py, collate.py, base_provider.py
Introduces dataset wrapper, collation logic, and abstract provider interface for MIMO data handling with placeholder token insertion and per-modality preprocessing.
MIMO Data Providers
src/megatron/bridge/data/mimo/hf_provider.py, mock_provider.py
Implements HuggingFace-based and mock providers for MIMO datasets with lazy processor/tokenizer loading and synthetic data generation for testing.
MIMO Data Loading
src/megatron/bridge/data/mimo/loaders.py, dp_utils.py
Builds DP-aware data loaders with DistributedSampler-based sharding; computes per-rank DP configuration across deployment modes.
Loader Integration
src/megatron/bridge/data/loaders.py
Adds early MIMO detection path that delegates to MIMO-specific loader construction when MimoModelProvider is detected.
MIMO Parallelism Config
src/megatron/bridge/models/mimo/mimo_config.py
Introduces ModuleParallelismConfig and MimoParallelismConfig with multi-mode validation (colocated, homogeneous, heterogeneous) and comprehensive finalization logic.
MIMO Model Infrastructure
src/megatron/bridge/models/mimo/mimo_builder.py, mimo_provider.py
Constructs HyperCommGrid instances, process groups, and per-module communication topology; implements MimoModelProvider with infrastructure caching and per-module DDP wrapping.
MIMO Model Specialization
src/megatron/bridge/models/mimo/llava_provider.py, __init__.py
Introduces LlavaMimoProvider for vision-language models with preset Vicuna config and vision encoder/projector specifications; exposes public module API.
Training Integration
src/megatron/bridge/training/config.py, mimo_ddp.py
Extends ConfigContainer to support MimoModelProvider; provides selective DDP wrapping for language and modality submodules based on rank participation.
Test Suite
tests/unit_tests/data/mimo/*.py, tests/unit_tests/models/mimo/*.py, tests/unit_tests/training/mimo/*.py
Comprehensive unit tests covering collation, datasets, data utilities, model provider, configuration validation, and DDP wrapping across deployment modes and rank scenarios.

Estimated code review effort

🎯 5 (Critical) | ⏱️ ~120 minutes

Possibly related PRs

  • PR #2011: Introduces decentralized process-group and HyperCommGrid infrastructure that this MIMO implementation directly depends on for per-module grid construction and distributed communication topology.

Suggested reviewers

  • cuichenx
  • skyw
🚥 Pre-merge checks | ✅ 3 | ❌ 1
❌ Failed checks (1 warning)
Check name Status Explanation Resolution
Test Results For Major Changes ⚠️ Warning PR introduces major architectural changes requiring documented test results, but test files contain multiple blocking bugs (incorrect field names, unused imports/variables) preventing execution. Fix all identified test bugs and provide pytest output logs demonstrating all tests passing to validate the claimed 70+ passing tests.
✅ Passed checks (3 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The title clearly summarizes the main change: introduction of MIMO data loading infrastructure. It directly relates to the primary objective of the PR.
Docstring Coverage ✅ Passed Docstring coverage is 80.61% which is sufficient. The required threshold is 80.00%.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing touches
  • 📝 Generate docstrings

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

@coderabbitai coderabbitai Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 11

🤖 Fix all issues with AI agents
In `@src/megatron/bridge/data/mimo/collate.py`:
- Around line 87-107: Update the function docstring that contains the logic
using modality_batch_items / first_non_empty (the collating routine that builds
modality_inputs) to explicitly document sparse-modality behavior: state that
some batch items may omit a modality key (sparse modalities are supported), that
when stacking tensors the resulting batch dimension equals the number of items
that include that modality (not the original full batch size), and that the
model's forward pass must handle variable/unequal batch sizes across different
modalities; keep this note near the existing "stacked along batch dimension"
sentence so callers understand this designed behavior.

In `@src/megatron/bridge/data/mimo/dataset.py`:
- Around line 158-166: The attention_mask is currently created as ones
(attention_mask) and ignores padding in input_ids; change it to mask out padding
by creating attention_mask = (input_ids != pad_token_id).long() (or boolean) so
padding tokens are 0 and real tokens 1, and ensure position_ids still derived
from token positions (position_ids = torch.arange(len(input_ids)) or use
cumulative non-padding positions if needed); update any consumers expecting
int/long dtype accordingly and reference attention_mask, input_ids,
position_ids, and pad_token_id when applying the fix.

In `@src/megatron/bridge/data/mimo/loaders.py`:
- Line 107: The code computes micro_batch_size = cfg.train.global_batch_size //
dp_size without validating divisibility or non-zero result; update the logic in
loaders.py around the micro_batch_size calculation to (1) check that
cfg.train.global_batch_size >= dp_size and (2) check that
cfg.train.global_batch_size % dp_size == 0, and if either check fails raise a
clear ValueError (mentioning cfg.train.global_batch_size and dp_size) so callers
know to adjust configuration; only after validation compute micro_batch_size =
cfg.train.global_batch_size // dp_size.

In `@src/megatron/bridge/data/mimo/mock_provider.py`:
- Around line 95-113: MockMimoProvider is missing the trust_remote_code
attribute referenced in _load_processors; add a trust_remote_code boolean field
to MockMimoProvider (e.g., an __init__ parameter defaulting to False and stored
on self) so is_safe_repo can read it, or set a class-level default bool; update
the MockMimoProvider constructor to accept and assign trust_remote_code and
ensure any tests or instantiations pass the intended value when creating
processor_paths and calling _load_processors.

In `@tests/unit_tests/data/mimo/test_dp_utils.py`:
- Around line 40-50: The helper _make_mimo_cfg constructs
ModuleParallelismConfig using incorrect field names; update the
ModuleParallelismConfig instantiations inside _make_mimo_cfg to use
tensor_model_parallel_size and data_parallel_size instead of tensor_parallel and
data_parallel, keeping rank_offset and the surrounding
MimoParallelismConfig(llm_module_name="llm",
module_parallelisms=module_parallelisms, deployment_mode=deployment_mode) intact
so tests reference the correct attributes.
- Around line 71-86: The test modifies the MIMO config using the wrong field
name: in test_get_mimo_dp_info_colocated_llm_first_pp (and any similar tests
using _make_mimo_cfg) you should set
module_parallelisms["vision"].data_parallel_size and
module_parallelisms["llm"].data_parallel_size instead of data_parallel; update
those assignments so the config uses the correct attribute name
(data_parallel_size) before calling get_mimo_dp_info so the FakeGrid sizes match
the config.

In `@tests/unit_tests/training/mimo/test_mimo_config.py`:
- Around line 78-91: The test constructs ModuleParallelismConfig with incorrect
field names; replace tensor_parallel and data_parallel with the actual dataclass
fields tensor_model_parallel_size and data_parallel_size when creating
module_parallelisms for the MimoParallelismConfig in
test_mimo_heterogeneous_gap_in_middle_raises_error so the
ModuleParallelismConfig instances match the class definition and then call
mimo_parallelism_config.finalize(world_size=None) as before to assert the
ValueError.
- Around line 94-107: The test uses incorrect keyword argument names when
constructing ModuleParallelismConfig; update the two
ModuleParallelismConfig(...) calls in test_mimo_heterogeneous_leading_gap_warns
to use the correct field names (e.g., tensor_parallelism and data_parallelism
instead of tensor_parallel and data_parallel) so the objects are created
correctly and then call MimoParallelismConfig(...).finalize(world_size=None) as
before.
- Around line 110-125: The test test_mimo_heterogeneous_contiguous_no_warning
uses incorrect constructor field names for ModuleParallelismConfig; update the
module_parallelisms entries to use the real parameter names expected by
ModuleParallelismConfig (e.g., tensor_parallel_size and data_parallel_size and
the correct rank_offset field) so MimoParallelismConfig(...) and subsequent
mimo_parallelism_config.finalize(...) exercise the contiguous allocation path
without raising due to mismatched keys.

In `@tests/unit_tests/training/mimo/test_mimo_ddp.py`:
- Around line 194-196: Remove the unused local variable assignments named result
where wrap_mimo_model_distributed(...) is called (the calls to
wrap_mimo_model_distributed with arguments mimo_model, ddp_config,
mimo_parallelism_config, grids, pg_collections and the later identical call
around lines 229-231); instead invoke wrap_mimo_model_distributed(...) without
assigning its return (or assign to _ if you intentionally ignore it) so the
result variable is not created and Ruff/Flake8 F841 is resolved.
- Line 6: Remove the unused top-level import "import pytest" from the
test_mimo_ddp module; locate the import statement in the test_mimo_ddp.py file
and delete it so the linter no longer flags an unused pytest import.
🧹 Nitpick comments (15)
src/megatron/bridge/data/mimo/collate.py (1)

56-60: Standard field stacking assumes uniform sequence lengths.

torch.stack will raise a RuntimeError if items have different sequence lengths. Unlike the modality inputs handling (which catches RuntimeError), these standard fields will propagate the exception without a helpful message.

Consider adding a try/except with a more descriptive error, or validating shapes upfront.

♻️ Optional: Add descriptive error handling
     # Stack standard fields
-    input_ids = torch.stack([item["input_ids"] for item in batch])
-    labels = torch.stack([item["labels"] for item in batch])
-    attention_mask = torch.stack([item["attention_mask"] for item in batch])
-    position_ids = torch.stack([item["position_ids"] for item in batch])
+    try:
+        input_ids = torch.stack([item["input_ids"] for item in batch])
+        labels = torch.stack([item["labels"] for item in batch])
+        attention_mask = torch.stack([item["attention_mask"] for item in batch])
+        position_ids = torch.stack([item["position_ids"] for item in batch])
+    except RuntimeError as e:
+        raise RuntimeError(
+            f"Failed to stack batch items. Ensure all items have the same sequence length. "
+            f"Original error: {e}"
+        ) from e
tests/unit_tests/data/mimo/test_collate.py (1)

10-24: Consider explicit Optional type hint.

Per PEP 484, implicit Optional via = None default is discouraged. Static analysis flagged this (RUF013).

♻️ Optional fix
+from typing import Optional
+
 def make_sample(
     seq_length: int = 64,
-    modalities: dict = None,
+    modalities: Optional[dict] = None,
 ) -> dict:
tests/unit_tests/data/mimo/test_dataset.py (1)

371-410: Multi-modality placeholder ordering test relies on dict insertion order.

This test assumes placeholders appear in dict insertion order (vision before audio). While Python 3.7+ guarantees dict ordering, this implicit dependency could be fragile if the ordering contract isn't documented in MimoDataset.

Consider adding a brief comment in the test or verifying the ordering is documented in the dataset implementation.

src/megatron/bridge/data/mimo/loaders.py (3)

6-6: Remove unused import.

Dict is imported but not used in this file.

🔧 Proposed fix
-from typing import TYPE_CHECKING, Dict, Optional, Tuple
+from typing import TYPE_CHECKING, Optional, Tuple

20-27: Consider documenting or removing unused train_state parameter.

The train_state parameter is declared but not used in the function body. If this is for API consistency with other data loader builders or planned future use, consider adding a brief comment. Otherwise, it can be removed.


66-67: Use TypeError for type validation.

When validating that an argument is of a specific type, TypeError is semantically more appropriate than ValueError.

🔧 Proposed fix
     if not isinstance(cfg.model, MimoModelProvider):
-        raise ValueError("cfg.model must be MimoModelProvider for MIMO data loading.")
+        raise TypeError("cfg.model must be MimoModelProvider for MIMO data loading.")
src/megatron/bridge/training/mimo_ddp.py (1)

78-94: Edge case: empty encoders dictionary.

The condition if hasattr(submodule, 'encoders') and submodule.encoders passes if encoders is a non-empty dict, but next(iter(submodule.encoders.keys())) would fail with StopIteration if called on an empty dict. The current check is correct, but consider adding a more explicit check or comment for clarity.

🔧 Suggested clarification
-            if hasattr(submodule, 'encoders') and submodule.encoders:
+            # Skip submodules without encoders or with empty encoder dict
+            if not (hasattr(submodule, 'encoders') and submodule.encoders):
+                continue
+            
+            if True:  # Encoders present
                 encoder_key = next(iter(submodule.encoders.keys()))

Or simply add a comment to clarify the truthiness check covers the empty dict case.

src/megatron/bridge/data/mimo/hf_provider.py (2)

8-8: Remove unused imports.

List and Union are imported but not used in this file.

🔧 Proposed fix
-from typing import Any, Callable, Dict, List, Optional, Tuple, Union
+from typing import Any, Callable, Dict, Optional, Tuple

122-140: Consider narrowing exception handling.

Catching ValueError broadly may inadvertently suppress errors unrelated to missing splits (e.g., configuration issues). Consider checking the exception message or using a more specific approach.

🔧 Proposed approach
         try:
             dataset = load_dataset(
                 self.hf_dataset_path,
                 name=self.hf_dataset_name,
                 split=split,
                 streaming=self.streaming,
                 trust_remote_code=is_safe_repo(
                     trust_remote_code=self.trust_remote_code,
                     hf_path=self.hf_dataset_path,
                 ),
             )
             return dataset
-        except ValueError:
-            # Split doesn't exist
+        except ValueError as e:
+            # Check if this is specifically a missing split error
+            if "split" in str(e).lower():
+                return None
+            raise
-            return None
src/megatron/bridge/data/mimo/dataset.py (2)

6-6: Remove unused imports.

List and Union are imported but not used in this file.

🔧 Proposed fix
-from typing import Any, Callable, Dict, List, Optional, Union
+from typing import Any, Callable, Dict, Optional

189-212: Potential over-truncation of text.

The text is first tokenized with max_length=self.seq_length (line 193), then truncated again after adding prefix tokens (line 211). This means if prefix tokens take up N positions, the final text could be truncated to seq_length - N tokens, but the initial tokenization may have already truncated content that could have fit.

Consider pre-computing the available text space:

🔧 Proposed optimization
+        # Pre-compute space available for text tokens
+        total_prefix_len = sum(
+            self.encoder_seq_lengths.get(m, 1)
+            for m in modality_inputs.keys()
+            if m in self.special_token_ids
+        )
+        max_text_length = self.seq_length - total_prefix_len
+        
         # Tokenize the text
         encoded = self.tokenizer(
             text,
             truncation=True,
-            max_length=self.seq_length,
+            max_length=max_text_length,
             return_tensors="pt",
         )
         input_ids = encoded["input_ids"].squeeze(0)
         
         # Insert placeholder tokens for each modality at the beginning
-        # The order follows the order of modality_inputs (Python 3.7+ dict ordering)
         prefix_tokens = []
         for modality_name in modality_inputs.keys():
             if modality_name in self.special_token_ids:
                 token_id = self.special_token_ids[modality_name]
                 num_tokens = self.encoder_seq_lengths.get(modality_name, 1)
                 prefix_tokens.extend([token_id] * num_tokens)
         
         if prefix_tokens:
             prefix = torch.tensor(prefix_tokens, dtype=input_ids.dtype)
-            # Truncate text tokens to make room for placeholders
-            max_text_len = self.seq_length - len(prefix_tokens)
-            input_ids = input_ids[:max_text_len]
             input_ids = torch.cat([prefix, input_ids])
src/megatron/bridge/models/mimo/mimo_builder.py (1)

61-73: Return type annotation uses forward reference but static analysis can't see the lazy import.

The return type "ColocatedCommConfig" on Line 63 is correctly handled via a string annotation (forward reference), and the actual import happens on Line 65 before use. This is a valid pattern for lazy imports.

However, Line 67 creates a redundant copy of the grids dict:

♻️ Minor simplification
 def build_colocated_comm_config(
     mimo_parallelism_config: MimoParallelismConfig, grids: Dict[str, "HyperCommGrid"]
 ) -> "ColocatedCommConfig":
     """Build ColocatedCommConfig with default encoder-to-LLM topology."""
     from megatron.core.models.mimo.config.base_configs import ColocatedCommConfig

-    module_to_grid_map = {name: grid for name, grid in grids.items()}
     topology = _default_topology(mimo_parallelism_config)
     return ColocatedCommConfig(
-        module_to_grid_map=module_to_grid_map,
+        module_to_grid_map=grids,
         topology=topology,
         dim_mapping={"b": 0, "s": 1, "h": 2},
     )
tests/unit_tests/models/mimo/test_mimo_provider.py (1)

355-376: Unused variable model from static analysis is valid but intentional.

The model variable on Line 372 is assigned but never used because the test focuses on verifying the side effect (parameter freezing) via the mock. While technically unused, the assignment makes the test intent clearer.

Consider prefixing with underscore to silence the warning:

♻️ Silence unused variable warning
-        model = provider.provide()
+        _model = provider.provide()
src/megatron/bridge/models/mimo/mimo_provider.py (2)

22-22: Unused import DistributedDataParallel flagged by static analysis.

The import at Line 22 is flagged as unused. While DistributedDataParallelConfig is used, DistributedDataParallel appears to not be directly referenced in this file (it's used in mimo_ddp.py instead).

♻️ Remove unused import
-from megatron.core.distributed import DistributedDataParallel, DistributedDataParallelConfig
+from megatron.core.distributed import DistributedDataParallelConfig

294-299: Unused loop variable encoder_name can be renamed.

The loop at Line 296 iterates over encoder items but only uses encoder_spec, not encoder_name.

♻️ Rename unused loop variable
-        for encoder_name, encoder_spec in spec.submodules["encoders"].items():
+        for _encoder_name, encoder_spec in spec.submodules["encoders"].items():

Comment thread src/megatron/bridge/data/mimo/collate.py
Comment thread src/megatron/bridge/data/mimo/dataset.py
Comment thread src/megatron/bridge/data/mimo/loaders.py Outdated
Comment thread src/megatron/bridge/data/mimo/mock_provider.py
Comment thread tests/unit_tests/data/mimo/test_dp_utils.py Outdated
Comment thread tests/unit_tests/training/mimo/test_mimo_config.py Outdated
Comment thread tests/unit_tests/training/mimo/test_mimo_config.py Outdated
Comment thread tests/unit_tests/training/mimo/test_mimo_config.py Outdated
Comment thread tests/unit_tests/models/mimo/test_mimo_ddp.py Outdated
Comment thread tests/unit_tests/models/mimo/test_mimo_ddp.py Outdated
@aroshanghias-nvd aroshanghias-nvd force-pushed the mimo/phase3-data-loading branch 5 times, most recently from 97b3478 to 5ac9802 Compare January 30, 2026 22:08
Add DDP wrapping utilities, embedding group support for PP > 1, and improved validation for heterogeneous deployment.

Key changes:
- Add wrap_mimo_model_distributed() for rank-aware DDP wrapping of MIMO submodules
- Add embedding group helpers to mimo_builder.py: populate_embedding_and_position_groups(), is_pp_first_stage(), is_pp_last_stage(), is_current_rank_in_grid()
- Improve gap detection in MimoParallelismConfig._validate_heterogeneous()
- Extend _get_pg_collections_from_grids() to populate pos_embd and embd process groups
- Set variable_seq_lengths=True in provide_distributed_model()
- Update copyright headers to 2026
- Add comprehensive unit tests for all new functionality

Signed-off-by: Ali Roshan Ghias <aroshanghias@nvidia.com>
@aroshanghias-nvd

Copy link
Copy Markdown
Contributor Author

/ok to test faf91b0

@yaoyu-33

Copy link
Copy Markdown
Contributor

/claude review

num_workers=mimo_provider.num_workers,
collate_fn=collate_fn,
pin_memory=mimo_provider.pin_memory,
drop_last=mimo_provider.drop_last,

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Bug: mimo_provider.drop_last will raise AttributeError at runtime. drop_last is not defined on DataloaderConfig, DatasetProvider, HFMimoDatasetProvider, or MockMimoProvider.

Either add drop_last: bool = True to DataloaderConfig, or hardcode/parameterize it here.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added drop_last to DataloaderConfig.

Comment thread src/megatron/bridge/training/config.py Outdated
self.mimo.finalize(world_size if world_size and world_size > 1 else None)

llm_parallelism = self.mimo.get_parallelism(self.mimo.llm_module_name)
parallelism_checks = {

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Design concern: MIMO fields belong on MimoModelProvider, not ConfigContainer

Adding mimo and encoder_providers to ConfigContainer means every user of GPT/Llama/T5/Mamba/… sees these fields, even though ~95% of models have no use for them. It also creates an implicit invariant that users must keep cfg.model and cfg.mimo in sync manually — a footgun that _validate_mimo() is trying to paper over.

Both fields are logically MIMO-specific and could live entirely inside MimoModelProvider:

@dataclass
class MimoModelProvider(ModelProviderMixin[MimoModel]):
    ...
    mimo_parallelism_config: Optional[MimoParallelismConfig] = None
    encoder_providers: Optional[dict[str, EncoderProvider]] = None   # move here
    # (MIMOConfig's deployment_mode / module_parallelisms could fold into
    #  mimo_parallelism_config or replace it, keeping everything in one place)

Downstream call sites that currently read cfg.mimo could instead read cfg.model.mimo_parallelism_config (or a renamed equivalent), and dispatch guards like:

# config.py get_data_parallel_size()
if isinstance(self.model, MimoModelProvider):
    return self.model.get_data_parallel_size(world_size)

_validate_mimo() would similarly move into MimoModelProvider.validate() or __post_init__, keeping MIMO invariants collocated with MIMO state.

There's also a naming overlap worth clarifying: MIMOConfig (training/mimo_config.py) and MimoParallelismConfig (models/mimo/mimo_config.py) cover very similar ground (module parallelisms, deployment mode, special tokens). Are these intentionally separate, or could one replace the other?

Happy to be corrected if there's a reason ConfigContainer needs to see these directly (e.g. Hydra instantiation, schema introspection), but the motivation isn't obvious from the PR.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch. This should have been removed. Cleaned up. ConfigContainer no longer carries mimo and other outdated fields.

Comment thread src/megatron/bridge/training/config.py Outdated

def get_data_parallel_size(self, world_size: int) -> int:
"""Calculate the data parallel size based on the model configuration."""
if self.mimo is not None:

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why not use the checker util?

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

all below should use checkers? is_colocated_mimo

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point — updated these checks to use the existing is_mimo checker for consistency.

Comment thread src/megatron/bridge/training/config.py Outdated
return self.mimo is not None

@property
def is_colocated_mimo(self) -> bool:

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this seems never used

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yup, it was dead code.

Remove ConfigContainer-level MIMO helper and validation paths, and keep training MIMO tests focused on provider-owned mimo_config behavior to preserve phase2-style boundaries.

Signed-off-by: Ali Roshan Ghias <aroshanghias@nvidia.com>
@aroshanghias-nvd

Copy link
Copy Markdown
Contributor Author

/ok to test 9502e7a

@aroshanghias-nvd

Copy link
Copy Markdown
Contributor Author

/ok to test 5d39f33

@aroshanghias-nvd

Copy link
Copy Markdown
Contributor Author

/claude review

Args:
cfg: Configuration container with MimoModelProvider as cfg.model.
train_state: Current training state.
mimo_provider: MIMO dataset provider (e.g., MockMimoDatasetProvider)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Typo: the class is named MockMimoProvider, not MockMimoDatasetProvider.

Suggested change
mimo_provider: MIMO dataset provider (e.g., MockMimoDatasetProvider)
mimo_provider: MIMO dataset provider (e.g., MockMimoProvider)

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed, thanks. Updated the docstring example to use MockMimoProvider (same behavior, correct class name).

Comment thread src/megatron/bridge/data/mimo/dp_utils.py Outdated
@claude

claude Bot commented Mar 19, 2026

Copy link
Copy Markdown
Contributor

Review Notes

Mostly LGTM — clean scaffolding with good unit tests for collate, dataset, and dp_utils. Two inline comments posted (typo + dead parameter).

Missing test coverage:

  • HFMimoDatasetProvider (hf_provider.py) has no unit tests. Even with mocked HF dependencies (load_dataset, AutoProcessor, AutoTokenizer), basic construction and build_datasets flow should be covered.
  • build_mimo_data_loaders (loaders.py) has no unit tests. The two ValueError guard clauses and the happy-path wiring through get_mimo_dp_infoDistributedSamplerDataLoader would benefit from tests.

Fix the MIMO loader/docstring nits by removing an unused dp-utils parameter and add targeted unit tests for HF provider and MIMO loader wiring to cover guard clauses and happy path behavior.

Signed-off-by: Ali Roshan Ghias <aroshanghias@nvidia.com>
@aroshanghias-nvd

Copy link
Copy Markdown
Contributor Author

/ok to test c580879

Apply ruff/ruff-format cleanup requested by CI in dp_utils typing imports and test_loaders line wrapping so lint and downstream aggregate checks pass.

Signed-off-by: Ali Roshan Ghias <aroshanghias@nvidia.com>
@aroshanghias-nvd

Copy link
Copy Markdown
Contributor Author

/ok to test 8e36f4c

@aroshanghias-nvd

Copy link
Copy Markdown
Contributor Author

/ok to test e98b742

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants