Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
45 commits
Select commit Hold shift + click to select a range
ff3705b
[ckpt] refactor: Consolidate fused expert mappings and fix MTP inference
yaoyu-33 Mar 6, 2026
cf6391d
renmove deprecated providers
yaoyu-33 Feb 28, 2026
755b07d
fix: resolve import errors from removed model providers
yaoyu-33 Feb 28, 2026
b2382d1
[model] feat: add MiniMax-M2 MoE bridge
yaoyu-33 Mar 1, 2026
40a9ca8
[model] feat: add full-dimension QK norm, FP8 dequantization, and mul…
yaoyu-33 Mar 2, 2026
a86dc26
[model] fix: Add GLMExpertGateUpProjMapping alias and clean up MiniMa…
yaoyu-33 Mar 6, 2026
f86b7ad
[model] chore: clean up MiniMax-M2 example scripts
yaoyu-33 Mar 6, 2026
57363f0
[model] fix: srun-native multi-node support and MiniMax-M2 example cl…
yaoyu-33 Mar 7, 2026
206c4fb
test: skip nemotronh 4b ckpt tests and PP=2 conversion (pleasefixme f…
yaoyu-33 Mar 7, 2026
5cfc87b
Revert "test: skip nemotronh 4b ckpt tests and PP=2 conversion (pleas…
yaoyu-33 Mar 7, 2026
8d66144
ci: re-trigger CI
yaoyu-33 Mar 10, 2026
2dfbb99
[model] fix: add missing GLMExpertGateUpProjMapping alias in glm_moe_…
yaoyu-33 Mar 10, 2026
3cc9686
[ckpt] fix: fix ruff I001 lint error in glm_moe_mappings.py
yaoyu-33 Mar 10, 2026
29616bc
[model, test] fix: add PROVIDER_CLASS and null_attr guards for Qwen3N…
yaoyu-33 Mar 10, 2026
46da10c
[model] chore: merge main into refactor-fused-expert-mappings
yaoyu-33 Mar 16, 2026
c27268b
Merge branch 'main' into yuya/add-minimax-m2-bridge
yaoyu-33 Mar 16, 2026
0e4d6ed
[model, ckpt] fix: fix GPT-OSS MXFP4 expert weight layout and GLM MTP…
yaoyu-33 Mar 16, 2026
efb9c0f
[ckpt] fix: Address PR review comments on fused expert mappings
yaoyu-33 Mar 17, 2026
b8c05de
[ckpt] merge: Merge latest main into yuya/refactor-fused-expert-mappings
yaoyu-33 Mar 17, 2026
400ef5a
[test] fix: Add missing Llama32ModelProvider1B definition in test_pre…
yaoyu-33 Mar 17, 2026
55bb430
[model] refactor: address PR review comments for MTP disable and expe…
yaoyu-33 Mar 17, 2026
a24f701
[ci] fix: Fix ruff lint failures
yaoyu-33 Mar 17, 2026
dfa47e4
[test] merge: Resolve conflict in test_pretrain.py with main
yaoyu-33 Mar 17, 2026
c4d4484
Merge remote-tracking branch 'origin/main' into yuya/refactor-fused-e…
yaoyu-33 Mar 17, 2026
b85c549
Merge remote-tracking branch 'origin/yuya/refactor-fused-expert-mappi…
yaoyu-33 Mar 17, 2026
d04a47e
Merge origin/main into yuya/add-minimax-m2-bridge
yaoyu-33 Mar 18, 2026
e6a2ca9
[misc] merge: Resolve conflicts with origin/yuya/add-minimax-m2-bridge
yaoyu-33 Mar 18, 2026
5b68f2f
[build] chore: sync 3rdparty/Megatron-LM submodule to main
yaoyu-33 Mar 18, 2026
9f7b2b9
Merge remote-tracking branch 'origin/main' into yuya/add-minimax-m2-b…
yaoyu-33 Mar 18, 2026
cf38f1e
[model] revert: restore nemotron_h_provider.py accidentally deleted i…
yaoyu-33 Mar 18, 2026
cab0056
[model, test, doc] fix: address PR review comments for MiniMax-M2 bridge
yaoyu-33 Mar 18, 2026
c31c4a9
[doc] fix: address remaining README review comments for MiniMax-M2
yaoyu-33 Mar 18, 2026
29a5629
[doc] fix: remove CONTAINER_MOUNTS from MiniMax-M2 README
yaoyu-33 Mar 18, 2026
070266a
[misc] fix: revert remaining PR review requests from yaoyu-33
yaoyu-33 Mar 18, 2026
a84ee7e
[ci,test] feat: add L0 bash script and unit tests for MiniMax-M2 bridge
yaoyu-33 Mar 19, 2026
54346c8
[ci] feat: add L0_Launch_models_minimax_m2 to cicd-main.yml matrix
yaoyu-33 Mar 19, 2026
c77d9cb
[test] fix: save surrogate gpt2 tokenizer in minimax_m2 toy model fix…
yaoyu-33 Mar 19, 2026
80bafee
[ci] fix: collect in-process coverage for minimax_m2 L0 bash script
yaoyu-33 Mar 19, 2026
2c692fe
[misc] fix: address cuichenx PR review comments for MiniMax-M2
yaoyu-33 Mar 20, 2026
02d739b
[ci] fix: resolve merge conflict with main by adopting dynamic test m…
yaoyu-33 Mar 20, 2026
e955302
[model] fix: correct MiniMax-M2 HF parameter names in bridge
yaoyu-33 Mar 23, 2026
c599b9c
[model] fix: auto-detect MiniMax-M2 HF checkpoint format in bridge
yaoyu-33 Mar 23, 2026
3724d87
[model] fix: drop legacy MiniMax-M2 format, use native transformers l…
yaoyu-33 Mar 23, 2026
53f2d66
[model, test] fix: use on-disk block_sparse_moe keys for MiniMax-M2 b…
yaoyu-33 Mar 23, 2026
c1bfdf2
[test] fix: use absolute coverage path in MiniMax-M2 functional test
yaoyu-33 Mar 23, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 24 additions & 19 deletions examples/conversion/compare_hf_and_megatron/compare.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@

from megatron.bridge import AutoBridge
from megatron.bridge.models.hf_pretrained.utils import is_safe_repo
from megatron.bridge.utils.common_utils import get_last_rank, print_rank_0
from megatron.bridge.utils.common_utils import disable_mtp_for_inference, get_last_rank, print_rank_0


# Cosine similarity threshold: require at least 98% similarity (2% tolerance)
Expand Down Expand Up @@ -618,8 +618,7 @@ def _load_megatron_model(args):

# Workaround: disable MTP for inference (causes hangs on NCCL collectives)
for m in megatron_model:
m.config.mtp_num_layers = None
m.config.grad_scale_func = None
disable_mtp_for_inference(m)

model_components = [m.eval() for m in megatron_model]

Expand Down Expand Up @@ -732,6 +731,11 @@ def compare_models_one_step(args) -> None:

# Broadcast HF results to all ranks
if torch.distributed.is_initialized():
# Ensure consistent dtype across ranks: rank 0 has bfloat16 logits from the HF model,
# so all ranks must use the same dtype for NCCL broadcast to work correctly.
if hf_logits is not None:
hf_logits = hf_logits.float()
Comment thread
yaoyu-33 marked this conversation as resolved.

# Create tensors for broadcasting if they don't exist on non-rank-0
if hf_next_token is None:
hf_next_token = torch.zeros(1, device=input_ids.device, dtype=torch.long)
Expand All @@ -748,6 +752,8 @@ def compare_models_one_step(args) -> None:
# Broadcast from rank 0 to all ranks
torch.distributed.broadcast(hf_next_token, 0)
torch.distributed.broadcast(hf_logits, 0)
torch.distributed.barrier()
print_rank_0("HF results broadcast complete.")

# Run Megatron model forward pass
print_rank_0("=== RUNNING MEGATRON MODEL (1-STEP) ===")
Expand Down Expand Up @@ -812,27 +818,26 @@ def compare_models_one_step(args) -> None:
top5_tokens = [tokenizer.decode([idx]) for idx in top5_ids]
print(f"Megatron Top 5: {list(zip(top5_tokens, top5_vals.tolist()))}")

# Megatron may pad vocab_size for GPU kernel efficiency — truncate
# to the HF vocab size so logits are directly comparable.
hf_vocab_size = hf_logits.shape[0]
megatron_logits_cmp = megatron_logits[:hf_vocab_size]
megatron_next_token_cmp = torch.argmax(megatron_logits_cmp, dim=-1)

# Compare outputs (only where we have valid Megatron results)
print("=== COMPARISON ===")
token_match = hf_next_token.item() == megatron_next_token.item()
token_match = hf_next_token.item() == megatron_next_token_cmp.item()
token_status_emoji = "✅" if token_match else "❌"
print(f"Token match: {token_match} {token_status_emoji}")

# Compare logits if shapes match
if hf_logits.shape == megatron_logits.shape:

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why is this check removed?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removed because we now always truncate Megatron logits to the HF vocab size before comparing (lines 822-824 in the current diff), which makes the old shape-mismatch branch unreachable. Megatron pads vocab size for GPU kernel efficiency (must be divisible by 64/128), so Megatron vocab ≥ HF vocab is always true — we just slice off the padding before the diff/cosine-similarity. The old else branch would silently skip the entire comparison; the new code always compares correctly.

diff = (hf_logits - megatron_logits).abs()
print(f"Logits diff - max: {diff.max():.6f}, mean: {diff.mean():.6f}")
cosine_sim = torch.cosine_similarity(hf_logits.unsqueeze(0), megatron_logits.unsqueeze(0))
cos_val = cosine_sim.item()
percent = cos_val * 100.0
status_emoji = "✅" if cos_val >= SIMILARITY_THRESHOLD else "❌"
tolerance_text = "within ±2%" if cos_val >= SIMILARITY_THRESHOLD else "outside ±2%"
print(
f"Cosine similarity: {cos_val:.6f} ({percent:.2f}%) {status_emoji} ({tolerance_text} tolerance)"
)
else:
print(f"Shape mismatch: HF {hf_logits.shape} vs Megatron {megatron_logits.shape}")
print("Cannot compare logits directly due to shape mismatch")
diff = (hf_logits - megatron_logits_cmp).abs()
print(f"Logits diff - max: {diff.max():.6f}, mean: {diff.mean():.6f}")
cosine_sim = torch.cosine_similarity(hf_logits.unsqueeze(0), megatron_logits_cmp.unsqueeze(0))
cos_val = cosine_sim.item()
percent = cos_val * 100.0
status_emoji = "✅" if cos_val >= SIMILARITY_THRESHOLD else "❌"
tolerance_text = "within ±2%" if cos_val >= SIMILARITY_THRESHOLD else "outside ±2%"
print(f"Cosine similarity: {cos_val:.6f} ({percent:.2f}%) {status_emoji} ({tolerance_text} tolerance)")

print("=== COMPARISON COMPLETE ===")
else:
Expand Down
10 changes: 4 additions & 6 deletions examples/conversion/hf_megatron_roundtrip_multi_gpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,14 +166,12 @@ def main(
original_param = bridge.hf_pretrained.state[name]
compare_param = param
compare_original = original_param
# Cast to float32 for params with known dtype mismatches between Megatron and HF
# (e.g. Megatron keeps expert_bias in float32 while HF may use bfloat16)
if any(p in name for p in IGNORE_PRECISION_PARAMS) or compare_param.dtype != compare_original.dtype:
# Cast to float32 for params with known precision mismatches.
# Any new known mismatch should be recorded in IGNORE_PRECISION_PARAMS above.
if any(p in name for p in IGNORE_PRECISION_PARAMS):
compare_param = param.float()
compare_original = original_param.float()
match = torch.allclose(
compare_param, compare_original.to(compare_param.device), atol=1e-1
) # Increased tolerance for bfloat16
match = torch.allclose(compare_param, compare_original.to(compare_param.device), atol=1e-1)
all_match = all_match and match
table.add_row(
name,
Expand Down
18 changes: 2 additions & 16 deletions examples/conversion/hf_to_megatron_generate_text.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@

from megatron.bridge import AutoBridge
from megatron.bridge.models.hf_pretrained.utils import is_safe_repo
from megatron.bridge.utils.common_utils import get_last_rank, print_rank_0
from megatron.bridge.utils.common_utils import disable_mtp_for_inference, get_last_rank, print_rank_0


class SingleBatchIterator:
Expand Down Expand Up @@ -167,24 +167,10 @@ def main(args) -> None:
model_provider.initialize_model_parallel(seed=0)
model = model_provider.provide_distributed_model(wrap_with_ddp=False)

# Disable MTP for inference (MTP is only used during training)
def _disable_mtp(m):
"""Disable MTP on a model by clearing mtp_process on the language model."""
m.config.mtp_num_layers = None
inner = m.module if hasattr(m, "module") else m
lang = getattr(inner, "language_model", inner)
if hasattr(lang, "mtp_process"):
lang.mtp_process = False

model = [m.cuda() for m in model]
for m in model:
m.eval()
_disable_mtp(m)

# Set grad_scale_func to None on the model's config for inference
Comment thread
yaoyu-33 marked this conversation as resolved.
for m in model:
if hasattr(m, "config"):
m.config.grad_scale_func = None
disable_mtp_for_inference(m)

# Initialize tokenizer
tokenizer = AutoTokenizer.from_pretrained(
Expand Down
76 changes: 76 additions & 0 deletions examples/models/minimax_m2/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
# MiniMax-M2 Examples

This directory contains example scripts for [MiniMax-M2](https://huggingface.co/MiniMaxAI/MiniMax-M2), a large sparse MoE model with 456B total parameters (45.9B active), 256 experts, and FP8 quantization.

## Hardware Requirements

MiniMax-M2 requires **at least 2 nodes (16 GPUs)** for inference and conversion. The model cannot fit on a single 8-GPU node because:

- TEGroupedMLP workspace is proportional to `num_experts / EP`; with EP=8 on 1 node, workspace alone OOMs.
- TP does **not** reduce expert memory — use EP instead.
- Minimum recommended config: `TP=1, EP=16, PP=1` (2 nodes × 8 GPUs).

## Checkpoint Conversion

[slurm_conversion.sh](slurm_conversion.sh) sweeps multiple TP/PP/EP configs to verify HF ↔ Megatron round-trip conversion.

### Setup

Edit the variables at the top of `slurm_conversion.sh`:

```bash
CONTAINER_IMAGE="/path/to/container.sqsh"
# Optional:
export HF_TOKEN="hf_your_token_here"
export HF_HOME="/path/to/shared/HF_HOME"
```

### Submit

```bash
sbatch examples/models/minimax_m2/slurm_conversion.sh
```

### Expected output (per config)

```
MiniMax-M2 Round-Trip Conversion Sweep
Job: <JOB_ID> | Nodes: 2
Parallelism configs: 2,1,8 1,2,8 2,2,4
======================================
Config 1/3: TP=2, PP=1, EP=8
...
All parameters match: True ✅
[OK] Config 1: TP=2, PP=1, EP=8 passed
...
All 3 configs passed
======================================
```

## Inference

[slurm_inference.sh](slurm_inference.sh) runs text generation on the full FP8 checkpoint with `TP=1, EP=16`.

### Setup

Edit the variables at the top of `slurm_inference.sh`:

```bash
CONTAINER_IMAGE="/path/to/container.sqsh"
export HF_TOKEN="hf_your_token_here"
```

### Submit

```bash
sbatch examples/models/minimax_m2/slurm_inference.sh
```

### Expected output

```
======== GENERATED TEXT OUTPUT ========
Comment thread
yaoyu-33 marked this conversation as resolved.
Prompt: What is 2+2?
Generated: What is 2+2? The answer is 4.
=======================================
```
121 changes: 121 additions & 0 deletions examples/models/minimax_m2/slurm_conversion.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
#!/bin/bash
Comment thread
yaoyu-33 marked this conversation as resolved.
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# ==============================================================================
# MiniMax-M2 Conversion Round-Trip Verification (Multi-Node via Slurm)
#
# MiniMax-M2 (MoE: 256 experts, top-8, ~230GB fp8)
# The full model OOMs on a single 8-GPU node — minimum 2 nodes (16 GPUs)
# with EP >= 8.
#
# Sweeps multiple parallelism configs (TP,PP,EP) to verify HF <-> Megatron
# round-trip conversion. Each config runs sequentially.
#
# Usage:
# 1. Fill in CONTAINER_IMAGE, CONTAINER_MOUNTS, and token exports
# 2. Adjust PARALLELISM_CONFIGS if needed
# 3. Submit: sbatch slurm_conversion.sh
# ==============================================================================

#SBATCH --job-name=minimax-m2-roundtrip
#SBATCH --nodes=2
#SBATCH --ntasks-per-node=8
#SBATCH --gpus-per-node=8
#SBATCH --time=4:00:00
#SBATCH --account=<your-account>
#SBATCH --partition=batch
#SBATCH --output=logs/minimax_m2_roundtrip_%j.log
#SBATCH --exclusive

# ── Container ────────────────────────────────────────────────────────────
CONTAINER_IMAGE=""
# CONTAINER_IMAGE="/path/to/container.sqsh"
CONTAINER_MOUNTS=""
# CONTAINER_MOUNTS="/path/to/shared/storage:/mnt/storage,/path/to/project:/opt/Megatron-Bridge"
WORKDIR="/opt/Megatron-Bridge"

# ── Tokens / Caches ──────────────────────────────────────────────────────
# export HF_TOKEN="hf_your_token_here"
# export HF_HOME="/path/to/shared/HF_HOME"
# export UV_CACHE_DIR="/path/to/shared/uv_cache"

# ── Parallelism configs: "TP,PP,EP" per entry ────────────────────────────
# TP*PP*EP must equal total GPUs (NODES * GPUS_PER_NODE).
# EP must divide 256 (number of experts).
# Note: TP=2,EP=4 on 8 GPUs (1 node) OOMs — use EP >= 8 with 2+ nodes.
PARALLELISM_CONFIGS=("2,1,8" "1,2,8" "2,2,4")

# ── Model ─────────────────────────────────────────────────────────────────
MODEL_NAME=MiniMax-M2
HF_MODEL_ID=MiniMaxAI/$MODEL_NAME

# ── Environment ───────────────────────────────────────────────────────────
export TORCH_NCCL_AVOID_RECORD_STREAMS=1
export NCCL_NVLS_ENABLE=0

# ==============================================================================
# Job Execution
# ==============================================================================

echo "======================================"
echo "MiniMax-M2 Round-Trip Conversion Sweep"
echo "Job: $SLURM_JOB_ID | Nodes: $SLURM_JOB_NUM_NODES"
echo "Parallelism configs: ${PARALLELISM_CONFIGS[*]}"
echo "======================================"

mkdir -p logs

if [ -z "$CONTAINER_IMAGE" ]; then
echo "ERROR: CONTAINER_IMAGE must be set."
exit 1
fi

SRUN_CMD="srun --mpi=pmix --container-image=$CONTAINER_IMAGE"
if [ -n "$CONTAINER_MOUNTS" ]; then
SRUN_CMD="$SRUN_CMD --container-mounts=$CONTAINER_MOUNTS"
fi

CONFIG_INDEX=0
for CONFIG in "${PARALLELISM_CONFIGS[@]}"; do
IFS=',' read -r TP PP EP <<< "$CONFIG"
CONFIG_INDEX=$((CONFIG_INDEX + 1))

echo ""
echo "======================================"
echo "Config $CONFIG_INDEX/${#PARALLELISM_CONFIGS[@]}: TP=$TP, PP=$PP, EP=$EP"
echo "======================================"

# Sync dependencies once per node, then run the roundtrip
CMD="if [ \"\$SLURM_LOCALID\" -eq 0 ]; then uv sync; else sleep 10; fi && "
CMD="${CMD}uv run --no-sync python examples/conversion/hf_megatron_roundtrip_multi_gpu.py"
CMD="$CMD --hf-model-id $HF_MODEL_ID"
CMD="$CMD --tp $TP --pp $PP --ep $EP"
CMD="$CMD --trust-remote-code"

echo "Executing: $CMD"

$SRUN_CMD bash -c "cd $WORKDIR && $CMD"
RUN_EXIT=$?
if [ $RUN_EXIT -ne 0 ]; then
echo "ERROR: Config TP=$TP, PP=$PP, EP=$EP failed (exit $RUN_EXIT)"
exit $RUN_EXIT
fi
echo "[OK] Config $CONFIG_INDEX: TP=$TP, PP=$PP, EP=$EP passed"
done

echo ""
echo "======================================"
echo "All ${#PARALLELISM_CONFIGS[@]} configs passed"
echo "======================================"
Loading
Loading