-
Notifications
You must be signed in to change notification settings - Fork 362
[model] feat: add MiniMax-M2 MoE bridge #2602
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
ff3705b
cf6391d
755b07d
b2382d1
40a9ca8
a86dc26
f86b7ad
57363f0
206c4fb
5cfc87b
8d66144
2dfbb99
3cc9686
29616bc
46da10c
c27268b
0e4d6ed
efb9c0f
b8c05de
400ef5a
55bb430
a24f701
dfa47e4
c4d4484
b85c549
d04a47e
e6a2ca9
5b68f2f
9f7b2b9
cf38f1e
cab0056
c31c4a9
29a5629
070266a
a84ee7e
54346c8
c77d9cb
80bafee
2c692fe
02d739b
e955302
c599b9c
3724d87
53f2d66
c1bfdf2
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -121,7 +121,7 @@ | |
|
|
||
| from megatron.bridge import AutoBridge | ||
| from megatron.bridge.models.hf_pretrained.utils import is_safe_repo | ||
| from megatron.bridge.utils.common_utils import get_last_rank, print_rank_0 | ||
| from megatron.bridge.utils.common_utils import disable_mtp_for_inference, get_last_rank, print_rank_0 | ||
|
|
||
|
|
||
| # Cosine similarity threshold: require at least 98% similarity (2% tolerance) | ||
|
|
@@ -618,8 +618,7 @@ def _load_megatron_model(args): | |
|
|
||
| # Workaround: disable MTP for inference (causes hangs on NCCL collectives) | ||
| for m in megatron_model: | ||
| m.config.mtp_num_layers = None | ||
| m.config.grad_scale_func = None | ||
| disable_mtp_for_inference(m) | ||
|
|
||
| model_components = [m.eval() for m in megatron_model] | ||
|
|
||
|
|
@@ -732,6 +731,11 @@ def compare_models_one_step(args) -> None: | |
|
|
||
| # Broadcast HF results to all ranks | ||
| if torch.distributed.is_initialized(): | ||
| # Ensure consistent dtype across ranks: rank 0 has bfloat16 logits from the HF model, | ||
| # so all ranks must use the same dtype for NCCL broadcast to work correctly. | ||
| if hf_logits is not None: | ||
| hf_logits = hf_logits.float() | ||
|
|
||
| # Create tensors for broadcasting if they don't exist on non-rank-0 | ||
| if hf_next_token is None: | ||
| hf_next_token = torch.zeros(1, device=input_ids.device, dtype=torch.long) | ||
|
|
@@ -748,6 +752,8 @@ def compare_models_one_step(args) -> None: | |
| # Broadcast from rank 0 to all ranks | ||
| torch.distributed.broadcast(hf_next_token, 0) | ||
| torch.distributed.broadcast(hf_logits, 0) | ||
| torch.distributed.barrier() | ||
| print_rank_0("HF results broadcast complete.") | ||
|
|
||
| # Run Megatron model forward pass | ||
| print_rank_0("=== RUNNING MEGATRON MODEL (1-STEP) ===") | ||
|
|
@@ -812,27 +818,26 @@ def compare_models_one_step(args) -> None: | |
| top5_tokens = [tokenizer.decode([idx]) for idx in top5_ids] | ||
| print(f"Megatron Top 5: {list(zip(top5_tokens, top5_vals.tolist()))}") | ||
|
|
||
| # Megatron may pad vocab_size for GPU kernel efficiency — truncate | ||
| # to the HF vocab size so logits are directly comparable. | ||
| hf_vocab_size = hf_logits.shape[0] | ||
| megatron_logits_cmp = megatron_logits[:hf_vocab_size] | ||
| megatron_next_token_cmp = torch.argmax(megatron_logits_cmp, dim=-1) | ||
|
|
||
| # Compare outputs (only where we have valid Megatron results) | ||
| print("=== COMPARISON ===") | ||
| token_match = hf_next_token.item() == megatron_next_token.item() | ||
| token_match = hf_next_token.item() == megatron_next_token_cmp.item() | ||
| token_status_emoji = "✅" if token_match else "❌" | ||
| print(f"Token match: {token_match} {token_status_emoji}") | ||
|
|
||
| # Compare logits if shapes match | ||
| if hf_logits.shape == megatron_logits.shape: | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. why is this check removed?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Removed because we now always truncate Megatron logits to the HF vocab size before comparing (lines 822-824 in the current diff), which makes the old shape-mismatch branch unreachable. Megatron pads vocab size for GPU kernel efficiency (must be divisible by 64/128), so Megatron vocab ≥ HF vocab is always true — we just slice off the padding before the diff/cosine-similarity. The old |
||
| diff = (hf_logits - megatron_logits).abs() | ||
| print(f"Logits diff - max: {diff.max():.6f}, mean: {diff.mean():.6f}") | ||
| cosine_sim = torch.cosine_similarity(hf_logits.unsqueeze(0), megatron_logits.unsqueeze(0)) | ||
| cos_val = cosine_sim.item() | ||
| percent = cos_val * 100.0 | ||
| status_emoji = "✅" if cos_val >= SIMILARITY_THRESHOLD else "❌" | ||
| tolerance_text = "within ±2%" if cos_val >= SIMILARITY_THRESHOLD else "outside ±2%" | ||
| print( | ||
| f"Cosine similarity: {cos_val:.6f} ({percent:.2f}%) {status_emoji} ({tolerance_text} tolerance)" | ||
| ) | ||
| else: | ||
| print(f"Shape mismatch: HF {hf_logits.shape} vs Megatron {megatron_logits.shape}") | ||
| print("Cannot compare logits directly due to shape mismatch") | ||
| diff = (hf_logits - megatron_logits_cmp).abs() | ||
| print(f"Logits diff - max: {diff.max():.6f}, mean: {diff.mean():.6f}") | ||
| cosine_sim = torch.cosine_similarity(hf_logits.unsqueeze(0), megatron_logits_cmp.unsqueeze(0)) | ||
| cos_val = cosine_sim.item() | ||
| percent = cos_val * 100.0 | ||
| status_emoji = "✅" if cos_val >= SIMILARITY_THRESHOLD else "❌" | ||
| tolerance_text = "within ±2%" if cos_val >= SIMILARITY_THRESHOLD else "outside ±2%" | ||
| print(f"Cosine similarity: {cos_val:.6f} ({percent:.2f}%) {status_emoji} ({tolerance_text} tolerance)") | ||
|
|
||
| print("=== COMPARISON COMPLETE ===") | ||
| else: | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,76 @@ | ||
| # MiniMax-M2 Examples | ||
|
|
||
| This directory contains example scripts for [MiniMax-M2](https://huggingface.co/MiniMaxAI/MiniMax-M2), a large sparse MoE model with 456B total parameters (45.9B active), 256 experts, and FP8 quantization. | ||
|
|
||
| ## Hardware Requirements | ||
|
|
||
| MiniMax-M2 requires **at least 2 nodes (16 GPUs)** for inference and conversion. The model cannot fit on a single 8-GPU node because: | ||
|
|
||
| - TEGroupedMLP workspace is proportional to `num_experts / EP`; with EP=8 on 1 node, workspace alone OOMs. | ||
| - TP does **not** reduce expert memory — use EP instead. | ||
| - Minimum recommended config: `TP=1, EP=16, PP=1` (2 nodes × 8 GPUs). | ||
|
|
||
| ## Checkpoint Conversion | ||
|
|
||
| [slurm_conversion.sh](slurm_conversion.sh) sweeps multiple TP/PP/EP configs to verify HF ↔ Megatron round-trip conversion. | ||
|
|
||
| ### Setup | ||
|
|
||
| Edit the variables at the top of `slurm_conversion.sh`: | ||
|
|
||
| ```bash | ||
| CONTAINER_IMAGE="/path/to/container.sqsh" | ||
| # Optional: | ||
| export HF_TOKEN="hf_your_token_here" | ||
| export HF_HOME="/path/to/shared/HF_HOME" | ||
| ``` | ||
|
|
||
| ### Submit | ||
|
|
||
| ```bash | ||
| sbatch examples/models/minimax_m2/slurm_conversion.sh | ||
| ``` | ||
|
|
||
| ### Expected output (per config) | ||
|
|
||
| ``` | ||
| MiniMax-M2 Round-Trip Conversion Sweep | ||
| Job: <JOB_ID> | Nodes: 2 | ||
| Parallelism configs: 2,1,8 1,2,8 2,2,4 | ||
| ====================================== | ||
| Config 1/3: TP=2, PP=1, EP=8 | ||
| ... | ||
| All parameters match: True ✅ | ||
| [OK] Config 1: TP=2, PP=1, EP=8 passed | ||
| ... | ||
| All 3 configs passed | ||
| ====================================== | ||
| ``` | ||
|
|
||
| ## Inference | ||
|
|
||
| [slurm_inference.sh](slurm_inference.sh) runs text generation on the full FP8 checkpoint with `TP=1, EP=16`. | ||
|
|
||
| ### Setup | ||
|
|
||
| Edit the variables at the top of `slurm_inference.sh`: | ||
|
|
||
| ```bash | ||
| CONTAINER_IMAGE="/path/to/container.sqsh" | ||
| export HF_TOKEN="hf_your_token_here" | ||
| ``` | ||
|
|
||
| ### Submit | ||
|
|
||
| ```bash | ||
| sbatch examples/models/minimax_m2/slurm_inference.sh | ||
| ``` | ||
|
|
||
| ### Expected output | ||
|
|
||
| ``` | ||
| ======== GENERATED TEXT OUTPUT ======== | ||
|
yaoyu-33 marked this conversation as resolved.
|
||
| Prompt: What is 2+2? | ||
| Generated: What is 2+2? The answer is 4. | ||
| ======================================= | ||
| ``` | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,121 @@ | ||
| #!/bin/bash | ||
|
yaoyu-33 marked this conversation as resolved.
|
||
| # Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. | ||
| # | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
|
|
||
| # ============================================================================== | ||
| # MiniMax-M2 Conversion Round-Trip Verification (Multi-Node via Slurm) | ||
| # | ||
| # MiniMax-M2 (MoE: 256 experts, top-8, ~230GB fp8) | ||
| # The full model OOMs on a single 8-GPU node — minimum 2 nodes (16 GPUs) | ||
| # with EP >= 8. | ||
| # | ||
| # Sweeps multiple parallelism configs (TP,PP,EP) to verify HF <-> Megatron | ||
| # round-trip conversion. Each config runs sequentially. | ||
| # | ||
| # Usage: | ||
| # 1. Fill in CONTAINER_IMAGE, CONTAINER_MOUNTS, and token exports | ||
| # 2. Adjust PARALLELISM_CONFIGS if needed | ||
| # 3. Submit: sbatch slurm_conversion.sh | ||
| # ============================================================================== | ||
|
|
||
| #SBATCH --job-name=minimax-m2-roundtrip | ||
| #SBATCH --nodes=2 | ||
| #SBATCH --ntasks-per-node=8 | ||
| #SBATCH --gpus-per-node=8 | ||
| #SBATCH --time=4:00:00 | ||
| #SBATCH --account=<your-account> | ||
| #SBATCH --partition=batch | ||
| #SBATCH --output=logs/minimax_m2_roundtrip_%j.log | ||
| #SBATCH --exclusive | ||
|
|
||
| # ── Container ──────────────────────────────────────────────────────────── | ||
| CONTAINER_IMAGE="" | ||
| # CONTAINER_IMAGE="/path/to/container.sqsh" | ||
| CONTAINER_MOUNTS="" | ||
| # CONTAINER_MOUNTS="/path/to/shared/storage:/mnt/storage,/path/to/project:/opt/Megatron-Bridge" | ||
| WORKDIR="/opt/Megatron-Bridge" | ||
|
|
||
| # ── Tokens / Caches ────────────────────────────────────────────────────── | ||
| # export HF_TOKEN="hf_your_token_here" | ||
| # export HF_HOME="/path/to/shared/HF_HOME" | ||
| # export UV_CACHE_DIR="/path/to/shared/uv_cache" | ||
|
|
||
| # ── Parallelism configs: "TP,PP,EP" per entry ──────────────────────────── | ||
| # TP*PP*EP must equal total GPUs (NODES * GPUS_PER_NODE). | ||
| # EP must divide 256 (number of experts). | ||
| # Note: TP=2,EP=4 on 8 GPUs (1 node) OOMs — use EP >= 8 with 2+ nodes. | ||
| PARALLELISM_CONFIGS=("2,1,8" "1,2,8" "2,2,4") | ||
|
|
||
| # ── Model ───────────────────────────────────────────────────────────────── | ||
| MODEL_NAME=MiniMax-M2 | ||
| HF_MODEL_ID=MiniMaxAI/$MODEL_NAME | ||
|
|
||
| # ── Environment ─────────────────────────────────────────────────────────── | ||
| export TORCH_NCCL_AVOID_RECORD_STREAMS=1 | ||
| export NCCL_NVLS_ENABLE=0 | ||
|
|
||
| # ============================================================================== | ||
| # Job Execution | ||
| # ============================================================================== | ||
|
|
||
| echo "======================================" | ||
| echo "MiniMax-M2 Round-Trip Conversion Sweep" | ||
| echo "Job: $SLURM_JOB_ID | Nodes: $SLURM_JOB_NUM_NODES" | ||
| echo "Parallelism configs: ${PARALLELISM_CONFIGS[*]}" | ||
| echo "======================================" | ||
|
|
||
| mkdir -p logs | ||
|
|
||
| if [ -z "$CONTAINER_IMAGE" ]; then | ||
| echo "ERROR: CONTAINER_IMAGE must be set." | ||
| exit 1 | ||
| fi | ||
|
|
||
| SRUN_CMD="srun --mpi=pmix --container-image=$CONTAINER_IMAGE" | ||
| if [ -n "$CONTAINER_MOUNTS" ]; then | ||
| SRUN_CMD="$SRUN_CMD --container-mounts=$CONTAINER_MOUNTS" | ||
| fi | ||
|
|
||
| CONFIG_INDEX=0 | ||
| for CONFIG in "${PARALLELISM_CONFIGS[@]}"; do | ||
| IFS=',' read -r TP PP EP <<< "$CONFIG" | ||
| CONFIG_INDEX=$((CONFIG_INDEX + 1)) | ||
|
|
||
| echo "" | ||
| echo "======================================" | ||
| echo "Config $CONFIG_INDEX/${#PARALLELISM_CONFIGS[@]}: TP=$TP, PP=$PP, EP=$EP" | ||
| echo "======================================" | ||
|
|
||
| # Sync dependencies once per node, then run the roundtrip | ||
| CMD="if [ \"\$SLURM_LOCALID\" -eq 0 ]; then uv sync; else sleep 10; fi && " | ||
| CMD="${CMD}uv run --no-sync python examples/conversion/hf_megatron_roundtrip_multi_gpu.py" | ||
| CMD="$CMD --hf-model-id $HF_MODEL_ID" | ||
| CMD="$CMD --tp $TP --pp $PP --ep $EP" | ||
| CMD="$CMD --trust-remote-code" | ||
|
|
||
| echo "Executing: $CMD" | ||
|
|
||
| $SRUN_CMD bash -c "cd $WORKDIR && $CMD" | ||
| RUN_EXIT=$? | ||
| if [ $RUN_EXIT -ne 0 ]; then | ||
| echo "ERROR: Config TP=$TP, PP=$PP, EP=$EP failed (exit $RUN_EXIT)" | ||
| exit $RUN_EXIT | ||
| fi | ||
| echo "[OK] Config $CONFIG_INDEX: TP=$TP, PP=$PP, EP=$EP passed" | ||
| done | ||
|
|
||
| echo "" | ||
| echo "======================================" | ||
| echo "All ${#PARALLELISM_CONFIGS[@]} configs passed" | ||
| echo "======================================" | ||
Uh oh!
There was an error while loading. Please reload this page.