Fix FlashAttention MLA prefill V unpadding#42642
Conversation
There was a problem hiding this comment.
Code Review
This pull request refactors the unpadding logic for attention outputs by moving it from the backend-specific Flash Attention implementation to the core MLA attention layer. This change ensures that outputs are correctly sliced to the value head dimension when padding is applied. The reviewer recommended simplifying the newly added conditional check in mla_attention.py by removing a redundant attribute check, as the shape comparison is sufficient to identify when unpadding is required.
| prefill_backend = prefill_metadata.prefill_backend | ||
| if ( | ||
| getattr(prefill_backend, "requires_v_padding", False) | ||
| and context_output.shape[-1] != self.v_head_dim | ||
| ): |
There was a problem hiding this comment.
The check for requires_v_padding via getattr is redundant here because the shape check context_output.shape[-1] != self.v_head_dim is sufficient to determine if unpadding is necessary. Simplifying this condition makes the code cleaner and consistent with the logic used in the else block (lines 2332-2333).
| prefill_backend = prefill_metadata.prefill_backend | |
| if ( | |
| getattr(prefill_backend, "requires_v_padding", False) | |
| and context_output.shape[-1] != self.v_head_dim | |
| ): | |
| if context_output.shape[-1] != self.v_head_dim: |
There was a problem hiding this comment.
Good point. I updated this to rely on the output shape instead of the backend-specific attribute, and made the context and suffix slices independent so each tensor is normalized to v_head_dim before the merge.
Signed-off-by: Martin Vit <martin@voipmonitor.org>
Signed-off-by: Martin Vit <martin@voipmonitor.org>
fb90f07 to
a27fab7
Compare
MatthewBonanni
left a comment
There was a problem hiding this comment.
Thanks for catching this! Just a few small comments
| if context_output.shape[-1] != self.v_head_dim: | ||
| context_output = context_output[..., : self.v_head_dim] | ||
| if suffix_output.shape[-1] != self.v_head_dim: | ||
| suffix_output = suffix_output[..., : self.v_head_dim] | ||
|
|
There was a problem hiding this comment.
The if statements aren't necessary because this will be a no-op when context_output.shape[-1] == self.v_head_dim
| if output_prefill.shape[-1] != self.v_head_dim: | ||
| output_prefill = output_prefill[..., : self.v_head_dim] |
| ) | ||
|
|
||
| if context_output.shape[-1] != self.v_head_dim: | ||
| context_output = context_output[..., : self.v_head_dim] |
There was a problem hiding this comment.
nit: stylistically would prefer context_output[..., :self.v_head_dim] (no space after colon)
|
@voipmonitor @MatthewBonanni Can we revive this? |
|
Would be curious about this as well :) |
Purpose
Fix a regression in the FlashAttention MLA prefill path introduced when the prefill implementations were split out in #32623.
Before #32623, the FlashAttention MLA helper padded
Vwhen the selected FlashAttention implementation did not support different QK/V head dimensions. The padded output was kept through the context/suffixmerge_attn_statespath, and only then sliced back tov_head_diminMLACommonImpl.forward_mha.After #32623,
FlashAttnPrefillBackend._flash_attn_varlen_diff_headdims()slices the output back tov_head_diminside the backend. That changes the tensor shape contract seen by the chunked-context merge path. On a long-context Kimi/DeepSeek-style MLA setup using the FlashAttention prefill backend withrequires_v_padding=True, this produced incorrect long-context generations: the model started continuing unrelated prompt padding text instead of answering the user question. Keeping the old late-unpad behavior restores the expected output.This PR moves the unpad back to the caller:
FlashAttnPrefillBackendnow returns the same padded output shape that the old in-file FlashAttention MLA prefill helper returned.MLACommonImpl.forward_mhaslicescontext_output,suffix_output, and no-contextoutput_prefillback tov_head_dimimmediately before writing/merging into the final output buffer.Notes
This only affects the FlashAttention MLA prefill backend when
requires_v_padding=True. Backends that natively support different QK/V head dimensions keep returningv_head_dimalready, so the added checks are no-ops for them.I intentionally kept this as a small compatibility fix rather than changing backend selection or DCP behavior.
Test Plan
python3 -m py_compile vllm/model_executor/layers/attention/mla_attention.py vllm/v1/attention/backends/mla/prefill/flash_attn.pyTRITON_MLA,DCP=1, MTP disabled, 128k synthetic context. Before this change the model continued unrelated context padding; with this change it answers the requested Sieve of Eratosthenes prompt again.