-
Notifications
You must be signed in to change notification settings - Fork 32.5k
Closed
Closed
Copy link
Description
... SDPA causal mask generation may be wrong for the mask generation.
transformers/src/transformers/modeling_attn_mask_utils.py
Lines 421 to 433 in 76fa17c
| if torch.all(mask == 1): | |
| if is_tracing: | |
| pass | |
| elif tgt_len == 1: | |
| # For query_length == 1, causal attention and bi-directional attention are the same. | |
| return None | |
| elif key_value_length == tgt_len: | |
| return None | |
| else: | |
| # Unfortunately, for query_length > 1 and key_value_length != query_length, we can not generally ignore the attention mask, as SDPA causal mask generation | |
| # may be wrong. We will set is_causal=False in SDPA and rely on Transformers attention_mask instead, hence not setting it to None here. | |
| # Reference: https://github.com/pytorch/pytorch/issues/108108 | |
| return AttentionMaskConverter._expand_mask(mask=mask, dtype=dtype, tgt_len=tgt_len) |
Will it be safe to just return None for the else: case?
For causal attention, we can just use _prepare_4d_causal_attention_mask_for_sdpa
Related issues:
pytorch/pytorch#108108
Dao-AILab/flash-attention@9e5e8bc
#28802
Reactions are currently unavailable
Metadata
Metadata
Assignees
Labels
No labels