Skip to content

[Feature] Xiaomi MiMo-V2.5-Pro day0 support#23808

Merged
ispobock merged 3 commits intosgl-project:mainfrom
JoyFuture:feat/support-mimov2-pro
Apr 28, 2026
Merged

[Feature] Xiaomi MiMo-V2.5-Pro day0 support#23808
ispobock merged 3 commits intosgl-project:mainfrom
JoyFuture:feat/support-mimov2-pro

Conversation

@JoyFuture
Copy link
Copy Markdown
Contributor

@JoyFuture JoyFuture commented Apr 27, 2026

Motivation

MiMo-V2.5-Pro is a Mixture-of-Experts (MoE) language model with 1.02T total parameters and 42B active parameters. It utilizes hybrid attention architecture and 3-layers Multi-Token Prediction (MTP) described in MiMo-V2-Flash. The context length is up to 1M tokens.

Modifications

Accuracy Tests

Speed Tests and Profiling

Benchmark command

python3 -m sglang.bench_serving \
  --backend sglang \
  --model XiaomiMiMo/MiMo-V2.5-Pro \
  --host 0.0.0.0 \
  --port 9001 \
  --dataset-name random \
  --random-input-len 8192 \
  --random-output-len 1 \
  --random-range-ratio 1.0 \
  --flush-cache \
  --seed 12345 \
  --num-prompts 10000 

Prefill performance

Test setting: EP16, chunk_size=32K, output length = 1 token, cache flushed.

Input length Output length Single-node prefill throughput, cache miss
4K 1 30.8K tok/s
8K 1 30.65K tok/s
16K 1 29.85K tok/s
32K 1 28.6K tok/s
64K 1 26.65K tok/s
128K 1 23.0K tok/s
256K 1 17.9K tok/s
512K 1 11.3K tok/s
768K 1 9.4K tok/s
1M 1 7.3K tok/s

For input lengths up to 256K, we tested with the benchmark commands above and checked the final output / logs to confirm that the requests were processed correctly.

For input lengths >= 512K, we sent two requests and ensured that they were routed to two different DP ranks. We then checked the bench_serving output and recorded the single-node prefill throughput.

Decode performance

Test setting: fixed 16K input and 1K output. We compared single-node decode throughput with and without 3-layer MTP.

BS per DP rank MTP MTP accept length TPS Single-node decode throughput
64 disabled - 29.3 1875 tok/s
64 3-layer 3 60.5 3873 tok/s
64 3-layer 4 79.7 5103 tok/s
96 disabled - 26.7 2564 tok/s
96 3-layer 3 50.4 4840 tok/s
96 3-layer 4 64.8 6225 tok/s

With 3-layer MTP enabled, decode throughput improves significantly:

BS per DP rank Without MTP 3-layer MTP, accept length = 3 3-layer MTP, accept length = 4
64 1875 tok/s 3873 tok/s 5103 tok/s
96 2564 tok/s 4840 tok/s 6225 tok/s

These results show that the model can run correctly with long-context prefill up to 1M tokens, and 3-layer MTP provides a clear decode-side throughput improvement under the tested 16K-input / 1K-output setting.

Launch Command example

SGLANG_ENABLE_SPEC_V2=1
SGLANG_DEEPEP_NUM_MAX_DISPATCH_TOKENS_PER_RANK=256
python3 -m sglang.launch_server \
              --model-path XiaomiMiMo/MiMo-V2.5-Pro \
              --trust-remote-code \
              --pp-size 1 \
              --dp-size 2 \
              --ep-size 16 \
              --tp-size 16 \
              --moe-dense-tp-size 1 \
              --enable-dp-attention \
              --moe-a2a-backend deepep \
              --dist-init-addr ${LWS_LEADER_IP}:20000 \
              --node-rank ${LWS_WORKER_INDEX} \
              --nnodes ${LWS_GROUP_SIZE} \
              --page-size 64 \
              --attention-backend fa3 \
              --quantization fp8 \
              --mem-fraction-static 0.7 \
              --max-running-requests 128 \
              --cuda-graph-max-bs 64 \
              --chunked-prefill-size 32768 \
              --context-length 1048576 \
              --tokenizer-worker-num 64 \
              --speculative-algorithm EAGLE \
              --speculative-num-steps 3 \
              --speculative-eagle-topk 1 \
              --speculative-num-draft-tokens 4 \
              --enable-multi-layer-eagle \
              --host 0.0.0.0 \
              --port 9001 \
              --reasoning-parser mimo \
              --tool-call-parser mimo \
              --watchdog-timeout 3600 \
              --model-loader-extra-config '{"enable_multithread_load": "true","num_threads": 64}' 

MiMo-V2.5-Pro FP8 checkpoint uses a fused QKV projection layout exported with attention TP=8.

The FP8 quantization is performed independently on each attention TP shard before the shards are concatenated into the HF checkpoint. Therefore, the fused qkv_proj checkpoint is TP-rank-interleaved rather than a flat [Q_all | K_all | V_all] layout.

For MiMo-V2.5-Pro:

  • hidden_size = 6144
  • num_attention_heads = 128
  • head_dim = 192
  • num_key_value_heads = 8
  • v_head_dim = 128

With attention TP=8, each attention TP shard has:

  • Q: 128 / 8 * 192 = 3072 rows
  • K: 1 * 192 = 192 rows
  • V: 1 * 128 = 128 rows
  • Total QKV output per shard = 3392 rows

Since FP8 block-wise quantization uses 128x128 blocks, each TP shard's row dimension is padded independently for scale generation:

ceil(3392 / 128) = 27

So each shard has:

  • qkv_proj.weight shard shape: [3392, 6144]
  • qkv_proj.weight_scale_inv shard shape: [27, 48]

After concatenating 8 shards along the row dimension, the HF checkpoint stores:

  • qkv_proj.weight: [27136, 6144]
  • qkv_proj.weight_scale_inv: [216, 48]

Note that the scale rows are 8 * ceil(3392 / 128) = 216, not ceil((3392 * 8) / 128) = 212, because quantization is done independently per TP shard.

Therefore, this checkpoint requires runtime attention TP=8 for the fused QKV loading path. With DP attention enabled, this means the derived attention TP size should be 8, e.g. --tp-size 16 --dp-size 2 gives attention TP=8.

Checklist

Review and Merge Process

  1. Ping Merge Oncalls to start the process. See the PR Merge Process.
  2. Get approvals from CODEOWNERS and other reviewers.
  3. Trigger CI tests with comments or contact authorized users to do so.
    • Common commands include /tag-and-rerun-ci, /tag-run-ci-label, /rerun-failed-ci
  4. After green CI and required approvals, ask Merge Oncalls or people with Write permission to merge the PR.

@gemini-code-assist
Copy link
Copy Markdown
Contributor

Warning

You have reached your daily quota limit. Please wait up to 24 hours and I will start processing your requests again!

@JoyFuture JoyFuture force-pushed the feat/support-mimov2-pro branch from 8a3c83b to 4af6985 Compare April 27, 2026 06:11
@seindum
Copy link
Copy Markdown

seindum commented Apr 27, 2026

I didn't see it being 1T parameters coming. When release?

@acelyc111
Copy link
Copy Markdown
Collaborator

/rerun-failed-ci

1 similar comment
@JoyFuture
Copy link
Copy Markdown
Contributor Author

/rerun-failed-ci

@JoyFuture JoyFuture force-pushed the feat/support-mimov2-pro branch from 4bae823 to 7cb04d2 Compare April 27, 2026 08:11
@JoyFuture JoyFuture changed the title [Feature] Xiaomi MiMo-V2-Pro day0 support [Feature] Xiaomi MiMo-V2.5-Pro day0 support Apr 27, 2026
@JoyFuture
Copy link
Copy Markdown
Contributor Author

/rerun-failed-ci

@ispobock
Copy link
Copy Markdown
Collaborator

/rerun-test test_mimo_models.py

@github-actions
Copy link
Copy Markdown
Contributor

8-gpu-h200 (1 test):

cd test/ && python3 registered/8-gpu-models/test_mimo_models.py

⚠️ Could not retrieve workflow run URL. Check the Actions tab for progress.

@JoyFuture JoyFuture force-pushed the feat/support-mimov2-pro branch from c6f8c05 to d11318c Compare April 28, 2026 01:02
@ispobock ispobock merged commit 1a55646 into sgl-project:main Apr 28, 2026
95 of 152 checks passed
vguduruTT pushed a commit to vguduruTT/sglang that referenced this pull request May 2, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants