Mamba / FalconMamba: Fix mamba left padding#32677
Mamba / FalconMamba: Fix mamba left padding#32677ArthurZucker merged 10 commits intohuggingface:mainfrom
Conversation
molbap
left a comment
There was a problem hiding this comment.
Thanks @younesbelkada for adding the states tuning-out! 😁 left a couple comments, mostly curious of some situations that were edge cases for mamba 2
Co-authored-by: Pablo Montalvo <39954772+molbap@users.noreply.github.com>
|
Can we propagate this to Jamba as well :D thx for this fix ❤️ |
molbap
left a comment
There was a problem hiding this comment.
LGTM! pinging @ArthurZucker for merging 🙂
ArthurZucker
left a comment
There was a problem hiding this comment.
Thanks for adding a test 🤗
| # In case cache is not used, manually add a new column in the attention mask | ||
| if not use_cache and attention_mask is not None and input_ids.shape != attention_mask.shape: | ||
| pad_length = input_ids.shape[-1] - attention_mask.shape[-1] | ||
| attention_mask = torch.cat([attention_mask, torch.ones_like(input_ids[:, :pad_length])], dim=-1) |
There was a problem hiding this comment.
Not sure I understand why we are adding a [1] x batch_size? ( past_length is usually gonna be 1 - current_generation_token , so imagine 20 input ids, then -19 to slice the input_ids?
Unless the inpud_ids is 20, but then it always has the same shape as the mask
There was a problem hiding this comment.
This is for users that run generation with use_cache=False and makes sure to manually update the attention mask because this is done no where else except here
There was a problem hiding this comment.
then this is more a problem with generate as it should pass the correct attention mask 😓
ArthurZucker
left a comment
There was a problem hiding this comment.
Will include this in the patch 🤗
| # In case cache is not used, manually update the attention mask | ||
| if not use_cache and attention_mask is not None and input_ids.shape != attention_mask.shape: | ||
| past_length = input_ids.shape[-1] - attention_mask.shape[-1] | ||
| attention_mask = torch.cat([attention_mask, torch.ones_like(input_ids[:, :past_length])], dim=-1) | ||
|
|
There was a problem hiding this comment.
that's the only thing bothering me as generate with use_cache = False should not alter the attention mask being passed
There was a problem hiding this comment.
Yes fixed it !
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
| def forward( | ||
| self, | ||
| input_ids: Optional[torch.LongTensor] = None, | ||
| attention_mask: Optional[torch.LongTensor] = None, |
There was a problem hiding this comment.
this is breaking (having it as the second place)
| if "attention_mask" in model_kwargs: | ||
| attention_mask = model_kwargs["attention_mask"] | ||
| model_kwargs["attention_mask"] = torch.cat( | ||
| [attention_mask, attention_mask.new_ones((attention_mask.shape[0], 1))], dim=-1 | ||
| ) |
* fix mamba left padding * Apply suggestions from code review Co-authored-by: Pablo Montalvo <39954772+molbap@users.noreply.github.com> * fix copies * test with `inputs_embeds` * Update src/transformers/models/falcon_mamba/modeling_falcon_mamba.py Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * copies * clairfy * fix last comments * remove --------- Co-authored-by: Pablo Montalvo <39954772+molbap@users.noreply.github.com> Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
This reverts commit 91b799b.
What does this PR do?
As pointed out in #32080 (comment) - it is important to zero-out hidden states that corresponds to the padd tokens before and after the causal convolution so that the padd token will not have an impact on the calculated hidden states.
This can be empirically proven by generation quality before / after this fix (note by default FalconMamba uses left padding):
Before the fix:
After the fix:
Propagated the changes in Mamba1 as well
cc @ArthurZucker @molbap