Skip to content

dispatch_attention_fn silently ignores attn_mask for certain backends #12605

@zzlol63

Description

@zzlol63

Describe the bug

When using the attention backends feature, it's possible to pass an attention mask to the dispatch_attention_fn to a backend which does not support them (ie. attn_mask is not in the method signature registered to that particular backend), and have no error thrown.

This may cause unintended side-effects during training or inference where the attention will be computed as a full-attention scenario.

This is not validated either at the as part of the optional checks used for debugging purposes (DIFFUSERS_ATTN_CHECKS=yes)

Reproduction

import os
os.environ["DIFFUSERS_ATTN_CHECKS"] = "yes"

import torch
from diffusers.models.attention_dispatch import dispatch_attention_fn, attention_backend

query = torch.randn(1, 10, 8, 64, dtype=torch.bfloat16, device="cuda")
key = torch.randn(1, 10, 8, 64, dtype=torch.bfloat16, device="cuda")
value = torch.randn(1, 10, 8, 64, dtype=torch.bfloat16, device="cuda")
attn_mask = torch.zeros((1, 1, 10, 10), dtype=torch.bool, device="cuda")

with attention_backend("native"):
    output = dispatch_attention_fn(query, key, value, attn_mask)
    output2 = dispatch_attention_fn(query, key, value)

    assert not torch.equal(output, output2), "native: These outputs should not be equal!"

with attention_backend("flash"):
    output = dispatch_attention_fn(query, key, value, attn_mask)
    output2 = dispatch_attention_fn(query, key, value)

    assert not torch.equal(output, output2), "flash: These outputs should not be equal!"
Traceback (most recent call last):
  File "C:\repro.py", line 22, in <module>
    assert not torch.equal(output, output2), "flash: These outputs should not be equal!"
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
AssertionError: flash: These outputs should not be equal!

System Info

  • 🤗 Diffusers version: 0.35.2
  • Platform: Windows-11-10.0.26100-SP0
  • Running on Google Colab?: No
  • Python version: 3.12.12
  • PyTorch version (GPU?): 2.8.0+cu128 (True)
  • Flax version (CPU?/GPU?/TPU?): not installed (NA)
  • Jax version: not installed
  • JaxLib version: not installed
  • Huggingface_hub version: 1.1.2
  • Transformers version: not installed
  • Accelerate version: not installed
  • PEFT version: not installed
  • Bitsandbytes version: not installed
  • Safetensors version: 0.6.2
  • xFormers version: not installed
  • Accelerator: NVIDIA GeForce RTX 5090, 32607 MiB
  • Using GPU in script?: Yes
  • Using distributed or parallel set-up in script?: No

Who can help?

No response

Metadata

Metadata

Assignees

No one assigned

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions