Skip to content

Add sliding window attention to sdpa in mistral #28980

@ehuaa

Description

@ehuaa

Feature request

https://github.com/huggingface/transformers/blob/main/src/transformers/models/mistral/modeling_mistral.py#L1006-L1023
image

In the code listed above, the latest version of transformers cannot use sliding window feature in mistral model.
I doubt that the reason is you mentioned above,
https://github.com/huggingface/transformers/blob/main/src/transformers/models/mistral/modeling_mistral.py#L687-L688
image
And this issue in PyTorch makes you bugged with custom attn_mask like sliding window attention mask.
pytorch/pytorch#112577

While this issue has been fixed since torch 2.2.0, and it has been released two weeks ago, can you add this feature back to sdpa kernel in mistral?

Motivation

I cannot use sliding window with sdpa right now, cause my gpu card is V100, i cannot work with flashattention2.

Your contribution

I think we can pass sliding_window param to _prepare_4d_causal_attention_mask_for_sdpa function.

Metadata

Metadata

Assignees

No one assigned

    Labels

    Good Second IssueIssues that are more difficult to do than "Good First" issues - give it a try if you want!

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions