Skip to content

Fix FlashAttention MLA prefill V unpadding#42642

Open
voipmonitor wants to merge 2 commits into
vllm-project:mainfrom
voipmonitor:codex/upstream-mla-lateunpad-pr
Open

Fix FlashAttention MLA prefill V unpadding#42642
voipmonitor wants to merge 2 commits into
vllm-project:mainfrom
voipmonitor:codex/upstream-mla-lateunpad-pr

Conversation

@voipmonitor

Copy link
Copy Markdown
Contributor

Purpose

Fix a regression in the FlashAttention MLA prefill path introduced when the prefill implementations were split out in #32623.

Before #32623, the FlashAttention MLA helper padded V when the selected FlashAttention implementation did not support different QK/V head dimensions. The padded output was kept through the context/suffix merge_attn_states path, and only then sliced back to v_head_dim in MLACommonImpl.forward_mha.

After #32623, FlashAttnPrefillBackend._flash_attn_varlen_diff_headdims() slices the output back to v_head_dim inside the backend. That changes the tensor shape contract seen by the chunked-context merge path. On a long-context Kimi/DeepSeek-style MLA setup using the FlashAttention prefill backend with requires_v_padding=True, this produced incorrect long-context generations: the model started continuing unrelated prompt padding text instead of answering the user question. Keeping the old late-unpad behavior restores the expected output.

This PR moves the unpad back to the caller:

  • FlashAttnPrefillBackend now returns the same padded output shape that the old in-file FlashAttention MLA prefill helper returned.
  • MLACommonImpl.forward_mha slices context_output, suffix_output, and no-context output_prefill back to v_head_dim immediately before writing/merging into the final output buffer.

Notes

This only affects the FlashAttention MLA prefill backend when requires_v_padding=True. Backends that natively support different QK/V head dimensions keep returning v_head_dim already, so the added checks are no-ops for them.

I intentionally kept this as a small compatibility fix rather than changing backend selection or DCP behavior.

Test Plan

  • python3 -m py_compile vllm/model_executor/layers/attention/mla_attention.py vllm/v1/attention/backends/mla/prefill/flash_attn.py
  • Local long-context Kimi-K2.6 validation on Blackwell, TRITON_MLA, DCP=1, MTP disabled, 128k synthetic context. Before this change the model continued unrelated context padding; with this change it answers the requested Sieve of Eratosthenes prompt again.

@claude claude Bot left a comment

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Claude Code Review

This pull request is from a fork — automated review is disabled. A repository maintainer can comment @claude review to run a one-time review.

@gemini-code-assist gemini-code-assist Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request refactors the unpadding logic for attention outputs by moving it from the backend-specific Flash Attention implementation to the core MLA attention layer. This change ensures that outputs are correctly sliced to the value head dimension when padding is applied. The reviewer recommended simplifying the newly added conditional check in mla_attention.py by removing a redundant attribute check, as the shape comparison is sufficient to identify when unpadding is required.

Comment on lines +2313 to +2317
prefill_backend = prefill_metadata.prefill_backend
if (
getattr(prefill_backend, "requires_v_padding", False)
and context_output.shape[-1] != self.v_head_dim
):

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The check for requires_v_padding via getattr is redundant here because the shape check context_output.shape[-1] != self.v_head_dim is sufficient to determine if unpadding is necessary. Simplifying this condition makes the code cleaner and consistent with the logic used in the else block (lines 2332-2333).

Suggested change
prefill_backend = prefill_metadata.prefill_backend
if (
getattr(prefill_backend, "requires_v_padding", False)
and context_output.shape[-1] != self.v_head_dim
):
if context_output.shape[-1] != self.v_head_dim:

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point. I updated this to rely on the output shape instead of the backend-specific attribute, and made the context and suffix slices independent so each tensor is normalized to v_head_dim before the merge.

Signed-off-by: Martin Vit <martin@voipmonitor.org>
Signed-off-by: Martin Vit <martin@voipmonitor.org>
@voipmonitor voipmonitor force-pushed the codex/upstream-mla-lateunpad-pr branch from fb90f07 to a27fab7 Compare May 14, 2026 13:53

@MatthewBonanni MatthewBonanni left a comment

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for catching this! Just a few small comments

Comment on lines +2313 to +2317
if context_output.shape[-1] != self.v_head_dim:
context_output = context_output[..., : self.v_head_dim]
if suffix_output.shape[-1] != self.v_head_dim:
suffix_output = suffix_output[..., : self.v_head_dim]

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The if statements aren't necessary because this will be a no-op when context_output.shape[-1] == self.v_head_dim

Comment on lines +2329 to +2330
if output_prefill.shape[-1] != self.v_head_dim:
output_prefill = output_prefill[..., : self.v_head_dim]

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto

)

if context_output.shape[-1] != self.v_head_dim:
context_output = context_output[..., : self.v_head_dim]

@MatthewBonanni MatthewBonanni May 14, 2026

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: stylistically would prefer context_output[..., :self.v_head_dim] (no space after colon)

@ehfd

ehfd commented Jun 4, 2026

Copy link
Copy Markdown
Contributor

@voipmonitor @MatthewBonanni Can we revive this?

@ehfd

ehfd commented Jun 4, 2026

Copy link
Copy Markdown
Contributor

I think this is related to #41623 or #42426.

@bbartels

Copy link
Copy Markdown
Contributor

Would be curious about this as well :)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants