Skip to content

Conversation

@aciddelgado
Copy link
Contributor

Description

Implement preliminary version of local (sliding window) attention. Currently only supported by Flash Attention (sm >= 80, Linux). Currently only supports sliding attention with a large cached kv.

Motivation and Context

This change enables to run Mistral and other models which use sliding window attention.

attention = attention.masked_fill(torch.all(causal_mask, dim=-1, keepdim=True), 0.0)
# Some rows might be completely masked out so we fill them with zero instead of NaN
if window_size[0] >= 0 or window_size[1] >= 0:
attention = attention.masked_fill(torch.all(local_mask, dim=-1, keepdim=True), 0.0)

Check failure

Code scanning / CodeQL

Potentially uninitialized local variable

Local variable 'local_mask' may be used before it is initialized.
yufenglee
yufenglee previously approved these changes Nov 15, 2023
Copy link
Member

@yufenglee yufenglee left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

:shipit:

@aciddelgado aciddelgado merged commit adb56df into main Nov 16, 2023
@aciddelgado aciddelgado deleted the aciddelgado/gqa_local branch November 16, 2023 23:01
kleiti pushed a commit to kleiti/onnxruntime that referenced this pull request Mar 22, 2024
### Description
Implement preliminary version of local (sliding window) attention.
Currently only supported by Flash Attention (sm >= 80, Linux). Currently
only supports sliding attention with a large cached kv.



### Motivation and Context
This change enables to run Mistral and other models which use sliding
window attention.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants