feat(mimo): Phase 5 - checkpoint save/resume, evaluation, e2e tests#2870
Conversation
…rallelism Squash of all Phase 4 MiMo work from mimo/phase4-training (47870e4), rebased onto upstream/main at f1fb06a. Includes: - MimoModelProvider with ModuleSpec-based API and heterogeneous LLaVA support - MiMo training loop (pretrain_mimo, train_mimo, mimo_step) - Heterogeneous TP/PP/DP parallelism plumbing (mimo_parallel_utils) - MiMo data loading (collate, dataset, loaders, hf_provider, mock_provider) - Data loader dispatch routing for MIMO models (loaders.py) - MiMo DDP wrapping and model builder - Kamran's loss mask and heterogeneous LLaVA dataset support - Megatron-LM submodule pinned to PR #3212 head - Full unit test coverage (provider, config, step, collate, pretrain tests) Phase 5 (checkpointing/evaluation) is stacked in a separate branch. Original commit history preserved in backup/mimo-phase4-training-v0 (47870e4). Signed-off-by: Ali Roshan Ghias <aroshanghias@nvidia.com>
ee23945 to
9642406
Compare
1e7481d to
a49c524
Compare
Co-authored-by: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Squash of all Phase 5 MiMo checkpointing/evaluation work from mimo/phase5-checkpointing (989842e), stacked on Phase 4 rebuild. Includes: - Checkpoint save/resume wiring for heterogeneous MIMO models - MiMo evaluation infrastructure (eval.py MIMO extensions) - Distributed batch slicing for evaluation (dp_utils.slice_batch_for_mimo) - E2E training tests (test_mimo_training_e2e, test_mimo_training_llava) - E2E checkpoint resume tests (test_mimo_checkpoint_resume_e2e) - Parallelism test runner (run_mimo_parallelism_tests.sh) - Full checkpoint unit test coverage (test_mimo_checkpointing — 1159 lines) Original commit history preserved in backup/mimo-phase5-checkpointing-v0 (989842e). Signed-off-by: Ali Roshan Ghias <aroshanghias@nvidia.com>
…iring The _make_setup_output fixture was missing pg_collections and checkpointing_context attributes needed by the Phase 5 checkpoint code path in pretrain_mimo. Also set checkpoint config fields to None and build_data_iterators_fn return value so the test completes without hitting unrelated code paths. Pre-existing test gap at 989842e. Signed-off-by: Ali Roshan Ghias <aroshanghias@nvidia.com>
…int test Restores the checkpoint resume test wrapper from mimo/wip-phase4-training. Runs save→resume round-trip across multiple parallelism configs. Signed-off-by: Ali Roshan Ghias <aroshanghias@nvidia.com>
cd3f7fc to
e4d2fdf
Compare
Signed-off-by: Kamran Jafari <kjafarisadeg@nvidia.com>
Signed-off-by: Li Ding <liding@nvidia.com>
Signed-off-by: Li Ding <liding@nvidia.com>
Signed-off-by: Li Ding <liding@nvidia.com>
Signed-off-by: Li Ding <liding@nvidia.com>
Signed-off-by: Li Ding <liding@nvidia.com>
Keep current Megatron-LM submodule version (resolve submodule conflict with ours). Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Merges phase 4 structural refactor (16 changes) into phase 5 (checkpointing, evaluation, e2e tests). Resolves 7 conflict files and adapts phase 5 code to phase 4's refactored structure. Key merge decisions: - pretrain_mimo: phase 4's thin entry point + phase 5's checkpoint load, MPU bridging, scheduler fan-out, deferred data iterators - checkpointing: phase 4's CheckpointManager refactor + phase 5's MiMo bypass (pg_collection, is_mimo detection, optimizer load skip) - train.py: both pg_collection (phase 5) and callback_manager (phase 4) - train_mimo: checkpointing_context migrated to checkpoint_manager API - setup_mimo: added checkpoint_manager, async worker init, start_time sync Src fixes: - setup_mimo.py: global_state.cfg set unconditionally (was inside conditional) - checkpointing.py: pg_collection added to CheckpointSaveContext Test fixes: - 33 unit tests updated for checkpoint_manager API - e2e tests: removed old params (mimo_provider, opt_config, schedulers), added min_lr=0.0, report_theoretical_memory monkey-patch Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
…_theoretical_memory) - Set cfg.ddp explicitly with DDP fields (was only on cfg.train before). setup_mimo.py reads cfg.ddp directly; without this, use_distributed_optimizer defaults to False causing CPU/CUDA device mismatch on vision encoder Conv2d. - Add report_theoretical_memory monkey-patch (MimoModelProvider has no kv_channels). Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Phase 4 moved DDP config from cfg.train to cfg.ddp. Remove redundant assignments on train_cfg (grad_reduce_in_fp32, overlap_grad_reduce, use_distributed_optimizer, check_for_nan_in_grad) and set them directly on DistributedDataParallelConfig instead. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Signed-off-by: Li Ding <liding@nvidia.com>
|
/ok to test 6fa707b |
|
/ok to test ab9b8cd |
📝 WalkthroughWalkthroughThis pull request extends MIMO training infrastructure with module-local data-parallel slicing, per-module checkpoint context, and comprehensive E2E test coverage. Key changes introduce Changes
Sequence Diagram(s)sequenceDiagram
participant Rank as PP Rank 0
participant DataLoader as Data Loader
participant Model as MIMO Model
participant ForwardStep as Forward Step
participant Comm as MIMO Communicator
Rank->>DataLoader: request batch
DataLoader->>DataLoader: sample via sampler_dp_size (global consistency)
DataLoader-->>Rank: global_batch
Rank->>ForwardStep: fetch global_batch
ForwardStep->>ForwardStep: compute module_dp_rank, module_dp_size
ForwardStep->>ForwardStep: slice_batch_for_mimo(global_batch, module_dp_rank, module_dp_size)
ForwardStep-->>Model: module_local_batch
Model->>Comm: forward with module_local_batch
Comm-->>Model: loss / output tensors
sequenceDiagram
participant Setup as setup_mimo()
participant GlobalState as GlobalState
participant IterBuilder as build_data_iterators_fn
participant CheckpointMgr as CheckpointManager
participant TrainLoop as pretrain_mimo()
Setup->>CheckpointMgr: create DefaultCheckpointManager(cfg.checkpoint)
Setup->>GlobalState: initialize_async_checkpoint_worker()
Setup->>GlobalState: synchronize start_time across ranks
Setup-->>TrainLoop: return MimoSetupOutput (with checkpoint_manager)
TrainLoop->>IterBuilder: build_data_iterators_fn(cfg, mimo_infra, train_state=None)
IterBuilder-->>TrainLoop: train_iter, valid_iter
TrainLoop->>TrainLoop: load checkpoint if exists (with pg_collection, module_name)
TrainLoop->>TrainLoop: propagate first scheduler state to other schedulers
TrainLoop->>TrainLoop: training loop iteration
TrainLoop->>CheckpointMgr: finalize_async_saves(blocking=False)
TrainLoop->>TrainLoop: checkpoint_and_decide_exit(checkpoint_manager, pg_collection, module_name)
CheckpointMgr-->>TrainLoop: should_exit
Estimated code review effort🎯 4 (Complex) | ⏱️ ~50 minutes 🚥 Pre-merge checks | ✅ 3 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (3 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches📝 Generate docstrings
🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Actionable comments posted: 2
Note
Due to the large number of review comments, Critical severity comments were prioritized as inline comments.
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
src/megatron/bridge/training/checkpointing.py (1)
2593-2635:⚠️ Potential issue | 🟠 MajorCarry
is_mimothrough the direct-iteration torch-dist path.
_load_base_checkpoint()now acceptsis_mimo, but theiteration == _DIRECT_ITERATION_DIR_SENTINELbranch never forwards it to_load_global_dist_base_checkpoint(). Loading an explicititer_*directory therefore re-enables access-integrity validation for the same MiMo+PP>1 case this patch is trying to exempt.Suggested fix
if ckpt_format == "torch_dist": return _load_global_dist_base_checkpoint( load_dir, ckpt_cfg, rank0, sharded_state_dict, iteration=None, release=False, checkpoint_path_override=checkpoint_path, checkpointing_context=checkpointing_context, + is_mimo=is_mimo, pg_collection=pg_collection, )🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@src/megatron/bridge/training/checkpointing.py` around lines 2593 - 2635, The direct-iteration branch in _load_base_checkpoint doesn't forward the is_mimo flag to _load_global_dist_base_checkpoint, so pass the existing is_mimo parameter through the call in the iteration == _DIRECT_ITERATION_DIR_SENTINEL branch (the torch_dist branch that constructs checkpoint_path and ckpt_format) by adding is_mimo=is_mimo to the argument list of _load_global_dist_base_checkpoint (alongside the existing checkpoint_path_override, checkpointing_context, and pg_collection) to preserve the MiMo handling.
🟠 Major comments (22)
src/megatron/bridge/training/setup_mimo.py-24-25 (1)
24-25:⚠️ Potential issue | 🟠 MajorHonor custom checkpoint managers in the MiMo setup path.
This hard-codes
DefaultCheckpointManager, socheckpoint.custom_manager_classworks in the standard path but is ignored for MiMo setup. Please go through the checkpoint-manager factory here and type the field against the interface.Suggested fix
-from megatron.bridge.training.checkpointing import DefaultCheckpointManager +from megatron.bridge.training.checkpointing import CheckpointManager, create_checkpoint_manager @@ - checkpoint_manager: DefaultCheckpointManager + checkpoint_manager: CheckpointManager @@ - checkpoint_manager = DefaultCheckpointManager(cfg.checkpoint) + checkpoint_manager = create_checkpoint_manager(cfg.checkpoint)Also applies to: 121-121, 349-350
src/megatron/bridge/models/mimo/mimo_provider.py-547-562 (1)
547-562:⚠️ Potential issue | 🟠 MajorMove frozen buffers too, not just frozen parameters.
This only migrates
Parameters. Withuse_cpu_initialization=True, frozen submodules can still keep CPU buffers, so the first forward can still fail with a device mismatch after this helper runs.Suggested fix
if not torch.cuda.is_available(): return device = torch.cuda.current_device() for param in model.parameters(): if not param.requires_grad and param.device.type == "cpu": param.data = param.data.to(device) + for buffer_name, buffer in model.named_buffers(): + if buffer.device.type == "cpu": + module_path, _, local_name = buffer_name.rpartition(".") + target_module = model.get_submodule(module_path) if module_path else model + setattr(target_module, local_name, buffer.to(device))🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@src/megatron/bridge/models/mimo/mimo_provider.py` around lines 547 - 562, The helper _move_frozen_params_to_device currently only moves frozen torch.nn.Parameter objects; update it to also move frozen buffers so CPU-resident buffers don't cause device-mismatch on first forward. Inside _move_frozen_params_to_device, after handling model.parameters() (or in the same loop), iterate model.buffers() and for any buffer with device.type == "cpu" and not attached to a grad-requiring Parameter (i.e., treat as frozen buffer), call buffer.data = buffer.data.to(device) (or equivalent) so all non-trainable tensors are moved to torch.cuda.current_device() before DDP wrapping.tests/e2e/mimo/run_hetero_llava.sh-25-27 (1)
25-27:⚠️ Potential issue | 🟠 MajorMake the required artifact paths configurable.
This runner hard-codes placeholder paths for
--vision-encoder-checkpoint,--language-model-checkpoint, and--dataset-root, so it cannot succeed without manual editing.Suggested fix
+VISION_ENCODER_CHECKPOINT="${VISION_ENCODER_CHECKPOINT:?set VISION_ENCODER_CHECKPOINT}" +LANGUAGE_MODEL_CHECKPOINT="${LANGUAGE_MODEL_CHECKPOINT:?set LANGUAGE_MODEL_CHECKPOINT}" +DATASET_ROOT="${DATASET_ROOT:?set DATASET_ROOT}" + uv run torchrun \ --nproc_per_node "$GPUS_PER_NODE" \ --nnodes "$NUM_NODES" \ tests/e2e/mimo/test_mimo_training_llava.py \ @@ - --vision-encoder-checkpoint /path/to/clip_checkpoint \ - --language-model-checkpoint /path/to/llm_checkpoint \ - --dataset-root /path/to/llava/pretrain/dataset + --vision-encoder-checkpoint "$VISION_ENCODER_CHECKPOINT" \ + --language-model-checkpoint "$LANGUAGE_MODEL_CHECKPOINT" \ + --dataset-root "$DATASET_ROOT"🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tests/e2e/mimo/run_hetero_llava.sh` around lines 25 - 27, The script currently hardcodes placeholder paths for --vision-encoder-checkpoint, --language-model-checkpoint, and --dataset-root; update the runner to accept these as configurable inputs by reading environment variables (e.g., VISION_ENCODER_CHECKPOINT, LANGUAGE_MODEL_CHECKPOINT, DATASET_ROOT) or parsing positional/optional CLI arguments, substitute those variables into the flags, and validate that each is set (print a clear error and exit non-zero if any are missing) so the script can run without manual editing; locate the invocation that uses the three flags and replace literal paths with the variable references and add a quick usage/help message.tests/e2e/mimo/run_hetero_llava_parallelism_tests.sh-84-85 (1)
84-85:⚠️ Potential issue | 🟠 MajorUse unique names for the asymmetric configs.
Both entries are named
asymmetric_2_6, so--config asymmetric_2_6can only run the first one and the summary collapses two different layouts under the same label.Suggested fix
- "asymmetric_2_6|1|2|1|0|2|1|3|2|3" - "asymmetric_2_6|2|1|1|0|2|1|3|2|3" + "asymmetric_2_6_llm_pp2|1|2|1|0|2|1|3|2|3" + "asymmetric_2_6_llm_tp2|2|1|1|0|2|1|3|2|3"🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tests/e2e/mimo/run_hetero_llava_parallelism_tests.sh` around lines 84 - 85, The two config entries both use the label "asymmetric_2_6" which causes collisions in selection and summary; update the two strings (the entries shown as "asymmetric_2_6|1|2|1|0|2|1|3|2|3" and "asymmetric_2_6|2|1|1|0|2|1|3|2|3") to use unique names (e.g., "asymmetric_2_6_a|..." and "asymmetric_2_6_b|..." or similar distinct suffixes) so that --config can target each layout individually and summaries do not merge them.tests/e2e/mimo/run_hetero_llava_parallelism_tests.sh-98-105 (1)
98-105:⚠️ Potential issue | 🟠 MajorReject GPU counts other than 2, 4, or 8.
--gpus 16still selectsCONFIGS_8GPU, but line 194 launches 16 workers. The config offsets only map ranks 0-7, so the extra ranks have no module assignment and will fail the MiMo PG assertions.Suggested fix
-# Select configs based on GPU count -if [[ $NUM_GPUS -ge 8 ]]; then - CONFIGS=("${CONFIGS_8GPU[@]}") -elif [[ $NUM_GPUS -ge 4 ]]; then - CONFIGS=("${CONFIGS_4GPU[@]}") -else - CONFIGS=("${CONFIGS_2GPU[@]}") -fi +# Select configs based on exact GPU count +case "${NUM_GPUS}" in + 8) + CONFIGS=("${CONFIGS_8GPU[@]}") + ;; + 4) + CONFIGS=("${CONFIGS_4GPU[@]}") + ;; + 2) + CONFIGS=("${CONFIGS_2GPU[@]}") + ;; + *) + echo "Unsupported GPU count: ${NUM_GPUS}. Expected one of: 2, 4, 8." + exit 1 + ;; +esac🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tests/e2e/mimo/run_hetero_llava_parallelism_tests.sh` around lines 98 - 105, The script currently uses a >= check and maps any GPU count >=8 to CONFIGS_8GPU which allows unsupported counts (e.g., 16) and causes rank mapping failures; update the NUM_GPUS validation so only exact supported GPU counts are accepted (2, 4, or 8): replace the >= branches with explicit checks against 2, 4, and 8 (using NUM_GPUS == 2/4/8 or a membership test), set CONFIGS to the corresponding CONFIGS_2GPU/CONFIGS_4GPU/CONFIGS_8GPU array, and for any other NUM_GPUS value print a clear error referencing NUM_GPUS and exit non‑zero before launching workers.tests/e2e/mimo/test_mimo_checkpoint_resume_e2e.py-470-494 (1)
470-494:⚠️ Potential issue | 🟠 MajorAlways tear down NCCL and close the log file in a
finallyblock.If
_run_phase_save()or_run_phase_resume()raises, line 494 is never reached. Intorchrunthat can leave other ranks stuck in collectives and make the next phase flaky.Suggested fix
- dist.init_process_group("nccl") - rank = dist.get_rank() - local_rank = int(os.environ.get("LOCAL_RANK", 0)) - torch.cuda.set_device(local_rank) - - log_dir = "/tmp/mimo_resume_e2e_logs" - os.makedirs(log_dir, exist_ok=True) - _rank_log_file = open(f"{log_dir}/rank_{rank}_{args.phase}.log", "w") - - logging.basicConfig( - level=logging.INFO, - format=f"[Rank {rank}] %(name)s: %(message)s", - handlers=[ - logging.FileHandler(f"{log_dir}/rank_{rank}_{args.phase}_full.log", mode="w"), - logging.StreamHandler(sys.stderr), - ], - force=True, - ) - - if args.phase == "save": - _run_phase_save(args.ckpt_dir) - else: - _run_phase_resume(args.ckpt_dir) - - dist.destroy_process_group() + dist.init_process_group("nccl") + try: + rank = dist.get_rank() + local_rank = int(os.environ.get("LOCAL_RANK", 0)) + torch.cuda.set_device(local_rank) + + log_dir = "/tmp/mimo_resume_e2e_logs" + os.makedirs(log_dir, exist_ok=True) + _rank_log_file = open(f"{log_dir}/rank_{rank}_{args.phase}.log", "w") + + logging.basicConfig( + level=logging.INFO, + format=f"[Rank {rank}] %(name)s: %(message)s", + handlers=[ + logging.FileHandler(f"{log_dir}/rank_{rank}_{args.phase}_full.log", mode="w"), + logging.StreamHandler(sys.stderr), + ], + force=True, + ) + + if args.phase == "save": + _run_phase_save(args.ckpt_dir) + else: + _run_phase_resume(args.ckpt_dir) + finally: + if _rank_log_file is not None: + _rank_log_file.close() + if dist.is_initialized(): + dist.destroy_process_group()🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tests/e2e/mimo/test_mimo_checkpoint_resume_e2e.py` around lines 470 - 494, Wrap the execution of _run_phase_save and _run_phase_resume in a try/finally so that dist.destroy_process_group() and closing of the opened _rank_log_file always run even if an exception occurs; specifically, after calling dist.init_process_group(...) and opening _rank_log_file, execute the phase call inside try, and in the finally block call dist.destroy_process_group() and _rank_log_file.close() (and any other necessary cleanup such as flushing handlers) to guarantee NCCL is torn down and the log file is closed if _run_phase_save or _run_phase_resume raises.tests/e2e/mimo/test_mimo_training_e2e.py-307-370 (1)
307-370:⚠️ Potential issue | 🟠 MajorClean up distributed state on failures too.
If model setup or
pretrain_mimo()raises, line 370 is never reached. That leaves NCCL initialized and_rank_log_fileopen, which can wedge the rest of the E2E job after the first failing rank.Suggested fix
- dist.init_process_group("nccl") - rank = dist.get_rank() - local_rank = int(os.environ.get("LOCAL_RANK", 0)) - torch.cuda.set_device(local_rank) - - log_dir = "/tmp/mimo_e2e_logs" - os.makedirs(log_dir, exist_ok=True) - _rank_log_file = open(f"{log_dir}/rank_{rank}.log", "w") - - logging.basicConfig( - level=logging.INFO, - format=f"[Rank {rank}] %(name)s: %(message)s", - handlers=[ - logging.FileHandler(f"{log_dir}/rank_{rank}_full.log", mode="w"), - logging.StreamHandler(sys.stderr), - ], - force=True, - ) - logging.getLogger("megatron.core.pipeline_parallel.bridge_communicator").setLevel(logging.DEBUG) - logging.getLogger("megatron.core.pipeline_parallel.multimodule_communicator").setLevel(logging.DEBUG) - - _log(f"distributed initialized (world_size={dist.get_world_size()})") - - _log("building model specs") - language_model_spec, modality_submodules_spec, special_token_ids = _build_model_specs() - mimo_parallelism_config = _build_parallelism_config() - - mimo_provider = MimoModelProvider( - language_model_spec=language_model_spec, - modality_submodules_spec=modality_submodules_spec, - special_token_ids=special_token_ids, - mimo_parallelism_config=mimo_parallelism_config, - topology={"vision": ["language"], "language": []}, - use_cpu_initialization=True, - ) - if not hasattr(mimo_provider, "num_moe_experts"): - mimo_provider.num_moe_experts = None - - _log("building data provider") - mock_data_provider = _build_mock_data_provider() - - opt_config = BridgeOptimizerConfig(lr=1e-4, min_lr=0.0) - - _log("building config") - cfg = _build_config( - mimo_provider, - mock_data_provider, - opt_config, - wandb_project=os.environ.get("WANDB_PROJECT", "Megatron-Bridge-MIMO"), - wandb_exp_name=os.environ.get("WANDB_EXP_NAME", "mimo-e2e-test"), - wandb_entity=os.environ.get("WANDB_ENTITY"), - wandb_save_dir=os.environ.get("WANDB_SAVE_DIR", "/tmp/wandb"), - ) - - _log("launching pretrain_mimo") - pretrain_mimo( - cfg=cfg, - forward_step_func=mimo_forward_step, - build_data_iterators_fn=_build_data_iterators, - ) - - _log("PASSED") - - dist.destroy_process_group() + dist.init_process_group("nccl") + try: + rank = dist.get_rank() + local_rank = int(os.environ.get("LOCAL_RANK", 0)) + torch.cuda.set_device(local_rank) + + log_dir = "/tmp/mimo_e2e_logs" + os.makedirs(log_dir, exist_ok=True) + _rank_log_file = open(f"{log_dir}/rank_{rank}.log", "w") + + logging.basicConfig( + level=logging.INFO, + format=f"[Rank {rank}] %(name)s: %(message)s", + handlers=[ + logging.FileHandler(f"{log_dir}/rank_{rank}_full.log", mode="w"), + logging.StreamHandler(sys.stderr), + ], + force=True, + ) + logging.getLogger("megatron.core.pipeline_parallel.bridge_communicator").setLevel(logging.DEBUG) + logging.getLogger("megatron.core.pipeline_parallel.multimodule_communicator").setLevel(logging.DEBUG) + + _log(f"distributed initialized (world_size={dist.get_world_size()})") + ... + _log("PASSED") + finally: + if _rank_log_file is not None: + _rank_log_file.close() + if dist.is_initialized(): + dist.destroy_process_group()🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tests/e2e/mimo/test_mimo_training_e2e.py` around lines 307 - 370, Wrap the distributed setup, model/data/config build and call to pretrain_mimo (everything after dist.init_process_group and before dist.destroy_process_group) in a try/finally so failures still run cleanup; in the finally block close the opened _rank_log_file (if it exists/has not been closed) and call dist.destroy_process_group() (guarded with dist.is_initialized() or a try/except) and optionally call logging.shutdown() to flush file handlers so NCCL and log files are not left open after exceptions in pretrain_mimo, _build_model_specs, or related setup.src/megatron/bridge/training/eval.py-204-207 (1)
204-207:⚠️ Potential issue | 🟠 MajorRaise an error when multimodule evaluation lacks an explicit
MultiModulePipelineCommunicator.If
pg_collectionis multimodule butp2p_communicatoris None, the code silently creates a single-moduleP2PCommunicatorwhileis_multimoduleremains True. This causes a type mismatch at line 259 when accessingeval_p2p_communicator.is_pp_last_stage, which exists only onMultiModulePipelineCommunicator. Raise a clear error instead.This applies to both locations (lines 204-207 and 322-325):
Suggested fix
eval_p2p_communicator = p2p_communicator if eval_p2p_communicator is None: + if is_multimodule: + raise ValueError( + "Multimodule evaluation requires an explicit MultiModulePipelineCommunicator." + ) eval_p2p_communicator = P2PCommunicator(pp_group=pg_collection.pp, config=model_config)non_loss_p2p_communicator = p2p_communicator if non_loss_p2p_communicator is None: + if is_multimodule: + raise ValueError( + "Multimodule evaluation requires an explicit MultiModulePipelineCommunicator." + ) non_loss_p2p_communicator = P2PCommunicator(pp_group=pg_collection.pp, config=model_config)🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@src/megatron/bridge/training/eval.py` around lines 204 - 207, The code currently defaults to creating a P2PCommunicator when p2p_communicator is None, which silently allows a P2PCommunicator to be used even if the pipeline is multimodule; update both places where eval_p2p_communicator is set (the block using p2p_communicator and the analogous block later) to check pg_collection.is_multimodule (or equivalent multimodule flag) and if True and p2p_communicator is None raise a clear ValueError stating that a MultiModulePipelineCommunicator must be provided; otherwise (when not multimodule) keep the existing behavior of creating a P2PCommunicator. Ensure the error references the types (MultiModulePipelineCommunicator vs P2PCommunicator) and the variable eval_p2p_communicator so the caller can fix the injection.tests/e2e/mimo/test_mimo_training_llava.py-647-797 (1)
647-797:⚠️ Potential issue | 🟠 MajorGuarantee distributed teardown on failures.
If checkpoint loading or
pretrain_mimo()raises, this function skips both_rank_log_filecleanup anddist.destroy_process_group(). In a multi-rank E2E test that usually leaves surviving ranks blocked until timeout. Wrap the training body intry/finallyand clean up conditionally.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tests/e2e/mimo/test_mimo_training_llava.py` around lines 647 - 797, Wrap the main training body (from after dist.init_process_group(...) through the call to pretrain_mimo(...) and subsequent logs) in a try/finally block so teardown always runs; in finally, if _rank_log_file is set and not closed close it, and call dist.destroy_process_group() only if torch.distributed.is_initialized() to guarantee the process group is torn down on exceptions, and optionally flush/close any per-rank logging handlers created around logging.basicConfig to avoid dangling file handles (refer to main(), pretrain_mimo, _rank_log_file, and dist.destroy_process_group).tests/e2e/mimo/test_mimo_training_llava.py-689-698 (1)
689-698:⚠️ Potential issue | 🟠 MajorApply the PP-size patch to the vision encoder too.
Lines 696-697 fix the language side, but
MIMO_VISION_PP > 1still leaves the CLIPTransformerConfig.pipeline_model_parallel_sizeat1. That makes each vision PP stage build all vision layers instead of its local shard, which is the same failure mode your language-side patch is avoiding.Suggested fix
llm_pp_size = mimo_parallelism_config.module_parallelisms["language"].pipeline_model_parallel_size language_model_spec.params["config"].pipeline_model_parallel_size = llm_pp_size + vision_pp_size = mimo_parallelism_config.module_parallelisms["images"].pipeline_model_parallel_size + modality_submodules_spec["images"].submodules["encoders"]["clip"].params[ + "transformer_config" + ].pipeline_model_parallel_size = vision_pp_size🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tests/e2e/mimo/test_mimo_training_llava.py` around lines 689 - 698, The vision encoder's TransformerConfig.pipeline_model_parallel_size is not being set, so when MIMO_VISION_PP > 1 each vision PP stage builds all vision layers; after building mimo_parallelism_config (from _build_parallelism_config()) read the vision module's pipeline size (e.g. mimo_parallelism_config.module_parallelisms["vision"].pipeline_model_parallel_size) and assign it into the vision model spec config (e.g. vision_model_spec.params["config"].pipeline_model_parallel_size) just like you do for language_model_spec so each vision PP stage only builds its local shard.tests/e2e/mimo/test_mimo_training_llava.py-777-783 (1)
777-783:⚠️ Potential issue | 🟠 MajorThis disables the RNG-resume path the PR just added.
The comment here is stale with the checkpointing changes in this PR:
src/megatron/bridge/training/checkpointing.py:743-787andsrc/megatron/bridge/training/checkpointing.py:1658-1688now takemodule_namespecifically to namespace per-module RNG shards. Forcingcfg.checkpoint.save_rng = Falsemakes--load-checkpointresumes non-reproducible and prevents this harness from covering the new functionality.Suggested fix
- # MiMo RNG save is not yet supported: each module produces ShardedObject - # with key "rng_state" using module-local PP/TP/DP ranks, causing - # duplicate shard keys across modules. Disable until upstream fix. - cfg.checkpoint.save_rng = False + cfg.checkpoint.save_rng = True🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tests/e2e/mimo/test_mimo_training_llava.py` around lines 777 - 783, The test is forcing RNG save off by setting cfg.checkpoint.save_rng = False which disables the new per-module RNG sharding; remove that override (or set cfg.checkpoint.save_rng = True) so the test exercises the RNG-resume path, relying on the checkpointing implementation that now accepts module_name to namespace RNG shards (the checkpointing code that takes module_name for per-module RNG shard keys should handle duplicates).tests/e2e/mimo/test_mimo_training_llava.py-641-643 (1)
641-643:⚠️ Potential issue | 🟠 MajorDon't use
type=boolfor these CLI toggles—they cannot be disabled from the command line.With
argparse,type=boolconverts any non-empty string (including"False") toTrue, making it impossible to disable these flags via CLI. Useargparse.BooleanOptionalActioninstead to enable both--freeze-visionand--no-freeze-visionforms.Suggested fix
- parser.add_argument("--freeze-vision", type=bool, default=True, help="Freeze the vision encoder (default: True)") - parser.add_argument("--freeze-llm", type=bool, default=True, help="Freeze the language model (default: True)") - parser.add_argument("--freeze-projector", type=bool, default=False, help="Freeze the projector (default: False)") + parser.add_argument( + "--freeze-vision", + action=argparse.BooleanOptionalAction, + default=True, + help="Freeze the vision encoder (default: True)", + ) + parser.add_argument( + "--freeze-llm", + action=argparse.BooleanOptionalAction, + default=True, + help="Freeze the language model (default: True)", + ) + parser.add_argument( + "--freeze-projector", + action=argparse.BooleanOptionalAction, + default=False, + help="Freeze the projector (default: False)", + )🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tests/e2e/mimo/test_mimo_training_llava.py` around lines 641 - 643, The CLI flags for freeze toggles use type=bool which treats any non-empty string as True; update the three parser.add_argument calls for "--freeze-vision", "--freeze-llm", and "--freeze-projector" to use action=argparse.BooleanOptionalAction (remove type=bool) so users can pass both --freeze-vision / --no-freeze-vision (and similarly for the other two) and keep the same default values; locate these calls in tests/e2e/mimo/test_mimo_training_llava.py (the parser.add_argument lines shown) and replace the action/flag configuration accordingly.tests/e2e/mimo/convert_hf_llama_to_megatron.py-198-205 (1)
198-205:⚠️ Potential issue | 🟠 MajorValidate TP divisibility before chunking.
torch.chunk()will happily return uneven shards, so a value like--tensor-parallel-size 3can generate per-rank tensors whose shapes no longer match Megatron’s partitioned layers. This should fail fast before writing checkpoints, both in the generic chunking path and in the SwiGLUlinear_fc1special case.Suggested guard
elif suffix == "mlp.gate_proj.weight": up_key = name.replace("gate_proj", "up_proj") gate = new_tensor # [ffn_hidden, hidden] up = state_dict[up_key].float() # [ffn_hidden, hidden] + if gate.size(0) % tensor_parallel_size != 0: + raise ValueError( + f"{name} cannot be split evenly across TP={tensor_parallel_size}: " + f"size {gate.size(0)} on dim 0" + ) new_name = f"{base}.mlp.linear_fc1.weight" # SwiGLU TP: chunk gate and up independently so each rank # gets [gate_chunk; up_chunk] — Megatron splits the activation # output in half (first half = gate, second half = up). gate_chunks = torch.chunk(gate, tensor_parallel_size, dim=0) @@ # Split for tensor parallelism if chunk_dim is None: chunks = [new_tensor] * tensor_parallel_size else: + if new_tensor.size(chunk_dim) % tensor_parallel_size != 0: + raise ValueError( + f"{new_name} cannot be split evenly across TP={tensor_parallel_size}: " + f"size {new_tensor.size(chunk_dim)} on dim {chunk_dim}" + ) chunks = torch.chunk(new_tensor, tensor_parallel_size, dim=chunk_dim)Also applies to: 216-223, 381-386
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tests/e2e/mimo/convert_hf_llama_to_megatron.py` around lines 198 - 205, Before calling torch.chunk (including the SwiGLU special-case that splits gate and up into gate_chunks/up_chunks and the generic chunking path used elsewhere, e.g., around linear_fc1 handling), validate that the tensor's first-dimension size is exactly divisible by tensor_parallel_size and raise/abort with a clear error if not; locate the splitting code that builds new_state_dicts[tp]["model"][new_name] (and the code that sets extra_key when use_te is true) and add the divisibility check there (and mirror it in the generic chunking routine) so uneven shards are detected and the process fails fast rather than producing mismatched per-rank tensors.tests/e2e/mimo/run_mimo_parallelism_tests.sh-98-106 (1)
98-106:⚠️ Potential issue | 🟠 MajorThese per-config environment overrides are currently a no-op.
test_mimo_training_e2e.pystill builds a fixedMimoParallelismConfigand itsmain()does not read any of theMIMO_*variables exported here, so every entry in this matrix currently exercises the same hardcoded topology instead of the config named in the shell runner.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tests/e2e/mimo/run_mimo_parallelism_tests.sh` around lines 98 - 106, The shell runner exports per-config env vars like MIMO_LLM_TP/MIMO_VISION_TP but the test entrypoint test_mimo_training_e2e.py currently ignores them and always builds a hardcoded MimoParallelismConfig; update the test to read the environment variables in its main() (or factory used by main) and construct MimoParallelismConfig from MIMO_LLM_*/MIMO_VISION_* and MIMO_*_OFFSET if set (falling back to existing defaults), or alternatively modify the test's config builder to accept overrides via os.environ (or argparse flags passed from the shell) so each matrix entry actually exercises the intended topology; ensure references to MimoParallelismConfig and main() in test_mimo_training_e2e.py are the places you change.tests/e2e/mimo/run_hetero_llava_parallelism_tests_unfrozen_llm.sh-114-145 (1)
114-145:⚠️ Potential issue | 🟠 MajorThe checkpoint cache key is missing the model identity.
These cache directories only vary by TP size, so changing
HF_VISION_MODEL,HF_LLM_MODEL, orMEGATRON_VOCAB_SIZEwill silently reuse stale conversions from a different source checkpoint. That can make this E2E runner train against the wrong weights while still looking “cached.”A simple fix is to incorporate the source model name (or a sanitized hash of it) and the vocab-size setting into the cache path, or write/read a small metadata file before reusing the directory.
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tests/e2e/mimo/run_hetero_llava_parallelism_tests_unfrozen_llm.sh` around lines 114 - 145, The cache dirs (clip_ckpt_dir and llm_ckpt_dir built from CHECKPOINT_BASE_DIR + TP size) omit source model identity and vocab size, causing stale conversions to be reused; update the directory naming or validation logic in the conversion block that sets clip_ckpt_dir and llm_ckpt_dir (and returns CONVERTED_CLIP_CKPT / CONVERTED_LLM_CKPT) to include a sanitized HF_VISION_MODEL and HF_LLM_MODEL identifier (or a short hash) and the MEGATRON_VOCAB_SIZE, or alternatively write/read a small metadata file (containing HF_* model names and vocab size) into the cache folder and refuse reuse when metadata mismatches before running convert_hf_*_to_megatron.py.tests/e2e/mimo/run_mimo_parallelism_tests.sh-65-72 (1)
65-72:⚠️ Potential issue | 🟠 MajorReject unsupported
--gpusvalues instead of falling through to the nearest bucket.With the current
>=logic,--gpus 6selectsCONFIGS_4GPUbut still launches--nproc_per_node=6, so the rank layout no longer matches the config offsets and DP sizes.Suggested guard
-if [[ $NUM_GPUS -ge 8 ]]; then +if [[ $NUM_GPUS -eq 8 ]]; then CONFIGS=("${CONFIGS_8GPU[@]}") -elif [[ $NUM_GPUS -ge 4 ]]; then +elif [[ $NUM_GPUS -eq 4 ]]; then CONFIGS=("${CONFIGS_4GPU[@]}") -else +elif [[ $NUM_GPUS -eq 2 ]]; then CONFIGS=("${CONFIGS_2GPU[@]}") +else + echo "Unsupported --gpus value: ${NUM_GPUS}. Supported values: 2, 4, 8." + exit 1 fi🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tests/e2e/mimo/run_mimo_parallelism_tests.sh` around lines 65 - 72, The GPU-config selection block currently falls through for unsupported NUM_GPUS (e.g., 6) causing mismatched rank layouts; add an explicit validation before that block to accept only supported values (2,4,8) and reject others by printing an error and exiting non-zero; keep the existing variables and selection logic (NUM_GPUS, CONFIGS_8GPU, CONFIGS_4GPU, CONFIGS_2GPU, CONFIGS assignment) but validate NUM_GPUS first (e.g., if not one of 2/4/8 -> echo error mentioning supported --gpus values and exit 1) so the script fails fast instead of choosing the nearest bucket.tests/e2e/mimo/verify_llama_conversion.py-95-136 (1)
95-136:⚠️ Potential issue | 🟠 MajorDerive the Megatron model shape from the source checkpoint instead of hardcoding 7B.
This verifier still bakes in the Vicuna/Llama-2 7B dimensions (
32layers,4096hidden size,32query groups), but--hf-modelis exposed as a generic input and the converter already reads those values dynamically. Any non-7B or GQA checkpoint will construct the wrongGPTModelhere beforeload_megatron_llm_weights()even runs.Also applies to: 223-231
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tests/e2e/mimo/verify_llama_conversion.py` around lines 95 - 136, The _make_language_config function currently hardcodes Vicuna/Llama-2-7B dimensions (num_layers=32, hidden_size=4096, num_query_groups=32, ffn_hidden_size=11008) which breaks conversion for non-7B checkpoints; instead, derive these shape fields from the source HF/Megatron checkpoint metadata that the converter already reads (e.g., use the checkpoint's config/state dict values) and populate the TransformerConfig accordingly before creating the GPTModel so load_megatron_llm_weights sees matching shapes; update _make_language_config (and the analogous code region around the other occurrence) to accept or pull num_layers, hidden_size, num_attention_heads/num_query_groups, ffn_hidden_size, and gating/activation flags from the source checkpoint rather than using hardcoded constants.tests/e2e/mimo/run_hetero_llava_parallelism_tests_unfrozen_llm.sh-95-102 (1)
95-102:⚠️ Potential issue | 🟠 MajorFail fast on unsupported
--gpusvalues.Like the other launchers in this PR, this one buckets with
>=but launches the exactNUM_GPUSvalue.--gpus 6would therefore selectCONFIGS_4GPUand still spawn 6 ranks, which breaks the encoded offsets and parallelism assumptions.Suggested guard
-if [[ $NUM_GPUS -ge 8 ]]; then +if [[ $NUM_GPUS -eq 8 ]]; then CONFIGS=("${CONFIGS_8GPU[@]}") -elif [[ $NUM_GPUS -ge 4 ]]; then +elif [[ $NUM_GPUS -eq 4 ]]; then CONFIGS=("${CONFIGS_4GPU[@]}") -else +elif [[ $NUM_GPUS -eq 2 ]]; then CONFIGS=("${CONFIGS_2GPU[@]}") +else + echo "Unsupported --gpus value: ${NUM_GPUS}. Supported values: 2, 4, 8." + exit 1 fi🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tests/e2e/mimo/run_hetero_llava_parallelism_tests_unfrozen_llm.sh` around lines 95 - 102, The GPU bucketing currently uses >= which can select a config set that doesn't match the exact NUM_GPUS launched (symbols: NUM_GPUS, CONFIGS, CONFIGS_8GPU, CONFIGS_4GPU, CONFIGS_2GPU), causing wrong rank/layout for values like 6; change the selection logic to require exact matches (e.g., if [[ $NUM_GPUS -eq 8 ]]; elif [[ $NUM_GPUS -eq 4 ]]; elif [[ $NUM_GPUS -eq 2 ]]; else print an error and exit) or add a post-check that NUM_GPUS is in the supported set {2,4,8} and fail-fast with a clear message if not.tests/e2e/mimo/run_mimo_checkpoint_resume.sh-53-59 (1)
53-59:⚠️ Potential issue | 🟠 MajorValidate
--gpusagainst the exact matrix sizes.The config tables are defined for 8, 4, and 2 GPUs only, but the current
>=selection will happily run--gpus 6with the 4-GPU layouts. That produces a world size that does not match the encoded TP/PP/DP/offset assumptions.Suggested guard
-if [[ $NUM_GPUS -ge 8 ]]; then +if [[ $NUM_GPUS -eq 8 ]]; then CONFIGS=("${CONFIGS_8GPU[@]}") -elif [[ $NUM_GPUS -ge 4 ]]; then +elif [[ $NUM_GPUS -eq 4 ]]; then CONFIGS=("${CONFIGS_4GPU[@]}") -else +elif [[ $NUM_GPUS -eq 2 ]]; then CONFIGS=("${CONFIGS_2GPU[@]}") +else + echo "Unsupported --gpus value: ${NUM_GPUS}. Supported values: 2, 4, 8." + exit 1 fi🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tests/e2e/mimo/run_mimo_checkpoint_resume.sh` around lines 53 - 59, The selection logic for CONFIGS based on NUM_GPUS is too permissive (using >=) and allows invalid values like 6 to pick the 4-GPU matrix; change the guard to only accept exact GPU counts that have defined matrices (NUM_GPUS == 8, == 4, == 2) and otherwise print a clear error and exit; update the branch that assigns CONFIGS (and any place that parses the --gpus input) to validate NUM_GPUS against CONFIGS_8GPU, CONFIGS_4GPU, CONFIGS_2GPU and refuse unsupported sizes so world-size/TP/PP/DP assumptions remain consistent.src/megatron/bridge/training/pretrain_mimo.py-148-163 (1)
148-163:⚠️ Potential issue | 🟠 MajorDon’t reject iterator builders that accept
train_statevia**kwargs.The current signature check only allows a literal
train_stateparameter name, so a perfectly compatible builder likedef build_data_iterators_fn(cfg, mimo_infra, **kwargs): ...will now fail resume with the newRuntimeError. The call site already passestrain_stateas a keyword; the guard should treatVAR_KEYWORDas compatible too.Suggested fix
sig = inspect.signature(build_data_iterators_fn) - if "train_state" in sig.parameters: + accepts_train_state = "train_state" in sig.parameters or any( + param.kind is inspect.Parameter.VAR_KEYWORD for param in sig.parameters.values() + ) + if accepts_train_state: train_data_iterator, valid_data_iterator = build_data_iterators_fn( cfg, setup_output.mimo_infra, train_state=train_state, )🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@src/megatron/bridge/training/pretrain_mimo.py` around lines 148 - 163, The resume guard incorrectly rejects iterator builders that accept train_state via **kwargs; update the check in the is_resuming branch that inspects the signature of build_data_iterators_fn to treat a VAR_KEYWORD parameter as compatible: after getting sig = inspect.signature(build_data_iterators_fn), allow resume if "train_state" is in sig.parameters OR if any parameter in sig.parameters.values() has kind == inspect.Parameter.VAR_KEYWORD, and only raise the RuntimeError when neither condition holds; references: build_data_iterators_fn, is_resuming, train_state, sig, inspect.Parameter.VAR_KEYWORD.src/megatron/bridge/data/mimo/loaders.py-80-90 (1)
80-90:⚠️ Potential issue | 🟠 MajorUse an explicit exception for the micro-batch divisibility guard.
Line 86 relies on
assert, which disappears underpython -O. That turns a required config check into a debug-only check and pushes the failure to a later slicing/load path. Please raiseValueErrorhere instead.Suggested fix
micro_batch_size = cfg.train.micro_batch_size for mod_name, mod_cfg in cfg.model.mimo_parallelism_config.module_parallelisms.items(): dp = mod_cfg.data_parallel_size - assert micro_batch_size % dp == 0, ( - f"micro_batch_size ({micro_batch_size}) must be divisible by " - f"data_parallel_size ({dp}) of module '{mod_name}'. " - f"slice_batch_for_mimo requires an evenly divisible micro-batch." - ) + if micro_batch_size % dp != 0: + raise ValueError( + f"micro_batch_size ({micro_batch_size}) must be divisible by " + f"data_parallel_size ({dp}) of module '{mod_name}'. " + f"slice_batch_for_mimo requires an evenly divisible micro-batch." + )🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@src/megatron/bridge/data/mimo/loaders.py` around lines 80 - 90, Replace the assert-based guard that checks micro_batch_size divisibility with an explicit exception: instead of using assert micro_batch_size % dp == 0, raise a ValueError when the condition fails so the check is enforced even under python -O; update the block that iterates over cfg.model.mimo_parallelism_config.module_parallelisms (using micro_batch_size = cfg.train.micro_batch_size and dp = mod_cfg.data_parallel_size) to raise ValueError with the same descriptive message referencing the module name and slice_batch_for_mimo requirement.tests/e2e/mimo/convert_hf_clip_to_megatron.py-48-53 (1)
48-53:⚠️ Potential issue | 🟠 MajorReject unsupported TP sizes before sharding or re-sharding.
This accepts any
tensor_parallel_size, but these paths assume exact partitions for the sharded dims. For example,TP=3leaves some 1024/4096-derived weights uneven, so the converter can write shards that later fail to load into the target Megatron module. Please validate divisibility up front and before the re-split path inload_megatron_clip_weights().Also applies to: 180-187, 286-295
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tests/e2e/mimo/convert_hf_clip_to_megatron.py` around lines 48 - 53, Validate and reject unsupported tensor parallel sizes early: in convert_hf_clip_to_megatron(hf_model_name, output_path, tensor_parallel_size, use_te) check that tensor_parallel_size cleanly divides all sharded dimensions used by the converter (e.g., hidden_size, mlp intermediate dims like 1024/4096-derived sizes) and raise a clear ValueError if not; likewise add the same divisibility check at the start of load_megatron_clip_weights() before attempting any re-splitting/resharding path so the function aborts with a descriptive error when TP would produce uneven shards.
🟡 Minor comments (2)
src/megatron/bridge/training/train.py-1180-1180 (1)
1180-1180:⚠️ Potential issue | 🟡 MinorThis condition treats
fp8=Falseas enabled.
getattr(..., None) is not Noneis true whenever the attribute exists, including the commonFalsecase, so non-FP8 runs now do a fullgc.collect()after every checkpoint.Suggested fix
- if getattr(state.cfg.model, "fp8", None) is not None: + if getattr(state.cfg.model, "fp8", False):🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@src/megatron/bridge/training/train.py` at line 1180, The condition incorrectly treats fp8=False as enabled because it only checks for existence; change the check in the train loop to test the boolean value explicitly (e.g., use getattr(state.cfg.model, "fp8", False) or compare with True) so that only truthy fp8 triggers the gc.collect() path; update the conditional that references state.cfg.model.fp8 accordingly (the branch that currently calls gc.collect() after checkpoints).tests/unit_tests/training/mimo/test_mimo_checkpointing.py-705-719 (1)
705-719:⚠️ Potential issue | 🟡 MinorReset
_PIPELINE_MODEL_PARALLEL_GROUPin the shared pretrain helper too.
pretrain_mimo()now writes the PP global for the active module, but this helper only patches the TP/DP globals. That makes the load-side tests order-dependent if a prior test leaves_PIPELINE_MODEL_PARALLEL_GROUPpopulated.Suggested fix
patch("megatron.core.parallel_state._TENSOR_MODEL_PARALLEL_GROUP", None), patch("megatron.core.parallel_state._DATA_PARALLEL_GROUP", None), patch("megatron.core.parallel_state._DATA_PARALLEL_GROUP_WITH_CP", None), + patch("megatron.core.parallel_state._PIPELINE_MODEL_PARALLEL_GROUP", None),🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tests/unit_tests/training/mimo/test_mimo_checkpointing.py` around lines 705 - 719, The shared pretrain helper must also reset the pipeline parallel global; add a patch for megatron.core.parallel_state._PIPELINE_MODEL_PARALLEL_GROUP (set to None) in the same with(...) context that currently patches _TENSOR_MODEL_PARALLEL_GROUP, _DATA_PARALLEL_GROUP, and _DATA_PARALLEL_GROUP_WITH_CP so tests are order-independent; update the block used by the pretrain_mimo() tests to include patch("megatron.core.parallel_state._PIPELINE_MODEL_PARALLEL_GROUP", None).
🧹 Nitpick comments (5)
tests/e2e/mimo/verify_llama_conversion.py (1)
167-182: Please switch these diagnostics tologgingorprint_rank_0().The script currently relies on bare
print()throughout the distributed path, which does not match the repo-wide Python logging convention.As per coding guidelines,
**/*.py:Use logging.getLogger(__name__) or print_rank_0() instead of bare print() statementsAlso applies to: 221-258
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tests/e2e/mimo/verify_llama_conversion.py` around lines 167 - 182, Replace the bare print calls in the verification block (the prints that show shapes, diffs, and Status) with the repo logging pattern: obtain a module logger via logging.getLogger(__name__) (or use the existing logger) and emit logger.info(...) for each message, or call print_rank_0(...) where distributed-only rank-0 output is required; ensure the same change is applied to the later diagnostic block around lines 221-258 and keep the message text identical to preserve test output expectations.tests/e2e/mimo/convert_hf_llama_to_megatron.py (1)
97-115: Please route these status messages throughloggingorprint_rank_0().The converter uses bare
print()throughout, which does not match the repo-wide Python logging convention.As per coding guidelines,
**/*.py:Use logging.getLogger(__name__) or print_rank_0() instead of bare print() statementsAlso applies to: 213-243, 307-357, 400-407
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tests/e2e/mimo/convert_hf_llama_to_megatron.py` around lines 97 - 115, Replace bare print() calls (e.g., the "Loading HuggingFace model: {hf_model_name}" print and the subsequent print of hidden_size/num_heads/head_dim/target_vocab) with the project logging convention: import logging if missing, add logger = logging.getLogger(__name__), and use logger.info(...) or use print_rank_0(...) for rank-0-only messages; apply the same replacement for other bare prints in the same module (the blocks around the ranges you noted). Ensure message strings and interpolation remain identical and that imports (logging or print_rank_0) are added/used consistently.tests/e2e/mimo/verify_clip_conversion.py (1)
129-142: Make the verification output rank-aware.This runs under
torchrun, so bareprint()calls get noisy quickly once multiple ranks are involved. Swapping these toprint_rank_0()or a logger would keep the diagnostics readable while staying consistent with the rest of the repo.As per coding guidelines, "Use
logging.getLogger(__name__)orprint_rank_0()instead of bareprint()statements".Also applies to: 177-177, 190-190, 196-196
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tests/e2e/mimo/verify_clip_conversion.py` around lines 129 - 142, Replace the bare print() calls in the verification block with a rank-aware output (use print_rank_0() or a module logger from logging.getLogger(__name__)); specifically change the prints that emit label, HF/Megatron shapes (hf_out, meg_out), mean_diff, max_diff, cos, and status to use print_rank_0() (or logger.info) so only rank 0 outputs these diagnostics; apply the same change to the other print sites noted (the other verification/diagnostic prints around the same function that reference hf_out/meg_out/mean_diff/max_diff/cos/status) to keep output consistent under torchrun.src/megatron/bridge/training/mimo_step.py (1)
30-54: Reuse the DP-resolution helper instead of re-implementing it here.
_get_module_dp_info()is now a second copy of the rank→module/DP lookup that already lives insrc/megatron/bridge/data/mimo/dp_utils.py. Keeping both paths in sync will be easy to miss as the heterogeneous DP rules evolve. I’d pull this behind one shared helper and call it from both loader setup andforward_step.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@src/megatron/bridge/training/mimo_step.py` around lines 30 - 54, This function _get_module_dp_info duplicates the rank→module/DP lookup; replace its body by importing and delegating to the centralized helper in src/megatron/bridge/data/mimo/dp_utils.py (use that module's exported lookup function rather than reimplementing rank checks), keeping the same signature (mimo_model: MimoModel) so callers like slice_batch_for_mimo and forward_step are unchanged; remove the local torch.distributed logic and any duplicated grid iteration, and ensure the import references the dp_utils helper so loader setup and forward_step both share the single implementation.tests/e2e/mimo/convert_hf_clip_to_megatron.py (1)
63-64: Use a logger for script progress instead of rawprint().This file is all stdout side effects today, which makes E2E output harder to control and filter in automation. A module logger would fit the repo conventions better.
As per coding guidelines, "Use
logging.getLogger(__name__)orprint_rank_0()instead of bareprint()statements".Also applies to: 177-178, 205-207, 268-268, 302-302, 345-345
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tests/e2e/mimo/convert_hf_clip_to_megatron.py` around lines 63 - 64, Replace the bare print() calls in this script with a module logger: add logger = logging.getLogger(__name__) (and import logging) near the top, then change each print(...) that reports progress — especially the one that prints hf_model_name before calling CLIPVisionModel.from_pretrained and the other prints used around model conversion/save steps — to logger.info(...) (or logger.debug as appropriate); ensure messages remain the same and keep any variable interpolations (e.g., hf_model_name) so behavior is identical but uses the repo logging convention.
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
Run ID: 0c03044f-6f93-4906-be3c-e9420b384739
📒 Files selected for processing (32)
src/megatron/bridge/data/mimo/__init__.pysrc/megatron/bridge/data/mimo/dp_utils.pysrc/megatron/bridge/data/mimo/loaders.pysrc/megatron/bridge/models/mimo/mimo_ddp.pysrc/megatron/bridge/models/mimo/mimo_provider.pysrc/megatron/bridge/training/checkpointing.pysrc/megatron/bridge/training/eval.pysrc/megatron/bridge/training/mimo_step.pysrc/megatron/bridge/training/pretrain_mimo.pysrc/megatron/bridge/training/setup_mimo.pysrc/megatron/bridge/training/train.pysrc/megatron/bridge/training/train_mimo.pysrc/megatron/bridge/training/utils/theoretical_memory_utils.pytests/e2e/__init__.pytests/e2e/mimo/__init__.pytests/e2e/mimo/convert_hf_clip_to_megatron.pytests/e2e/mimo/convert_hf_llama_to_megatron.pytests/e2e/mimo/run_conversion_verification.shtests/e2e/mimo/run_hetero_llava.shtests/e2e/mimo/run_hetero_llava_parallelism_tests.shtests/e2e/mimo/run_hetero_llava_parallelism_tests_unfrozen_llm.shtests/e2e/mimo/run_mimo_checkpoint_resume.shtests/e2e/mimo/run_mimo_parallelism_tests.shtests/e2e/mimo/test_mimo_checkpoint_resume_e2e.pytests/e2e/mimo/test_mimo_training_e2e.pytests/e2e/mimo/test_mimo_training_llava.pytests/e2e/mimo/verify_clip_conversion.pytests/e2e/mimo/verify_llama_conversion.pytests/unit_tests/data/mimo/test_loaders.pytests/unit_tests/training/mimo/test_mimo_checkpointing.pytests/unit_tests/training/mimo/test_pretrain_mimo.pytests/unit_tests/training/test_checkpointing.py
Signed-off-by: Kamran Jafari <kjafarisadeg@nvidia.com>
|
/claude review |
Signed-off-by: Li Ding <liding@nvidia.com>
|
/ok to test e23fb22 |
Signed-off-by: Li Ding <liding@nvidia.com>
|
/ok to test e5ad53b |
cuichenx
left a comment
There was a problem hiding this comment.
LGTM, just one small comment
There was a problem hiding this comment.
why do we need to create this tests/e2e folder?
There was a problem hiding this comment.
I think files here are best suited for the example folder, what do you think?
There was a problem hiding this comment.
yea absolutely... i was planning to do all the code structure changes in the follow-up PR (rename, moving files, etc)... the main reason not doing now was that some folks have been using scripts under e2e folder. So i don't want to them to change things twice.. but lemme know if you think we should do it in this PR...
There was a problem hiding this comment.
We can do it in a follow up PR, totally
…VIDIA-NeMo#2870) Signed-off-by: Ali Roshan Ghias <aroshanghias@nvidia.com> Signed-off-by: Kamran Jafari <kjafarisadeg@nvidia.com> Signed-off-by: Li Ding <liding@nvidia.com> Co-authored-by: Yashaswi Karnati <144376261+yashaswikarnati@users.noreply.github.com> Co-authored-by: Claude Opus 4.6 (1M context) <noreply@anthropic.com> Co-authored-by: Kamran Jafari <kjafarisadeg@nvidia.com> Co-authored-by: Li Ding <liding@nvidia.com>
…VIDIA-NeMo#2870) Signed-off-by: Ali Roshan Ghias <aroshanghias@nvidia.com> Signed-off-by: Kamran Jafari <kjafarisadeg@nvidia.com> Signed-off-by: Li Ding <liding@nvidia.com> Co-authored-by: Yashaswi Karnati <144376261+yashaswikarnati@users.noreply.github.com> Co-authored-by: Claude Opus 4.6 (1M context) <noreply@anthropic.com> Co-authored-by: Kamran Jafari <kjafarisadeg@nvidia.com> Co-authored-by: Li Ding <liding@nvidia.com> Signed-off-by: Vasudevan Rengasamy <vrengasamy@nvidia.com>
Summary
Adds checkpoint save/resume and evaluation support for MiMo models, stacked on the Phase 4 training PR (#2869).
checkpointing.py,pretrain_mimo, andtrain_mimo, withtorch_distformat support and access-pattern validation bypass for nested DDP language model tensors in PP>1eval.pyand distributed batch slicing (dp_utils.slice_batch_for_mimo) for evaluation across heterogeneous DP groupsset_input_tensormethod proxying on DDP-wrapped language model for correct PP decoder input wiring during checkpoint resumeValidation
baseline_dp_only, 8 GPUs)dp4_both,tp4_both,tp2_dp2_both,pp2_llm_dp4_visionStack
Summary by CodeRabbit
New Features
Bug Fixes
Tests