Skip to content

[model] fix: apply attention_value_scale to V in MiMo-V2-Flash#4155

Open
Eisenhower wants to merge 3 commits into
NVIDIA-NeMo:mainfrom
Eisenhower:fix/mimo-v2-flash-value-scale
Open

[model] fix: apply attention_value_scale to V in MiMo-V2-Flash#4155
Eisenhower wants to merge 3 commits into
NVIDIA-NeMo:mainfrom
Eisenhower:fix/mimo-v2-flash-value-scale

Conversation

@Eisenhower

@Eisenhower Eisenhower commented Jun 4, 2026

Copy link
Copy Markdown

What does this PR do?

Fix MiMo-V2-Flash attention to actually apply attention_value_scale to the V tensor.

MiMoV2FlashTEDotProductAttention.__init__ reads attention_value_scale from the HF config and stores it as self._attention_value_scale, but the field is never consumed in forward. As a result the scale (0.707 in the public XiaomiMiMo/MiMo-V2-Flash checkpoint) is silently dropped, leaving the attention output scaled by ~1.414x at every layer relative to the reference implementation.

Changelog

  • MiMoV2FlashTEDotProductAttention.forward: multiply value by self._attention_value_scale (when set) before delegating to TE's attention kernel.

MiMoV2FlashMTPTEDotProductAttention inherits the fix automatically.

Why on V (and only on V)?

attention_value_scale cannot be folded into softmax_scale — softmax is non-linear in its argument, so scaling QK^T changes the attention distribution. Mathematically the scale must land either:

  • on V before the kernel (what HF and vLLM do), or
  • equivalently on the attention output before o_proj.

The bridge's mapping_registry maps linear_proj.weight 1:1 from HF o_proj.weight with no transformation, so the scale is not being absorbed at checkpoint-load time either. Applying it on V is the only correct fix.

References — both upstream implementations apply this on V

  1. HuggingFace reference (XiaomiMiMo/MiMo-V2-Flash):
    https://huggingface.co/XiaomiMiMo/MiMo-V2-Flash/blob/main/modeling_mimo_v2_flash.py

    In MiMoV2Attention:

    self.v_scale = getattr(config, "attention_value_scale", None)
    ...
    value_states = self.v_proj(hidden_states).view(v_hidden_shape).transpose(1, 2)
    if self.v_scale is not None:
        value_states = value_states * self.v_scale
    # then RoPE on Q/K, then attention_interface(...)
  2. vLLM (vllm/model_executor/models/mimo_v2.py):
    https://github.com/vllm-project/vllm/blob/e68988a24807c9dfb2bf6936eb17425ce7812c5f/vllm/model_executor/models/mimo_v2.py#L327-L329

    # Apply v_scale before attention
    if self.v_scale is not None:
        v = v * self.v_scale
    
    attn_output = self.attn(q, k, v)

    Both the SWA and full-attention branches in vLLM pass the same v_scale = getattr(config, "attention_value_scale", None).

Numerical impact

For the released XiaomiMiMo/MiMo-V2-Flash checkpoint with attention_value_scale = 0.707, the omission scales the attention output by ~1.414x at every attention layer. The error propagates into logits/log-probs and is large enough to materially affect both inference quality and downstream RL training stability (where training/inference consistency with vLLM matters).

Tests

This is a two-line behavioral fix. The existing tests/unit_tests/models/mimo_v2_flash/test_mimo_v2_flash_bridge.py covers config/mapping round-trips but not the attention forward path. Happy to add a focused unit test asserting value is scaled when attention_value_scale is set if maintainers would like — let me know.

Before your PR is "Ready for review"

  • Make sure you read and followed Contributor guidelines
  • Did you write any new necessary tests? — see note above
  • Did you add or update any necessary documentation? — N/A, no public API change
  • Does the PR affect components that are optional to install? — No

@copy-pr-bot

copy-pr-bot Bot commented Jun 4, 2026

Copy link
Copy Markdown

This pull request requires additional validation before any workflows can run on NVIDIA's runners.

Pull request vetters can view their responsibilities here.

Contributors can view more details about this message here.

The attention_value_scale field is read from the HF config and stored on
MiMoV2FlashTEDotProductAttention as `self._attention_value_scale`, but
never applied to the value tensor in the forward path. As a result the
factor (0.707 in the public XiaomiMiMo/MiMo-V2-Flash checkpoint) is
silently dropped, leaving the attention output scaled by ~1.414x
relative to the reference implementation.

The scale cannot be folded into `softmax_scale` (softmax is non-linear
in its argument) and the linear_proj weight is mapped 1:1 from HF
`o_proj.weight`, so the only correct place to apply it is on V (or
equivalently on the attention output).

This mirrors the HF reference modeling code and vLLM's MiMo-V2 model
runner, both of which multiply V by `attention_value_scale` before the
attention kernel.

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
Signed-off-by: Leo <hitler594588@163.com>
@Eisenhower Eisenhower force-pushed the fix/mimo-v2-flash-value-scale branch from cc0ebf1 to e5fbaad Compare June 4, 2026 12:43
@yaoyu-33 yaoyu-33 added area:model Model implementations and HF bridge logic bug Something isn't working needs-more-tests Requires additional L0 and L1 test coverage before merge needs-review PR is ready for code review and waiting on a reviewer ready-to-merge PR is approved, current, and only waiting for CI to pass before merge and removed needs-review PR is ready for code review and waiting on a reviewer labels Jun 4, 2026
yaoyu-33
yaoyu-33 previously approved these changes Jun 4, 2026
@yaoyu-33

yaoyu-33 commented Jun 4, 2026

Copy link
Copy Markdown
Contributor

/ok to test e5fbaad

Adds unit tests for the MiMoV2FlashTEDotProductAttention.forward
override that applies attention_value_scale to V before the attention
kernel. Tests verify:

- Scale is applied (V passed to super() equals input * scale)
- None scale leaves V passing through unchanged (same object)
- Caller's V buffer is not mutated in place
- Q and K are forwarded untouched
- Extra kwargs are forwarded to super()
- super()'s return value is propagated

The tests bypass TransformerEngine entirely by using object.__new__ to
build an instance without invoking __init__, then patching the parent
TEDotProductAttention.forward to capture what gets forwarded. This
keeps the tests CPU-only and dependency-free.

Closes the codecov/patch coverage gap on the two-line fix in
src/megatron/bridge/models/mimo_v2_flash/modeling_mimo_v2_flash.py.

Signed-off-by: Leo <hitler594588@163.com>
@Eisenhower

Copy link
Copy Markdown
Author

Pushed 9006932d with unit tests covering the attention_value_scale forward path in MiMoV2FlashTEDotProductAttention.forward, addressing the codecov/patch 0% gap on the previous diff.

The new tests (7 cases, all CPU-only) verify:

  • V is multiplied by the scale before the attention kernel
  • None scale leaves V passing through unchanged (same tensor object)
  • The caller's V buffer is not mutated in place
  • Q and K reach super().forward untouched
  • Extra **kwargs are forwarded
  • The return value from super().forward is propagated

The tests bypass TransformerEngine by constructing the instance via object.__new__ (skipping __init__) and patching the parent's forward, so they don't require a CUDA build of TE.

@yaoyu-33 — could you re-issue /ok to test 9006932d so CI picks up the new commit? Thanks!

Two pre-existing failures (release / build-docs and release-summary) look like they're tied to the release-pipeline workflow rather than this change — happy to dig in if you'd like, but flagging in case they're known infra issues.

@Eisenhower

Copy link
Copy Markdown
Author

Friendly ping @yaoyu-33 — the latest head is 9006932db8ec4c63ad026762b6f17b5f5fb21dd2, which adds the requested CPU-only unit coverage for the attention_value_scale forward path. Could you please re-issue /ok to test 9006932db8ec4c63ad026762b6f17b5f5fb21dd2 and re-approve when you have a chance? Thanks.

@adityavavreNVDA

Copy link
Copy Markdown
Contributor

/ok to test c797336

@gautham-kollu

Copy link
Copy Markdown
Contributor

UT added are failing. Adding the log below.

I think the super().foward() Mock seems to be not working as expected. @Eisenhower can you please look at it ?
May be try passing an actual argument that TEDotProductAttention.forward() takes to see if you get a different error


 =================================== FAILURES ===================================
  __________ TestAttentionValueScaleForward.test_scale_applied_to_value __________
  
  self = <tests.unit_tests.models.mimo_v2_flash.test_modeling_mimo_v2_flash.TestAttentionValueScaleForward object at 0x79d70ac1fd40>
  
      def test_scale_applied_to_value(self):
          scale = 0.707
          v = torch.randn(2, 4, 8, 64)
  >       captured, _, _ = _invoke_forward(scale, v)
                           ^^^^^^^^^^^^^^^^^^^^^^^^^
  
  tests/unit_tests/models/mimo_v2_flash/test_modeling_mimo_v2_flash.py:98: 
  _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 
  tests/unit_tests/models/mimo_v2_flash/test_modeling_mimo_v2_flash.py:75: in _invoke_forward
      out = instance.forward(
  src/megatron/bridge/models/mimo_v2_flash/modeling_mimo_v2_flash.py:273: in forward
      return super().forward(query, key, value, attention_mask, attn_mask_type, **kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  <string>:2: in forward
      ???
  /usr/lib/python3.12/unittest/mock.py:192: in checksig
      sig.bind(*args, **kwargs)
  /usr/lib/python3.12/inspect.py:3242: in bind
      return self._bind(args, kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^
  _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 
  
  self = <Signature (self, query: 'Tensor', key: 'Tensor', value: 'Tensor', attention_mask: 'Optional[Tensor]', attn_mask_type:...' = None, packed_seq_params: 'Optional[PackedSeqParams]' = None, num_splits: 'Optional[int]' = None) -> 'torch.Tensor'>
  args = (<[AttributeError("'MiMoV2FlashTEDotProductAttention' object has no attribute '_modules'") raised in repr()] MiMoV2Fla...e-01,  8.8249e-01,  ..., -7.7397e-02,
             -1.3442e+00,  6.1891e-01]]]]), None, <MagicMock id='133991313938896'>)
  kwargs = {'extra_kwarg': 'passthrough'}
  
      def _bind(self, args, kwargs, *, partial=False):
          """Private method. Don't use directly."""
      
          arguments = {}
      
          parameters = iter(self.parameters.values())
          parameters_ex = ()
          arg_vals = iter(args)
      
          while True:
              # Let's iterate through the positional arguments and corresponding
              # parameters
              try:
                  arg_val = next(arg_vals)
              except StopIteration:
                  # No more positional arguments
                  try:
                      param = next(parameters)
                  except StopIteration:
                      # No more parameters. That's it. Just need to check that
                      # we have no `kwargs` after this while loop
                      break
                  else:
                      if param.kind == _VAR_POSITIONAL:
                          # That's OK, just empty *args.  Let's start parsing
                          # kwargs
                          break
                      elif param.name in kwargs:
                          if param.kind == _POSITIONAL_ONLY:
                              msg = '{arg!r} parameter is positional only, ' \
                                    'but was passed as a keyword'
                              msg = msg.format(arg=param.name)
                              raise TypeError(msg) from None
                          parameters_ex = (param,)
                          break
                      elif (param.kind == _VAR_KEYWORD or
                                                  param.default is not _empty):
                          # That's fine too - we have a default value for this
                          # parameter.  So, lets start parsing `kwargs`, starting
                          # with the current parameter
                          parameters_ex = (param,)
                          break
                      else:
                          # No default, not VAR_KEYWORD, not VAR_POSITIONAL,
                          # not in `kwargs`
                          if partial:
                              parameters_ex = (param,)
                              break
                          else:
                              if param.kind == _KEYWORD_ONLY:
                                  argtype = ' keyword-only'
                              else:
                                  argtype = ''
                              msg = 'missing a required{argtype} argument: {arg!r}'
                              msg = msg.format(arg=param.name, argtype=argtype)
                              raise TypeError(msg) from None
              else:
                  # We have a positional argument to process
                  try:
                      param = next(parameters)
                  except StopIteration:
                      raise TypeError('too many positional arguments') from None
                  else:
                      if param.kind in (_VAR_KEYWORD, _KEYWORD_ONLY):
                          # Looks like we have no parameter for this positional
                          # argument
                          raise TypeError(
                              'too many positional arguments') from None
      
                      if param.kind == _VAR_POSITIONAL:
                          # We have an '*args'-like argument, let's fill it with
                          # all positional arguments we have left and move on to
                          # the next phase
                          values = [arg_val]
                          values.extend(arg_vals)
                          arguments[param.name] = tuple(values)
                          break
      
                      if param.name in kwargs and param.kind != _POSITIONAL_ONLY:
                          raise TypeError(
                              'multiple values for argument {arg!r}'.format(
                                  arg=param.name)) from None
      
                      arguments[param.name] = arg_val
      
          # Now, we iterate through the remaining parameters to process
          # keyword arguments
          kwargs_param = None
          for param in itertools.chain(parameters_ex, parameters):
              if param.kind == _VAR_KEYWORD:
                  # Memorize that we have a '**kwargs'-like parameter
                  kwargs_param = param
                  continue
      
              if param.kind == _VAR_POSITIONAL:
                  # Named arguments don't refer to '*args'-like parameters.
                  # We only arrive here if the positional arguments ended
                  # before reaching the last parameter before *args.
                  continue
      
              param_name = param.name
              try:
                  arg_val = kwargs.pop(param_name)
              except KeyError:
                  # We have no value for this parameter.  It's fine though,
                  # if it has a default value, or it is an '*args'-like
                  # parameter, left alone by the processing of positional
                  # arguments.
                  if (not partial and param.kind != _VAR_POSITIONAL and
                                                      param.default is _empty):
                      raise TypeError('missing a required argument: {arg!r}'. \
                                      format(arg=param_name)) from None
      
              else:
                  if param.kind == _POSITIONAL_ONLY:
                      # This should never happen in case of a properly built
                      # Signature object (but let's have this check here
                      # to ensure correct behaviour just in case)
                      raise TypeError('{arg!r} parameter is positional only, '
                                      'but was passed as a keyword'. \
                                      format(arg=param.name))
      
                  arguments[param_name] = arg_val
      
          if kwargs:
              if kwargs_param is not None:
                  # Process our '**kwargs'-like parameter
                  arguments[kwargs_param.name] = kwargs
              else:
  >               raise TypeError(
                      'got an unexpected keyword argument {arg!r}'.format(
                          arg=next(iter(kwargs))))
  E               TypeError: got an unexpected keyword argument 'extra_kwarg'
  
  /usr/lib/python3.12/inspect.py:3231: TypeError
  ________ TestAttentionValueScaleForward.test_scale_various_values[0.5] _________

@gautham-kollu

Copy link
Copy Markdown
Contributor

Dropping ready-to-merge label as the UTs added in this PR are failing

@gautham-kollu gautham-kollu removed the ready-to-merge PR is approved, current, and only waiting for CI to pass before merge label Jun 12, 2026
@yaoyu-33 yaoyu-33 added the waiting-on-customer Waiting on the original author to respond label Jun 12, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

area:model Model implementations and HF bridge logic bug Something isn't working community-request needs-more-tests Requires additional L0 and L1 test coverage before merge waiting-on-customer Waiting on the original author to respond

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants