This may cause unintended side-effects during training or inference where the attention will be computed as a full-attention scenario.
This is not validated either at the as part of the optional checks used for debugging purposes (DIFFUSERS_ATTN_CHECKS=yes)
import os
os.environ["DIFFUSERS_ATTN_CHECKS"] = "yes"
import torch
from diffusers.models.attention_dispatch import dispatch_attention_fn, attention_backend
query = torch.randn(1, 10, 8, 64, dtype=torch.bfloat16, device="cuda")
key = torch.randn(1, 10, 8, 64, dtype=torch.bfloat16, device="cuda")
value = torch.randn(1, 10, 8, 64, dtype=torch.bfloat16, device="cuda")
attn_mask = torch.zeros((1, 1, 10, 10), dtype=torch.bool, device="cuda")
with attention_backend("native"):
output = dispatch_attention_fn(query, key, value, attn_mask)
output2 = dispatch_attention_fn(query, key, value)
assert not torch.equal(output, output2), "native: These outputs should not be equal!"
with attention_backend("flash"):
output = dispatch_attention_fn(query, key, value, attn_mask)
output2 = dispatch_attention_fn(query, key, value)
assert not torch.equal(output, output2), "flash: These outputs should not be equal!"
Traceback (most recent call last):
File "C:\repro.py", line 22, in <module>
assert not torch.equal(output, output2), "flash: These outputs should not be equal!"
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
AssertionError: flash: These outputs should not be equal!
Describe the bug
When using the attention backends feature, it's possible to pass an attention mask to the dispatch_attention_fn to a backend which does not support them (ie. attn_mask is not in the method signature registered to that particular backend), and have no error thrown.
This may cause unintended side-effects during training or inference where the attention will be computed as a full-attention scenario.
This is not validated either at the as part of the optional checks used for debugging purposes (
DIFFUSERS_ATTN_CHECKS=yes)Reproduction
System Info
Who can help?
No response