Skip to content

Fix SDPA sliding window compatibility#30127

Merged
fxmarty merged 5 commits intohuggingface:mainfrom
fxmarty:mistral-sdpa-window-attn
Apr 17, 2024
Merged

Fix SDPA sliding window compatibility#30127
fxmarty merged 5 commits intohuggingface:mainfrom
fxmarty:mistral-sdpa-window-attn

Conversation

@fxmarty
Copy link
Contributor

@fxmarty fxmarty commented Apr 8, 2024

As per title, fixes #28980

Supersedes #29220 #29407 as the implementation ends up being different (added you as co-author here @ehuaa).

This bug dates back to #26572 where sliding_window was not properly accounted for in the _prepare_4d_causal_attention_mask_for_sdpa method. Since then, SDPA support was added to models that use sliding window, but this bug was not yet fixed.

fxmarty and others added 3 commits April 8, 2024 18:03
@fxmarty fxmarty requested a review from ArthurZucker April 8, 2024 16:14
@ehuaa
Copy link
Contributor

ehuaa commented Apr 8, 2024

Hi, @fxmarty ,thanks for your great work! I think you can add some tests i mentioned in #29407, to check if the result with sliding window in SDPA is the same as flashattention2.

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@fxmarty
Copy link
Contributor Author

fxmarty commented Apr 8, 2024

@ehuaa Thank you, for sure will do!

Runing mistral, mixtral, starcoder2 tests, those fail but are already failing on main:

FAILED tests/models/mistral/test_modeling_mistral.py::MistralIntegrationTest::test_speculative_generation - AssertionError: 'My f[19 chars]is 100% Sriracha. I love the heat, the tang and the fact costs' != 'My f[19 chars]is 100% Sriracha. I love the heat, the ...
FAILED tests/models/mixtral/test_modeling_mixtral.py::MixtralIntegrationTest::test_small_model_logits - AssertionError: The values for attribute 'dtype' do not match: torch.float16 != torch.float32.
FAILED tests/models/mixtral/test_modeling_mixtral.py::MixtralIntegrationTest::test_small_model_logits_batched - AssertionError: The values for attribute 'device' do not match: cuda:0 != cpu.
FAILED tests/models/starcoder2/test_modeling_starcoder2.py::Starcoder2IntegrationTest::test_starcoder2_batched_generation_4bit - AssertionError: Lists differ: ['Hel[110 chars]t is related to the topic of "How to make a ga[179 chars]ute'] != ['Hel[110 chars]t is aimed at creating a...
FAILED tests/models/starcoder2/test_modeling_starcoder2.py::Starcoder2IntegrationTest::test_starcoder2_batched_generation_eager - AssertionError: Lists differ: ['Hel[181 chars]I am currently working on', "def hello_world()[114 chars]app"] != ['Hel[181 chars]I am looking for a', "de...

@fxmarty fxmarty requested review from ArthurZucker, LysandreJik and amyeroberts and removed request for ArthurZucker April 15, 2024 08:09
Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Don't know if I mentioned it offline, we'll refactor this to a single function without inheritance similar to update_causal_mask. See Recurrent Gemma, as it supports sliding window!

Thanks for re-enabling sliding window.

Comment on lines +319 to +324
ignore_causal_mask = False

if attention_mask is None:
if sliding_window is None or key_value_length < sliding_window:
ignore_causal_mask = not is_tracing
elif sliding_window is None or key_value_length < sliding_window:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are basically 2 cases:

  1. You ignore the causal mask
  2. You don't ignore it.
    Code is really not super clear but we will refactor this soon anyways.

@fxmarty fxmarty merged commit 40eb6d6 into huggingface:main Apr 17, 2024
ydshieh pushed a commit that referenced this pull request Apr 23, 2024
* fix sdpa + sliding window

* give credit

Co-authored-by: ehuaa <ehuamail@163.com>

* remove unnecessary warning

* fix typog

* add test

---------

Co-authored-by: ehuaa <ehuamail@163.com>
@gugarosa gugarosa mentioned this pull request Apr 24, 2024
5 tasks
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Add sliding window attention to sdpa in mistral

4 participants