Skip to content

[Feature] Xiaomi MiMo-V2-Flash day0 support#15207

Merged
hnyls2002 merged 104 commits intomainfrom
xiaomi-mimo-v2-flash
Dec 19, 2025
Merged

[Feature] Xiaomi MiMo-V2-Flash day0 support#15207
hnyls2002 merged 104 commits intomainfrom
xiaomi-mimo-v2-flash

Conversation

@acelyc111
Copy link
Copy Markdown
Collaborator

@acelyc111 acelyc111 commented Dec 15, 2025

Motivation

MiMo-V2-Flash is a Mixture-of-Experts (MoE) language model with 309B total parameters and 15B active parameters. Designed for high-speed reasoning and agentic workflows, it utilizes a novel hybrid attention architecture and Multi-Token Prediction (MTP) to achieve state-of-the-art performance while significantly reducing inference costs.

See it on HF: https://huggingface.co/XiaomiMiMo/MiMo-V2-Flash
LMSys blog: https://lmsys.org/blog/2025-12-16-mimo-v2-flash/

Modifications

  • Basic model adaption
  • Enhance Sliding Window Attention (SWA)
  • Introduce multi-layer MTP
  • Fix bugs related to PDD ans SWA KV cache management

The remain improvements are tracked by #15263.

Benchmarking and Profiling

MiMo-V2-Flash Prefill Benchmark (Radix Cache Disabled):
image

MiMo-V2-Flash Decode Benchmark (DP 2, TP 4, EP 8, MTP Accept Length 3.6, Input Token Length 16k, Varying Batch Size):
image

MiMo-V2-Flash Decode Benchmark (DP 2, TP 4, EP 8, MTP Accept Length 3.6, Per DP Rank Batch Size 16, Varying Input Token Length):
image

The full performance can be reproduced by the branch in PR: #15208

We will merge all the performance and accuracy improvements in following patches.

Launch Command example

SGLANG_ENABLE_SPEC_V2=1 python3 -m sglang.launch_server \
        --model-path XiaomiMiMo/MiMo-V2-Flash \
        --dp-size 2 \
        --enable-dp-attention \
        --tp-size 8 \
        --trust-remote-code \
        --mem-fraction-static 0.75 \
        --max-running-requests 128 \
        --chunked-prefill-size 16384 \
        --reasoning-parser qwen3 \
        --tool-call-parser mimo \
        --model-loader-extra-config '{"enable_multithread_load": "true","num_threads": 64}' \
        --attention-backend fa3 \
        --speculative-algorithm EAGLE \
        --speculative-num-steps=3 \
        --speculative-eagle-topk=1 \
        --speculative-num-draft-tokens=4 \
        --enable-mtp

Co-authors

@JoyFuture
@Jumbo0715
@TZHelloWorld
@acelyc111
@hnyls2002
@ispobock
@lshmouse
@ollybbmonster
@sitabulaixizawaluduo
@yetlinghao
@zhannngchen

Checklist

acelyc111 and others added 30 commits December 14, 2025 15:01
@yhyang201
Copy link
Copy Markdown
Collaborator

yhyang201 commented Dec 18, 2025

I tested the latest commit(36ef1e9) locally by running python3 test/srt/test_vision_openai_server_a.py TestDeepseekOCRServer,
and it still runs into an OOM issue.
I suspect this CI failure is related to MHATokenToKVPool or to deepseek-ocr itself, rather than being a flaky CI problem.
The DeepSeek OCR tests were passing before and after PR #15277
Is it possibly related to profile_max_num_token?


# For Multi-Layer MTP
# FIXME: rename -> enable_multi_layer_mtp
enable_mtp: bool = False
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Delete this argument. If it is mimo, turn this one.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

rename `multi_layer_eagle_worker.py

@TZHelloWorld
Copy link
Copy Markdown
Contributor

TZHelloWorld commented Dec 18, 2025

I tested the latest commit(36ef1e9) locally by running python3 test/srt/test_vision_openai_server_a.py TestDeepseekOCRServer, and it still runs into an OOM issue. I suspect this CI failure is related to MHATokenToKVPool or to deepseek-ocr itself, rather than being a flaky CI problem. The DeepSeek OCR tests were passing before and after PR #15277 Is it possibly related to profile_max_num_token?

when i debug this ci ,i find the k buffer and v buffer maybe use 180+GiB, and then find in function profile_max_num_token is modify ,when use the model deepseek-ai/DeepSeek-OCR , the self.model_config.v_head_dim=0.
image

and i check model deepseek-ocr config.json ,find the v_head_dim is 0 :
image

@hnyls2002
Copy link
Copy Markdown
Collaborator

@TZHelloWorld Yes, we have fixed it. It is deepseek's bug.

@acelyc111
Copy link
Copy Markdown
Collaborator Author

/rerun-failed-ci

@acelyc111
Copy link
Copy Markdown
Collaborator Author

/rerun-stage unit-test-backend-2-gpu

@github-actions
Copy link
Copy Markdown
Contributor

✅ Triggered unit-test-backend-2-gpu to run independently (skipping dependencies).

It will not be shown in this page. Check the Actions tab for progress.

@acelyc111
Copy link
Copy Markdown
Collaborator Author

/rerun-failed-ci

@hnyls2002
Copy link
Copy Markdown
Collaborator

/rerun-stage unit-test-deepep-4-gpu

@hnyls2002
Copy link
Copy Markdown
Collaborator

/rerun-stage unit-test-deepep-8-gpu

@hnyls2002
Copy link
Copy Markdown
Collaborator

hnyls2002 commented Dec 19, 2025

@hnyls2002 hnyls2002 force-pushed the xiaomi-mimo-v2-flash branch from bbffab7 to e53dc8c Compare December 19, 2025 03:38
@hnyls2002 hnyls2002 merged commit 160a06c into main Dec 19, 2025
85 of 157 checks passed
@hnyls2002 hnyls2002 deleted the xiaomi-mimo-v2-flash branch December 19, 2025 03:40
xiaobaicxy added a commit to xiaobaicxy/sglang that referenced this pull request Dec 19, 2025
* 'main' of https://github.com/sgl-project/sglang: (136 commits)
  fix: unreachable error check in retraction (sgl-project#15433)
  [sgl-kernel] chore: update deepgemm version (sgl-project#13402)
  [diffusion] multi-platform: support diffusion on amd and fix encoder loading on MI325 (sgl-project#13760)
  [amd] Add deterministic all-reduce kernel for AMD (ROCm) (sgl-project#15340)
  [diffusion] refactor: refactor _build_req_from_sampling to use shallow_asdict (sgl-project#13782)
  Add customized sampler registration (sgl-project#15423)
  Update readme (sgl-project#15425)
  Fix Mindspore model import warning (sgl-project#15287)
  [Feature] Xiaomi `MiMo-V2-Flash` day0 support (sgl-project#15207)
  [diffusion] profiling: add bench_serving.py and VBench (sgl-project#15410)
  [DLLM] Fix dLLM regression (sgl-project#15371)
  [Deepseek V3.2] Fix Deepseek MTP in V1 mode (sgl-project#15429)
  chore: update CI_PERMISSIONS (sgl-project#15431)
  [DLLM] Add CI for diffusion LLMs (sgl-project#14723)
  Support using different attention backend for draft decoding. (sgl-project#14843)
  feat(dsv32): better error handling for DeepSeek-v3.2 encoder (sgl-project#14353)
  tiny fix lint on main (sgl-project#15424)
  multimodal: precompute hash for MultimodalDataItem (sgl-project#14354)
  [AMD] Clear pre-built AITER kernels and warmup to prevent segfaults and test timeouts (sgl-project#15318)
  [Performance] optimize NSA backend metadata computation for multi-step speculative decoding (sgl-project#14781)
  ...
Prozac614 pushed a commit to Prozac614/sglang that referenced this pull request Dec 23, 2025
Co-authored-by: 谢学扬 <xiexueyang@xiaomi.com>
Co-authored-by: tz <tangzhen3@xiaomi.com>
Co-authored-by: 李家乐 <lijiale10@xiaomi.com>
Co-authored-by: 张晨 <zhangchen50@xiaomi.com>
Co-authored-by: Shaohui Liu <liushaohui3@xiaomi.com>
Co-authored-by: 王晨 <wangchen77@xiaomi.com>
Co-authored-by: jiangzihan <jiangzihan@xiaomi.com>
Co-authored-by: xiexueyang <xyxie_wangyi@163.com>
Co-authored-by: Linghao Zhang <zhanglinghao@xiaomi.com>
Co-authored-by: ispobock <ispobaoke@gmail.com>
Co-authored-by: Liangsheng Yin <lsyincs@gmail.com>
Co-authored-by: JoyFuture <35593546+JoyFuture@users.noreply.github.com>
Co-authored-by: Liangsheng Yin <hnyls2002@gmail.com>
Co-authored-by: Qiaolin Yu <liin1211@outlook.com>
Co-authored-by: root <root@bj9-ml-g8h20e-k8s-slave106-20251106.alicn.idc.xiaomi.com>
jiaming1130 pushed a commit to zhuyijie88/sglang that referenced this pull request Dec 25, 2025
Co-authored-by: 谢学扬 <xiexueyang@xiaomi.com>
Co-authored-by: tz <tangzhen3@xiaomi.com>
Co-authored-by: 李家乐 <lijiale10@xiaomi.com>
Co-authored-by: 张晨 <zhangchen50@xiaomi.com>
Co-authored-by: Shaohui Liu <liushaohui3@xiaomi.com>
Co-authored-by: 王晨 <wangchen77@xiaomi.com>
Co-authored-by: jiangzihan <jiangzihan@xiaomi.com>
Co-authored-by: xiexueyang <xyxie_wangyi@163.com>
Co-authored-by: Linghao Zhang <zhanglinghao@xiaomi.com>
Co-authored-by: ispobock <ispobaoke@gmail.com>
Co-authored-by: Liangsheng Yin <lsyincs@gmail.com>
Co-authored-by: JoyFuture <35593546+JoyFuture@users.noreply.github.com>
Co-authored-by: Liangsheng Yin <hnyls2002@gmail.com>
Co-authored-by: Qiaolin Yu <liin1211@outlook.com>
Co-authored-by: root <root@bj9-ml-g8h20e-k8s-slave106-20251106.alicn.idc.xiaomi.com>
YChange01 pushed a commit to YChange01/sglang that referenced this pull request Jan 13, 2026
Co-authored-by: 谢学扬 <xiexueyang@xiaomi.com>
Co-authored-by: tz <tangzhen3@xiaomi.com>
Co-authored-by: 李家乐 <lijiale10@xiaomi.com>
Co-authored-by: 张晨 <zhangchen50@xiaomi.com>
Co-authored-by: Shaohui Liu <liushaohui3@xiaomi.com>
Co-authored-by: 王晨 <wangchen77@xiaomi.com>
Co-authored-by: jiangzihan <jiangzihan@xiaomi.com>
Co-authored-by: xiexueyang <xyxie_wangyi@163.com>
Co-authored-by: Linghao Zhang <zhanglinghao@xiaomi.com>
Co-authored-by: ispobock <ispobaoke@gmail.com>
Co-authored-by: Liangsheng Yin <lsyincs@gmail.com>
Co-authored-by: JoyFuture <35593546+JoyFuture@users.noreply.github.com>
Co-authored-by: Liangsheng Yin <hnyls2002@gmail.com>
Co-authored-by: Qiaolin Yu <liin1211@outlook.com>
Co-authored-by: root <root@bj9-ml-g8h20e-k8s-slave106-20251106.alicn.idc.xiaomi.com>
draft_logits_output.topk_index,
)
else:
draft_logits_output, _ = self.draft_runner_list[step].forward(
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I’m a bit confused about the logic here. If self.draft_runner_list[step] is a ModelRunner, shouldn’t its return value be an object rather than a tuple?

Unfortunately, I’m encountering an error when launching mimov2 with disable-cuda-graph enabled, so I’m unable to investigate this part firsthand.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Copy Markdown
Collaborator Author

@acelyc111 acelyc111 Feb 2, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@yhyang201 It seems like a conflict introduced by another patch #15400, they are merged into the main branch about the same time.

I'll fix it later.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.