Skip to content

[Feature] Support QwenImageEditPlus series attention mask for NPU #13016

@zhangtao0408

Description

@zhangtao0408

Problem related

# Construct joint attention mask once to avoid reconstructing in every block
# This eliminates 60 GPU syncs during training while maintaining torch.compile compatibility
block_attention_kwargs = attention_kwargs.copy() if attention_kwargs is not None else {}
if encoder_hidden_states_mask is not None:
# Build joint mask: [text_mask, all_ones_for_image]
batch_size, image_seq_len = hidden_states.shape[:2]
image_mask = torch.ones((batch_size, image_seq_len), dtype=torch.bool, device=hidden_states.device)
joint_attention_mask = torch.cat([encoder_hidden_states_mask, image_mask], dim=1)
block_attention_kwargs["attention_mask"] = joint_attention_mask

Since PR #12702 introduced the attention mask to QwenImageEditPlus Series, the current _native_npu backend attention implementation does not support passing in the attention mask, which causes an error.

def _native_npu_attention(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attn_mask: Optional[torch.Tensor] = None,
dropout_p: float = 0.0,
scale: Optional[float] = None,
return_lse: bool = False,
_parallel_config: Optional["ParallelConfig"] = None,
) -> torch.Tensor:
if attn_mask is not None:
raise ValueError("`attn_mask` is not supported for NPU attention")
if return_lse:
raise ValueError("NPU attention backend does not support setting `return_lse=True`.")

We would like to enable NPU support for QwenImageEditPlus. Based on a printed check, the mask currently contains all 1s (full attention). Is it possible to use a workaround to bypass this limitation so that Qwen-Image-Edit-Plus can run normally on the NPU?

Solution

The npu_fusion_attention function supports several shapes of attention masks. We need to add a check to verify if the attention_mask is supported by the NPU backend:

  1. If the mask consists entirely of 1s (indicating full attention), we can pass None as the attention_mask to implement full attention in FA.
  2. Otherwise, we should refer to the official documentation to add a validation check and determine whether the attention mask is valid for npu_fusion_attention.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions