Skip to content

[model] feat: add MiMo-V2-Flash model support#3163

Merged
yaoyu-33 merged 12 commits into
NVIDIA-NeMo:mainfrom
beccohov:beccohov/mimo-v2-flash
May 31, 2026
Merged

[model] feat: add MiMo-V2-Flash model support#3163
yaoyu-33 merged 12 commits into
NVIDIA-NeMo:mainfrom
beccohov:beccohov/mimo-v2-flash

Conversation

@beccohov

@beccohov beccohov commented Apr 5, 2026

Copy link
Copy Markdown
Contributor

What does this PR do ?

Add Megatron Bridge support for MiMo-V2-Flash (Xiaomi), a 309B / ~15B active parameter LLM with hybrid attention, fine-grained MoE, value scaling, asymmetric V head dims, and Multi-Token Prediction.

Changelog

  • Add MiMoV2FlashBridge with FP8 block-wise dequantization (supports non-uniform block sizes). Return fp32 weights to allow internal type cast.
  • Add MiMoV2FlashModelProvider with dual-base RoPE, per-layer KV head switching, and asymmetric V head dim
  • Add custom MiMoV2FlashSelfAttention that rebuilds linear_qkv/linear_proj for V head dim ≠ K head dim (128 vs 192)
  • Add custom MiMoV2FlashTEDotProductAttention with per-layer sliding window, attention sink bias, and TE k/v channel support
  • Add MiMoV2FlashQKVMapping for asymmetric QKV merge/split during checkpoint conversion
  • Add MTP support: dense MLP spec (not MoE), SWA attention, with MTP layer count auto-detected from safetensor keys
  • Add TP/CP assertions for known limitations (TP ≤ min KV groups, CP unsupported due to TE learnable softmax + CP). I decided not to replicate kv heads because it's not efficient.
  • Add unit tests: provider bridge config mapping, MTP detection, config round-trip, mapping registry coverage, QKV round-trip, FP8 dequant, weight loading hooks
  • Validated: per-layer fp32 numerics match HF reference, end-to-end generation produces correct output, parallelism tested (TP, EP, SP, TP+EP combinations). Also tested generation on 8GPUs.

GitHub Actions CI

The CI requires approval from an NVIDIA developer to run for external contributors.

Before your PR is "Ready for review"

Pre checks:

  • Make sure you read and followed Contributor guidelines
  • Did you write any new necessary tests?
  • Did you add or update any necessary documentation?
  • Does the PR affect components that are optional to install? (Ex: Numba, Pynini, Apex etc)
    • Reviewer: Does the PR have correct import guards for all optional libraries?

If you haven't finished some of the above items you can still open "Draft" PR.

Additional Information

@copy-pr-bot

copy-pr-bot Bot commented Apr 5, 2026

Copy link
Copy Markdown

This pull request requires additional validation before any workflows can run on NVIDIA's runners.

Pull request vetters can view their responsibilities here.

Contributors can view more details about this message here.

@beccohov

beccohov commented Apr 5, 2026

Copy link
Copy Markdown
Contributor Author

Drafted this WIP PR for MiMoV2 Flash support according to this issue

@beccohov

beccohov commented Apr 5, 2026

Copy link
Copy Markdown
Contributor Author

@sbhavani Hey,
MiMo-V2-Flash uses asymmetric attention head dimensions: Q/K use head_dim=192 but V uses v_head_dim=128. This is currently not supported by MCore's standard SelfAttention for two reasons:

  1. linear_proj input size — Attention.__init__ sizes linear_proj input as kv_channels * num_attention_heads = 192 * 64 = 12288 (here), but the correct size is v_head_dim * num_attention_heads = 128 * 64 = 8192.
  2. QKV split in forward pass — get_query_key_value_tensors splits the fused QKV output using hidden_size_per_attention_head (= kv_channels = 192) for both K and V, so V would be extracted with the wrong size (here).

Should I expose support similar to MLA but in standard SelfAttention for decoupled V size in megatron? Alternatively, is there a recommended pattern for bridging models with asymmetric V dims without modifying MCore?

@yaoyu-33

yaoyu-33 commented Apr 6, 2026

Copy link
Copy Markdown
Contributor

Should I expose support similar to MLA but in standard SelfAttention for decoupled V size in megatron? Alternatively, is there a recommended pattern for bridging models with asymmetric V dims without modifying MCore?

@beccohov : usually we patch the config / implementation / fwd in bridge directly and use custom layer spec to specify the customized version of it.

Here's the GitHub link to the OLMoE provider file:
https://github.com/NVIDIA-NeMo/Megatron-Bridge/blob/main/src/megatron/bridge/models/olmoe/olmoe_provider.py
Layer spec builder — L41-45
Custom OLMoESelfAttention class — L98-188

@beccohov

Copy link
Copy Markdown
Contributor Author

Hey, @yaoyu-33, I have a few questions before I mark this PR as "ready for review":

  1. Currently my implementation of MiMo-V2-Flash asserts CP because TE lacks a backend for CP + learnable softmax (attention sink bias). SWA layers use softmax_type="learnable" which is refused by all TE backends when CP is enabled. However, seems like it is relatively easy to fix in TE (here we'll go via MLA branch and we have v.shape wrong because of GQA). I tried to patch it in provider, but there are lot's of TE assertions that strictly control ability of hacks. Is asserting acceptable, or should we implement a workaround ?

  2. The HF config has attention_value_scale=0.707, but the released HF modeling code does not apply it in the forward pass (i.e. the released weights rescale V back). We currently skip it too (matching HF behavior). For training, however, I think we should re-enable the scale (i.e. to do value = value * 0.707) to match MiMo-v2 Flash fully. What do you think ? This is basically minor fix and I think it plays role only for large scale pretraining with fp8. For finetunes I think it's better to skip this scaling, because naive implementation would consume more memory on activations.

  3. Should mimo_v2_flash/ be merged into the existing mimo/ directory? The existing mimo/ contains both the original Xiaomi MiMo bridge (Qwen2-based) and the unrelated multimodal MIMO provider. I kept it separate to avoid confusion, but happy to merge if preferred.

@sbhavani

Copy link
Copy Markdown
Contributor

@beccohov I'd recommend creating an issue in TE to track CP support. I think the incomplete CP coverage for non-vanilla softmax should be fixed eventually.

@beccohov

Copy link
Copy Markdown
Contributor Author

Created issue in TE.
Apart from this issue, what about v_scale and directory ?

@beccohov beccohov marked this pull request as ready for review April 24, 2026 09:53
@beccohov

Copy link
Copy Markdown
Contributor Author

The TE issue with CP is fixed in TE2.15.

@svcnvidia-nemo-ci svcnvidia-nemo-ci added the waiting-on-maintainers Waiting on maintainers to respond label Apr 26, 2026
@beccohov

beccohov commented May 3, 2026

Copy link
Copy Markdown
Contributor Author

@sbhavani hey, would it be possible to have review please?
Thanks!

@svcnvidia-nemo-ci svcnvidia-nemo-ci removed the waiting-on-maintainers Waiting on maintainers to respond label May 3, 2026
@sbhavani sbhavani added the waiting-on-maintainers Waiting on maintainers to respond label May 4, 2026
@sbhavani

sbhavani commented May 4, 2026

Copy link
Copy Markdown
Contributor

CC @snowmanwwg

@svcnvidia-nemo-ci svcnvidia-nemo-ci removed the waiting-on-maintainers Waiting on maintainers to respond label May 4, 2026
@beccohov

beccohov commented May 4, 2026

Copy link
Copy Markdown
Contributor Author

@sbhavani the label is removed automatically anyways for some reason

@sbhavani

sbhavani commented May 4, 2026

Copy link
Copy Markdown
Contributor

@beccohov thanks! we need to fix that automation bug

CC @yaoyu-33

@cuichenx cuichenx left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you add a readme page and runs scripts following example here? https://github.com/NVIDIA-NeMo/Megatron-Bridge/tree/main/examples/models/minimax_m2

Can you also add a functional test following this example https://github.com/NVIDIA-NeMo/Megatron-Bridge/pull/2602/changes

@@ -0,0 +1,22 @@
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, changed in all places

Comment thread src/megatron/bridge/models/mimo_v2_flash/mimo_v2_flash_provider.py Outdated
Comment thread src/megatron/bridge/models/mimo_v2_flash/mimo_v2_flash_provider.py Outdated
@svcnvidia-nemo-ci svcnvidia-nemo-ci added waiting-on-customer Waiting on the original author to respond and removed waiting-on-maintainers Waiting on maintainers to respond labels May 11, 2026
@yaoyu-33 yaoyu-33 added area:model Model implementations and HF bridge logic feature New capabilities, enhancements, or enablement work labels May 12, 2026
@beccohov

Copy link
Copy Markdown
Contributor Author

@cuichenx I believe I've addressed all your comments — would appreciate a re-review!

@svcnvidia-nemo-ci svcnvidia-nemo-ci removed the waiting-on-customer Waiting on the original author to respond label May 17, 2026
@kamran-nvidia

Copy link
Copy Markdown
Contributor

@beccohov We are working on caching this model on CI server. I will re-run the failed CI test, when the fix is in place. No action needed from your side. Thanks.

@beccohov

Copy link
Copy Markdown
Contributor Author

@kamran-nvidia can you please share the PR / something where I can track the progress, so that I can understand when I'll be able to finish this PR?

@kamran-nvidia

Copy link
Copy Markdown
Contributor

@JRD971000 Can you help @beccohov please? Thanks.

@yaoyu-33

Copy link
Copy Markdown
Contributor

/ok to test 017062b

@beccohov

Copy link
Copy Markdown
Contributor Author

Afaiu, the tests failed due to (2) from my comment here. So what should I do? Should I add trust_remote_code ?
cc @yaoyu-33 @kamran-nvidia

@yaoyu-33

Copy link
Copy Markdown
Contributor

@beccohov nvm, I will force merge this. It's ci caching issue.

Signed-off-by: yaoyu-33 <yaoyu.094@gmail.com>
@yaoyu-33

Copy link
Copy Markdown
Contributor

/ok to test 341094f

@yaoyu-33

Copy link
Copy Markdown
Contributor

will follow up add back the functional tests after caching issue resolved.

@yaoyu-33 yaoyu-33 merged commit 38803fd into NVIDIA-NeMo:main May 31, 2026
13 checks passed
@Eisenhower

Copy link
Copy Markdown

Hi @beccohov @yaoyu-33 — small follow-up while reading this merged work.

MiMoV2FlashTEDotProductAttention.__init__ reads attention_value_scale from config and stores it as self._attention_value_scale (modeling_mimo_v2_flash.py#L256), but the value is never consumed in forward (L270-L271). For the released XiaomiMiMo/MiMo-V2-Flash checkpoint (attention_value_scale = 0.707) this silently scales the attention output by ~1.414× at every layer relative to the reference implementations.

The scale can't be folded into softmax_scale (softmax is non-linear) and linear_proj.weight is mapped 1:1 from HF o_proj.weight, so it has to land on V (or equivalently on the attention output). Both upstream references do this on V before the kernel:

I opened a two-line follow-up: #4155. Would appreciate a quick look when you have a moment — happy to add a focused unit test if that helps the review.

@beccohov

beccohov commented Jun 4, 2026

Copy link
Copy Markdown
Contributor Author

Hi @Eisenhower,
thanks for your comment. Actually, the reason why I did not scale is that HF weights seem to be already static scaled in advance (V weights). When I was implementing this I noticed that by quality degradation.
So you don't need to scale this to reproduce results if you load from retrained model because it was already trained with suitable magnitudes.
However, for training dynamics you may want to scale (this will affect gradients magnitude for V and also if you retrain from scratch this may change magnitude of attention scores). But naive scaling will increase memory consumption (and this would be notable on this large model) so ideally you need to implement this scaling to avoid memory overhead.

So, one way is no upscale back HF weights while loading them and then use this scaling in training. This is especially valuable when you train from scratch.
Another way is to skip this scaling completely (as in my implementation) and keep in mind that gradients for V would be slightly different. Since we use init from pertained model, I believe this won't significantly change model convergence (since we use smaller lr and model is more stable for post-training). But will save you memory (if you don't use full AC).

@Eisenhower

Copy link
Copy Markdown

Thanks for the context. I think the HF implementation you checked was likely the old version. I raised this issue with the Xiaomi MiMo team afterwards, and they have since updated the HF modeling code to apply attention_value_scale on value_states in forward. vLLM also applies the same scale before attention.

I agree the extra allocation from a naive value = value * scale is worth optimizing. My concern is mainly RL consistency: rollout usually follows the vLLM/HF inference path, while training/logprob/KL are computed by the actor path. If rollout applies the scale but training does not, the two forward paths are no longer equivalent. For that reason I would prefer Megatron-Bridge to match the HF/vLLM semantics first, and then optimize the implementation or fold the scale only if we can keep the two paths numerically equivalent.

@Eisenhower

Copy link
Copy Markdown

@beccohov Thanks for the context. I think the HF implementation you checked was likely the old version. I raised this issue with the Xiaomi MiMo team afterwards, and they have since updated the HF modeling code to apply attention_value_scale on value_states in forward. vLLM also applies the same scale before attention.

I agree the extra allocation from a naive value = value * scale is worth optimizing. My concern is mainly RL consistency: rollout usually follows the vLLM/HF inference path, while training/logprob/KL are computed by the actor path. If rollout applies the scale but training does not, the two forward paths are no longer equivalent. For that reason I would prefer Megatron-Bridge to match the HF/vLLM semantics first, and then optimize the implementation or fold the scale only if we can keep the two paths numerically equivalent.

@beccohov

Copy link
Copy Markdown
Contributor Author

@Eisenhower Sure, if HF weights will be updated then the situation changes. You can then add a follow-up PR with this small change!

vasunvidia pushed a commit to vasunvidia/Megatron-Bridge that referenced this pull request Jun 10, 2026
Signed-off-by: Arkadii Be <beccohov@gmail.com>
Signed-off-by: yaoyu-33 <yaoyu.094@gmail.com>
Co-authored-by: Chen Cui <chcui@nvidia.com>
Co-authored-by: Yu Yao <54727607+yaoyu-33@users.noreply.github.com>
Co-authored-by: yaoyu-33 <yaoyu.094@gmail.com>
Signed-off-by: Vasudevan Rengasamy <vrengasamy@nvidia.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

area:model Model implementations and HF bridge logic community-request feature New capabilities, enhancements, or enablement work ready-to-merge PR is approved, current, and only waiting for CI to pass before merge

Projects

None yet

Development

Successfully merging this pull request may close these issues.

7 participants