[Mistral&Mixtral]Add sliding window for sdpa#29407
[Mistral&Mixtral]Add sliding window for sdpa#29407ehuaa wants to merge 21 commits intohuggingface:mainfrom
Conversation
…com/https://github.com/ehuaa/transformers into add_sliding_window_for_sdpa
…com/https://github.com/ehuaa/transformers into add_sliding_window_for_sdpa
There was a problem hiding this comment.
Thanks! Let's throw in a generation tests as well and we should be good to go! 🤗
There was a problem hiding this comment.
Thanks! Let's throw in a generation tests as well and we should be good to go! 🤗
Ok, and the test flash vs sdpa i submitted above cannot pass the tests, have you debugged with it? I'm also curious about the reason why it failed here.
There was a problem hiding this comment.
No I have not debugged it, I won't have the bandwidth, do you need help on this? cc @younesbelkada I think that this is pretty important
There was a problem hiding this comment.
No I have not debugged it, I won't have the bandwidth, do you need help on this? cc @younesbelkada I think that this is pretty important
and the generation test you mentioned above i think test_model_7b_long_prompt_sdpa is enough, it contains generation with sdpa and sliding window.
There was a problem hiding this comment.
No I have not debugged it, I won't have the bandwidth, do you need help on this? cc @younesbelkada I think that this is pretty important
And i see that https://github.com/huggingface/transformers/blob/main/tests/models/gemma/test_modeling_gemma.py#L471 gemma has a similar sdpa logits test as i committed. I think they have passed this test, maybe it can help with the debug.
ArthurZucker
left a comment
There was a problem hiding this comment.
Late but glad we waited!
The _prepare_4d_causal_attention_mask_for_sdpa does not seem to fair well with sliding_window when there is no mask. Let's add one more full generation tets similar to test_model_7b_logits_long_with_sdpa_and_flash2 but generating!
| model = MistralForCausalLM.from_pretrained( | ||
| "mistralai/Mistral-7B-v0.1", device_map="auto", attn_implementation="flash_attention_2" | ||
| ) | ||
| input_ids = torch.tensor([input_ids]).to(model.model.embed_tokens.weight.device) | ||
| with torch.no_grad(): | ||
| out = model(input_ids).logits.cpu() | ||
|
|
||
| input_ids = [1] + [306, 338] * 2048 | ||
| model = MistralForCausalLM.from_pretrained( | ||
| "mistralai/Mistral-7B-v0.1", device_map="auto", attn_implementation="sdpa" | ||
| ) |
There was a problem hiding this comment.
| model = MistralForCausalLM.from_pretrained( | |
| "mistralai/Mistral-7B-v0.1", device_map="auto", attn_implementation="flash_attention_2" | |
| ) | |
| input_ids = torch.tensor([input_ids]).to(model.model.embed_tokens.weight.device) | |
| with torch.no_grad(): | |
| out = model(input_ids).logits.cpu() | |
| input_ids = [1] + [306, 338] * 2048 | |
| model = MistralForCausalLM.from_pretrained( | |
| "mistralai/Mistral-7B-v0.1", device_map="auto", attn_implementation="sdpa" | |
| ) | |
| model = MistralForCausalLM.from_pretrained( | |
| "mistralai/Mistral-7B-v0.1", device_map="auto", attn_implementation="flash_attention_2", torch_dtype=torch.bfloat16 | |
| ) | |
| input_ids = torch.tensor([input_ids]).to(model.model.embed_tokens.weight.device) | |
| with torch.no_grad(): | |
| out = model(input_ids).logits.cpu() | |
| input_ids = [1] + [306, 338] * 2048 | |
| model = MistralForCausalLM.from_pretrained( | |
| "mistralai/Mistral-7B-v0.1", device_map="auto", attn_implementation="sdpa", torch_dtype=torch.bfloat16 | |
| ) |
I am getting an error because by default it seems to be float32.
| with torch.no_grad(): | ||
| out = model(input_ids).logits.cpu() | ||
|
|
||
| input_ids = [1] + [306, 338] * 2048 |
There was a problem hiding this comment.
| input_ids = [1] + [306, 338] * 2048 |
| model = MistralForCausalLM.from_pretrained( | ||
| "mistralai/Mistral-7B-v0.1", device_map="auto", attn_implementation="sdpa" | ||
| ) | ||
| input_ids = torch.tensor([input_ids]).to(model.model.embed_tokens.weight.device) |
There was a problem hiding this comment.
| input_ids = torch.tensor([input_ids]).to(model.model.embed_tokens.weight.device) |
| input_ids = torch.tensor([input_ids]).to(model.model.embed_tokens.weight.device) | ||
| with torch.no_grad(): | ||
| out1 = model(input_ids).logits.cpu() | ||
| torch.testing.assert_close(out.mean(-1), out1.mean(-1), atol=1e-2, rtol=1e-2) |
There was a problem hiding this comment.
let's make sure we test all logits not just the mean
| torch.testing.assert_close(out.mean(-1), out1.mean(-1), atol=1e-2, rtol=1e-2) | |
| torch.testing.assert_close(out, out1, atol=1e-4, rtol=1e-4) |
with this, the test is failing:
> torch.testing.assert_close(out, out1, atol=1e-4, rtol=1e-4)
E AssertionError: Tensor-likes are not close!
E
E Mismatched elements: 90967735 / 131104000 (69.4%)
E Greatest absolute difference: 0.328125 at index (0, 2310, 338) (up to 0.0001 allowed)
E Greatest relative difference: 114689.0 at index (0, 1267, 4581) (up to 0.0001 allowed)| (batch_size, seq_length), | ||
| inputs_embeds, | ||
| past_key_values_length, | ||
| sliding_window=self.config.sliding_window if is_torch_version_greater_or_equal_than_2_2_0 else None, |
There was a problem hiding this comment.
The issue here is that _prepare_4d_causal_attention_mask_for_sdpa seems to return None if attention_mask is None (which is the case in the test) while if we actually want to use sliding we need to return the full causal mask. cc @fxmarty
|
@fxmarty if you want to take over in a new PR, this is fairly important IMO |
|
This PL will solve #28980 |
|
This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread. Please note that issues that do not follow the contributing guidelines are likely to be ignored. |
|
closing as #30127 was merged and takes inspiration from this PR |
@ArthurZucker Arthur has reviewed before, but my git changes log is weird, so i open a new pr instead. I uploaded a new test for slidingwindow flash vs sdpa for checking.
Superseeds #29220