[Refactor Attention mask handling] Moves attention mask processing to the Attention class#28132
[Refactor Attention mask handling] Moves attention mask processing to the Attention class#28132ArthurZucker wants to merge 20 commits intomainfrom
Refactor Attention mask handling] Moves attention mask processing to the Attention class#28132Conversation
…tor-attention-converesion
gante
left a comment
There was a problem hiding this comment.
LGTM (yay, fewer config-dependent if/elses 🙌 )
BTW, for retrocompatibility, we may want to check whether the attention masks are 4D before expanding in the attention classes. As we've learned with the Cache refactor, other repos might rely on the interface of these internal classes, and this is technically an interface change (4D attention mask input -> 2D attention mask input).
|
We added support for 4d attention mask inside the converter so should be alright but yeah will check related issues! |
| class LlamaAttention(nn.Module): | ||
| """Multi-headed attention from 'Attention Is All You Need' paper""" | ||
|
|
||
| cached_mask = None |
There was a problem hiding this comment.
as seen with you offline, this cannot work as-is due to sharing across model instances
There was a problem hiding this comment.
Yes will either not cache it at the class level but instance level a tril. Or pass it as kwargs. Jax does not seem to care so should not be too bad
There was a problem hiding this comment.
Why can't we pass the attention_mask just into the cache.update(...) function?
There was a problem hiding this comment.
I'll check that as well, but that doesn't help for all type of attention we have which need a pre-processed mask, will work after the pre-processing tho
|
Can you specify how this helps with the static cache? The static cache should also work with the attention_mask being passed at every forward call (it'll always have the same shape). I don't think it's a good idea to have the |
|
It will not be a class variable forgot to update but I'll follow what we do with jax. |
|
Can give more details but basically new cache + attention was not behaving properly. This is gonna be my priority this week anyway! |
|
This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread. Please note that issues that do not follow the contributing guidelines are likely to be ignored. |
What does this PR do?
This is more aligned with our philosophy, but also simplifies and will simplify things.
Will help a lot with the static cache.
The only way to share the mask is to call
LlamaAttentionbut if you have a better way I'll update it!This makes the attention class self contained, which is also pretty convenient for testing.
Ran the slow test without fa2 will run them again on dgx once approved.
cc @patrickvonplaten for visibility