[model] fix: apply attention_value_scale to V in MiMo-V2-Flash#4155
[model] fix: apply attention_value_scale to V in MiMo-V2-Flash#4155Eisenhower wants to merge 3 commits into
Conversation
The attention_value_scale field is read from the HF config and stored on MiMoV2FlashTEDotProductAttention as `self._attention_value_scale`, but never applied to the value tensor in the forward path. As a result the factor (0.707 in the public XiaomiMiMo/MiMo-V2-Flash checkpoint) is silently dropped, leaving the attention output scaled by ~1.414x relative to the reference implementation. The scale cannot be folded into `softmax_scale` (softmax is non-linear in its argument) and the linear_proj weight is mapped 1:1 from HF `o_proj.weight`, so the only correct place to apply it is on V (or equivalently on the attention output). This mirrors the HF reference modeling code and vLLM's MiMo-V2 model runner, both of which multiply V by `attention_value_scale` before the attention kernel. Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com> Signed-off-by: Leo <hitler594588@163.com>
cc0ebf1 to
e5fbaad
Compare
|
/ok to test e5fbaad |
Adds unit tests for the MiMoV2FlashTEDotProductAttention.forward override that applies attention_value_scale to V before the attention kernel. Tests verify: - Scale is applied (V passed to super() equals input * scale) - None scale leaves V passing through unchanged (same object) - Caller's V buffer is not mutated in place - Q and K are forwarded untouched - Extra kwargs are forwarded to super() - super()'s return value is propagated The tests bypass TransformerEngine entirely by using object.__new__ to build an instance without invoking __init__, then patching the parent TEDotProductAttention.forward to capture what gets forwarded. This keeps the tests CPU-only and dependency-free. Closes the codecov/patch coverage gap on the two-line fix in src/megatron/bridge/models/mimo_v2_flash/modeling_mimo_v2_flash.py. Signed-off-by: Leo <hitler594588@163.com>
|
Pushed The new tests (7 cases, all CPU-only) verify:
The tests bypass TransformerEngine by constructing the instance via @yaoyu-33 — could you re-issue Two pre-existing failures ( |
|
Friendly ping @yaoyu-33 — the latest head is |
|
/ok to test c797336 |
|
UT added are failing. Adding the log below. I think the super().foward() Mock seems to be not working as expected. @Eisenhower can you please look at it ? |
|
Dropping ready-to-merge label as the UTs added in this PR are failing |
What does this PR do?
Fix MiMo-V2-Flash attention to actually apply
attention_value_scaleto the V tensor.MiMoV2FlashTEDotProductAttention.__init__readsattention_value_scalefrom the HF config and stores it asself._attention_value_scale, but the field is never consumed inforward. As a result the scale (0.707 in the publicXiaomiMiMo/MiMo-V2-Flashcheckpoint) is silently dropped, leaving the attention output scaled by ~1.414x at every layer relative to the reference implementation.Changelog
MiMoV2FlashTEDotProductAttention.forward: multiplyvaluebyself._attention_value_scale(when set) before delegating to TE's attention kernel.MiMoV2FlashMTPTEDotProductAttentioninherits the fix automatically.Why on V (and only on V)?
attention_value_scalecannot be folded intosoftmax_scale— softmax is non-linear in its argument, so scaling QK^T changes the attention distribution. Mathematically the scale must land either:o_proj.The bridge's
mapping_registrymapslinear_proj.weight1:1 from HFo_proj.weightwith no transformation, so the scale is not being absorbed at checkpoint-load time either. Applying it on V is the only correct fix.References — both upstream implementations apply this on V
HuggingFace reference (
XiaomiMiMo/MiMo-V2-Flash):https://huggingface.co/XiaomiMiMo/MiMo-V2-Flash/blob/main/modeling_mimo_v2_flash.py
In
MiMoV2Attention:vLLM (
vllm/model_executor/models/mimo_v2.py):https://github.com/vllm-project/vllm/blob/e68988a24807c9dfb2bf6936eb17425ce7812c5f/vllm/model_executor/models/mimo_v2.py#L327-L329
Both the SWA and full-attention branches in vLLM pass the same
v_scale = getattr(config, "attention_value_scale", None).Numerical impact
For the released
XiaomiMiMo/MiMo-V2-Flashcheckpoint withattention_value_scale = 0.707, the omission scales the attention output by ~1.414x at every attention layer. The error propagates into logits/log-probs and is large enough to materially affect both inference quality and downstream RL training stability (where training/inference consistency with vLLM matters).Tests
This is a two-line behavioral fix. The existing
tests/unit_tests/models/mimo_v2_flash/test_mimo_v2_flash_bridge.pycovers config/mapping round-trips but not the attention forward path. Happy to add a focused unit test assertingvalueis scaled whenattention_value_scaleis set if maintainers would like — let me know.Before your PR is "Ready for review"